feat(config): add support for model prefixes and prefix normalization

Refactor model management to include an optional `prefix` field for model credentials, enabling better namespace handling. Update affected configuration files, APIs, and handlers to support prefix normalization and routing. Remove unused OpenAI compatibility provider logic to simplify processing.
This commit is contained in:
Luis Pater
2025-12-17 01:07:26 +08:00
parent fcecbc7d46
commit 52b6306388
10 changed files with 210 additions and 87 deletions

View File

@@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
@@ -396,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
resp, errExec := executor.Execute(execCtx, auth, req, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
@@ -420,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
@@ -453,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
resp, errExec := executor.CountTokens(execCtx, auth, req, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
@@ -477,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
if provider == "" {
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
@@ -510,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
@@ -535,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
out <- chunk
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
if auth == nil || model == "" {
return model, metadata
}
prefix := strings.TrimSpace(auth.Prefix)
if prefix == "" {
return model, metadata
}
needle := prefix + "/"
if !strings.HasPrefix(model, needle) {
return model, metadata
}
rewritten := strings.TrimPrefix(model, needle)
return rewritten, stripPrefixFromMetadata(metadata, needle)
}
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
if len(metadata) == 0 || needle == "" {
return metadata
}
keys := []string{
util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey,
}
var out map[string]any
for _, key := range keys {
raw, ok := metadata[key]
if !ok {
continue
}
value, okStr := raw.(string)
if !okStr || !strings.HasPrefix(value, needle) {
continue
}
if out == nil {
out = make(map[string]any, len(metadata))
for k, v := range metadata {
out[k] = v
}
}
out[key] = strings.TrimPrefix(value, needle)
}
if out == nil {
return metadata
}
return out
}
func (m *Manager) normalizeProviders(providers []string) []string {
if len(providers) == 0 {
return nil

View File

@@ -19,6 +19,8 @@ type Auth struct {
Index uint64 `json:"-"`
// Provider is the upstream provider key (e.g. "gemini", "claude").
Provider string `json:"provider"`
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
Prefix string `json:"prefix,omitempty"`
// FileName stores the relative or absolute path of the backing auth file.
FileName string `json:"-"`
// Storage holds the token persistence implementation used during login flows.

View File

@@ -787,7 +787,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if providerKey == "" {
providerKey = "openai-compatibility"
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
} else {
// Ensure stale registrations are cleared when model list becomes empty.
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -807,7 +807,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if key == "" {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, models)
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
return
}
@@ -987,6 +987,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
return filtered
}
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
trimmedPrefix := strings.TrimSpace(prefix)
if trimmedPrefix == "" || len(models) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models)*2)
seen := make(map[string]struct{}, len(models)*2)
addModel := func(model *ModelInfo) {
if model == nil {
return
}
id := strings.TrimSpace(model.ID)
if id == "" {
return
}
if _, exists := seen[id]; exists {
return
}
seen[id] = struct{}{}
out = append(out, model)
}
for _, model := range models {
if model == nil {
continue
}
baseID := strings.TrimSpace(model.ID)
if baseID == "" {
continue
}
if !forceModelPrefix || trimmedPrefix == baseID {
addModel(model)
}
clone := *model
clone.ID = trimmedPrefix + "/" + baseID
addModel(&clone)
}
return out
}
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
func matchWildcard(pattern, value string) bool {
if pattern == "" {