diff --git a/internal/provider/gemini-web/auth.go b/internal/provider/gemini-web/auth.go deleted file mode 100644 index c10f76ee..00000000 --- a/internal/provider/gemini-web/auth.go +++ /dev/null @@ -1,240 +0,0 @@ -package geminiwebapi - -import ( - "crypto/tls" - "errors" - "io" - "net/http" - "net/http/cookiejar" - "net/url" - "os" - "path/filepath" - "regexp" - "strings" - "time" - - log "github.com/sirupsen/logrus" -) - -type httpOptions struct { - ProxyURL string - Insecure bool - FollowRedirects bool -} - -func newHTTPClient(opts httpOptions) *http.Client { - transport := &http.Transport{} - if opts.ProxyURL != "" { - if pu, err := url.Parse(opts.ProxyURL); err == nil { - transport.Proxy = http.ProxyURL(pu) - } - } - if opts.Insecure { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - jar, _ := cookiejar.New(nil) - client := &http.Client{Transport: transport, Timeout: 60 * time.Second, Jar: jar} - if !opts.FollowRedirects { - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - } - return client -} - -func applyHeaders(req *http.Request, headers http.Header) { - for k, v := range headers { - for _, vv := range v { - req.Header.Add(k, vv) - } - } -} - -func applyCookies(req *http.Request, cookies map[string]string) { - for k, v := range cookies { - req.AddCookie(&http.Cookie{Name: k, Value: v}) - } -} - -func sendInitRequest(cookies map[string]string, proxy string, insecure bool) (*http.Response, map[string]string, error) { - client := newHTTPClient(httpOptions{ProxyURL: proxy, Insecure: insecure, FollowRedirects: true}) - req, _ := http.NewRequest(http.MethodGet, EndpointInit, nil) - applyHeaders(req, HeadersGemini) - applyCookies(req, cookies) - resp, err := client.Do(req) - if err != nil { - return nil, nil, err - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return resp, nil, &AuthError{Msg: resp.Status} - } - outCookies := map[string]string{} - for _, c := range resp.Cookies() { - outCookies[c.Name] = c.Value - } - for k, v := range cookies { - outCookies[k] = v - } - return resp, outCookies, nil -} - -func getAccessToken(baseCookies map[string]string, proxy string, verbose bool, insecure bool) (string, map[string]string, error) { - // Warm-up google.com to gain extra cookies (NID, etc.) and capture them. - extraCookies := map[string]string{} - { - client := newHTTPClient(httpOptions{ProxyURL: proxy, Insecure: insecure, FollowRedirects: true}) - req, _ := http.NewRequest(http.MethodGet, EndpointGoogle, nil) - resp, _ := client.Do(req) - if resp != nil { - if u, err := url.Parse(EndpointGoogle); err == nil { - for _, c := range client.Jar.Cookies(u) { - extraCookies[c.Name] = c.Value - } - } - _ = resp.Body.Close() - } - } - - trySets := make([]map[string]string, 0, 8) - - if v1, ok1 := baseCookies["__Secure-1PSID"]; ok1 { - if v2, ok2 := baseCookies["__Secure-1PSIDTS"]; ok2 { - merged := map[string]string{"__Secure-1PSID": v1, "__Secure-1PSIDTS": v2} - if nid, ok := baseCookies["NID"]; ok { - merged["NID"] = nid - } - trySets = append(trySets, merged) - } else if verbose { - log.Debug("Skipping base cookies: __Secure-1PSIDTS missing") - } - } - - cacheDir := "temp" - _ = os.MkdirAll(cacheDir, 0o755) - if v1, ok1 := baseCookies["__Secure-1PSID"]; ok1 { - cacheFile := filepath.Join(cacheDir, ".cached_1psidts_"+v1+".txt") - if b, err := os.ReadFile(cacheFile); err == nil { - cv := strings.TrimSpace(string(b)) - if cv != "" { - merged := map[string]string{"__Secure-1PSID": v1, "__Secure-1PSIDTS": cv} - trySets = append(trySets, merged) - } - } - } - - if len(extraCookies) > 0 { - trySets = append(trySets, extraCookies) - } - - reToken := regexp.MustCompile(`"SNlM0e":"([^"]+)"`) - - for _, cookies := range trySets { - resp, mergedCookies, err := sendInitRequest(cookies, proxy, insecure) - if err != nil { - if verbose { - log.Warnf("Failed init request: %v", err) - } - continue - } - body, err := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if err != nil { - return "", nil, err - } - matches := reToken.FindStringSubmatch(string(body)) - if len(matches) >= 2 { - token := matches[1] - if verbose { - log.Infof("Gemini access token acquired.") - } - return token, mergedCookies, nil - } - } - return "", nil, &AuthError{Msg: "Failed to retrieve token."} -} - -// rotate1PSIDTS refreshes __Secure-1PSIDTS -func rotate1PSIDTS(cookies map[string]string, proxy string, insecure bool) (string, error) { - _, ok := cookies["__Secure-1PSID"] - if !ok { - return "", &AuthError{Msg: "__Secure-1PSID missing"} - } - - tr := &http.Transport{} - if proxy != "" { - if pu, err := url.Parse(proxy); err == nil { - tr.Proxy = http.ProxyURL(pu) - } - } - if insecure { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - client := &http.Client{Transport: tr, Timeout: 60 * time.Second} - - req, _ := http.NewRequest(http.MethodPost, EndpointRotateCookies, io.NopCloser(stringsReader("[000,\"-0000000000000000000\"]"))) - applyHeaders(req, HeadersRotateCookies) - applyCookies(req, cookies) - - resp, err := client.Do(req) - if err != nil { - return "", err - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode == http.StatusUnauthorized { - return "", &AuthError{Msg: "unauthorized"} - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return "", errors.New(resp.Status) - } - - for _, c := range resp.Cookies() { - if c.Name == "__Secure-1PSIDTS" { - return c.Value, nil - } - } - return "", nil -} - -// Minimal reader helpers to avoid importing strings everywhere. -type constReader struct { - s string - i int -} - -func (r *constReader) Read(p []byte) (int, error) { - if r.i >= len(r.s) { - return 0, io.EOF - } - n := copy(p, r.s[r.i:]) - r.i += n - return n, nil -} - -func stringsReader(s string) io.Reader { return &constReader{s: s} } - -func MaskToken28(s string) string { - n := len(s) - if n == 0 { - return "" - } - if n < 20 { - return strings.Repeat("*", n) - } - midStart := n/2 - 2 - if midStart < 8 { - midStart = 8 - } - if midStart+4 > n-8 { - midStart = n - 8 - 4 - if midStart < 8 { - midStart = 8 - } - } - prefixByte := s[:8] - middle := s[midStart : midStart+4] - suffix := s[n-8:] - return prefixByte + strings.Repeat("*", 4) + middle + strings.Repeat("*", 4) + suffix -} diff --git a/internal/provider/gemini-web/client.go b/internal/provider/gemini-web/client.go index 829f21ee..396a9dc9 100644 --- a/internal/provider/gemini-web/client.go +++ b/internal/provider/gemini-web/client.go @@ -1,12 +1,16 @@ package geminiwebapi import ( + "crypto/tls" "encoding/json" "errors" "fmt" "io" "net/http" + "net/http/cookiejar" "net/url" + "os" + "path/filepath" "regexp" "strings" "time" @@ -25,6 +29,227 @@ type GeminiClient struct { insecure bool } +// HTTP bootstrap utilities ------------------------------------------------- +type httpOptions struct { + ProxyURL string + Insecure bool + FollowRedirects bool +} + +func newHTTPClient(opts httpOptions) *http.Client { + transport := &http.Transport{} + if opts.ProxyURL != "" { + if pu, err := url.Parse(opts.ProxyURL); err == nil { + transport.Proxy = http.ProxyURL(pu) + } + } + if opts.Insecure { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + jar, _ := cookiejar.New(nil) + client := &http.Client{Transport: transport, Timeout: 60 * time.Second, Jar: jar} + if !opts.FollowRedirects { + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + return client +} + +func applyHeaders(req *http.Request, headers http.Header) { + for k, v := range headers { + for _, vv := range v { + req.Header.Add(k, vv) + } + } +} + +func applyCookies(req *http.Request, cookies map[string]string) { + for k, v := range cookies { + req.AddCookie(&http.Cookie{Name: k, Value: v}) + } +} + +func sendInitRequest(cookies map[string]string, proxy string, insecure bool) (*http.Response, map[string]string, error) { + client := newHTTPClient(httpOptions{ProxyURL: proxy, Insecure: insecure, FollowRedirects: true}) + req, _ := http.NewRequest(http.MethodGet, EndpointInit, nil) + applyHeaders(req, HeadersGemini) + applyCookies(req, cookies) + resp, err := client.Do(req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return resp, nil, &AuthError{Msg: resp.Status} + } + outCookies := map[string]string{} + for _, c := range resp.Cookies() { + outCookies[c.Name] = c.Value + } + for k, v := range cookies { + outCookies[k] = v + } + return resp, outCookies, nil +} + +func getAccessToken(baseCookies map[string]string, proxy string, verbose bool, insecure bool) (string, map[string]string, error) { + extraCookies := map[string]string{} + { + client := newHTTPClient(httpOptions{ProxyURL: proxy, Insecure: insecure, FollowRedirects: true}) + req, _ := http.NewRequest(http.MethodGet, EndpointGoogle, nil) + resp, _ := client.Do(req) + if resp != nil { + if u, err := url.Parse(EndpointGoogle); err == nil { + for _, c := range client.Jar.Cookies(u) { + extraCookies[c.Name] = c.Value + } + } + _ = resp.Body.Close() + } + } + + trySets := make([]map[string]string, 0, 8) + + if v1, ok1 := baseCookies["__Secure-1PSID"]; ok1 { + if v2, ok2 := baseCookies["__Secure-1PSIDTS"]; ok2 { + merged := map[string]string{"__Secure-1PSID": v1, "__Secure-1PSIDTS": v2} + if nid, ok := baseCookies["NID"]; ok { + merged["NID"] = nid + } + trySets = append(trySets, merged) + } else if verbose { + log.Debug("Skipping base cookies: __Secure-1PSIDTS missing") + } + } + + cacheDir := "temp" + _ = os.MkdirAll(cacheDir, 0o755) + if v1, ok1 := baseCookies["__Secure-1PSID"]; ok1 { + cacheFile := filepath.Join(cacheDir, ".cached_1psidts_"+v1+".txt") + if b, err := os.ReadFile(cacheFile); err == nil { + cv := strings.TrimSpace(string(b)) + if cv != "" { + merged := map[string]string{"__Secure-1PSID": v1, "__Secure-1PSIDTS": cv} + trySets = append(trySets, merged) + } + } + } + + if len(extraCookies) > 0 { + trySets = append(trySets, extraCookies) + } + + reToken := regexp.MustCompile(`"SNlM0e":"([^"]+)"`) + + for _, cookies := range trySets { + resp, mergedCookies, err := sendInitRequest(cookies, proxy, insecure) + if err != nil { + if verbose { + log.Warnf("Failed init request: %v", err) + } + continue + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return "", nil, err + } + matches := reToken.FindStringSubmatch(string(body)) + if len(matches) >= 2 { + token := matches[1] + if verbose { + log.Infof("Gemini access token acquired.") + } + return token, mergedCookies, nil + } + } + return "", nil, &AuthError{Msg: "Failed to retrieve token."} +} + +func rotate1PSIDTS(cookies map[string]string, proxy string, insecure bool) (string, error) { + _, ok := cookies["__Secure-1PSID"] + if !ok { + return "", &AuthError{Msg: "__Secure-1PSID missing"} + } + + tr := &http.Transport{} + if proxy != "" { + if pu, err := url.Parse(proxy); err == nil { + tr.Proxy = http.ProxyURL(pu) + } + } + if insecure { + tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + client := &http.Client{Transport: tr, Timeout: 60 * time.Second} + + req, _ := http.NewRequest(http.MethodPost, EndpointRotateCookies, io.NopCloser(stringsReader("[000,\"-0000000000000000000\"]"))) + applyHeaders(req, HeadersRotateCookies) + applyCookies(req, cookies) + + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode == http.StatusUnauthorized { + return "", &AuthError{Msg: "unauthorized"} + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", errors.New(resp.Status) + } + + for _, c := range resp.Cookies() { + if c.Name == "__Secure-1PSIDTS" { + return c.Value, nil + } + } + return "", nil +} + +type constReader struct { + s string + i int +} + +func (r *constReader) Read(p []byte) (int, error) { + if r.i >= len(r.s) { + return 0, io.EOF + } + n := copy(p, r.s[r.i:]) + r.i += n + return n, nil +} + +func stringsReader(s string) io.Reader { return &constReader{s: s} } + +func MaskToken28(s string) string { + n := len(s) + if n == 0 { + return "" + } + if n < 20 { + return strings.Repeat("*", n) + } + midStart := n/2 - 2 + if midStart < 8 { + midStart = 8 + } + if midStart+4 > n-8 { + midStart = n - 8 - 4 + if midStart < 8 { + midStart = 8 + } + } + prefixByte := s[:8] + middle := s[midStart : midStart+4] + suffix := s[n-8:] + return prefixByte + strings.Repeat("*", 4) + middle + strings.Repeat("*", 4) + suffix +} + var NanoBananaModel = map[string]struct{}{ "gemini-2.5-flash-image-preview": {}, } diff --git a/internal/provider/gemini-web/convert_ext.go b/internal/provider/gemini-web/convert_ext.go deleted file mode 100644 index db5a1e50..00000000 --- a/internal/provider/gemini-web/convert_ext.go +++ /dev/null @@ -1,178 +0,0 @@ -package geminiwebapi - -import ( - "bytes" - "encoding/json" - "fmt" - "math" - "regexp" - "strings" - "time" - "unicode/utf8" -) - -var ( - reGoogle = regexp.MustCompile("(\\()?\\[`([^`]+?)`\\]\\(https://www\\.google\\.com/search\\?q=[^)]*\\)(\\))?") - reColonNum = regexp.MustCompile(`([^:]+:\d+)`) - reInline = regexp.MustCompile("`(\\[[^\\]]+\\]\\([^\\)]+\\))`") -) - -func unescapeGeminiText(s string) string { - if s == "" { - return s - } - s = strings.ReplaceAll(s, "<", "<") - s = strings.ReplaceAll(s, "\\<", "<") - s = strings.ReplaceAll(s, "\\_", "_") - s = strings.ReplaceAll(s, "\\>", ">") - return s -} - -func postProcessModelText(text string) string { - text = reGoogle.ReplaceAllStringFunc(text, func(m string) string { - subs := reGoogle.FindStringSubmatch(m) - if len(subs) < 4 { - return m - } - outerOpen := subs[1] - display := subs[2] - target := display - if loc := reColonNum.FindString(display); loc != "" { - target = loc - } - newSeg := "[`" + display + "`](" + target + ")" - if outerOpen != "" { - return "(" + newSeg + ")" - } - return newSeg - }) - text = reInline.ReplaceAllString(text, "$1") - return text -} - -func estimateTokens(s string) int { - if s == "" { - return 0 - } - rc := float64(utf8.RuneCountInString(s)) - if rc <= 0 { - return 0 - } - est := int(math.Ceil(rc / 4.0)) - if est < 0 { - return 0 - } - return est -} - -// ConvertOutputToGemini converts simplified ModelOutput to Gemini API-like JSON. -// promptText is used only to estimate usage tokens to populate usage fields. -func ConvertOutputToGemini(output *ModelOutput, modelName string, promptText string) ([]byte, error) { - if output == nil || len(output.Candidates) == 0 { - return nil, fmt.Errorf("empty output") - } - - parts := make([]map[string]any, 0, 2) - - var thoughtsText string - if output.Candidates[0].Thoughts != nil { - if t := strings.TrimSpace(*output.Candidates[0].Thoughts); t != "" { - thoughtsText = unescapeGeminiText(t) - parts = append(parts, map[string]any{ - "text": thoughtsText, - "thought": true, - }) - } - } - - visible := unescapeGeminiText(output.Candidates[0].Text) - finalText := postProcessModelText(visible) - if finalText != "" { - parts = append(parts, map[string]any{"text": finalText}) - } - - if imgs := output.Candidates[0].GeneratedImages; len(imgs) > 0 { - for _, gi := range imgs { - if mime, data, err := FetchGeneratedImageData(gi); err == nil && data != "" { - parts = append(parts, map[string]any{ - "inlineData": map[string]any{ - "mimeType": mime, - "data": data, - }, - }) - } - } - } - - promptTokens := estimateTokens(promptText) - completionTokens := estimateTokens(finalText) - thoughtsTokens := 0 - if thoughtsText != "" { - thoughtsTokens = estimateTokens(thoughtsText) - } - totalTokens := promptTokens + completionTokens - - now := time.Now() - resp := map[string]any{ - "candidates": []any{ - map[string]any{ - "content": map[string]any{ - "parts": parts, - "role": "model", - }, - "finishReason": "stop", - "index": 0, - }, - }, - "createTime": now.Format(time.RFC3339Nano), - "responseId": fmt.Sprintf("gemini-web-%d", now.UnixNano()), - "modelVersion": modelName, - "usageMetadata": map[string]any{ - "promptTokenCount": promptTokens, - "candidatesTokenCount": completionTokens, - "thoughtsTokenCount": thoughtsTokens, - "totalTokenCount": totalTokens, - }, - } - b, err := json.Marshal(resp) - if err != nil { - return nil, fmt.Errorf("failed to marshal gemini response: %w", err) - } - return ensureColonSpacing(b), nil -} - -// ensureColonSpacing inserts a single space after JSON key-value colons while -// leaving string content untouched. This matches the relaxed formatting used by -// Gemini responses and keeps downstream text-processing tools compatible with -// the proxy output. -func ensureColonSpacing(b []byte) []byte { - if len(b) == 0 { - return b - } - var out bytes.Buffer - out.Grow(len(b) + len(b)/8) - inString := false - escaped := false - for i := 0; i < len(b); i++ { - ch := b[i] - out.WriteByte(ch) - if escaped { - escaped = false - continue - } - switch ch { - case '\\': - escaped = true - case '"': - inString = !inString - case ':': - if !inString && i+1 < len(b) { - next := b[i+1] - if next != ' ' && next != '\n' && next != '\r' && next != '\t' { - out.WriteByte(' ') - } - } - } - } - return out.Bytes() -} diff --git a/internal/provider/gemini-web/errors.go b/internal/provider/gemini-web/errors.go deleted file mode 100644 index 6341b696..00000000 --- a/internal/provider/gemini-web/errors.go +++ /dev/null @@ -1,47 +0,0 @@ -package geminiwebapi - -type AuthError struct{ Msg string } - -func (e *AuthError) Error() string { - if e.Msg == "" { - return "authentication error" - } - return e.Msg -} - -type APIError struct{ Msg string } - -func (e *APIError) Error() string { - if e.Msg == "" { - return "api error" - } - return e.Msg -} - -type ImageGenerationError struct{ APIError } - -type GeminiError struct{ Msg string } - -func (e *GeminiError) Error() string { - if e.Msg == "" { - return "gemini error" - } - return e.Msg -} - -type TimeoutError struct{ GeminiError } - -type UsageLimitExceeded struct{ GeminiError } - -type ModelInvalid struct{ GeminiError } - -type TemporarilyBlocked struct{ GeminiError } - -type ValueError struct{ Msg string } - -func (e *ValueError) Error() string { - if e.Msg == "" { - return "value error" - } - return e.Msg -} diff --git a/internal/provider/gemini-web/media.go b/internal/provider/gemini-web/media.go index 3c843c62..c21bc262 100644 --- a/internal/provider/gemini-web/media.go +++ b/internal/provider/gemini-web/media.go @@ -4,9 +4,11 @@ import ( "bytes" "crypto/tls" "encoding/base64" + "encoding/json" "errors" "fmt" "io" + "math" "mime/multipart" "net/http" "net/http/cookiejar" @@ -17,6 +19,7 @@ import ( "sort" "strings" "time" + "unicode/utf8" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -393,3 +396,171 @@ func parseFileName(path string) (string, error) { } return filepath.Base(path), nil } + +// Response formatting helpers ---------------------------------------------- + +var ( + reGoogle = regexp.MustCompile("(\\()?\\[`([^`]+?)`\\]\\(https://www\\.google\\.com/search\\?q=[^)]*\\)(\\))?") + reColonNum = regexp.MustCompile(`([^:]+:\d+)`) + reInline = regexp.MustCompile("`(\\[[^\\]]+\\]\\([^\\)]+\\))`") +) + +func unescapeGeminiText(s string) string { + if s == "" { + return s + } + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, "\\<", "<") + s = strings.ReplaceAll(s, "\\_", "_") + s = strings.ReplaceAll(s, "\\>", ">") + return s +} + +func postProcessModelText(text string) string { + text = reGoogle.ReplaceAllStringFunc(text, func(m string) string { + subs := reGoogle.FindStringSubmatch(m) + if len(subs) < 4 { + return m + } + outerOpen := subs[1] + display := subs[2] + target := display + if loc := reColonNum.FindString(display); loc != "" { + target = loc + } + newSeg := "[`" + display + "`](" + target + ")" + if outerOpen != "" { + return "(" + newSeg + ")" + } + return newSeg + }) + text = reInline.ReplaceAllString(text, "$1") + return text +} + +func estimateTokens(s string) int { + if s == "" { + return 0 + } + rc := float64(utf8.RuneCountInString(s)) + if rc <= 0 { + return 0 + } + est := int(math.Ceil(rc / 4.0)) + if est < 0 { + return 0 + } + return est +} + +// ConvertOutputToGemini converts simplified ModelOutput to Gemini API-like JSON. +// promptText is used only to estimate usage tokens to populate usage fields. +func ConvertOutputToGemini(output *ModelOutput, modelName string, promptText string) ([]byte, error) { + if output == nil || len(output.Candidates) == 0 { + return nil, fmt.Errorf("empty output") + } + + parts := make([]map[string]any, 0, 2) + + var thoughtsText string + if output.Candidates[0].Thoughts != nil { + if t := strings.TrimSpace(*output.Candidates[0].Thoughts); t != "" { + thoughtsText = unescapeGeminiText(t) + parts = append(parts, map[string]any{ + "text": thoughtsText, + "thought": true, + }) + } + } + + visible := unescapeGeminiText(output.Candidates[0].Text) + finalText := postProcessModelText(visible) + if finalText != "" { + parts = append(parts, map[string]any{"text": finalText}) + } + + if imgs := output.Candidates[0].GeneratedImages; len(imgs) > 0 { + for _, gi := range imgs { + if mime, data, err := FetchGeneratedImageData(gi); err == nil && data != "" { + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": mime, + "data": data, + }, + }) + } + } + } + + promptTokens := estimateTokens(promptText) + completionTokens := estimateTokens(finalText) + thoughtsTokens := 0 + if thoughtsText != "" { + thoughtsTokens = estimateTokens(thoughtsText) + } + totalTokens := promptTokens + completionTokens + + now := time.Now() + resp := map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "parts": parts, + "role": "model", + }, + "finishReason": "stop", + "index": 0, + }, + }, + "createTime": now.Format(time.RFC3339Nano), + "responseId": fmt.Sprintf("gemini-web-%d", now.UnixNano()), + "modelVersion": modelName, + "usageMetadata": map[string]any{ + "promptTokenCount": promptTokens, + "candidatesTokenCount": completionTokens, + "thoughtsTokenCount": thoughtsTokens, + "totalTokenCount": totalTokens, + }, + } + b, err := json.Marshal(resp) + if err != nil { + return nil, fmt.Errorf("failed to marshal gemini response: %w", err) + } + return ensureColonSpacing(b), nil +} + +// ensureColonSpacing inserts a single space after JSON key-value colons while +// leaving string content untouched. This matches the relaxed formatting used by +// Gemini responses and keeps downstream text-processing tools compatible with +// the proxy output. +func ensureColonSpacing(b []byte) []byte { + if len(b) == 0 { + return b + } + var out bytes.Buffer + out.Grow(len(b) + len(b)/8) + inString := false + escaped := false + for i := 0; i < len(b); i++ { + ch := b[i] + out.WriteByte(ch) + if escaped { + escaped = false + continue + } + switch ch { + case '\\': + escaped = true + case '"': + inString = !inString + case ':': + if !inString && i+1 < len(b) { + next := b[i+1] + if next != ' ' && next != '\n' && next != '\r' && next != '\t' { + out.WriteByte(' ') + } + } + } + } + return out.Bytes() +} diff --git a/internal/provider/gemini-web/models.go b/internal/provider/gemini-web/models.go index 2d4f3f0c..c4cb29e8 100644 --- a/internal/provider/gemini-web/models.go +++ b/internal/provider/gemini-web/models.go @@ -1,14 +1,17 @@ package geminiwebapi import ( + "fmt" + "html" "net/http" "strings" "sync" + "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" ) -// Endpoints used by the Gemini web app +// Gemini web endpoints and default headers ---------------------------------- const ( EndpointGoogle = "https://www.google.com" EndpointInit = "https://gemini.google.com/app" @@ -17,7 +20,6 @@ const ( EndpointUpload = "https://content-push.googleapis.com/upload" ) -// Default headers var ( HeadersGemini = http.Header{ "Content-Type": []string{"application/x-www-form-urlencoded;charset=utf-8"}, @@ -35,7 +37,7 @@ var ( } ) -// Model defines available model names and headers +// Model metadata ------------------------------------------------------------- type Model struct { Name string ModelHeader http.Header @@ -62,14 +64,14 @@ var ( }, AdvancedOnly: false, } - ModelG20Flash = Model{ // Deprecated, still supported + ModelG20Flash = Model{ Name: "gemini-2.0-flash", ModelHeader: http.Header{ "x-goog-ext-525001261-jspb": []string{"[1,null,null,null,\"f299729663a2343f\"]"}, }, AdvancedOnly: false, } - ModelG20FlashThinking = Model{ // Deprecated, still supported + ModelG20FlashThinking = Model{ Name: "gemini-2.0-flash-thinking", ModelHeader: http.Header{ "x-goog-ext-525001261-jspb": []string{"[null,null,null,null,\"7ca48d02d802f20a\"]"}, @@ -78,7 +80,6 @@ var ( } ) -// ModelFromName returns a model by name or error if not found func ModelFromName(name string) (Model, error) { switch name { case ModelUnspecified.Name: @@ -96,7 +97,7 @@ func ModelFromName(name string) (Model, error) { } } -// Known error codes returned from server +// Known error codes returned from the server. const ( ErrorUsageLimitExceeded = 1037 ErrorModelInconsistent = 1050 @@ -109,7 +110,6 @@ var ( GeminiWebAliasMap map[string]string ) -// EnsureGeminiWebAliasMap initializes alias lookup lazily. func EnsureGeminiWebAliasMap() { GeminiWebAliasOnce.Do(func() { GeminiWebAliasMap = make(map[string]string) @@ -125,7 +125,6 @@ func EnsureGeminiWebAliasMap() { }) } -// GetGeminiWebAliasedModels returns Gemini models exposed with web aliases. func GetGeminiWebAliasedModels() []*registry.ModelInfo { EnsureGeminiWebAliasMap() aliased := make([]*registry.ModelInfo, 0) @@ -148,7 +147,6 @@ func GetGeminiWebAliasedModels() []*registry.ModelInfo { return aliased } -// MapAliasToUnderlying normalizes web aliases back to canonical Gemini IDs. func MapAliasToUnderlying(name string) string { EnsureGeminiWebAliasMap() n := strings.ToLower(name) @@ -162,7 +160,151 @@ func MapAliasToUnderlying(name string) string { return name } -// AliasFromModelID builds the web alias for a Gemini model identifier. func AliasFromModelID(modelID string) string { return modelID + "-web" } + +// Conversation domain structures ------------------------------------------- +type RoleText struct { + Role string + Text string +} + +type StoredMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +type ConversationRecord struct { + Model string `json:"model"` + ClientID string `json:"client_id"` + Metadata []string `json:"metadata,omitempty"` + Messages []StoredMessage `json:"messages"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type Candidate struct { + RCID string + Text string + Thoughts *string + WebImages []WebImage + GeneratedImages []GeneratedImage +} + +func (c Candidate) String() string { + t := c.Text + if len(t) > 20 { + t = t[:20] + "..." + } + return fmt.Sprintf("Candidate(rcid='%s', text='%s', images=%d)", c.RCID, t, len(c.WebImages)+len(c.GeneratedImages)) +} + +func (c Candidate) Images() []Image { + images := make([]Image, 0, len(c.WebImages)+len(c.GeneratedImages)) + for _, wi := range c.WebImages { + images = append(images, wi.Image) + } + for _, gi := range c.GeneratedImages { + images = append(images, gi.Image) + } + return images +} + +type ModelOutput struct { + Metadata []string + Candidates []Candidate + Chosen int +} + +func (m ModelOutput) String() string { return m.Text() } + +func (m ModelOutput) Text() string { + if len(m.Candidates) == 0 { + return "" + } + return m.Candidates[m.Chosen].Text +} + +func (m ModelOutput) Thoughts() *string { + if len(m.Candidates) == 0 { + return nil + } + return m.Candidates[m.Chosen].Thoughts +} + +func (m ModelOutput) Images() []Image { + if len(m.Candidates) == 0 { + return nil + } + return m.Candidates[m.Chosen].Images() +} + +func (m ModelOutput) RCID() string { + if len(m.Candidates) == 0 { + return "" + } + return m.Candidates[m.Chosen].RCID +} + +type Gem struct { + ID string + Name string + Description *string + Prompt *string + Predefined bool +} + +func (g Gem) String() string { + return fmt.Sprintf("Gem(id='%s', name='%s', description='%v', prompt='%v', predefined=%v)", g.ID, g.Name, g.Description, g.Prompt, g.Predefined) +} + +func decodeHTML(s string) string { return html.UnescapeString(s) } + +// Error hierarchy ----------------------------------------------------------- +type AuthError struct{ Msg string } + +func (e *AuthError) Error() string { + if e.Msg == "" { + return "authentication error" + } + return e.Msg +} + +type APIError struct{ Msg string } + +func (e *APIError) Error() string { + if e.Msg == "" { + return "api error" + } + return e.Msg +} + +type ImageGenerationError struct{ APIError } + +type GeminiError struct{ Msg string } + +func (e *GeminiError) Error() string { + if e.Msg == "" { + return "gemini error" + } + return e.Msg +} + +type TimeoutError struct{ GeminiError } + +type UsageLimitExceeded struct{ GeminiError } + +type ModelInvalid struct{ GeminiError } + +type TemporarilyBlocked struct{ GeminiError } + +type ValueError struct{ Msg string } + +func (e *ValueError) Error() string { + if e.Msg == "" { + return "value error" + } + return e.Msg +} diff --git a/internal/provider/gemini-web/persistence.go b/internal/provider/gemini-web/persistence.go deleted file mode 100644 index 59e14ddf..00000000 --- a/internal/provider/gemini-web/persistence.go +++ /dev/null @@ -1,364 +0,0 @@ -package geminiwebapi - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - bolt "go.etcd.io/bbolt" -) - -// StoredMessage represents a single message in a conversation record. -type StoredMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Name string `json:"name,omitempty"` -} - -// ConversationRecord stores a full conversation with its metadata for persistence. -type ConversationRecord struct { - Model string `json:"model"` - ClientID string `json:"client_id"` - Metadata []string `json:"metadata,omitempty"` - Messages []StoredMessage `json:"messages"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// Sha256Hex computes the SHA256 hash of a string and returns its hex representation. -func Sha256Hex(s string) string { - sum := sha256.Sum256([]byte(s)) - return hex.EncodeToString(sum[:]) -} - -// RoleText represents a turn in a conversation with a role and text content. -type RoleText struct { - Role string - Text string -} - -func ToStoredMessages(msgs []RoleText) []StoredMessage { - out := make([]StoredMessage, 0, len(msgs)) - for _, m := range msgs { - out = append(out, StoredMessage{ - Role: m.Role, - Content: m.Text, - }) - } - return out -} - -func HashMessage(m StoredMessage) string { - s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role)) - return Sha256Hex(s) -} - -func HashConversation(clientID, model string, msgs []StoredMessage) string { - var b strings.Builder - b.WriteString(clientID) - b.WriteString("|") - b.WriteString(model) - for _, m := range msgs { - b.WriteString("|") - b.WriteString(HashMessage(m)) - } - return Sha256Hex(b.String()) -} - -// ConvStorePath returns the path for account-level metadata persistence based on token file path. -func ConvStorePath(tokenFilePath string) string { - wd, err := os.Getwd() - if err != nil || wd == "" { - wd = "." - } - convDir := filepath.Join(wd, "conv") - base := strings.TrimSuffix(filepath.Base(tokenFilePath), filepath.Ext(tokenFilePath)) - return filepath.Join(convDir, base+".bolt") -} - -// ConvDataPath returns the path for full conversation persistence based on token file path. -func ConvDataPath(tokenFilePath string) string { - wd, err := os.Getwd() - if err != nil || wd == "" { - wd = "." - } - convDir := filepath.Join(wd, "conv") - base := strings.TrimSuffix(filepath.Base(tokenFilePath), filepath.Ext(tokenFilePath)) - return filepath.Join(convDir, base+".bolt") -} - -// LoadConvStore reads the account-level metadata store from disk. -func LoadConvStore(path string) (map[string][]string, error) { - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return nil, err - } - db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: time.Second}) - if err != nil { - return nil, err - } - defer func() { - _ = db.Close() - }() - out := map[string][]string{} - err = db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("account_meta")) - if b == nil { - return nil - } - return b.ForEach(func(k, v []byte) error { - var arr []string - if len(v) > 0 { - if e := json.Unmarshal(v, &arr); e != nil { - // Skip malformed entries instead of failing the whole load - return nil - } - } - out[string(k)] = arr - return nil - }) - }) - if err != nil { - return nil, err - } - return out, nil -} - -// SaveConvStore writes the account-level metadata store to disk atomically. -func SaveConvStore(path string, data map[string][]string) error { - if data == nil { - data = map[string][]string{} - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second}) - if err != nil { - return err - } - defer func() { - _ = db.Close() - }() - return db.Update(func(tx *bolt.Tx) error { - // Recreate bucket to reflect the given snapshot exactly. - if b := tx.Bucket([]byte("account_meta")); b != nil { - if err = tx.DeleteBucket([]byte("account_meta")); err != nil { - return err - } - } - b, errCreateBucket := tx.CreateBucket([]byte("account_meta")) - if errCreateBucket != nil { - return errCreateBucket - } - for k, v := range data { - enc, e := json.Marshal(v) - if e != nil { - return e - } - if e = b.Put([]byte(k), enc); e != nil { - return e - } - } - return nil - }) -} - -// AccountMetaKey builds the key for account-level metadata map. -func AccountMetaKey(email, modelName string) string { - return fmt.Sprintf("account-meta|%s|%s", email, modelName) -} - -// LoadConvData reads the full conversation data and index from disk. -func LoadConvData(path string) (map[string]ConversationRecord, map[string]string, error) { - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return nil, nil, err - } - db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: time.Second}) - if err != nil { - return nil, nil, err - } - defer func() { - _ = db.Close() - }() - items := map[string]ConversationRecord{} - index := map[string]string{} - err = db.View(func(tx *bolt.Tx) error { - // Load conv_items - if b := tx.Bucket([]byte("conv_items")); b != nil { - if e := b.ForEach(func(k, v []byte) error { - var rec ConversationRecord - if len(v) > 0 { - if e2 := json.Unmarshal(v, &rec); e2 != nil { - // Skip malformed - return nil - } - items[string(k)] = rec - } - return nil - }); e != nil { - return e - } - } - // Load conv_index - if b := tx.Bucket([]byte("conv_index")); b != nil { - if e := b.ForEach(func(k, v []byte) error { - index[string(k)] = string(v) - return nil - }); e != nil { - return e - } - } - return nil - }) - if err != nil { - return nil, nil, err - } - return items, index, nil -} - -// SaveConvData writes the full conversation data and index to disk atomically. -func SaveConvData(path string, items map[string]ConversationRecord, index map[string]string) error { - if items == nil { - items = map[string]ConversationRecord{} - } - if index == nil { - index = map[string]string{} - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second}) - if err != nil { - return err - } - defer func() { - _ = db.Close() - }() - return db.Update(func(tx *bolt.Tx) error { - // Recreate items bucket - if b := tx.Bucket([]byte("conv_items")); b != nil { - if err = tx.DeleteBucket([]byte("conv_items")); err != nil { - return err - } - } - bi, errCreateBucket := tx.CreateBucket([]byte("conv_items")) - if errCreateBucket != nil { - return errCreateBucket - } - for k, rec := range items { - enc, e := json.Marshal(rec) - if e != nil { - return e - } - if e = bi.Put([]byte(k), enc); e != nil { - return e - } - } - - // Recreate index bucket - if b := tx.Bucket([]byte("conv_index")); b != nil { - if err = tx.DeleteBucket([]byte("conv_index")); err != nil { - return err - } - } - bx, errCreateBucket := tx.CreateBucket([]byte("conv_index")) - if errCreateBucket != nil { - return errCreateBucket - } - for k, v := range index { - if e := bx.Put([]byte(k), []byte(v)); e != nil { - return e - } - } - return nil - }) -} - -// BuildConversationRecord constructs a ConversationRecord from history and the latest output. -// Returns false when output is empty or has no candidates. -func BuildConversationRecord(model, clientID string, history []RoleText, output *ModelOutput, metadata []string) (ConversationRecord, bool) { - if output == nil || len(output.Candidates) == 0 { - return ConversationRecord{}, false - } - text := "" - if t := output.Candidates[0].Text; t != "" { - text = RemoveThinkTags(t) - } - final := append([]RoleText{}, history...) - final = append(final, RoleText{Role: "assistant", Text: text}) - rec := ConversationRecord{ - Model: model, - ClientID: clientID, - Metadata: metadata, - Messages: ToStoredMessages(final), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - return rec, true -} - -// FindByMessageListIn looks up a conversation record by hashed message list. -// It attempts both the stable client ID and a legacy email-based ID. -func FindByMessageListIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { - stored := ToStoredMessages(msgs) - stableHash := HashConversation(stableClientID, model, stored) - fallbackHash := HashConversation(email, model, stored) - - // Try stable hash via index indirection first - if key, ok := index["hash:"+stableHash]; ok { - if rec, ok2 := items[key]; ok2 { - return rec, true - } - } - if rec, ok := items[stableHash]; ok { - return rec, true - } - // Fallback to legacy hash (email-based) - if key, ok := index["hash:"+fallbackHash]; ok { - if rec, ok2 := items[key]; ok2 { - return rec, true - } - } - if rec, ok := items[fallbackHash]; ok { - return rec, true - } - return ConversationRecord{}, false -} - -// FindConversationIn tries exact then sanitized assistant messages. -func FindConversationIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { - if len(msgs) == 0 { - return ConversationRecord{}, false - } - if rec, ok := FindByMessageListIn(items, index, stableClientID, email, model, msgs); ok { - return rec, true - } - if rec, ok := FindByMessageListIn(items, index, stableClientID, email, model, SanitizeAssistantMessages(msgs)); ok { - return rec, true - } - return ConversationRecord{}, false -} - -// FindReusableSessionIn returns reusable metadata and the remaining message suffix. -func FindReusableSessionIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) ([]string, []RoleText) { - if len(msgs) < 2 { - return nil, nil - } - searchEnd := len(msgs) - for searchEnd >= 2 { - sub := msgs[:searchEnd] - tail := sub[len(sub)-1] - if strings.EqualFold(tail.Role, "assistant") || strings.EqualFold(tail.Role, "system") { - if rec, ok := FindConversationIn(items, index, stableClientID, email, model, sub); ok { - remain := msgs[searchEnd:] - return rec.Metadata, remain - } - } - searchEnd-- - } - return nil, nil -} diff --git a/internal/provider/gemini-web/prompt.go b/internal/provider/gemini-web/prompt.go index 50760b36..1f9cd8be 100644 --- a/internal/provider/gemini-web/prompt.go +++ b/internal/provider/gemini-web/prompt.go @@ -1,11 +1,13 @@ package geminiwebapi import ( + "fmt" "math" "regexp" "strings" "unicode/utf8" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/tidwall/gjson" ) @@ -128,3 +130,98 @@ func EstimateTotalTokensFromRawJSON(rawJSON []byte) int { } return int(math.Ceil(float64(totalChars) / 4.0)) } + +// Request chunking helpers ------------------------------------------------ + +const continuationHint = "\n(More messages to come, please reply with just 'ok.')" + +func ChunkByRunes(s string, size int) []string { + if size <= 0 { + return []string{s} + } + chunks := make([]string, 0, (len(s)/size)+1) + var buf strings.Builder + count := 0 + for _, r := range s { + buf.WriteRune(r) + count++ + if count >= size { + chunks = append(chunks, buf.String()) + buf.Reset() + count = 0 + } + } + if buf.Len() > 0 { + chunks = append(chunks, buf.String()) + } + if len(chunks) == 0 { + return []string{""} + } + return chunks +} + +func MaxCharsPerRequest(cfg *config.Config) int { + // Read max characters per request from config with a conservative default. + if cfg != nil { + if v := cfg.GeminiWeb.MaxCharsPerRequest; v > 0 { + return v + } + } + return 1_000_000 +} + +func SendWithSplit(chat *ChatSession, text string, files []string, cfg *config.Config) (ModelOutput, error) { + // Validate chat session + if chat == nil { + return ModelOutput{}, fmt.Errorf("nil chat session") + } + + // Resolve maxChars characters per request + maxChars := MaxCharsPerRequest(cfg) + if maxChars <= 0 { + maxChars = 1_000_000 + } + + // If within limit, send directly + if utf8.RuneCountInString(text) <= maxChars { + return chat.SendMessage(text, files) + } + + // Decide whether to use continuation hint (enabled by default) + useHint := true + if cfg != nil && cfg.GeminiWeb.DisableContinuationHint { + useHint = false + } + + // Compute chunk size in runes. If the hint does not fit, disable it for this request. + hintLen := 0 + if useHint { + hintLen = utf8.RuneCountInString(continuationHint) + } + chunkSize := maxChars - hintLen + if chunkSize <= 0 { + // maxChars is too small to accommodate the hint; fall back to no-hint splitting + useHint = false + chunkSize = maxChars + } + + // Split into rune-safe chunks + chunks := ChunkByRunes(text, chunkSize) + if len(chunks) == 0 { + chunks = []string{""} + } + + // Send all but the last chunk without files, optionally appending hint + for i := 0; i < len(chunks)-1; i++ { + part := chunks[i] + if useHint { + part += continuationHint + } + if _, err := chat.SendMessage(part, nil); err != nil { + return ModelOutput{}, err + } + } + + // Send final chunk with files and return the actual output + return chat.SendMessage(chunks[len(chunks)-1], files) +} diff --git a/internal/provider/gemini-web/request.go b/internal/provider/gemini-web/request.go deleted file mode 100644 index 2e9a4830..00000000 --- a/internal/provider/gemini-web/request.go +++ /dev/null @@ -1,102 +0,0 @@ -package geminiwebapi - -import ( - "fmt" - "strings" - "unicode/utf8" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" -) - -const continuationHint = "\n(More messages to come, please reply with just 'ok.')" - -func ChunkByRunes(s string, size int) []string { - if size <= 0 { - return []string{s} - } - chunks := make([]string, 0, (len(s)/size)+1) - var buf strings.Builder - count := 0 - for _, r := range s { - buf.WriteRune(r) - count++ - if count >= size { - chunks = append(chunks, buf.String()) - buf.Reset() - count = 0 - } - } - if buf.Len() > 0 { - chunks = append(chunks, buf.String()) - } - if len(chunks) == 0 { - return []string{""} - } - return chunks -} - -func MaxCharsPerRequest(cfg *config.Config) int { - // Read max characters per request from config with a conservative default. - if cfg != nil { - if v := cfg.GeminiWeb.MaxCharsPerRequest; v > 0 { - return v - } - } - return 1_000_000 -} - -func SendWithSplit(chat *ChatSession, text string, files []string, cfg *config.Config) (ModelOutput, error) { - // Validate chat session - if chat == nil { - return ModelOutput{}, fmt.Errorf("nil chat session") - } - - // Resolve maxChars characters per request - maxChars := MaxCharsPerRequest(cfg) - if maxChars <= 0 { - maxChars = 1_000_000 - } - - // If within limit, send directly - if utf8.RuneCountInString(text) <= maxChars { - return chat.SendMessage(text, files) - } - - // Decide whether to use continuation hint (enabled by default) - useHint := true - if cfg != nil && cfg.GeminiWeb.DisableContinuationHint { - useHint = false - } - - // Compute chunk size in runes. If the hint does not fit, disable it for this request. - hintLen := 0 - if useHint { - hintLen = utf8.RuneCountInString(continuationHint) - } - chunkSize := maxChars - hintLen - if chunkSize <= 0 { - // maxChars is too small to accommodate the hint; fall back to no-hint splitting - useHint = false - chunkSize = maxChars - } - - // Split into rune-safe chunks - chunks := ChunkByRunes(text, chunkSize) - if len(chunks) == 0 { - chunks = []string{""} - } - - // Send all but the last chunk without files, optionally appending hint - for i := 0; i < len(chunks)-1; i++ { - part := chunks[i] - if useHint { - part += continuationHint - } - if _, err := chat.SendMessage(part, nil); err != nil { - return ModelOutput{}, err - } - } - - // Send final chunk with files and return the actual output - return chat.SendMessage(chunks[len(chunks)-1], files) -} diff --git a/internal/provider/gemini-web/state.go b/internal/provider/gemini-web/state.go index aed61b74..4442dad7 100644 --- a/internal/provider/gemini-web/state.go +++ b/internal/provider/gemini-web/state.go @@ -3,8 +3,12 @@ package geminiwebapi import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" "fmt" + "os" "path/filepath" "strings" "sync" @@ -19,6 +23,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + bolt "go.etcd.io/bbolt" ) const ( @@ -512,3 +517,332 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt ginCtx.Set("API_RESPONSE", data) } } + +// Persistence helpers -------------------------------------------------- + +// Sha256Hex computes the SHA256 hash of a string and returns its hex representation. +func Sha256Hex(s string) string { + sum := sha256.Sum256([]byte(s)) + return hex.EncodeToString(sum[:]) +} + +func ToStoredMessages(msgs []RoleText) []StoredMessage { + out := make([]StoredMessage, 0, len(msgs)) + for _, m := range msgs { + out = append(out, StoredMessage{ + Role: m.Role, + Content: m.Text, + }) + } + return out +} + +func HashMessage(m StoredMessage) string { + s := fmt.Sprintf(`{"content":%q,"role":%q}`, m.Content, strings.ToLower(m.Role)) + return Sha256Hex(s) +} + +func HashConversation(clientID, model string, msgs []StoredMessage) string { + var b strings.Builder + b.WriteString(clientID) + b.WriteString("|") + b.WriteString(model) + for _, m := range msgs { + b.WriteString("|") + b.WriteString(HashMessage(m)) + } + return Sha256Hex(b.String()) +} + +// ConvStorePath returns the path for account-level metadata persistence based on token file path. +func ConvStorePath(tokenFilePath string) string { + wd, err := os.Getwd() + if err != nil || wd == "" { + wd = "." + } + convDir := filepath.Join(wd, "conv") + base := strings.TrimSuffix(filepath.Base(tokenFilePath), filepath.Ext(tokenFilePath)) + return filepath.Join(convDir, base+".bolt") +} + +// ConvDataPath returns the path for full conversation persistence based on token file path. +func ConvDataPath(tokenFilePath string) string { + wd, err := os.Getwd() + if err != nil || wd == "" { + wd = "." + } + convDir := filepath.Join(wd, "conv") + base := strings.TrimSuffix(filepath.Base(tokenFilePath), filepath.Ext(tokenFilePath)) + return filepath.Join(convDir, base+".bolt") +} + +// LoadConvStore reads the account-level metadata store from disk. +func LoadConvStore(path string) (map[string][]string, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, err + } + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: time.Second}) + if err != nil { + return nil, err + } + defer func() { + _ = db.Close() + }() + out := map[string][]string{} + err = db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("account_meta")) + if b == nil { + return nil + } + return b.ForEach(func(k, v []byte) error { + var arr []string + if len(v) > 0 { + if e := json.Unmarshal(v, &arr); e != nil { + // Skip malformed entries instead of failing the whole load + return nil + } + } + out[string(k)] = arr + return nil + }) + }) + if err != nil { + return nil, err + } + return out, nil +} + +// SaveConvStore writes the account-level metadata store to disk atomically. +func SaveConvStore(path string, data map[string][]string) error { + if data == nil { + data = map[string][]string{} + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second}) + if err != nil { + return err + } + defer func() { + _ = db.Close() + }() + return db.Update(func(tx *bolt.Tx) error { + // Recreate bucket to reflect the given snapshot exactly. + if b := tx.Bucket([]byte("account_meta")); b != nil { + if err = tx.DeleteBucket([]byte("account_meta")); err != nil { + return err + } + } + b, errCreateBucket := tx.CreateBucket([]byte("account_meta")) + if errCreateBucket != nil { + return errCreateBucket + } + for k, v := range data { + enc, e := json.Marshal(v) + if e != nil { + return e + } + if e = b.Put([]byte(k), enc); e != nil { + return e + } + } + return nil + }) +} + +// AccountMetaKey builds the key for account-level metadata map. +func AccountMetaKey(email, modelName string) string { + return fmt.Sprintf("account-meta|%s|%s", email, modelName) +} + +// LoadConvData reads the full conversation data and index from disk. +func LoadConvData(path string) (map[string]ConversationRecord, map[string]string, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return nil, nil, err + } + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: time.Second}) + if err != nil { + return nil, nil, err + } + defer func() { + _ = db.Close() + }() + items := map[string]ConversationRecord{} + index := map[string]string{} + err = db.View(func(tx *bolt.Tx) error { + // Load conv_items + if b := tx.Bucket([]byte("conv_items")); b != nil { + if e := b.ForEach(func(k, v []byte) error { + var rec ConversationRecord + if len(v) > 0 { + if e2 := json.Unmarshal(v, &rec); e2 != nil { + // Skip malformed + return nil + } + items[string(k)] = rec + } + return nil + }); e != nil { + return e + } + } + // Load conv_index + if b := tx.Bucket([]byte("conv_index")); b != nil { + if e := b.ForEach(func(k, v []byte) error { + index[string(k)] = string(v) + return nil + }); e != nil { + return e + } + } + return nil + }) + if err != nil { + return nil, nil, err + } + return items, index, nil +} + +// SaveConvData writes the full conversation data and index to disk atomically. +func SaveConvData(path string, items map[string]ConversationRecord, index map[string]string) error { + if items == nil { + items = map[string]ConversationRecord{} + } + if index == nil { + index = map[string]string{} + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + db, err := bolt.Open(path, 0o600, &bolt.Options{Timeout: 2 * time.Second}) + if err != nil { + return err + } + defer func() { + _ = db.Close() + }() + return db.Update(func(tx *bolt.Tx) error { + // Recreate items bucket + if b := tx.Bucket([]byte("conv_items")); b != nil { + if err = tx.DeleteBucket([]byte("conv_items")); err != nil { + return err + } + } + bi, errCreateBucket := tx.CreateBucket([]byte("conv_items")) + if errCreateBucket != nil { + return errCreateBucket + } + for k, rec := range items { + enc, e := json.Marshal(rec) + if e != nil { + return e + } + if e = bi.Put([]byte(k), enc); e != nil { + return e + } + } + + // Recreate index bucket + if b := tx.Bucket([]byte("conv_index")); b != nil { + if err = tx.DeleteBucket([]byte("conv_index")); err != nil { + return err + } + } + bx, errCreateBucket := tx.CreateBucket([]byte("conv_index")) + if errCreateBucket != nil { + return errCreateBucket + } + for k, v := range index { + if e := bx.Put([]byte(k), []byte(v)); e != nil { + return e + } + } + return nil + }) +} + +// BuildConversationRecord constructs a ConversationRecord from history and the latest output. +// Returns false when output is empty or has no candidates. +func BuildConversationRecord(model, clientID string, history []RoleText, output *ModelOutput, metadata []string) (ConversationRecord, bool) { + if output == nil || len(output.Candidates) == 0 { + return ConversationRecord{}, false + } + text := "" + if t := output.Candidates[0].Text; t != "" { + text = RemoveThinkTags(t) + } + final := append([]RoleText{}, history...) + final = append(final, RoleText{Role: "assistant", Text: text}) + rec := ConversationRecord{ + Model: model, + ClientID: clientID, + Metadata: metadata, + Messages: ToStoredMessages(final), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + return rec, true +} + +// FindByMessageListIn looks up a conversation record by hashed message list. +// It attempts both the stable client ID and a legacy email-based ID. +func FindByMessageListIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { + stored := ToStoredMessages(msgs) + stableHash := HashConversation(stableClientID, model, stored) + fallbackHash := HashConversation(email, model, stored) + + // Try stable hash via index indirection first + if key, ok := index["hash:"+stableHash]; ok { + if rec, ok2 := items[key]; ok2 { + return rec, true + } + } + if rec, ok := items[stableHash]; ok { + return rec, true + } + // Fallback to legacy hash (email-based) + if key, ok := index["hash:"+fallbackHash]; ok { + if rec, ok2 := items[key]; ok2 { + return rec, true + } + } + if rec, ok := items[fallbackHash]; ok { + return rec, true + } + return ConversationRecord{}, false +} + +// FindConversationIn tries exact then sanitized assistant messages. +func FindConversationIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) (ConversationRecord, bool) { + if len(msgs) == 0 { + return ConversationRecord{}, false + } + if rec, ok := FindByMessageListIn(items, index, stableClientID, email, model, msgs); ok { + return rec, true + } + if rec, ok := FindByMessageListIn(items, index, stableClientID, email, model, SanitizeAssistantMessages(msgs)); ok { + return rec, true + } + return ConversationRecord{}, false +} + +// FindReusableSessionIn returns reusable metadata and the remaining message suffix. +func FindReusableSessionIn(items map[string]ConversationRecord, index map[string]string, stableClientID, email, model string, msgs []RoleText) ([]string, []RoleText) { + if len(msgs) < 2 { + return nil, nil + } + searchEnd := len(msgs) + for searchEnd >= 2 { + sub := msgs[:searchEnd] + tail := sub[len(sub)-1] + if strings.EqualFold(tail.Role, "assistant") || strings.EqualFold(tail.Role, "system") { + if rec, ok := FindConversationIn(items, index, stableClientID, email, model, sub); ok { + remain := msgs[searchEnd:] + return rec.Metadata, remain + } + } + searchEnd-- + } + return nil, nil +} diff --git a/internal/provider/gemini-web/types.go b/internal/provider/gemini-web/types.go deleted file mode 100644 index 7edacbdf..00000000 --- a/internal/provider/gemini-web/types.go +++ /dev/null @@ -1,83 +0,0 @@ -package geminiwebapi - -import ( - "fmt" - "html" -) - -type Candidate struct { - RCID string - Text string - Thoughts *string - WebImages []WebImage - GeneratedImages []GeneratedImage -} - -func (c Candidate) String() string { - t := c.Text - if len(t) > 20 { - t = t[:20] + "..." - } - return fmt.Sprintf("Candidate(rcid='%s', text='%s', images=%d)", c.RCID, t, len(c.WebImages)+len(c.GeneratedImages)) -} - -func (c Candidate) Images() []Image { - images := make([]Image, 0, len(c.WebImages)+len(c.GeneratedImages)) - for _, wi := range c.WebImages { - images = append(images, wi.Image) - } - for _, gi := range c.GeneratedImages { - images = append(images, gi.Image) - } - return images -} - -type ModelOutput struct { - Metadata []string - Candidates []Candidate - Chosen int -} - -func (m ModelOutput) String() string { return m.Text() } - -func (m ModelOutput) Text() string { - if len(m.Candidates) == 0 { - return "" - } - return m.Candidates[m.Chosen].Text -} - -func (m ModelOutput) Thoughts() *string { - if len(m.Candidates) == 0 { - return nil - } - return m.Candidates[m.Chosen].Thoughts -} - -func (m ModelOutput) Images() []Image { - if len(m.Candidates) == 0 { - return nil - } - return m.Candidates[m.Chosen].Images() -} - -func (m ModelOutput) RCID() string { - if len(m.Candidates) == 0 { - return "" - } - return m.Candidates[m.Chosen].RCID -} - -type Gem struct { - ID string - Name string - Description *string - Prompt *string - Predefined bool -} - -func (g Gem) String() string { - return fmt.Sprintf("Gem(id='%s', name='%s', description='%v', prompt='%v', predefined=%v)", g.ID, g.Name, g.Description, g.Prompt, g.Predefined) -} - -func decodeHTML(s string) string { return html.UnescapeString(s) }