refactor(cache, translator): refine signature caching logic and tests, replace session-based logic with model group handling

This commit is contained in:
Luis Pater
2026-01-21 18:30:05 +08:00
parent ef4508dbc8
commit d9c6317c84
7 changed files with 129 additions and 126 deletions

View File

@@ -3,7 +3,6 @@ package cache
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -25,18 +24,18 @@ const (
// MinValidSignatureLen is the minimum length for a signature to be considered valid // MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50 MinValidSignatureLen = 50
// SessionCleanupInterval controls how often stale sessions are purged // CacheCleanupInterval controls how often stale entries are purged
SessionCleanupInterval = 10 * time.Minute CacheCleanupInterval = 10 * time.Minute
) )
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry // signatureCache stores signatures by model group -> textHash -> SignatureEntry
var signatureCache sync.Map var signatureCache sync.Map
// sessionCleanupOnce ensures the background cleanup goroutine starts only once // cacheCleanupOnce ensures the background cleanup goroutine starts only once
var sessionCleanupOnce sync.Once var cacheCleanupOnce sync.Once
// sessionCache is the inner map type // groupCache is the inner map type
type sessionCache struct { type groupCache struct {
mu sync.RWMutex mu sync.RWMutex
entries map[string]SignatureEntry entries map[string]SignatureEntry
} }
@@ -47,36 +46,36 @@ func hashText(text string) string {
return hex.EncodeToString(h[:])[:SignatureTextHashLen] return hex.EncodeToString(h[:])[:SignatureTextHashLen]
} }
// getOrCreateSession gets or creates a session cache // getOrCreateGroupCache gets or creates a cache bucket for a model group
func getOrCreateSession(sessionID string) *sessionCache { func getOrCreateGroupCache(groupKey string) *groupCache {
// Start background cleanup on first access // Start background cleanup on first access
sessionCleanupOnce.Do(startSessionCleanup) cacheCleanupOnce.Do(startCacheCleanup)
if val, ok := signatureCache.Load(sessionID); ok { if val, ok := signatureCache.Load(groupKey); ok {
return val.(*sessionCache) return val.(*groupCache)
} }
sc := &sessionCache{entries: make(map[string]SignatureEntry)} sc := &groupCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc) actual, _ := signatureCache.LoadOrStore(groupKey, sc)
return actual.(*sessionCache) return actual.(*groupCache)
} }
// startSessionCleanup launches a background goroutine that periodically // startCacheCleanup launches a background goroutine that periodically
// removes sessions where all entries have expired. // removes caches where all entries have expired.
func startSessionCleanup() { func startCacheCleanup() {
go func() { go func() {
ticker := time.NewTicker(SessionCleanupInterval) ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
purgeExpiredSessions() purgeExpiredCaches()
} }
}() }()
} }
// purgeExpiredSessions removes sessions with no valid (non-expired) entries. // purgeExpiredCaches removes caches with no valid (non-expired) entries.
func purgeExpiredSessions() { func purgeExpiredCaches() {
now := time.Now() now := time.Now()
signatureCache.Range(func(key, value any) bool { signatureCache.Range(func(key, value any) bool {
sc := value.(*sessionCache) sc := value.(*groupCache)
sc.mu.Lock() sc.mu.Lock()
// Remove expired entries // Remove expired entries
for k, entry := range sc.entries { for k, entry := range sc.entries {
@@ -86,7 +85,7 @@ func purgeExpiredSessions() {
} }
isEmpty := len(sc.entries) == 0 isEmpty := len(sc.entries) == 0
sc.mu.Unlock() sc.mu.Unlock()
// Remove session if empty // Remove cache bucket if empty
if isEmpty { if isEmpty {
signatureCache.Delete(key) signatureCache.Delete(key)
} }
@@ -94,7 +93,7 @@ func purgeExpiredSessions() {
}) })
} }
// CacheSignature stores a thinking signature for a given session and text. // CacheSignature stores a thinking signature for a given model group and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations. // Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(modelName, text, signature string) { func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" { if text == "" || signature == "" {
@@ -104,9 +103,9 @@ func CacheSignature(modelName, text, signature string) {
return return
} }
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) groupKey := GetModelGroup(modelName)
textHash := hashText(text) textHash := hashText(text)
sc := getOrCreateSession(textHash) sc := getOrCreateGroupCache(groupKey)
sc.mu.Lock() sc.mu.Lock()
defer sc.mu.Unlock() defer sc.mu.Unlock()
@@ -116,26 +115,25 @@ func CacheSignature(modelName, text, signature string) {
} }
} }
// GetCachedSignature retrieves a cached signature for a given session and text. // GetCachedSignature retrieves a cached signature for a given model group and text.
// Returns empty string if not found or expired. // Returns empty string if not found or expired.
func GetCachedSignature(modelName, text string) string { func GetCachedSignature(modelName, text string) string {
family := GetModelGroup(modelName) groupKey := GetModelGroup(modelName)
if text == "" { if text == "" {
if family == "gemini" { if groupKey == "gemini" {
return "skip_thought_signature_validator" return "skip_thought_signature_validator"
} }
return "" return ""
} }
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text) val, ok := signatureCache.Load(groupKey)
val, ok := signatureCache.Load(hashText(text))
if !ok { if !ok {
if family == "gemini" { if groupKey == "gemini" {
return "skip_thought_signature_validator" return "skip_thought_signature_validator"
} }
return "" return ""
} }
sc := val.(*sessionCache) sc := val.(*groupCache)
textHash := hashText(text) textHash := hashText(text)
@@ -145,7 +143,7 @@ func GetCachedSignature(modelName, text string) string {
entry, exists := sc.entries[textHash] entry, exists := sc.entries[textHash]
if !exists { if !exists {
sc.mu.Unlock() sc.mu.Unlock()
if family == "gemini" { if groupKey == "gemini" {
return "skip_thought_signature_validator" return "skip_thought_signature_validator"
} }
return "" return ""
@@ -153,7 +151,7 @@ func GetCachedSignature(modelName, text string) string {
if now.Sub(entry.Timestamp) > SignatureCacheTTL { if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash) delete(sc.entries, textHash)
sc.mu.Unlock() sc.mu.Unlock()
if family == "gemini" { if groupKey == "gemini" {
return "skip_thought_signature_validator" return "skip_thought_signature_validator"
} }
return "" return ""
@@ -167,22 +165,17 @@ func GetCachedSignature(modelName, text string) string {
return entry.Signature return entry.Signature
} }
// ClearSignatureCache clears signature cache for a specific session or all sessions. // ClearSignatureCache clears signature cache for a specific model group or all groups.
func ClearSignatureCache(sessionID string) { func ClearSignatureCache(modelName string) {
if sessionID != "" { if modelName == "" {
signatureCache.Range(func(key, _ any) bool {
kStr, ok := key.(string)
if ok && strings.HasSuffix(kStr, "#"+sessionID) {
signatureCache.Delete(key)
}
return true
})
} else {
signatureCache.Range(func(key, _ any) bool { signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key) signatureCache.Delete(key)
return true return true
}) })
return
} }
groupKey := GetModelGroup(modelName)
signatureCache.Delete(groupKey)
} }
// HasValidSignature checks if a signature is valid (non-empty and long enough) // HasValidSignature checks if a signature is valid (non-empty and long enough)

View File

@@ -21,33 +21,33 @@ func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
} }
} }
func TestCacheSignature_DifferentSessions(t *testing.T) { func TestCacheSignature_DifferentModelGroups(t *testing.T) {
ClearSignatureCache("") ClearSignatureCache("")
text := "Same text in different sessions" text := "Same text across models"
sig1 := "signature1_1234567890123456789012345678901234567890123456" sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456" sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("test-model", text, sig1) CacheSignature("claude-sonnet-4-5-thinking", text, sig1)
CacheSignature("test-model", text, sig2) CacheSignature("gpt-4o", text, sig2)
if GetCachedSignature("test-model", text) != sig1 { if GetCachedSignature("claude-sonnet-4-5-thinking", text) != sig1 {
t.Error("Session-a signature mismatch") t.Error("Claude signature mismatch")
} }
if GetCachedSignature("test-model", text) != sig2 { if GetCachedSignature("gpt-4o", text) != sig2 {
t.Error("Session-b signature mismatch") t.Error("GPT signature mismatch")
} }
} }
func TestCacheSignature_NotFound(t *testing.T) { func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("") ClearSignatureCache("")
// Non-existent session // Non-existent cache entry
if got := GetCachedSignature("test-model", "some text"); got != "" { if got := GetCachedSignature("test-model", "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got) t.Errorf("Expected empty string for missing entry, got '%s'", got)
} }
// Existing session but different text // Existing cache but different text
CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890") CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("test-model", "text-b"); got != "" { if got := GetCachedSignature("test-model", "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got) t.Errorf("Expected empty string for different text, got '%s'", got)
@@ -58,7 +58,6 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("") ClearSignatureCache("")
// All empty/invalid inputs should be no-ops // All empty/invalid inputs should be no-ops
CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890") CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "text", "") CacheSignature("test-model", "text", "")
CacheSignature("test-model", "text", "short") // Too short CacheSignature("test-model", "text", "short") // Too short
@@ -81,20 +80,21 @@ func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
} }
} }
func TestClearSignatureCache_SpecificSession(t *testing.T) { func TestClearSignatureCache_ModelGroup(t *testing.T) {
ClearSignatureCache("") ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456" sigClaude := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("test-model", "text", sig) sigGpt := "validSig9876543210987654321098765432109876543210987654"
CacheSignature("test-model", "text", sig) CacheSignature("claude-sonnet-4-5-thinking", "text", sigClaude)
CacheSignature("gpt-4o", "text", sigGpt)
ClearSignatureCache("session-1") ClearSignatureCache("claude-sonnet-4-5-thinking")
if got := GetCachedSignature("test-model", "text"); got != "" { if got := GetCachedSignature("claude-sonnet-4-5-thinking", "text"); got != "" {
t.Error("session-1 should be cleared") t.Error("Claude cache should be cleared")
} }
if got := GetCachedSignature("test-model", "text"); got != sig { if got := GetCachedSignature("gpt-4o", "text"); got != sigGpt {
t.Error("session-2 should still exist") t.Error("GPT cache should still exist")
} }
} }
@@ -108,10 +108,10 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("") ClearSignatureCache("")
if got := GetCachedSignature("test-model", "text"); got != "" { if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-1 should be cleared") t.Error("cache should be cleared")
} }
if got := GetCachedSignature("test-model", "text"); got != "" { if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-2 should be cleared") t.Error("cache should be cleared")
} }
} }

