mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-19 04:40:52 +08:00
Merge pull request #173 from tobwen/feature/dynamic-model-routing
Add support for dynamic model providers
This commit is contained in:
@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||
return clipexec.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
@@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
|
||||
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
|
||||
@@ -4,6 +4,7 @@ package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
|
||||
@@ -17,5 +18,14 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
return bytes.Clone(inputRawJSON)
|
||||
// Update the "model" field in the JSON payload with the provided modelName
|
||||
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
|
||||
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
|
||||
if err != nil {
|
||||
// If there's an error, return the original JSON or handle the error appropriately.
|
||||
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||
// handling mechanism would be needed.
|
||||
return bytes.Clone(inputRawJSON)
|
||||
}
|
||||
return updatedJSON
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package handlers
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
OpenAICompatProviders []string
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
OpenAICompatProviders: openAICompatProviders,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
errChan <- errMsg
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
}
|
||||
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
|
||||
|
||||
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||
// This needs to happen before determining the provider for non-dynamic models.
|
||||
normalizedModel, metadata = normalizeModelMetadata(modelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
|
||||
// so we use it as the final normalizedModel.
|
||||
normalizedModel = extractedModelName
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
}
|
||||
|
||||
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
||||
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
||||
// So, normalizedModel is already correctly set at this point.
|
||||
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.OpenAICompatProviders {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user