mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(config): add support for model prefixes and prefix normalization
Refactor model management to include an optional `prefix` field for model credentials, enabling better namespace handling. Update affected configuration files, APIs, and handlers to support prefix normalization and routing. Remove unused OpenAI compatibility provider logic to simplify processing.
This commit is contained in:
@@ -49,9 +49,6 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
OpenAICompatProviders []string
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -63,11 +60,10 @@ type BaseAPIHandler struct {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
OpenAICompatProviders: openAICompatProviders,
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,30 +338,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
targetModelName := resolvedModelName
|
||||
if isDynamic {
|
||||
targetModelName = extractedModelName
|
||||
}
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(targetModelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,30 +368,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.OpenAICompatProviders {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -396,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.Execute(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -420,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -453,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -477,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -510,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
@@ -535,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
|
||||
if auth == nil || model == "" {
|
||||
return model, metadata
|
||||
}
|
||||
prefix := strings.TrimSpace(auth.Prefix)
|
||||
if prefix == "" {
|
||||
return model, metadata
|
||||
}
|
||||
needle := prefix + "/"
|
||||
if !strings.HasPrefix(model, needle) {
|
||||
return model, metadata
|
||||
}
|
||||
rewritten := strings.TrimPrefix(model, needle)
|
||||
return rewritten, stripPrefixFromMetadata(metadata, needle)
|
||||
}
|
||||
|
||||
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
|
||||
if len(metadata) == 0 || needle == "" {
|
||||
return metadata
|
||||
}
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
value, okStr := raw.(string)
|
||||
if !okStr || !strings.HasPrefix(value, needle) {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = make(map[string]any, len(metadata))
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[key] = strings.TrimPrefix(value, needle)
|
||||
}
|
||||
if out == nil {
|
||||
return metadata
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,8 @@ type Auth struct {
|
||||
Index uint64 `json:"-"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
// FileName stores the relative or absolute path of the backing auth file.
|
||||
FileName string `json:"-"`
|
||||
// Storage holds the token persistence implementation used during login flows.
|
||||
|
||||
@@ -787,7 +787,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -807,7 +807,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -987,6 +987,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
return filtered
|
||||
}
|
||||
|
||||
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
|
||||
trimmedPrefix := strings.TrimSpace(prefix)
|
||||
if trimmedPrefix == "" || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
out := make([]*ModelInfo, 0, len(models)*2)
|
||||
seen := make(map[string]struct{}, len(models)*2)
|
||||
|
||||
addModel := func(model *ModelInfo) {
|
||||
if model == nil {
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
baseID := strings.TrimSpace(model.ID)
|
||||
if baseID == "" {
|
||||
continue
|
||||
}
|
||||
if !forceModelPrefix || trimmedPrefix == baseID {
|
||||
addModel(model)
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = trimmedPrefix + "/" + baseID
|
||||
addModel(&clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
|
||||
func matchWildcard(pattern, value string) bool {
|
||||
if pattern == "" {
|
||||
|
||||
@@ -9,6 +9,11 @@ type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user