mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 05:20:52 +08:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e998b1229a | ||
|
|
bbed134bd1 | ||
|
|
cb56cb250e | ||
|
|
e0381a6ae0 | ||
|
|
2c01b2ef64 | ||
|
|
e947266743 | ||
|
|
c6b0e85b54 | ||
|
|
26efbed05c | ||
|
|
96340bf136 | ||
|
|
b055e00c1a | ||
|
|
857c880f99 | ||
|
|
ce7474d953 | ||
|
|
70fdd70b84 | ||
|
|
08ab6a7d77 | ||
|
|
9fa2a7e9df | ||
|
|
d443c86620 | ||
|
|
7be3f1c36c | ||
|
|
f6ab6d97b9 | ||
|
|
bc866bac49 | ||
|
|
50e6d845f4 | ||
|
|
a8cb01819d | ||
|
|
44b63f0767 | ||
|
|
6705d20194 |
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# models:
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
@@ -106,7 +109,7 @@ ws-auth: false
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# - name: "gpt-5-codex" # upstream model name
|
||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
@@ -125,7 +128,7 @@ ws-auth: false
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
@@ -156,9 +159,9 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# - name: "gemini-2.5-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
|
||||
# Amp Integration
|
||||
@@ -188,12 +191,42 @@ ws-auth: false
|
||||
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# # but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
# model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
|
||||
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
|
||||
# - from: "claude-sonnet-4-5-20250929"
|
||||
# to: "gemini-claude-sonnet-4-5-thinking"
|
||||
# - from: "claude-haiku-4-5-20251001"
|
||||
# to: "gemini-2.5-flash"
|
||||
|
||||
# Global OAuth model name mappings (per channel)
|
||||
# These mappings rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# NOTE: Mappings do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# oauth-model-mappings:
|
||||
# gemini-cli:
|
||||
# - name: "gemini-2.5-pro" # original model name under this channel
|
||||
# alias: "g2.5p" # client-visible alias
|
||||
# vertex:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
# codex:
|
||||
# - name: "gpt-5"
|
||||
# alias: "g5"
|
||||
# qwen:
|
||||
# - name: "qwen3-coder-plus"
|
||||
# alias: "qwen-plus"
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
|
||||
@@ -856,7 +856,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
if err := logging.ConfigureLogOutput(cfg); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
if oldCfg == nil {
|
||||
|
||||
@@ -91,6 +91,14 @@ type Config struct {
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
|
||||
// OAuthModelMappings defines global model name mappings for OAuth/file-backed auth channels.
|
||||
// These mappings affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
OAuthModelMappings map[string][]ModelNameMapping `yaml:"oauth-model-mappings,omitempty" json:"oauth-model-mappings,omitempty"`
|
||||
|
||||
// Payload defines default and override rules for provider payload parameters.
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
|
||||
@@ -137,6 +145,13 @@ type RoutingConfig struct {
|
||||
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
|
||||
}
|
||||
|
||||
// ModelNameMapping defines a model ID rename mapping for a specific channel.
|
||||
// It maps the original model name (Name) to the client-visible alias (Alias).
|
||||
type ModelNameMapping struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, this mapping
|
||||
// allows routing to an alternative model that IS available.
|
||||
@@ -253,6 +268,9 @@ type ClaudeModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m ClaudeModel) GetName() string { return m.Name }
|
||||
func (m ClaudeModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// CodexKey represents the configuration for a Codex API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type CodexKey struct {
|
||||
@@ -288,6 +306,9 @@ type CodexModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m CodexModel) GetName() string { return m.Name }
|
||||
func (m CodexModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// GeminiKey represents the configuration for a Gemini API key,
|
||||
// including optional overrides for upstream base URL, proxy routing, and headers.
|
||||
type GeminiKey struct {
|
||||
@@ -303,6 +324,9 @@ type GeminiKey struct {
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
// Models defines upstream model names and aliases for request routing.
|
||||
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
@@ -310,6 +334,18 @@ type GeminiKey struct {
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiModel describes a mapping between an alias and the actual upstream model name.
|
||||
type GeminiModel struct {
|
||||
// Name is the upstream model identifier used when issuing requests.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the client-facing model name that maps to Name.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m GeminiModel) GetName() string { return m.Name }
|
||||
func (m GeminiModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||
type OpenAICompatibility struct {
|
||||
@@ -461,6 +497,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Normalize OAuth provider model exclusion map.
|
||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||
|
||||
// Normalize global OAuth model name mappings.
|
||||
cfg.SanitizeOAuthModelMappings()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
@@ -477,6 +516,50 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// SanitizeOAuthModelMappings normalizes and deduplicates global OAuth model name mappings.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// and ensures (From, To) pairs are unique within each channel.
|
||||
func (cfg *Config) SanitizeOAuthModelMappings() {
|
||||
if cfg == nil || len(cfg.OAuthModelMappings) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]ModelNameMapping, len(cfg.OAuthModelMappings))
|
||||
for rawChannel, mappings := range cfg.OAuthModelMappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(mappings) == 0 {
|
||||
continue
|
||||
}
|
||||
seenName := make(map[string]struct{}, len(mappings))
|
||||
seenAlias := make(map[string]struct{}, len(mappings))
|
||||
clean := make([]ModelNameMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
name := strings.TrimSpace(mapping.Name)
|
||||
alias := strings.TrimSpace(mapping.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, ok := seenName[nameKey]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := seenAlias[aliasKey]; ok {
|
||||
continue
|
||||
}
|
||||
seenName[nameKey] = struct{}{}
|
||||
seenAlias[aliasKey] = struct{}{}
|
||||
clean = append(clean, ModelNameMapping{Name: name, Alias: alias})
|
||||
}
|
||||
if len(clean) > 0 {
|
||||
out[channel] = clean
|
||||
}
|
||||
}
|
||||
cfg.OAuthModelMappings = out
|
||||
}
|
||||
|
||||
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
|
||||
// not actionable, specifically those missing a BaseURL. It trims whitespace before
|
||||
// evaluation and preserves the relative order of remaining entries.
|
||||
|
||||
@@ -42,6 +42,9 @@ type VertexCompatModel struct {
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
func (m VertexCompatModel) GetName() string { return m.Name }
|
||||
func (m VertexCompatModel) GetAlias() string { return m.Alias }
|
||||
|
||||
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if cfg == nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@@ -83,10 +84,30 @@ func SetupBaseLogger() {
|
||||
})
|
||||
}
|
||||
|
||||
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
|
||||
func isDirWritable(dir string) bool {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
|
||||
testFile := filepath.Join(dir, ".perm_test")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(testFile)
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
func ConfigureLogOutput(cfg *config.Config) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
@@ -95,10 +116,12 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
} else if !isDirWritable(logDir) {
|
||||
logDir = filepath.Join(cfg.AuthDir, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if loggingToFile {
|
||||
if cfg.LoggingToFile {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
@@ -122,7 +145,7 @@ func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
||||
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -24,10 +24,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -198,6 +199,16 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
@@ -210,15 +221,14 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if err := os.MkdirAll(staticDir, 0o755); err != nil {
|
||||
log.WithError(err).Warn("failed to prepare static directory for management asset")
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
@@ -229,6 +239,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
@@ -240,6 +257,13 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
@@ -256,6 +280,22 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to download fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to persist fallback management control panel page")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
|
||||
return true
|
||||
}
|
||||
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
|
||||
@@ -781,3 +781,29 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
}
|
||||
|
||||
// LookupStaticModelInfo searches all static model definitions for a model by ID.
|
||||
// Returns nil if no matching model is found.
|
||||
func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
allModels := [][]*ModelInfo{
|
||||
GetClaudeModels(),
|
||||
GetGeminiModels(),
|
||||
GetGeminiVertexModels(),
|
||||
GetGeminiCLIModels(),
|
||||
GetAIStudioModels(),
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
if m != nil && m.ID == modelID {
|
||||
return m
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||
wsReq := &wsrelay.HTTPRequest{
|
||||
Method: http.MethodPost,
|
||||
|
||||
@@ -76,7 +76,8 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if strings.Contains(req.Model, "claude") {
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
if isClaude {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
@@ -98,7 +99,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -193,7 +194,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -520,6 +521,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -527,7 +530,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated, isClaude)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
@@ -676,6 +679,8 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
@@ -694,7 +699,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload, isClaude)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
@@ -1308,7 +1313,7 @@ func alias2ModelName(modelName string) string {
|
||||
|
||||
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
|
||||
// For Claude models, it additionally ensures thinking budget < max_tokens.
|
||||
func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
func normalizeAntigravityThinking(model string, payload []byte, isClaude bool) []byte {
|
||||
payload = util.StripThinkingConfigIfUnsupported(model, payload)
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
@@ -1320,7 +1325,6 @@ func normalizeAntigravityThinking(model string, payload []byte) []byte {
|
||||
raw := int(budget.Int())
|
||||
normalized := util.NormalizeThinkingBudget(model, raw)
|
||||
|
||||
isClaude := strings.Contains(strings.ToLower(model), "claude")
|
||||
if isClaude {
|
||||
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
|
||||
if effectiveMax > 0 && normalized >= effectiveMax {
|
||||
|
||||
@@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
// Inject thinking config based on model metadata for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, req.Metadata, body)
|
||||
body = e.injectThinkingConfig(model, req.Metadata, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
body = ensureMaxTokensForThinking(model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
@@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") {
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
|
||||
@@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
|
||||
@@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel == "" {
|
||||
upstreamModel = req.Model
|
||||
}
|
||||
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
} else if !strings.EqualFold(upstreamModel, req.Model) {
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
upstreamModel = modelOverride
|
||||
}
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := upstreamModel
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
enc, err := tokenizerForCodexModel(modelForCounting)
|
||||
enc, err := tokenizerForCodexModel(model)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response, reqBody []byte, attempt string) {
|
||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
@@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -417,6 +417,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
@@ -425,7 +427,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload)
|
||||
payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
|
||||
|
||||
tok, errTok := tokenSource.Token()
|
||||
if errTok != nil {
|
||||
|
||||
@@ -77,19 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -98,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
}
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -173,21 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -287,19 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
apiKey, bearer := geminiCreds(auth)
|
||||
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
|
||||
|
||||
requestBody := bytes.NewReader(translatedReq)
|
||||
|
||||
@@ -398,6 +410,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
|
||||
return base
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveGeminiConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.GeminiKey {
|
||||
entry := &e.cfg.GeminiKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
|
||||
@@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
}
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
@@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(model, body)
|
||||
body = fixGeminiImageAspectRatio(model, body)
|
||||
body = applyPayloadConfig(e.cfg, model, body)
|
||||
body, _ = sjson.SetBytes(body, "model", model)
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
@@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
|
||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
@@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
model := req.Model
|
||||
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
|
||||
model = override
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
@@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
@@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
|
||||
}
|
||||
return tok.AccessToken, nil
|
||||
}
|
||||
|
||||
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
|
||||
// It matches the requested model alias against configured models and returns the actual upstream name.
|
||||
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
trimmed := strings.TrimSpace(alias)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
entry := e.resolveVertexConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
|
||||
|
||||
// Candidate names to match against configured aliases/names.
|
||||
candidates := []string{strings.TrimSpace(normalizedModel)}
|
||||
if !strings.EqualFold(normalizedModel, trimmed) {
|
||||
candidates = append(candidates, trimmed)
|
||||
}
|
||||
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
|
||||
candidates = append(candidates, original)
|
||||
}
|
||||
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
modelAlias := strings.TrimSpace(model.Alias)
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
|
||||
if name != "" {
|
||||
return name
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return name
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
|
||||
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.VertexCompatAPIKey {
|
||||
entry := &e.cfg.VertexCompatAPIKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -58,12 +58,9 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
@@ -151,12 +148,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
|
||||
@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body, _ = sjson.SetBytes(body, "model", req.Model)
|
||||
body = NormalizeThinkingConfig(body, req.Model, false)
|
||||
if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -23,6 +23,7 @@ type geminiToResponsesState struct {
|
||||
MsgIndex int
|
||||
CurrentMsgID string
|
||||
TextBuf strings.Builder
|
||||
ItemTextBuf strings.Builder
|
||||
|
||||
// reasoning aggregation
|
||||
ReasoningOpened bool
|
||||
@@ -189,6 +190,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
|
||||
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
||||
out = append(out, emitEvent("response.content_part.added", partAdded))
|
||||
st.ItemTextBuf.Reset()
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
}
|
||||
st.TextBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
@@ -250,20 +253,24 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
finalizeReasoning()
|
||||
// Close message output if opened
|
||||
if st.MsgOpened {
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ThinkingBudgetMetadataKey = "thinking_budget"
|
||||
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
||||
ReasoningEffortMetadataKey = "reasoning_effort"
|
||||
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
||||
ThinkingBudgetMetadataKey = "thinking_budget"
|
||||
ThinkingIncludeThoughtsMetadataKey = "thinking_include_thoughts"
|
||||
ReasoningEffortMetadataKey = "reasoning_effort"
|
||||
ThinkingOriginalModelMetadataKey = "thinking_original_model"
|
||||
ModelMappingOriginalModelMetadataKey = "model_mapping_original_model"
|
||||
)
|
||||
|
||||
// NormalizeThinkingModel parses dynamic thinking suffixes on model names and returns
|
||||
@@ -215,6 +216,13 @@ func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
}
|
||||
|
||||
if metadata != nil {
|
||||
if v, ok := metadata[ModelMappingOriginalModelMetadataKey]; ok {
|
||||
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||
if base := normalize(s); base != "" {
|
||||
return base
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := metadata[ThinkingOriginalModelMetadataKey]; ok {
|
||||
if s, okStr := v.(string); okStr && strings.TrimSpace(s) != "" {
|
||||
if base := normalize(s); base != "" {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -126,7 +127,7 @@ func (w *Watcher) reloadConfig() bool {
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelMappings, newConfig.OAuthModelMappings))
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
|
||||
@@ -90,6 +90,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeGeminiModels(o.Models)
|
||||
newModels := SummarizeGeminiModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -120,6 +125,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeClaudeModels(o.Models)
|
||||
newModels := SummarizeClaudeModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -150,6 +160,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldModels := SummarizeCodexModels(o.Models)
|
||||
newModels := SummarizeCodexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
@@ -194,6 +209,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
if entries, _ := DiffOAuthModelMappingChanges(oldCfg.OAuthModelMappings, newCfg.OAuthModelMappings); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
|
||||
@@ -71,6 +71,21 @@ func ComputeCodexModelsHash(models []config.CodexModel) string {
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
|
||||
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
|
||||
121
internal/watcher/diff/models_summary.go
Normal file
121
internal/watcher/diff/models_summary.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type GeminiModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type ClaudeModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type CodexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
|
||||
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return GeminiModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return GeminiModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeClaudeModels hashes Claude model aliases for change detection.
|
||||
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return ClaudeModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return ClaudeModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeCodexModels hashes Codex model aliases for change detection.
|
||||
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return CodexModelsSummary{}
|
||||
}
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return CodexModelsSummary{
|
||||
hash: hashJoined(keys),
|
||||
count: len(keys),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
@@ -116,36 +116,3 @@ func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappin
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
|
||||
98
internal/watcher/diff/oauth_model_mappings.go
Normal file
98
internal/watcher/diff/oauth_model_mappings.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type OAuthModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeOAuthModelMappings summarizes OAuth model mappings per channel.
|
||||
func SummarizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string]OAuthModelMappingsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]OAuthModelMappingsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = summarizeOAuthModelMappingList(v)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthModelMappingChanges compares OAuth model mappings maps.
|
||||
func DiffOAuthModelMappingChanges(oldMap, newMap map[string][]config.ModelNameMapping) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthModelMappings(oldMap)
|
||||
newSummary := SummarizeOAuthModelMappings(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-model-mappings[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
func summarizeOAuthModelMappingList(list []config.ModelNameMapping) OAuthModelMappingsSummary {
|
||||
if len(list) == 0 {
|
||||
return OAuthModelMappingsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, mapping := range list {
|
||||
name := strings.ToLower(strings.TrimSpace(mapping.Name))
|
||||
alias := strings.ToLower(strings.TrimSpace(mapping.Alias))
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
key := name + "->" + alias
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return OAuthModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
|
||||
return OAuthModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(entry.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
|
||||
@@ -20,6 +20,7 @@ type ManagementTokenRequester interface {
|
||||
RequestQwenToken(*gin.Context)
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
GetAuthStatus(c *gin.Context)
|
||||
}
|
||||
|
||||
type managementTokenRequester struct {
|
||||
@@ -60,3 +61,7 @@ func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
|
||||
m.handler.RequestIFlowCookieToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
|
||||
m.handler.GetAuthStatus(c)
|
||||
}
|
||||
|
||||
@@ -111,6 +111,9 @@ type Manager struct {
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
|
||||
modelNameMappings atomic.Value
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
@@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
@@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
util.ModelMappingOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
|
||||
169
sdk/cliproxy/auth/model_name_mappings.go
Normal file
169
sdk/cliproxy/auth/model_name_mappings.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
type modelNameMappingTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
}
|
||||
|
||||
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
|
||||
if len(mappings) == 0 {
|
||||
return &modelNameMappingTable{}
|
||||
}
|
||||
out := &modelNameMappingTable{
|
||||
reverse: make(map[string]map[string]string, len(mappings)),
|
||||
}
|
||||
for rawChannel, entries := range mappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
rev := make(map[string]string, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, exists := rev[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
rev[aliasKey] = name
|
||||
}
|
||||
if len(rev) > 0 {
|
||||
out.reverse[channel] = rev
|
||||
}
|
||||
}
|
||||
if len(out.reverse) == 0 {
|
||||
out.reverse = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
|
||||
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
|
||||
// client-visible model name unchanged for translation/response formatting.
|
||||
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
table := compileModelNameMappingTable(mappings)
|
||||
// atomic.Value requires non-nil store values.
|
||||
if table == nil {
|
||||
table = &modelNameMappingTable{}
|
||||
}
|
||||
m.modelNameMappings.Store(table)
|
||||
}
|
||||
|
||||
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
|
||||
// and returns the resolved model along with updated metadata. If a mapping exists,
|
||||
// the returned model is the upstream model and metadata contains the original
|
||||
// requested model for response translation.
|
||||
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
|
||||
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return requestedModel, metadata
|
||||
}
|
||||
out := make(map[string]any, 1)
|
||||
if len(metadata) > 0 {
|
||||
out = make(map[string]any, len(metadata)+1)
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel
|
||||
return upstreamModel, out
|
||||
}
|
||||
|
||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
channel := modelMappingChannel(auth)
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw := m.modelNameMappings.Load()
|
||||
table, _ := raw.(*modelNameMappingTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
return ""
|
||||
}
|
||||
rev := table.reverse[channel]
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||
return ""
|
||||
}
|
||||
return original
|
||||
}
|
||||
|
||||
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
|
||||
// It determines the provider and auth kind from the Auth's attributes and delegates
|
||||
// to OAuthModelMappingChannel for the actual channel resolution.
|
||||
func modelMappingChannel(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
authKind := ""
|
||||
if auth.Attributes != nil {
|
||||
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||
}
|
||||
if authKind == "" {
|
||||
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
return OAuthModelMappingChannel(provider, authKind)
|
||||
}
|
||||
|
||||
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model mappings (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
func OAuthModelMappingChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
|
||||
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
|
||||
return ""
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
|
||||
}
|
||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
||||
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
|
||||
|
||||
service := &Service{
|
||||
cfg: b.cfg,
|
||||
|
||||
@@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.cfgMu.Lock()
|
||||
s.cfg = newCfg
|
||||
s.cfgMu.Unlock()
|
||||
if s.coreManager != nil {
|
||||
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
|
||||
}
|
||||
s.rebindExecutors()
|
||||
}
|
||||
|
||||
@@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
return
|
||||
}
|
||||
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
||||
if authKind == "" {
|
||||
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -702,6 +710,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "gemini":
|
||||
models = registry.GetGeminiModels()
|
||||
if entry := s.resolveConfigGeminiKey(a); entry != nil {
|
||||
if len(entry.Models) > 0 {
|
||||
models = buildGeminiConfigModels(entry)
|
||||
}
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
@@ -836,6 +847,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
}
|
||||
}
|
||||
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
|
||||
if len(models) > 0 {
|
||||
key := provider
|
||||
if key == "" {
|
||||
@@ -1107,17 +1119,22 @@ func matchWildcard(pattern, value string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
type modelEntry interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}
|
||||
|
||||
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for i := range models {
|
||||
model := models[i]
|
||||
name := strings.TrimSpace(model.GetName())
|
||||
alias := strings.TrimSpace(model.GetAlias())
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
@@ -1133,90 +1150,135 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if display == "" {
|
||||
display = alias
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
info := &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "vertex",
|
||||
Type: "vertex",
|
||||
OwnedBy: ownedBy,
|
||||
Type: modelType,
|
||||
DisplayName: display,
|
||||
})
|
||||
}
|
||||
if name != "" {
|
||||
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
|
||||
info.Thinking = upstream.Thinking
|
||||
}
|
||||
}
|
||||
out = append(out, info)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return buildConfigModels(entry.Models, "google", "vertex")
|
||||
}
|
||||
|
||||
func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
return buildConfigModels(entry.Models, "google", "gemini")
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
display := name
|
||||
if display == "" {
|
||||
display = alias
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "claude",
|
||||
Type: "claude",
|
||||
DisplayName: display,
|
||||
})
|
||||
}
|
||||
return out
|
||||
return buildConfigModels(entry.Models, "anthropic", "claude")
|
||||
}
|
||||
|
||||
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
if alias == "" {
|
||||
return buildConfigModels(entry.Models, "openai", "openai")
|
||||
}
|
||||
|
||||
func rewriteModelInfoName(name, oldID, newID string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return name
|
||||
}
|
||||
oldID = strings.TrimSpace(oldID)
|
||||
newID = strings.TrimSpace(newID)
|
||||
if oldID == "" || newID == "" {
|
||||
return name
|
||||
}
|
||||
if strings.EqualFold(oldID, newID) {
|
||||
return name
|
||||
}
|
||||
if strings.HasSuffix(trimmed, "/"+oldID) {
|
||||
prefix := strings.TrimSuffix(trimmed, oldID)
|
||||
return prefix + newID
|
||||
}
|
||||
if trimmed == "models/"+oldID {
|
||||
return "models/" + newID
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||
if cfg == nil || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
|
||||
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
|
||||
return models
|
||||
}
|
||||
mappings := cfg.OAuthModelMappings[channel]
|
||||
if len(mappings) == 0 {
|
||||
return models
|
||||
}
|
||||
forward := make(map[string]string, len(mappings))
|
||||
for i := range mappings {
|
||||
name := strings.TrimSpace(mappings[i].Name)
|
||||
alias := strings.TrimSpace(mappings[i].Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
if _, exists := seen[key]; exists {
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
display := name
|
||||
if display == "" {
|
||||
display = alias
|
||||
key := strings.ToLower(name)
|
||||
if _, exists := forward[key]; exists {
|
||||
continue
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
DisplayName: display,
|
||||
})
|
||||
forward[key] = alias
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
return models
|
||||
}
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
mappedID := id
|
||||
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
|
||||
mappedID = strings.TrimSpace(to)
|
||||
}
|
||||
uniqueKey := strings.ToLower(mappedID)
|
||||
if _, exists := seen[uniqueKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[uniqueKey] = struct{}{}
|
||||
if mappedID == id {
|
||||
out = append(out, model)
|
||||
continue
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = mappedID
|
||||
if clone.Name != "" {
|
||||
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||
}
|
||||
out = append(out, &clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
type ModelNameMapping = internalconfig.ModelNameMapping
|
||||
type PayloadConfig = internalconfig.PayloadConfig
|
||||
type PayloadRule = internalconfig.PayloadRule
|
||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||
|
||||
Reference in New Issue
Block a user