feat(gemini-web): Add conversation affinity selector

This commit is contained in:
hkfires
2025-09-29 19:59:03 +08:00
parent f4977e5ef6
commit 82187bffba
15 changed files with 795 additions and 121 deletions

View File

@@ -0,0 +1,80 @@
package conversation
import (
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
var (
aliasOnce sync.Once
aliasMap map[string]string
)
// EnsureGeminiWebAliasMap populates the alias map once.
func EnsureGeminiWebAliasMap() {
aliasOnce.Do(func() {
aliasMap = make(map[string]string)
for _, m := range registry.GetGeminiModels() {
if m.ID == "gemini-2.5-flash-lite" {
continue
}
if m.ID == "gemini-2.5-flash" {
aliasMap["gemini-2.5-flash-image-preview"] = "gemini-2.5-flash"
}
alias := AliasFromModelID(m.ID)
aliasMap[strings.ToLower(alias)] = strings.ToLower(m.ID)
}
})
}
// MapAliasToUnderlying normalizes a model alias to its underlying identifier.
func MapAliasToUnderlying(name string) string {
EnsureGeminiWebAliasMap()
n := strings.ToLower(strings.TrimSpace(name))
if n == "" {
return n
}
if u, ok := aliasMap[n]; ok {
return u
}
const suffix = "-web"
if strings.HasSuffix(n, suffix) {
return strings.TrimSuffix(n, suffix)
}
return n
}
// AliasFromModelID mirrors the original helper for deriving alias IDs.
func AliasFromModelID(modelID string) string {
return modelID + "-web"
}
// NormalizeModel returns the canonical identifier used for hashing.
func NormalizeModel(model string) string {
return MapAliasToUnderlying(model)
}
// GetGeminiWebAliasedModels returns alias metadata for registry exposure.
func GetGeminiWebAliasedModels() []*registry.ModelInfo {
EnsureGeminiWebAliasMap()
aliased := make([]*registry.ModelInfo, 0)
for _, m := range registry.GetGeminiModels() {
if m.ID == "gemini-2.5-flash-lite" {
continue
} else if m.ID == "gemini-2.5-flash" {
cpy := *m
cpy.ID = "gemini-2.5-flash-image-preview"
cpy.Name = "gemini-2.5-flash-image-preview"
cpy.DisplayName = "Nano Banana"
cpy.Description = "Gemini 2.5 Flash Preview Image"
aliased = append(aliased, &cpy)
}
cpy := *m
cpy.ID = AliasFromModelID(m.ID)
cpy.Name = cpy.ID
aliased = append(aliased, &cpy)
}
return aliased
}

View File

@@ -0,0 +1,74 @@
package conversation
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
)
// Message represents a minimal role-text pair used for hashing and comparison.
type Message struct {
Role string `json:"role"`
Text string `json:"text"`
}
// StoredMessage mirrors the persisted conversation message structure.
type StoredMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
// Sha256Hex computes SHA-256 hex digest for the specified string.
func Sha256Hex(s string) string {
sum := sha256.Sum256([]byte(s))
return hex.EncodeToString(sum[:])
}
// ToStoredMessages converts in-memory messages into the persisted representation.
func ToStoredMessages(msgs []Message) []StoredMessage {
out := make([]StoredMessage, 0, len(msgs))
for _, m := range msgs {
out = append(out, StoredMessage{Role: m.Role, Content: m.Text})
}
return out
}
// StoredToMessages converts stored messages back into the in-memory representation.
func StoredToMessages(msgs []StoredMessage) []Message {
out := make([]Message, 0, len(msgs))
for _, m := range msgs {
out = append(out, Message{Role: m.Role, Text: m.Content})
}
return out
}
// hashMessage normalizes message data and returns a stable digest.
func hashMessage(m StoredMessage) string {
s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role))
return Sha256Hex(s)
}
// HashConversationWithPrefix computes a conversation hash using the provided prefix (client identifier) and model.
func HashConversationWithPrefix(prefix, model string, msgs []StoredMessage) string {
var b strings.Builder
b.WriteString(strings.ToLower(strings.TrimSpace(prefix)))
b.WriteString("|")
b.WriteString(strings.ToLower(strings.TrimSpace(model)))
for _, m := range msgs {
b.WriteString("|")
b.WriteString(hashMessage(m))
}
return Sha256Hex(b.String())
}
// HashConversationForAccount keeps compatibility with the per-account hash previously used.
func HashConversationForAccount(clientID, model string, msgs []StoredMessage) string {
return HashConversationWithPrefix(clientID, model, msgs)
}
// HashConversationGlobal produces a hash suitable for cross-account lookups.
func HashConversationGlobal(model string, msgs []StoredMessage) string {
return HashConversationWithPrefix("global", model, msgs)
}