View File

@@ -98,9 +98,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Use GetThinkingText to handle wrapped thinking objects // Use GetThinkingText to handle wrapped thinking objects
thinkingText := thinking.GetThinkingText(contentResult) thinkingText := thinking.GetThinkingText(contentResult)
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := "" signature := ""
signatureResult := contentResult.Get("signature")
hasClientSignature := signatureResult.Exists() && signatureResult.String() != ""
// Only consider cached signatures when the client provided a signature.
// Unsigned thinking blocks must be dropped.
if hasClientSignature {
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from other requests
if thinkingText != "" { if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" { if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig signature = cachedSig
@@ -110,7 +116,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Fallback to client signature only if cache miss and client signature is valid // Fallback to client signature only if cache miss and client signature is valid
if signature == "" { if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := "" clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" { if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2) arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
@@ -125,6 +130,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
// log.Debugf("Using client-provided signature for thinking block") // log.Debugf("Using client-provided signature for thinking block")
} }
}
// Store for subsequent tool_use in the same message // Store for subsequent tool_use in the same message
if cache.HasValidSignature(modelName, signature) { if cache.HasValidSignature(modelName, signature) {

View File

@@ -78,9 +78,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890" validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..." thinkingText := "Let me think..."
// Pre-cache the signature (simulating a response from the same session) // Pre-cache the signature (simulating a previous response for the same thinking text)
// The session ID is derived from the first user message hash
// Since there's no user message in this test, we need to add one
inputJSON := []byte(`{ inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking", "model": "claude-sonnet-4-5-thinking",
"messages": [ "messages": [

View File

@@ -139,7 +139,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if params.CurrentThinkingText.Len() > 0 { if params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String()) cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len()) // log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset() params.CurrentThinkingText.Reset()
} }

View File

@@ -12,10 +12,10 @@ import (
// Signature Caching Tests // Signature Caching Tests
// ============================================================================ // ============================================================================
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) { func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
cache.ClearSignatureCache("") cache.ClearSignatureCache("")
// Request with user message - should derive session ID // Request with user message - should initialize params
requestJSON := []byte(`{ requestJSON := []byte(`{
"messages": [ "messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]} {"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param) ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param)
// Verify session ID was set
params := param.(*Params) params := param.(*Params)
if params.SessionID == "" { if !params.HasFirstResponse {
t.Error("SessionID should be derived from request") t.Error("HasFirstResponse should be set after first chunk")
}
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated")
} }
} }
@@ -130,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
// Process thinking chunk // Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param) ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param)
params := param.(*Params) params := param.(*Params)
sessionID := params.SessionID
thinkingText := params.CurrentThinkingText.String() thinkingText := params.CurrentThinkingText.String()
if sessionID == "" {
t.Fatal("SessionID should be set")
}
if thinkingText == "" { if thinkingText == "" {
t.Fatal("Thinking text should be accumulated") t.Fatal("Thinking text should be accumulated")
} }

View File

@@ -99,36 +99,44 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
// Gemini-specific handling for non-Claude models: // Gemini-specific handling for non-Claude models:
// - Remove thinking parts entirely.
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation. // - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") { if !strings.Contains(modelName, "claude") {
const skipSentinel = "skip_thought_signature_validator" const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool { gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" { if content.Get("role").String() != "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel return true
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
} }
// Add skip sentinel to functionCall parts partsResult := content.Get("parts")
if !partsResult.IsArray() {
return true
}
parts := partsResult.Array()
newParts := make([]interface{}, 0, len(parts))
for _, part := range parts {
if part.Get("thought").Bool() {
continue
}
partRaw := part.Raw
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String() existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 { if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel) updatedPart, errSet := sjson.Set(partRaw, "thoughtSignature", skipSentinel)
if errSet != nil {
log.WithError(errSet).Debug("failed to set thoughtSignature on functionCall part")
} else {
partRaw = updatedPart
}
} }
} }
return true
})
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices newParts = append(newParts, gjson.Parse(partRaw).Value())
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
} }
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts", contentIdx.Int()), newParts)
return true return true
}) })
} }