mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
5baa753539 | ||
|
|
ead98e4bca | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
5977af96a0 |
@@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
// Imagen image generation models - use :predict action
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Generate",
|
||||||
|
Description: "Imagen 4.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-ultra-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-ultra-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||||
|
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-generate-002",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-generate-002",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Generate",
|
||||||
|
Description: "Imagen 3.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-fast-generate-001",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Fast Generate",
|
||||||
|
Description: "Imagen 3.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-fast-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Fast Generate",
|
||||||
|
Description: "Imagen 4.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -765,21 +826,23 @@ func GetIFlowModels() []*ModelInfo {
|
|||||||
type AntigravityModelConfig struct {
|
type AntigravityModelConfig struct {
|
||||||
Thinking *ThinkingSupport
|
Thinking *ThinkingSupport
|
||||||
MaxCompletionTokens int
|
MaxCompletionTokens int
|
||||||
Name string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
return map[string]*AntigravityModelConfig{
|
return map[string]*AntigravityModelConfig{
|
||||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/rev19-uic3-1p"},
|
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-high"},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image"},
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash"},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
|
"gpt-oss-120b-medium": {},
|
||||||
|
"tab_flash_lite_preview": {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -809,10 +872,9 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
// Check Antigravity static config
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil && cfg.Thinking != nil {
|
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||||
return &ModelInfo{
|
return &ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Name: cfg.Name,
|
|
||||||
Thinking: cfg.Thinking,
|
Thinking: cfg.Thinking,
|
||||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,6 +78,8 @@ type ThinkingSupport struct {
|
|||||||
type ModelRegistration struct {
|
type ModelRegistration struct {
|
||||||
// Info contains the model metadata
|
// Info contains the model metadata
|
||||||
Info *ModelInfo
|
Info *ModelInfo
|
||||||
|
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
|
||||||
|
InfoByProvider map[string]*ModelInfo
|
||||||
// Count is the number of active clients that can provide this model
|
// Count is the number of active clients that can provide this model
|
||||||
Count int
|
Count int
|
||||||
// LastUpdated tracks when this registration was last modified
|
// LastUpdated tracks when this registration was last modified
|
||||||
@@ -132,16 +134,19 @@ func GetGlobalRegistry() *ModelRegistry {
|
|||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupModelInfo searches the dynamic registry first, then falls back to static model definitions.
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
//
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
// This helper exists because some code paths only have a model ID and still need Thinking and
|
|
||||||
// max completion token metadata even when the dynamic registry hasn't been populated.
|
|
||||||
func LookupModelInfo(modelID string) *ModelInfo {
|
|
||||||
modelID = strings.TrimSpace(modelID)
|
modelID = strings.TrimSpace(modelID)
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if info := GetGlobalRegistry().GetModelInfo(modelID); info != nil {
|
|
||||||
|
p := ""
|
||||||
|
if len(provider) > 0 {
|
||||||
|
p = strings.ToLower(strings.TrimSpace(provider[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
return LookupStaticModelInfo(modelID)
|
return LookupStaticModelInfo(modelID)
|
||||||
@@ -297,6 +302,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
if count, okProv := reg.Providers[oldProvider]; okProv {
|
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||||
if count <= toRemove {
|
if count <= toRemove {
|
||||||
delete(reg.Providers, oldProvider)
|
delete(reg.Providers, oldProvider)
|
||||||
|
if reg.InfoByProvider != nil {
|
||||||
|
delete(reg.InfoByProvider, oldProvider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reg.Providers[oldProvider] = count - toRemove
|
reg.Providers[oldProvider] = count - toRemove
|
||||||
}
|
}
|
||||||
@@ -346,6 +354,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
model := newModels[id]
|
model := newModels[id]
|
||||||
if reg, ok := r.models[id]; ok {
|
if reg, ok := r.models[id]; ok {
|
||||||
reg.Info = cloneModelInfo(model)
|
reg.Info = cloneModelInfo(model)
|
||||||
|
if provider != "" {
|
||||||
|
if reg.InfoByProvider == nil {
|
||||||
|
reg.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
|
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
|
}
|
||||||
reg.LastUpdated = now
|
reg.LastUpdated = now
|
||||||
if reg.QuotaExceededClients != nil {
|
if reg.QuotaExceededClients != nil {
|
||||||
delete(reg.QuotaExceededClients, clientID)
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
@@ -409,11 +423,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
if existing.SuspendedClients == nil {
|
if existing.SuspendedClients == nil {
|
||||||
existing.SuspendedClients = make(map[string]string)
|
existing.SuspendedClients = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
if existing.InfoByProvider == nil {
|
||||||
|
existing.InfoByProvider = make(map[string]*ModelInfo)
|
||||||
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
if existing.Providers == nil {
|
if existing.Providers == nil {
|
||||||
existing.Providers = make(map[string]int)
|
existing.Providers = make(map[string]int)
|
||||||
}
|
}
|
||||||
existing.Providers[provider]++
|
existing.Providers[provider]++
|
||||||
|
existing.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||||
return
|
return
|
||||||
@@ -421,6 +439,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
|
|
||||||
registration := &ModelRegistration{
|
registration := &ModelRegistration{
|
||||||
Info: cloneModelInfo(model),
|
Info: cloneModelInfo(model),
|
||||||
|
InfoByProvider: make(map[string]*ModelInfo),
|
||||||
Count: 1,
|
Count: 1,
|
||||||
LastUpdated: now,
|
LastUpdated: now,
|
||||||
QuotaExceededClients: make(map[string]*time.Time),
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
@@ -428,6 +447,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
|||||||
}
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
registration.Providers = map[string]int{provider: 1}
|
registration.Providers = map[string]int{provider: 1}
|
||||||
|
registration.InfoByProvider[provider] = cloneModelInfo(model)
|
||||||
}
|
}
|
||||||
r.models[modelID] = registration
|
r.models[modelID] = registration
|
||||||
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||||
@@ -453,6 +473,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -534,6 +557,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
if count, ok := registration.Providers[provider]; ok {
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
if count <= 1 {
|
if count <= 1 {
|
||||||
delete(registration.Providers, provider)
|
delete(registration.Providers, provider)
|
||||||
|
if registration.InfoByProvider != nil {
|
||||||
|
delete(registration.InfoByProvider, provider)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
registration.Providers[provider] = count - 1
|
registration.Providers[provider] = count - 1
|
||||||
}
|
}
|
||||||
@@ -940,12 +966,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
|
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
|
||||||
// Returns nil if the model is unknown to the registry.
|
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||||
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
|
|
||||||
r.mutex.RLock()
|
r.mutex.RLock()
|
||||||
defer r.mutex.RUnlock()
|
defer r.mutex.RUnlock()
|
||||||
if reg, ok := r.models[modelID]; ok && reg != nil {
|
if reg, ok := r.models[modelID]; ok && reg != nil {
|
||||||
|
// Try provider specific definition first
|
||||||
|
if provider != "" && reg.InfoByProvider != nil {
|
||||||
|
if reg.Providers != nil {
|
||||||
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return reg.Info
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -393,7 +393,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
|||||||
}
|
}
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, translatedPayload{}, err
|
return nil, translatedPayload{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -256,7 +256,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -622,7 +622,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -802,7 +802,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
// Prepare payload once (doesn't depend on baseURL)
|
// Prepare payload once (doesn't depend on baseURL)
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -1005,9 +1005,6 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
}
|
}
|
||||||
modelCfg := modelConfig[modelID]
|
modelCfg := modelConfig[modelID]
|
||||||
modelName := modelID
|
modelName := modelID
|
||||||
if modelCfg != nil && modelCfg.Name != "" {
|
|
||||||
modelName = modelCfg.Name
|
|
||||||
}
|
|
||||||
modelInfo := ®istry.ModelInfo{
|
modelInfo := ®istry.ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Name: modelName,
|
Name: modelName,
|
||||||
@@ -1410,13 +1407,6 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
|||||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
|
|
||||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
|
||||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
|
||||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
|
||||||
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
if strings.Contains(modelName, "claude") {
|
||||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -106,7 +105,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -119,9 +118,6 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
|
||||||
body = ensureMaxTokensForThinking(baseModel, body)
|
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -239,7 +235,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -250,9 +246,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
|
||||||
body = ensureMaxTokensForThinking(baseModel, body)
|
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
extraBetas, body = extractAndRemoveBetas(body)
|
extraBetas, body = extractAndRemoveBetas(body)
|
||||||
@@ -541,81 +534,6 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
|
||||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
|
||||||
// This function should be called after all thinking configuration is finalized.
|
|
||||||
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
|
|
||||||
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
|
|
||||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
|
||||||
if thinkingType != "enabled" {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
|
|
||||||
if budgetTokens <= 0 {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
|
|
||||||
|
|
||||||
// Look up the model's max completion tokens from the registry
|
|
||||||
maxCompletionTokens := 0
|
|
||||||
if modelInfo := registry.LookupModelInfo(modelName); modelInfo != nil {
|
|
||||||
maxCompletionTokens = modelInfo.MaxCompletionTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to budget + buffer if registry lookup fails or returns 0
|
|
||||||
const fallbackBuffer = 4000
|
|
||||||
requiredMaxTokens := budgetTokens + fallbackBuffer
|
|
||||||
if maxCompletionTokens > 0 {
|
|
||||||
requiredMaxTokens = int64(maxCompletionTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxTokens < requiredMaxTokens {
|
|
||||||
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
|
|
||||||
}
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ClaudeExecutor) resolveClaudeConfig(auth *cliproxyauth.Auth) *config.ClaudeKey {
|
|
||||||
if auth == nil || e.cfg == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var attrKey, attrBase string
|
|
||||||
if auth.Attributes != nil {
|
|
||||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
|
||||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
|
||||||
}
|
|
||||||
for i := range e.cfg.ClaudeKey {
|
|
||||||
entry := &e.cfg.ClaudeKey[i]
|
|
||||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
|
||||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
|
||||||
if attrKey != "" && attrBase != "" {
|
|
||||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
|
||||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if attrKey != "" {
|
|
||||||
for i := range e.cfg.ClaudeKey {
|
|
||||||
entry := &e.cfg.ClaudeKey[i]
|
|
||||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type compositeReadCloser struct {
|
type compositeReadCloser struct {
|
||||||
io.Reader
|
io.Reader
|
||||||
closers []func() error
|
closers []func() error
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -208,7 +208,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -316,7 +316,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||||
body = misc.StripCodexUserAgent(body)
|
body = misc.StripCodexUserAgent(body)
|
||||||
|
|
||||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -272,7 +272,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -479,7 +479,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
for range models {
|
for range models {
|
||||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -338,7 +338,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -31,6 +32,143 @@ const (
|
|||||||
vertexAPIVersion = "v1"
|
vertexAPIVersion = "v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isImagenModel checks if the model name is an Imagen image generation model.
|
||||||
|
// Imagen models use the :predict action instead of :generateContent.
|
||||||
|
func isImagenModel(model string) bool {
|
||||||
|
lowerModel := strings.ToLower(model)
|
||||||
|
return strings.Contains(lowerModel, "imagen")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVertexAction returns the appropriate action for the given model.
|
||||||
|
// Imagen models use "predict", while Gemini models use "generateContent".
|
||||||
|
func getVertexAction(model string, isStream bool) string {
|
||||||
|
if isImagenModel(model) {
|
||||||
|
return "predict"
|
||||||
|
}
|
||||||
|
if isStream {
|
||||||
|
return "streamGenerateContent"
|
||||||
|
}
|
||||||
|
return "generateContent"
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
|
||||||
|
// so it can be processed by the standard translation pipeline.
|
||||||
|
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
|
||||||
|
func convertImagenToGeminiResponse(data []byte, model string) []byte {
|
||||||
|
predictions := gjson.GetBytes(data, "predictions")
|
||||||
|
if !predictions.Exists() || !predictions.IsArray() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Gemini-compatible response with inlineData
|
||||||
|
parts := make([]map[string]any, 0)
|
||||||
|
for _, pred := range predictions.Array() {
|
||||||
|
imageData := pred.Get("bytesBase64Encoded").String()
|
||||||
|
mimeType := pred.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
if imageData != "" {
|
||||||
|
parts = append(parts, map[string]any{
|
||||||
|
"inlineData": map[string]any{
|
||||||
|
"mimeType": mimeType,
|
||||||
|
"data": imageData,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique response ID using timestamp
|
||||||
|
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"candidates": []map[string]any{{
|
||||||
|
"content": map[string]any{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
}},
|
||||||
|
"responseId": responseId,
|
||||||
|
"modelVersion": model,
|
||||||
|
// Imagen API doesn't return token counts, set to 0 for tracking purposes
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": 0,
|
||||||
|
"candidatesTokenCount": 0,
|
||||||
|
"totalTokenCount": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
|
||||||
|
// Imagen API uses a different structure: instances[].prompt instead of contents[].
|
||||||
|
func convertToImagenRequest(payload []byte) ([]byte, error) {
|
||||||
|
// Extract prompt from Gemini-style contents
|
||||||
|
prompt := ""
|
||||||
|
|
||||||
|
// Try to get prompt from contents[0].parts[0].text
|
||||||
|
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
|
||||||
|
if contentsText.Exists() {
|
||||||
|
prompt = contentsText.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no contents, try messages format (OpenAI-compatible)
|
||||||
|
if prompt == "" {
|
||||||
|
messagesText := gjson.GetBytes(payload, "messages.#.content")
|
||||||
|
if messagesText.Exists() && messagesText.IsArray() {
|
||||||
|
for _, msg := range messagesText.Array() {
|
||||||
|
if msg.String() != "" {
|
||||||
|
prompt = msg.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still no prompt, try direct prompt field
|
||||||
|
if prompt == "" {
|
||||||
|
directPrompt := gjson.GetBytes(payload, "prompt")
|
||||||
|
if directPrompt.Exists() {
|
||||||
|
prompt = directPrompt.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, fmt.Errorf("imagen: no prompt found in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Imagen API request
|
||||||
|
imagenReq := map[string]any{
|
||||||
|
"instances": []map[string]any{
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"sampleCount": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract optional parameters
|
||||||
|
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
|
||||||
|
}
|
||||||
|
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
|
||||||
|
}
|
||||||
|
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
|
||||||
|
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(imagenReq)
|
||||||
|
}
|
||||||
|
|
||||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||||
type GeminiVertexExecutor struct {
|
type GeminiVertexExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -160,26 +298,38 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
var body []byte
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
|
|
||||||
originalPayload := bytes.Clone(req.Payload)
|
// Handle Imagen models with special request format
|
||||||
if len(opts.OriginalRequest) > 0 {
|
if isImagenModel(baseModel) {
|
||||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
imagenBody, errImagen := convertToImagenRequest(req.Payload)
|
||||||
}
|
if errImagen != nil {
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
return resp, errImagen
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
}
|
||||||
|
body = imagenBody
|
||||||
|
} else {
|
||||||
|
// Standard Gemini translation flow
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
if err != nil {
|
if len(opts.OriginalRequest) > 0 {
|
||||||
return resp, err
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
action := getVertexAction(baseModel, false)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
|
||||||
|
|
||||||
action := "generateContent"
|
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -249,6 +399,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
|
||||||
|
// For Imagen models, convert response to Gemini format before translation
|
||||||
|
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||||
|
if isImagenModel(baseModel) {
|
||||||
|
data = convertImagenToGeminiResponse(data, baseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
@@ -272,7 +432,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -281,7 +441,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := getVertexAction(baseModel, false)
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -375,7 +535,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -384,12 +544,16 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -494,7 +658,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -503,15 +667,19 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -605,7 +773,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
@@ -689,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -190,7 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -187,7 +187,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -297,7 +297,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
|
|
||||||
modelForCounting := baseModel
|
modelForCounting := baseModel
|
||||||
|
|
||||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cliproxyexecutor.Response{}, err
|
return cliproxyexecutor.Response{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -172,7 +172,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
|||||||
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||||
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
||||||
|
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - Modified request body JSON with thinking configuration applied
|
// - Modified request body JSON with thinking configuration applied
|
||||||
@@ -79,12 +80,16 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
|||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// // With suffix - suffix config takes priority
|
// // With suffix - suffix config takes priority
|
||||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini")
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
|
||||||
//
|
//
|
||||||
// // Without suffix - uses body config
|
// // Without suffix - uses body config
|
||||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini")
|
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
|
||||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) {
|
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
|
||||||
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
||||||
|
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
|
||||||
|
if providerKey == "" {
|
||||||
|
providerKey = providerFormat
|
||||||
|
}
|
||||||
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
||||||
if fromFormat == "" {
|
if fromFormat == "" {
|
||||||
fromFormat = providerFormat
|
fromFormat = providerFormat
|
||||||
@@ -102,7 +107,8 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string
|
|||||||
// 2. Parse suffix and get modelInfo
|
// 2. Parse suffix and get modelInfo
|
||||||
suffixResult := ParseSuffix(model)
|
suffixResult := ParseSuffix(model)
|
||||||
baseModel := suffixResult.ModelName
|
baseModel := suffixResult.ModelName
|
||||||
modelInfo := registry.LookupModelInfo(baseModel)
|
// Use provider-specific lookup to handle capability differences across providers.
|
||||||
|
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
|
||||||
|
|
||||||
// 3. Model capability check
|
// 3. Model capability check
|
||||||
// Unknown models are treated as user-defined so thinking config can still be applied.
|
// Unknown models are treated as user-defined so thinking config can still be applied.
|
||||||
|
|||||||
@@ -80,9 +80,66 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||||
|
|
||||||
|
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||||
|
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||||
|
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||||
|
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
|
||||||
|
if budgetTokens <= 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the request satisfies Claude constraints:
|
||||||
|
// 1) Determine effective max_tokens (request overrides model default)
|
||||||
|
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
|
||||||
|
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
|
||||||
|
// 4) If max_tokens came from model default, write it back into the request
|
||||||
|
|
||||||
|
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
|
||||||
|
if setDefaultMax && effectiveMax > 0 {
|
||||||
|
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
|
||||||
|
adjustedBudget := budgetTokens
|
||||||
|
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
|
||||||
|
adjustedBudget = effectiveMax - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
minBudget := 0
|
||||||
|
if modelInfo != nil && modelInfo.Thinking != nil {
|
||||||
|
minBudget = modelInfo.Thinking.Min
|
||||||
|
}
|
||||||
|
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
|
||||||
|
// If enforcing the max_tokens constraint would push the budget below the model minimum,
|
||||||
|
// leave the request unchanged.
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
if adjustedBudget != budgetTokens {
|
||||||
|
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// effectiveMaxTokens returns the max tokens to cap thinking:
|
||||||
|
// prefer request-provided max_tokens; otherwise fall back to model default.
|
||||||
|
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||||
|
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
|
||||||
|
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||||
|
return int(maxTok.Int()), false
|
||||||
|
}
|
||||||
|
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||||
|
return modelInfo.MaxCompletionTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||||
return body, nil
|
return body, nil
|
||||||
|
|||||||
@@ -124,11 +124,11 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||||
// Use GetThinkingText to handle wrapped thinking objects
|
// Use GetThinkingText to handle wrapped thinking objects
|
||||||
thinkingText := thinking.GetThinkingText(contentResult)
|
thinkingText := thinking.GetThinkingText(contentResult)
|
||||||
signatureResult := contentResult.Get("signature")
|
// signatureResult := contentResult.Get("signature")
|
||||||
clientSignature := ""
|
// clientSignature := ""
|
||||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
// if signatureResult.Exists() && signatureResult.String() != "" {
|
||||||
clientSignature = signatureResult.String()
|
// clientSignature = signatureResult.String()
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Always try cached signature first (more reliable than client-provided)
|
// Always try cached signature first (more reliable than client-provided)
|
||||||
// Client may send stale or invalid signatures from different sessions
|
// Client may send stale or invalid signatures from different sessions
|
||||||
@@ -140,11 +140,11 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to client signature only if cache miss and client signature is valid
|
// NOTE: We do NOT fallback to client signature anymore.
|
||||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
// Client signatures from Claude models are incompatible with Antigravity/Gemini API.
|
||||||
signature = clientSignature
|
// When switching between models (e.g., Claude Opus -> Gemini Flash), the Claude
|
||||||
// log.Debugf("Using client-provided signature for thinking block")
|
// signatures will cause "Corrupted thought signature" errors.
|
||||||
}
|
// If we have no cached signature, the thinking block will be skipped below.
|
||||||
|
|
||||||
// Store for subsequent tool_use in the same message
|
// Store for subsequent tool_use in the same message
|
||||||
if cache.HasValidSignature(signature) {
|
if cache.HasValidSignature(signature) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,28 +76,42 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
|||||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||||
// Valid signature must be at least 50 characters
|
// Valid signature must be at least 50 characters
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
|
// Pre-cache the signature (simulating a response from the same session)
|
||||||
|
// The session ID is derived from the first user message hash
|
||||||
|
// Since there's no user message in this test, we need to add one
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{"type": "text", "text": "Answer"}
|
{"type": "text", "text": "Answer"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
// Derive session ID and cache the signature
|
||||||
|
sessionID := deriveSessionID(inputJSON)
|
||||||
|
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||||
|
defer cache.ClearSignatureCache(sessionID)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check thinking block conversion
|
// Check thinking block conversion (now in contents.1 due to user message)
|
||||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
|
||||||
if !firstPart.Get("thought").Bool() {
|
if !firstPart.Get("thought").Bool() {
|
||||||
t.Error("thinking block should have thought: true")
|
t.Error("thinking block should have thought: true")
|
||||||
}
|
}
|
||||||
if firstPart.Get("text").String() != "Let me think..." {
|
if firstPart.Get("text").String() != thinkingText {
|
||||||
t.Error("thinking text mismatch")
|
t.Error("thinking text mismatch")
|
||||||
}
|
}
|
||||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||||
@@ -227,13 +242,19 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
|||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Let me think..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||||
{
|
{
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"id": "call_123",
|
"id": "call_123",
|
||||||
@@ -245,11 +266,16 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
// Derive session ID and cache the signature
|
||||||
|
sessionID := deriveSessionID(inputJSON)
|
||||||
|
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||||
|
defer cache.ClearSignatureCache(sessionID)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Check function call has the signature from the preceding thinking block
|
// Check function call has the signature from the preceding thinking block (now in contents.1)
|
||||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
part := gjson.Get(outputStr, "request.contents.1.parts.1")
|
||||||
if part.Get("functionCall.name").String() != "get_weather" {
|
if part.Get("functionCall.name").String() != "get_weather" {
|
||||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||||
}
|
}
|
||||||
@@ -261,24 +287,35 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
|||||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Planning..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Test user message"}]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is the plan."},
|
{"type": "text", "text": "Here is the plan."},
|
||||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
// Derive session ID and cache the signature
|
||||||
|
sessionID := deriveSessionID(inputJSON)
|
||||||
|
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||||
|
defer cache.ClearSignatureCache(sessionID)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
// Verify order: Thinking block MUST be first
|
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
|
||||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||||
}
|
}
|
||||||
@@ -460,6 +497,9 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
|||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||||
// Last assistant message ends with signed thinking block - should be kept
|
// Last assistant message ends with signed thinking block - should be kept
|
||||||
|
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||||
|
thinkingText := "Valid thinking..."
|
||||||
|
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
"model": "claude-sonnet-4-5-thinking",
|
"model": "claude-sonnet-4-5-thinking",
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -471,12 +511,17 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "text", "text": "Here is my answer"},
|
{"type": "text", "text": "Here is my answer"},
|
||||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}`)
|
}`)
|
||||||
|
|
||||||
|
// Derive session ID and cache the signature
|
||||||
|
sessionID := deriveSessionID(inputJSON)
|
||||||
|
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||||
|
defer cache.ClearSignatureCache(sessionID)
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||||
outputStr := string(output)
|
outputStr := string(output)
|
||||||
|
|
||||||
|
|||||||
@@ -117,8 +117,12 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
|
||||||
template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int())
|
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
|
||||||
|
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
output = "event: message_delta\n"
|
output = "event: message_delta\n"
|
||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
@@ -204,8 +208,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
|
|
||||||
hasToolCall := false
|
hasToolCall := false
|
||||||
|
|
||||||
@@ -308,12 +316,27 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
|
return out
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
}
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
|
||||||
|
func extractResponsesUsage(usage gjson.Result) (int64, int64, int64) {
|
||||||
|
if !usage.Exists() || usage.Type == gjson.Null {
|
||||||
|
return 0, 0, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
inputTokens := usage.Get("input_tokens").Int()
|
||||||
|
outputTokens := usage.Get("output_tokens").Int()
|
||||||
|
cachedTokens := usage.Get("input_tokens_details.cached_tokens").Int()
|
||||||
|
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
if inputTokens >= cachedTokens {
|
||||||
|
inputTokens -= cachedTokens
|
||||||
|
} else {
|
||||||
|
inputTokens = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputTokens, outputTokens, cachedTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
var messagesJSON = "[]"
|
var messagesJSON = "[]"
|
||||||
|
|
||||||
// Handle system message first
|
// Handle system message first
|
||||||
systemMsgJSON := `{"role":"system","content":[{"type":"text","text":"Use ANY tool, the parameters MUST accord with RFC 8259 (The JavaScript Object Notation (JSON) Data Interchange Format), the keys and value MUST be enclosed in double quotes."}]}`
|
systemMsgJSON := `{"role":"system","content":[]}`
|
||||||
if system := root.Get("system"); system.Exists() {
|
if system := root.Get("system"); system.Exists() {
|
||||||
if system.Type == gjson.String {
|
if system.Type == gjson.String {
|
||||||
if system.String() != "" {
|
if system.String() != "" {
|
||||||
|
|||||||
@@ -289,21 +289,17 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
|||||||
// Only process if usage has actual values (not null)
|
// Only process if usage has actual values (not null)
|
||||||
if param.FinishReason != "" {
|
if param.FinishReason != "" {
|
||||||
usage := root.Get("usage")
|
usage := root.Get("usage")
|
||||||
var inputTokens, outputTokens int64
|
var inputTokens, outputTokens, cachedTokens int64
|
||||||
if usage.Exists() && usage.Type != gjson.Null {
|
if usage.Exists() && usage.Type != gjson.Null {
|
||||||
// Check if usage has actual token counts
|
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
|
||||||
promptTokens := usage.Get("prompt_tokens")
|
|
||||||
completionTokens := usage.Get("completion_tokens")
|
|
||||||
|
|
||||||
if promptTokens.Exists() && completionTokens.Exists() {
|
|
||||||
inputTokens = promptTokens.Int()
|
|
||||||
outputTokens = completionTokens.Int()
|
|
||||||
}
|
|
||||||
// Send message_delta with usage
|
// Send message_delta with usage
|
||||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||||
param.MessageDeltaSent = true
|
param.MessageDeltaSent = true
|
||||||
|
|
||||||
@@ -423,13 +419,12 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
|||||||
|
|
||||||
// Set usage information
|
// Set usage information
|
||||||
if usage := root.Get("usage"); usage.Exists() {
|
if usage := root.Get("usage"); usage.Exists() {
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(usage)
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
reasoningTokens := int64(0)
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
|
if cachedTokens > 0 {
|
||||||
reasoningTokens = v.Int()
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
}
|
}
|
||||||
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{out}
|
return []string{out}
|
||||||
@@ -674,8 +669,12 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
}
|
}
|
||||||
|
|
||||||
if respUsage := root.Get("usage"); respUsage.Exists() {
|
if respUsage := root.Get("usage"); respUsage.Exists() {
|
||||||
out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
|
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(respUsage)
|
||||||
out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
|
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||||
|
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !stopReasonSet {
|
if !stopReasonSet {
|
||||||
@@ -692,3 +691,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAIUsage(usage gjson.Result) (int64, int64, int64) {
|
||||||
|
if !usage.Exists() || usage.Type == gjson.Null {
|
||||||
|
return 0, 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
inputTokens := usage.Get("prompt_tokens").Int()
|
||||||
|
outputTokens := usage.Get("completion_tokens").Int()
|
||||||
|
cachedTokens := usage.Get("prompt_tokens_details.cached_tokens").Int()
|
||||||
|
|
||||||
|
if cachedTokens > 0 {
|
||||||
|
if inputTokens >= cachedTokens {
|
||||||
|
inputTokens -= cachedTokens
|
||||||
|
} else {
|
||||||
|
inputTokens = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputTokens, outputTokens, cachedTokens
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// thinkingTestCase represents a common test case structure for both suffix and body tests.
|
// thinkingTestCase represents a common test case structure for both suffix and body tests.
|
||||||
@@ -2707,8 +2708,11 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
|
|||||||
[]byte(tc.inputJSON),
|
[]byte(tc.inputJSON),
|
||||||
true,
|
true,
|
||||||
)
|
)
|
||||||
|
if applyTo == "claude" {
|
||||||
|
body, _ = sjson.SetBytes(body, "max_tokens", 200000)
|
||||||
|
}
|
||||||
|
|
||||||
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo)
|
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo, applyTo)
|
||||||
|
|
||||||
if tc.expectErr {
|
if tc.expectErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user