View File

@@ -0,0 +1,187 @@
package conversation
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"sync"
"time"
bolt "go.etcd.io/bbolt"
)
const (
bucketMatches = "matches"
defaultIndexFile = "gemini-web-index.bolt"
)
// MatchRecord stores persisted mapping metadata for a conversation prefix.
type MatchRecord struct {
AccountLabel string `json:"account_label"`
Metadata []string `json:"metadata,omitempty"`
PrefixLen int `json:"prefix_len"`
UpdatedAt int64 `json:"updated_at"`
}
// MatchResult combines a persisted record with the hash that produced it.
type MatchResult struct {
Hash string
Record MatchRecord
Model string
}
var (
indexOnce sync.Once
indexDB *bolt.DB
indexErr error
)
func openIndex() (*bolt.DB, error) {
indexOnce.Do(func() {
path := indexPath()
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
indexErr = err
return
}
db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second})
if err != nil {
indexErr = err
return
}
indexDB = db
})
return indexDB, indexErr
}
func indexPath() string {
wd, err := os.Getwd()
if err != nil || wd == "" {
wd = "."
}
return filepath.Join(wd, "conv", defaultIndexFile)
}
// StoreMatch persists or updates a conversation hash mapping.
func StoreMatch(hash string, record MatchRecord) error {
if strings.TrimSpace(hash) == "" {
return errors.New("gemini-web conversation: empty hash")
}
db, err := openIndex()
if err != nil {
return err
}
record.UpdatedAt = time.Now().UTC().Unix()
payload, err := json.Marshal(record)
if err != nil {
return err
}
return db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte(bucketMatches))
if err != nil {
return err
}
return bucket.Put([]byte(hash), payload)
})
}
// LookupMatch retrieves a stored mapping.
func LookupMatch(hash string) (MatchRecord, bool, error) {
db, err := openIndex()
if err != nil {
return MatchRecord{}, false, err
}
var record MatchRecord
err = db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketMatches))
if bucket == nil {
return nil
}
raw := bucket.Get([]byte(hash))
if len(raw) == 0 {
return nil
}
return json.Unmarshal(raw, &record)
})
if err != nil {
return MatchRecord{}, false, err
}
if record.AccountLabel == "" || record.PrefixLen <= 0 {
return MatchRecord{}, false, nil
}
return record, true, nil
}
// RemoveMatch deletes a mapping for the given hash.
func RemoveMatch(hash string) error {
db, err := openIndex()
if err != nil {
return err
}
return db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketMatches))
if bucket == nil {
return nil
}
return bucket.Delete([]byte(hash))
})
}
// RemoveMatchesByLabel removes all entries associated with the specified label.
func RemoveMatchesByLabel(label string) error {
label = strings.TrimSpace(label)
if label == "" {
return nil
}
db, err := openIndex()
if err != nil {
return err
}
return db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(bucketMatches))
if bucket == nil {
return nil
}
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
if len(v) == 0 {
continue
}
var record MatchRecord
if err := json.Unmarshal(v, &record); err != nil {
_ = bucket.Delete(k)
continue
}
if strings.EqualFold(strings.TrimSpace(record.AccountLabel), label) {
if err := bucket.Delete(k); err != nil {
return err
}
}
}
return nil
})
}
// StoreConversation updates all hashes representing the provided conversation snapshot.
func StoreConversation(label, model string, msgs []Message, metadata []string) error {
label = strings.TrimSpace(label)
if label == "" || len(msgs) == 0 {
return nil
}
hashes := BuildStorageHashes(model, msgs)
if len(hashes) == 0 {
return nil
}
for _, h := range hashes {
rec := MatchRecord{
AccountLabel: label,
Metadata: append([]string(nil), metadata...),
PrefixLen: h.PrefixLen,
}
if err := StoreMatch(h.Hash, rec); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,40 @@
package conversation
import "strings"
// PrefixHash represents a hash candidate for a specific prefix length.
type PrefixHash struct {
Hash string
PrefixLen int
}
// BuildLookupHashes generates hash candidates ordered from longest to shortest prefix.
func BuildLookupHashes(model string, msgs []Message) []PrefixHash {
if len(msgs) < 2 {
return nil
}
model = NormalizeModel(model)
sanitized := SanitizeAssistantMessages(msgs)
result := make([]PrefixHash, 0, len(sanitized))
for end := len(sanitized); end >= 2; end-- {
tailRole := strings.ToLower(strings.TrimSpace(sanitized[end-1].Role))
if tailRole != "assistant" && tailRole != "system" {
continue
}
prefix := sanitized[:end]
hash := HashConversationGlobal(model, ToStoredMessages(prefix))
result = append(result, PrefixHash{Hash: hash, PrefixLen: end})
}
return result
}
// BuildStorageHashes returns hashes representing the full conversation snapshot.
func BuildStorageHashes(model string, msgs []Message) []PrefixHash {
if len(msgs) == 0 {
return nil
}
model = NormalizeModel(model)
sanitized := SanitizeAssistantMessages(msgs)
hash := HashConversationGlobal(model, ToStoredMessages(sanitized))
return []PrefixHash{{Hash: hash, PrefixLen: len(sanitized)}}
}

View File

@@ -0,0 +1,6 @@
package conversation
const (
MetadataMessagesKey = "gemini_web_messages"
MetadataMatchKey = "gemini_web_match"
)

View File

@@ -0,0 +1,102 @@
package conversation
import (
"strings"
"github.com/tidwall/gjson"
)
// ExtractMessages attempts to build a message list from the inbound request payload.
func ExtractMessages(handlerType string, raw []byte) []Message {
if len(raw) == 0 {
return nil
}
if msgs := extractOpenAIStyle(raw); len(msgs) > 0 {
return msgs
}
if msgs := extractGeminiContents(raw); len(msgs) > 0 {
return msgs
}
return nil
}
func extractOpenAIStyle(raw []byte) []Message {
root := gjson.ParseBytes(raw)
messages := root.Get("messages")
if !messages.Exists() {
return nil
}
out := make([]Message, 0, 8)
messages.ForEach(func(_, entry gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(entry.Get("role").String()))
if role == "" {
return true
}
if role == "system" {
return true
}
var contentBuilder strings.Builder
content := entry.Get("content")
if !content.Exists() {
out = append(out, Message{Role: role, Text: ""})
return true
}
switch content.Type {
case gjson.String:
contentBuilder.WriteString(content.String())
case gjson.JSON:
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text"); text.Exists() {
if contentBuilder.Len() > 0 {
contentBuilder.WriteString("\n")
}
contentBuilder.WriteString(text.String())
}
return true
})
}
}
out = append(out, Message{Role: role, Text: contentBuilder.String()})
return true
})
if len(out) == 0 {
return nil
}
return out
}
func extractGeminiContents(raw []byte) []Message {
contents := gjson.GetBytes(raw, "contents")
if !contents.Exists() {
return nil
}
out := make([]Message, 0, 8)
contents.ForEach(func(_, entry gjson.Result) bool {
role := strings.TrimSpace(entry.Get("role").String())
if role == "" {
role = "user"
} else {
role = strings.ToLower(role)
if role == "model" {
role = "assistant"
}
}
var builder strings.Builder
entry.Get("parts").ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text"); text.Exists() {
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString(text.String())
}
return true
})
out = append(out, Message{Role: role, Text: builder.String()})
return true
})
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -0,0 +1,39 @@
package conversation
import (
"regexp"
"strings"
)
var reThink = regexp.MustCompile(`(?is)<think>.*?</think>`)
// RemoveThinkTags strips <think>...</think> blocks and trims whitespace.
func RemoveThinkTags(s string) string {
return strings.TrimSpace(reThink.ReplaceAllString(s, ""))
}
// SanitizeAssistantMessages removes think tags from assistant messages while leaving others untouched.
func SanitizeAssistantMessages(msgs []Message) []Message {
out := make([]Message, 0, len(msgs))
for _, m := range msgs {
if strings.EqualFold(strings.TrimSpace(m.Role), "assistant") {
out = append(out, Message{Role: m.Role, Text: RemoveThinkTags(m.Text)})
continue
}
out = append(out, m)
}
return out
}
// EqualMessages compares two message slices for equality.
func EqualMessages(a, b []Message) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Role != b[i].Role || a[i].Text != b[i].Text {
return false
}
}
return true
}

