mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Merge pull request #173 from tobwen/feature/dynamic-model-routing
Add support for dynamic model providers
This commit is contained in:
@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
|
|||||||
return ch, nil
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
|
||||||
|
return clipexec.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||||
|
|
||||||
// Create server instance
|
// Create server instance
|
||||||
|
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||||
|
for _, p := range cfg.OpenAICompatibility {
|
||||||
|
providerNames = append(providerNames, p.Name)
|
||||||
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
engine: engine,
|
engine: engine,
|
||||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
accessManager: accessManager,
|
accessManager: accessManager,
|
||||||
requestLogger: requestLogger,
|
requestLogger: requestLogger,
|
||||||
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
// Save YAML snapshot for next comparison
|
// Save YAML snapshot for next comparison
|
||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
|
|
||||||
|
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||||
|
for _, p := range cfg.OpenAICompatibility {
|
||||||
|
providerNames = append(providerNames, p.Name)
|
||||||
|
}
|
||||||
|
s.handlers.OpenAICompatProviders = providerNames
|
||||||
|
|
||||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||||
|
|
||||||
if !cfg.RemoteManagement.DisableControlPanel {
|
if !cfg.RemoteManagement.DisableControlPanel {
|
||||||
@@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
|
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
|
||||||
@@ -4,6 +4,7 @@ package chat_completions
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
|
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
|
||||||
@@ -17,5 +18,14 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The transformed request data in Gemini CLI API format
|
// - []byte: The transformed request data in Gemini CLI API format
|
||||||
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
return bytes.Clone(inputRawJSON)
|
// Update the "model" field in the JSON payload with the provided modelName
|
||||||
|
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
|
||||||
|
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
|
||||||
|
if err != nil {
|
||||||
|
// If there's an error, return the original JSON or handle the error appropriately.
|
||||||
|
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||||
|
// handling mechanism would be needed.
|
||||||
|
return bytes.Clone(inputRawJSON)
|
||||||
|
}
|
||||||
|
return updatedJSON
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package handlers
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
|
|||||||
|
|
||||||
// Cfg holds the current application configuration.
|
// Cfg holds the current application configuration.
|
||||||
Cfg *config.SDKConfig
|
Cfg *config.SDKConfig
|
||||||
|
|
||||||
|
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||||
|
OpenAICompatProviders []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||||
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *BaseAPIHandler: A new API handlers instance
|
// - *BaseAPIHandler: A new API handlers instance
|
||||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||||
return &BaseAPIHandler{
|
return &BaseAPIHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthManager: authManager,
|
AuthManager: authManager,
|
||||||
|
OpenAICompatProviders: openAICompatProviders,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
|||||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||||
providers := util.GetProviderName(normalizedModel)
|
if errMsg != nil {
|
||||||
if len(providers) == 0 {
|
return nil, errMsg
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
||||||
}
|
}
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||||
providers := util.GetProviderName(normalizedModel)
|
if errMsg != nil {
|
||||||
if len(providers) == 0 {
|
return nil, errMsg
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
||||||
}
|
}
|
||||||
req := coreexecutor.Request{
|
req := coreexecutor.Request{
|
||||||
Model: normalizedModel,
|
Model: normalizedModel,
|
||||||
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||||
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||||
providers := util.GetProviderName(normalizedModel)
|
if errMsg != nil {
|
||||||
if len(providers) == 0 {
|
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
errChan <- errMsg
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, errChan
|
||||||
}
|
}
|
||||||
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return dataChan, errChan
|
return dataChan, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||||
|
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
|
||||||
|
|
||||||
|
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||||
|
// This needs to happen before determining the provider for non-dynamic models.
|
||||||
|
normalizedModel, metadata = normalizeModelMetadata(modelName)
|
||||||
|
|
||||||
|
if isDynamic {
|
||||||
|
providers = []string{providerName}
|
||||||
|
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
|
||||||
|
// so we use it as the final normalizedModel.
|
||||||
|
normalizedModel = extractedModelName
|
||||||
|
} else {
|
||||||
|
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||||
|
providers = util.GetProviderName(normalizedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
||||||
|
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
||||||
|
// So, normalizedModel is already correctly set at this point.
|
||||||
|
|
||||||
|
return providers, normalizedModel, metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||||
|
var providerPart, modelPart string
|
||||||
|
for _, sep := range []string{"://"} {
|
||||||
|
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||||
|
providerPart = parts[0]
|
||||||
|
modelPart = parts[1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if providerPart == "" {
|
||||||
|
return "", modelName, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the provider is a configured openai-compatibility provider
|
||||||
|
for _, pName := range h.OpenAICompatProviders {
|
||||||
|
if pName == providerPart {
|
||||||
|
return providerPart, modelPart, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", modelName, false
|
||||||
|
}
|
||||||
|
|
||||||
func cloneBytes(src []byte) []byte {
|
func cloneBytes(src []byte) []byte {
|
||||||
if len(src) == 0 {
|
if len(src) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user