diff --git a/internal/provider/gemini-web/state.go b/internal/provider/gemini-web/state.go index cd6c5af5..e9ee8f13 100644 --- a/internal/provider/gemini-web/state.go +++ b/internal/provider/gemini-web/state.go @@ -56,6 +56,12 @@ type GeminiWebState struct { pendingMatch *conversation.MatchResult } +type reuseComputation struct { + metadata []string + history []RoleText + overlap int +} + func NewGeminiWebState(cfg *config.Config, token *gemini.GeminiWebTokenStorage, storagePath, authLabel string) *GeminiWebState { state := &GeminiWebState{ cfg: cfg, @@ -155,6 +161,78 @@ func (s *GeminiWebState) convPath() string { return ConvBoltPath(base) } +func cloneRoleTextSlice(in []RoleText) []RoleText { + if len(in) == 0 { + return nil + } + out := make([]RoleText, len(in)) + copy(out, in) + return out +} + +func cloneStringSlice(in []string) []string { + if len(in) == 0 { + return nil + } + out := make([]string, len(in)) + copy(out, in) + return out +} + +func longestHistoryOverlap(history, incoming []RoleText) int { + max := len(history) + if len(incoming) < max { + max = len(incoming) + } + for overlap := max; overlap > 0; overlap-- { + if conversation.EqualMessages(history[len(history)-overlap:], incoming[:overlap]) { + return overlap + } + } + return 0 +} + +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func storedMessagesToRoleText(stored []conversation.StoredMessage) []RoleText { + if len(stored) == 0 { + return nil + } + converted := make([]RoleText, len(stored)) + for i, msg := range stored { + converted[i] = RoleText{Role: msg.Role, Text: msg.Content} + } + return converted +} + +func (s *GeminiWebState) findConversationByMetadata(model string, metadata []string) ([]RoleText, bool) { + if len(metadata) == 0 { + return nil, false + } + s.convMu.RLock() + defer s.convMu.RUnlock() + for _, rec := range s.convData { + if !strings.EqualFold(strings.TrimSpace(rec.Model), strings.TrimSpace(model)) { + continue + } + if !equalStringSlice(rec.Metadata, metadata) { + continue + } + return cloneRoleTextSlice(storedMessagesToRoleText(rec.Messages)), true + } + return nil, false +} + func (s *GeminiWebState) GetRequestMutex() *sync.Mutex { return &s.reqMu } func (s *GeminiWebState) EnsureClient() error { @@ -248,7 +326,7 @@ func (s *GeminiWebState) prepare(ctx context.Context, modelName string, rawJSON return nil, &interfaces.ErrorMessage{StatusCode: 400, Error: fmt.Errorf("bad request: %w", err)} } cleaned := SanitizeAssistantMessages(messages) - res.cleaned = cleaned + fullCleaned := cloneRoleTextSlice(cleaned) res.underlying = MapAliasToUnderlying(modelName) model, err := ModelFromName(res.underlying) if err != nil { @@ -261,18 +339,27 @@ func (s *GeminiWebState) prepare(ctx context.Context, modelName string, rawJSON mimesSubset := mimes if s.useReusableContext() { - reuseMeta, remaining := s.reuseFromPending(res.underlying, cleaned) - if len(reuseMeta) == 0 { - reuseMeta, remaining = s.findReusableSession(res.underlying, cleaned) + reusePlan := s.reuseFromPending(res.underlying, cleaned) + if reusePlan == nil { + reusePlan = s.findReusableSession(res.underlying, cleaned) } - if len(reuseMeta) > 0 { + if reusePlan != nil { res.reuse = true - meta = reuseMeta - if len(remaining) == 1 { - useMsgs = []RoleText{remaining[0]} - } else if len(remaining) > 1 { - useMsgs = remaining - } else if len(cleaned) > 0 { + meta = cloneStringSlice(reusePlan.metadata) + overlap := reusePlan.overlap + if overlap > len(cleaned) { + overlap = len(cleaned) + } else if overlap < 0 { + overlap = 0 + } + delta := cloneRoleTextSlice(cleaned[overlap:]) + if len(reusePlan.history) > 0 { + fullCleaned = append(cloneRoleTextSlice(reusePlan.history), delta...) + } else { + fullCleaned = append(cloneRoleTextSlice(cleaned[:overlap]), delta...) + } + useMsgs = delta + if len(delta) == 0 && len(cleaned) > 0 { useMsgs = []RoleText{cleaned[len(cleaned)-1]} } if len(useMsgs) == 1 && len(messages) > 0 && len(msgFileIdx) == len(messages) { @@ -330,6 +417,8 @@ func (s *GeminiWebState) prepare(ctx context.Context, modelName string, rawJSON s.convMu.RUnlock() } + res.cleaned = fullCleaned + res.tagged = NeedRoleTags(useMsgs) if res.reuse && len(useMsgs) == 1 { res.tagged = false @@ -533,33 +622,44 @@ func (s *GeminiWebState) useReusableContext() bool { return s.cfg.GeminiWeb.Context } -func (s *GeminiWebState) reuseFromPending(modelName string, msgs []RoleText) ([]string, []RoleText) { +func (s *GeminiWebState) reuseFromPending(modelName string, msgs []RoleText) *reuseComputation { match := s.consumePendingMatch() if match == nil { - return nil, nil + return nil } if !strings.EqualFold(strings.TrimSpace(match.Model), strings.TrimSpace(modelName)) { - return nil, nil + return nil } - prefixLen := match.Record.PrefixLen - if prefixLen <= 0 || prefixLen > len(msgs) { - return nil, nil - } - metadata := match.Record.Metadata + metadata := cloneStringSlice(match.Record.Metadata) if len(metadata) == 0 { - return nil, nil + return nil } - remaining := make([]RoleText, len(msgs)-prefixLen) - copy(remaining, msgs[prefixLen:]) - return metadata, remaining + history, ok := s.findConversationByMetadata(modelName, metadata) + if !ok { + return nil + } + overlap := longestHistoryOverlap(history, msgs) + return &reuseComputation{metadata: metadata, history: history, overlap: overlap} } -func (s *GeminiWebState) findReusableSession(modelName string, msgs []RoleText) ([]string, []RoleText) { +func (s *GeminiWebState) findReusableSession(modelName string, msgs []RoleText) *reuseComputation { s.convMu.RLock() items := s.convData index := s.convIndex s.convMu.RUnlock() - return FindReusableSessionIn(items, index, s.stableClientID, s.accountID, modelName, msgs) + rec, metadata, overlap, ok := FindReusableSessionIn(items, index, s.stableClientID, s.accountID, modelName, msgs) + if !ok { + return nil + } + history := cloneRoleTextSlice(storedMessagesToRoleText(rec.Messages)) + if len(history) == 0 { + return nil + } + // Ensure overlap reflects the actual history alignment. + if computed := longestHistoryOverlap(history, msgs); computed > 0 { + overlap = computed + } + return &reuseComputation{metadata: cloneStringSlice(metadata), history: history, overlap: overlap} } func (s *GeminiWebState) getConfiguredGem() *Gem { @@ -865,9 +965,9 @@ func FindConversationIn(items map[string]ConversationRecord, index map[string]st } // FindReusableSessionIn returns reusable metadata and the remaining message suffix. -func FindReusableSessionIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) ([]string, []RoleText) { +func FindReusableSessionIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, []string, int, bool) { if len(msgs) < 2 { - return nil, nil + return ConversationRecord{}, nil, 0, false } searchEnd := len(msgs) for searchEnd >= 2 { @@ -875,11 +975,10 @@ func FindReusableSessionIn(items map[string]ConversationRecord, index map[string tail := sub[len(sub)-1] if strings.EqualFold(tail.Role, "assistant") || strings.EqualFold(tail.Role, "system") { if rec, ok := FindConversationIn(items, index, stableClientID, email, model, sub); ok { - remain := msgs[searchEnd:] - return rec.Metadata, remain + return rec, rec.Metadata, searchEnd, true } } searchEnd-- } - return nil, nil + return ConversationRecord{}, nil, 0, false }