mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 20:30:51 +08:00
feat(gemini-web): Add conversation affinity selector
This commit is contained in:
80
internal/provider/gemini-web/conversation/alias.go
Normal file
80
internal/provider/gemini-web/conversation/alias.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
aliasOnce sync.Once
|
||||
aliasMap map[string]string
|
||||
)
|
||||
|
||||
// EnsureGeminiWebAliasMap populates the alias map once.
|
||||
func EnsureGeminiWebAliasMap() {
|
||||
aliasOnce.Do(func() {
|
||||
aliasMap = make(map[string]string)
|
||||
for _, m := range registry.GetGeminiModels() {
|
||||
if m.ID == "gemini-2.5-flash-lite" {
|
||||
continue
|
||||
}
|
||||
if m.ID == "gemini-2.5-flash" {
|
||||
aliasMap["gemini-2.5-flash-image-preview"] = "gemini-2.5-flash"
|
||||
}
|
||||
alias := AliasFromModelID(m.ID)
|
||||
aliasMap[strings.ToLower(alias)] = strings.ToLower(m.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// MapAliasToUnderlying normalizes a model alias to its underlying identifier.
|
||||
func MapAliasToUnderlying(name string) string {
|
||||
EnsureGeminiWebAliasMap()
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
if n == "" {
|
||||
return n
|
||||
}
|
||||
if u, ok := aliasMap[n]; ok {
|
||||
return u
|
||||
}
|
||||
const suffix = "-web"
|
||||
if strings.HasSuffix(n, suffix) {
|
||||
return strings.TrimSuffix(n, suffix)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// AliasFromModelID mirrors the original helper for deriving alias IDs.
|
||||
func AliasFromModelID(modelID string) string {
|
||||
return modelID + "-web"
|
||||
}
|
||||
|
||||
// NormalizeModel returns the canonical identifier used for hashing.
|
||||
func NormalizeModel(model string) string {
|
||||
return MapAliasToUnderlying(model)
|
||||
}
|
||||
|
||||
// GetGeminiWebAliasedModels returns alias metadata for registry exposure.
|
||||
func GetGeminiWebAliasedModels() []*registry.ModelInfo {
|
||||
EnsureGeminiWebAliasMap()
|
||||
aliased := make([]*registry.ModelInfo, 0)
|
||||
for _, m := range registry.GetGeminiModels() {
|
||||
if m.ID == "gemini-2.5-flash-lite" {
|
||||
continue
|
||||
} else if m.ID == "gemini-2.5-flash" {
|
||||
cpy := *m
|
||||
cpy.ID = "gemini-2.5-flash-image-preview"
|
||||
cpy.Name = "gemini-2.5-flash-image-preview"
|
||||
cpy.DisplayName = "Nano Banana"
|
||||
cpy.Description = "Gemini 2.5 Flash Preview Image"
|
||||
aliased = append(aliased, &cpy)
|
||||
}
|
||||
cpy := *m
|
||||
cpy.ID = AliasFromModelID(m.ID)
|
||||
cpy.Name = cpy.ID
|
||||
aliased = append(aliased, &cpy)
|
||||
}
|
||||
return aliased
|
||||
}
|
||||
74
internal/provider/gemini-web/conversation/hash.go
Normal file
74
internal/provider/gemini-web/conversation/hash.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Message represents a minimal role-text pair used for hashing and comparison.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// StoredMessage mirrors the persisted conversation message structure.
|
||||
type StoredMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
// Sha256Hex computes SHA-256 hex digest for the specified string.
|
||||
func Sha256Hex(s string) string {
|
||||
sum := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// ToStoredMessages converts in-memory messages into the persisted representation.
|
||||
func ToStoredMessages(msgs []Message) []StoredMessage {
|
||||
out := make([]StoredMessage, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, StoredMessage{Role: m.Role, Content: m.Text})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// StoredToMessages converts stored messages back into the in-memory representation.
|
||||
func StoredToMessages(msgs []StoredMessage) []Message {
|
||||
out := make([]Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
out = append(out, Message{Role: m.Role, Text: m.Content})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hashMessage normalizes message data and returns a stable digest.
|
||||
func hashMessage(m StoredMessage) string {
|
||||
s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role))
|
||||
return Sha256Hex(s)
|
||||
}
|
||||
|
||||
// HashConversationWithPrefix computes a conversation hash using the provided prefix (client identifier) and model.
|
||||
func HashConversationWithPrefix(prefix, model string, msgs []StoredMessage) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(strings.ToLower(strings.TrimSpace(prefix)))
|
||||
b.WriteString("|")
|
||||
b.WriteString(strings.ToLower(strings.TrimSpace(model)))
|
||||
for _, m := range msgs {
|
||||
b.WriteString("|")
|
||||
b.WriteString(hashMessage(m))
|
||||
}
|
||||
return Sha256Hex(b.String())
|
||||
}
|
||||
|
||||
// HashConversationForAccount keeps compatibility with the per-account hash previously used.
|
||||
func HashConversationForAccount(clientID, model string, msgs []StoredMessage) string {
|
||||
return HashConversationWithPrefix(clientID, model, msgs)
|
||||
}
|
||||
|
||||
// HashConversationGlobal produces a hash suitable for cross-account lookups.
|
||||
func HashConversationGlobal(model string, msgs []StoredMessage) string {
|
||||
return HashConversationWithPrefix("global", model, msgs)
|
||||
}
|
||||
187
internal/provider/gemini-web/conversation/index.go
Normal file
187
internal/provider/gemini-web/conversation/index.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
bucketMatches = "matches"
|
||||
defaultIndexFile = "gemini-web-index.bolt"
|
||||
)
|
||||
|
||||
// MatchRecord stores persisted mapping metadata for a conversation prefix.
|
||||
type MatchRecord struct {
|
||||
AccountLabel string `json:"account_label"`
|
||||
Metadata []string `json:"metadata,omitempty"`
|
||||
PrefixLen int `json:"prefix_len"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MatchResult combines a persisted record with the hash that produced it.
|
||||
type MatchResult struct {
|
||||
Hash string
|
||||
Record MatchRecord
|
||||
Model string
|
||||
}
|
||||
|
||||
var (
|
||||
indexOnce sync.Once
|
||||
indexDB *bolt.DB
|
||||
indexErr error
|
||||
)
|
||||
|
||||
func openIndex() (*bolt.DB, error) {
|
||||
indexOnce.Do(func() {
|
||||
path := indexPath()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
indexErr = err
|
||||
return
|
||||
}
|
||||
db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second})
|
||||
if err != nil {
|
||||
indexErr = err
|
||||
return
|
||||
}
|
||||
indexDB = db
|
||||
})
|
||||
return indexDB, indexErr
|
||||
}
|
||||
|
||||
func indexPath() string {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil || wd == "" {
|
||||
wd = "."
|
||||
}
|
||||
return filepath.Join(wd, "conv", defaultIndexFile)
|
||||
}
|
||||
|
||||
// StoreMatch persists or updates a conversation hash mapping.
|
||||
func StoreMatch(hash string, record MatchRecord) error {
|
||||
if strings.TrimSpace(hash) == "" {
|
||||
return errors.New("gemini-web conversation: empty hash")
|
||||
}
|
||||
db, err := openIndex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record.UpdatedAt = time.Now().UTC().Unix()
|
||||
payload, err := json.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Update(func(tx *bolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists([]byte(bucketMatches))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bucket.Put([]byte(hash), payload)
|
||||
})
|
||||
}
|
||||
|
||||
// LookupMatch retrieves a stored mapping.
|
||||
func LookupMatch(hash string) (MatchRecord, bool, error) {
|
||||
db, err := openIndex()
|
||||
if err != nil {
|
||||
return MatchRecord{}, false, err
|
||||
}
|
||||
var record MatchRecord
|
||||
err = db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketMatches))
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
raw := bucket.Get([]byte(hash))
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(raw, &record)
|
||||
})
|
||||
if err != nil {
|
||||
return MatchRecord{}, false, err
|
||||
}
|
||||
if record.AccountLabel == "" || record.PrefixLen <= 0 {
|
||||
return MatchRecord{}, false, nil
|
||||
}
|
||||
return record, true, nil
|
||||
}
|
||||
|
||||
// RemoveMatch deletes a mapping for the given hash.
|
||||
func RemoveMatch(hash string) error {
|
||||
db, err := openIndex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketMatches))
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
return bucket.Delete([]byte(hash))
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveMatchesByLabel removes all entries associated with the specified label.
|
||||
func RemoveMatchesByLabel(label string) error {
|
||||
label = strings.TrimSpace(label)
|
||||
if label == "" {
|
||||
return nil
|
||||
}
|
||||
db, err := openIndex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(bucketMatches))
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
cursor := bucket.Cursor()
|
||||
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
|
||||
if len(v) == 0 {
|
||||
continue
|
||||
}
|
||||
var record MatchRecord
|
||||
if err := json.Unmarshal(v, &record); err != nil {
|
||||
_ = bucket.Delete(k)
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(record.AccountLabel), label) {
|
||||
if err := bucket.Delete(k); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// StoreConversation updates all hashes representing the provided conversation snapshot.
|
||||
func StoreConversation(label, model string, msgs []Message, metadata []string) error {
|
||||
label = strings.TrimSpace(label)
|
||||
if label == "" || len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
hashes := BuildStorageHashes(model, msgs)
|
||||
if len(hashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, h := range hashes {
|
||||
rec := MatchRecord{
|
||||
AccountLabel: label,
|
||||
Metadata: append([]string(nil), metadata...),
|
||||
PrefixLen: h.PrefixLen,
|
||||
}
|
||||
if err := StoreMatch(h.Hash, rec); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
40
internal/provider/gemini-web/conversation/lookup.go
Normal file
40
internal/provider/gemini-web/conversation/lookup.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package conversation
|
||||
|
||||
import "strings"
|
||||
|
||||
// PrefixHash represents a hash candidate for a specific prefix length.
|
||||
type PrefixHash struct {
|
||||
Hash string
|
||||
PrefixLen int
|
||||
}
|
||||
|
||||
// BuildLookupHashes generates hash candidates ordered from longest to shortest prefix.
|
||||
func BuildLookupHashes(model string, msgs []Message) []PrefixHash {
|
||||
if len(msgs) < 2 {
|
||||
return nil
|
||||
}
|
||||
model = NormalizeModel(model)
|
||||
sanitized := SanitizeAssistantMessages(msgs)
|
||||
result := make([]PrefixHash, 0, len(sanitized))
|
||||
for end := len(sanitized); end >= 2; end-- {
|
||||
tailRole := strings.ToLower(strings.TrimSpace(sanitized[end-1].Role))
|
||||
if tailRole != "assistant" && tailRole != "system" {
|
||||
continue
|
||||
}
|
||||
prefix := sanitized[:end]
|
||||
hash := HashConversationGlobal(model, ToStoredMessages(prefix))
|
||||
result = append(result, PrefixHash{Hash: hash, PrefixLen: end})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildStorageHashes returns hashes representing the full conversation snapshot.
|
||||
func BuildStorageHashes(model string, msgs []Message) []PrefixHash {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
model = NormalizeModel(model)
|
||||
sanitized := SanitizeAssistantMessages(msgs)
|
||||
hash := HashConversationGlobal(model, ToStoredMessages(sanitized))
|
||||
return []PrefixHash{{Hash: hash, PrefixLen: len(sanitized)}}
|
||||
}
|
||||
6
internal/provider/gemini-web/conversation/metadata.go
Normal file
6
internal/provider/gemini-web/conversation/metadata.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package conversation
|
||||
|
||||
const (
|
||||
MetadataMessagesKey = "gemini_web_messages"
|
||||
MetadataMatchKey = "gemini_web_match"
|
||||
)
|
||||
102
internal/provider/gemini-web/conversation/parse.go
Normal file
102
internal/provider/gemini-web/conversation/parse.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ExtractMessages attempts to build a message list from the inbound request payload.
|
||||
func ExtractMessages(handlerType string, raw []byte) []Message {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
if msgs := extractOpenAIStyle(raw); len(msgs) > 0 {
|
||||
return msgs
|
||||
}
|
||||
if msgs := extractGeminiContents(raw); len(msgs) > 0 {
|
||||
return msgs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractOpenAIStyle(raw []byte) []Message {
|
||||
root := gjson.ParseBytes(raw)
|
||||
messages := root.Get("messages")
|
||||
if !messages.Exists() {
|
||||
return nil
|
||||
}
|
||||
out := make([]Message, 0, 8)
|
||||
messages.ForEach(func(_, entry gjson.Result) bool {
|
||||
role := strings.ToLower(strings.TrimSpace(entry.Get("role").String()))
|
||||
if role == "" {
|
||||
return true
|
||||
}
|
||||
if role == "system" {
|
||||
return true
|
||||
}
|
||||
var contentBuilder strings.Builder
|
||||
content := entry.Get("content")
|
||||
if !content.Exists() {
|
||||
out = append(out, Message{Role: role, Text: ""})
|
||||
return true
|
||||
}
|
||||
switch content.Type {
|
||||
case gjson.String:
|
||||
contentBuilder.WriteString(content.String())
|
||||
case gjson.JSON:
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
if contentBuilder.Len() > 0 {
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
contentBuilder.WriteString(text.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
out = append(out, Message{Role: role, Text: contentBuilder.String()})
|
||||
return true
|
||||
})
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractGeminiContents(raw []byte) []Message {
|
||||
contents := gjson.GetBytes(raw, "contents")
|
||||
if !contents.Exists() {
|
||||
return nil
|
||||
}
|
||||
out := make([]Message, 0, 8)
|
||||
contents.ForEach(func(_, entry gjson.Result) bool {
|
||||
role := strings.TrimSpace(entry.Get("role").String())
|
||||
if role == "" {
|
||||
role = "user"
|
||||
} else {
|
||||
role = strings.ToLower(role)
|
||||
if role == "model" {
|
||||
role = "assistant"
|
||||
}
|
||||
}
|
||||
var builder strings.Builder
|
||||
entry.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString(text.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
out = append(out, Message{Role: role, Text: builder.String()})
|
||||
return true
|
||||
})
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
39
internal/provider/gemini-web/conversation/sanitize.go
Normal file
39
internal/provider/gemini-web/conversation/sanitize.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var reThink = regexp.MustCompile(`(?is)<think>.*?</think>`)
|
||||
|
||||
// RemoveThinkTags strips <think>...</think> blocks and trims whitespace.
|
||||
func RemoveThinkTags(s string) string {
|
||||
return strings.TrimSpace(reThink.ReplaceAllString(s, ""))
|
||||
}
|
||||
|
||||
// SanitizeAssistantMessages removes think tags from assistant messages while leaving others untouched.
|
||||
func SanitizeAssistantMessages(msgs []Message) []Message {
|
||||
out := make([]Message, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
if strings.EqualFold(strings.TrimSpace(m.Role), "assistant") {
|
||||
out = append(out, Message{Role: m.Role, Text: RemoveThinkTags(m.Text)})
|
||||
continue
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// EqualMessages compares two message slices for equality.
|
||||
func EqualMessages(a, b []Message) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].Role != b[i].Role || a[i].Text != b[i].Text {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user