mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Fixes thinking signature validation errors
Addresses an issue where thinking signature validation fails due to model mapping and empty internal registry. - Implements a fallback mechanism in the router to use the global model registry when the internal registry is empty. This ensures that models registered via API keys are correctly resolved even without local provider configurations. - Modifies `GetModelGroup` to use registry-based grouping in addition to name pattern matching, covering cases where models are registered with API keys but lack provider names in their names. - Updates signature validation to compare model groups instead of exact model names. These changes resolve thinking signature validation errors and improve the accuracy of model resolution.
This commit is contained in:
19
internal/cache/signature_cache.go
vendored
19
internal/cache/signature_cache.go
vendored
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
// SignatureEntry holds a cached thinking signature with timestamp
|
||||
@@ -184,6 +186,7 @@ func HasValidSignature(modelName, signature string) bool {
|
||||
}
|
||||
|
||||
func GetModelGroup(modelName string) string {
|
||||
// Fast path: check model name patterns first
|
||||
if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
} else if strings.Contains(modelName, "claude") {
|
||||
@@ -191,5 +194,21 @@ func GetModelGroup(modelName string) string {
|
||||
} else if strings.Contains(modelName, "gemini") {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
// Slow path: check registry for provider-based grouping
|
||||
// This handles models registered via claude-api-key, gemini-api-key, etc.
|
||||
// that don't have provider name in their model name (e.g., kimi-k2.5 via claude-api-key)
|
||||
if providers := registry.GetGlobalRegistry().GetModelProviders(modelName); len(providers) > 0 {
|
||||
provider := strings.ToLower(providers[0])
|
||||
switch provider {
|
||||
case "claude":
|
||||
return "claude"
|
||||
case "gemini", "gemini-cli", "aistudio", "vertex", "antigravity":
|
||||
return "gemini"
|
||||
case "codex":
|
||||
return "gpt"
|
||||
}
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
81
internal/cache/signature_cache_test.go
vendored
81
internal/cache/signature_cache_test.go
vendored
@@ -208,3 +208,84 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
|
||||
// === GetModelGroup Tests ===
|
||||
// These tests verify that GetModelGroup correctly identifies model groups
|
||||
// both by name pattern (fast path) and by registry provider lookup (slow path).
|
||||
|
||||
func TestGetModelGroup_ByNamePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelName string
|
||||
expectedGroup string
|
||||
}{
|
||||
{"gpt-4o", "gpt"},
|
||||
{"gpt-4-turbo", "gpt"},
|
||||
{"claude-sonnet-4-20250514", "claude"},
|
||||
{"claude-opus-4-5-thinking", "claude"},
|
||||
{"gemini-2.5-pro", "gemini"},
|
||||
{"gemini-3-pro-preview", "gemini"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
result := GetModelGroup(tt.modelName)
|
||||
if result != tt.expectedGroup {
|
||||
t.Errorf("GetModelGroup(%q) = %q, expected %q", tt.modelName, result, tt.expectedGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelGroup_UnknownModel(t *testing.T) {
|
||||
// For unknown models with no registry entry, should return the model name itself
|
||||
result := GetModelGroup("unknown-model-xyz")
|
||||
if result != "unknown-model-xyz" {
|
||||
t.Errorf("GetModelGroup for unknown model should return model name, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetModelGroup_RegistryFallback tests that models registered via
|
||||
// provider-specific API keys (e.g., kimi-k2.5 via claude-api-key) are
|
||||
// correctly grouped by their provider.
|
||||
// This test requires a populated global registry.
|
||||
func TestGetModelGroup_RegistryFallback(t *testing.T) {
|
||||
// This test only makes sense when the global registry is populated
|
||||
// In unit test context, skip if registry is empty
|
||||
|
||||
// Example: kimi-k2.5 registered via claude-api-key should group as "claude"
|
||||
// The model name doesn't contain "claude", so name pattern matching fails.
|
||||
// The registry should be checked to find the provider.
|
||||
|
||||
// Skip for now - this requires integration test setup
|
||||
t.Skip("Requires populated global registry - run as integration test")
|
||||
}
|
||||
|
||||
// === Cross-Model Signature Validation Tests ===
|
||||
// These tests verify that signatures cached under one model name can be
|
||||
// validated under mapped model names (same provider group).
|
||||
|
||||
func TestCacheSignature_CrossModelValidation(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Original request uses "claude-opus-4-5-20251101"
|
||||
originalModel := "claude-opus-4-5-20251101"
|
||||
// Mapped model is "claude-opus-4-5-thinking"
|
||||
mappedModel := "claude-opus-4-5-thinking"
|
||||
|
||||
text := "Some thinking block content"
|
||||
sig := "validSignature123456789012345678901234567890123456789012"
|
||||
|
||||
// Cache signature under the original model
|
||||
CacheSignature(originalModel, text, sig)
|
||||
|
||||
// Both should return the same signature because they're in the same group
|
||||
retrieved1 := GetCachedSignature(originalModel, text)
|
||||
retrieved2 := GetCachedSignature(mappedModel, text)
|
||||
|
||||
if retrieved1 != sig {
|
||||
t.Errorf("Original model signature mismatch: got %q", retrieved1)
|
||||
}
|
||||
if retrieved2 != sig {
|
||||
t.Errorf("Mapped model signature mismatch: got %q", retrieved2)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user