From e5ed2cba4aab5737b770b2c94c615d45927f7270 Mon Sep 17 00:00:00 2001 From: tobwen <1864057+tobwen@users.noreply.github.com> Date: Tue, 28 Oct 2025 00:30:56 +0100 Subject: [PATCH] Add support for dynamic model providers Implements functionality to parse model names with provider information in the format "provider://model" This allows dynamic provider selection rather than relying only on predefined mappings. The change affects all execution methods to properly handle these dynamic model specifications while maintaining compatibility with the existing approach for standard model names. --- examples/custom-provider/main.go | 4 + internal/api/server.go | 15 +++- .../chat-completions/openai_openai_request.go | 12 ++- sdk/api/handlers/handlers.go | 84 +++++++++++++++---- 4 files changed, 97 insertions(+), 18 deletions(-) diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index e6a76475..eb1755d0 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe return ch, nil } +func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) { + return clipexec.Response{}, errors.New("not implemented") +} + func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { return a, nil } diff --git a/internal/api/server.go b/internal/api/server.go index f52ca80a..afc19318 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk envManagementSecret := envAdminPasswordSet && envAdminPassword != "" // Create server instance + providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) + for _, p := range cfg.OpenAICompatibility { + providerNames = append(providerNames, p.Name) + } s := &Server{ engine: engine, - handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), + handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames), cfg: cfg, accessManager: accessManager, requestLogger: requestLogger, @@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) { managementasset.SetCurrentConfig(cfg) // Save YAML snapshot for next comparison s.oldConfigYaml, _ = yaml.Marshal(cfg) + + providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) + for _, p := range cfg.OpenAICompatibility { + providerNames = append(providerNames, p.Name) + } + s.handlers.OpenAICompatProviders = providerNames + s.handlers.UpdateClients(&cfg.SDKConfig) if !cfg.RemoteManagement.DisableControlPanel { @@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc { } } -// legacy clientsToSlice removed; handlers no longer consume legacy client slices +// legacy clientsToSlice removed; handlers no longer consume legacy client slices \ No newline at end of file diff --git a/internal/translator/openai/openai/chat-completions/openai_openai_request.go b/internal/translator/openai/openai/chat-completions/openai_openai_request.go index 1ff0f7c8..211c0eb4 100644 --- a/internal/translator/openai/openai/chat-completions/openai_openai_request.go +++ b/internal/translator/openai/openai/chat-completions/openai_openai_request.go @@ -4,6 +4,7 @@ package chat_completions import ( "bytes" + "github.com/tidwall/sjson" ) // ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON) @@ -17,5 +18,14 @@ import ( // Returns: // - []byte: The transformed request data in Gemini CLI API format func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte { - return bytes.Clone(inputRawJSON) + // Update the "model" field in the JSON payload with the provided modelName + // The sjson.SetBytes function returns a new byte slice with the updated JSON. + updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName) + if err != nil { + // If there's an error, return the original JSON or handle the error appropriately. + // For now, we'll return the original, but in a real scenario, logging or a more robust error + // handling mechanism would be needed. + return bytes.Clone(inputRawJSON) + } + return updatedJSON } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 73c647f3..0a1df939 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -6,6 +6,7 @@ package handlers import ( "fmt" "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" @@ -46,6 +47,9 @@ type BaseAPIHandler struct { // Cfg holds the current application configuration. Cfg *config.SDKConfig + + // OpenAICompatProviders is a list of provider names for OpenAI compatibility. + OpenAICompatProviders []string } // NewBaseAPIHandlers creates a new API handlers instance. @@ -57,10 +61,11 @@ type BaseAPIHandler struct { // // Returns: // - *BaseAPIHandler: A new API handlers instance -func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { +func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, + Cfg: cfg, + AuthManager: authManager, + OpenAICompatProviders: openAICompatProviders, } } @@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { + return nil, errMsg } req := coreexecutor.Request{ Model: normalizedModel, @@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { - return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { + return nil, errMsg } req := coreexecutor.Request{ Model: normalizedModel, @@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { - normalizedModel, metadata := normalizeModelMetadata(modelName) - providers := util.GetProviderName(normalizedModel) - if len(providers) == 0 { + providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName) + if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) - errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + errChan <- errMsg close(errChan) return nil, errChan } @@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, errChan } +func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { + providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName) + + // First, normalize the model name to handle suffixes like "-thinking-128" + // This needs to happen before determining the provider for non-dynamic models. + normalizedModel, metadata = normalizeModelMetadata(modelName) + + if isDynamic { + providers = []string{providerName} + // For dynamic models, the extractedModelName is already normalized by parseDynamicModel + // so we use it as the final normalizedModel. + normalizedModel = extractedModelName + } else { + // For non-dynamic models, use the normalizedModel to get the provider name. + providers = util.GetProviderName(normalizedModel) + } + + if len(providers) == 0 { + return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + } + + // If it's a dynamic model, the normalizedModel was already set to extractedModelName. + // If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata. + // So, normalizedModel is already correctly set at this point. + + return providers, normalizedModel, metadata, nil +} + +func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) { + var providerPart, modelPart string + for _, sep := range []string{"://"} { + if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 { + providerPart = parts[0] + modelPart = parts[1] + break + } + } + + if providerPart == "" { + return "", modelName, false + } + + // Check if the provider is a configured openai-compatibility provider + for _, pName := range h.OpenAICompatProviders { + if pName == providerPart { + return providerPart, modelPart, true + } + } + + return "", modelName, false +} + func cloneBytes(src []byte) []byte { if len(src) == 0 { return nil