mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
feat(registry): support provider-specific model info lookup
This commit is contained in:
@@ -78,6 +78,8 @@ type ThinkingSupport struct {
|
||||
type ModelRegistration struct {
|
||||
// Info contains the model metadata
|
||||
Info *ModelInfo
|
||||
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
|
||||
InfoByProvider map[string]*ModelInfo
|
||||
// Count is the number of active clients that can provide this model
|
||||
Count int
|
||||
// LastUpdated tracks when this registration was last modified
|
||||
@@ -132,16 +134,19 @@ func GetGlobalRegistry() *ModelRegistry {
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// LookupModelInfo searches the dynamic registry first, then falls back to static model definitions.
|
||||
//
|
||||
// This helper exists because some code paths only have a model ID and still need Thinking and
|
||||
// max completion token metadata even when the dynamic registry hasn't been populated.
|
||||
func LookupModelInfo(modelID string) *ModelInfo {
|
||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID); info != nil {
|
||||
|
||||
p := ""
|
||||
if len(provider) > 0 {
|
||||
p = strings.ToLower(strings.TrimSpace(provider[0]))
|
||||
}
|
||||
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||
return info
|
||||
}
|
||||
return LookupStaticModelInfo(modelID)
|
||||
@@ -297,6 +302,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||
if count <= toRemove {
|
||||
delete(reg.Providers, oldProvider)
|
||||
if reg.InfoByProvider != nil {
|
||||
delete(reg.InfoByProvider, oldProvider)
|
||||
}
|
||||
} else {
|
||||
reg.Providers[oldProvider] = count - toRemove
|
||||
}
|
||||
@@ -346,6 +354,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
model := newModels[id]
|
||||
if reg, ok := r.models[id]; ok {
|
||||
reg.Info = cloneModelInfo(model)
|
||||
if provider != "" {
|
||||
if reg.InfoByProvider == nil {
|
||||
reg.InfoByProvider = make(map[string]*ModelInfo)
|
||||
}
|
||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
reg.LastUpdated = now
|
||||
if reg.QuotaExceededClients != nil {
|
||||
delete(reg.QuotaExceededClients, clientID)
|
||||
@@ -409,11 +423,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
if existing.SuspendedClients == nil {
|
||||
existing.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if existing.InfoByProvider == nil {
|
||||
existing.InfoByProvider = make(map[string]*ModelInfo)
|
||||
}
|
||||
if provider != "" {
|
||||
if existing.Providers == nil {
|
||||
existing.Providers = make(map[string]int)
|
||||
}
|
||||
existing.Providers[provider]++
|
||||
existing.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||
return
|
||||
@@ -421,6 +439,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
|
||||
registration := &ModelRegistration{
|
||||
Info: cloneModelInfo(model),
|
||||
InfoByProvider: make(map[string]*ModelInfo),
|
||||
Count: 1,
|
||||
LastUpdated: now,
|
||||
QuotaExceededClients: make(map[string]*time.Time),
|
||||
@@ -428,6 +447,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
}
|
||||
if provider != "" {
|
||||
registration.Providers = map[string]int{provider: 1}
|
||||
registration.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
r.models[modelID] = registration
|
||||
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||
@@ -453,6 +473,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
if count <= 1 {
|
||||
delete(registration.Providers, provider)
|
||||
if registration.InfoByProvider != nil {
|
||||
delete(registration.InfoByProvider, provider)
|
||||
}
|
||||
} else {
|
||||
registration.Providers[provider] = count - 1
|
||||
}
|
||||
@@ -534,6 +557,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
if count <= 1 {
|
||||
delete(registration.Providers, provider)
|
||||
if registration.InfoByProvider != nil {
|
||||
delete(registration.InfoByProvider, provider)
|
||||
}
|
||||
} else {
|
||||
registration.Providers[provider] = count - 1
|
||||
}
|
||||
@@ -940,12 +966,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
|
||||
// Returns nil if the model is unknown to the registry.
|
||||
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
|
||||
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
|
||||
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if reg, ok := r.models[modelID]; ok && reg != nil {
|
||||
// Try provider specific definition first
|
||||
if provider != "" && reg.InfoByProvider != nil {
|
||||
if reg.Providers != nil {
|
||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to global info (last registered)
|
||||
return reg.Info
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -393,7 +393,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -622,7 +622,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -802,7 +802,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -235,7 +235,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -316,7 +316,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -479,7 +479,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -432,7 +432,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -535,7 +535,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -658,7 +658,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -773,7 +773,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
@@ -857,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
||||
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||
//
|
||||
// Returns:
|
||||
// - Modified request body JSON with thinking configuration applied
|
||||
@@ -79,12 +80,16 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||
// Example:
|
||||
//
|
||||
// // With suffix - suffix config takes priority
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini")
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
|
||||
//
|
||||
// // Without suffix - uses body config
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini")
|
||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) {
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
|
||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
|
||||
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
||||
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
|
||||
if providerKey == "" {
|
||||
providerKey = providerFormat
|
||||
}
|
||||
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
||||
if fromFormat == "" {
|
||||
fromFormat = providerFormat
|
||||
@@ -102,7 +107,8 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string
|
||||
// 2. Parse suffix and get modelInfo
|
||||
suffixResult := ParseSuffix(model)
|
||||
baseModel := suffixResult.ModelName
|
||||
modelInfo := registry.LookupModelInfo(baseModel)
|
||||
// Use provider-specific lookup to handle capability differences across providers.
|
||||
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
|
||||
|
||||
// 3. Model capability check
|
||||
// Unknown models are treated as user-defined so thinking config can still be applied.
|
||||
|
||||
@@ -2712,7 +2712,7 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", 200000)
|
||||
}
|
||||
|
||||
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo)
|
||||
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo, applyTo)
|
||||
|
||||
if tc.expectErr {
|
||||
if err == nil {
|
||||
|
||||
Reference in New Issue
Block a user