mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
- Replaced `config.Config` with `config.SDKConfig` across components for simpler configuration management. - Updated proxy setup functions and handlers to align with `SDKConfig` improvements. - Reorganized handler imports to match new SDK structure.
569 lines
18 KiB
Go
569 lines
18 KiB
Go
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
|
// This package implements the OpenAI-compatible API interface, including model listing
|
|
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
|
// and manages a pool of clients to interact with backend services.
|
|
// The handlers translate OpenAI API requests to the appropriate backend format and
|
|
// convert responses back to OpenAI-compatible format.
|
|
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
|
|
// It holds a pool of clients to interact with the backend service.
|
|
type OpenAIAPIHandler struct {
|
|
*handlers.BaseAPIHandler
|
|
}
|
|
|
|
// NewOpenAIAPIHandler creates a new OpenAI API handlers instance.
|
|
// It takes an BaseAPIHandler instance as input and returns an OpenAIAPIHandler.
|
|
//
|
|
// Parameters:
|
|
// - apiHandlers: The base API handlers instance
|
|
//
|
|
// Returns:
|
|
// - *OpenAIAPIHandler: A new OpenAI API handlers instance
|
|
func NewOpenAIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIAPIHandler {
|
|
return &OpenAIAPIHandler{
|
|
BaseAPIHandler: apiHandlers,
|
|
}
|
|
}
|
|
|
|
// HandlerType returns the identifier for this handler implementation.
|
|
func (h *OpenAIAPIHandler) HandlerType() string {
|
|
return OpenAI
|
|
}
|
|
|
|
// Models returns the OpenAI-compatible model metadata supported by this handler.
|
|
func (h *OpenAIAPIHandler) Models() []map[string]any {
|
|
// Get dynamic models from the global registry
|
|
modelRegistry := registry.GetGlobalRegistry()
|
|
return modelRegistry.GetAvailableModels("openai")
|
|
}
|
|
|
|
// OpenAIModels handles the /v1/models endpoint.
|
|
// It returns a list of available AI models with their capabilities
|
|
// and specifications in OpenAI-compatible format.
|
|
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
|
|
// Get all available models
|
|
allModels := h.Models()
|
|
|
|
// Filter to only include the 4 required fields: id, object, created, owned_by
|
|
filteredModels := make([]map[string]any, len(allModels))
|
|
for i, model := range allModels {
|
|
filteredModel := map[string]any{
|
|
"id": model["id"],
|
|
"object": model["object"],
|
|
}
|
|
|
|
// Add created field if it exists
|
|
if created, exists := model["created"]; exists {
|
|
filteredModel["created"] = created
|
|
}
|
|
|
|
// Add owned_by field if it exists
|
|
if ownedBy, exists := model["owned_by"]; exists {
|
|
filteredModel["owned_by"] = ownedBy
|
|
}
|
|
|
|
filteredModels[i] = filteredModel
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": filteredModels,
|
|
})
|
|
}
|
|
|
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
|
// It determines whether the request is for a streaming or non-streaming response
|
|
// and calls the appropriate handler based on the model provider.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
|
rawJSON, err := c.GetRawData()
|
|
// If data retrieval fails, return a 400 Bad Request error.
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if the client requested a streaming response.
|
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
|
if streamResult.Type == gjson.True {
|
|
h.handleStreamingResponse(c, rawJSON)
|
|
} else {
|
|
h.handleNonStreamingResponse(c, rawJSON)
|
|
}
|
|
|
|
}
|
|
|
|
// Completions handles the /v1/completions endpoint.
|
|
// It determines whether the request is for a streaming or non-streaming response
|
|
// and calls the appropriate handler based on the model provider.
|
|
// This endpoint follows the OpenAI completions API specification.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
|
|
rawJSON, err := c.GetRawData()
|
|
// If data retrieval fails, return a 400 Bad Request error.
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
Type: "invalid_request_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Check if the client requested a streaming response.
|
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
|
if streamResult.Type == gjson.True {
|
|
h.handleCompletionsStreamingResponse(c, rawJSON)
|
|
} else {
|
|
h.handleCompletionsNonStreamingResponse(c, rawJSON)
|
|
}
|
|
|
|
}
|
|
|
|
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
|
|
// This allows the completions endpoint to use the existing chat completions infrastructure.
|
|
//
|
|
// Parameters:
|
|
// - rawJSON: The raw JSON bytes of the completions request
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted chat completions request
|
|
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
|
root := gjson.ParseBytes(rawJSON)
|
|
|
|
// Extract prompt from completions request
|
|
prompt := root.Get("prompt").String()
|
|
if prompt == "" {
|
|
prompt = "Complete this:"
|
|
}
|
|
|
|
// Create chat completions structure
|
|
out := `{"model":"","messages":[{"role":"user","content":""}]}`
|
|
|
|
// Set model
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
// Set the prompt as user message content
|
|
out, _ = sjson.Set(out, "messages.0.content", prompt)
|
|
|
|
// Copy other parameters from completions to chat completions
|
|
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
|
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
|
}
|
|
|
|
if temperature := root.Get("temperature"); temperature.Exists() {
|
|
out, _ = sjson.Set(out, "temperature", temperature.Float())
|
|
}
|
|
|
|
if topP := root.Get("top_p"); topP.Exists() {
|
|
out, _ = sjson.Set(out, "top_p", topP.Float())
|
|
}
|
|
|
|
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
|
|
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
|
|
}
|
|
|
|
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
|
|
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
|
|
}
|
|
|
|
if stop := root.Get("stop"); stop.Exists() {
|
|
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
|
|
}
|
|
|
|
if stream := root.Get("stream"); stream.Exists() {
|
|
out, _ = sjson.Set(out, "stream", stream.Bool())
|
|
}
|
|
|
|
if logprobs := root.Get("logprobs"); logprobs.Exists() {
|
|
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
|
|
}
|
|
|
|
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
|
|
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
|
|
}
|
|
|
|
if echo := root.Get("echo"); echo.Exists() {
|
|
out, _ = sjson.Set(out, "echo", echo.Bool())
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
|
// This ensures the completions endpoint returns data in the expected format.
|
|
//
|
|
// Parameters:
|
|
// - rawJSON: The raw JSON bytes of the chat completions response
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted completions response
|
|
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
|
|
root := gjson.ParseBytes(rawJSON)
|
|
|
|
// Base completions response structure
|
|
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
|
|
|
// Copy basic fields
|
|
if id := root.Get("id"); id.Exists() {
|
|
out, _ = sjson.Set(out, "id", id.String())
|
|
}
|
|
|
|
if created := root.Get("created"); created.Exists() {
|
|
out, _ = sjson.Set(out, "created", created.Int())
|
|
}
|
|
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
if usage := root.Get("usage"); usage.Exists() {
|
|
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
|
}
|
|
|
|
// Convert choices from chat completions to completions format
|
|
var choices []interface{}
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
completionsChoice := map[string]interface{}{
|
|
"index": choice.Get("index").Int(),
|
|
}
|
|
|
|
// Extract text content from message.content
|
|
if message := choice.Get("message"); message.Exists() {
|
|
if content := message.Get("content"); content.Exists() {
|
|
completionsChoice["text"] = content.String()
|
|
}
|
|
} else if delta := choice.Get("delta"); delta.Exists() {
|
|
// For streaming responses, use delta.content
|
|
if content := delta.Get("content"); content.Exists() {
|
|
completionsChoice["text"] = content.String()
|
|
}
|
|
}
|
|
|
|
// Copy finish_reason
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
|
completionsChoice["finish_reason"] = finishReason.String()
|
|
}
|
|
|
|
// Copy logprobs if present
|
|
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
|
completionsChoice["logprobs"] = logprobs.Value()
|
|
}
|
|
|
|
choices = append(choices, completionsChoice)
|
|
return true
|
|
})
|
|
}
|
|
|
|
if len(choices) > 0 {
|
|
choicesJSON, _ := json.Marshal(choices)
|
|
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
|
|
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
|
|
//
|
|
// Parameters:
|
|
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
|
|
//
|
|
// Returns:
|
|
// - []byte: The converted completions stream chunk, or nil if should be filtered out
|
|
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
|
root := gjson.ParseBytes(chunkData)
|
|
|
|
// Check if this chunk has any meaningful content
|
|
hasContent := false
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
// Check if delta has content or finish_reason
|
|
if delta := choice.Get("delta"); delta.Exists() {
|
|
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
|
hasContent = true
|
|
return false // Break out of forEach
|
|
}
|
|
}
|
|
// Also check for finish_reason to ensure we don't skip final chunks
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
|
|
hasContent = true
|
|
return false // Break out of forEach
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// If no meaningful content, return nil to indicate this chunk should be skipped
|
|
if !hasContent {
|
|
return nil
|
|
}
|
|
|
|
// Base completions stream response structure
|
|
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
|
|
|
// Copy basic fields
|
|
if id := root.Get("id"); id.Exists() {
|
|
out, _ = sjson.Set(out, "id", id.String())
|
|
}
|
|
|
|
if created := root.Get("created"); created.Exists() {
|
|
out, _ = sjson.Set(out, "created", created.Int())
|
|
}
|
|
|
|
if model := root.Get("model"); model.Exists() {
|
|
out, _ = sjson.Set(out, "model", model.String())
|
|
}
|
|
|
|
// Convert choices from chat completions delta to completions format
|
|
var choices []interface{}
|
|
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
|
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
|
completionsChoice := map[string]interface{}{
|
|
"index": choice.Get("index").Int(),
|
|
}
|
|
|
|
// Extract text content from delta.content
|
|
if delta := choice.Get("delta"); delta.Exists() {
|
|
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
|
completionsChoice["text"] = content.String()
|
|
} else {
|
|
completionsChoice["text"] = ""
|
|
}
|
|
} else {
|
|
completionsChoice["text"] = ""
|
|
}
|
|
|
|
// Copy finish_reason
|
|
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
|
|
completionsChoice["finish_reason"] = finishReason.String()
|
|
}
|
|
|
|
// Copy logprobs if present
|
|
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
|
completionsChoice["logprobs"] = logprobs.Value()
|
|
}
|
|
|
|
choices = append(choices, completionsChoice)
|
|
return true
|
|
})
|
|
}
|
|
|
|
if len(choices) > 0 {
|
|
choicesJSON, _ := json.Marshal(choices)
|
|
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
|
}
|
|
|
|
return []byte(out)
|
|
}
|
|
|
|
// handleNonStreamingResponse handles non-streaming chat completion responses
|
|
// for Gemini models. It selects a client from the pool, sends the request, and
|
|
// aggregates the response before sending it back to the client in OpenAI format.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
|
func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "application/json")
|
|
|
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
|
if errMsg != nil {
|
|
h.WriteErrorResponse(c, errMsg)
|
|
cliCancel(errMsg.Error)
|
|
return
|
|
}
|
|
_, _ = c.Writer.Write(resp)
|
|
cliCancel()
|
|
}
|
|
|
|
// handleStreamingResponse handles streaming responses for Gemini models.
|
|
// It establishes a streaming connection with the backend service and forwards
|
|
// the response chunks to the client in real-time using Server-Sent Events.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
|
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
|
|
// Get the http.Flusher interface to manually flush the response.
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: "Streaming not supported",
|
|
Type: "server_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
|
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
|
}
|
|
|
|
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
|
// It converts completions request to chat completions format, sends to backend,
|
|
// then converts the response back to completions format before sending to client.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
|
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "application/json")
|
|
|
|
// Convert completions request to chat completions format
|
|
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
|
|
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
|
if errMsg != nil {
|
|
h.WriteErrorResponse(c, errMsg)
|
|
cliCancel(errMsg.Error)
|
|
return
|
|
}
|
|
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
|
_, _ = c.Writer.Write(completionsResp)
|
|
cliCancel()
|
|
}
|
|
|
|
// handleCompletionsStreamingResponse handles streaming completions responses.
|
|
// It converts completions request to chat completions format, streams from backend,
|
|
// then converts each response chunk back to completions format before sending to client.
|
|
//
|
|
// Parameters:
|
|
// - c: The Gin context containing the HTTP request and response
|
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
|
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
|
|
// Get the http.Flusher interface to manually flush the response.
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
Error: handlers.ErrorDetail{
|
|
Message: "Streaming not supported",
|
|
Type: "server_error",
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Convert completions request to chat completions format
|
|
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
|
|
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
|
|
|
for {
|
|
select {
|
|
case <-c.Request.Context().Done():
|
|
cliCancel(c.Request.Context().Err())
|
|
return
|
|
case chunk, isOk := <-dataChan:
|
|
if !isOk {
|
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
cliCancel()
|
|
return
|
|
}
|
|
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
|
if converted != nil {
|
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
|
|
flusher.Flush()
|
|
}
|
|
case errMsg, isOk := <-errChan:
|
|
if !isOk {
|
|
continue
|
|
}
|
|
if errMsg != nil {
|
|
h.WriteErrorResponse(c, errMsg)
|
|
flusher.Flush()
|
|
}
|
|
var execErr error
|
|
if errMsg != nil {
|
|
execErr = errMsg.Error
|
|
}
|
|
cliCancel(execErr)
|
|
return
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
}
|
|
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
|
for {
|
|
select {
|
|
case <-c.Request.Context().Done():
|
|
cancel(c.Request.Context().Err())
|
|
return
|
|
case chunk, ok := <-data:
|
|
if !ok {
|
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
cancel(nil)
|
|
return
|
|
}
|
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
|
flusher.Flush()
|
|
case errMsg, ok := <-errs:
|
|
if !ok {
|
|
continue
|
|
}
|
|
if errMsg != nil {
|
|
h.WriteErrorResponse(c, errMsg)
|
|
flusher.Flush()
|
|
}
|
|
var execErr error
|
|
if errMsg != nil {
|
|
execErr = errMsg.Error
|
|
}
|
|
cancel(execErr)
|
|
return
|
|
case <-time.After(500 * time.Millisecond):
|
|
}
|
|
}
|
|
}
|