mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 20:30:51 +08:00
refactor: improve thinking logic
This commit is contained in:
@@ -4,9 +4,15 @@ import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
type modelMappingEntry interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}
|
||||
|
||||
type modelNameMappingTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
@@ -71,9 +77,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
|
||||
// requested model for response translation.
|
||||
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
|
||||
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||
return applyUpstreamModelOverride(requestedModel, upstreamModel, metadata)
|
||||
}
|
||||
|
||||
func applyUpstreamModelOverride(requestedModel, upstreamModel string, metadata map[string]any) (string, map[string]any) {
|
||||
if upstreamModel == "" {
|
||||
return requestedModel, metadata
|
||||
}
|
||||
|
||||
out := make(map[string]any, 1)
|
||||
if len(metadata) > 0 {
|
||||
out = make(map[string]any, len(metadata)+1)
|
||||
@@ -81,24 +92,92 @@ func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, meta
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
// Store the requested alias (e.g., "gp") so downstream can use it to look up
|
||||
// model metadata from the global registry where it was registered under this alias.
|
||||
|
||||
// Preserve the original client model string (including any suffix) for downstream.
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
return upstreamModel, out
|
||||
}
|
||||
|
||||
func resolveModelAliasFromConfigModels(requestedModel string, models []modelMappingEntry) string {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
base := requestResult.ModelName
|
||||
candidates := []string{base}
|
||||
if base != requestedModel {
|
||||
candidates = append(candidates, requestedModel)
|
||||
}
|
||||
|
||||
preserveSuffix := func(resolved string) string {
|
||||
resolved = strings.TrimSpace(resolved)
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
for i := range models {
|
||||
name := strings.TrimSpace(models[i].GetName())
|
||||
alias := strings.TrimSpace(models[i].GetAlias())
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" && strings.EqualFold(alias, candidate) {
|
||||
if name != "" {
|
||||
return preserveSuffix(name)
|
||||
}
|
||||
return preserveSuffix(candidate)
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return preserveSuffix(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveOAuthUpstreamModel resolves the upstream model name from OAuth model mappings.
|
||||
// If a mapping exists, returns the original (upstream) model name that corresponds
|
||||
// to the requested alias.
|
||||
//
|
||||
// If the requested model contains a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
|
||||
// the suffix is preserved in the returned model name. However, if the mapping's
|
||||
// original name already contains a suffix, the config suffix takes priority.
|
||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||
return resolveUpstreamModelFromMappingTable(m, auth, requestedModel, modelMappingChannel(auth))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelFromMappingTable(m *Manager, auth *Auth, requestedModel, channel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
channel := modelMappingChannel(auth)
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if key == "" {
|
||||
return ""
|
||||
|
||||
// Extract thinking suffix from requested model using ParseSuffix
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := requestResult.ModelName
|
||||
|
||||
// Candidate keys to match: base model and raw input (handles suffix-parsing edge cases).
|
||||
candidates := []string{baseModel}
|
||||
if baseModel != requestedModel {
|
||||
candidates = append(candidates, requestedModel)
|
||||
}
|
||||
|
||||
raw := m.modelNameMappings.Load()
|
||||
table, _ := raw.(*modelNameMappingTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
@@ -108,11 +187,32 @@ func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) s
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||
return ""
|
||||
|
||||
for _, candidate := range candidates {
|
||||
key := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(original, baseModel) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If config already has suffix, it takes priority.
|
||||
if thinking.ParseSuffix(original).HasSuffix {
|
||||
return original
|
||||
}
|
||||
// Preserve user's thinking suffix on the resolved model.
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return original + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return original
|
||||
}
|
||||
return original
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
|
||||
|
||||
Reference in New Issue
Block a user