diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index e6a76475..eb1755d0 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe return ch, nil } +func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) { + return clipexec.Response{}, errors.New("not implemented") +} + func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { return a, nil } diff --git a/internal/api/server.go b/internal/api/server.go index f52ca80a..afc19318 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk envManagementSecret := envAdminPasswordSet && envAdminPassword != "" // Create server instance + providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) + for _, p := range cfg.OpenAICompatibility { + providerNames = append(providerNames, p.Name) + } s := &Server{ engine: engine, - handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), + handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames), cfg: cfg, accessManager: accessManager, requestLogger: requestLogger, @@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) { managementasset.SetCurrentConfig(cfg) // Save YAML snapshot for next comparison s.oldConfigYaml, _ = yaml.Marshal(cfg) + + providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) + for _, p := range cfg.OpenAICompatibility { + providerNames = append(providerNames, p.Name) + } + s.handlers.OpenAICompatProviders = providerNames + s.handlers.UpdateClients(&cfg.SDKConfig) if !cfg.RemoteManagement.DisableControlPanel { @@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc { } } -// legacy clientsToSlice removed; handlers no longer consume legacy client slices +// legacy clientsToSlice removed; handlers no longer consume legacy client slices \ No newline at end of file diff --git a/internal/translator/openai/openai/chat-completions/openai_openai_request.go b/internal/translator/openai/openai/chat-completions/openai_openai_request.go index 1ff0f7c8..211c0eb4 100644 --- a/internal/translator/openai/openai/chat-completions/openai_openai_request.go +++ b/internal/translator/openai/openai/chat-completions/openai_openai_request.go @@ -4,6 +4,7 @@ package chat_completions import ( "bytes" + "github.com/tidwall/sjson" ) // ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON) @@ -17,5 +18,14 @@ import ( // Returns: // - []byte: The transformed request data in Gemini CLI API format func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte { - return bytes.Clone(inputRawJSON) + // Update the "model" field in the JSON payload with the provided modelName + // The sjson.SetBytes function returns a new byte slice with the updated JSON. + updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName) + if err != nil { + // If there's an error, return the original JSON or handle the error appropriately. + // For now, we'll return the original, but in a real scenario, logging or a more robust error + // handling mechanism would be needed. + return bytes.Clone(inputRawJSON) + } + return updatedJSON } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 73c647f3..0a1df939 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -6,6 +6,7 @@ package handlers import ( "fmt" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" @@ -46,6 +47,9 @@ type BaseAPIHandler struct { // Cfg holds the current application configuration. Cfg *config.SDKConfig + + // OpenAICompatProviders is a list of provider names for OpenAI compatibility. + OpenAICompatProviders []string } // NewBaseAPIHandlers creates a new API handlers instance. @@ -57,10 +61,11 @@ type BaseAPIHandler struct { // // Returns: // - *BaseAPIHandler: A new API handlers instance -func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { +func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, + Cfg: cfg, + AuthManager: authManager, + OpenAICompatProviders: openAICompatProviders, } } @@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { + return nil, errMsg } req := coreexecutor.Request{ Model: normalizedModel, @@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { + return nil, errMsg } req := coreexecutor.Request{ Model: normalizedModel, @@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) - errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + errChan <- errMsg close(errChan) return nil, errChan } @@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, errChan } +func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { + providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName) + + // First, normalize the model name to handle suffixes like "-thinking-128" + // This needs to happen before determining the provider for non-dynamic models. + normalizedModel, metadata = normalizeModelMetadata(modelName) + + if isDynamic { + providers = []string{providerName} + // For dynamic models, the extractedModelName is already normalized by parseDynamicModel + // so we use it as the final normalizedModel. + normalizedModel = extractedModelName + } else { + // For non-dynamic models, use the normalizedModel to get the provider name. + providers = util.GetProviderName(normalizedModel) + } + + if len(providers) == 0 { + return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + } + + // If it's a dynamic model, the normalizedModel was already set to extractedModelName. + // If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata. + // So, normalizedModel is already correctly set at this point. + + return providers, normalizedModel, metadata, nil +} + +func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) { + var providerPart, modelPart string + for _, sep := range []string{"://"} { + if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 { + providerPart = parts[0] + modelPart = parts[1] + break + } + } + + if providerPart == "" { + return "", modelName, false + } + + // Check if the provider is a configured openai-compatibility provider + for _, pName := range h.OpenAICompatProviders { + if pName == providerPart { + return providerPart, modelPart, true + } + } + + return "", modelName, false +} + func cloneBytes(src []byte) []byte { if len(src) == 0 { return nil