Add support for dynamic model providers

Implements functionality to parse model names with provider information in the format "provider://model" This allows dynamic provider selection rather than relying only on predefined mappings.

The change affects all execution methods to properly handle these dynamic model specifications while maintaining compatibility with the existing approach for standard model names.
This commit is contained in:
tobwen
2025-10-28 00:30:56 +01:00
parent c7196ba7dc
commit e5ed2cba4a
4 changed files with 97 additions and 18 deletions

View File

@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
return ch, nil return ch, nil
} }
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("not implemented")
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
return a, nil return a, nil
} }

View File

@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
envManagementSecret := envAdminPasswordSet && envAdminPassword != "" envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
// Create server instance // Create server instance
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s := &Server{ s := &Server{
engine: engine, engine: engine,
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
cfg: cfg, cfg: cfg,
accessManager: accessManager, accessManager: accessManager,
requestLogger: requestLogger, requestLogger: requestLogger,
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
managementasset.SetCurrentConfig(cfg) managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison // Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg) s.oldConfigYaml, _ = yaml.Marshal(cfg)
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s.handlers.OpenAICompatProviders = providerNames
s.handlers.UpdateClients(&cfg.SDKConfig) s.handlers.UpdateClients(&cfg.SDKConfig)
if !cfg.RemoteManagement.DisableControlPanel { if !cfg.RemoteManagement.DisableControlPanel {

View File

@@ -4,6 +4,7 @@ package chat_completions
import ( import (
"bytes" "bytes"
"github.com/tidwall/sjson"
) )
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON) // ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
@@ -17,5 +18,14 @@ import (
// Returns: // Returns:
// - []byte: The transformed request data in Gemini CLI API format // - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
// Update the "model" field in the JSON payload with the provided modelName
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
if err != nil {
// If there's an error, return the original JSON or handle the error appropriately.
// For now, we'll return the original, but in a real scenario, logging or a more robust error
// handling mechanism would be needed.
return bytes.Clone(inputRawJSON) return bytes.Clone(inputRawJSON)
}
return updatedJSON
} }

View File

@@ -6,6 +6,7 @@ package handlers
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
// Cfg holds the current application configuration. // Cfg holds the current application configuration.
Cfg *config.SDKConfig Cfg *config.SDKConfig
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
OpenAICompatProviders []string
} }
// NewBaseAPIHandlers creates a new API handlers instance. // NewBaseAPIHandlers creates a new API handlers instance.
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
// //
// Returns: // Returns:
// - *BaseAPIHandler: A new API handlers instance // - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
return &BaseAPIHandler{ return &BaseAPIHandler{
Cfg: cfg, Cfg: cfg,
AuthManager: authManager, AuthManager: authManager,
OpenAICompatProviders: openAICompatProviders,
} }
} }
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 { return nil, errMsg
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 { return nil, errMsg
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 {
errChan := make(chan *interfaces.ErrorMessage, 1) errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} errChan <- errMsg
close(errChan) close(errChan)
return nil, errChan return nil, errChan
} }
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, errChan return dataChan, errChan
} }
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
// First, normalize the model name to handle suffixes like "-thinking-128"
// This needs to happen before determining the provider for non-dynamic models.
normalizedModel, metadata = normalizeModelMetadata(modelName)
if isDynamic {
providers = []string{providerName}
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
// so we use it as the final normalizedModel.
normalizedModel = extractedModelName
} else {
// For non-dynamic models, use the normalizedModel to get the provider name.
providers = util.GetProviderName(normalizedModel)
}
if len(providers) == 0 {
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
// So, normalizedModel is already correctly set at this point.
return providers, normalizedModel, metadata, nil
}
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
var providerPart, modelPart string
for _, sep := range []string{"://"} {
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
providerPart = parts[0]
modelPart = parts[1]
break
}
}
if providerPart == "" {
return "", modelName, false
}
// Check if the provider is a configured openai-compatibility provider
for _, pName := range h.OpenAICompatProviders {
if pName == providerPart {
return providerPart, modelPart, true
}
}
return "", modelName, false
}
func cloneBytes(src []byte) []byte { func cloneBytes(src []byte) []byte {
if len(src) == 0 { if len(src) == 0 {
return nil return nil