mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
fix: improve streaming bootstrap and forwarding
This commit is contained in:
@@ -22,6 +22,21 @@ type SDKConfig struct {
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
}
|
||||
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
type StreamingConfig struct {
|
||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||
// nil means default (15 seconds). <= 0 disables keep-alives.
|
||||
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
|
||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||
// to allow auth rotation / transient recovery.
|
||||
// nil means default (2). 0 disables bootstrap retries.
|
||||
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
// - c: The Gin context for the request.
|
||||
// - rawJSON: The raw JSON request body.
|
||||
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||
// These headers are essential for maintaining a persistent connection
|
||||
// and enabling real-time streaming of chat completions
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
@@ -213,56 +204,72 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
|
||||
// This guarantees clients see incremental output even for small responses.
|
||||
for {
|
||||
// Peek at the first chunk to determine success or failure before setting headers
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
|
||||
case chunk, ok := <-data:
|
||||
case errMsg := <-errChan:
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers now.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write the first chunk
|
||||
if len(chunk) > 0 {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
// Continue streaming the rest
|
||||
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
if errMsg != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
// An error occurred: emit as a proper SSE error event
|
||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type claudeErrorDetail struct {
|
||||
|
||||
@@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
cancel(nil)
|
||||
return
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
KeepAliveInterval: keepAliveInterval,
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if alt == "" {
|
||||
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(chunk, []byte("data:")) {
|
||||
@@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
if alt == "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
|
||||
alt := h.GetAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -247,8 +240,57 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg := <-errChan:
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Closed without data
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
|
||||
// Write first chunk
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Continue
|
||||
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCountTokens handles token counting requests for Gemini models.
|
||||
@@ -297,16 +339,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
cancel(nil)
|
||||
return
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
KeepAliveInterval: keepAliveInterval,
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
@@ -314,22 +355,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
if alt == "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,8 +9,10 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -40,6 +42,115 @@ type ErrorDetail struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
const idempotencyKeyMetadataKey = "idempotency_key"
|
||||
|
||||
const (
|
||||
defaultStreamingKeepAliveSeconds = 15
|
||||
defaultStreamingBootstrapRetries = 2
|
||||
)
|
||||
|
||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if strings.TrimSpace(errText) == "" {
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
|
||||
errType := "invalid_request_error"
|
||||
var code string
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
code = "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
code = "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
errType = "invalid_request_error"
|
||||
code = "model_not_found"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
code = "internal_server_error"
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
Code: code,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
|
||||
// Returning 0 disables keep-alives.
|
||||
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
seconds := defaultStreamingKeepAliveSeconds
|
||||
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
||||
seconds = *cfg.Streaming.KeepAliveSeconds
|
||||
}
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
retries := defaultStreamingBootstrapRetries
|
||||
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
||||
retries = *cfg.Streaming.BootstrapRetries
|
||||
}
|
||||
if retries < 0 {
|
||||
retries = 0
|
||||
}
|
||||
return retries
|
||||
}
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
key := ""
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
||||
}
|
||||
}
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
}
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
}
|
||||
|
||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(overlay) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(overlay))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range overlay {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service and manages
|
||||
// load balancing, client selection, and configuration.
|
||||
@@ -182,6 +293,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -195,9 +307,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -224,6 +334,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -237,9 +348,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -269,6 +378,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -282,9 +392,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
@@ -309,31 +417,81 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
go func() {
|
||||
defer close(dataChan)
|
||||
defer close(errChan)
|
||||
sentPayload := false
|
||||
bootstrapRetries := 0
|
||||
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
|
||||
|
||||
bootstrapEligible := func(err error) bool {
|
||||
status := statusFromError(err)
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
switch status {
|
||||
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
|
||||
http.StatusRequestTimeout, http.StatusTooManyRequests:
|
||||
return true
|
||||
default:
|
||||
return status >= http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
outer:
|
||||
for {
|
||||
for chunk := range chunks {
|
||||
if chunk.Err != nil {
|
||||
streamErr := chunk.Err
|
||||
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
|
||||
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
|
||||
if !sentPayload {
|
||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||
bootstrapRetries++
|
||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if retryErr == nil {
|
||||
chunks = retryChunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
}
|
||||
}
|
||||
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
status = code
|
||||
}
|
||||
}
|
||||
var addon http.Header
|
||||
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if hdr := he.Headers(); hdr != nil {
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
sentPayload = true
|
||||
dataChan <- cloneBytes(chunk.Payload)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
@@ -417,38 +575,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer preserving upstream JSON error bodies when possible.
|
||||
buildJSONBody := func() []byte {
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
errType := "invalid_request_error"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
body := buildJSONBody()
|
||||
body := BuildErrorResponseBody(status, errText)
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
|
||||
if !c.Writer.Written() {
|
||||
|
||||
120
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
120
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
type failOnceStreamExecutor struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
call := e.calls
|
||||
e.mu.Unlock()
|
||||
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
if call == 1 {
|
||||
ch <- coreexecutor.StreamChunk{
|
||||
Err: &coreauth.Error{
|
||||
Code: "unauthorized",
|
||||
Message: "unauthorized",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager, nil)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if executor.Calls() != 2 {
|
||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -463,8 +458,48 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk to determine success or failure before setting headers
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg := <-errChan:
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Commit to streaming headers.
|
||||
setSSEHeaders()
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
|
||||
// Continue streaming the rest
|
||||
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
// It converts completions request to chat completions format, sends to backend,
|
||||
@@ -500,11 +535,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
||||
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -524,71 +554,101 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
|
||||
for {
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, isOk := <-dataChan:
|
||||
if !isOk {
|
||||
case errMsg := <-errChan:
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write the first chunk
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
if converted != nil {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
|
||||
flusher.Flush()
|
||||
}
|
||||
case errMsg, isOk := <-errChan:
|
||||
if !isOk {
|
||||
|
||||
done := make(chan struct{})
|
||||
var doneOnce sync.Once
|
||||
stop := func() { doneOnce.Do(func() { close(done) }) }
|
||||
|
||||
convertedChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(convertedChan)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
if converted == nil {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cliCancel(execErr)
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case convertedChan <- converted:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
h.handleStreamResult(c, flusher, func(err error) {
|
||||
stop()
|
||||
cliCancel(err)
|
||||
}, convertedChan, errChan)
|
||||
}
|
||||
}
|
||||
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
|
||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -149,46 +143,80 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
// Peek at the first chunk
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
case errMsg := <-errChan:
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send headers and done.
|
||||
setSSEHeaders()
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
// Continue
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
cancel(execErr)
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
121
sdk/api/handlers/stream_forwarder.go
Normal file
121
sdk/api/handlers/stream_forwarder.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
)
|
||||
|
||||
type StreamForwardOptions struct {
|
||||
// KeepAliveInterval overrides the configured streaming keep-alive interval.
|
||||
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
|
||||
KeepAliveInterval *time.Duration
|
||||
|
||||
// WriteChunk writes a single data chunk to the response body. It should not flush.
|
||||
WriteChunk func(chunk []byte)
|
||||
|
||||
// WriteTerminalError writes an error payload to the response body when streaming fails
|
||||
// after headers have already been committed. It should not flush.
|
||||
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
|
||||
|
||||
// WriteDone optionally writes a terminal marker when the upstream data channel closes
|
||||
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
|
||||
WriteDone func()
|
||||
|
||||
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
|
||||
// When nil, a standard SSE comment heartbeat is used.
|
||||
WriteKeepAlive func()
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
writeChunk := opts.WriteChunk
|
||||
if writeChunk == nil {
|
||||
writeChunk = func([]byte) {}
|
||||
}
|
||||
|
||||
writeKeepAlive := opts.WriteKeepAlive
|
||||
if writeKeepAlive == nil {
|
||||
writeKeepAlive = func() {
|
||||
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
|
||||
if opts.KeepAliveInterval != nil {
|
||||
keepAliveInterval = *opts.KeepAliveInterval
|
||||
}
|
||||
var keepAlive *time.Ticker
|
||||
var keepAliveC <-chan time.Time
|
||||
if keepAliveInterval > 0 {
|
||||
keepAlive = time.NewTicker(keepAliveInterval)
|
||||
defer keepAlive.Stop()
|
||||
keepAliveC = keepAlive.C
|
||||
}
|
||||
|
||||
var terminalErr *interfaces.ErrorMessage
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
// Prefer surfacing a terminal error if one is pending.
|
||||
if terminalErr == nil {
|
||||
select {
|
||||
case errMsg, ok := <-errs:
|
||||
if ok && errMsg != nil {
|
||||
terminalErr = errMsg
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
if terminalErr != nil {
|
||||
if opts.WriteTerminalError != nil {
|
||||
opts.WriteTerminalError(terminalErr)
|
||||
}
|
||||
flusher.Flush()
|
||||
cancel(terminalErr.Error)
|
||||
return
|
||||
}
|
||||
if opts.WriteDone != nil {
|
||||
opts.WriteDone()
|
||||
}
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
writeChunk(chunk)
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
terminalErr = errMsg
|
||||
if opts.WriteTerminalError != nil {
|
||||
opts.WriteTerminalError(errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-keepAliveC:
|
||||
writeKeepAlive()
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider
|
||||
|
||||
type Config = internalconfig.Config
|
||||
|
||||
type StreamingConfig = internalconfig.StreamingConfig
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
|
||||
Reference in New Issue
Block a user