feat(aistudio): track Gemini usage and improve stream errors

This commit is contained in:
hkfires
2025-10-25 14:45:42 +08:00
parent 3839d93ba0
commit c32e013605
3 changed files with 52 additions and 22 deletions

View File

@@ -14,6 +14,7 @@ import (
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -37,10 +38,13 @@ func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth)
return nil
}
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
return resp, err
}
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
@@ -68,24 +72,29 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
AuthValue: authValue,
})
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
}
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
}
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil {
return nil, err
@@ -114,20 +123,22 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
AuthType: authType,
AuthValue: authValue,
})
stream, err := e.relay.Stream(ctx, e.provider, wsReq)
wsStream, err := e.relay.Stream(ctx, e.provider, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
var param any
metadataLogged := false
for event := range stream {
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
out <- cliproxyexecutor.StreamChunk{Err: event.Err}
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return
}
switch event.Type {
@@ -139,6 +150,9 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
if detail, ok := parseGeminiStreamUsage(event.Payload); ok {
reporter.publish(ctx, detail)
}
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
for i := range lines {
@@ -158,19 +172,21 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return
}
}
}()
return out, nil
return stream, nil
}
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
translatedReq, body, err := e.translateRequest(req, opts, false)
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
}
@@ -210,9 +226,12 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
}
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
if totalTokens <= 0 {
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
}
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {

View File

@@ -142,11 +142,16 @@ func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
s.provider = strings.ToLower(s.id)
}
m.sessMutex.Lock()
var replaced *session
if existing, ok := m.sessions[s.provider]; ok {
existing.cleanup(errors.New("replaced by new connection"))
replaced = existing
}
m.sessions[s.provider] = s
m.sessMutex.Unlock()
if replaced != nil {
replaced.cleanup(errors.New("replaced by new connection"))
}
if m.onConnected != nil {
m.onConnected(s.provider)
}