mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 21:40:51 +08:00
Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
8d9f4edf9b | ||
|
|
020e61d0da | ||
|
|
6184c43319 | ||
|
|
2cbe4a790c | ||
|
|
68b3565d7b | ||
|
|
3f385a8572 | ||
|
|
9823dc35e1 | ||
|
|
059bfee91b | ||
|
|
7beaf0eaa2 | ||
|
|
1fef90ff58 | ||
|
|
8447fd27a0 | ||
|
|
7831cba9f6 | ||
|
|
e02b2d58d5 | ||
|
|
28726632a9 | ||
|
|
3b26129c82 | ||
|
|
d4bb4e6624 | ||
|
|
0766c49f93 | ||
|
|
a7ffc77e3d | ||
|
|
e641fde25c | ||
|
|
5717c7f2f4 | ||
|
|
8734d4cb90 | ||
|
|
5baa753539 | ||
|
|
ead98e4bca | ||
|
|
1d2fe55310 | ||
|
|
c175821cc4 | ||
|
|
239a28793c | ||
|
|
c421d653e7 | ||
|
|
2542c2920d | ||
|
|
52e46ced1b | ||
|
|
cf9daf470c | ||
|
|
5977af96a0 | ||
|
|
5bb9c2a2bd | ||
|
|
0b5bbe9234 | ||
|
|
14c74e5e84 | ||
|
|
6448d0ee7c | ||
|
|
b0c17af2cf | ||
|
|
aa8526edc0 | ||
|
|
ac3ca0ad8e | ||
|
|
08d21b76e2 | ||
|
|
33aa665555 | ||
|
|
00280b6fe8 | ||
|
|
52760a4eaa |
@@ -141,6 +141,15 @@ codex-instructions-enabled: false
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||
# cloak: # optional: request cloaking for non-Claude-Code clients
|
||||
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
|
||||
# # "always": always apply cloaking
|
||||
# # "never": never apply cloaking
|
||||
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
|
||||
# # true: strip all user system messages, keep only Claude Code prompt
|
||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||
# - "API"
|
||||
# - "proxy"
|
||||
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
|
||||
@@ -3,6 +3,8 @@ package management
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -1383,9 +1385,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
||||
email := ""
|
||||
accountID := ""
|
||||
planType := ""
|
||||
if claims != nil {
|
||||
email = claims.GetUserEmail()
|
||||
accountID = claims.GetAccountID()
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
}
|
||||
hashAccountID := ""
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
// Build bundle compatible with existing storage
|
||||
bundle := &codex.CodexAuthBundle{
|
||||
@@ -1402,10 +1411,11 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
// Create token storage and persist
|
||||
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
||||
ID: fileName,
|
||||
Provider: "codex",
|
||||
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
@@ -2110,7 +2120,20 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
finalProjectID := projectID
|
||||
if responseProjectID != "" {
|
||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||
strings.EqualFold(tierID, "FREE") ||
|
||||
strings.EqualFold(tierID, "LEGACY")
|
||||
|
||||
if isFreeUser {
|
||||
// For free users, use backend project ID for preview model access
|
||||
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
|
||||
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
|
||||
finalProjectID = responseProjectID
|
||||
} else {
|
||||
// Pro users: keep requested project ID (original behavior)
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
}
|
||||
} else {
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
|
||||
57
internal/auth/codex/filename.go
Normal file
57
internal/auth/codex/filename.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||
// When planType is available (e.g. "plus", "team"), it is appended after the email
|
||||
// as a suffix to disambiguate subscriptions.
|
||||
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
plan := normalizePlanTypeForFilename(planType)
|
||||
|
||||
prefix := ""
|
||||
if includeProviderPrefix {
|
||||
prefix = "codex"
|
||||
}
|
||||
|
||||
if plan == "" {
|
||||
return fmt.Sprintf("%s-%s.json", prefix, email)
|
||||
} else if plan == "team" {
|
||||
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
|
||||
}
|
||||
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
|
||||
}
|
||||
|
||||
func normalizePlanTypeForFilename(planType string) string {
|
||||
planType = strings.TrimSpace(planType)
|
||||
if planType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(planType, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = titleToken(part)
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
func titleToken(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
return cases.Title(language.English).String(token)
|
||||
}
|
||||
55
internal/cache/signature_cache.go
vendored
55
internal/cache/signature_cache.go
vendored
@@ -3,6 +3,8 @@ package cache
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -94,17 +96,17 @@ func purgeExpiredSessions() {
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(sessionID, text, signature string) {
|
||||
if sessionID == "" || text == "" || signature == "" {
|
||||
func CacheSignature(modelName, text, signature string) {
|
||||
if text == "" || signature == "" {
|
||||
return
|
||||
}
|
||||
if len(signature) < MinValidSignatureLen {
|
||||
return
|
||||
}
|
||||
|
||||
sc := getOrCreateSession(sessionID)
|
||||
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
|
||||
textHash := hashText(text)
|
||||
|
||||
sc := getOrCreateSession(textHash)
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
@@ -116,13 +118,21 @@ func CacheSignature(sessionID, text, signature string) {
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(sessionID, text string) string {
|
||||
if sessionID == "" || text == "" {
|
||||
func GetCachedSignature(modelName, text string) string {
|
||||
family := GetModelGroup(modelName)
|
||||
|
||||
if text == "" {
|
||||
if family == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
val, ok := signatureCache.Load(sessionID)
|
||||
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
|
||||
val, ok := signatureCache.Load(hashText(text))
|
||||
if !ok {
|
||||
if family == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
sc := val.(*sessionCache)
|
||||
@@ -135,11 +145,17 @@ func GetCachedSignature(sessionID, text string) string {
|
||||
entry, exists := sc.entries[textHash]
|
||||
if !exists {
|
||||
sc.mu.Unlock()
|
||||
if family == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
if family == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -154,7 +170,13 @@ func GetCachedSignature(sessionID, text string) string {
|
||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
||||
func ClearSignatureCache(sessionID string) {
|
||||
if sessionID != "" {
|
||||
signatureCache.Delete(sessionID)
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
kStr, ok := key.(string)
|
||||
if ok && strings.HasSuffix(kStr, "#"+sessionID) {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
@@ -164,6 +186,17 @@ func ClearSignatureCache(sessionID string) {
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
func HasValidSignature(signature string) bool {
|
||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
||||
func HasValidSignature(modelName, signature string) bool {
|
||||
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
|
||||
}
|
||||
|
||||
func GetModelGroup(modelName string) string {
|
||||
if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
} else if strings.Contains(modelName, "claude") {
|
||||
return "claude"
|
||||
} else if strings.Contains(modelName, "gemini") {
|
||||
return "gemini"
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
79
internal/cache/signature_cache_test.go
vendored
79
internal/cache/signature_cache_test.go
vendored
@@ -8,15 +8,14 @@ import (
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-session-1"
|
||||
text := "This is some thinking text content"
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature(sessionID, text, signature)
|
||||
CacheSignature("test-model", text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature(sessionID, text)
|
||||
retrieved := GetCachedSignature("test-model", text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
@@ -29,13 +28,13 @@ func TestCacheSignature_DifferentSessions(t *testing.T) {
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("session-a", text, sig1)
|
||||
CacheSignature("session-b", text, sig2)
|
||||
CacheSignature("test-model", text, sig1)
|
||||
CacheSignature("test-model", text, sig2)
|
||||
|
||||
if GetCachedSignature("session-a", text) != sig1 {
|
||||
if GetCachedSignature("test-model", text) != sig1 {
|
||||
t.Error("Session-a signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("session-b", text) != sig2 {
|
||||
if GetCachedSignature("test-model", text) != sig2 {
|
||||
t.Error("Session-b signature mismatch")
|
||||
}
|
||||
}
|
||||
@@ -44,13 +43,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
||||
if got := GetCachedSignature("test-model", "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
||||
CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("test-model", "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -59,12 +58,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "text", "")
|
||||
CacheSignature("session", "text", "short") // Too short
|
||||
CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("test-model", "text", "")
|
||||
CacheSignature("test-model", "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature("session", "text"); got != "" {
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -72,13 +71,12 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-short-sig"
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature(sessionID, text, shortSig)
|
||||
CacheSignature("test-model", text, shortSig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
||||
if got := GetCachedSignature("test-model", text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -87,15 +85,15 @@ func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
||||
if got := GetCachedSignature("test-model", "text"); got != sig {
|
||||
t.Error("session-2 should still exist")
|
||||
}
|
||||
}
|
||||
@@ -104,15 +102,15 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
CacheSignature("test-model", "text", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
||||
if got := GetCachedSignature("test-model", "text"); got != "" {
|
||||
t.Error("session-2 should be cleared")
|
||||
}
|
||||
}
|
||||
@@ -132,7 +130,7 @@ func TestHasValidSignature(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature(tt.signature)
|
||||
result := HasValidSignature("claude-sonnet-4-5-thinking", tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
@@ -143,21 +141,19 @@ func TestHasValidSignature(t *testing.T) {
|
||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "hash-test-session"
|
||||
|
||||
// Different texts should produce different hashes
|
||||
text1 := "First thinking text"
|
||||
text2 := "Second thinking text"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text1, sig1)
|
||||
CacheSignature(sessionID, text2, sig2)
|
||||
CacheSignature("test-model", text1, sig1)
|
||||
CacheSignature("test-model", text2, sig2)
|
||||
|
||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
||||
if GetCachedSignature("test-model", text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
||||
if GetCachedSignature("test-model", text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
@@ -165,13 +161,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "unicode-session"
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
CacheSignature("test-model", text, sig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
if got := GetCachedSignature("test-model", text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -179,15 +174,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "overwrite-session"
|
||||
text := "Same text"
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature(sessionID, text, sig1)
|
||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
||||
CacheSignature("test-model", text, sig1)
|
||||
CacheSignature("test-model", text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
||||
if got := GetCachedSignature("test-model", text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
@@ -199,14 +193,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
|
||||
// This test verifies the expiration check exists
|
||||
// In a real scenario, we'd mock time.Now()
|
||||
sessionID := "expiration-test"
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
CacheSignature("test-model", text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
if got := GetCachedSignature("test-model", text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
@@ -134,6 +135,13 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
@@ -261,7 +269,39 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
finalProjectID := projectID
|
||||
if responseProjectID != "" {
|
||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||
strings.EqualFold(tierID, "FREE") ||
|
||||
strings.EqualFold(tierID, "LEGACY")
|
||||
|
||||
if isFreeUser {
|
||||
// Interactive prompt for free users
|
||||
fmt.Printf("\nGoogle returned a different project ID:\n")
|
||||
fmt.Printf(" Requested (frontend): %s\n", projectID)
|
||||
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
|
||||
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
|
||||
fmt.Printf(" This is normal for free tier users.\n\n")
|
||||
fmt.Printf("Which project ID would you like to use?\n")
|
||||
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
|
||||
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
|
||||
fmt.Printf("Enter choice [1]: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
choice, _ := reader.ReadString('\n')
|
||||
choice = strings.TrimSpace(choice)
|
||||
|
||||
if choice == "2" {
|
||||
log.Infof("Using frontend project ID: %s", projectID)
|
||||
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
|
||||
finalProjectID = projectID
|
||||
} else {
|
||||
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
} else {
|
||||
// Pro users: keep requested project ID (original behavior)
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
}
|
||||
} else {
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
|
||||
@@ -248,6 +248,25 @@ type PayloadModelRule struct {
|
||||
Protocol string `yaml:"protocol" json:"protocol"`
|
||||
}
|
||||
|
||||
// CloakConfig configures request cloaking for non-Claude-Code clients.
|
||||
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
|
||||
type CloakConfig struct {
|
||||
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
|
||||
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
|
||||
// - "always": always apply cloaking regardless of client
|
||||
// - "never": never apply cloaking
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
|
||||
// StrictMode controls how system prompts are handled when cloaking.
|
||||
// - false (default): prepend Claude Code prompt to user system messages
|
||||
// - true: strip all user system messages, keep only Claude Code prompt
|
||||
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
|
||||
|
||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||
// This can help bypass certain content filters.
|
||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeKey represents the configuration for a Claude API key,
|
||||
// including the API key itself and an optional base URL for the API endpoint.
|
||||
type ClaudeKey struct {
|
||||
@@ -276,6 +295,9 @@ type ClaudeKey struct {
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
|
||||
// Cloak configures request cloaking for non-Claude-Code clients.
|
||||
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
|
||||
}
|
||||
|
||||
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
|
||||
// - gin.HandlerFunc: A middleware handler for panic recovery
|
||||
func GinLogrusRecovery() gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"panic": recovered,
|
||||
"stack": string(debug.Stack()),
|
||||
|
||||
60
internal/logging/gin_logger_test.go
Normal file
60
internal/logging/gin_logger_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/abort", func(c *gin.Context) {
|
||||
panic(http.ErrAbortHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
t.Fatalf("expected panic, got nil")
|
||||
}
|
||||
err, ok := recovered.(error)
|
||||
if !ok {
|
||||
t.Fatalf("expected error panic, got %T", recovered)
|
||||
}
|
||||
if !errors.Is(err, http.ErrAbortHandler) {
|
||||
t.Fatalf("expected ErrAbortHandler, got %v", err)
|
||||
}
|
||||
if err != http.ErrAbortHandler {
|
||||
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/panic", func(c *gin.Context) {
|
||||
panic("boom")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
@@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
// Imagen image generation models - use :predict action
|
||||
{
|
||||
ID: "imagen-4.0-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Generate",
|
||||
Description: "Imagen 4.0 image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-4.0-ultra-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-ultra-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-3.0-generate-002",
|
||||
Object: "model",
|
||||
Created: 1740000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-3.0-generate-002",
|
||||
Version: "3.0",
|
||||
DisplayName: "Imagen 3.0 Generate",
|
||||
Description: "Imagen 3.0 image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-3.0-fast-generate-001",
|
||||
Object: "model",
|
||||
Created: 1740000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-3.0-fast-generate-001",
|
||||
Version: "3.0",
|
||||
DisplayName: "Imagen 3.0 Fast Generate",
|
||||
Description: "Imagen 3.0 fast image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
{
|
||||
ID: "imagen-4.0-fast-generate-001",
|
||||
Object: "model",
|
||||
Created: 1750000000,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/imagen-4.0-fast-generate-001",
|
||||
Version: "4.0",
|
||||
DisplayName: "Imagen 4.0 Fast Generate",
|
||||
Description: "Imagen 4.0 fast image generation model",
|
||||
SupportedGenerationMethods: []string{"predict"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -765,21 +826,23 @@ func GetIFlowModels() []*ModelInfo {
|
||||
type AntigravityModelConfig struct {
|
||||
Thinking *ThinkingSupport
|
||||
MaxCompletionTokens int
|
||||
Name string
|
||||
}
|
||||
|
||||
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
return map[string]*AntigravityModelConfig{
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/rev19-uic3-1p"},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-high"},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image"},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash"},
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -809,10 +872,9 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
}
|
||||
|
||||
// Check Antigravity static config
|
||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil && cfg.Thinking != nil {
|
||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
||||
return &ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Thinking: cfg.Thinking,
|
||||
MaxCompletionTokens: cfg.MaxCompletionTokens,
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ type ThinkingSupport struct {
|
||||
type ModelRegistration struct {
|
||||
// Info contains the model metadata
|
||||
Info *ModelInfo
|
||||
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
|
||||
InfoByProvider map[string]*ModelInfo
|
||||
// Count is the number of active clients that can provide this model
|
||||
Count int
|
||||
// LastUpdated tracks when this registration was last modified
|
||||
@@ -132,16 +134,19 @@ func GetGlobalRegistry() *ModelRegistry {
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// LookupModelInfo searches the dynamic registry first, then falls back to static model definitions.
|
||||
//
|
||||
// This helper exists because some code paths only have a model ID and still need Thinking and
|
||||
// max completion token metadata even when the dynamic registry hasn't been populated.
|
||||
func LookupModelInfo(modelID string) *ModelInfo {
|
||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return nil
|
||||
}
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID); info != nil {
|
||||
|
||||
p := ""
|
||||
if len(provider) > 0 {
|
||||
p = strings.ToLower(strings.TrimSpace(provider[0]))
|
||||
}
|
||||
|
||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||
return info
|
||||
}
|
||||
return LookupStaticModelInfo(modelID)
|
||||
@@ -297,6 +302,9 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||
if count <= toRemove {
|
||||
delete(reg.Providers, oldProvider)
|
||||
if reg.InfoByProvider != nil {
|
||||
delete(reg.InfoByProvider, oldProvider)
|
||||
}
|
||||
} else {
|
||||
reg.Providers[oldProvider] = count - toRemove
|
||||
}
|
||||
@@ -346,6 +354,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
model := newModels[id]
|
||||
if reg, ok := r.models[id]; ok {
|
||||
reg.Info = cloneModelInfo(model)
|
||||
if provider != "" {
|
||||
if reg.InfoByProvider == nil {
|
||||
reg.InfoByProvider = make(map[string]*ModelInfo)
|
||||
}
|
||||
reg.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
reg.LastUpdated = now
|
||||
if reg.QuotaExceededClients != nil {
|
||||
delete(reg.QuotaExceededClients, clientID)
|
||||
@@ -409,11 +423,15 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
if existing.SuspendedClients == nil {
|
||||
existing.SuspendedClients = make(map[string]string)
|
||||
}
|
||||
if existing.InfoByProvider == nil {
|
||||
existing.InfoByProvider = make(map[string]*ModelInfo)
|
||||
}
|
||||
if provider != "" {
|
||||
if existing.Providers == nil {
|
||||
existing.Providers = make(map[string]int)
|
||||
}
|
||||
existing.Providers[provider]++
|
||||
existing.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||
return
|
||||
@@ -421,6 +439,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
|
||||
registration := &ModelRegistration{
|
||||
Info: cloneModelInfo(model),
|
||||
InfoByProvider: make(map[string]*ModelInfo),
|
||||
Count: 1,
|
||||
LastUpdated: now,
|
||||
QuotaExceededClients: make(map[string]*time.Time),
|
||||
@@ -428,6 +447,7 @@ func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *Mo
|
||||
}
|
||||
if provider != "" {
|
||||
registration.Providers = map[string]int{provider: 1}
|
||||
registration.InfoByProvider[provider] = cloneModelInfo(model)
|
||||
}
|
||||
r.models[modelID] = registration
|
||||
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||
@@ -453,6 +473,9 @@ func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider stri
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
if count <= 1 {
|
||||
delete(registration.Providers, provider)
|
||||
if registration.InfoByProvider != nil {
|
||||
delete(registration.InfoByProvider, provider)
|
||||
}
|
||||
} else {
|
||||
registration.Providers[provider] = count - 1
|
||||
}
|
||||
@@ -534,6 +557,9 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
if count, ok := registration.Providers[provider]; ok {
|
||||
if count <= 1 {
|
||||
delete(registration.Providers, provider)
|
||||
if registration.InfoByProvider != nil {
|
||||
delete(registration.InfoByProvider, provider)
|
||||
}
|
||||
} else {
|
||||
registration.Providers[provider] = count - 1
|
||||
}
|
||||
@@ -940,12 +966,22 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetModelInfo returns the registered ModelInfo for the given model ID, if present.
|
||||
// Returns nil if the model is unknown to the registry.
|
||||
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
|
||||
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
|
||||
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
if reg, ok := r.models[modelID]; ok && reg != nil {
|
||||
// Try provider specific definition first
|
||||
if provider != "" && reg.InfoByProvider != nil {
|
||||
if reg.Providers != nil {
|
||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to global info (last registered)
|
||||
return reg.Info
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -393,7 +393,7 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -622,7 +622,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -802,7 +802,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
@@ -1005,9 +1005,6 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
}
|
||||
modelCfg := modelConfig[modelID]
|
||||
modelName := modelID
|
||||
if modelCfg != nil && modelCfg.Name != "" {
|
||||
modelName = modelCfg.Name
|
||||
}
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: modelID,
|
||||
Name: modelName,
|
||||
@@ -1205,7 +1202,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
@@ -1408,16 +1405,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
// template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
|
||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
@@ -1429,7 +1419,9 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
})
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
}
|
||||
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -106,22 +105,20 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(baseModel, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -239,20 +236,20 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body = checkSystemInstructions(body)
|
||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(baseModel, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
@@ -541,81 +538,6 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||
// This function should be called after all thinking configuration is finalized.
|
||||
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
|
||||
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType != "enabled" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
|
||||
if budgetTokens <= 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
|
||||
|
||||
// Look up the model's max completion tokens from the registry
|
||||
maxCompletionTokens := 0
|
||||
if modelInfo := registry.LookupModelInfo(modelName); modelInfo != nil {
|
||||
maxCompletionTokens = modelInfo.MaxCompletionTokens
|
||||
}
|
||||
|
||||
// Fall back to budget + buffer if registry lookup fails or returns 0
|
||||
const fallbackBuffer = 4000
|
||||
requiredMaxTokens := budgetTokens + fallbackBuffer
|
||||
if maxCompletionTokens > 0 {
|
||||
requiredMaxTokens = int64(maxCompletionTokens)
|
||||
}
|
||||
|
||||
if maxTokens < requiredMaxTokens {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) resolveClaudeConfig(auth *cliproxyauth.Auth) *config.ClaudeKey {
|
||||
if auth == nil || e.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range e.cfg.ClaudeKey {
|
||||
entry := &e.cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range e.cfg.ClaudeKey {
|
||||
entry := &e.cfg.ClaudeKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type compositeReadCloser struct {
|
||||
io.Reader
|
||||
closers []func() error
|
||||
@@ -901,3 +823,163 @@ func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// getClientUserAgent extracts the client User-Agent from the gin context.
|
||||
func getClientUserAgent(ctx context.Context) string {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
return ginCtx.GetHeader("User-Agent")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||
// Returns (cloakMode, strictMode, sensitiveWords).
|
||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return "auto", false, nil
|
||||
}
|
||||
|
||||
cloakMode := auth.Attributes["cloak_mode"]
|
||||
if cloakMode == "" {
|
||||
cloakMode = "auto"
|
||||
}
|
||||
|
||||
strictMode := strings.ToLower(auth.Attributes["cloak_strict_mode"]) == "true"
|
||||
|
||||
var sensitiveWords []string
|
||||
if wordsStr := auth.Attributes["cloak_sensitive_words"]; wordsStr != "" {
|
||||
sensitiveWords = strings.Split(wordsStr, ",")
|
||||
for i := range sensitiveWords {
|
||||
sensitiveWords[i] = strings.TrimSpace(sensitiveWords[i])
|
||||
}
|
||||
}
|
||||
|
||||
return cloakMode, strictMode, sensitiveWords
|
||||
}
|
||||
|
||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
|
||||
if cfg == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey, baseURL := claudeCreds(auth)
|
||||
if apiKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
|
||||
// Match by API key
|
||||
if strings.EqualFold(cfgKey, apiKey) {
|
||||
// If baseURL is specified, also check it
|
||||
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
|
||||
continue
|
||||
}
|
||||
return entry.Cloak
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||
func injectFakeUserID(payload []byte) []byte {
|
||||
metadata := gjson.GetBytes(payload, "metadata")
|
||||
if !metadata.Exists() {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
return payload
|
||||
}
|
||||
|
||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// checkSystemInstructionsWithMode injects Claude Code system prompt.
|
||||
// In strict mode, it replaces all user system messages.
|
||||
// In non-strict mode (default), it prepends to existing system messages.
|
||||
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
|
||||
if strictMode {
|
||||
// Strict mode: replace all system messages with Claude Code prompt only
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
return payload
|
||||
}
|
||||
|
||||
// Non-strict mode (default): prepend Claude Code prompt to existing system messages
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
||||
clientUserAgent := getClientUserAgent(ctx)
|
||||
|
||||
// Get cloak config from ClaudeKey configuration
|
||||
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
|
||||
|
||||
// Determine cloak settings
|
||||
var cloakMode string
|
||||
var strictMode bool
|
||||
var sensitiveWords []string
|
||||
|
||||
if cloakCfg != nil {
|
||||
cloakMode = cloakCfg.Mode
|
||||
strictMode = cloakCfg.StrictMode
|
||||
sensitiveWords = cloakCfg.SensitiveWords
|
||||
}
|
||||
|
||||
// Fallback to auth attributes if no config found
|
||||
if cloakMode == "" {
|
||||
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
||||
cloakMode = attrMode
|
||||
if !strictMode {
|
||||
strictMode = attrStrict
|
||||
}
|
||||
if len(sensitiveWords) == 0 {
|
||||
sensitiveWords = attrWords
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if cloaking should be applied
|
||||
if !shouldCloak(cloakMode, clientUserAgent) {
|
||||
return payload
|
||||
}
|
||||
|
||||
// Skip system instructions for claude-3-5-haiku models
|
||||
if !strings.HasPrefix(model, "claude-3-5-haiku") {
|
||||
payload = checkSystemInstructionsWithMode(payload, strictMode)
|
||||
}
|
||||
|
||||
// Inject fake user ID
|
||||
payload = injectFakeUserID(payload)
|
||||
|
||||
// Apply sensitive word obfuscation
|
||||
if len(sensitiveWords) > 0 {
|
||||
matcher := buildSensitiveWordMatcher(sensitiveWords)
|
||||
payload = obfuscateSensitiveWords(payload, matcher)
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
176
internal/runtime/executor/cloak_obfuscate.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
|
||||
const zeroWidthSpace = "\u200B"
|
||||
|
||||
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
|
||||
type SensitiveWordMatcher struct {
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
// buildSensitiveWordMatcher compiles a regex from the word list.
|
||||
// Words are sorted by length (longest first) for proper matching.
|
||||
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
|
||||
if len(words) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter and normalize words
|
||||
var validWords []string
|
||||
for _, w := range words {
|
||||
w = strings.TrimSpace(w)
|
||||
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
|
||||
validWords = append(validWords, w)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validWords) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by length (longest first) for proper matching
|
||||
sort.Slice(validWords, func(i, j int) bool {
|
||||
return len(validWords[i]) > len(validWords[j])
|
||||
})
|
||||
|
||||
// Escape and join
|
||||
escaped := make([]string, len(validWords))
|
||||
for i, w := range validWords {
|
||||
escaped[i] = regexp.QuoteMeta(w)
|
||||
}
|
||||
|
||||
pattern := "(?i)" + strings.Join(escaped, "|")
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SensitiveWordMatcher{regex: re}
|
||||
}
|
||||
|
||||
// obfuscateWord inserts a zero-width space after the first grapheme.
|
||||
func obfuscateWord(word string) string {
|
||||
if strings.Contains(word, zeroWidthSpace) {
|
||||
return word
|
||||
}
|
||||
|
||||
// Get first rune
|
||||
r, size := utf8.DecodeRuneInString(word)
|
||||
if r == utf8.RuneError || size >= len(word) {
|
||||
return word
|
||||
}
|
||||
|
||||
return string(r) + zeroWidthSpace + word[size:]
|
||||
}
|
||||
|
||||
// obfuscateText replaces all sensitive words in the text.
|
||||
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
|
||||
if m == nil || m.regex == nil {
|
||||
return text
|
||||
}
|
||||
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
|
||||
}
|
||||
|
||||
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
|
||||
// in system blocks and message content.
|
||||
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
if matcher == nil || matcher.regex == nil {
|
||||
return payload
|
||||
}
|
||||
|
||||
// Obfuscate in system blocks
|
||||
payload = obfuscateSystemBlocks(payload, matcher)
|
||||
|
||||
// Obfuscate in messages
|
||||
payload = obfuscateMessages(payload, matcher)
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
|
||||
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
if !system.Exists() {
|
||||
return payload
|
||||
}
|
||||
|
||||
if system.IsArray() {
|
||||
modified := false
|
||||
system.ForEach(func(key, value gjson.Result) bool {
|
||||
if value.Get("type").String() == "text" {
|
||||
text := value.Get("text").String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
path := "system." + key.String() + ".text"
|
||||
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if modified {
|
||||
return payload
|
||||
}
|
||||
} else if system.Type == gjson.String {
|
||||
text := system.String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
|
||||
}
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// obfuscateMessages obfuscates sensitive words in message content.
|
||||
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return payload
|
||||
}
|
||||
|
||||
messages.ForEach(func(msgKey, msg gjson.Result) bool {
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() {
|
||||
return true
|
||||
}
|
||||
|
||||
msgPath := "messages." + msgKey.String()
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Simple string content
|
||||
text := content.String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
|
||||
}
|
||||
} else if content.IsArray() {
|
||||
// Array of content blocks
|
||||
content.ForEach(func(blockKey, block gjson.Result) bool {
|
||||
if block.Get("type").String() == "text" {
|
||||
text := block.Get("text").String()
|
||||
obfuscated := matcher.obfuscateText(text)
|
||||
if obfuscated != text {
|
||||
path := msgPath + ".content." + blockKey.String() + ".text"
|
||||
payload, _ = sjson.SetBytes(payload, path, obfuscated)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return payload
|
||||
}
|
||||
47
internal/runtime/executor/cloak_utils.go
Normal file
47
internal/runtime/executor/cloak_utils.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// userIDPattern matches Claude Code format: user_[64-hex]_account__session_[uuid-v4]
|
||||
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
|
||||
|
||||
// generateFakeUserID generates a fake user ID in Claude Code format.
|
||||
// Format: user_[64-hex-chars]_account__session_[UUID-v4]
|
||||
func generateFakeUserID() string {
|
||||
hexBytes := make([]byte, 32)
|
||||
_, _ = rand.Read(hexBytes)
|
||||
hexPart := hex.EncodeToString(hexBytes)
|
||||
uuidPart := uuid.New().String()
|
||||
return "user_" + hexPart + "_account__session_" + uuidPart
|
||||
}
|
||||
|
||||
// isValidUserID checks if a user ID matches Claude Code format.
|
||||
func isValidUserID(userID string) bool {
|
||||
return userIDPattern.MatchString(userID)
|
||||
}
|
||||
|
||||
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
|
||||
// Returns true if cloaking should be applied.
|
||||
func shouldCloak(cloakMode string, userAgent string) bool {
|
||||
switch strings.ToLower(cloakMode) {
|
||||
case "always":
|
||||
return true
|
||||
case "never":
|
||||
return false
|
||||
default: // "auto" or empty
|
||||
// If client is Claude Code, don't cloak
|
||||
return !strings.HasPrefix(userAgent, "claude-cli")
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
|
||||
func isClaudeCodeClient(userAgent string) bool {
|
||||
return strings.HasPrefix(userAgent, "claude-cli")
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -208,7 +208,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, true)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -316,7 +316,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, body, false)
|
||||
body = misc.StripCodexUserAgent(body)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String())
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -479,7 +479,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String())
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -31,6 +32,143 @@ const (
|
||||
vertexAPIVersion = "v1"
|
||||
)
|
||||
|
||||
// isImagenModel checks if the model name is an Imagen image generation model.
|
||||
// Imagen models use the :predict action instead of :generateContent.
|
||||
func isImagenModel(model string) bool {
|
||||
lowerModel := strings.ToLower(model)
|
||||
return strings.Contains(lowerModel, "imagen")
|
||||
}
|
||||
|
||||
// getVertexAction returns the appropriate action for the given model.
|
||||
// Imagen models use "predict", while Gemini models use "generateContent".
|
||||
func getVertexAction(model string, isStream bool) string {
|
||||
if isImagenModel(model) {
|
||||
return "predict"
|
||||
}
|
||||
if isStream {
|
||||
return "streamGenerateContent"
|
||||
}
|
||||
return "generateContent"
|
||||
}
|
||||
|
||||
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
|
||||
// so it can be processed by the standard translation pipeline.
|
||||
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
|
||||
func convertImagenToGeminiResponse(data []byte, model string) []byte {
|
||||
predictions := gjson.GetBytes(data, "predictions")
|
||||
if !predictions.Exists() || !predictions.IsArray() {
|
||||
return data
|
||||
}
|
||||
|
||||
// Build Gemini-compatible response with inlineData
|
||||
parts := make([]map[string]any, 0)
|
||||
for _, pred := range predictions.Array() {
|
||||
imageData := pred.Get("bytesBase64Encoded").String()
|
||||
mimeType := pred.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
if imageData != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"inlineData": map[string]any{
|
||||
"mimeType": mimeType,
|
||||
"data": imageData,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Generate unique response ID using timestamp
|
||||
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
|
||||
|
||||
response := map[string]any{
|
||||
"candidates": []map[string]any{{
|
||||
"content": map[string]any{
|
||||
"parts": parts,
|
||||
"role": "model",
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
}},
|
||||
"responseId": responseId,
|
||||
"modelVersion": model,
|
||||
// Imagen API doesn't return token counts, set to 0 for tracking purposes
|
||||
"usageMetadata": map[string]any{
|
||||
"promptTokenCount": 0,
|
||||
"candidatesTokenCount": 0,
|
||||
"totalTokenCount": 0,
|
||||
},
|
||||
}
|
||||
|
||||
result, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
|
||||
// Imagen API uses a different structure: instances[].prompt instead of contents[].
|
||||
func convertToImagenRequest(payload []byte) ([]byte, error) {
|
||||
// Extract prompt from Gemini-style contents
|
||||
prompt := ""
|
||||
|
||||
// Try to get prompt from contents[0].parts[0].text
|
||||
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
|
||||
if contentsText.Exists() {
|
||||
prompt = contentsText.String()
|
||||
}
|
||||
|
||||
// If no contents, try messages format (OpenAI-compatible)
|
||||
if prompt == "" {
|
||||
messagesText := gjson.GetBytes(payload, "messages.#.content")
|
||||
if messagesText.Exists() && messagesText.IsArray() {
|
||||
for _, msg := range messagesText.Array() {
|
||||
if msg.String() != "" {
|
||||
prompt = msg.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If still no prompt, try direct prompt field
|
||||
if prompt == "" {
|
||||
directPrompt := gjson.GetBytes(payload, "prompt")
|
||||
if directPrompt.Exists() {
|
||||
prompt = directPrompt.String()
|
||||
}
|
||||
}
|
||||
|
||||
if prompt == "" {
|
||||
return nil, fmt.Errorf("imagen: no prompt found in request")
|
||||
}
|
||||
|
||||
// Build Imagen API request
|
||||
imagenReq := map[string]any{
|
||||
"instances": []map[string]any{
|
||||
{
|
||||
"prompt": prompt,
|
||||
},
|
||||
},
|
||||
"parameters": map[string]any{
|
||||
"sampleCount": 1,
|
||||
},
|
||||
}
|
||||
|
||||
// Extract optional parameters
|
||||
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
|
||||
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
|
||||
}
|
||||
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
|
||||
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
|
||||
}
|
||||
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
|
||||
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
|
||||
}
|
||||
|
||||
return json.Marshal(imagenReq)
|
||||
}
|
||||
|
||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||
type GeminiVertexExecutor struct {
|
||||
cfg *config.Config
|
||||
@@ -160,26 +298,38 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var body []byte
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
// Handle Imagen models with special request format
|
||||
if isImagenModel(baseModel) {
|
||||
imagenBody, errImagen := convertToImagenRequest(req.Payload)
|
||||
if errImagen != nil {
|
||||
return resp, errImagen
|
||||
}
|
||||
body = imagenBody
|
||||
} else {
|
||||
// Standard Gemini translation flow
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
action := getVertexAction(baseModel, false)
|
||||
if req.Metadata != nil {
|
||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||
action = "countTokens"
|
||||
@@ -249,6 +399,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
|
||||
// For Imagen models, convert response to Gemini format before translation
|
||||
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||
if isImagenModel(baseModel) {
|
||||
data = convertImagenToGeminiResponse(data, baseModel)
|
||||
}
|
||||
|
||||
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
@@ -272,7 +432,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -281,7 +441,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
action := getVertexAction(baseModel, false)
|
||||
if req.Metadata != nil {
|
||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||
action = "countTokens"
|
||||
@@ -375,7 +535,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -384,12 +544,16 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||
// Imagen models don't support streaming, skip SSE params
|
||||
if !isImagenModel(baseModel) {
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
@@ -494,7 +658,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -503,15 +667,19 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||
// Imagen models don't support streaming, skip SSE params
|
||||
if !isImagenModel(baseModel) {
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
@@ -605,7 +773,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
@@ -689,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String())
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow")
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String())
|
||||
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
@@ -172,7 +172,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
|
||||
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||
//
|
||||
// Returns:
|
||||
// - Modified request body JSON with thinking configuration applied
|
||||
@@ -79,12 +80,16 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||
// Example:
|
||||
//
|
||||
// // With suffix - suffix config takes priority
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini")
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
|
||||
//
|
||||
// // Without suffix - uses body config
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini")
|
||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string) ([]byte, error) {
|
||||
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
|
||||
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
|
||||
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
|
||||
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
|
||||
if providerKey == "" {
|
||||
providerKey = providerFormat
|
||||
}
|
||||
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
|
||||
if fromFormat == "" {
|
||||
fromFormat = providerFormat
|
||||
@@ -102,7 +107,8 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string
|
||||
// 2. Parse suffix and get modelInfo
|
||||
suffixResult := ParseSuffix(model)
|
||||
baseModel := suffixResult.ModelName
|
||||
modelInfo := registry.LookupModelInfo(baseModel)
|
||||
// Use provider-specific lookup to handle capability differences across providers.
|
||||
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
|
||||
|
||||
// 3. Model capability check
|
||||
// Unknown models are treated as user-defined so thinking config can still be applied.
|
||||
|
||||
@@ -80,9 +80,66 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
|
||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
|
||||
if budgetTokens <= 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// Ensure the request satisfies Claude constraints:
|
||||
// 1) Determine effective max_tokens (request overrides model default)
|
||||
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
|
||||
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
|
||||
// 4) If max_tokens came from model default, write it back into the request
|
||||
|
||||
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
|
||||
if setDefaultMax && effectiveMax > 0 {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
|
||||
}
|
||||
|
||||
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
|
||||
adjustedBudget := budgetTokens
|
||||
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
|
||||
adjustedBudget = effectiveMax - 1
|
||||
}
|
||||
|
||||
minBudget := 0
|
||||
if modelInfo != nil && modelInfo.Thinking != nil {
|
||||
minBudget = modelInfo.Thinking.Min
|
||||
}
|
||||
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
|
||||
// If enforcing the max_tokens constraint would push the budget below the model minimum,
|
||||
// leave the request unchanged.
|
||||
return body
|
||||
}
|
||||
|
||||
if adjustedBudget != budgetTokens {
|
||||
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// effectiveMaxTokens returns the max tokens to cap thinking:
|
||||
// prefer request-provided max_tokens; otherwise fall back to model default.
|
||||
// The boolean indicates whether the value came from the model default (and thus should be written back).
|
||||
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
|
||||
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
|
||||
return int(maxTok.Int()), false
|
||||
}
|
||||
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
|
||||
return modelInfo.MaxCompletionTokens, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
|
||||
return body, nil
|
||||
|
||||
@@ -7,8 +7,6 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -19,29 +17,6 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// deriveSessionID generates a stable session ID from the request.
|
||||
// Uses the hash of the first user message to identify the conversation.
|
||||
func deriveSessionID(rawJSON []byte) string {
|
||||
messages := gjson.GetBytes(rawJSON, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
content := msg.Get("content").String()
|
||||
if content == "" {
|
||||
// Try to get text from content array
|
||||
content = msg.Get("content.0.text").String()
|
||||
}
|
||||
if content != "" {
|
||||
h := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
||||
@@ -61,11 +36,9 @@ func deriveSessionID(rawJSON []byte) string {
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
// Derive session ID for signature caching
|
||||
sessionID := deriveSessionID(rawJSON)
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
@@ -124,41 +97,49 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||
// Use GetThinkingText to handle wrapped thinking objects
|
||||
thinkingText := thinking.GetThinkingText(contentResult)
|
||||
signatureResult := contentResult.Get("signature")
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
clientSignature = signatureResult.String()
|
||||
}
|
||||
|
||||
// Always try cached signature first (more reliable than client-provided)
|
||||
// Client may send stale or invalid signatures from different sessions
|
||||
signature := ""
|
||||
if sessionID != "" && thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
||||
if thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
// log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
||||
signature = clientSignature
|
||||
if signature == "" {
|
||||
signatureResult := contentResult.Get("signature")
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
if modelName == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
if cache.HasValidSignature(modelName, clientSignature) {
|
||||
signature = clientSignature
|
||||
}
|
||||
// log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(signature) {
|
||||
if cache.HasValidSignature(modelName, signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(signature)
|
||||
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
enableThoughtTranslate = false
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -206,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
||||
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
@@ -386,7 +367,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -75,28 +76,39 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
// Valid signature must be at least 50 characters
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
// Pre-cache the signature (simulating a response from the same session)
|
||||
// The session ID is derived from the first user message hash
|
||||
// Since there's no user message in this test, we need to add one
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Test user message"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check thinking block conversion
|
||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
||||
// Check thinking block conversion (now in contents.1 due to user message)
|
||||
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
|
||||
if !firstPart.Get("thought").Bool() {
|
||||
t.Error("thinking block should have thought: true")
|
||||
}
|
||||
if firstPart.Get("text").String() != "Let me think..." {
|
||||
if firstPart.Get("text").String() != thinkingText {
|
||||
t.Error("thinking text mismatch")
|
||||
}
|
||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||
@@ -227,13 +239,19 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Test user message"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
@@ -245,11 +263,13 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function call has the signature from the preceding thinking block
|
||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
||||
// Check function call has the signature from the preceding thinking block (now in contents.1)
|
||||
part := gjson.Get(outputStr, "request.contents.1.parts.1")
|
||||
if part.Get("functionCall.name").String() != "get_weather" {
|
||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||
}
|
||||
@@ -261,24 +281,32 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Planning..."
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Test user message"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is the plan."},
|
||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify order: Thinking block MUST be first
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
|
||||
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
@@ -460,6 +488,9 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||
// Last assistant message ends with signed thinking block - should be kept
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Valid thinking..."
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
@@ -471,12 +502,14 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ type Params struct {
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Signature caching support
|
||||
SessionID string // Session ID derived from request for signature caching
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
}
|
||||
|
||||
@@ -70,9 +69,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
params := (*param).(*Params)
|
||||
|
||||
@@ -138,14 +137,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
// log.Debug("Branch: signature_delta")
|
||||
|
||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
if params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
@@ -372,7 +371,7 @@ func resolveStopReason(params *Params) string {
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
|
||||
@@ -437,7 +436,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
|
||||
@@ -97,6 +97,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
||||
}`)
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
||||
|
||||
// Verify signature was cached
|
||||
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
|
||||
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
|
||||
if cachedSig != validSignature {
|
||||
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
||||
}
|
||||
@@ -158,6 +159,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
||||
}`)
|
||||
|
||||
@@ -221,13 +223,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
// Process first thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
firstThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
||||
|
||||
// Verify first signature cached
|
||||
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
|
||||
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
|
||||
t.Error("First thinking block signature should be cached")
|
||||
}
|
||||
|
||||
@@ -241,76 +242,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
||||
|
||||
// Verify second signature cached
|
||||
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
|
||||
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
|
||||
t.Error("Second thinking block signature should be cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user message",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "user message with content array",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "no user message",
|
||||
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages",
|
||||
input: []byte(`{"messages": []}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
input: []byte(`{}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := deriveSessionID(tt.input)
|
||||
if tt.wantEmpty && result != "" {
|
||||
t.Errorf("Expected empty session ID, got '%s'", result)
|
||||
}
|
||||
if !tt.wantEmpty && result == "" {
|
||||
t.Error("Expected non-empty session ID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
|
||||
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input)
|
||||
id2 := deriveSessionID(input)
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
|
||||
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
|
||||
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input1)
|
||||
id2 := deriveSessionID(input2)
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("Different messages should produce different session IDs")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -32,12 +33,12 @@ import (
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||
@@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
|
||||
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
// Gemini-specific handling for non-Claude models:
|
||||
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
|
||||
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
// First pass: collect indices of thinking parts to remove
|
||||
var thinkingIndicesToRemove []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Mark thinking blocks for removal
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
|
||||
}
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
// First pass: collect indices of thinking parts to mark with skip sentinel
|
||||
var thinkingIndicesToSkipSignature []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Collect indices of thinking blocks to mark with skip sentinel
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Remove thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToRemove[i]
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
|
||||
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToSkipSignature[i]
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
@@ -66,6 +66,13 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
|
||||
}
|
||||
|
||||
// Candidate count (OpenAI 'n' parameter)
|
||||
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||
if val := n.Int(); val > 1 {
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
|
||||
}
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||
|
||||
@@ -117,8 +117,12 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int())
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int())
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
output = "event: message_delta\n"
|
||||
output += fmt.Sprintf("data: %s\n\n", template)
|
||||
@@ -204,8 +208,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
||||
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
hasToolCall := false
|
||||
|
||||
@@ -308,12 +316,27 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
|
||||
}
|
||||
|
||||
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
|
||||
return out
|
||||
}
|
||||
|
||||
func extractResponsesUsage(usage gjson.Result) (int64, int64, int64) {
|
||||
if !usage.Exists() || usage.Type == gjson.Null {
|
||||
return 0, 0, 0
|
||||
}
|
||||
|
||||
return out
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
outputTokens := usage.Get("output_tokens").Int()
|
||||
cachedTokens := usage.Get("input_tokens_details.cached_tokens").Int()
|
||||
|
||||
if cachedTokens > 0 {
|
||||
if inputTokens >= cachedTokens {
|
||||
inputTokens -= cachedTokens
|
||||
} else {
|
||||
inputTokens = 0
|
||||
}
|
||||
}
|
||||
|
||||
return inputTokens, outputTokens, cachedTokens
|
||||
}
|
||||
|
||||
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
|
||||
|
||||
@@ -63,6 +63,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
|
||||
}
|
||||
|
||||
// Candidate count (OpenAI 'n' parameter)
|
||||
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||
if val := n.Int(); val > 1 {
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
|
||||
}
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||
|
||||
@@ -63,6 +63,13 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.topK", tkr.Num)
|
||||
}
|
||||
|
||||
// Candidate count (OpenAI 'n' parameter)
|
||||
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
|
||||
if val := n.Int(); val > 1 {
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.candidateCount", val)
|
||||
}
|
||||
}
|
||||
|
||||
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
|
||||
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
|
||||
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
|
||||
|
||||
@@ -21,7 +21,8 @@ import (
|
||||
// convertGeminiResponseToOpenAIChatParams holds parameters for response conversion.
|
||||
type convertGeminiResponseToOpenAIChatParams struct {
|
||||
UnixTimestamp int64
|
||||
FunctionIndex int
|
||||
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
|
||||
FunctionIndex map[int]int
|
||||
}
|
||||
|
||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
@@ -42,13 +43,20 @@ var functionCallIDCounter uint64
|
||||
// Returns:
|
||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||
func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
// Initialize parameters if nil.
|
||||
if *param == nil {
|
||||
*param = &convertGeminiResponseToOpenAIChatParams{
|
||||
UnixTimestamp: 0,
|
||||
FunctionIndex: 0,
|
||||
FunctionIndex: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the Map is initialized (handling cases where param might be reused from older context).
|
||||
p := (*param).(*convertGeminiResponseToOpenAIChatParams)
|
||||
if p.FunctionIndex == nil {
|
||||
p.FunctionIndex = make(map[int]int)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
@@ -57,151 +65,179 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
// Initialize the OpenAI SSE base template.
|
||||
// We use a base template and clone it for each candidate to support multiple candidates.
|
||||
baseTemplate := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createTimeResult := gjson.GetBytes(rawJSON, "createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
if err == nil {
|
||||
(*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
|
||||
p.UnixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", (*param).(*convertGeminiResponseToOpenAIChatParams).UnixTimestamp)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
// Usage is applied to the base template so it appears in the chunks.
|
||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
baseTemplate, err = sjson.Set(baseTemplate, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJSON, "candidates.0.content.parts")
|
||||
hasFunctionCall := false
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
var responseStrings []string
|
||||
candidates := gjson.GetBytes(rawJSON, "candidates")
|
||||
|
||||
// Iterate over all candidates to support candidate_count > 1.
|
||||
if candidates.IsArray() {
|
||||
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||
// Clone the template for the current candidate.
|
||||
template := baseTemplate
|
||||
|
||||
// Set the specific index for this candidate.
|
||||
candidateIndex := int(candidate.Get("index").Int())
|
||||
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
}
|
||||
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
partsResult := candidate.Get("content.parts")
|
||||
hasFunctionCall := false
|
||||
|
||||
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
if partTextResult.Exists() {
|
||||
text := partTextResult.String()
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", text)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
|
||||
// Retrieve the function index for this specific candidate.
|
||||
functionCallIndex := p.FunctionIndex[candidateIndex]
|
||||
p.FunctionIndex[candidateIndex]++
|
||||
|
||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||
functionCallIndex = len(toolCallsResult.Array())
|
||||
} else {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
mimeType := inlineDataResult.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = inlineDataResult.Get("mime_type").String()
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if partTextResult.Exists() {
|
||||
text := partTextResult.String()
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", text)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
hasFunctionCall = true
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
functionCallIndex := (*param).(*convertGeminiResponseToOpenAIChatParams).FunctionIndex
|
||||
(*param).(*convertGeminiResponseToOpenAIChatParams).FunctionIndex++
|
||||
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
|
||||
functionCallIndex = len(toolCallsResult.Array())
|
||||
} else {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
mimeType := inlineDataResult.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = inlineDataResult.Get("mime_type").String()
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
}
|
||||
|
||||
responseStrings = append(responseStrings, template)
|
||||
return true // continue loop
|
||||
})
|
||||
} else {
|
||||
// If there are no candidates (e.g., a pure usageMetadata chunk), return the usage chunk if present.
|
||||
if gjson.GetBytes(rawJSON, "usageMetadata").Exists() && len(responseStrings) == 0 {
|
||||
responseStrings = append(responseStrings, baseTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
return responseStrings
|
||||
}
|
||||
|
||||
// ConvertGeminiResponseToOpenAINonStream converts a non-streaming Gemini response to a non-streaming OpenAI response.
|
||||
@@ -219,7 +255,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
// - string: An OpenAI-compatible JSON response containing all message content and metadata
|
||||
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
var unixTimestamp int64
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
// Initialize template with an empty choices array to support multiple candidates.
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[]}`
|
||||
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
@@ -238,11 +276,6 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
}
|
||||
|
||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
@@ -267,74 +300,96 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJSON, "candidates.0.content.parts")
|
||||
hasFunctionCall := false
|
||||
if partsResult.IsArray() {
|
||||
partsResults := partsResult.Array()
|
||||
for i := 0; i < len(partsResults); i++ {
|
||||
partResult := partsResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
// Process the main content part of the response for all candidates.
|
||||
candidates := gjson.GetBytes(rawJSON, "candidates")
|
||||
if candidates.IsArray() {
|
||||
candidates.ForEach(func(_, candidate gjson.Result) bool {
|
||||
// Construct a single Choice object.
|
||||
choiceTemplate := `{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}`
|
||||
|
||||
// Set the index for this choice.
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "index", candidate.Get("index").Int())
|
||||
|
||||
// Set finish reason.
|
||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
}
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Append text content, distinguishing between regular content and reasoning.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Append function call content to the tool_calls array.
|
||||
hasFunctionCall = true
|
||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
}
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
mimeType := inlineDataResult.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = inlineDataResult.Get("mime_type").String()
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(template, "choices.0.message.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
partsResult := candidate.Get("content.parts")
|
||||
hasFunctionCall := false
|
||||
if partsResult.IsArray() {
|
||||
partsResults := partsResult.Array()
|
||||
for i := 0; i < len(partsResults); i++ {
|
||||
partResult := partsResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
if partTextResult.Exists() {
|
||||
// Append text content, distinguishing between regular content and reasoning.
|
||||
if partResult.Get("thought").Bool() {
|
||||
oldVal := gjson.Get(choiceTemplate, "message.reasoning_content").String()
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.reasoning_content", oldVal+partTextResult.String())
|
||||
} else {
|
||||
oldVal := gjson.Get(choiceTemplate, "message.content").String()
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.content", oldVal+partTextResult.String())
|
||||
}
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Append function call content to the tool_calls array.
|
||||
hasFunctionCall = true
|
||||
toolCallsResult := gjson.Get(choiceTemplate, "message.tool_calls")
|
||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls", `[]`)
|
||||
}
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls.-1", functionCallItemTemplate)
|
||||
} else if inlineDataResult.Exists() {
|
||||
data := inlineDataResult.Get("data").String()
|
||||
if data != "" {
|
||||
mimeType := inlineDataResult.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = inlineDataResult.Get("mime_type").String()
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagesResult := gjson.Get(choiceTemplate, "message.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(choiceTemplate, "message.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
|
||||
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images.-1", imagePayload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasFunctionCall {
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", "tool_calls")
|
||||
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", "tool_calls")
|
||||
}
|
||||
|
||||
// Append the constructed choice to the main choices array.
|
||||
template, _ = sjson.SetRaw(template, "choices.-1", choiceTemplate)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return template
|
||||
|
||||
@@ -298,6 +298,15 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
||||
|
||||
case "reasoning":
|
||||
thoughtContent := `{"role":"model","parts":[]}`
|
||||
thought := `{"text":"","thoughtSignature":"","thought":true}`
|
||||
thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String())
|
||||
thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String())
|
||||
|
||||
thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent)
|
||||
}
|
||||
}
|
||||
} else if input.Exists() && input.Type == gjson.String {
|
||||
|
||||
@@ -20,6 +20,7 @@ type geminiToResponsesState struct {
|
||||
|
||||
// message aggregation
|
||||
MsgOpened bool
|
||||
MsgClosed bool
|
||||
MsgIndex int
|
||||
CurrentMsgID string
|
||||
TextBuf strings.Builder
|
||||
@@ -29,6 +30,7 @@ type geminiToResponsesState struct {
|
||||
ReasoningOpened bool
|
||||
ReasoningIndex int
|
||||
ReasoningItemID string
|
||||
ReasoningEnc string
|
||||
ReasoningBuf strings.Builder
|
||||
ReasoningClosed bool
|
||||
|
||||
@@ -37,6 +39,7 @@ type geminiToResponsesState struct {
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
}
|
||||
|
||||
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||
@@ -45,6 +48,39 @@ var responseIDCounter uint64
|
||||
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
var funcCallIDCounter uint64
|
||||
|
||||
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
|
||||
return originalRequestRawJSON
|
||||
}
|
||||
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
|
||||
return requestRawJSON
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unwrapRequestRoot(root gjson.Result) gjson.Result {
|
||||
req := root.Get("request")
|
||||
if !req.Exists() {
|
||||
return root
|
||||
}
|
||||
if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() {
|
||||
return req
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result {
|
||||
resp := root.Get("response")
|
||||
if !resp.Exists() {
|
||||
return root
|
||||
}
|
||||
// Vertex-style Gemini responses wrap the actual payload in a "response" object.
|
||||
if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() {
|
||||
return resp
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func emitEvent(event string, payload string) string {
|
||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||
}
|
||||
@@ -56,18 +92,37 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
st := (*param).(*geminiToResponsesState)
|
||||
if st.FuncArgsBuf == nil {
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
}
|
||||
if st.FuncNames == nil {
|
||||
st.FuncNames = make(map[int]string)
|
||||
}
|
||||
if st.FuncCallIDs == nil {
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
}
|
||||
if st.FuncDone == nil {
|
||||
st.FuncDone = make(map[int]bool)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
|
||||
rawJSON = bytes.TrimSpace(rawJSON)
|
||||
if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
if !root.Exists() {
|
||||
return []string{}
|
||||
}
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
|
||||
var out []string
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
@@ -98,19 +153,54 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.ReasoningClosed = true
|
||||
}
|
||||
|
||||
// Helper to finalize the assistant message in correct order.
|
||||
// It emits response.output_text.done, response.content_part.done,
|
||||
// and response.output_item.done exactly once.
|
||||
finalizeMessage := func() {
|
||||
if !st.MsgOpened || st.MsgClosed {
|
||||
return
|
||||
}
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
|
||||
st.MsgClosed = true
|
||||
}
|
||||
|
||||
// Initialize per-response fields and emit created/in_progress once
|
||||
if !st.Started {
|
||||
if v := root.Get("responseId"); v.Exists() {
|
||||
st.ResponseID = v.String()
|
||||
st.ResponseID = root.Get("responseId").String()
|
||||
if st.ResponseID == "" {
|
||||
st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
|
||||
}
|
||||
if !strings.HasPrefix(st.ResponseID, "resp_") {
|
||||
st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID)
|
||||
}
|
||||
if v := root.Get("createTime"); v.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
||||
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||
st.CreatedAt = t.Unix()
|
||||
}
|
||||
}
|
||||
@@ -143,15 +233,21 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// Ignore any late thought chunks after reasoning is finalized.
|
||||
return true
|
||||
}
|
||||
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||
st.ReasoningEnc = sig.String()
|
||||
} else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||
st.ReasoningEnc = sig.String()
|
||||
}
|
||||
if !st.ReasoningOpened {
|
||||
st.ReasoningOpened = true
|
||||
st.ReasoningIndex = st.NextIndex
|
||||
st.NextIndex++
|
||||
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
|
||||
@@ -191,9 +287,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
||||
out = append(out, emitEvent("response.content_part.added", partAdded))
|
||||
st.ItemTextBuf.Reset()
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
}
|
||||
st.TextBuf.WriteString(t.String())
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
||||
@@ -205,8 +301,10 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
|
||||
// Function call
|
||||
if fc := part.Get("functionCall"); fc.Exists() {
|
||||
// Before emitting function-call outputs, finalize reasoning if open.
|
||||
// Before emitting function-call outputs, finalize reasoning and the message (if open).
|
||||
// Responses streaming requires message done events before the next output_item.added.
|
||||
finalizeReasoning()
|
||||
finalizeMessage()
|
||||
name := fc.Get("name").String()
|
||||
idx := st.NextIndex
|
||||
st.NextIndex++
|
||||
@@ -219,6 +317,14 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
st.FuncNames[idx] = name
|
||||
|
||||
argsJSON := "{}"
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
argsJSON = args.Raw
|
||||
}
|
||||
if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" {
|
||||
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
||||
}
|
||||
|
||||
// Emit item.added for function call
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
@@ -228,10 +334,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
item, _ = sjson.Set(item, "item.name", name)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
|
||||
// Emit arguments delta (full args in one chunk)
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
argsJSON := args.Raw
|
||||
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
||||
// Emit arguments delta (full args in one chunk).
|
||||
// When Gemini omits args, emit "{}" to keep Responses streaming event order consistent.
|
||||
if argsJSON != "" {
|
||||
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
||||
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
@@ -240,6 +345,27 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
|
||||
// Gemini emits the full function call payload at once, so we can finalize it immediately.
|
||||
if !st.FuncDone[idx] {
|
||||
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
|
||||
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON)
|
||||
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.FuncDone[idx] = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -251,28 +377,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
|
||||
// Finalize reasoning first to keep ordering tight with last delta
|
||||
finalizeReasoning()
|
||||
// Close message output if opened
|
||||
if st.MsgOpened {
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
}
|
||||
finalizeMessage()
|
||||
|
||||
// Close function calls
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
@@ -289,6 +394,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
if st.FuncDone[idx] {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
@@ -308,6 +416,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.FuncDone[idx] = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,8 +429,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||
}
|
||||
@@ -383,41 +493,34 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
|
||||
// Compose outputs in encountered order: reasoning, message, function_calls
|
||||
// Compose outputs in output_index order.
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
if st.ReasoningOpened {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if st.MsgOpened {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for idx := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, idx)
|
||||
for idx := 0; idx < st.NextIndex; idx++ {
|
||||
if st.ReasoningOpened && idx == st.ReasoningIndex {
|
||||
item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
continue
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
if st.MsgOpened && idx == st.MsgIndex {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
continue
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[idx]; b != nil {
|
||||
|
||||
if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" {
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
|
||||
item, _ = sjson.Set(item, "call_id", callID)
|
||||
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
@@ -431,8 +534,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
||||
@@ -460,6 +563,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
|
||||
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
|
||||
// Base response scaffold
|
||||
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
|
||||
@@ -478,15 +582,15 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
// created_at: map from createTime if available
|
||||
createdAt := time.Now().Unix()
|
||||
if v := root.Get("createTime"); v.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
||||
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||
createdAt = t.Unix()
|
||||
}
|
||||
}
|
||||
resp, _ = sjson.Set(resp, "created_at", createdAt)
|
||||
|
||||
// Echo request fields when present; fallback model from response modelVersion
|
||||
if len(requestRawJSON) > 0 {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
resp, _ = sjson.Set(resp, "instructions", v.String())
|
||||
}
|
||||
@@ -636,8 +740,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", 0)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(chunk, "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) {
|
||||
// Vertex-style Gemini stream wraps the actual response payload under "response".
|
||||
// This test ensures we unwrap and that output_text.done contains the full text.
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`)
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var (
|
||||
gotTextDone bool
|
||||
gotMessageDone bool
|
||||
gotResponseDone bool
|
||||
gotFuncDone bool
|
||||
|
||||
textDone string
|
||||
messageText string
|
||||
responseID string
|
||||
instructions string
|
||||
cachedTokens int64
|
||||
|
||||
funcName string
|
||||
funcArgs string
|
||||
|
||||
posTextDone = -1
|
||||
posPartDone = -1
|
||||
posMessageDone = -1
|
||||
posFuncAdded = -1
|
||||
)
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_text.done":
|
||||
gotTextDone = true
|
||||
if posTextDone == -1 {
|
||||
posTextDone = i
|
||||
}
|
||||
textDone = data.Get("text").String()
|
||||
case "response.content_part.done":
|
||||
if posPartDone == -1 {
|
||||
posPartDone = i
|
||||
}
|
||||
case "response.output_item.done":
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
gotMessageDone = true
|
||||
if posMessageDone == -1 {
|
||||
posMessageDone = i
|
||||
}
|
||||
messageText = data.Get("item.content.0.text").String()
|
||||
case "function_call":
|
||||
gotFuncDone = true
|
||||
funcName = data.Get("item.name").String()
|
||||
funcArgs = data.Get("item.arguments").String()
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 {
|
||||
posFuncAdded = i
|
||||
}
|
||||
case "response.completed":
|
||||
gotResponseDone = true
|
||||
responseID = data.Get("response.id").String()
|
||||
instructions = data.Get("response.instructions").String()
|
||||
cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int()
|
||||
}
|
||||
}
|
||||
|
||||
if !gotTextDone {
|
||||
t.Fatalf("missing response.output_text.done event")
|
||||
}
|
||||
if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 {
|
||||
t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||
}
|
||||
if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) {
|
||||
t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||
}
|
||||
if !gotMessageDone {
|
||||
t.Fatalf("missing message response.output_item.done event")
|
||||
}
|
||||
if !gotFuncDone {
|
||||
t.Fatalf("missing function_call response.output_item.done event")
|
||||
}
|
||||
if !gotResponseDone {
|
||||
t.Fatalf("missing response.completed event")
|
||||
}
|
||||
|
||||
if textDone != "让我先了解" {
|
||||
t.Fatalf("unexpected output_text.done text: got %q", textDone)
|
||||
}
|
||||
if messageText != "让我先了解" {
|
||||
t.Fatalf("unexpected message done text: got %q", messageText)
|
||||
}
|
||||
|
||||
if responseID != "resp_req_vrtx_1" {
|
||||
t.Fatalf("unexpected response id: got %q", responseID)
|
||||
}
|
||||
if instructions != "test instructions" {
|
||||
t.Fatalf("unexpected instructions echo: got %q", instructions)
|
||||
}
|
||||
if cachedTokens != 2 {
|
||||
t.Fatalf("unexpected cached token count: got %d", cachedTokens)
|
||||
}
|
||||
|
||||
if funcName != "mcp__serena__list_dir" {
|
||||
t.Fatalf("unexpected function name: got %q", funcName)
|
||||
}
|
||||
if !gjson.Valid(funcArgs) {
|
||||
t.Fatalf("invalid function arguments JSON: %q", funcArgs)
|
||||
}
|
||||
if gjson.Get(funcArgs, "recursive").Bool() != false {
|
||||
t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value())
|
||||
}
|
||||
if gjson.Get(funcArgs, "relative_path").String() != "internal" {
|
||||
t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) {
|
||||
sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw=="
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var (
|
||||
addedEnc string
|
||||
doneEnc string
|
||||
)
|
||||
for _, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "reasoning" {
|
||||
addedEnc = data.Get("item.encrypted_content").String()
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "reasoning" {
|
||||
doneEnc = data.Get("item.encrypted_content").String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addedEnc != sig {
|
||||
t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc)
|
||||
}
|
||||
if doneEnc != sig {
|
||||
t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
posAdded := []int{-1, -1, -1}
|
||||
posArgsDelta := []int{-1, -1, -1}
|
||||
posArgsDone := []int{-1, -1, -1}
|
||||
posItemDone := []int{-1, -1, -1}
|
||||
posCompleted := -1
|
||||
deltaByIndex := map[int]string{}
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posAdded) {
|
||||
posAdded[idx] = i
|
||||
}
|
||||
case "response.function_call_arguments.delta":
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posArgsDelta) {
|
||||
posArgsDelta[idx] = i
|
||||
deltaByIndex[idx] = data.Get("delta").String()
|
||||
}
|
||||
case "response.function_call_arguments.done":
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posArgsDone) {
|
||||
posArgsDone[idx] = i
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posItemDone) {
|
||||
posItemDone[idx] = i
|
||||
}
|
||||
case "response.completed":
|
||||
posCompleted = i
|
||||
|
||||
output := data.Get("response.output")
|
||||
if !output.Exists() || !output.IsArray() {
|
||||
t.Fatalf("missing response.output in response.completed")
|
||||
}
|
||||
if len(output.Array()) != 3 {
|
||||
t.Fatalf("unexpected response.output length: got %d", len(output.Array()))
|
||||
}
|
||||
if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" {
|
||||
t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" {
|
||||
t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
if data.Get("response.output.2.name").String() != "tool2" {
|
||||
t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw)
|
||||
}
|
||||
if !gjson.Valid(data.Get("response.output.2.arguments").String()) {
|
||||
t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if posCompleted == -1 {
|
||||
t.Fatalf("missing response.completed event")
|
||||
}
|
||||
for idx := 0; idx < 3; idx++ {
|
||||
if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 {
|
||||
t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||
}
|
||||
if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) {
|
||||
t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||
}
|
||||
if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) {
|
||||
t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if deltaByIndex[0] != "{}" {
|
||||
t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0])
|
||||
}
|
||||
if deltaByIndex[1] != "{}" {
|
||||
t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1])
|
||||
}
|
||||
if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 {
|
||||
t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2])
|
||||
}
|
||||
if !(posItemDone[2] < posCompleted) {
|
||||
t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
posFuncDone := -1
|
||||
posMsgAdded := -1
|
||||
posCompleted := -1
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 {
|
||||
posFuncDone = i
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 {
|
||||
posMsgAdded = i
|
||||
}
|
||||
case "response.completed":
|
||||
posCompleted = i
|
||||
if data.Get("response.output.0.type").String() != "function_call" {
|
||||
t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.type").String() != "message" {
|
||||
t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.content.0.text").String() != "hi" {
|
||||
t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 {
|
||||
t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted)
|
||||
}
|
||||
if !(posFuncDone < posMsgAdded) {
|
||||
t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded)
|
||||
}
|
||||
if !(posMsgAdded < posCompleted) {
|
||||
t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted)
|
||||
}
|
||||
}
|
||||
@@ -88,7 +88,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
var messagesJSON = "[]"
|
||||
|
||||
// Handle system message first
|
||||
systemMsgJSON := `{"role":"system","content":[{"type":"text","text":"Use ANY tool, the parameters MUST accord with RFC 8259 (The JavaScript Object Notation (JSON) Data Interchange Format), the keys and value MUST be enclosed in double quotes."}]}`
|
||||
systemMsgJSON := `{"role":"system","content":[]}`
|
||||
if system := root.Get("system"); system.Exists() {
|
||||
if system.Type == gjson.String {
|
||||
if system.String() != "" {
|
||||
|
||||
@@ -289,21 +289,17 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Only process if usage has actual values (not null)
|
||||
if param.FinishReason != "" {
|
||||
usage := root.Get("usage")
|
||||
var inputTokens, outputTokens int64
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
if usage.Exists() && usage.Type != gjson.Null {
|
||||
// Check if usage has actual token counts
|
||||
promptTokens := usage.Get("prompt_tokens")
|
||||
completionTokens := usage.Get("completion_tokens")
|
||||
|
||||
if promptTokens.Exists() && completionTokens.Exists() {
|
||||
inputTokens = promptTokens.Int()
|
||||
outputTokens = completionTokens.Int()
|
||||
}
|
||||
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
|
||||
// Send message_delta with usage
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
|
||||
@@ -423,13 +419,12 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
||||
|
||||
// Set usage information
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
|
||||
reasoningTokens := int64(0)
|
||||
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
|
||||
reasoningTokens = v.Int()
|
||||
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(usage)
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
|
||||
}
|
||||
|
||||
return []string{out}
|
||||
@@ -674,8 +669,12 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
|
||||
if respUsage := root.Get("usage"); respUsage.Exists() {
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
|
||||
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(respUsage)
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if !stopReasonSet {
|
||||
@@ -692,3 +691,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"input_tokens":%d}`, count)
|
||||
}
|
||||
|
||||
func extractOpenAIUsage(usage gjson.Result) (int64, int64, int64) {
|
||||
if !usage.Exists() || usage.Type == gjson.Null {
|
||||
return 0, 0, 0
|
||||
}
|
||||
|
||||
inputTokens := usage.Get("prompt_tokens").Int()
|
||||
outputTokens := usage.Get("completion_tokens").Int()
|
||||
cachedTokens := usage.Get("prompt_tokens_details.cached_tokens").Int()
|
||||
|
||||
if cachedTokens > 0 {
|
||||
if inputTokens >= cachedTokens {
|
||||
inputTokens -= cachedTokens
|
||||
} else {
|
||||
inputTokens = 0
|
||||
}
|
||||
}
|
||||
|
||||
return inputTokens, outputTokens, cachedTokens
|
||||
}
|
||||
|
||||
@@ -77,7 +77,13 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
}
|
||||
|
||||
// Candidate count (OpenAI 'n' parameter)
|
||||
if candidateCount := genConfig.Get("candidateCount"); candidateCount.Exists() {
|
||||
out, _ = sjson.Set(out, "n", candidateCount.Int())
|
||||
}
|
||||
|
||||
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
|
||||
@@ -2,6 +2,8 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -191,7 +193,19 @@ waitForCallback:
|
||||
return nil, fmt.Errorf("codex token storage missing account information")
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("codex-%s.json", tokenStorage.Email)
|
||||
planType := ""
|
||||
hashAccountID := ""
|
||||
if tokenStorage.IDToken != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
}
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -76,7 +75,7 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
if existing, errRead := os.ReadFile(path); errRead == nil {
|
||||
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change.
|
||||
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
|
||||
if metadataEqualIgnoringTimestamps(existing, raw) {
|
||||
if metadataEqualIgnoringTimestamps(existing, raw, auth.Provider) {
|
||||
return path, nil
|
||||
}
|
||||
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||
@@ -300,28 +299,101 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs, ignoring volatile fields that
|
||||
// change on every refresh but don't affect authentication logic.
|
||||
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
||||
var objA map[string]any
|
||||
var objB map[string]any
|
||||
if errUnmarshalA := json.Unmarshal(a, &objA); errUnmarshalA != nil {
|
||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
||||
// This function is kept for backward compatibility but can cause refresh loops.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
var objB any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
if errUnmarshalB := json.Unmarshal(b, &objB); errUnmarshalB != nil {
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
stripVolatileMetadataFields(objA)
|
||||
stripVolatileMetadataFields(objB)
|
||||
return reflect.DeepEqual(objA, objB)
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
func stripVolatileMetadataFields(metadata map[string]any) {
|
||||
if metadata == nil {
|
||||
return
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
||||
// ignoring fields that change on every refresh but don't affect functionality.
|
||||
// This prevents unnecessary file writes that would trigger watcher events and
|
||||
// create refresh loops.
|
||||
// The provider parameter controls whether access_token is ignored: providers like
|
||||
// Google OAuth (gemini, gemini-cli) can re-fetch tokens when needed, while others
|
||||
// like iFlow require the refreshed token to be persisted.
|
||||
func metadataEqualIgnoringTimestamps(a, b []byte, provider string) bool {
|
||||
var objA, objB map[string]any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
// These fields change on refresh and would otherwise trigger watcher reload loops.
|
||||
for _, field := range []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"} {
|
||||
delete(metadata, field)
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh"}
|
||||
|
||||
// For providers that can re-fetch tokens when needed (e.g., Google OAuth),
|
||||
// we ignore access_token to avoid unnecessary file writes.
|
||||
switch provider {
|
||||
case "gemini", "gemini-cli", "antigravity":
|
||||
ignoredFields = append(ignoredFields, "access_token")
|
||||
}
|
||||
|
||||
for _, field := range ignoredFields {
|
||||
delete(objA, field)
|
||||
delete(objB, field)
|
||||
}
|
||||
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
func deepEqualJSON(a, b any) bool {
|
||||
switch valA := a.(type) {
|
||||
case map[string]any:
|
||||
valB, ok := b.(map[string]any)
|
||||
if !ok || len(valA) != len(valB) {
|
||||
return false
|
||||
}
|
||||
for key, subA := range valA {
|
||||
subB, ok1 := valB[key]
|
||||
if !ok1 || !deepEqualJSON(subA, subB) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case []any:
|
||||
sliceB, ok := b.([]any)
|
||||
if !ok || len(valA) != len(sliceB) {
|
||||
return false
|
||||
}
|
||||
for i := range valA {
|
||||
if !deepEqualJSON(valA[i], sliceB[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case float64:
|
||||
valB, ok := b.(float64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case string:
|
||||
valB, ok := b.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case bool:
|
||||
valB, ok := b.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case nil:
|
||||
return b == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// thinkingTestCase represents a common test case structure for both suffix and body tests.
|
||||
@@ -2707,8 +2708,11 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
|
||||
[]byte(tc.inputJSON),
|
||||
true,
|
||||
)
|
||||
if applyTo == "claude" {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", 200000)
|
||||
}
|
||||
|
||||
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo)
|
||||
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo, applyTo)
|
||||
|
||||
if tc.expectErr {
|
||||
if err == nil {
|
||||
|
||||
Reference in New Issue
Block a user