From 71a6dffbb6299c92646ed371becb16514c50bce6 Mon Sep 17 00:00:00 2001 From: gwizz Date: Mon, 22 Dec 2025 17:21:29 +1100 Subject: [PATCH] fix: improve streaming bootstrap and forwarding --- internal/config/sdk_config.go | 15 ++ sdk/api/handlers/claude/code_handlers.go | 111 ++++---- .../handlers/gemini/gemini-cli_handlers.go | 54 ++-- sdk/api/handlers/gemini/gemini_handlers.go | 112 +++++--- sdk/api/handlers/handlers.go | 245 +++++++++++++----- .../handlers_stream_bootstrap_test.go | 120 +++++++++ sdk/api/handlers/openai/openai_handlers.go | 202 ++++++++++----- .../openai/openai_responses_handlers.go | 102 +++++--- sdk/api/handlers/stream_forwarder.go | 121 +++++++++ sdk/config/config.go | 1 + 10 files changed, 804 insertions(+), 279 deletions(-) create mode 100644 sdk/api/handlers/handlers_stream_bootstrap_test.go create mode 100644 sdk/api/handlers/stream_forwarder.go diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go index f6f20d5c..7f019520 100644 --- a/internal/config/sdk_config.go +++ b/internal/config/sdk_config.go @@ -22,6 +22,21 @@ type SDKConfig struct { // Access holds request authentication provider configuration. Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` + + // Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries). + Streaming StreamingConfig `yaml:"streaming" json:"streaming"` +} + +// StreamingConfig holds server streaming behavior configuration. +type StreamingConfig struct { + // KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n"). + // nil means default (15 seconds). <= 0 disables keep-alives. + KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"` + + // BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent, + // to allow auth rotation / transient recovery. + // nil means default (2). 0 disables bootstrap retries. + BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"` } // AccessConfig groups request authentication providers. diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index 8a4c4806..bdf7c9c7 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -14,7 +14,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO // - c: The Gin context for the request. // - rawJSON: The raw JSON request body. func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - // Set up Server-Sent Events (SSE) headers for streaming response - // These headers are essential for maintaining a persistent connection - // and enabling real-time streaming of chat completions - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. // This is crucial for streaming as it allows immediate sending of data chunks flusher, ok := c.Writer.(http.Flusher) @@ -213,56 +204,72 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk to determine success or failure before setting headers + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers now. + setSSEHeaders() + + // Write the first chunk + if len(chunk) > 0 { + _, _ = c.Writer.Write(chunk) + flusher.Flush() + } + + // Continue streaming the rest + h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - // OpenAI-style stream forwarding: write each SSE chunk and flush immediately. - // This guarantees clients see incremental output even for small responses. - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - - case chunk, ok := <-data: - if !ok { - flusher.Flush() - cancel(nil) + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + if len(chunk) == 0 { return } - if len(chunk) > 0 { - _, _ = c.Writer.Write(chunk) - flusher.Flush() + _, _ = c.Writer.Write(chunk) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + } + c.Status(status) - case errMsg, ok := <-errs: - if !ok { - continue - } - if errMsg != nil { - status := http.StatusInternalServerError - if errMsg.StatusCode > 0 { - status = errMsg.StatusCode - } - c.Status(status) - - // An error occurred: emit as a proper SSE error event - errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) - _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) - flusher.Flush() - } - - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + errorBytes, _ := json.Marshal(h.toClaudeError(errMsg)) + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes) + }, + }) } type claudeErrorDetail struct { diff --git a/sdk/api/handlers/gemini/gemini-cli_handlers.go b/sdk/api/handlers/gemini/gemini-cli_handlers.go index 5224faf8..ea78657d 100644 --- a/sdk/api/handlers/gemini/gemini-cli_handlers.go +++ b/sdk/api/handlers/gemini/gemini-cli_handlers.go @@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ } func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) { - continue + return } if !bytes.HasPrefix(chunk, []byte("data:")) { @@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 901421b5..baf68aac 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) { func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) { alt := h.GetAlt(c) - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -247,8 +240,57 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) - h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) - return + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Closed without data + if alt == "" { + setSSEHeaders() + } + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + if alt == "" { + setSSEHeaders() + } + + // Write first chunk + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } + flusher.Flush() + + // Continue + h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan) + } } // handleCountTokens handles token counting requests for Gemini models. @@ -297,16 +339,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin } func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - cancel(nil) - return - } + var keepAliveInterval *time.Duration + if alt != "" { + disabled := time.Duration(0) + keepAliveInterval = &disabled + } + + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + KeepAliveInterval: keepAliveInterval, + WriteChunk: func(chunk []byte) { if alt == "" { _, _ = c.Writer.Write([]byte("data: ")) _, _ = c.Writer.Write(chunk) @@ -314,22 +355,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus } else { _, _ = c.Writer.Write(chunk) } - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + if alt == "" { + _, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body)) + } else { + _, _ = c.Writer.Write(body) + } + }, + }) } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index e5b4fc93..5d33fe0e 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -9,8 +9,10 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -40,6 +42,115 @@ type ErrorDetail struct { Code string `json:"code,omitempty"` } +const idempotencyKeyMetadataKey = "idempotency_key" + +const ( + defaultStreamingKeepAliveSeconds = 15 + defaultStreamingBootstrapRetries = 2 +) + +// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body. +// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads. +func BuildErrorResponseBody(status int, errText string) []byte { + if status <= 0 { + status = http.StatusInternalServerError + } + if strings.TrimSpace(errText) == "" { + errText = http.StatusText(status) + } + + trimmed := strings.TrimSpace(errText) + if trimmed != "" && json.Valid([]byte(trimmed)) { + return []byte(trimmed) + } + + errType := "invalid_request_error" + var code string + switch status { + case http.StatusUnauthorized: + errType = "authentication_error" + code = "invalid_api_key" + case http.StatusForbidden: + errType = "permission_error" + code = "insufficient_quota" + case http.StatusTooManyRequests: + errType = "rate_limit_error" + code = "rate_limit_exceeded" + case http.StatusNotFound: + errType = "invalid_request_error" + code = "model_not_found" + default: + if status >= http.StatusInternalServerError { + errType = "server_error" + code = "internal_server_error" + } + } + + payload, err := json.Marshal(ErrorResponse{ + Error: ErrorDetail{ + Message: errText, + Type: errType, + Code: code, + }, + }) + if err != nil { + return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText)) + } + return payload +} + +// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server. +// Returning 0 disables keep-alives. +func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration { + seconds := defaultStreamingKeepAliveSeconds + if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil { + seconds = *cfg.Streaming.KeepAliveSeconds + } + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent. +func StreamingBootstrapRetries(cfg *config.SDKConfig) int { + retries := defaultStreamingBootstrapRetries + if cfg != nil && cfg.Streaming.BootstrapRetries != nil { + retries = *cfg.Streaming.BootstrapRetries + } + if retries < 0 { + retries = 0 + } + return retries +} + +func requestExecutionMetadata(ctx context.Context) map[string]any { + key := "" + if ctx != nil { + if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { + key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")) + } + } + if key == "" { + key = uuid.NewString() + } + return map[string]any{idempotencyKeyMetadataKey: key} +} + +func mergeMetadata(base, overlay map[string]any) map[string]any { + if len(base) == 0 && len(overlay) == 0 { + return nil + } + out := make(map[string]any, len(base)+len(overlay)) + for k, v := range base { + out[k] = v + } + for k, v := range overlay { + out[k] = v + } + return out +} + // BaseAPIHandler contains the handlers for API endpoints. // It holds a pool of clients to interact with the backend service and manages // load balancing, client selection, and configuration. @@ -182,6 +293,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -195,9 +307,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -224,6 +334,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle if errMsg != nil { return nil, errMsg } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -237,9 +348,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { status := http.StatusInternalServerError @@ -269,6 +378,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl close(errChan) return nil, errChan } + reqMeta := requestExecutionMetadata(ctx) req := coreexecutor.Request{ Model: normalizedModel, Payload: cloneBytes(rawJSON), @@ -282,9 +392,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } - if cloned := cloneMetadata(metadata); cloned != nil { - opts.Metadata = cloned - } + opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta) chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -309,31 +417,81 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl go func() { defer close(dataChan) defer close(errChan) - for chunk := range chunks { - if chunk.Err != nil { - status := http.StatusInternalServerError - if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil { - if code := se.StatusCode(); code > 0 { - status = code - } - } - var addon http.Header - if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() - } - } - errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon} - return + sentPayload := false + bootstrapRetries := 0 + maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + + bootstrapEligible := func(err error) bool { + status := statusFromError(err) + if status == 0 { + return true } - if len(chunk.Payload) > 0 { - dataChan <- cloneBytes(chunk.Payload) + switch status { + case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired, + http.StatusRequestTimeout, http.StatusTooManyRequests: + return true + default: + return status >= http.StatusInternalServerError } } + + outer: + for { + for chunk := range chunks { + if chunk.Err != nil { + streamErr := chunk.Err + // Safe bootstrap recovery: if the upstream fails before any payload bytes are sent, + // retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback. + if !sentPayload { + if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) { + bootstrapRetries++ + retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + if retryErr == nil { + chunks = retryChunks + continue outer + } + streamErr = retryErr + } + } + + status := http.StatusInternalServerError + if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } + } + var addon http.Header + if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon} + return + } + if len(chunk.Payload) > 0 { + sentPayload = true + dataChan <- cloneBytes(chunk.Payload) + } + } + return + } }() return dataChan, errChan } +func statusFromError(err error) int { + if err == nil { + return 0 + } + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + return code + } + } + return 0 +} + func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) { // Resolve "auto" model to an actual available model first resolvedModelName := util.ResolveAutoModel(modelName) @@ -417,38 +575,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro } } - // Prefer preserving upstream JSON error bodies when possible. - buildJSONBody := func() []byte { - trimmed := strings.TrimSpace(errText) - if trimmed != "" && json.Valid([]byte(trimmed)) { - return []byte(trimmed) - } - errType := "invalid_request_error" - switch status { - case http.StatusUnauthorized: - errType = "authentication_error" - case http.StatusForbidden: - errType = "permission_error" - case http.StatusTooManyRequests: - errType = "rate_limit_error" - default: - if status >= http.StatusInternalServerError { - errType = "server_error" - } - } - payload, err := json.Marshal(ErrorResponse{ - Error: ErrorDetail{ - Message: errText, - Type: errType, - }, - }) - if err != nil { - return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText)) - } - return payload - } - - body := buildJSONBody() + body := BuildErrorResponseBody(status, errText) c.Set("API_RESPONSE", bytes.Clone(body)) if !c.Writer.Written() { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go new file mode 100644 index 00000000..cd2fdf4d --- /dev/null +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -0,0 +1,120 @@ +package handlers + +import ( + "context" + "net/http" + "sync" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +type failOnceStreamExecutor struct { + mu sync.Mutex + calls int +} + +func (e *failOnceStreamExecutor) Identifier() string { return "codex" } + +func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { + e.mu.Lock() + e.calls++ + call := e.calls + e.mu.Unlock() + + ch := make(chan coreexecutor.StreamChunk, 1) + if call == 1 { + ch <- coreexecutor.StreamChunk{ + Err: &coreauth.Error{ + Code: "unauthorized", + Message: "unauthorized", + Retryable: false, + HTTPStatus: http.StatusUnauthorized, + }, + } + close(ch) + return ch, nil + } + + ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} + close(ch) + return ch, nil +} + +func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *failOnceStreamExecutor) Calls() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.calls +} + +func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { + executor := &failOnceStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager, nil) + dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if string(got) != "ok" { + t.Fatalf("expected payload ok, got %q", string(got)) + } + if executor.Calls() != 2 { + t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) + } +} diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index ae925f91..d5962ea7 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -11,7 +11,7 @@ import ( "encoding/json" "fmt" "net/http" - "time" + "sync" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible request func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -463,7 +458,47 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c)) - h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk to determine success or failure before setting headers + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send DONE or just headers. + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Commit to streaming headers. + setSSEHeaders() + + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + flusher.Flush() + + // Continue streaming the rest + h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } // handleCompletionsNonStreamingResponse handles non-streaming completions responses. @@ -500,11 +535,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -524,71 +554,101 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") - for { - select { - case <-c.Request.Context().Done(): - cliCancel(c.Request.Context().Err()) - return - case chunk, isOk := <-dataChan: - if !isOk { - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cliCancel() - return - } - converted := convertChatCompletionsStreamChunkToCompletions(chunk) - if converted != nil { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) - flusher.Flush() - } - case errMsg, isOk := <-errChan: - if !isOk { - continue - } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() - } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cliCancel(execErr) - return - case <-time.After(500 * time.Millisecond): + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) } + return + case chunk, ok := <-dataChan: + if !ok { + setSSEHeaders() + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + setSSEHeaders() + + // Write the first chunk + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted != nil { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted)) + flusher.Flush() + } + + done := make(chan struct{}) + var doneOnce sync.Once + stop := func() { doneOnce.Do(func() { close(done) }) } + + convertedChan := make(chan []byte) + go func() { + defer close(convertedChan) + for { + select { + case <-done: + return + case chunk, ok := <-dataChan: + if !ok { + return + } + converted := convertChatCompletionsStreamChunkToCompletions(chunk) + if converted == nil { + continue + } + select { + case <-done: + return + case convertedChan <- converted: + } + } + } + }() + + h.handleStreamResult(c, flusher, func(err error) { + stop() + cliCancel(err) + }, convertedChan, errChan) } } func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cancel(nil) + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { return } - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk)) - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - var execErr error - if errMsg != nil { - execErr = errMsg.Error - } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + }, + }) } diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index ace02313..dd63deeb 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -11,7 +11,6 @@ import ( "context" "fmt" "net/http" - "time" "github.com/gin-gonic/gin" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" @@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r // - c: The Gin context containing the HTTP request and response // - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { @@ -149,46 +143,80 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ modelName := gjson.GetBytes(rawJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") - h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) - return + + setSSEHeaders := func() { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Peek at the first chunk + select { + case <-c.Request.Context().Done(): + cliCancel(c.Request.Context().Err()) + return + case errMsg := <-errChan: + // Upstream failed immediately. Return proper error status and JSON. + h.WriteErrorResponse(c, errMsg) + if errMsg != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + case chunk, ok := <-dataChan: + if !ok { + // Stream closed without data? Send headers and done. + setSSEHeaders() + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + cliCancel(nil) + return + } + + // Success! Set headers. + setSSEHeaders() + + // Write first chunk logic (matching forwardResponsesStream) + if bytes.HasPrefix(chunk, []byte("event:")) { + _, _ = c.Writer.Write([]byte("\n")) + } + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + + // Continue + h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + } } func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { - for { - select { - case <-c.Request.Context().Done(): - cancel(c.Request.Context().Err()) - return - case chunk, ok := <-data: - if !ok { - _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - cancel(nil) - return - } - + h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ + WriteChunk: func(chunk []byte) { if bytes.HasPrefix(chunk, []byte("event:")) { _, _ = c.Writer.Write([]byte("\n")) } _, _ = c.Writer.Write(chunk) _, _ = c.Writer.Write([]byte("\n")) - - flusher.Flush() - case errMsg, ok := <-errs: - if !ok { - continue + }, + WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + if errMsg == nil { + return } - if errMsg != nil { - h.WriteErrorResponse(c, errMsg) - flusher.Flush() + status := http.StatusInternalServerError + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode } - var execErr error - if errMsg != nil { - execErr = errMsg.Error + errText := http.StatusText(status) + if errMsg.Error != nil && errMsg.Error.Error() != "" { + errText = errMsg.Error.Error() } - cancel(execErr) - return - case <-time.After(500 * time.Millisecond): - } - } + body := handlers.BuildErrorResponseBody(status, errText) + _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body)) + }, + WriteDone: func() { + _, _ = c.Writer.Write([]byte("\n")) + }, + }) } diff --git a/sdk/api/handlers/stream_forwarder.go b/sdk/api/handlers/stream_forwarder.go new file mode 100644 index 00000000..401baca8 --- /dev/null +++ b/sdk/api/handlers/stream_forwarder.go @@ -0,0 +1,121 @@ +package handlers + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" +) + +type StreamForwardOptions struct { + // KeepAliveInterval overrides the configured streaming keep-alive interval. + // If nil, the configured default is used. If set to <= 0, keep-alives are disabled. + KeepAliveInterval *time.Duration + + // WriteChunk writes a single data chunk to the response body. It should not flush. + WriteChunk func(chunk []byte) + + // WriteTerminalError writes an error payload to the response body when streaming fails + // after headers have already been committed. It should not flush. + WriteTerminalError func(errMsg *interfaces.ErrorMessage) + + // WriteDone optionally writes a terminal marker when the upstream data channel closes + // without an error (e.g. OpenAI's `[DONE]`). It should not flush. + WriteDone func() + + // WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush. + // When nil, a standard SSE comment heartbeat is used. + WriteKeepAlive func() +} + +func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) { + if c == nil { + return + } + if cancel == nil { + return + } + + writeChunk := opts.WriteChunk + if writeChunk == nil { + writeChunk = func([]byte) {} + } + + writeKeepAlive := opts.WriteKeepAlive + if writeKeepAlive == nil { + writeKeepAlive = func() { + _, _ = c.Writer.Write([]byte(": keep-alive\n\n")) + } + } + + keepAliveInterval := StreamingKeepAliveInterval(h.Cfg) + if opts.KeepAliveInterval != nil { + keepAliveInterval = *opts.KeepAliveInterval + } + var keepAlive *time.Ticker + var keepAliveC <-chan time.Time + if keepAliveInterval > 0 { + keepAlive = time.NewTicker(keepAliveInterval) + defer keepAlive.Stop() + keepAliveC = keepAlive.C + } + + var terminalErr *interfaces.ErrorMessage + for { + select { + case <-c.Request.Context().Done(): + cancel(c.Request.Context().Err()) + return + case chunk, ok := <-data: + if !ok { + // Prefer surfacing a terminal error if one is pending. + if terminalErr == nil { + select { + case errMsg, ok := <-errs: + if ok && errMsg != nil { + terminalErr = errMsg + } + default: + } + } + if terminalErr != nil { + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(terminalErr) + } + flusher.Flush() + cancel(terminalErr.Error) + return + } + if opts.WriteDone != nil { + opts.WriteDone() + } + flusher.Flush() + cancel(nil) + return + } + writeChunk(chunk) + flusher.Flush() + case errMsg, ok := <-errs: + if !ok { + continue + } + if errMsg != nil { + terminalErr = errMsg + if opts.WriteTerminalError != nil { + opts.WriteTerminalError(errMsg) + flusher.Flush() + } + } + var execErr error + if errMsg != nil { + execErr = errMsg.Error + } + cancel(execErr) + return + case <-keepAliveC: + writeKeepAlive() + flusher.Flush() + } + } +} diff --git a/sdk/config/config.go b/sdk/config/config.go index 6e4efad5..b471e5e0 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider type Config = internalconfig.Config +type StreamingConfig = internalconfig.StreamingConfig type TLSConfig = internalconfig.TLSConfig type RemoteManagement = internalconfig.RemoteManagement type AmpCode = internalconfig.AmpCode