mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-12 17:30:51 +08:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c32e2a8196 | ||
|
|
873d41582f | ||
|
|
6fb7d85558 | ||
|
|
d5e3e32d58 | ||
|
|
f353a54555 | ||
|
|
1d6e2e751d | ||
|
|
cc50b63422 | ||
|
|
15ae83a15b | ||
|
|
81b369aed9 | ||
|
|
ecc850bfb7 | ||
|
|
19b4ef33e0 | ||
|
|
7ca045d8b9 | ||
|
|
abfca6aab2 | ||
|
|
3c71c075db | ||
|
|
9c2992bfb2 | ||
|
|
269a1c5452 | ||
|
|
22ce65ac72 | ||
|
|
a2f8f59192 | ||
|
|
30a59168d7 | ||
|
|
c8884f5e25 | ||
|
|
d9c6317c84 | ||
|
|
d29ec95526 | ||
|
|
ef4508dbc8 | ||
|
|
f775e46fe2 | ||
|
|
65ad5c0c9d | ||
|
|
88bf4e77ec | ||
|
|
a4f8015caa | ||
|
|
ffd129909e | ||
|
|
9332316383 | ||
|
|
6dcbbf64c3 | ||
|
|
2ce3553612 | ||
|
|
2e14f787d4 | ||
|
|
523b41ccd2 | ||
|
|
09970dc7af | ||
|
|
d81abd401c | ||
|
|
a6cba25bc1 | ||
|
|
c6fa1d0e67 | ||
|
|
ac56e1e88b | ||
|
|
9b72ea9efa | ||
|
|
9f364441e8 | ||
|
|
e49a1c07bf | ||
|
|
8d9f4edf9b | ||
|
|
2f6004d74a | ||
|
|
a1634909e8 |
2
go.mod
2
go.mod
@@ -21,6 +21,7 @@ require (
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.31.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -70,7 +71,6 @@ require (
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
@@ -749,6 +749,72 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
return err
|
||||
}
|
||||
|
||||
// PatchAuthFileStatus toggles the disabled state of an auth file
|
||||
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
|
||||
if h.authManager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Disabled *bool `json:"disabled"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(req.Name)
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
if req.Disabled == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Find auth by name or ID
|
||||
var targetAuth *coreauth.Auth
|
||||
if auth, ok := h.authManager.GetByID(name); ok {
|
||||
targetAuth = auth
|
||||
} else {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name {
|
||||
targetAuth = auth
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetAuth == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update disabled state
|
||||
targetAuth.Disabled = *req.Disabled
|
||||
if *req.Disabled {
|
||||
targetAuth.Status = coreauth.StatusDisabled
|
||||
targetAuth.StatusMessage = "disabled via management API"
|
||||
} else {
|
||||
targetAuth.Status = coreauth.StatusActive
|
||||
targetAuth.StatusMessage = ""
|
||||
}
|
||||
targetAuth.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
|
||||
@@ -610,6 +610,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
|
||||
@@ -4,9 +4,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
|
||||
@@ -43,15 +40,7 @@ func normalizePlanTypeForFilename(planType string) string {
|
||||
}
|
||||
|
||||
for i, part := range parts {
|
||||
parts[i] = titleToken(part)
|
||||
parts[i] = strings.ToLower(strings.TrimSpace(part))
|
||||
}
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
func titleToken(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
return cases.Title(language.English).String(token)
|
||||
}
|
||||
|
||||
115
internal/cache/signature_cache.go
vendored
115
internal/cache/signature_cache.go
vendored
@@ -3,7 +3,7 @@ package cache
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -24,18 +24,18 @@ const (
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
|
||||
// SessionCleanupInterval controls how often stale sessions are purged
|
||||
SessionCleanupInterval = 10 * time.Minute
|
||||
// CacheCleanupInterval controls how often stale entries are purged
|
||||
CacheCleanupInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
||||
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var sessionCleanupOnce sync.Once
|
||||
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
|
||||
var cacheCleanupOnce sync.Once
|
||||
|
||||
// sessionCache is the inner map type
|
||||
type sessionCache struct {
|
||||
// groupCache is the inner map type
|
||||
type groupCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]SignatureEntry
|
||||
}
|
||||
@@ -46,36 +46,36 @@ func hashText(text string) string {
|
||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||
}
|
||||
|
||||
// getOrCreateSession gets or creates a session cache
|
||||
func getOrCreateSession(sessionID string) *sessionCache {
|
||||
// getOrCreateGroupCache gets or creates a cache bucket for a model group
|
||||
func getOrCreateGroupCache(groupKey string) *groupCache {
|
||||
// Start background cleanup on first access
|
||||
sessionCleanupOnce.Do(startSessionCleanup)
|
||||
cacheCleanupOnce.Do(startCacheCleanup)
|
||||
|
||||
if val, ok := signatureCache.Load(sessionID); ok {
|
||||
return val.(*sessionCache)
|
||||
if val, ok := signatureCache.Load(groupKey); ok {
|
||||
return val.(*groupCache)
|
||||
}
|
||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
||||
return actual.(*sessionCache)
|
||||
sc := &groupCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
|
||||
return actual.(*groupCache)
|
||||
}
|
||||
|
||||
// startSessionCleanup launches a background goroutine that periodically
|
||||
// removes sessions where all entries have expired.
|
||||
func startSessionCleanup() {
|
||||
// startCacheCleanup launches a background goroutine that periodically
|
||||
// removes caches where all entries have expired.
|
||||
func startCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(SessionCleanupInterval)
|
||||
ticker := time.NewTicker(CacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
purgeExpiredSessions()
|
||||
purgeExpiredCaches()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
|
||||
func purgeExpiredSessions() {
|
||||
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
|
||||
func purgeExpiredCaches() {
|
||||
now := time.Now()
|
||||
signatureCache.Range(func(key, value any) bool {
|
||||
sc := value.(*sessionCache)
|
||||
sc := value.(*groupCache)
|
||||
sc.mu.Lock()
|
||||
// Remove expired entries
|
||||
for k, entry := range sc.entries {
|
||||
@@ -85,7 +85,7 @@ func purgeExpiredSessions() {
|
||||
}
|
||||
isEmpty := len(sc.entries) == 0
|
||||
sc.mu.Unlock()
|
||||
// Remove session if empty
|
||||
// Remove cache bucket if empty
|
||||
if isEmpty {
|
||||
signatureCache.Delete(key)
|
||||
}
|
||||
@@ -93,19 +93,19 @@ func purgeExpiredSessions() {
|
||||
})
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// CacheSignature stores a thinking signature for a given model group and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(modelName, sessionID, text, signature string) {
|
||||
if sessionID == "" || text == "" || signature == "" {
|
||||
func CacheSignature(modelName, text, signature string) {
|
||||
if text == "" || signature == "" {
|
||||
return
|
||||
}
|
||||
if len(signature) < MinValidSignatureLen {
|
||||
return
|
||||
}
|
||||
|
||||
sc := getOrCreateSession(fmt.Sprintf("%s#%s", modelName, sessionID))
|
||||
groupKey := GetModelGroup(modelName)
|
||||
textHash := hashText(text)
|
||||
|
||||
sc := getOrCreateGroupCache(groupKey)
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
@@ -115,18 +115,25 @@ func CacheSignature(modelName, sessionID, text, signature string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
||||
// GetCachedSignature retrieves a cached signature for a given model group and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(modelName, sessionID, text string) string {
|
||||
if sessionID == "" || text == "" {
|
||||
return ""
|
||||
}
|
||||
func GetCachedSignature(modelName, text string) string {
|
||||
groupKey := GetModelGroup(modelName)
|
||||
|
||||
val, ok := signatureCache.Load(fmt.Sprintf("%s#%s", modelName, sessionID))
|
||||
if !ok {
|
||||
if text == "" {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
sc := val.(*sessionCache)
|
||||
val, ok := signatureCache.Load(groupKey)
|
||||
if !ok {
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
sc := val.(*groupCache)
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
@@ -136,11 +143,17 @@ func GetCachedSignature(modelName, sessionID, text string) string {
|
||||
entry, exists := sc.entries[textHash]
|
||||
if !exists {
|
||||
sc.mu.Unlock()
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
if groupKey == "gemini" {
|
||||
return "skip_thought_signature_validator"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -152,19 +165,31 @@ func GetCachedSignature(modelName, sessionID, text string) string {
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
||||
func ClearSignatureCache(sessionID string) {
|
||||
if sessionID != "" {
|
||||
signatureCache.Delete(sessionID)
|
||||
} else {
|
||||
// ClearSignatureCache clears signature cache for a specific model group or all groups.
|
||||
func ClearSignatureCache(modelName string) {
|
||||
if modelName == "" {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
groupKey := GetModelGroup(modelName)
|
||||
signatureCache.Delete(groupKey)
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
func HasValidSignature(signature string) bool {
|
||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
||||
func HasValidSignature(modelName, signature string) bool {
|
||||
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
|
||||
}
|
||||
|
||||
func GetModelGroup(modelName string) string {
|
||||
if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
} else if strings.Contains(modelName, "claude") {
|
||||
return "claude"
|
||||
} else if strings.Contains(modelName, "gemini") {
|
||||
return "gemini"
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
110
internal/cache/signature_cache_test.go
vendored
110
internal/cache/signature_cache_test.go
vendored
@@ -5,38 +5,40 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const testModelName = "claude-sonnet-4-5"
|
||||
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-session-1"
|
||||
text := "This is some thinking text content"
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature(sessionID, text, signature)
|
||||
CacheSignature(testModelName, text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature(sessionID, text)
|
||||
retrieved := GetCachedSignature(testModelName, text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
||||
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text in different sessions"
|
||||
text := "Same text across models"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("session-a", text, sig1)
|
||||
CacheSignature("session-b", text, sig2)
|
||||
geminiModel := "gemini-3-pro-preview"
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(geminiModel, text, sig2)
|
||||
|
||||
if GetCachedSignature("session-a", text) != sig1 {
|
||||
t.Error("Session-a signature mismatch")
|
||||
if GetCachedSignature(testModelName, text) != sig1 {
|
||||
t.Error("Claude signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("session-b", text) != sig2 {
|
||||
t.Error("Session-b signature mismatch")
|
||||
if GetCachedSignature(geminiModel, text) != sig2 {
|
||||
t.Error("Gemini signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,13 +46,13 @@ func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
||||
if got := GetCachedSignature(testModelName, "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
||||
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -59,12 +61,11 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "text", "")
|
||||
CacheSignature("session", "text", "short") // Too short
|
||||
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature(testModelName, "text", "")
|
||||
CacheSignature(testModelName, "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature("session", "text"); got != "" {
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -72,31 +73,27 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-short-sig"
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature(sessionID, text, shortSig)
|
||||
CacheSignature(testModelName, text, shortSig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
||||
if got := GetCachedSignature(testModelName, text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
||||
func TestClearSignatureCache_ModelGroup(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
||||
t.Error("session-2 should still exist")
|
||||
if got := GetCachedSignature(testModelName, "text"); got != sig {
|
||||
t.Error("signature should remain when clearing unknown session")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,35 +101,37 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
CacheSignature(testModelName, "text", sig)
|
||||
CacheSignature(testModelName, "text-2", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
if got := GetCachedSignature(testModelName, "text"); got != "" {
|
||||
t.Error("text should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
||||
t.Error("session-2 should be cleared")
|
||||
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
|
||||
t.Error("text-2 should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasValidSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
signature string
|
||||
expected bool
|
||||
}{
|
||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", "", false},
|
||||
{"short signature", "abc", false},
|
||||
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", testModelName, "", false},
|
||||
{"short signature", testModelName, "abc", false},
|
||||
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature(tt.signature)
|
||||
result := HasValidSignature(tt.modelName, tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
@@ -143,21 +142,19 @@ func TestHasValidSignature(t *testing.T) {
|
||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "hash-test-session"
|
||||
|
||||
// Different texts should produce different hashes
|
||||
text1 := "First thinking text"
|
||||
text2 := "Second thinking text"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text1, sig1)
|
||||
CacheSignature(sessionID, text2, sig2)
|
||||
CacheSignature(testModelName, text1, sig1)
|
||||
CacheSignature(testModelName, text2, sig2)
|
||||
|
||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
||||
if GetCachedSignature(testModelName, text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
||||
if GetCachedSignature(testModelName, text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
@@ -165,13 +162,12 @@ func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "unicode-session"
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
@@ -179,15 +175,14 @@ func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "overwrite-session"
|
||||
text := "Same text"
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature(sessionID, text, sig1)
|
||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
||||
CacheSignature(testModelName, text, sig1)
|
||||
CacheSignature(testModelName, text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
@@ -199,14 +194,13 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
|
||||
// This test verifies the expiration check exists
|
||||
// In a real scenario, we'd mock time.Now()
|
||||
sessionID := "expiration-test"
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
CacheSignature(testModelName, text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
if got := GetCachedSignature(testModelName, text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
@@ -112,6 +113,11 @@ func isAIAPIPath(path string) bool {
|
||||
// - gin.HandlerFunc: A middleware handler for panic recovery
|
||||
func GinLogrusRecovery() gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"panic": recovered,
|
||||
"stack": string(debug.Stack()),
|
||||
|
||||
60
internal/logging/gin_logger_test.go
Normal file
60
internal/logging/gin_logger_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/abort", func(c *gin.Context) {
|
||||
panic(http.ErrAbortHandler)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
defer func() {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
t.Fatalf("expected panic, got nil")
|
||||
}
|
||||
err, ok := recovered.(error)
|
||||
if !ok {
|
||||
t.Fatalf("expected error panic, got %T", recovered)
|
||||
}
|
||||
if !errors.Is(err, http.ErrAbortHandler) {
|
||||
t.Fatalf("expected ErrAbortHandler, got %v", err)
|
||||
}
|
||||
if err != http.ErrAbortHandler {
|
||||
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(GinLogrusRecovery())
|
||||
engine.GET("/panic", func(c *gin.Context) {
|
||||
panic("boom")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
engine.ServeHTTP(recorder, req)
|
||||
if recorder.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
@@ -398,7 +398,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
return nil, translatedPayload{}, err
|
||||
}
|
||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
|
||||
@@ -142,7 +142,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -261,7 +262,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
return resp, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -627,7 +629,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
return nil, err
|
||||
}
|
||||
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -1202,7 +1205,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
@@ -1213,7 +1216,17 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
} else {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||
// without adding empty-schema placeholders.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
|
||||
@@ -1230,6 +1243,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -1405,24 +1424,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
|
||||
template, _ = sjson.Delete(template, "toolConfig")
|
||||
}
|
||||
|
||||
return []byte(template)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -245,7 +246,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
// based on client type and configuration.
|
||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
|
||||
@@ -101,7 +101,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
@@ -213,7 +214,8 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
|
||||
@@ -129,7 +129,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -278,7 +279,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
||||
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
|
||||
|
||||
@@ -126,7 +126,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
@@ -228,7 +229,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
|
||||
@@ -325,7 +325,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
|
||||
@@ -438,7 +439,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, false)
|
||||
@@ -541,7 +543,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
@@ -664,7 +667,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
|
||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
|
||||
@@ -98,7 +98,8 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
body = preserveReasoningContentInMessages(body)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
@@ -201,7 +202,8 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
body = ensureToolsArray(body)
|
||||
}
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
|
||||
|
||||
@@ -90,7 +90,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -185,7 +186,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -12,8 +14,9 @@ import (
|
||||
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
|
||||
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
|
||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||
// against the original payload when provided.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte) []byte {
|
||||
// against the original payload when provided. requestedModel carries the client-visible
|
||||
// model name before alias resolution so payload rules can target aliases precisely.
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -22,10 +25,11 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(cfg, model, protocol)
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
@@ -163,65 +167,42 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadModelCandidates(cfg *config.Config, model, protocol string) []string {
|
||||
func payloadModelCandidates(model, requestedModel string) []string {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
candidates := []string{model}
|
||||
if cfg == nil {
|
||||
return candidates
|
||||
}
|
||||
aliases := payloadModelAliases(cfg, model, protocol)
|
||||
if len(aliases) == 0 {
|
||||
return candidates
|
||||
}
|
||||
seen := map[string]struct{}{strings.ToLower(model): struct{}{}}
|
||||
for _, alias := range aliases {
|
||||
alias = strings.TrimSpace(alias)
|
||||
if alias == "" {
|
||||
continue
|
||||
candidates := make([]string, 0, 3)
|
||||
seen := make(map[string]struct{}, 3)
|
||||
addCandidate := func(value string) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
key := strings.ToLower(value)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
candidates = append(candidates, alias)
|
||||
candidates = append(candidates, value)
|
||||
}
|
||||
if model != "" {
|
||||
addCandidate(model)
|
||||
}
|
||||
if requestedModel != "" {
|
||||
parsed := thinking.ParseSuffix(requestedModel)
|
||||
base := strings.TrimSpace(parsed.ModelName)
|
||||
if base != "" {
|
||||
addCandidate(base)
|
||||
}
|
||||
if parsed.HasSuffix {
|
||||
addCandidate(requestedModel)
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func payloadModelAliases(cfg *config.Config, model, protocol string) []string {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return nil
|
||||
}
|
||||
channel := strings.ToLower(strings.TrimSpace(protocol))
|
||||
if channel == "" {
|
||||
return nil
|
||||
}
|
||||
entries := cfg.OAuthModelAlias[channel]
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
aliases := make([]string, 0, 2)
|
||||
for _, entry := range entries {
|
||||
if !strings.EqualFold(strings.TrimSpace(entry.Name), model) {
|
||||
continue
|
||||
}
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
aliases = append(aliases, alias)
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
// buildPayloadPath combines an optional root path with a relative parameter path.
|
||||
// When root is empty, the parameter path is used as-is. When root is non-empty,
|
||||
// the parameter path is treated as relative to root.
|
||||
@@ -258,6 +239,35 @@ func payloadRawValue(value any) ([]byte, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
fallback = strings.TrimSpace(fallback)
|
||||
if len(opts.Metadata) == 0 {
|
||||
return fallback
|
||||
}
|
||||
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return fallback
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return fallback
|
||||
}
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return fallback
|
||||
}
|
||||
trimmed := strings.TrimSpace(string(v))
|
||||
if trimmed == "" {
|
||||
return fallback
|
||||
}
|
||||
return trimmed
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||
// Examples:
|
||||
//
|
||||
|
||||
@@ -91,7 +91,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
@@ -184,7 +185,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
|
||||
@@ -7,8 +7,6 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -19,37 +17,6 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// deriveSessionID generates a stable session ID from the request.
|
||||
// Uses the hash of the first user message to identify the conversation.
|
||||
func deriveSessionID(rawJSON []byte) string {
|
||||
userIDResult := gjson.GetBytes(rawJSON, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
userID := userIDResult.String()
|
||||
idx := strings.Index(userID, "session_")
|
||||
if idx != -1 {
|
||||
return userID[idx+8:]
|
||||
}
|
||||
}
|
||||
messages := gjson.GetBytes(rawJSON, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
content := msg.Get("content").String()
|
||||
if content == "" {
|
||||
// Try to get text from content array
|
||||
content = msg.Get("content.0.text").String()
|
||||
}
|
||||
if content != "" {
|
||||
h := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
|
||||
@@ -72,9 +39,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
// Derive session ID for signature caching
|
||||
sessionID := deriveSessionID(rawJSON)
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
@@ -137,8 +101,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Always try cached signature first (more reliable than client-provided)
|
||||
// Client may send stale or invalid signatures from different sessions
|
||||
signature := ""
|
||||
if sessionID != "" && thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(modelName, sessionID, thinkingText); cachedSig != "" {
|
||||
if thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
// log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
@@ -156,19 +120,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
}
|
||||
if cache.HasValidSignature(clientSignature) {
|
||||
if cache.HasValidSignature(modelName, clientSignature) {
|
||||
signature = clientSignature
|
||||
}
|
||||
// log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(signature) {
|
||||
if cache.HasValidSignature(modelName, signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(signature)
|
||||
isUnsigned := !cache.HasValidSignature(modelName, signature)
|
||||
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
@@ -223,7 +187,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
||||
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
|
||||
@@ -74,13 +74,13 @@ func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Valid signature must be at least 50 characters
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
// Pre-cache the signature (simulating a response from the same session)
|
||||
// The session ID is derived from the first user message hash
|
||||
// Since there's no user message in this test, we need to add one
|
||||
// Pre-cache the signature (simulating a previous response for the same thinking text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
@@ -98,10 +98,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
]
|
||||
}`)
|
||||
|
||||
// Derive session ID and cache the signature
|
||||
sessionID := deriveSessionID(inputJSON)
|
||||
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||
defer cache.ClearSignatureCache(sessionID)
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
@@ -120,6 +117,8 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
@@ -241,6 +240,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think..."
|
||||
|
||||
@@ -266,10 +267,7 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
]
|
||||
}`)
|
||||
|
||||
// Derive session ID and cache the signature
|
||||
sessionID := deriveSessionID(inputJSON)
|
||||
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||
defer cache.ClearSignatureCache(sessionID)
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
@@ -285,6 +283,8 @@ func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Planning..."
|
||||
@@ -306,10 +306,7 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
]
|
||||
}`)
|
||||
|
||||
// Derive session ID and cache the signature
|
||||
sessionID := deriveSessionID(inputJSON)
|
||||
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||
defer cache.ClearSignatureCache(sessionID)
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
@@ -496,6 +493,8 @@ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *t
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Last assistant message ends with signed thinking block - should be kept
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Valid thinking..."
|
||||
@@ -517,10 +516,7 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin
|
||||
]
|
||||
}`)
|
||||
|
||||
// Derive session ID and cache the signature
|
||||
sessionID := deriveSessionID(inputJSON)
|
||||
cache.CacheSignature(sessionID, thinkingText, validSignature)
|
||||
defer cache.ClearSignatureCache(sessionID)
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
@@ -41,7 +41,6 @@ type Params struct {
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Signature caching support
|
||||
SessionID string // Session ID derived from request for signature caching
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
}
|
||||
|
||||
@@ -70,7 +69,6 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
modelName := gjson.GetBytes(requestRawJSON, "model").String()
|
||||
@@ -139,14 +137,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
// log.Debug("Branch: signature_delta")
|
||||
|
||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(modelName, params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
if params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", modelName, thoughtSignature.String()))
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
@@ -438,7 +436,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", modelName, thinkingSignature))
|
||||
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
// Signature Caching Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Request with user message - should derive session ID
|
||||
// Request with user message - should initialize params
|
||||
requestJSON := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
||||
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
||||
|
||||
// Verify session ID was set
|
||||
params := param.(*Params)
|
||||
if params.SessionID == "" {
|
||||
t.Error("SessionID should be derived from request")
|
||||
if !params.HasFirstResponse {
|
||||
t.Error("HasFirstResponse should be set after first chunk")
|
||||
}
|
||||
if params.CurrentThinkingText.Len() == 0 {
|
||||
t.Error("Thinking text should be accumulated")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +99,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
||||
}`)
|
||||
|
||||
@@ -129,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
// Process thinking chunk
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
thinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
if sessionID == "" {
|
||||
t.Fatal("SessionID should be set")
|
||||
}
|
||||
if thinkingText == "" {
|
||||
t.Fatal("Thinking text should be accumulated")
|
||||
}
|
||||
@@ -143,7 +142,7 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
||||
|
||||
// Verify signature was cached
|
||||
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
|
||||
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
|
||||
if cachedSig != validSignature {
|
||||
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
||||
}
|
||||
@@ -158,6 +157,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
||||
}`)
|
||||
|
||||
@@ -221,13 +221,12 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
// Process first thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
firstThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
||||
|
||||
// Verify first signature cached
|
||||
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
|
||||
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
|
||||
t.Error("First thinking block signature should be cached")
|
||||
}
|
||||
|
||||
@@ -241,76 +240,7 @@ func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T)
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
||||
|
||||
// Verify second signature cached
|
||||
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
|
||||
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
|
||||
t.Error("Second thinking block signature should be cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user message",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "user message with content array",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "no user message",
|
||||
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages",
|
||||
input: []byte(`{"messages": []}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
input: []byte(`{}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := deriveSessionID(tt.input)
|
||||
if tt.wantEmpty && result != "" {
|
||||
t.Errorf("Expected empty session ID, got '%s'", result)
|
||||
}
|
||||
if !tt.wantEmpty && result == "" {
|
||||
t.Error("Expected non-empty session ID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
|
||||
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input)
|
||||
id2 := deriveSessionID(input)
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
|
||||
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
|
||||
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input1)
|
||||
id2 := deriveSessionID(input2)
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("Different messages should produce different session IDs")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package gemini
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -32,12 +33,12 @@ import (
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Set(template, "model", modelName)
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
|
||||
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
||||
@@ -97,37 +98,40 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
|
||||
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
// Gemini-specific handling for non-Claude models:
|
||||
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
|
||||
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
// First pass: collect indices of thinking parts to remove
|
||||
var thinkingIndicesToRemove []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Mark thinking blocks for removal
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
|
||||
}
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
// First pass: collect indices of thinking parts to mark with skip sentinel
|
||||
var thinkingIndicesToSkipSignature []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Collect indices of thinking blocks to mark with skip sentinel
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Remove thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToRemove[i]
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
|
||||
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToSkipSignature[i]
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
@@ -62,40 +62,6 @@ func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *test
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
|
||||
// Thinking blocks should be removed entirely for Gemini
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
|
||||
{"text": "Here is my response"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that thinking block is removed
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed for Gemini")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Here is my response" {
|
||||
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
// Multiple functionCalls should all get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
|
||||
@@ -98,9 +98,8 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
// Top P setting for nucleus sampling
|
||||
if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
} else if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
// Stop sequences configuration for custom termination conditions
|
||||
|
||||
@@ -110,10 +110,8 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Temperature setting for controlling response randomness
|
||||
if temp := root.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
|
||||
// Top P setting for nucleus sampling
|
||||
if topP := root.Get("top_p"); topP.Exists() {
|
||||
} else if topP := root.Get("top_p"); topP.Exists() {
|
||||
// Top P setting for nucleus sampling (filtered out if temperature is set)
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
|
||||
@@ -298,6 +298,15 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
||||
|
||||
case "reasoning":
|
||||
thoughtContent := `{"role":"model","parts":[]}`
|
||||
thought := `{"text":"","thoughtSignature":"","thought":true}`
|
||||
thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String())
|
||||
thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String())
|
||||
|
||||
thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent)
|
||||
}
|
||||
}
|
||||
} else if input.Exists() && input.Type == gjson.String {
|
||||
|
||||
@@ -20,6 +20,7 @@ type geminiToResponsesState struct {
|
||||
|
||||
// message aggregation
|
||||
MsgOpened bool
|
||||
MsgClosed bool
|
||||
MsgIndex int
|
||||
CurrentMsgID string
|
||||
TextBuf strings.Builder
|
||||
@@ -29,6 +30,7 @@ type geminiToResponsesState struct {
|
||||
ReasoningOpened bool
|
||||
ReasoningIndex int
|
||||
ReasoningItemID string
|
||||
ReasoningEnc string
|
||||
ReasoningBuf strings.Builder
|
||||
ReasoningClosed bool
|
||||
|
||||
@@ -37,6 +39,7 @@ type geminiToResponsesState struct {
|
||||
FuncArgsBuf map[int]*strings.Builder
|
||||
FuncNames map[int]string
|
||||
FuncCallIDs map[int]string
|
||||
FuncDone map[int]bool
|
||||
}
|
||||
|
||||
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
|
||||
@@ -45,6 +48,39 @@ var responseIDCounter uint64
|
||||
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||
var funcCallIDCounter uint64
|
||||
|
||||
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
|
||||
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
|
||||
return originalRequestRawJSON
|
||||
}
|
||||
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
|
||||
return requestRawJSON
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unwrapRequestRoot(root gjson.Result) gjson.Result {
|
||||
req := root.Get("request")
|
||||
if !req.Exists() {
|
||||
return root
|
||||
}
|
||||
if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() {
|
||||
return req
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result {
|
||||
resp := root.Get("response")
|
||||
if !resp.Exists() {
|
||||
return root
|
||||
}
|
||||
// Vertex-style Gemini responses wrap the actual payload in a "response" object.
|
||||
if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() {
|
||||
return resp
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func emitEvent(event string, payload string) string {
|
||||
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
|
||||
}
|
||||
@@ -56,18 +92,37 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncDone: make(map[int]bool),
|
||||
}
|
||||
}
|
||||
st := (*param).(*geminiToResponsesState)
|
||||
if st.FuncArgsBuf == nil {
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
}
|
||||
if st.FuncNames == nil {
|
||||
st.FuncNames = make(map[int]string)
|
||||
}
|
||||
if st.FuncCallIDs == nil {
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
}
|
||||
if st.FuncDone == nil {
|
||||
st.FuncDone = make(map[int]bool)
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
}
|
||||
|
||||
rawJSON = bytes.TrimSpace(rawJSON)
|
||||
if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
if !root.Exists() {
|
||||
return []string{}
|
||||
}
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
|
||||
var out []string
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
@@ -98,19 +153,54 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.ReasoningClosed = true
|
||||
}
|
||||
|
||||
// Helper to finalize the assistant message in correct order.
|
||||
// It emits response.output_text.done, response.content_part.done,
|
||||
// and response.output_item.done exactly once.
|
||||
finalizeMessage := func() {
|
||||
if !st.MsgOpened || st.MsgClosed {
|
||||
return
|
||||
}
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
|
||||
st.MsgClosed = true
|
||||
}
|
||||
|
||||
// Initialize per-response fields and emit created/in_progress once
|
||||
if !st.Started {
|
||||
if v := root.Get("responseId"); v.Exists() {
|
||||
st.ResponseID = v.String()
|
||||
st.ResponseID = root.Get("responseId").String()
|
||||
if st.ResponseID == "" {
|
||||
st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
|
||||
}
|
||||
if !strings.HasPrefix(st.ResponseID, "resp_") {
|
||||
st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID)
|
||||
}
|
||||
if v := root.Get("createTime"); v.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
||||
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||
st.CreatedAt = t.Unix()
|
||||
}
|
||||
}
|
||||
@@ -143,15 +233,21 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// Ignore any late thought chunks after reasoning is finalized.
|
||||
return true
|
||||
}
|
||||
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||
st.ReasoningEnc = sig.String()
|
||||
} else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
|
||||
st.ReasoningEnc = sig.String()
|
||||
}
|
||||
if !st.ReasoningOpened {
|
||||
st.ReasoningOpened = true
|
||||
st.ReasoningIndex = st.NextIndex
|
||||
st.NextIndex++
|
||||
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
|
||||
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
|
||||
@@ -191,9 +287,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
|
||||
out = append(out, emitEvent("response.content_part.added", partAdded))
|
||||
st.ItemTextBuf.Reset()
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
}
|
||||
st.TextBuf.WriteString(t.String())
|
||||
st.ItemTextBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
|
||||
@@ -205,8 +301,10 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
|
||||
// Function call
|
||||
if fc := part.Get("functionCall"); fc.Exists() {
|
||||
// Before emitting function-call outputs, finalize reasoning if open.
|
||||
// Before emitting function-call outputs, finalize reasoning and the message (if open).
|
||||
// Responses streaming requires message done events before the next output_item.added.
|
||||
finalizeReasoning()
|
||||
finalizeMessage()
|
||||
name := fc.Get("name").String()
|
||||
idx := st.NextIndex
|
||||
st.NextIndex++
|
||||
@@ -219,6 +317,14 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
st.FuncNames[idx] = name
|
||||
|
||||
argsJSON := "{}"
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
argsJSON = args.Raw
|
||||
}
|
||||
if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" {
|
||||
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
||||
}
|
||||
|
||||
// Emit item.added for function call
|
||||
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
|
||||
item, _ = sjson.Set(item, "sequence_number", nextSeq())
|
||||
@@ -228,10 +334,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
item, _ = sjson.Set(item, "item.name", name)
|
||||
out = append(out, emitEvent("response.output_item.added", item))
|
||||
|
||||
// Emit arguments delta (full args in one chunk)
|
||||
if args := fc.Get("args"); args.Exists() {
|
||||
argsJSON := args.Raw
|
||||
st.FuncArgsBuf[idx].WriteString(argsJSON)
|
||||
// Emit arguments delta (full args in one chunk).
|
||||
// When Gemini omits args, emit "{}" to keep Responses streaming event order consistent.
|
||||
if argsJSON != "" {
|
||||
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
|
||||
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
@@ -240,6 +345,27 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
|
||||
// Gemini emits the full function call payload at once, so we can finalize it immediately.
|
||||
if !st.FuncDone[idx] {
|
||||
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
|
||||
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
|
||||
fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON)
|
||||
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
|
||||
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON)
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.FuncDone[idx] = true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -251,28 +377,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
|
||||
// Finalize reasoning first to keep ordering tight with last delta
|
||||
finalizeReasoning()
|
||||
// Close message output if opened
|
||||
if st.MsgOpened {
|
||||
fullText := st.ItemTextBuf.String()
|
||||
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
|
||||
done, _ = sjson.Set(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
|
||||
done, _ = sjson.Set(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
|
||||
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
|
||||
partDone, _ = sjson.Set(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
|
||||
final, _ = sjson.Set(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
|
||||
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.Set(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
}
|
||||
finalizeMessage()
|
||||
|
||||
// Close function calls
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
@@ -289,6 +394,9 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
if st.FuncDone[idx] {
|
||||
continue
|
||||
}
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
@@ -308,6 +416,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
|
||||
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||
|
||||
st.FuncDone[idx] = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,8 +429,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
|
||||
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
|
||||
|
||||
if requestRawJSON != nil {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.instructions", v.String())
|
||||
}
|
||||
@@ -383,41 +493,34 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
}
|
||||
|
||||
// Compose outputs in encountered order: reasoning, message, function_calls
|
||||
// Compose outputs in output_index order.
|
||||
outputsWrapper := `{"arr":[]}`
|
||||
if st.ReasoningOpened {
|
||||
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if st.MsgOpened {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for idx := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, idx)
|
||||
for idx := 0; idx < st.NextIndex; idx++ {
|
||||
if st.ReasoningOpened && idx == st.ReasoningIndex {
|
||||
item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}`
|
||||
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
|
||||
item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc)
|
||||
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
continue
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
if st.MsgOpened && idx == st.MsgIndex {
|
||||
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
|
||||
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
|
||||
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
continue
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[idx]; b != nil {
|
||||
|
||||
if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" {
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
|
||||
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.Set(item, "arguments", args)
|
||||
item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
|
||||
item, _ = sjson.Set(item, "call_id", callID)
|
||||
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
|
||||
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
|
||||
}
|
||||
@@ -431,8 +534,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
||||
@@ -460,6 +563,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
|
||||
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
root = unwrapGeminiResponseRoot(root)
|
||||
|
||||
// Base response scaffold
|
||||
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
|
||||
@@ -478,15 +582,15 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
// created_at: map from createTime if available
|
||||
createdAt := time.Now().Unix()
|
||||
if v := root.Get("createTime"); v.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339Nano, v.String()); err == nil {
|
||||
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
|
||||
createdAt = t.Unix()
|
||||
}
|
||||
}
|
||||
resp, _ = sjson.Set(resp, "created_at", createdAt)
|
||||
|
||||
// Echo request fields when present; fallback model from response modelVersion
|
||||
if len(requestRawJSON) > 0 {
|
||||
req := gjson.ParseBytes(requestRawJSON)
|
||||
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
|
||||
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
|
||||
if v := req.Get("instructions"); v.Exists() {
|
||||
resp, _ = sjson.Set(resp, "instructions", v.String())
|
||||
}
|
||||
@@ -636,8 +740,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||
// cached_tokens not provided by Gemini; default to 0 for structure compatibility
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", 0)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(chunk, "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) {
|
||||
// Vertex-style Gemini stream wraps the actual response payload under "response".
|
||||
// This test ensures we unwrap and that output_text.done contains the full text.
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`)
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var (
|
||||
gotTextDone bool
|
||||
gotMessageDone bool
|
||||
gotResponseDone bool
|
||||
gotFuncDone bool
|
||||
|
||||
textDone string
|
||||
messageText string
|
||||
responseID string
|
||||
instructions string
|
||||
cachedTokens int64
|
||||
|
||||
funcName string
|
||||
funcArgs string
|
||||
|
||||
posTextDone = -1
|
||||
posPartDone = -1
|
||||
posMessageDone = -1
|
||||
posFuncAdded = -1
|
||||
)
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_text.done":
|
||||
gotTextDone = true
|
||||
if posTextDone == -1 {
|
||||
posTextDone = i
|
||||
}
|
||||
textDone = data.Get("text").String()
|
||||
case "response.content_part.done":
|
||||
if posPartDone == -1 {
|
||||
posPartDone = i
|
||||
}
|
||||
case "response.output_item.done":
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
gotMessageDone = true
|
||||
if posMessageDone == -1 {
|
||||
posMessageDone = i
|
||||
}
|
||||
messageText = data.Get("item.content.0.text").String()
|
||||
case "function_call":
|
||||
gotFuncDone = true
|
||||
funcName = data.Get("item.name").String()
|
||||
funcArgs = data.Get("item.arguments").String()
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 {
|
||||
posFuncAdded = i
|
||||
}
|
||||
case "response.completed":
|
||||
gotResponseDone = true
|
||||
responseID = data.Get("response.id").String()
|
||||
instructions = data.Get("response.instructions").String()
|
||||
cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int()
|
||||
}
|
||||
}
|
||||
|
||||
if !gotTextDone {
|
||||
t.Fatalf("missing response.output_text.done event")
|
||||
}
|
||||
if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 {
|
||||
t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||
}
|
||||
if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) {
|
||||
t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
|
||||
}
|
||||
if !gotMessageDone {
|
||||
t.Fatalf("missing message response.output_item.done event")
|
||||
}
|
||||
if !gotFuncDone {
|
||||
t.Fatalf("missing function_call response.output_item.done event")
|
||||
}
|
||||
if !gotResponseDone {
|
||||
t.Fatalf("missing response.completed event")
|
||||
}
|
||||
|
||||
if textDone != "让我先了解" {
|
||||
t.Fatalf("unexpected output_text.done text: got %q", textDone)
|
||||
}
|
||||
if messageText != "让我先了解" {
|
||||
t.Fatalf("unexpected message done text: got %q", messageText)
|
||||
}
|
||||
|
||||
if responseID != "resp_req_vrtx_1" {
|
||||
t.Fatalf("unexpected response id: got %q", responseID)
|
||||
}
|
||||
if instructions != "test instructions" {
|
||||
t.Fatalf("unexpected instructions echo: got %q", instructions)
|
||||
}
|
||||
if cachedTokens != 2 {
|
||||
t.Fatalf("unexpected cached token count: got %d", cachedTokens)
|
||||
}
|
||||
|
||||
if funcName != "mcp__serena__list_dir" {
|
||||
t.Fatalf("unexpected function name: got %q", funcName)
|
||||
}
|
||||
if !gjson.Valid(funcArgs) {
|
||||
t.Fatalf("invalid function arguments JSON: %q", funcArgs)
|
||||
}
|
||||
if gjson.Get(funcArgs, "recursive").Bool() != false {
|
||||
t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value())
|
||||
}
|
||||
if gjson.Get(funcArgs, "relative_path").String() != "internal" {
|
||||
t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) {
|
||||
sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw=="
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var (
|
||||
addedEnc string
|
||||
doneEnc string
|
||||
)
|
||||
for _, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "reasoning" {
|
||||
addedEnc = data.Get("item.encrypted_content").String()
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "reasoning" {
|
||||
doneEnc = data.Get("item.encrypted_content").String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if addedEnc != sig {
|
||||
t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc)
|
||||
}
|
||||
if doneEnc != sig {
|
||||
t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
posAdded := []int{-1, -1, -1}
|
||||
posArgsDelta := []int{-1, -1, -1}
|
||||
posArgsDone := []int{-1, -1, -1}
|
||||
posItemDone := []int{-1, -1, -1}
|
||||
posCompleted := -1
|
||||
deltaByIndex := map[int]string{}
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posAdded) {
|
||||
posAdded[idx] = i
|
||||
}
|
||||
case "response.function_call_arguments.delta":
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posArgsDelta) {
|
||||
posArgsDelta[idx] = i
|
||||
deltaByIndex[idx] = data.Get("delta").String()
|
||||
}
|
||||
case "response.function_call_arguments.done":
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posArgsDone) {
|
||||
posArgsDone[idx] = i
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
idx := int(data.Get("output_index").Int())
|
||||
if idx >= 0 && idx < len(posItemDone) {
|
||||
posItemDone[idx] = i
|
||||
}
|
||||
case "response.completed":
|
||||
posCompleted = i
|
||||
|
||||
output := data.Get("response.output")
|
||||
if !output.Exists() || !output.IsArray() {
|
||||
t.Fatalf("missing response.output in response.completed")
|
||||
}
|
||||
if len(output.Array()) != 3 {
|
||||
t.Fatalf("unexpected response.output length: got %d", len(output.Array()))
|
||||
}
|
||||
if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" {
|
||||
t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" {
|
||||
t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
if data.Get("response.output.2.name").String() != "tool2" {
|
||||
t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw)
|
||||
}
|
||||
if !gjson.Valid(data.Get("response.output.2.arguments").String()) {
|
||||
t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if posCompleted == -1 {
|
||||
t.Fatalf("missing response.completed event")
|
||||
}
|
||||
for idx := 0; idx < 3; idx++ {
|
||||
if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 {
|
||||
t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||
}
|
||||
if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) {
|
||||
t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
|
||||
}
|
||||
if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) {
|
||||
t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if deltaByIndex[0] != "{}" {
|
||||
t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0])
|
||||
}
|
||||
if deltaByIndex[1] != "{}" {
|
||||
t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1])
|
||||
}
|
||||
if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 {
|
||||
t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2])
|
||||
}
|
||||
if !(posItemDone[2] < posCompleted) {
|
||||
t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
|
||||
}
|
||||
|
||||
var param any
|
||||
var out []string
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
posFuncDone := -1
|
||||
posMsgAdded := -1
|
||||
posCompleted := -1
|
||||
|
||||
for i, chunk := range out {
|
||||
ev, data := parseSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 {
|
||||
posFuncDone = i
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 {
|
||||
posMsgAdded = i
|
||||
}
|
||||
case "response.completed":
|
||||
posCompleted = i
|
||||
if data.Get("response.output.0.type").String() != "function_call" {
|
||||
t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.type").String() != "message" {
|
||||
t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
if data.Get("response.output.1.content.0.text").String() != "hi" {
|
||||
t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 {
|
||||
t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted)
|
||||
}
|
||||
if !(posFuncDone < posMsgAdded) {
|
||||
t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded)
|
||||
}
|
||||
if !(posMsgAdded < posCompleted) {
|
||||
t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted)
|
||||
}
|
||||
}
|
||||
@@ -89,12 +89,14 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
// Handle system message first
|
||||
systemMsgJSON := `{"role":"system","content":[]}`
|
||||
hasSystemContent := false
|
||||
if system := root.Get("system"); system.Exists() {
|
||||
if system.Type == gjson.String {
|
||||
if system.String() != "" {
|
||||
oldSystem := `{"type":"text","text":""}`
|
||||
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
|
||||
hasSystemContent = true
|
||||
}
|
||||
} else if system.Type == gjson.JSON {
|
||||
if system.IsArray() {
|
||||
@@ -102,12 +104,16 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
|
||||
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
|
||||
hasSystemContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
||||
// Only add system message if it has content
|
||||
if hasSystemContent {
|
||||
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
|
||||
}
|
||||
|
||||
// Process Anthropic messages
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
|
||||
@@ -12,10 +12,23 @@ import (
|
||||
|
||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
|
||||
const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
|
||||
|
||||
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
|
||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||
// semantic information as description hints.
|
||||
func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||
return cleanJSONSchema(jsonStr, true)
|
||||
}
|
||||
|
||||
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
|
||||
// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
|
||||
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
return cleanJSONSchema(jsonStr, false)
|
||||
}
|
||||
|
||||
// cleanJSONSchema performs the core cleaning operations on the JSON schema.
|
||||
func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||
// Phase 1: Convert and add hints
|
||||
jsonStr = convertRefsToHints(jsonStr)
|
||||
jsonStr = convertConstToEnum(jsonStr)
|
||||
@@ -31,10 +44,94 @@ func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||
|
||||
// Phase 3: Cleanup
|
||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||
if !addPlaceholder {
|
||||
// Gemini schema cleanup: remove nullable/title and placeholder-only fields.
|
||||
jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
|
||||
jsonStr = removePlaceholderFields(jsonStr)
|
||||
}
|
||||
jsonStr = cleanupRequiredFields(jsonStr)
|
||||
|
||||
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
|
||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||
if addPlaceholder {
|
||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||
}
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||
func removeKeywords(jsonStr string, keywords []string) string {
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
|
||||
func removePlaceholderFields(jsonStr string) string {
|
||||
// Remove "_" placeholder properties.
|
||||
paths := findPaths(jsonStr, "_")
|
||||
sortByDepth(paths)
|
||||
for _, p := range paths {
|
||||
if !strings.HasSuffix(p, ".properties._") {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
parentPath := trimSuffix(p, ".properties._")
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if req.IsArray() {
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if r.String() != "_" {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove placeholder-only "reason" objects.
|
||||
reasonPaths := findPaths(jsonStr, "reason")
|
||||
sortByDepth(reasonPaths)
|
||||
for _, p := range reasonPaths {
|
||||
if !strings.HasSuffix(p, ".properties.reason") {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, ".properties.reason")
|
||||
props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
|
||||
if !props.IsObject() || len(props.Map()) != 1 {
|
||||
continue
|
||||
}
|
||||
desc := gjson.Get(jsonStr, p+".description").String()
|
||||
if desc != placeholderReasonDescription {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if req.IsArray() {
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if r.String() != "reason" {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
@@ -409,7 +506,7 @@ func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
// Add placeholder "reason" property
|
||||
reasonPath := joinPath(propsPath, "reason")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
|
||||
|
||||
// Add to required array
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||
|
||||
@@ -385,6 +385,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -423,6 +424,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -464,6 +466,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
|
||||
@@ -570,6 +570,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -597,6 +598,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
@@ -619,6 +623,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -646,6 +651,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
@@ -668,6 +676,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
opts = ensureRequestedModelMetadata(opts, routeModel)
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
@@ -694,6 +703,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
@@ -729,167 +741,42 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return opts
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
if hasRequestedModelMetadata(opts.Metadata) {
|
||||
return opts
|
||||
}
|
||||
if len(opts.Metadata) == 0 {
|
||||
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
|
||||
return opts
|
||||
}
|
||||
meta := make(map[string]any, len(opts.Metadata)+1)
|
||||
for k, v := range opts.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
|
||||
opts.Metadata = meta
|
||||
return opts
|
||||
}
|
||||
|
||||
func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
func hasRequestedModelMetadata(meta map[string]any) bool {
|
||||
if len(meta) == 0 {
|
||||
return false
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
result.RetryAfter = ra
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v) != ""
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v)) != ""
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1140,35 +1027,6 @@ func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// rotateProviders returns a rotated view of the providers list starting from the
|
||||
// current offset for the model, and atomically increments the offset for the next call.
|
||||
// This ensures concurrent requests get different starting providers.
|
||||
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Atomic read-and-increment: get current offset and advance cursor in one lock
|
||||
m.mu.Lock()
|
||||
offset := m.providerOffsets[model]
|
||||
m.providerOffsets[model] = (offset + 1) % len(providers)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(providers) > 0 {
|
||||
offset %= len(providers)
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset == 0 {
|
||||
return providers
|
||||
}
|
||||
rotated := make([]string, 0, len(providers))
|
||||
rotated = append(rotated, providers[offset:]...)
|
||||
rotated = append(rotated, providers[:offset]...)
|
||||
return rotated
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
@@ -1250,42 +1108,6 @@ func waitForCooldown(ctx context.Context, wait time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
resp, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
chunks, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// MarkResult records an execution result and notifies hooks.
|
||||
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
if result.AuthID == "" {
|
||||
@@ -1371,8 +1193,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
if quotaCooldownDisabled.Load() {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
@@ -1623,7 +1449,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
||||
auth.NextRetryAfter = next
|
||||
case 408, 500, 502, 503, 504:
|
||||
auth.StatusMessage = "transient upstream error"
|
||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||
if quotaCooldownDisabled.Load() {
|
||||
auth.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||
}
|
||||
default:
|
||||
if auth.StatusMessage == "" {
|
||||
auth.StatusMessage = "request failed"
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||
const RequestedModelMetadataKey = "requested_model"
|
||||
|
||||
// Request encapsulates the translated payload that will be sent to a provider executor.
|
||||
type Request struct {
|
||||
// Model is the upstream model identifier after translation.
|
||||
|
||||
Reference in New Issue
Block a user