Merge pull request #173 from tobwen/feature/dynamic-model-routing

Add support for dynamic model providers
This commit is contained in:
Luis Pater
2025-10-28 08:55:08 +08:00
committed by GitHub
4 changed files with 97 additions and 18 deletions

View File

@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
return ch, nil return ch, nil
} }
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("not implemented")
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) { func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
return a, nil return a, nil
} }

View File

@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
envManagementSecret := envAdminPasswordSet && envAdminPassword != "" envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
// Create server instance // Create server instance
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s := &Server{ s := &Server{
engine: engine, engine: engine,
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
cfg: cfg, cfg: cfg,
accessManager: accessManager, accessManager: accessManager,
requestLogger: requestLogger, requestLogger: requestLogger,
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
managementasset.SetCurrentConfig(cfg) managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison // Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg) s.oldConfigYaml, _ = yaml.Marshal(cfg)
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s.handlers.OpenAICompatProviders = providerNames
s.handlers.UpdateClients(&cfg.SDKConfig) s.handlers.UpdateClients(&cfg.SDKConfig)
if !cfg.RemoteManagement.DisableControlPanel { if !cfg.RemoteManagement.DisableControlPanel {
@@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
} }
} }
// legacy clientsToSlice removed; handlers no longer consume legacy client slices // legacy clientsToSlice removed; handlers no longer consume legacy client slices

View File

@@ -4,6 +4,7 @@ package chat_completions
import ( import (
"bytes" "bytes"
"github.com/tidwall/sjson"
) )
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON) // ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
@@ -17,5 +18,14 @@ import (
// Returns: // Returns:
// - []byte: The transformed request data in Gemini CLI API format // - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
return bytes.Clone(inputRawJSON) // Update the "model" field in the JSON payload with the provided modelName
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
if err != nil {
// If there's an error, return the original JSON or handle the error appropriately.
// For now, we'll return the original, but in a real scenario, logging or a more robust error
// handling mechanism would be needed.
return bytes.Clone(inputRawJSON)
}
return updatedJSON
} }

View File

@@ -6,6 +6,7 @@ package handlers
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
// Cfg holds the current application configuration. // Cfg holds the current application configuration.
Cfg *config.SDKConfig Cfg *config.SDKConfig
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
OpenAICompatProviders []string
} }
// NewBaseAPIHandlers creates a new API handlers instance. // NewBaseAPIHandlers creates a new API handlers instance.
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
// //
// Returns: // Returns:
// - *BaseAPIHandler: A new API handlers instance // - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
return &BaseAPIHandler{ return &BaseAPIHandler{
Cfg: cfg, Cfg: cfg,
AuthManager: authManager, AuthManager: authManager,
OpenAICompatProviders: openAICompatProviders,
} }
} }
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 { return nil, errMsg
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 { return nil, errMsg
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: normalizedModel, Model: normalizedModel,
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route. // This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName) providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
providers := util.GetProviderName(normalizedModel) if errMsg != nil {
if len(providers) == 0 {
errChan := make(chan *interfaces.ErrorMessage, 1) errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} errChan <- errMsg
close(errChan) close(errChan)
return nil, errChan return nil, errChan
} }
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, errChan return dataChan, errChan
} }
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
// First, normalize the model name to handle suffixes like "-thinking-128"
// This needs to happen before determining the provider for non-dynamic models.
normalizedModel, metadata = normalizeModelMetadata(modelName)
if isDynamic {
providers = []string{providerName}
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
// so we use it as the final normalizedModel.
normalizedModel = extractedModelName
} else {
// For non-dynamic models, use the normalizedModel to get the provider name.
providers = util.GetProviderName(normalizedModel)
}
if len(providers) == 0 {
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
// So, normalizedModel is already correctly set at this point.
return providers, normalizedModel, metadata, nil
}
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
var providerPart, modelPart string
for _, sep := range []string{"://"} {
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
providerPart = parts[0]
modelPart = parts[1]
break
}
}
if providerPart == "" {
return "", modelName, false
}
// Check if the provider is a configured openai-compatibility provider
for _, pName := range h.OpenAICompatProviders {
if pName == providerPart {
return providerPart, modelPart, true
}
}
return "", modelName, false
}
func cloneBytes(src []byte) []byte { func cloneBytes(src []byte) []byte {
if len(src) == 0 { if len(src) == 0 {
return nil return nil