Files
CLIProxyAPI/internal/api/handlers/openai/openai_handlers.go
2025-09-25 10:32:48 +08:00

569 lines
18 KiB
Go

// Package openai provides HTTP handlers for OpenAI API endpoints.
// This package implements the OpenAI-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAI API requests to the appropriate backend format and
// convert responses back to OpenAI-compatible format.
package openai
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewOpenAIAPIHandler creates a new OpenAI API handlers instance.
// It takes an BaseAPIHandler instance as input and returns an OpenAIAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handlers instance
//
// Returns:
// - *OpenAIAPIHandler: A new OpenAI API handlers instance
func NewOpenAIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIAPIHandler {
return &OpenAIAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *OpenAIAPIHandler) HandlerType() string {
return OpenAI
}
// Models returns the OpenAI-compatible model metadata supported by this handler.
func (h *OpenAIAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("openai")
}
// OpenAIModels handles the /v1/models endpoint.
// It returns a list of available AI models with their capabilities
// and specifications in OpenAI-compatible format.
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
// Get all available models
allModels := h.Models()
// Filter to only include the 4 required fields: id, object, created, owned_by
filteredModels := make([]map[string]any, len(allModels))
for i, model := range allModels {
filteredModel := map[string]any{
"id": model["id"],
"object": model["object"],
}
// Add created field if it exists
if created, exists := model["created"]; exists {
filteredModel["created"] = created
}
// Add owned_by field if it exists
if ownedBy, exists := model["owned_by"]; exists {
filteredModel["owned_by"] = ownedBy
}
filteredModels[i] = filteredModel
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": filteredModels,
})
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
// Completions handles the /v1/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
// This endpoint follows the OpenAI completions API specification.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleCompletionsStreamingResponse(c, rawJSON)
} else {
h.handleCompletionsNonStreamingResponse(c, rawJSON)
}
}
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
// This allows the completions endpoint to use the existing chat completions infrastructure.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the completions request
//
// Returns:
// - []byte: The converted chat completions request
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Extract prompt from completions request
prompt := root.Get("prompt").String()
if prompt == "" {
prompt = "Complete this:"
}
// Create chat completions structure
out := `{"model":"","messages":[{"role":"user","content":""}]}`
// Set model
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Set the prompt as user message content
out, _ = sjson.Set(out, "messages.0.content", prompt)
// Copy other parameters from completions to chat completions
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
if temperature := root.Get("temperature"); temperature.Exists() {
out, _ = sjson.Set(out, "temperature", temperature.Float())
}
if topP := root.Get("top_p"); topP.Exists() {
out, _ = sjson.Set(out, "top_p", topP.Float())
}
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
}
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
}
if stop := root.Get("stop"); stop.Exists() {
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
}
if stream := root.Get("stream"); stream.Exists() {
out, _ = sjson.Set(out, "stream", stream.Bool())
}
if logprobs := root.Get("logprobs"); logprobs.Exists() {
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
}
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
}
if echo := root.Get("echo"); echo.Exists() {
out, _ = sjson.Set(out, "echo", echo.Bool())
}
return []byte(out)
}
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
// This ensures the completions endpoint returns data in the expected format.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the chat completions response
//
// Returns:
// - []byte: The converted completions response
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Base completions response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
// Convert choices from chat completions to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from message.content
if message := choice.Get("message"); message.Exists() {
if content := message.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
} else if delta := choice.Get("delta"); delta.Exists() {
// For streaming responses, use delta.content
if content := delta.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
return []byte(out)
}
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
//
// Parameters:
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
//
// Returns:
// - []byte: The converted completions stream chunk, or nil if should be filtered out
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
root := gjson.ParseBytes(chunkData)
// Check if this chunk has any meaningful content
hasContent := false
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
hasContent = true
return false // Break out of forEach
}
}
// Also check for finish_reason to ensure we don't skip final chunks
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
hasContent = true
return false // Break out of forEach
}
return true
})
}
// If no meaningful content, return nil to indicate this chunk should be skipped
if !hasContent {
return nil
}
// Base completions stream response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Convert choices from chat completions delta to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from delta.content
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
completionsChoice["text"] = content.String()
} else {
completionsChoice["text"] = ""
}
} else {
completionsChoice["text"] = ""
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
return []byte(out)
}
// handleNonStreamingResponse handles non-streaming chat completion responses
// for Gemini models. It selects a client from the pool, sends the request, and
// aggregates the response before sending it back to the client in OpenAI format.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
}
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
// It converts completions request to chat completions format, sends to backend,
// then converts the response back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
completionsResp := convertChatCompletionsResponseToCompletions(resp)
_, _ = c.Writer.Write(completionsResp)
cliCancel()
}
// handleCompletionsStreamingResponse handles streaming completions responses.
// It converts completions request to chat completions format, streams from backend,
// then converts each response chunk back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case chunk, isOk := <-dataChan:
if !isOk {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
}
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted != nil {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
flusher.Flush()
}
case errMsg, isOk := <-errChan:
if !isOk {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cliCancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cancel(nil)
return
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}