mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Add /v1/completions endpoint with OpenAI compatibility
- Implemented `/v1/completions` endpoint mirroring OpenAI's completions API specification. - Added conversion functions to translate between completions and chat completions formats. - Introduced streaming and non-streaming response handling for completions requests. - Updated `server.go` to register the new endpoint and include it in the API's metadata.
This commit is contained in:
@@ -8,6 +8,7 @@ package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
|
||||
@@ -92,6 +94,276 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// Completions handles the /v1/completions endpoint.
|
||||
// It determines whether the request is for a streaming or non-streaming response
|
||||
// and calls the appropriate handler based on the model provider.
|
||||
// This endpoint follows the OpenAI completions API specification.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleCompletionsStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleCompletionsNonStreamingResponse(c, rawJSON)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
|
||||
// This allows the completions endpoint to use the existing chat completions infrastructure.
|
||||
//
|
||||
// Parameters:
|
||||
// - rawJSON: The raw JSON bytes of the completions request
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The converted chat completions request
|
||||
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Extract prompt from completions request
|
||||
prompt := root.Get("prompt").String()
|
||||
if prompt == "" {
|
||||
prompt = "Complete this:"
|
||||
}
|
||||
|
||||
// Create chat completions structure
|
||||
out := `{"model":"","messages":[{"role":"user","content":""}]}`
|
||||
|
||||
// Set model
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
}
|
||||
|
||||
// Set the prompt as user message content
|
||||
out, _ = sjson.Set(out, "messages.0.content", prompt)
|
||||
|
||||
// Copy other parameters from completions to chat completions
|
||||
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
if temperature := root.Get("temperature"); temperature.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temperature.Float())
|
||||
}
|
||||
|
||||
if topP := root.Get("top_p"); topP.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
|
||||
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
|
||||
}
|
||||
|
||||
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
|
||||
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
|
||||
}
|
||||
|
||||
if stop := root.Get("stop"); stop.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
|
||||
}
|
||||
|
||||
if stream := root.Get("stream"); stream.Exists() {
|
||||
out, _ = sjson.Set(out, "stream", stream.Bool())
|
||||
}
|
||||
|
||||
if logprobs := root.Get("logprobs"); logprobs.Exists() {
|
||||
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
|
||||
}
|
||||
|
||||
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
|
||||
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
|
||||
}
|
||||
|
||||
if echo := root.Get("echo"); echo.Exists() {
|
||||
out, _ = sjson.Set(out, "echo", echo.Bool())
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
||||
// This ensures the completions endpoint returns data in the expected format.
|
||||
//
|
||||
// Parameters:
|
||||
// - rawJSON: The raw JSON bytes of the chat completions response
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The converted completions response
|
||||
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Base completions response structure
|
||||
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
||||
|
||||
// Copy basic fields
|
||||
if id := root.Get("id"); id.Exists() {
|
||||
out, _ = sjson.Set(out, "id", id.String())
|
||||
}
|
||||
|
||||
if created := root.Get("created"); created.Exists() {
|
||||
out, _ = sjson.Set(out, "created", created.Int())
|
||||
}
|
||||
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
}
|
||||
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||
}
|
||||
|
||||
// Convert choices from chat completions to completions format
|
||||
var choices []interface{}
|
||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||
completionsChoice := map[string]interface{}{
|
||||
"index": choice.Get("index").Int(),
|
||||
}
|
||||
|
||||
// Extract text content from message.content
|
||||
if message := choice.Get("message"); message.Exists() {
|
||||
if content := message.Get("content"); content.Exists() {
|
||||
completionsChoice["text"] = content.String()
|
||||
}
|
||||
} else if delta := choice.Get("delta"); delta.Exists() {
|
||||
// For streaming responses, use delta.content
|
||||
if content := delta.Get("content"); content.Exists() {
|
||||
completionsChoice["text"] = content.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Copy finish_reason
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
||||
completionsChoice["finish_reason"] = finishReason.String()
|
||||
}
|
||||
|
||||
// Copy logprobs if present
|
||||
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
||||
completionsChoice["logprobs"] = logprobs.Value()
|
||||
}
|
||||
|
||||
choices = append(choices, completionsChoice)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if len(choices) > 0 {
|
||||
choicesJSON, _ := json.Marshal(choices)
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
|
||||
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
|
||||
//
|
||||
// Parameters:
|
||||
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The converted completions stream chunk, or nil if should be filtered out
|
||||
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
root := gjson.ParseBytes(chunkData)
|
||||
|
||||
// Check if this chunk has any meaningful content
|
||||
hasContent := false
|
||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||
// Check if delta has content or finish_reason
|
||||
if delta := choice.Get("delta"); delta.Exists() {
|
||||
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
||||
hasContent = true
|
||||
return false // Break out of forEach
|
||||
}
|
||||
}
|
||||
// Also check for finish_reason to ensure we don't skip final chunks
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
|
||||
hasContent = true
|
||||
return false // Break out of forEach
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// If no meaningful content, return nil to indicate this chunk should be skipped
|
||||
if !hasContent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Base completions stream response structure
|
||||
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
||||
|
||||
// Copy basic fields
|
||||
if id := root.Get("id"); id.Exists() {
|
||||
out, _ = sjson.Set(out, "id", id.String())
|
||||
}
|
||||
|
||||
if created := root.Get("created"); created.Exists() {
|
||||
out, _ = sjson.Set(out, "created", created.Int())
|
||||
}
|
||||
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
}
|
||||
|
||||
// Convert choices from chat completions delta to completions format
|
||||
var choices []interface{}
|
||||
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
|
||||
chatChoices.ForEach(func(_, choice gjson.Result) bool {
|
||||
completionsChoice := map[string]interface{}{
|
||||
"index": choice.Get("index").Int(),
|
||||
}
|
||||
|
||||
// Extract text content from delta.content
|
||||
if delta := choice.Get("delta"); delta.Exists() {
|
||||
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
||||
completionsChoice["text"] = content.String()
|
||||
} else {
|
||||
completionsChoice["text"] = ""
|
||||
}
|
||||
} else {
|
||||
completionsChoice["text"] = ""
|
||||
}
|
||||
|
||||
// Copy finish_reason
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
|
||||
completionsChoice["finish_reason"] = finishReason.String()
|
||||
}
|
||||
|
||||
// Copy logprobs if present
|
||||
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
|
||||
completionsChoice["logprobs"] = logprobs.Value()
|
||||
}
|
||||
|
||||
choices = append(choices, completionsChoice)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if len(choices) > 0 {
|
||||
choicesJSON, _ := json.Marshal(choices)
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses
|
||||
// for Gemini models. It selects a client from the pool, sends the request, and
|
||||
// aggregates the response before sending it back to the client in OpenAI format.
|
||||
@@ -251,3 +523,177 @@ outLoop:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
// It converts completions request to chat completions format, sends to backend,
|
||||
// then converts the response back to completions format before sending to client.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
||||
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// Convert completions request to chat completions format
|
||||
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
||||
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
var cliClient interfaces.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
retryCount := 0
|
||||
for retryCount <= h.Cfg.RequestRetry {
|
||||
var errorResponse *interfaces.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Send the converted chat completions request
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, modelName, chatCompletionsJSON, "")
|
||||
if err != nil {
|
||||
switch err.StatusCode {
|
||||
case 429:
|
||||
if h.Cfg.QuotaExceeded.SwitchProject {
|
||||
log.Debugf("quota exceeded, switch client")
|
||||
continue // Restart the client selection process
|
||||
}
|
||||
case 403, 408, 500, 502, 503, 504:
|
||||
log.Debugf("http status code %d, switch client", err.StatusCode)
|
||||
retryCount++
|
||||
continue
|
||||
default:
|
||||
// Forward other errors directly to the client
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
// Convert chat completions response back to completions format
|
||||
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
||||
_, _ = c.Writer.Write(completionsResp)
|
||||
cliCancel(completionsResp)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsStreamingResponse handles streaming completions responses.
|
||||
// It converts completions request to chat completions format, streams from backend,
|
||||
// then converts each response chunk back to completions format before sending to client.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
||||
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert completions request to chat completions format
|
||||
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
|
||||
|
||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
var cliClient interfaces.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
retryCount := 0
|
||||
outLoop:
|
||||
for retryCount <= h.Cfg.RequestRetry {
|
||||
var errorResponse *interfaces.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Send the converted chat completions request and receive response chunks
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, modelName, chatCompletionsJSON, "")
|
||||
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream is closed, send the final [DONE] message.
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Convert chat completions chunk to completions chunk format
|
||||
completionsChunk := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
// Skip this chunk if it has no meaningful content (empty text)
|
||||
if completionsChunk != nil {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(completionsChunk))
|
||||
flusher.Flush()
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
switch err.StatusCode {
|
||||
case 429:
|
||||
if h.Cfg.QuotaExceeded.SwitchProject {
|
||||
log.Debugf("quota exceeded, switch client")
|
||||
continue outLoop // Restart the client selection process
|
||||
}
|
||||
case 403, 408, 500, 502, 503, 504:
|
||||
log.Debugf("http status code %d, switch client", err.StatusCode)
|
||||
retryCount++
|
||||
continue outLoop
|
||||
default:
|
||||
// Forward other errors directly to the client
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ func (s *Server) setupRoutes() {
|
||||
{
|
||||
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
|
||||
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
}
|
||||
|
||||
@@ -123,6 +124,7 @@ func (s *Server) setupRoutes() {
|
||||
"version": "1.0.0",
|
||||
"endpoints": []string{
|
||||
"POST /v1/chat/completions",
|
||||
"POST /v1/completions",
|
||||
"GET /v1/models",
|
||||
},
|
||||
})
|
||||
|
||||
@@ -185,7 +185,7 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user