View File

@@ -4,10 +4,9 @@ import (
"fmt" "fmt"
"html" "html"
"net/http" "net/http"
"strings"
"sync"
"time" "time"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
) )
@@ -105,76 +104,20 @@ const (
ErrorIPTemporarilyBlocked = 1060 ErrorIPTemporarilyBlocked = 1060
) )
var ( func EnsureGeminiWebAliasMap() { conversation.EnsureGeminiWebAliasMap() }
GeminiWebAliasOnce sync.Once
GeminiWebAliasMap map[string]string
)
func EnsureGeminiWebAliasMap() {
GeminiWebAliasOnce.Do(func() {
GeminiWebAliasMap = make(map[string]string)
for _, m := range registry.GetGeminiModels() {
if m.ID == "gemini-2.5-flash-lite" {
continue
} else if m.ID == "gemini-2.5-flash" {
GeminiWebAliasMap["gemini-2.5-flash-image-preview"] = "gemini-2.5-flash"
}
alias := AliasFromModelID(m.ID)
GeminiWebAliasMap[strings.ToLower(alias)] = strings.ToLower(m.ID)
}
})
}
func GetGeminiWebAliasedModels() []*registry.ModelInfo { func GetGeminiWebAliasedModels() []*registry.ModelInfo {
EnsureGeminiWebAliasMap() return conversation.GetGeminiWebAliasedModels()
aliased := make([]*registry.ModelInfo, 0)
for _, m := range registry.GetGeminiModels() {
if m.ID == "gemini-2.5-flash-lite" {
continue
} else if m.ID == "gemini-2.5-flash" {
cpy := *m
cpy.ID = "gemini-2.5-flash-image-preview"
cpy.Name = "gemini-2.5-flash-image-preview"
cpy.DisplayName = "Nano Banana"
cpy.Description = "Gemini 2.5 Flash Preview Image"
aliased = append(aliased, &cpy)
}
cpy := *m
cpy.ID = AliasFromModelID(m.ID)
cpy.Name = cpy.ID
aliased = append(aliased, &cpy)
}
return aliased
} }
func MapAliasToUnderlying(name string) string { func MapAliasToUnderlying(name string) string { return conversation.MapAliasToUnderlying(name) }
EnsureGeminiWebAliasMap()
n := strings.ToLower(name)
if u, ok := GeminiWebAliasMap[n]; ok {
return u
}
const suffix = "-web"
if strings.HasSuffix(n, suffix) {
return strings.TrimSuffix(n, suffix)
}
return name
}
func AliasFromModelID(modelID string) string { func AliasFromModelID(modelID string) string { return conversation.AliasFromModelID(modelID) }
return modelID + "-web"
}
// Conversation domain structures ------------------------------------------- // Conversation domain structures -------------------------------------------
type RoleText struct { type RoleText = conversation.Message
Role string
Text string
}
type StoredMessage struct { type StoredMessage = conversation.StoredMessage
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type ConversationRecord struct { type ConversationRecord struct {
Model string `json:"model"` Model string `json:"model"`

View File

@@ -8,11 +8,11 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
var ( var (
reThink = regexp.MustCompile(`(?s)^\s*<think>.*?</think>\s*`)
reXMLAnyTag = regexp.MustCompile(`(?s)<\s*[^>]+>`) reXMLAnyTag = regexp.MustCompile(`(?s)<\s*[^>]+>`)
) )
@@ -77,20 +77,13 @@ func BuildPrompt(msgs []RoleText, tagged bool, appendAssistant bool) string {
// RemoveThinkTags strips <think>...</think> blocks from a string. // RemoveThinkTags strips <think>...</think> blocks from a string.
func RemoveThinkTags(s string) string { func RemoveThinkTags(s string) string {
return strings.TrimSpace(reThink.ReplaceAllString(s, "")) return conversation.RemoveThinkTags(s)
} }
// SanitizeAssistantMessages removes think tags from assistant messages. // SanitizeAssistantMessages removes think tags from assistant messages.
func SanitizeAssistantMessages(msgs []RoleText) []RoleText { func SanitizeAssistantMessages(msgs []RoleText) []RoleText {
out := make([]RoleText, 0, len(msgs)) cleaned := conversation.SanitizeAssistantMessages(msgs)
for _, m := range msgs { return cleaned
if strings.ToLower(m.Role) == "assistant" {
out = append(out, RoleText{Role: m.Role, Text: RemoveThinkTags(m.Text)})
} else {
out = append(out, m)
}
}
return out
} }
// AppendXMLWrapHintIfNeeded appends an XML wrap hint to messages containing XML-like blocks. // AppendXMLWrapHintIfNeeded appends an XML wrap hint to messages containing XML-like blocks.

View File

@@ -3,8 +3,6 @@ package geminiwebapi
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -19,6 +17,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/constant" "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -51,6 +50,9 @@ type GeminiWebState struct {
convIndex map[string]string convIndex map[string]string
lastRefresh time.Time lastRefresh time.Time
pendingMatchMu sync.Mutex
pendingMatch *conversation.MatchResult
} }
func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, storagePath string) *GeminiWebState { func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, storagePath string) *GeminiWebState {
@@ -62,7 +64,7 @@ func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage,
convData: make(map[string]ConversationRecord), convData: make(map[string]ConversationRecord),
convIndex: make(map[string]string), convIndex: make(map[string]string),
} }
suffix := Sha256Hex(token.Secure1PSID) suffix := conversation.Sha256Hex(token.Secure1PSID)
if len(suffix) > 16 { if len(suffix) > 16 {
suffix = suffix[:16] suffix = suffix[:16]
} }
@@ -81,6 +83,28 @@ func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage,
return state return state
} }
func (s *GeminiWebState) setPendingMatch(match *conversation.MatchResult) {
if s == nil {
return
}
s.pendingMatchMu.Lock()
s.pendingMatch = match
s.pendingMatchMu.Unlock()
}
func (s *GeminiWebState) consumePendingMatch() *conversation.MatchResult {
s.pendingMatchMu.Lock()
defer s.pendingMatchMu.Unlock()
match := s.pendingMatch
s.pendingMatch = nil
return match
}
// SetPendingMatch makes a cached conversation match available for the next request.
func (s *GeminiWebState) SetPendingMatch(match *conversation.MatchResult) {
s.setPendingMatch(match)
}
// Label returns a stable account label for logging and persistence. // Label returns a stable account label for logging and persistence.
// If a storage file path is known, it uses the file base name (without extension). // If a storage file path is known, it uses the file base name (without extension).
// Otherwise, it falls back to the stable client ID (e.g., "gemini-web-<hash>"). // Otherwise, it falls back to the stable client ID (e.g., "gemini-web-<hash>").
@@ -232,7 +256,10 @@ func (s *GeminiWebState) prepare(ctx context.Context, modelName string, rawJSON
mimesSubset := mimes mimesSubset := mimes
if s.useReusableContext() { if s.useReusableContext() {
reuseMeta, remaining := s.findReusableSession(res.underlying, cleaned) reuseMeta, remaining := s.reuseFromPending(res.underlying, cleaned)
if len(reuseMeta) == 0 {
reuseMeta, remaining = s.findReusableSession(res.underlying, cleaned)
}
if len(reuseMeta) > 0 { if len(reuseMeta) > 0 {
res.reuse = true res.reuse = true
meta = reuseMeta meta = reuseMeta
@@ -421,8 +448,16 @@ func (s *GeminiWebState) persistConversation(modelName string, prep *geminiWebPr
if !ok { if !ok {
return return
} }
stableHash := HashConversation(rec.ClientID, prep.underlying, rec.Messages) label := strings.TrimSpace(s.Label())
accountHash := HashConversation(s.accountID, prep.underlying, rec.Messages) if label == "" {
label = s.accountID
}
conversationMsgs := conversation.StoredToMessages(rec.Messages)
if err := conversation.StoreConversation(label, prep.underlying, conversationMsgs, metadata); err != nil {
log.Debugf("gemini web: failed to persist global conversation index: %v", err)
}
stableHash := conversation.HashConversationForAccount(rec.ClientID, prep.underlying, rec.Messages)
accountHash := conversation.HashConversationForAccount(s.accountID, prep.underlying, rec.Messages)
s.convMu.Lock() s.convMu.Lock()
s.convData[stableHash] = rec s.convData[stableHash] = rec
@@ -493,6 +528,27 @@ func (s *GeminiWebState) useReusableContext() bool {
return s.cfg.GeminiWeb.Context return s.cfg.GeminiWeb.Context
} }
func (s *GeminiWebState) reuseFromPending(modelName string, msgs []RoleText) ([]string, []RoleText) {
match := s.consumePendingMatch()
if match == nil {
return nil, nil
}
if !strings.EqualFold(strings.TrimSpace(match.Model), strings.TrimSpace(modelName)) {
return nil, nil
}
prefixLen := match.Record.PrefixLen
if prefixLen <= 0 || prefixLen > len(msgs) {
return nil, nil
}
metadata := match.Record.Metadata
if len(metadata) == 0 {
return nil, nil
}
remaining := make([]RoleText, len(msgs)-prefixLen)
copy(remaining, msgs[prefixLen:])
return metadata, remaining
}
func (s *GeminiWebState) findReusableSession(modelName string, msgs []RoleText) ([]string, []RoleText) { func (s *GeminiWebState) findReusableSession(modelName string, msgs []RoleText) ([]string, []RoleText) {
s.convMu.RLock() s.convMu.RLock()
items := s.convData items := s.convData
@@ -540,42 +596,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
} }
} }
// Persistence helpers --------------------------------------------------
// Sha256Hex computes the SHA256 hash of a string and returns its hex representation.
func Sha256Hex(s string) string {
sum := sha256.Sum256([]byte(s))
return hex.EncodeToString(sum[:])
}
func ToStoredMessages(msgs []RoleText) []StoredMessage {
out := make([]StoredMessage, 0, len(msgs))
for _, m := range msgs {
out = append(out, StoredMessage{
Role: m.Role,
Content: m.Text,
})
}
return out
}
func HashMessage(m StoredMessage) string {
s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role))
return Sha256Hex(s)
}
func HashConversation(clientID, model string, msgs []StoredMessage) string {
var b strings.Builder
b.WriteString(clientID)
b.WriteString("|")
b.WriteString(model)
for _, m := range msgs {
b.WriteString("|")
b.WriteString(HashMessage(m))
}
return Sha256Hex(b.String())
}
// ConvBoltPath returns the BoltDB file path used for both account metadata and conversation data. // ConvBoltPath returns the BoltDB file path used for both account metadata and conversation data.
// Different logical datasets are kept in separate buckets within this single DB file. // Different logical datasets are kept in separate buckets within this single DB file.
func ConvBoltPath(tokenFilePath string) string { func ConvBoltPath(tokenFilePath string) string {
@@ -790,7 +810,7 @@ func BuildConversationRecord(model, clientID string, history []RoleText, output
Model: model, Model: model,
ClientID: clientID, ClientID: clientID,
Metadata: metadata, Metadata: metadata,
Messages: ToStoredMessages(final), Messages: conversation.ToStoredMessages(final),
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} }
@@ -800,9 +820,9 @@ func BuildConversationRecord(model, clientID string, history []RoleText, output
// FindByMessageListIn looks up a conversation record by hashed message list. // FindByMessageListIn looks up a conversation record by hashed message list.
// It attempts both the stable client ID and a legacy email-based ID. // It attempts both the stable client ID and a legacy email-based ID.
func FindByMessageListIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { func FindByMessageListIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) {
stored := ToStoredMessages(msgs) stored := conversation.ToStoredMessages(msgs)
stableHash := HashConversation(stableClientID, model, stored) stableHash := conversation.HashConversationForAccount(stableClientID, model, stored)
fallbackHash := HashConversation(email, model, stored) fallbackHash := conversation.HashConversationForAccount(email, model, stored)
// Try stable hash via index indirection first // Try stable hash via index indirection first
if key, ok := index["hash:"+stableHash]; ok { if key, ok := index["hash:"+stableHash]; ok {

View File

@@ -13,6 +13,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
geminiwebapi "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web" geminiwebapi "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -40,12 +41,18 @@ func (e *GeminiWebExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
if err = state.EnsureClient(); err != nil { if err = state.EnsureClient(); err != nil {
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
} }
match := extractGeminiWebMatch(opts.Metadata)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
mutex := state.GetRequestMutex() mutex := state.GetRequestMutex()
if mutex != nil { if mutex != nil {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if match != nil {
state.SetPendingMatch(match)
}
} else if match != nil {
state.SetPendingMatch(match)
} }
payload := bytes.Clone(req.Payload) payload := bytes.Clone(req.Payload)
@@ -72,11 +79,18 @@ func (e *GeminiWebExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if err = state.EnsureClient(); err != nil { if err = state.EnsureClient(); err != nil {
return nil, err return nil, err
} }
match := extractGeminiWebMatch(opts.Metadata)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
mutex := state.GetRequestMutex() mutex := state.GetRequestMutex()
if mutex != nil { if mutex != nil {
mutex.Lock() mutex.Lock()
if match != nil {
state.SetPendingMatch(match)
}
}
if mutex == nil && match != nil {
state.SetPendingMatch(match)
} }
gemBytes, errMsg, prep := state.Send(ctx, req.Model, bytes.Clone(req.Payload), opts) gemBytes, errMsg, prep := state.Send(ctx, req.Model, bytes.Clone(req.Payload), opts)
@@ -242,3 +256,21 @@ func (e geminiWebError) StatusCode() int {
} }
return e.message.StatusCode return e.message.StatusCode
} }
func extractGeminiWebMatch(metadata map[string]any) *conversation.MatchResult {
if metadata == nil {
return nil
}
value, ok := metadata[conversation.MetadataMatchKey]
if !ok {
return nil
}
switch v := value.(type) {
case *conversation.MatchResult:
return v
case conversation.MatchResult:
return &v
default:
return nil
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -48,6 +49,8 @@ type BaseAPIHandler struct {
Cfg *config.SDKConfig Cfg *config.SDKConfig
} }
const geminiWebProvider = "gemini-web"
// NewBaseAPIHandlers creates a new API handlers instance. // NewBaseAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and configuration as input. // It takes a slice of clients and configuration as input.
// //
@@ -137,6 +140,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
if len(providers) == 0 { if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: modelName, Model: modelName,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -146,6 +150,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
Alt: alt, Alt: alt,
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
Metadata: metadata,
} }
resp, err := h.AuthManager.Execute(ctx, providers, req, opts) resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil { if err != nil {
@@ -161,6 +166,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
if len(providers) == 0 { if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
} }
metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: modelName, Model: modelName,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -170,6 +176,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
Alt: alt, Alt: alt,
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
Metadata: metadata,
} }
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil { if err != nil {
@@ -188,6 +195,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
close(errChan) close(errChan)
return nil, errChan return nil, errChan
} }
metadata := h.buildGeminiWebMetadata(handlerType, providers, rawJSON)
req := coreexecutor.Request{ req := coreexecutor.Request{
Model: modelName, Model: modelName,
Payload: cloneBytes(rawJSON), Payload: cloneBytes(rawJSON),
@@ -197,6 +205,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
Alt: alt, Alt: alt,
OriginalRequest: cloneBytes(rawJSON), OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType), SourceFormat: sdktranslator.FromString(handlerType),
Metadata: metadata,
} }
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil { if err != nil {
@@ -232,6 +241,18 @@ func cloneBytes(src []byte) []byte {
return dst return dst
} }
func (h *BaseAPIHandler) buildGeminiWebMetadata(handlerType string, providers []string, rawJSON []byte) map[string]any {
if !util.InArray(providers, geminiWebProvider) {
return nil
}
meta := make(map[string]any)
msgs := conversation.ExtractMessages(handlerType, rawJSON)
if len(msgs) > 0 {
meta[conversation.MetadataMessagesKey] = msgs
}
return meta
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError status := http.StatusInternalServerError

View File

@@ -0,0 +1,125 @@
package auth
import (
"context"
"strings"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
const (
geminiWebProviderKey = "gemini-web"
)
type geminiWebStickySelector struct {
base Selector
}
func NewGeminiWebStickySelector(base Selector) Selector {
if selector, ok := base.(*geminiWebStickySelector); ok {
return selector
}
if base == nil {
base = &RoundRobinSelector{}
}
return &geminiWebStickySelector{base: base}
}
func (m *Manager) EnableGeminiWebStickySelector() {
if m == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.selector.(*geminiWebStickySelector); ok {
return
}
m.selector = NewGeminiWebStickySelector(m.selector)
}
func (s *geminiWebStickySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
if !strings.EqualFold(provider, geminiWebProviderKey) {
if opts.Metadata != nil {
delete(opts.Metadata, conversation.MetadataMatchKey)
}
return s.base.Pick(ctx, provider, model, opts, auths)
}
messages := extractGeminiWebMessages(opts.Metadata)
if len(messages) >= 2 {
normalizedModel := conversation.NormalizeModel(model)
candidates := conversation.BuildLookupHashes(normalizedModel, messages)
for _, candidate := range candidates {
record, ok, err := conversation.LookupMatch(candidate.Hash)
if err != nil {
log.Warnf("gemini-web selector: lookup failed for hash %s: %v", candidate.Hash, err)
continue
}
if !ok {
continue
}
label := strings.TrimSpace(record.AccountLabel)
if label == "" {
continue
}
auth := findAuthByLabel(auths, label)
if auth != nil {
if opts.Metadata != nil {
opts.Metadata[conversation.MetadataMatchKey] = &conversation.MatchResult{
Hash: candidate.Hash,
Record: record,
Model: normalizedModel,
}
}
return auth, nil
}
_ = conversation.RemoveMatch(candidate.Hash)
}
}
return s.base.Pick(ctx, provider, model, opts, auths)
}
func extractGeminiWebMessages(metadata map[string]any) []conversation.Message {
if metadata == nil {
return nil
}
raw, ok := metadata[conversation.MetadataMessagesKey]
if !ok {
return nil
}
switch v := raw.(type) {
case []conversation.Message:
return v
case *[]conversation.Message:
if v == nil {
return nil
}
return *v
default:
return nil
}
}
func findAuthByLabel(auths []*Auth, label string) *Auth {
if len(auths) == 0 {
return nil
}
normalized := strings.ToLower(strings.TrimSpace(label))
for _, auth := range auths {
if auth == nil {
continue
}
if strings.ToLower(strings.TrimSpace(auth.Label)) == normalized {
return auth
}
if auth.Metadata != nil {
if v, ok := auth.Metadata["label"].(string); ok && strings.ToLower(strings.TrimSpace(v)) == normalized {
return auth
}
}
}
return nil
}

View File

@@ -33,6 +33,8 @@ type Options struct {
OriginalRequest []byte OriginalRequest []byte
// SourceFormat identifies the inbound schema. // SourceFormat identifies the inbound schema.
SourceFormat sdktranslator.Format SourceFormat sdktranslator.Format
// Metadata carries extra execution hints shared across selection and executors.
Metadata map[string]any
} }
// Response wraps either a full provider response or metadata for streaming flows. // Response wraps either a full provider response or metadata for streaming flows.

View File

@@ -15,6 +15,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
geminiwebclient "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web" geminiwebclient "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web"
conversation "github.com/router-for-me/CLIProxyAPI/v6/internal/provider/gemini-web/conversation"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
@@ -206,6 +207,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
} }
GlobalModelRegistry().UnregisterClient(id) GlobalModelRegistry().UnregisterClient(id)
if existing, ok := s.coreManager.GetByID(id); ok && existing != nil { if existing, ok := s.coreManager.GetByID(id); ok && existing != nil {
if strings.EqualFold(existing.Provider, "gemini-web") {
label := strings.TrimSpace(existing.Label)
if label != "" {
if err := conversation.RemoveMatchesByLabel(label); err != nil {
log.Debugf("failed to remove gemini web sticky entries for %s: %v", label, err)
}
}
}
existing.Disabled = true existing.Disabled = true
existing.Status = coreauth.StatusDisabled existing.Status = coreauth.StatusDisabled
if _, err := s.coreManager.Update(ctx, existing); err != nil { if _, err := s.coreManager.Update(ctx, existing); err != nil {
@@ -225,6 +234,7 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
case "gemini-web": case "gemini-web":
s.coreManager.RegisterExecutor(executor.NewGeminiWebExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewGeminiWebExecutor(s.cfg))
s.coreManager.EnableGeminiWebStickySelector()
case "claude": case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "codex": case "codex":