fix: improve streaming bootstrap and forwarding

This commit is contained in:
gwizz
2025-12-22 17:21:29 +11:00
parent 27b43ed63f
commit 71a6dffbb6
10 changed files with 804 additions and 279 deletions

View File

@@ -9,8 +9,10 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -40,6 +42,115 @@ type ErrorDetail struct {
Code string `json:"code,omitempty"`
}
const idempotencyKeyMetadataKey = "idempotency_key"
const (
defaultStreamingKeepAliveSeconds = 15
defaultStreamingBootstrapRetries = 2
)
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if strings.TrimSpace(errText) == "" {
errText = http.StatusText(status)
}
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
var code string
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
code = "invalid_api_key"
case http.StatusForbidden:
errType = "permission_error"
code = "insufficient_quota"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
code = "rate_limit_exceeded"
case http.StatusNotFound:
errType = "invalid_request_error"
code = "model_not_found"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
code = "internal_server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
Code: code,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
}
return payload
}
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
// Returning 0 disables keep-alives.
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
seconds = *cfg.Streaming.KeepAliveSeconds
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
retries = *cfg.Streaming.BootstrapRetries
}
if retries < 0 {
retries = 0
}
return retries
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
}
}
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
}
func mergeMetadata(base, overlay map[string]any) map[string]any {
if len(base) == 0 && len(overlay) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(overlay))
for k, v := range base {
out[k] = v
}
for k, v := range overlay {
out[k] = v
}
return out
}
// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
@@ -182,6 +293,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
if errMsg != nil {
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -195,9 +307,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
@@ -224,6 +334,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
if errMsg != nil {
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -237,9 +348,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
@@ -269,6 +378,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
close(errChan)
return nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -282,9 +392,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -309,31 +417,81 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
go func() {
defer close(dataChan)
defer close(errChan)
for chunk := range chunks {
if chunk.Err != nil {
status := http.StatusInternalServerError
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
return
sentPayload := false
bootstrapRetries := 0
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
bootstrapEligible := func(err error) bool {
status := statusFromError(err)
if status == 0 {
return true
}
if len(chunk.Payload) > 0 {
dataChan <- cloneBytes(chunk.Payload)
switch status {
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
http.StatusRequestTimeout, http.StatusTooManyRequests:
return true
default:
return status >= http.StatusInternalServerError
}
}
outer:
for {
for chunk := range chunks {
if chunk.Err != nil {
streamErr := chunk.Err
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
chunks = retryChunks
continue outer
}
streamErr = retryErr
}
}
status := http.StatusInternalServerError
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
return
}
if len(chunk.Payload) > 0 {
sentPayload = true
dataChan <- cloneBytes(chunk.Payload)
}
}
return
}
}()
return dataChan, errChan
}
func statusFromError(err error) int {
if err == nil {
return 0
}
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
return code
}
}
return 0
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
// Resolve "auto" model to an actual available model first
resolvedModelName := util.ResolveAutoModel(modelName)
@@ -417,38 +575,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
}
}
// Prefer preserving upstream JSON error bodies when possible.
buildJSONBody := func() []byte {
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
case http.StatusForbidden:
errType = "permission_error"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
}
return payload
}
body := buildJSONBody()
body := BuildErrorResponseBody(status, errText)
c.Set("API_RESPONSE", bytes.Clone(body))
if !c.Writer.Written() {