From c32e0136050279c1d7da444e8d776147a0864dd5 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 14:45:42 +0800 Subject: [PATCH] feat(aistudio): track Gemini usage and improve stream errors --- .../runtime/executor/aistudio_executor.go | 59 ++++++++++++------- internal/wsrelay/manager.go | 7 ++- sdk/cliproxy/service.go | 8 ++- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 3eb9af24..4bcdab3a 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -14,6 +14,7 @@ import ( cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -37,10 +38,13 @@ func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) return nil } -func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { +func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + translatedReq, body, err := e.translateRequest(req, opts, false) if err != nil { - return cliproxyexecutor.Response{}, err + return resp, err } endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ @@ -68,24 +72,29 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, AuthValue: authValue, }) - resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) + wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) - return cliproxyexecutor.Response{}, err + return resp, err } - recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) - if len(resp.Body) > 0 { - appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) + recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone()) + if len(wsResp.Body) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body)) } - if resp.Status < 200 || resp.Status >= 300 { - return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} + if wsResp.Status < 200 || wsResp.Status >= 300 { + return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} } + reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil } -func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { +func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + translatedReq, body, err := e.translateRequest(req, opts, true) if err != nil { return nil, err @@ -114,20 +123,22 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth AuthType: authType, AuthValue: authValue, }) - stream, err := e.relay.Stream(ctx, e.provider, wsReq) + wsStream, err := e.relay.Stream(ctx, e.provider, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return nil, err } out := make(chan cliproxyexecutor.StreamChunk) + stream = out go func() { defer close(out) var param any metadataLogged := false - for event := range stream { + for event := range wsStream { if event.Err != nil { recordAPIResponseError(ctx, e.cfg, event.Err) - out <- cliproxyexecutor.StreamChunk{Err: event.Err} + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return } switch event.Type { @@ -139,6 +150,9 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth case wsrelay.MessageTypeStreamChunk: if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + if detail, ok := parseGeminiStreamUsage(event.Payload); ok { + reporter.publish(ctx, detail) + } } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) for i := range lines { @@ -158,19 +172,21 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth for i := range lines { out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} } + reporter.publish(ctx, parseGeminiUsage(event.Payload)) return case wsrelay.MessageTypeError: recordAPIResponseError(ctx, e.cfg, event.Err) + reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return } } }() - return out, nil + return stream, nil } func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - translatedReq, body, err := e.translateRequest(req, opts, false) + _, body, err := e.translateRequest(req, opts, false) if err != nil { return cliproxyexecutor.Response{}, err } @@ -210,9 +226,12 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A if resp.Status < 200 || resp.Status >= 300 { return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} } - var param any - out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil + totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int() + if totalTokens <= 0 { + return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response") + } + translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body)) + return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { diff --git a/internal/wsrelay/manager.go b/internal/wsrelay/manager.go index ab32f9f3..ae28234c 100644 --- a/internal/wsrelay/manager.go +++ b/internal/wsrelay/manager.go @@ -142,11 +142,16 @@ func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) { s.provider = strings.ToLower(s.id) } m.sessMutex.Lock() + var replaced *session if existing, ok := m.sessions[s.provider]; ok { - existing.cleanup(errors.New("replaced by new connection")) + replaced = existing } m.sessions[s.provider] = s m.sessMutex.Unlock() + + if replaced != nil { + replaced.cleanup(errors.New("replaced by new connection")) + } if m.onConnected != nil { m.onConnected(s.provider) } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index ccbdf903..b0f4605b 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -203,7 +203,9 @@ func (s *Service) wsOnConnected(provider string) { } if s.coreManager != nil { if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil { - return + if !existing.Disabled && existing.Status == coreauth.StatusActive { + return + } } } now := time.Now().UTC() @@ -225,6 +227,10 @@ func (s *Service) wsOnDisconnected(provider string, reason error) { return } if reason != nil { + if strings.Contains(reason.Error(), "replaced by new connection") { + log.Infof("websocket provider replaced: %s", provider) + return + } log.Warnf("websocket provider disconnected: %s (%v)", provider, reason) } else { log.Infof("websocket provider disconnected: %s", provider)