diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 220826c0..8f78547b 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -129,18 +129,60 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth recordAPIResponseError(ctx, e.cfg, err) return nil, err } + firstEvent, ok := <-wsStream + if !ok { + err = fmt.Errorf("wsrelay: stream closed before start") + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK { + metadataLogged := false + if firstEvent.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone()) + metadataLogged = true + } + var body bytes.Buffer + if len(firstEvent.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload)) + body.Write(firstEvent.Payload) + } + if firstEvent.Type == wsrelay.MessageTypeStreamEnd { + return nil, statusErr{code: firstEvent.Status, msg: body.String()} + } + for event := range wsStream { + if event.Err != nil { + recordAPIResponseError(ctx, e.cfg, event.Err) + if body.Len() == 0 { + body.WriteString(event.Err.Error()) + } + break + } + if !metadataLogged && event.Status > 0 { + recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) + metadataLogged = true + } + if len(event.Payload) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) + body.Write(event.Payload) + } + if event.Type == wsrelay.MessageTypeStreamEnd { + break + } + } + return nil, statusErr{code: firstEvent.Status, msg: body.String()} + } out := make(chan cliproxyexecutor.StreamChunk) stream = out - go func() { + go func(first wsrelay.StreamEvent) { defer close(out) var param any metadataLogged := false - for event := range wsStream { + processEvent := func(event wsrelay.StreamEvent) bool { if event.Err != nil { recordAPIResponseError(ctx, e.cfg, event.Err) reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} - return + return false } switch event.Type { case wsrelay.MessageTypeStreamStart: @@ -162,7 +204,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth break } case wsrelay.MessageTypeStreamEnd: - return + return false case wsrelay.MessageTypeHTTPResp: if !metadataLogged && event.Status > 0 { recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) @@ -176,15 +218,24 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} } reporter.publish(ctx, parseGeminiUsage(event.Payload)) - return + return false case wsrelay.MessageTypeError: recordAPIResponseError(ctx, e.cfg, event.Err) reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} + return false + } + return true + } + if !processEvent(first) { + return + } + for event := range wsStream { + if !processEvent(event) { return } } - }() + }(firstEvent) return stream, nil }