mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
- Updated all handlers to safely unlock the request mutex only if it's non-nil. - Enhanced mutex locking and unlocking logic to avoid runtime errors. - Improved robustness of resource cleanup across clients. Add `GetRequestMutex` method for synchronization across clients - Introduced a new `GetRequestMutex` method in OpenAICompatibilityClient, CodexClient, GeminiCLIClient, GeminiClient, and QwenClient for request synchronization. - Ensures only one request is processed at a time to manage quotas effectively.
708 lines
22 KiB
Go
708 lines
22 KiB
Go
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
|
// This package implements the OpenAI-compatible API interface, including model listing
|
|
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
|
// and manages a pool of clients to interact with backend services.
|
|
// The handlers translate OpenAI API requests to the appropriate backend format and
|
|
// convert responses back to OpenAI-compatible format.
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
|
|
// It holds a pool of clients to interact with the backend service.
|
|
type OpenAIAPIHandler struct {
|
|
*handlers.BaseAPIHandler
|
|
}
|
|
|
|
// NewOpenAIAPIHandler creates a new OpenAI API handlers instance.
|
|
// It takes an BaseAPIHandler instance as input and returns an OpenAIAPIHandler.
|
|
//
|
|
// Parameters:
|
|
// - apiHandlers: The base API handlers instance
|
|
//
|
|
// Returns:
|
|
// - *OpenAIAPIHandler: A new OpenAI API handlers instance
|
|
func NewOpenAIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIAPIHandler {
|
|
return &OpenAIAPIHandler{
|
|
BaseAPIHandler: apiHandlers,
|
|
}
|
|
}
|
|
|
|
// HandlerType returns the identifier for this handler implementation.
|
|
func (h *OpenAIAPIHandler) HandlerType() string {
|
|
return OPENAI
|
|
}
|
|
|
|
// Models returns the OpenAI-compatible model metadata supported by this handler.
|
|
func (h *OpenAIAPIHandler) Models() []map[string]any {
|
|
// Get dynamic models from the global registry
|
|
modelRegistry := registry.GetGlobalRegistry()
|
|
return modelRegistry.GetAvailableModels("openai")
|
|
}
|
|
|
|
// OpenAIModels handles the /v1/models endpoint.
|
|
// It returns a list of available AI models with their capabilities
|
|
// and specifications in OpenAI-compatible format.
|
|
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": h.Models(),
|
|
})
|
|
}
|
|
|
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
|
// It determines whether the request is for a streaming or non-streaming response
|
|
// and calls the appropriate handler based on the model provider.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
|
rawJSON, err := c.GetRawData()
|
|
// If data retrieval fails, return a 400 Bad Request error.
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if the client requested a streaming response.
|
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
|
if streamResult.Type == gjson.True {
|
|
h.handleStreamingResponse(c, rawJSON)
|
|
} else {
|
|
h.handleNonStreamingResponse(c, rawJSON)
|
|
}
|
|
|
|
}
|
|
|
|
// Completions handles the /v1/completions endpoint.
|
|
// It determines whether the request is for a streaming or non-streaming response
|
|
// and calls the appropriate handler based on the model provider.
|
|
// This endpoint follows the OpenAI completions API specification.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
|
|
rawJSON, err := c.GetRawData()
|
|
// If data retrieval fails, return a 400 Bad Request error.
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if the client requested a streaming response.
|
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
|
if streamResult.Type == gjson.True {
|
|
h.handleCompletionsStreamingResponse(c, rawJSON)
|
|
} else {
|
|
h.handleCompletionsNonStreamingResponse(c, rawJSON)
|
|
}
|
|
|
|
}
|
|
|
|
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
|
|
// This allows the completions endpoint to use the existing chat completions infrastructure.
|
|
//
|
|
// Parameters:
|
|
// - rawJSON: The raw JSON bytes of the completions request
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted chat completions request
|
|
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
|
root := gjson.ParseBytes(rawJSON)
|
|
|
|
// Extract prompt from completions request
|
|
prompt := root.Get("prompt").String()
|
|
if prompt == "" {
|
|
prompt = "Complete this:"
|
|
}
|
|
|
|
// Create chat completions structure
|
|
out := `{"model":"","messages":[{"role":"user","content":""}]}`
|
|
|
|
// Set model
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
// Set the prompt as user message content
|
|
out, _ = sjson.Set(out, "messages.0.content", prompt)
|
|
|
|
// Copy other parameters from completions to chat completions
|
|
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
|
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
|
}
|
|
|
|
if temperature := root.Get("temperature"); temperature.Exists() {
|
|
out, _ = sjson.Set(out, "temperature", temperature.Float())
|
|
}
|
|
|
|
if topP := root.Get("top_p"); topP.Exists() {
|
|
out, _ = sjson.Set(out, "top_p", topP.Float())
|
|
}
|
|
|
|
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
|
|
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
|
|
}
|
|
|
|
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
|
|
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
|
|
}
|
|
|
|
if stop := root.Get("stop"); stop.Exists() {
|
|
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
|
|
}
|
|
|
|
if stream := root.Get("stream"); stream.Exists() {
|
|
out, _ = sjson.Set(out, "stream", stream.Bool())
|
|
}
|
|
|
|
if logprobs := root.Get("logprobs"); logprobs.Exists() {
|
|
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
|
|
}
|
|
|
|
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
|
|
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
|
|
}
|
|
|
|
if echo := root.Get("echo"); echo.Exists() {
|
|
out, _ = sjson.Set(out, "echo", echo.Bool())
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
|
// This ensures the completions endpoint returns data in the expected format.
|
|
//
|
|
// Parameters:
|
|
// - rawJSON: The raw JSON bytes of the chat completions response
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted completions response
|
|
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
|
|
root := gjson.ParseBytes(rawJSON)
|
|
|
|
// Base completions response structure
|
|
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
|
|
|
// Copy basic fields
|
|
if id := root.Get("id"); id.Exists() {
|
|
out, _ = sjson.Set(out, "id", id.String())
|
|
}
|
|
|
|
if created := root.Get("created"); created.Exists() {
|
|
out, _ = sjson.Set(out, "created", created.Int())
|
|
}
|
|
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
if usage := root.Get("usage"); usage.Exists() {
|
|
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
|
}
|
|
|
|
// Convert choices from chat completions to completions format
|
|
var choices []interface{}
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
completionsChoice := map[string]interface{}{
|
|
"index": choice.Get("index").Int(),
|
|
}
|
|
|
|
// Extract text content from message.content
|
|
if message := choice.Get("message"); message.Exists() {
|
|
if content := message.Get("content"); content.Exists() {
|
|
completionsChoice["text"] = content.String()
|
|
}
|
|
} else if delta := choice.Get("delta"); delta.Exists() {
|
|
// For streaming responses, use delta.content
|
|
if content := delta.Get("content"); content.Exists() {
|
|
completionsChoice["text"] = content.String()
|
|
}
|
|
}
|
|
|
|
// Copy finish_reason
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
|
completionsChoice["finish_reason"] = finishReason.String()
|
|
}
|
|
|
|
// Copy logprobs if present
|
|
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
|
completionsChoice["logprobs"] = logprobs.Value()
|
|
}
|
|
|
|
choices = append(choices, completionsChoice)
|
|
return true
|
|
})
|
|
}
|
|
|
|
if len(choices) > 0 {
|
|
choicesJSON, _ := json.Marshal(choices)
|
|
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
|
|
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
|
|
//
|
|
// Parameters:
|
|
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted completions stream chunk, or nil if should be filtered out
|
|
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
|
root := gjson.ParseBytes(chunkData)
|
|
|
|
// Check if this chunk has any meaningful content
|
|
hasContent := false
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
// Check if delta has content or finish_reason
|
|
if delta := choice.Get("delta"); delta.Exists() {
|
|
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
|
hasContent = true
|
|
return false // Break out of forEach
|
|
}
|
|
}
|
|
// Also check for finish_reason to ensure we don't skip final chunks
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
|
|
hasContent = true
|
|
return false // Break out of forEach
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// If no meaningful content, return nil to indicate this chunk should be skipped
|
|
if !hasContent {
|
|
return nil
|
|
}
|
|
|
|
// Base completions stream response structure
|
|
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
|
|
|
// Copy basic fields
|
|
if id := root.Get("id"); id.Exists() {
|
|
out, _ = sjson.Set(out, "id", id.String())
|
|
}
|
|
|
|
if created := root.Get("created"); created.Exists() {
|
|
out, _ = sjson.Set(out, "created", created.Int())
|
|
}
|
|
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
// Convert choices from chat completions delta to completions format
|
|
var choices []interface{}
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
completionsChoice := map[string]interface{}{
|
|
"index": choice.Get("index").Int(),
|
|
}
|
|
|
|
// Extract text content from delta.content
|
|
if delta := choice.Get("delta"); delta.Exists() {
|
|
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
|
completionsChoice["text"] = content.String()
|
|
} else {
|
|
completionsChoice["text"] = ""
|
|
}
|
|
} else {
|
|
completionsChoice["text"] = ""
|
|
}
|
|
|
|
// Copy finish_reason
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
|
|
completionsChoice["finish_reason"] = finishReason.String()
|
|
}
|
|
|
|
// Copy logprobs if present
|
|
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
|
completionsChoice["logprobs"] = logprobs.Value()
|
|
}
|
|
|
|
choices = append(choices, completionsChoice)
|
|
return true
|
|
})
|
|
}
|
|
|
|
if len(choices) > 0 {
|
|
choicesJSON, _ := json.Marshal(choices)
|
|
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// handleNonStreamingResponse handles non-streaming chat completion responses
|
|
// for Gemini models. It selects a client from the pool, sends the request, and
|
|
// aggregates the response before sending it back to the client in OpenAI format.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
|
func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "application/json")
|
|
|
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
|
|
var cliClient interfaces.Client
|
|
defer func() {
|
|
if cliClient != nil {
|
|
if mutex := cliClient.GetRequestMutex(); mutex != nil {
|
|
mutex.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
retryCount := 0
|
|
for retryCount <= h.Cfg.RequestRetry {
|
|
var errorResponse *interfaces.ErrorMessage
|
|
cliClient, errorResponse = h.GetClient(modelName)
|
|
if errorResponse != nil {
|
|
c.Status(errorResponse.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
resp, err := cliClient.SendRawMessage(cliCtx, modelName, rawJSON, "")
|
|
if err != nil {
|
|
switch err.StatusCode {
|
|
case 429:
|
|
if h.Cfg.QuotaExceeded.SwitchProject {
|
|
log.Debugf("quota exceeded, switch client")
|
|
continue // Restart the client selection process
|
|
}
|
|
case 403, 408, 500, 502, 503, 504:
|
|
log.Debugf("http status code %d, switch client", err.StatusCode)
|
|
retryCount++
|
|
continue
|
|
default:
|
|
// Forward other errors directly to the client
|
|
c.Status(err.StatusCode)
|
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
cliCancel(err.Error)
|
|
}
|
|
break
|
|
} else {
|
|
_, _ = c.Writer.Write(resp)
|
|
cliCancel(resp)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleStreamingResponse handles streaming responses for Gemini models.
|
|
// It establishes a streaming connection with the backend service and forwards
|
|
// the response chunks to the client in real-time using Server-Sent Events.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
|
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
|
|
// Get the http.Flusher interface to manually flush the response.
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: "Streaming not supported",
|
|
Type: "server_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
|
|
var cliClient interfaces.Client
|
|
defer func() {
|
|
// Ensure the client's mutex is unlocked on function exit.
|
|
if cliClient != nil {
|
|
if mutex := cliClient.GetRequestMutex(); mutex != nil {
|
|
mutex.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
retryCount := 0
|
|
outLoop:
|
|
for retryCount <= h.Cfg.RequestRetry {
|
|
var errorResponse *interfaces.ErrorMessage
|
|
cliClient, errorResponse = h.GetClient(modelName)
|
|
if errorResponse != nil {
|
|
c.Status(errorResponse.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
|
flusher.Flush()
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
// Send the message and receive response chunks and errors via channels.
|
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, modelName, rawJSON, "")
|
|
|
|
for {
|
|
select {
|
|
// Handle client disconnection.
|
|
case <-c.Request.Context().Done():
|
|
if c.Request.Context().Err().Error() == "context canceled" {
|
|
log.Debugf("qwen client disconnected: %v", c.Request.Context().Err())
|
|
cliCancel() // Cancel the backend request.
|
|
return
|
|
}
|
|
// Process incoming response chunks.
|
|
case chunk, okStream := <-respChan:
|
|
if !okStream {
|
|
// Stream is closed, send the final [DONE] message.
|
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
|
flusher.Flush()
|
|
// Handle errors from the backend.
|
|
case err, okError := <-errChan:
|
|
if okError {
|
|
switch err.StatusCode {
|
|
case 429:
|
|
if h.Cfg.QuotaExceeded.SwitchProject {
|
|
log.Debugf("quota exceeded, switch client")
|
|
continue outLoop // Restart the client selection process
|
|
}
|
|
case 403, 408, 500, 502, 503, 504:
|
|
log.Debugf("http status code %d, switch client", err.StatusCode)
|
|
retryCount++
|
|
continue outLoop
|
|
default:
|
|
// Forward other errors directly to the client
|
|
c.Status(err.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
flusher.Flush()
|
|
cliCancel(err.Error)
|
|
}
|
|
return
|
|
}
|
|
// Send a keep-alive signal to the client.
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
|
// It converts completions request to chat completions format, sends to backend,
|
|
// then converts the response back to completions format before sending to client.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
|
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "application/json")
|
|
|
|
// Convert completions request to chat completions format
|
|
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
|
|
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
|
|
var cliClient interfaces.Client
|
|
defer func() {
|
|
if cliClient != nil {
|
|
if mutex := cliClient.GetRequestMutex(); mutex != nil {
|
|
mutex.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
retryCount := 0
|
|
for retryCount <= h.Cfg.RequestRetry {
|
|
var errorResponse *interfaces.ErrorMessage
|
|
cliClient, errorResponse = h.GetClient(modelName)
|
|
if errorResponse != nil {
|
|
c.Status(errorResponse.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
// Send the converted chat completions request
|
|
resp, err := cliClient.SendRawMessage(cliCtx, modelName, chatCompletionsJSON, "")
|
|
if err != nil {
|
|
switch err.StatusCode {
|
|
case 429:
|
|
if h.Cfg.QuotaExceeded.SwitchProject {
|
|
log.Debugf("quota exceeded, switch client")
|
|
continue // Restart the client selection process
|
|
}
|
|
case 403, 408, 500, 502, 503, 504:
|
|
log.Debugf("http status code %d, switch client", err.StatusCode)
|
|
retryCount++
|
|
continue
|
|
default:
|
|
// Forward other errors directly to the client
|
|
c.Status(err.StatusCode)
|
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
cliCancel(err.Error)
|
|
}
|
|
break
|
|
} else {
|
|
// Convert chat completions response back to completions format
|
|
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
|
_, _ = c.Writer.Write(completionsResp)
|
|
cliCancel(completionsResp)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleCompletionsStreamingResponse handles streaming completions responses.
|
|
// It converts completions request to chat completions format, streams from backend,
|
|
// then converts each response chunk back to completions format before sending to client.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
|
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
|
|
// Get the http.Flusher interface to manually flush the response.
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: "Streaming not supported",
|
|
Type: "server_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Convert completions request to chat completions format
|
|
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
|
|
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
|
|
var cliClient interfaces.Client
|
|
defer func() {
|
|
// Ensure the client's mutex is unlocked on function exit.
|
|
if cliClient != nil {
|
|
if mutex := cliClient.GetRequestMutex(); mutex != nil {
|
|
mutex.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
retryCount := 0
|
|
outLoop:
|
|
for retryCount <= h.Cfg.RequestRetry {
|
|
var errorResponse *interfaces.ErrorMessage
|
|
cliClient, errorResponse = h.GetClient(modelName)
|
|
if errorResponse != nil {
|
|
c.Status(errorResponse.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
|
flusher.Flush()
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
// Send the converted chat completions request and receive response chunks
|
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, modelName, chatCompletionsJSON, "")
|
|
|
|
for {
|
|
select {
|
|
// Handle client disconnection.
|
|
case <-c.Request.Context().Done():
|
|
if c.Request.Context().Err().Error() == "context canceled" {
|
|
log.Debugf("client disconnected: %v", c.Request.Context().Err())
|
|
cliCancel() // Cancel the backend request.
|
|
return
|
|
}
|
|
// Process incoming response chunks.
|
|
case chunk, okStream := <-respChan:
|
|
if !okStream {
|
|
// Stream is closed, send the final [DONE] message.
|
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
cliCancel()
|
|
return
|
|
}
|
|
|
|
// Convert chat completions chunk to completions chunk format
|
|
completionsChunk := convertChatCompletionsStreamChunkToCompletions(chunk)
|
|
// Skip this chunk if it has no meaningful content (empty text)
|
|
if completionsChunk != nil {
|
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(completionsChunk))
|
|
flusher.Flush()
|
|
}
|
|
// Handle errors from the backend.
|
|
case err, okError := <-errChan:
|
|
if okError {
|
|
switch err.StatusCode {
|
|
case 429:
|
|
if h.Cfg.QuotaExceeded.SwitchProject {
|
|
log.Debugf("quota exceeded, switch client")
|
|
continue outLoop // Restart the client selection process
|
|
}
|
|
case 403, 408, 500, 502, 503, 504:
|
|
log.Debugf("http status code %d, switch client", err.StatusCode)
|
|
retryCount++
|
|
continue outLoop
|
|
default:
|
|
// Forward other errors directly to the client
|
|
c.Status(err.StatusCode)
|
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
flusher.Flush()
|
|
cliCancel(err.Error)
|
|
}
|
|
return
|
|
}
|
|
// Send a keep-alive signal to the client.
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
}
|
|
}
|