From 82187bffba28296c1ebbe7de29f7f91de44a3d6c Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:59:03 +0800 Subject: [PATCH] feat(gemini-web): Add conversation affinity selector --- .../provider/gemini-web/conversation/alias.go | 80 ++++++++ .../provider/gemini-web/conversation/hash.go | 74 +++++++ .../provider/gemini-web/conversation/index.go | 187 ++++++++++++++++++ .../gemini-web/conversation/lookup.go | 40 ++++ .../gemini-web/conversation/metadata.go | 6 + .../provider/gemini-web/conversation/parse.go | 102 ++++++++++ .../gemini-web/conversation/sanitize.go | 39 ++++ internal/provider/gemini-web/models.go | 71 +------ internal/provider/gemini-web/prompt.go | 15 +- internal/provider/gemini-web/state.go | 112 ++++++----- .../runtime/executor/gemini_web_executor.go | 32 +++ sdk/api/handlers/handlers.go | 21 ++ sdk/cliproxy/auth/selector_rr.go | 125 ++++++++++++ sdk/cliproxy/executor/types.go | 2 + sdk/cliproxy/service.go | 10 + 15 files changed, 795 insertions(+), 121 deletions(-) create mode 100644 internal/provider/gemini-web/conversation/alias.go create mode 100644 internal/provider/gemini-web/conversation/hash.go create mode 100644 internal/provider/gemini-web/conversation/index.go create mode 100644 internal/provider/gemini-web/conversation/lookup.go create mode 100644 internal/provider/gemini-web/conversation/metadata.go create mode 100644 internal/provider/gemini-web/conversation/parse.go create mode 100644 internal/provider/gemini-web/conversation/sanitize.go create mode 100644 sdk/cliproxy/auth/selector_rr.go diff --git a/internal/provider/gemini-web/conversation/alias.go b/internal/provider/gemini-web/conversation/alias.go new file mode 100644 index 00000000..b0481883 --- /dev/null +++ b/internal/provider/gemini-web/conversation/alias.go @@ -0,0 +1,80 @@ +package conversation + +import ( + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +var ( + aliasOnce sync.Once + aliasMap map[string]string +) + +// EnsureGeminiWebAliasMap populates the alias map once. +func EnsureGeminiWebAliasMap() { + aliasOnce.Do(func() { + aliasMap = make(map[string]string) + for _, m := range registry.GetGeminiModels() { + if m.ID == "gemini-2.5-flash-lite" { + continue + } + if m.ID == "gemini-2.5-flash" { + aliasMap["gemini-2.5-flash-image-preview"] = "gemini-2.5-flash" + } + alias := AliasFromModelID(m.ID) + aliasMap[strings.ToLower(alias)] = strings.ToLower(m.ID) + } + }) +} + +// MapAliasToUnderlying normalizes a model alias to its underlying identifier. +func MapAliasToUnderlying(name string) string { + EnsureGeminiWebAliasMap() + n := strings.ToLower(strings.TrimSpace(name)) + if n == "" { + return n + } + if u, ok := aliasMap[n]; ok { + return u + } + const suffix = "-web" + if strings.HasSuffix(n, suffix) { + return strings.TrimSuffix(n, suffix) + } + return n +} + +// AliasFromModelID mirrors the original helper for deriving alias IDs. +func AliasFromModelID(modelID string) string { + return modelID + "-web" +} + +// NormalizeModel returns the canonical identifier used for hashing. +func NormalizeModel(model string) string { + return MapAliasToUnderlying(model) +} + +// GetGeminiWebAliasedModels returns alias metadata for registry exposure. +func GetGeminiWebAliasedModels() []*registry.ModelInfo { + EnsureGeminiWebAliasMap() + aliased := make([]*registry.ModelInfo, 0) + for _, m := range registry.GetGeminiModels() { + if m.ID == "gemini-2.5-flash-lite" { + continue + } else if m.ID == "gemini-2.5-flash" { + cpy := *m + cpy.ID = "gemini-2.5-flash-image-preview" + cpy.Name = "gemini-2.5-flash-image-preview" + cpy.DisplayName = "Nano Banana" + cpy.Description = "Gemini 2.5 Flash Preview Image" + aliased = append(aliased, &cpy) + } + cpy := *m + cpy.ID = AliasFromModelID(m.ID) + cpy.Name = cpy.ID + aliased = append(aliased, &cpy) + } + return aliased +} diff --git a/internal/provider/gemini-web/conversation/hash.go b/internal/provider/gemini-web/conversation/hash.go new file mode 100644 index 00000000..a163a3b2 --- /dev/null +++ b/internal/provider/gemini-web/conversation/hash.go @@ -0,0 +1,74 @@ +package conversation + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" +) + +// Message represents a minimal role-text pair used for hashing and comparison. +type Message struct { + Role string `json:"role"` + Text string `json:"text"` +} + +// StoredMessage mirrors the persisted conversation message structure. +type StoredMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +// Sha256Hex computes SHA-256 hex digest for the specified string. +func Sha256Hex(s string) string { + sum := sha256.Sum256([]byte(s)) + return hex.EncodeToString(sum[:]) +} + +// ToStoredMessages converts in-memory messages into the persisted representation. +func ToStoredMessages(msgs []Message) []StoredMessage { + out := make([]StoredMessage, 0, len(msgs)) + for _, m := range msgs { + out = append(out, StoredMessage{Role: m.Role, Content: m.Text}) + } + return out +} + +// StoredToMessages converts stored messages back into the in-memory representation. +func StoredToMessages(msgs []StoredMessage) []Message { + out := make([]Message, 0, len(msgs)) + for _, m := range msgs { + out = append(out, Message{Role: m.Role, Text: m.Content}) + } + return out +} + +// hashMessage normalizes message data and returns a stable digest. +func hashMessage(m StoredMessage) string { + s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role)) + return Sha256Hex(s) +} + +// HashConversationWithPrefix computes a conversation hash using the provided prefix (client identifier) and model. +func HashConversationWithPrefix(prefix, model string, msgs []StoredMessage) string { + var b strings.Builder + b.WriteString(strings.ToLower(strings.TrimSpace(prefix))) + b.WriteString("|") + b.WriteString(strings.ToLower(strings.TrimSpace(model))) + for _, m := range msgs { + b.WriteString("|") + b.WriteString(hashMessage(m)) + } + return Sha256Hex(b.String()) +} + +// HashConversationForAccount keeps compatibility with the per-account hash previously used. +func HashConversationForAccount(clientID, model string, msgs []StoredMessage) string { + return HashConversationWithPrefix(clientID, model, msgs) +} + +// HashConversationGlobal produces a hash suitable for cross-account lookups. +func HashConversationGlobal(model string, msgs []StoredMessage) string { + return HashConversationWithPrefix("global", model, msgs) +} diff --git a/internal/provider/gemini-web/conversation/index.go b/internal/provider/gemini-web/conversation/index.go new file mode 100644 index 00000000..cd3b6d47 --- /dev/null +++ b/internal/provider/gemini-web/conversation/index.go @@ -0,0 +1,187 @@ +package conversation + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "sync" + "time" + + bolt "go.etcd.io/bbolt" +) + +const ( + bucketMatches = "matches" + defaultIndexFile = "gemini-web-index.bolt" +) + +// MatchRecord stores persisted mapping metadata for a conversation prefix. +type MatchRecord struct { + AccountLabel string `json:"account_label"` + Metadata []string `json:"metadata,omitempty"` + PrefixLen int `json:"prefix_len"` + UpdatedAt int64 `json:"updated_at"` +} + +// MatchResult combines a persisted record with the hash that produced it. +type MatchResult struct { + Hash string + Record MatchRecord + Model string +} + +var ( + indexOnce sync.Once + indexDB *bolt.DB + indexErr error +) + +func openIndex() (*bolt.DB, error) { + indexOnce.Do(func() { + path := indexPath() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + indexErr = err + return + } + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second}) + if err != nil { + indexErr = err + return + } + indexDB = db + }) + return indexDB, indexErr +} + +func indexPath() string { + wd, err := os.Getwd() + if err != nil || wd == "" { + wd = "." + } + return filepath.Join(wd, "conv", defaultIndexFile) +} + +// StoreMatch persists or updates a conversation hash mapping. +func StoreMatch(hash string, record MatchRecord) error { + if strings.TrimSpace(hash) == "" { + return errors.New("gemini-web conversation: empty hash") + } + db, err := openIndex() + if err != nil { + return err + } + record.UpdatedAt = time.Now().UTC().Unix() + payload, err := json.Marshal(record) + if err != nil { + return err + } + return db.Update(func(tx *bolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists([]byte(bucketMatches)) + if err != nil { + return err + } + return bucket.Put([]byte(hash), payload) + }) +} + +// LookupMatch retrieves a stored mapping. +func LookupMatch(hash string) (MatchRecord, bool, error) { + db, err := openIndex() + if err != nil { + return MatchRecord{}, false, err + } + var record MatchRecord + err = db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(bucketMatches)) + if bucket == nil { + return nil + } + raw := bucket.Get([]byte(hash)) + if len(raw) == 0 { + return nil + } + return json.Unmarshal(raw, &record) + }) + if err != nil { + return MatchRecord{}, false, err + } + if record.AccountLabel == "" || record.PrefixLen <= 0 { + return MatchRecord{}, false, nil + } + return record, true, nil +} + +// RemoveMatch deletes a mapping for the given hash. +func RemoveMatch(hash string) error { + db, err := openIndex() + if err != nil { + return err + } + return db.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(bucketMatches)) + if bucket == nil { + return nil + } + return bucket.Delete([]byte(hash)) + }) +} + +// RemoveMatchesByLabel removes all entries associated with the specified label. +func RemoveMatchesByLabel(label string) error { + label = strings.TrimSpace(label) + if label == "" { + return nil + } + db, err := openIndex() + if err != nil { + return err + } + return db.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(bucketMatches)) + if bucket == nil { + return nil + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + if len(v) == 0 { + continue + } + var record MatchRecord + if err := json.Unmarshal(v, &record); err != nil { + _ = bucket.Delete(k) + continue + } + if strings.EqualFold(strings.TrimSpace(record.AccountLabel), label) { + if err := bucket.Delete(k); err != nil { + return err + } + } + } + return nil + }) +} + +// StoreConversation updates all hashes representing the provided conversation snapshot. +func StoreConversation(label, model string, msgs []Message, metadata []string) error { + label = strings.TrimSpace(label) + if label == "" || len(msgs) == 0 { + return nil + } + hashes := BuildStorageHashes(model, msgs) + if len(hashes) == 0 { + return nil + } + for _, h := range hashes { + rec := MatchRecord{ + AccountLabel: label, + Metadata: append([]string(nil), metadata...), + PrefixLen: h.PrefixLen, + } + if err := StoreMatch(h.Hash, rec); err != nil { + return err + } + } + return nil +} diff --git a/internal/provider/gemini-web/conversation/lookup.go b/internal/provider/gemini-web/conversation/lookup.go new file mode 100644 index 00000000..18debc51 --- /dev/null +++ b/internal/provider/gemini-web/conversation/lookup.go @@ -0,0 +1,40 @@ +package conversation + +import "strings" + +// PrefixHash represents a hash candidate for a specific prefix length. +type PrefixHash struct { + Hash string + PrefixLen int +} + +// BuildLookupHashes generates hash candidates ordered from longest to shortest prefix. +func BuildLookupHashes(model string, msgs []Message) []PrefixHash { + if len(msgs) < 2 { + return nil + } + model = NormalizeModel(model) + sanitized := SanitizeAssistantMessages(msgs) + result := make([]PrefixHash, 0, len(sanitized)) + for end := len(sanitized); end >= 2; end-- { + tailRole := strings.ToLower(strings.TrimSpace(sanitized[end-1].Role)) + if tailRole != "assistant" && tailRole != "system" { + continue + } + prefix := sanitized[:end] + hash := HashConversationGlobal(model, ToStoredMessages(prefix)) + result = append(result, PrefixHash{Hash: hash, PrefixLen: end}) + } + return result +} + +// BuildStorageHashes returns hashes representing the full conversation snapshot. +func BuildStorageHashes(model string, msgs []Message) []PrefixHash { + if len(msgs) == 0 { + return nil + } + model = NormalizeModel(model) + sanitized := SanitizeAssistantMessages(msgs) + hash := HashConversationGlobal(model, ToStoredMessages(sanitized)) + return []PrefixHash{{Hash: hash, PrefixLen: len(sanitized)}} +} diff --git a/internal/provider/gemini-web/conversation/metadata.go b/internal/provider/gemini-web/conversation/metadata.go new file mode 100644 index 00000000..ba20f5b3 --- /dev/null +++ b/internal/provider/gemini-web/conversation/metadata.go @@ -0,0 +1,6 @@ +package conversation + +const ( + MetadataMessagesKey = "gemini_web_messages" + MetadataMatchKey = "gemini_web_match" +) diff --git a/internal/provider/gemini-web/conversation/parse.go b/internal/provider/gemini-web/conversation/parse.go new file mode 100644 index 00000000..d27cb708 --- /dev/null +++ b/internal/provider/gemini-web/conversation/parse.go @@ -0,0 +1,102 @@ +package conversation + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +// ExtractMessages attempts to build a message list from the inbound request payload. +func ExtractMessages(handlerType string, raw []byte) []Message { + if len(raw) == 0 { + return nil + } + if msgs := extractOpenAIStyle(raw); len(msgs) > 0 { + return msgs + } + if msgs := extractGeminiContents(raw); len(msgs) > 0 { + return msgs + } + return nil +} + +func extractOpenAIStyle(raw []byte) []Message { + root := gjson.ParseBytes(raw) + messages := root.Get("messages") + if !messages.Exists() { + return nil + } + out := make([]Message, 0, 8) + messages.ForEach(func(_, entry gjson.Result) bool { + role := strings.ToLower(strings.TrimSpace(entry.Get("role").String())) + if role == "" { + return true + } + if role == "system" { + return true + } + var contentBuilder strings.Builder + content := entry.Get("content") + if !content.Exists() { + out = append(out, Message{Role: role, Text: ""}) + return true + } + switch content.Type { + case gjson.String: + contentBuilder.WriteString(content.String()) + case gjson.JSON: + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text"); text.Exists() { + if contentBuilder.Len() > 0 { + contentBuilder.WriteString("\n") + } + contentBuilder.WriteString(text.String()) + } + return true + }) + } + } + out = append(out, Message{Role: role, Text: contentBuilder.String()}) + return true + }) + if len(out) == 0 { + return nil + } + return out +} + +func extractGeminiContents(raw []byte) []Message { + contents := gjson.GetBytes(raw, "contents") + if !contents.Exists() { + return nil + } + out := make([]Message, 0, 8) + contents.ForEach(func(_, entry gjson.Result) bool { + role := strings.TrimSpace(entry.Get("role").String()) + if role == "" { + role = "user" + } else { + role = strings.ToLower(role) + if role == "model" { + role = "assistant" + } + } + var builder strings.Builder + entry.Get("parts").ForEach(func(_, part gjson.Result) bool { + if text := part.Get("text"); text.Exists() { + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString(text.String()) + } + return true + }) + out = append(out, Message{Role: role, Text: builder.String()}) + return true + }) + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/provider/gemini-web/conversation/sanitize.go b/internal/provider/gemini-web/conversation/sanitize.go new file mode 100644 index 00000000..82359702 --- /dev/null +++ b/internal/provider/gemini-web/conversation/sanitize.go @@ -0,0 +1,39 @@ +package conversation + +import ( + "regexp" + "strings" +) + +var reThink = regexp.MustCompile(`(?is).*?`) + +// RemoveThinkTags strips ... blocks and trims whitespace. +func RemoveThinkTags(s string) string { + return strings.TrimSpace(reThink.ReplaceAllString(s, "")) +} + +// SanitizeAssistantMessages removes think tags from assistant messages while leaving others untouched. +func SanitizeAssistantMessages(msgs []Message) []Message { + out := make([]Message, 0, len(msgs)) + for _, m := range msgs { + if strings.EqualFold(strings.TrimSpace(m.Role), "assistant") { + out = append(out, Message{Role: m.Role, Text: RemoveThinkTags(m.Text)}) + continue + } + out = append(out, m) + } + return out +} + +// EqualMessages compares two message slices for equality. +func EqualMessages(a, b []Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Role != b[i].Role || a[i].Text != b[i].Text { + return false + } + } + return true +} diff --git a/internal/provider/gemini-web/models.go b/internal/provider/gemini-web/models.go index c4cb29e8..b1e50dc3 100644 --- a/internal/provider/gemini-web/models.go +++ b/internal/provider/gemini-web/models.go @@ -4,10 +4,9 @@ import ( "fmt" "html" "net/http" - "strings" - "sync" "time" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" ) @@ -105,76 +104,20 @@ const ( ErrorIPTemporarilyBlocked = 1060 ) -var ( - GeminiWebAliasOnce sync.Once - GeminiWebAliasMap map[string]string -) - -func EnsureGeminiWebAliasMap() { - GeminiWebAliasOnce.Do(func() { - GeminiWebAliasMap = make(map[string]string) - for _, m := range registry.GetGeminiModels() { - if m.ID == "gemini-2.5-flash-lite" { - continue - } else if m.ID == "gemini-2.5-flash" { - GeminiWebAliasMap["gemini-2.5-flash-image-preview"] = "gemini-2.5-flash" - } - alias := AliasFromModelID(m.ID) - GeminiWebAliasMap[strings.ToLower(alias)] = strings.ToLower(m.ID) - } - }) -} +func EnsureGeminiWebAliasMap() { conversation.EnsureGeminiWebAliasMap() } func GetGeminiWebAliasedModels() []*registry.ModelInfo { - EnsureGeminiWebAliasMap() - aliased := make([]*registry.ModelInfo, 0) - for _, m := range registry.GetGeminiModels() { - if m.ID == "gemini-2.5-flash-lite" { - continue - } else if m.ID == "gemini-2.5-flash" { - cpy := *m - cpy.ID = "gemini-2.5-flash-image-preview" - cpy.Name = "gemini-2.5-flash-image-preview" - cpy.DisplayName = "Nano Banana" - cpy.Description = "Gemini 2.5 Flash Preview Image" - aliased = append(aliased, &cpy) - } - cpy := *m - cpy.ID = AliasFromModelID(m.ID) - cpy.Name = cpy.ID - aliased = append(aliased, &cpy) - } - return aliased + return conversation.GetGeminiWebAliasedModels() } -func MapAliasToUnderlying(name string) string { - EnsureGeminiWebAliasMap() - n := strings.ToLower(name) - if u, ok := GeminiWebAliasMap[n]; ok { - return u - } - const suffix = "-web" - if strings.HasSuffix(n, suffix) { - return strings.TrimSuffix(n, suffix) - } - return name -} +func MapAliasToUnderlying(name string) string { return conversation.MapAliasToUnderlying(name) } -func AliasFromModelID(modelID string) string { - return modelID + "-web" -} +func AliasFromModelID(modelID string) string { return conversation.AliasFromModelID(modelID) } // Conversation domain structures ------------------------------------------- -type RoleText struct { - Role string - Text string -} +type RoleText = conversation.Message -type StoredMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Name string `json:"name,omitempty"` -} +type StoredMessage = conversation.StoredMessage type ConversationRecord struct { Model string `json:"model"` diff --git a/internal/provider/gemini-web/prompt.go b/internal/provider/gemini-web/prompt.go index 1f9cd8be..e3051243 100644 --- a/internal/provider/gemini-web/prompt.go +++ b/internal/provider/gemini-web/prompt.go @@ -8,11 +8,11 @@ import ( "unicode/utf8" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" "github.com/tidwall/gjson" ) var ( - reThink = regexp.MustCompile(`(?s)^\s*.*?\s*`) reXMLAnyTag = regexp.MustCompile(`(?s)<\s*[^>]+>`) ) @@ -77,20 +77,13 @@ func BuildPrompt(msgs []RoleText, tagged bool, appendAssistant bool) string { // RemoveThinkTags strips ... blocks from a string. func RemoveThinkTags(s string) string { - return strings.TrimSpace(reThink.ReplaceAllString(s, "")) + return conversation.RemoveThinkTags(s) } // SanitizeAssistantMessages removes think tags from assistant messages. func SanitizeAssistantMessages(msgs []RoleText) []RoleText { - out := make([]RoleText, 0, len(msgs)) - for _, m := range msgs { - if strings.ToLower(m.Role) == "assistant" { - out = append(out, RoleText{Role: m.Role, Text: RemoveThinkTags(m.Text)}) - } else { - out = append(out, m) - } - } - return out + cleaned := conversation.SanitizeAssistantMessages(msgs) + return cleaned } // AppendXMLWrapHintIfNeeded appends an XML wrap hint to messages containing XML-like blocks. diff --git a/internal/provider/gemini-web/state.go b/internal/provider/gemini-web/state.go index e0044984..e7b86aef 100644 --- a/internal/provider/gemini-web/state.go +++ b/internal/provider/gemini-web/state.go @@ -3,8 +3,6 @@ package geminiwebapi import ( "bytes" "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -19,6 +17,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" log "github.com/sirupsen/logrus" @@ -51,6 +50,9 @@ type GeminiWebState struct { convIndex map[string]string lastRefresh time.Time + + pendingMatchMu sync.Mutex + pendingMatch *conversation.MatchResult } func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, storagePath string) *GeminiWebState { @@ -62,7 +64,7 @@ func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, convData: make(map[string]ConversationRecord), convIndex: make(map[string]string), } - suffix := Sha256Hex(token.Secure1PSID) + suffix := conversation.Sha256Hex(token.Secure1PSID) if len(suffix) > 16 { suffix = suffix[:16] } @@ -81,6 +83,28 @@ func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, return state } +func (s *GeminiWebState) setPendingMatch(match *conversation.MatchResult) { + if s == nil { + return + } + s.pendingMatchMu.Lock() + s.pendingMatch = match + s.pendingMatchMu.Unlock() +} + +func (s *GeminiWebState) consumePendingMatch() *conversation.MatchResult { + s.pendingMatchMu.Lock() + defer s.pendingMatchMu.Unlock() + match := s.pendingMatch + s.pendingMatch = nil + return match +} + +// SetPendingMatch makes a cached conversation match available for the next request. +func (s *GeminiWebState) SetPendingMatch(match *conversation.MatchResult) { + s.setPendingMatch(match) +} + // Label returns a stable account label for logging and persistence. // If a storage file path is known, it uses the file base name (without extension). // Otherwise, it falls back to the stable client ID (e.g., "gemini-web-"). @@ -232,7 +256,10 @@ func (s *GeminiWebState) prepare(ctx context.Context, modelName string, rawJSON mimesSubset := mimes if s.useReusableContext() { - reuseMeta, remaining := s.findReusableSession(res.underlying, cleaned) + reuseMeta, remaining := s.reuseFromPending(res.underlying, cleaned) + if len(reuseMeta) == 0 { + reuseMeta, remaining = s.findReusableSession(res.underlying, cleaned) + } if len(reuseMeta) > 0 { res.reuse = true meta = reuseMeta @@ -421,8 +448,16 @@ func (s *GeminiWebState) persistConversation(modelName string, prep *geminiWebPr if !ok { return } - stableHash := HashConversation(rec.ClientID, prep.underlying, rec.Messages) - accountHash := HashConversation(s.accountID, prep.underlying, rec.Messages) + label := strings.TrimSpace(s.Label()) + if label == "" { + label = s.accountID + } + conversationMsgs := conversation.StoredToMessages(rec.Messages) + if err := conversation.StoreConversation(label, prep.underlying, conversationMsgs, metadata); err != nil { + log.Debugf("gemini web: failed to persist global conversation index: %v", err) + } + stableHash := conversation.HashConversationForAccount(rec.ClientID, prep.underlying, rec.Messages) + accountHash := conversation.HashConversationForAccount(s.accountID, prep.underlying, rec.Messages) s.convMu.Lock() s.convData[stableHash] = rec @@ -493,6 +528,27 @@ func (s *GeminiWebState) useReusableContext() bool { return s.cfg.GeminiWeb.Context } +func (s *GeminiWebState) reuseFromPending(modelName string, msgs []RoleText) ([]string, []RoleText) { + match := s.consumePendingMatch() + if match == nil { + return nil, nil + } + if !strings.EqualFold(strings.TrimSpace(match.Model), strings.TrimSpace(modelName)) { + return nil, nil + } + prefixLen := match.Record.PrefixLen + if prefixLen <= 0 || prefixLen > len(msgs) { + return nil, nil + } + metadata := match.Record.Metadata + if len(metadata) == 0 { + return nil, nil + } + remaining := make([]RoleText, len(msgs)-prefixLen) + copy(remaining, msgs[prefixLen:]) + return metadata, remaining +} + func (s *GeminiWebState) findReusableSession(modelName string, msgs []RoleText) ([]string, []RoleText) { s.convMu.RLock() items := s.convData @@ -540,42 +596,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt } } -// Persistence helpers -------------------------------------------------- - -// Sha256Hex computes the SHA256 hash of a string and returns its hex representation. -func Sha256Hex(s string) string { - sum := sha256.Sum256([]byte(s)) - return hex.EncodeToString(sum[:]) -} - -func ToStoredMessages(msgs []RoleText) []StoredMessage { - out := make([]StoredMessage, 0, len(msgs)) - for _, m := range msgs { - out = append(out, StoredMessage{ - Role: m.Role, - Content: m.Text, - }) - } - return out -} - -func HashMessage(m StoredMessage) string { - s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role)) - return Sha256Hex(s) -} - -func HashConversation(clientID, model string, msgs []StoredMessage) string { - var b strings.Builder - b.WriteString(clientID) - b.WriteString("|") - b.WriteString(model) - for _, m := range msgs { - b.WriteString("|") - b.WriteString(HashMessage(m)) - } - return Sha256Hex(b.String()) -} - // ConvBoltPath returns the BoltDB file path used for both account metadata and conversation data. // Different logical datasets are kept in separate buckets within this single DB file. func ConvBoltPath(tokenFilePath string) string { @@ -790,7 +810,7 @@ func BuildConversationRecord(model, clientID string, history []RoleText, output Model: model, ClientID: clientID, Metadata: metadata, - Messages: ToStoredMessages(final), + Messages: conversation.ToStoredMessages(final), CreatedAt: time.Now(), UpdatedAt: time.Now(), } @@ -800,9 +820,9 @@ func BuildConversationRecord(model, clientID string, history []RoleText, output // FindByMessageListIn looks up a conversation record by hashed message list. // It attempts both the stable client ID and a legacy email-based ID. func FindByMessageListIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { - stored := ToStoredMessages(msgs) - stableHash := HashConversation(stableClientID, model, stored) - fallbackHash := HashConversation(email, model, stored) + stored := conversation.ToStoredMessages(msgs) + stableHash := conversation.HashConversationForAccount(stableClientID, model, stored) + fallbackHash := conversation.HashConversationForAccount(email, model, stored) // Try stable hash via index indirection first if key, ok := index["hash:"+stableHash]; ok { diff --git a/internal/runtime/executor/gemini_web_executor.go b/internal/runtime/executor/gemini_web_executor.go index f026299c..4b0db143 100644 --- a/internal/runtime/executor/gemini_web_executor.go +++ b/internal/runtime/executor/gemini_web_executor.go @@ -13,6 +13,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" geminiwebapi "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -40,12 +41,18 @@ func (e *GeminiWebExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth if err = state.EnsureClient(); err != nil { return cliproxyexecutor.Response{}, err } + match := extractGeminiWebMatch(opts.Metadata) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) mutex := state.GetRequestMutex() if mutex != nil { mutex.Lock() defer mutex.Unlock() + if match != nil { + state.SetPendingMatch(match) + } + } else if match != nil { + state.SetPendingMatch(match) } payload := bytes.Clone(req.Payload) @@ -72,11 +79,18 @@ func (e *GeminiWebExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut if err = state.EnsureClient(); err != nil { return nil, err } + match := extractGeminiWebMatch(opts.Metadata) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) mutex := state.GetRequestMutex() if mutex != nil { mutex.Lock() + if match != nil { + state.SetPendingMatch(match) + } + } + if mutex == nil && match != nil { + state.SetPendingMatch(match) } gemBytes, errMsg, prep := state.Send(ctx, req.Model, bytes.Clone(req.Payload), opts) @@ -242,3 +256,21 @@ func (e geminiWebError) StatusCode() int { } return e.message.StatusCode } + +func extractGeminiWebMatch(metadata map[string]any) *conversation.MatchResult { + if metadata == nil { + return nil + } + value, ok := metadata[conversation.MetadataMatchKey] + if !ok { + return nil + } + switch v := value.(type) { + case *conversation.MatchResult: + return v + case conversation.MatchResult: + return &v + default: + return nil + } +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index f9f86fd3..cfa61bc0 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -48,6 +49,8 @@ type BaseAPIHandler struct { Cfg *config.SDKConfig } +const geminiWebProvider = "gemini-web" + // NewBaseAPIHandlers creates a new API handlers instance. // It takes a slice of clients and configuration as input. // @@ -137,6 +140,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType if len(providers) == 0 { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} } + metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON) req := coreexecutor.Request{ Model: modelName, Payload: cloneBytes(rawJSON), @@ -146,6 +150,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), + Metadata: metadata, } resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { @@ -161,6 +166,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle if len(providers) == 0 { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} } + metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON) req := coreexecutor.Request{ Model: modelName, Payload: cloneBytes(rawJSON), @@ -170,6 +176,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), + Metadata: metadata, } resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { @@ -188,6 +195,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl close(errChan) return nil, errChan } + metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON) req := coreexecutor.Request{ Model: modelName, Payload: cloneBytes(rawJSON), @@ -197,6 +205,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), + Metadata: metadata, } chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { @@ -232,6 +241,18 @@ func cloneBytes(src []byte) []byte { return dst } +func (h *BaseAPIHandler) buildGeminiWebMetadata(handlerType string, providers []string, rawJSON []byte) map[string]any { + if !util.InArray(providers, geminiWebProvider) { + return nil + } + meta := make(map[string]any) + msgs := conversation.ExtractMessages(handlerType, rawJSON) + if len(msgs) > 0 { + meta[conversation.MetadataMessagesKey] = msgs + } + return meta +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError diff --git a/sdk/cliproxy/auth/selector_rr.go b/sdk/cliproxy/auth/selector_rr.go new file mode 100644 index 00000000..65b666b8 --- /dev/null +++ b/sdk/cliproxy/auth/selector_rr.go @@ -0,0 +1,125 @@ +package auth + +import ( + "context" + "strings" + + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + log "github.com/sirupsen/logrus" +) + +const ( + geminiWebProviderKey = "gemini-web" +) + +type geminiWebStickySelector struct { + base Selector +} + +func NewGeminiWebStickySelector(base Selector) Selector { + if selector, ok := base.(*geminiWebStickySelector); ok { + return selector + } + if base == nil { + base = &RoundRobinSelector{} + } + return &geminiWebStickySelector{base: base} +} + +func (m *Manager) EnableGeminiWebStickySelector() { + if m == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.selector.(*geminiWebStickySelector); ok { + return + } + m.selector = NewGeminiWebStickySelector(m.selector) +} + +func (s *geminiWebStickySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + if !strings.EqualFold(provider, geminiWebProviderKey) { + if opts.Metadata != nil { + delete(opts.Metadata, conversation.MetadataMatchKey) + } + return s.base.Pick(ctx, provider, model, opts, auths) + } + + messages := extractGeminiWebMessages(opts.Metadata) + if len(messages) >= 2 { + normalizedModel := conversation.NormalizeModel(model) + candidates := conversation.BuildLookupHashes(normalizedModel, messages) + for _, candidate := range candidates { + record, ok, err := conversation.LookupMatch(candidate.Hash) + if err != nil { + log.Warnf("gemini-web selector: lookup failed for hash %s: %v", candidate.Hash, err) + continue + } + if !ok { + continue + } + label := strings.TrimSpace(record.AccountLabel) + if label == "" { + continue + } + auth := findAuthByLabel(auths, label) + if auth != nil { + if opts.Metadata != nil { + opts.Metadata[conversation.MetadataMatchKey] = &conversation.MatchResult{ + Hash: candidate.Hash, + Record: record, + Model: normalizedModel, + } + } + return auth, nil + } + _ = conversation.RemoveMatch(candidate.Hash) + } + } + + return s.base.Pick(ctx, provider, model, opts, auths) +} + +func extractGeminiWebMessages(metadata map[string]any) []conversation.Message { + if metadata == nil { + return nil + } + raw, ok := metadata[conversation.MetadataMessagesKey] + if !ok { + return nil + } + switch v := raw.(type) { + case []conversation.Message: + return v + case *[]conversation.Message: + if v == nil { + return nil + } + return *v + default: + return nil + } +} + +func findAuthByLabel(auths []*Auth, label string) *Auth { + if len(auths) == 0 { + return nil + } + normalized := strings.ToLower(strings.TrimSpace(label)) + for _, auth := range auths { + if auth == nil { + continue + } + if strings.ToLower(strings.TrimSpace(auth.Label)) == normalized { + return auth + } + if auth.Metadata != nil { + if v, ok := auth.Metadata["label"].(string); ok && strings.ToLower(strings.TrimSpace(v)) == normalized { + return auth + } + } + } + return nil +} diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 5b48b11d..c8bb9447 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -33,6 +33,8 @@ type Options struct { OriginalRequest []byte // SourceFormat identifies the inbound schema. SourceFormat sdktranslator.Format + // Metadata carries extra execution hints shared across selection and executors. + Metadata map[string]any } // Response wraps either a full provider response or metadata for streaming flows. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index be9f3716..88cdf2a1 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -15,6 +15,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" geminiwebclient "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web" + conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" @@ -206,6 +207,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { } GlobalModelRegistry().UnregisterClient(id) if existing, ok := s.coreManager.GetByID(id); ok && existing != nil { + if strings.EqualFold(existing.Provider, "gemini-web") { + label := strings.TrimSpace(existing.Label) + if label != "" { + if err := conversation.RemoveMatchesByLabel(label); err != nil { + log.Debugf("failed to remove gemini web sticky entries for %s: %v", label, err) + } + } + } existing.Disabled = true existing.Status = coreauth.StatusDisabled if _, err := s.coreManager.Update(ctx, existing); err != nil { @@ -225,6 +234,7 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "gemini-web": s.coreManager.RegisterExecutor(executor.NewGeminiWebExecutor(s.cfg)) + s.coreManager.EnableGeminiWebStickySelector() case "claude": s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) case "codex":