mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-19 12:50:51 +08:00
fix: improve streaming bootstrap and forwarding
This commit is contained in:
@@ -9,8 +9,10 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -40,6 +42,115 @@ type ErrorDetail struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
const idempotencyKeyMetadataKey = "idempotency_key"
|
||||
|
||||
const (
|
||||
defaultStreamingKeepAliveSeconds = 15
|
||||
defaultStreamingBootstrapRetries = 2
|
||||
)
|
||||
|
||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if strings.TrimSpace(errText) == "" {
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
|
||||
errType := "invalid_request_error"
|
||||
var code string
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
code = "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
code = "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
errType = "invalid_request_error"
|
||||
code = "model_not_found"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
code = "internal_server_error"
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
Code: code,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
|
||||
// Returning 0 disables keep-alives.
|
||||
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
seconds := defaultStreamingKeepAliveSeconds
|
||||
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
||||
seconds = *cfg.Streaming.KeepAliveSeconds
|
||||
}
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
retries := defaultStreamingBootstrapRetries
|
||||
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
||||
retries = *cfg.Streaming.BootstrapRetries
|
||||
}
|
||||
if retries < 0 {
|
||||
retries = 0
|
||||
}
|
||||
return retries
|
||||
}
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
key := ""
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
||||
}
|
||||
}
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
}
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
}
|
||||
|
||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(overlay) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(overlay))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range overlay {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service and manages
|
||||
// load balancing, client selection, and configuration.
|
||||
@@ -182,6 +293,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -195,9 +307,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -224,6 +334,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -237,9 +348,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -269,6 +378,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -282,9 +392,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
@@ -309,31 +417,81 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
go func() {
|
||||
defer close(dataChan)
|
||||
defer close(errChan)
|
||||
for chunk := range chunks {
|
||||
if chunk.Err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
status = code
|
||||
}
|
||||
}
|
||||
var addon http.Header
|
||||
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if hdr := he.Headers(); hdr != nil {
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
|
||||
return
|
||||
sentPayload := false
|
||||
bootstrapRetries := 0
|
||||
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
|
||||
|
||||
bootstrapEligible := func(err error) bool {
|
||||
status := statusFromError(err)
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
dataChan <- cloneBytes(chunk.Payload)
|
||||
switch status {
|
||||
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
|
||||
http.StatusRequestTimeout, http.StatusTooManyRequests:
|
||||
return true
|
||||
default:
|
||||
return status >= http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
outer:
|
||||
for {
|
||||
for chunk := range chunks {
|
||||
if chunk.Err != nil {
|
||||
streamErr := chunk.Err
|
||||
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
|
||||
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
|
||||
if !sentPayload {
|
||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||
bootstrapRetries++
|
||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if retryErr == nil {
|
||||
chunks = retryChunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
}
|
||||
}
|
||||
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
status = code
|
||||
}
|
||||
}
|
||||
var addon http.Header
|
||||
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if hdr := he.Headers(); hdr != nil {
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
sentPayload = true
|
||||
dataChan <- cloneBytes(chunk.Payload)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
@@ -417,38 +575,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer preserving upstream JSON error bodies when possible.
|
||||
buildJSONBody := func() []byte {
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
errType := "invalid_request_error"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
body := buildJSONBody()
|
||||
body := BuildErrorResponseBody(status, errText)
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
|
||||
if !c.Writer.Written() {
|
||||
|
||||
Reference in New Issue
Block a user