refactor(cache, translator): refine signature caching logic and tests, replace session-based logic with model group handling

This commit is contained in:
Luis Pater
2026-01-21 18:30:05 +08:00
parent ef4508dbc8
commit d9c6317c84
7 changed files with 129 additions and 126 deletions

View File

@@ -3,7 +3,6 @@ package cache
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"time"
@@ -25,18 +24,18 @@ const (
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// SessionCleanupInterval controls how often stale sessions are purged
SessionCleanupInterval = 10 * time.Minute
// CacheCleanupInterval controls how often stale entries are purged
CacheCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCleanupOnce ensures the background cleanup goroutine starts only once
var sessionCleanupOnce sync.Once
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
var cacheCleanupOnce sync.Once
// sessionCache is the inner map type
type sessionCache struct {
// groupCache is the inner map type
type groupCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
@@ -47,36 +46,36 @@ func hashText(text string) string {
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
// getOrCreateGroupCache gets or creates a cache bucket for a model group
func getOrCreateGroupCache(groupKey string) *groupCache {
// Start background cleanup on first access
sessionCleanupOnce.Do(startSessionCleanup)
cacheCleanupOnce.Do(startCacheCleanup)
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
if val, ok := signatureCache.Load(groupKey); ok {
return val.(*groupCache)
}
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
return actual.(*sessionCache)
sc := &groupCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
return actual.(*groupCache)
}
// startSessionCleanup launches a background goroutine that periodically
// removes sessions where all entries have expired.
func startSessionCleanup() {
// startCacheCleanup launches a background goroutine that periodically
// removes caches where all entries have expired.
func startCacheCleanup() {
go func() {
ticker := time.NewTicker(SessionCleanupInterval)
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredSessions()
purgeExpiredCaches()
}
}()
}
// purgeExpiredSessions removes sessions with no valid (non-expired) entries.
func purgeExpiredSessions() {
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
func purgeExpiredCaches() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*sessionCache)
sc := value.(*groupCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
@@ -86,7 +85,7 @@ func purgeExpiredSessions() {
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove session if empty
// Remove cache bucket if empty
if isEmpty {
signatureCache.Delete(key)
}
@@ -94,7 +93,7 @@ func purgeExpiredSessions() {
})
}
// CacheSignature stores a thinking signature for a given session and text.
// CacheSignature stores a thinking signature for a given model group and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" {
@@ -104,9 +103,9 @@ func CacheSignature(modelName, text, signature string) {
return
}
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
groupKey := GetModelGroup(modelName)
textHash := hashText(text)
sc := getOrCreateSession(textHash)
sc := getOrCreateGroupCache(groupKey)
sc.mu.Lock()
defer sc.mu.Unlock()
@@ -116,26 +115,25 @@ func CacheSignature(modelName, text, signature string) {
}
}
// GetCachedSignature retrieves a cached signature for a given session and text.
// GetCachedSignature retrieves a cached signature for a given model group and text.
// Returns empty string if not found or expired.
func GetCachedSignature(modelName, text string) string {
family := GetModelGroup(modelName)
groupKey := GetModelGroup(modelName)
if text == "" {
if family == "gemini" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
text = fmt.Sprintf("%s#%s", GetModelGroup(modelName), text)
val, ok := signatureCache.Load(hashText(text))
val, ok := signatureCache.Load(groupKey)
if !ok {
if family == "gemini" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*sessionCache)
sc := val.(*groupCache)
textHash := hashText(text)
@@ -145,7 +143,7 @@ func GetCachedSignature(modelName, text string) string {
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
if family == "gemini" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
@@ -153,7 +151,7 @@ func GetCachedSignature(modelName, text string) string {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
if family == "gemini" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
@@ -167,22 +165,17 @@ func GetCachedSignature(modelName, text string) string {
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Range(func(key, _ any) bool {
kStr, ok := key.(string)
if ok && strings.HasSuffix(kStr, "#"+sessionID) {
signatureCache.Delete(key)
}
return true
})
} else {
// ClearSignatureCache clears signature cache for a specific model group or all groups.
func ClearSignatureCache(modelName string) {
if modelName == "" {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
return
}
groupKey := GetModelGroup(modelName)
signatureCache.Delete(groupKey)
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)

View File

@@ -21,33 +21,33 @@ func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
}
}
func TestCacheSignature_DifferentSessions(t *testing.T) {
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
ClearSignatureCache("")
text := "Same text in different sessions"
text := "Same text across models"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("test-model", text, sig1)
CacheSignature("test-model", text, sig2)
CacheSignature("claude-sonnet-4-5-thinking", text, sig1)
CacheSignature("gpt-4o", text, sig2)
if GetCachedSignature("test-model", text) != sig1 {
t.Error("Session-a signature mismatch")
if GetCachedSignature("claude-sonnet-4-5-thinking", text) != sig1 {
t.Error("Claude signature mismatch")
}
if GetCachedSignature("test-model", text) != sig2 {
t.Error("Session-b signature mismatch")
if GetCachedSignature("gpt-4o", text) != sig2 {
t.Error("GPT signature mismatch")
}
}
func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
// Non-existent cache entry
if got := GetCachedSignature("test-model", "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
t.Errorf("Expected empty string for missing entry, got '%s'", got)
}
// Existing session but different text
// Existing cache but different text
CacheSignature("test-model", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("test-model", "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
@@ -58,7 +58,6 @@ func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("test-model", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("test-model", "text", "")
CacheSignature("test-model", "text", "short") // Too short
@@ -81,20 +80,21 @@ func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
}
}
func TestClearSignatureCache_SpecificSession(t *testing.T) {
func TestClearSignatureCache_ModelGroup(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("test-model", "text", sig)
CacheSignature("test-model", "text", sig)
sigClaude := "validSig1234567890123456789012345678901234567890123456"
sigGpt := "validSig9876543210987654321098765432109876543210987654"
CacheSignature("claude-sonnet-4-5-thinking", "text", sigClaude)
CacheSignature("gpt-4o", "text", sigGpt)
ClearSignatureCache("session-1")
ClearSignatureCache("claude-sonnet-4-5-thinking")
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-1 should be cleared")
if got := GetCachedSignature("claude-sonnet-4-5-thinking", "text"); got != "" {
t.Error("Claude cache should be cleared")
}
if got := GetCachedSignature("test-model", "text"); got != sig {
t.Error("session-2 should still exist")
if got := GetCachedSignature("gpt-4o", "text"); got != sigGpt {
t.Error("GPT cache should still exist")
}
}
@@ -108,10 +108,10 @@ func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-1 should be cleared")
t.Error("cache should be cleared")
}
if got := GetCachedSignature("test-model", "text"); got != "" {
t.Error("session-2 should be cleared")
t.Error("cache should be cleared")
}
}

View File

@@ -98,9 +98,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Use GetThinkingText to handle wrapped thinking objects
thinkingText := thinking.GetThinkingText(contentResult)
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
signatureResult := contentResult.Get("signature")
hasClientSignature := signatureResult.Exists() && signatureResult.String() != ""
// Only consider cached signatures when the client provided a signature.
// Unsigned thinking blocks must be dropped.
if hasClientSignature {
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from other requests
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
@@ -110,7 +116,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
@@ -125,6 +130,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// log.Debugf("Using client-provided signature for thinking block")
}
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(modelName, signature) {

View File

@@ -78,9 +78,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
// Pre-cache the signature (simulating a response from the same session)
// The session ID is derived from the first user message hash
// Since there's no user message in this test, we need to add one
// Pre-cache the signature (simulating a previous response for the same thinking text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [

View File

@@ -139,7 +139,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}

View File

@@ -12,10 +12,10 @@ import (
// Signature Caching Tests
// ============================================================================
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
cache.ClearSignatureCache("")
// Request with user message - should derive session ID
// Request with user message - should initialize params
requestJSON := []byte(`{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
@@ -37,10 +37,12 @@ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param)
// Verify session ID was set
params := param.(*Params)
if params.SessionID == "" {
t.Error("SessionID should be derived from request")
if !params.HasFirstResponse {
t.Error("HasFirstResponse should be set after first chunk")
}
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated")
}
}
@@ -130,12 +132,8 @@ func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
// Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param)
params := param.(*Params)
sessionID := params.SessionID
thinkingText := params.CurrentThinkingText.String()
if sessionID == "" {
t.Fatal("SessionID should be set")
}
if thinkingText == "" {
t.Fatal("Thinking text should be accumulated")
}

View File

@@ -99,36 +99,44 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// Gemini-specific handling for non-Claude models:
// - Remove thinking parts entirely.
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") {
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
if content.Get("role").String() != "model" {
return true
}
// Add skip sentinel to functionCall parts
partsResult := content.Get("parts")
if !partsResult.IsArray() {
return true
}
parts := partsResult.Array()
newParts := make([]interface{}, 0, len(parts))
for _, part := range parts {
if part.Get("thought").Bool() {
continue
}
partRaw := part.Raw
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
updatedPart, errSet := sjson.Set(partRaw, "thoughtSignature", skipSentinel)
if errSet != nil {
log.WithError(errSet).Debug("failed to set thoughtSignature on functionCall part")
} else {
partRaw = updatedPart
}
}
}
return true
})
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
newParts = append(newParts, gjson.Parse(partRaw).Value())
}
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts", contentIdx.Int()), newParts)
return true
})
}