mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Add OpenAI compatibility support and improve resource cleanup
- Introduced OpenAI compatibility configurations for external providers, enabling model alias routing via the OpenAI API format. - Enhanced provider logic in `GetProviderName` to handle OpenAI aliases and added new helper functions for compatibility checks. - Updated API handlers and client initialization to support OpenAI compatibility models. - Improved resource cleanup across clients by closing response bodies and streams using deferred functions.
This commit is contained in:
@@ -148,7 +148,7 @@ outLoop:
|
||||
// Detects when the HTTP client has disconnected and cleans up resources
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||
log.Debugf("claude client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||
return
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ outLoop:
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
log.Debugf("gemini cli client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -290,7 +290,7 @@ outLoop:
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||
log.Debugf("gemini client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,6 +136,8 @@ func (h *BaseAPIHandler) GetClient(modelName string, isGenerateContent ...bool)
|
||||
log.Debugf("Claude Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||
} else if cliClient.Provider() == "qwen" {
|
||||
log.Debugf("Qwen Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||
} else if cliClient.Type() == "openai-compatibility" {
|
||||
log.Debugf("OpenAI Compatibility Model %s is quota exceeded for provider %s", modelName, cliClient.Provider())
|
||||
}
|
||||
cliClient = nil
|
||||
continue
|
||||
@@ -145,7 +147,7 @@ func (h *BaseAPIHandler) GetClient(modelName string, isGenerateContent ...bool)
|
||||
}
|
||||
|
||||
if len(reorderedClients) == 0 {
|
||||
if util.GetProviderName(modelName) == "claude" {
|
||||
if util.GetProviderName(modelName, h.Cfg) == "claude" {
|
||||
// log.Debugf("Claude Model %s is quota exceeded for all accounts", modelName)
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed your account's rate limit. Please try again later."}}`)}
|
||||
}
|
||||
|
||||
@@ -278,7 +278,7 @@ outLoop:
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
log.Debugf("qwen client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -183,6 +183,7 @@ func (c *ClaudeClient) SendRawMessage(ctx context.Context, modelName string, raw
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
var param any
|
||||
@@ -238,6 +239,9 @@ func (c *ClaudeClient) SendRawMessageStream(ctx context.Context, modelName strin
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
buffer := make([]byte, 10240*1024)
|
||||
|
||||
@@ -132,6 +132,7 @@ func (c *CodexClient) SendRawMessage(ctx context.Context, modelName string, rawJ
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
var param any
|
||||
@@ -188,6 +189,9 @@ func (c *CodexClient) SendRawMessageStream(ctx context.Context, modelName string
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
buffer := make([]byte, 10240*1024)
|
||||
|
||||
@@ -497,6 +497,7 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string,
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
newCtx := context.WithValue(ctx, "alt", alt)
|
||||
@@ -571,6 +572,11 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break
|
||||
}
|
||||
defer func() {
|
||||
if stream != nil {
|
||||
_ = stream.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
newCtx := context.WithValue(ctx, "alt", alt)
|
||||
var param any
|
||||
|
||||
@@ -249,6 +249,7 @@ func (c *GeminiClient) SendRawMessage(ctx context.Context, modelName string, raw
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
var param any
|
||||
@@ -301,6 +302,9 @@ func (c *GeminiClient) SendRawMessageStream(ctx context.Context, modelName strin
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
newCtx := context.WithValue(ctx, "alt", alt)
|
||||
var param any
|
||||
|
||||
360
internal/client/openai-compatibility_client.go
Normal file
360
internal/client/openai-compatibility_client.go
Normal file
@@ -0,0 +1,360 @@
|
||||
// Package client defines the interface and base structure for AI API clients.
|
||||
// It provides a common interface that all supported AI service clients must implement,
|
||||
// including methods for sending messages, handling streams, and managing authentication.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// OpenAICompatibilityClient implements the Client interface for external OpenAI-compatible API providers.
|
||||
// This client handles requests to external services that support OpenAI-compatible APIs,
|
||||
// such as OpenRouter, Together.ai, and other similar services.
|
||||
type OpenAICompatibilityClient struct {
|
||||
ClientBase
|
||||
compatConfig *config.OpenAICompatibility
|
||||
currentAPIKeyIndex int
|
||||
}
|
||||
|
||||
// NewOpenAICompatibilityClient creates a new OpenAI compatibility client instance.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration.
|
||||
// - compatConfig: The OpenAI compatibility configuration for the specific provider.
|
||||
//
|
||||
// Returns:
|
||||
// - *OpenAICompatibilityClient: A new OpenAI compatibility client instance.
|
||||
// - error: An error if the client creation fails.
|
||||
func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenAICompatibility) (*OpenAICompatibilityClient, error) {
|
||||
if compatConfig == nil {
|
||||
return nil, fmt.Errorf("compatibility configuration is required")
|
||||
}
|
||||
|
||||
if len(compatConfig.APIKeys) == 0 {
|
||||
return nil, fmt.Errorf("at least one API key is required for OpenAI compatibility provider: %s", compatConfig.Name)
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||
client := &OpenAICompatibilityClient{
|
||||
ClientBase: ClientBase{
|
||||
RequestMutex: &sync.Mutex{},
|
||||
httpClient: httpClient,
|
||||
cfg: cfg,
|
||||
modelQuotaExceeded: make(map[string]*time.Time),
|
||||
},
|
||||
compatConfig: compatConfig,
|
||||
currentAPIKeyIndex: 0,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Type returns the client type.
|
||||
func (c *OpenAICompatibilityClient) Type() string {
|
||||
return OPENAI
|
||||
}
|
||||
|
||||
// Provider returns the provider name for this client.
|
||||
func (c *OpenAICompatibilityClient) Provider() string {
|
||||
return c.compatConfig.Name
|
||||
}
|
||||
|
||||
// CanProvideModel checks if this client can provide the specified model alias.
|
||||
//
|
||||
// Parameters:
|
||||
// - modelName: The name/alias of the model to check.
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the model alias is supported, false otherwise.
|
||||
func (c *OpenAICompatibilityClient) CanProvideModel(modelName string) bool {
|
||||
for _, model := range c.compatConfig.Models {
|
||||
if model.Alias == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetUserAgent returns the user agent string for OpenAI compatibility API requests.
|
||||
func (c *OpenAICompatibilityClient) GetUserAgent() string {
|
||||
return fmt.Sprintf("cli-proxy-api-%s", c.compatConfig.Name)
|
||||
}
|
||||
|
||||
// TokenStorage returns nil as this client doesn't use traditional token storage.
|
||||
func (c *OpenAICompatibilityClient) TokenStorage() auth.TokenStorage {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentAPIKey returns the current API key to use, with rotation support.
|
||||
func (c *OpenAICompatibilityClient) GetCurrentAPIKey() string {
|
||||
if len(c.compatConfig.APIKeys) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
key := c.compatConfig.APIKeys[c.currentAPIKeyIndex]
|
||||
// Rotate to next key for load balancing
|
||||
c.currentAPIKeyIndex = (c.currentAPIKeyIndex + 1) % len(c.compatConfig.APIKeys)
|
||||
return key
|
||||
}
|
||||
|
||||
// GetActualModelName returns the actual model name to use with the external API
|
||||
// based on the provided alias.
|
||||
func (c *OpenAICompatibilityClient) GetActualModelName(alias string) string {
|
||||
for _, model := range c.compatConfig.Models {
|
||||
if model.Alias == alias {
|
||||
return model.Name
|
||||
}
|
||||
}
|
||||
return alias // fallback to alias if not found
|
||||
}
|
||||
|
||||
// APIRequest makes an HTTP request to the OpenAI-compatible API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request.
|
||||
// - modelName: The model name to use.
|
||||
// - endpoint: The API endpoint path.
|
||||
// - rawJSON: The raw JSON request data.
|
||||
// - alt: Alternative response format (not used for OpenAI compatibility).
|
||||
// - stream: Whether this is a streaming request.
|
||||
//
|
||||
// Returns:
|
||||
// - io.ReadCloser: The response body reader.
|
||||
// - *interfaces.ErrorMessage: An error message if the request fails.
|
||||
func (c *OpenAICompatibilityClient) APIRequest(ctx context.Context, modelName string, endpoint string, rawJSON []byte, alt string, stream bool) (io.ReadCloser, *interfaces.ErrorMessage) {
|
||||
// Replace the model alias with the actual model name in the request
|
||||
actualModelName := c.GetActualModelName(modelName)
|
||||
modifiedJSON, errReplace := sjson.SetBytes(rawJSON, "model", actualModelName)
|
||||
if errReplace != nil {
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to replace model name: %w", errReplace),
|
||||
}
|
||||
}
|
||||
|
||||
// Create the HTTP request
|
||||
url := strings.TrimSuffix(c.compatConfig.BaseURL, "/") + endpoint
|
||||
req, errReq := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(modifiedJSON))
|
||||
if errReq != nil {
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: fmt.Errorf("failed to create request: %w", errReq),
|
||||
}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
apiKey := c.GetCurrentAPIKey()
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||||
}
|
||||
req.Header.Set("User-Agent", c.GetUserAgent())
|
||||
|
||||
if stream {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
}
|
||||
|
||||
log.Debugf("OpenAI Compatibility [%s] API request: %s", c.compatConfig.Name, util.HideAPIKey(apiKey))
|
||||
|
||||
// Send the request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("failed to execute request: %v", err)}
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
// log.Debug(string(jsonBody))
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: resp.StatusCode, Error: fmt.Errorf("%s", string(bodyBytes))}
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// SendRawMessage sends a raw message to the OpenAI-compatible API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request.
|
||||
// - modelName: The model alias name to use.
|
||||
// - rawJSON: The raw JSON request data.
|
||||
// - alt: Alternative response format parameter.
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The response data from the API.
|
||||
// - *interfaces.ErrorMessage: An error message if the request fails.
|
||||
func (c *OpenAICompatibilityClient) SendRawMessage(ctx context.Context, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
handler := ctx.Value("handler").(interfaces.APIHandler)
|
||||
handlerType := handler.HandlerType()
|
||||
rawJSON = translator.Request(handlerType, c.Type(), modelName, rawJSON, false)
|
||||
|
||||
respBody, err := c.APIRequest(ctx, modelName, "/chat/completions", rawJSON, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
var param any
|
||||
bodyBytes = []byte(translator.ResponseNonStream(handlerType, c.Type(), ctx, modelName, bodyBytes, ¶m))
|
||||
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
// SendRawMessageStream sends a raw streaming message to the OpenAI-compatible API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request.
|
||||
// - modelName: The model alias name to use.
|
||||
// - rawJSON: The raw JSON request data.
|
||||
// - alt: Alternative response format parameter.
|
||||
//
|
||||
// Returns:
|
||||
// - <-chan []byte: A channel that will receive response chunks.
|
||||
// - <-chan *interfaces.ErrorMessage: A channel that will receive error messages.
|
||||
func (c *OpenAICompatibilityClient) SendRawMessageStream(ctx context.Context, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
handler := ctx.Value("handler").(interfaces.APIHandler)
|
||||
handlerType := handler.HandlerType()
|
||||
rawJSON = translator.Request(handlerType, c.Type(), modelName, rawJSON, true)
|
||||
|
||||
dataTag := []byte("data: ")
|
||||
doneTag := []byte("data: [DONE]")
|
||||
errChan := make(chan *interfaces.ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
// log.Debugf(string(rawJSON))
|
||||
// return dataChan, errChan
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
// Set streaming flag in the request
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
|
||||
newCtx := context.WithValue(ctx, "gin", ctx.Value("gin").(*gin.Context))
|
||||
|
||||
stream, err := c.APIRequest(newCtx, modelName, "/chat/completions", rawJSON, alt, true)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
|
||||
if translator.NeedConvert(handlerType, c.Type()) {
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
if bytes.Equal(line, doneTag) {
|
||||
break
|
||||
}
|
||||
lines := translator.Response(handlerType, c.Type(), newCtx, modelName, line[6:], ¶m)
|
||||
for i := 0; i < len(lines); i++ {
|
||||
dataChan <- []byte(lines[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No translation needed, stream data directly
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
if bytes.Equal(line, doneTag) {
|
||||
break
|
||||
}
|
||||
c.AddAPIResponseData(newCtx, line[6:])
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: 500, Error: scanner.Err()}
|
||||
}
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// SendRawTokenCount sends a token count request (not implemented for OpenAI compatibility).
|
||||
// This method is required by the Client interface but not supported by OpenAI compatibility clients.
|
||||
func (c *OpenAICompatibilityClient) SendRawTokenCount(ctx context.Context, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
return nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusNotImplemented,
|
||||
Error: fmt.Errorf("token counting not supported for OpenAI compatibility clients"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmail returns a placeholder email for this OpenAI compatibility client.
|
||||
// Since these clients don't use traditional email-based authentication,
|
||||
// we return the provider name as an identifier.
|
||||
func (c *OpenAICompatibilityClient) GetEmail() string {
|
||||
return fmt.Sprintf("openai-compatibility-%s", c.compatConfig.Name)
|
||||
}
|
||||
|
||||
// IsModelQuotaExceeded checks if the specified model has exceeded its quota.
|
||||
// For OpenAI compatibility clients, this is based on tracked quota exceeded times.
|
||||
func (c *OpenAICompatibilityClient) IsModelQuotaExceeded(model string) bool {
|
||||
if quota, exists := c.modelQuotaExceeded[model]; exists && quota != nil {
|
||||
// Check if quota exceeded time is less than 5 minutes ago
|
||||
if time.Since(*quota) < 5*time.Minute {
|
||||
return true
|
||||
}
|
||||
// Clear expired quota tracking
|
||||
delete(c.modelQuotaExceeded, model)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SaveTokenToFile returns nil as this client type doesn't use traditional token storage.
|
||||
func (c *OpenAICompatibilityClient) SaveTokenToFile() error {
|
||||
// No token file to save for OpenAI compatibility clients
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshTokens is not applicable for OpenAI compatibility clients as they use API keys.
|
||||
func (c *OpenAICompatibilityClient) RefreshTokens(ctx context.Context) error {
|
||||
// API keys don't need refreshing
|
||||
return nil
|
||||
}
|
||||
@@ -128,6 +128,7 @@ func (c *QwenClient) SendRawMessage(ctx context.Context, modelName string, rawJS
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
_ = respBody.Close()
|
||||
c.AddAPIResponseData(ctx, bodyBytes)
|
||||
|
||||
var param any
|
||||
@@ -186,6 +187,9 @@ func (c *QwenClient) SendRawMessageStream(ctx context.Context, modelName string,
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
buffer := make([]byte, 10240*1024)
|
||||
|
||||
@@ -150,6 +150,18 @@ func StartService(cfg *config.Config, configPath string) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.OpenAICompatibility) > 0 {
|
||||
// Initialize clients for OpenAI compatibility configurations
|
||||
for _, compatConfig := range cfg.OpenAICompatibility {
|
||||
log.Debugf("Initializing OpenAI compatibility client for provider: %s", compatConfig.Name)
|
||||
compatClient, errClient := client.NewOpenAICompatibilityClient(cfg, &compatConfig)
|
||||
if errClient != nil {
|
||||
log.Fatalf("failed to create OpenAI compatibility client for %s: %v", compatConfig.Name, errClient)
|
||||
}
|
||||
cliClients = append(cliClients, compatClient)
|
||||
}
|
||||
}
|
||||
|
||||
// Create and start the API server with the pool of clients in a separate goroutine.
|
||||
apiServer := api.NewServer(cfg, cliClients)
|
||||
log.Infof("Starting API server on port %d", cfg.Port)
|
||||
|
||||
@@ -41,6 +41,9 @@ type Config struct {
|
||||
RequestRetry int `yaml:"request-retry"`
|
||||
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key"`
|
||||
|
||||
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
|
||||
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -64,6 +67,32 @@ type ClaudeKey struct {
|
||||
BaseURL string `yaml:"base-url"`
|
||||
}
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||
type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url"`
|
||||
|
||||
// APIKeys are the authentication keys for accessing the external API services.
|
||||
APIKeys []string `yaml:"api-keys"`
|
||||
|
||||
// Models defines the model configurations including aliases for routing.
|
||||
Models []OpenAICompatibilityModel `yaml:"models"`
|
||||
}
|
||||
|
||||
// OpenAICompatibilityModel represents a model configuration for OpenAI compatibility,
|
||||
// including the actual model name and its alias for API routing.
|
||||
type OpenAICompatibilityModel struct {
|
||||
// Name is the actual model name used by the external provider.
|
||||
Name string `yaml:"name"`
|
||||
|
||||
// Alias is the model name alias that clients will use to reference this model.
|
||||
Alias string `yaml:"alias"`
|
||||
}
|
||||
|
||||
// LoadConfig reads a YAML configuration file from the given path,
|
||||
// unmarshals it into a Config struct, applies environment variable overrides,
|
||||
// and returns it.
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
GEMINI = "gemini"
|
||||
GEMINICLI = "gemini-cli"
|
||||
CODEX = "codex"
|
||||
CLAUDE = "claude"
|
||||
OPENAI = "openai"
|
||||
GEMINI = "gemini"
|
||||
GEMINICLI = "gemini-cli"
|
||||
CODEX = "codex"
|
||||
CLAUDE = "claude"
|
||||
OPENAI = "openai"
|
||||
OPENAI_COMPATIBILITY = "openai-compatibility"
|
||||
)
|
||||
|
||||
@@ -5,25 +5,33 @@ package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
)
|
||||
|
||||
// GetProviderName determines the AI service provider based on the model name.
|
||||
// It analyzes the model name string to identify which service provider it belongs to.
|
||||
// First checks for OpenAI compatibility aliases, then falls back to standard provider detection.
|
||||
//
|
||||
// Supported providers:
|
||||
// - "gemini" for Google's Gemini models
|
||||
// - "gpt" for OpenAI's GPT models
|
||||
// - "claude" for Anthropic's Claude models
|
||||
// - "qwen" for Alibaba's Qwen models
|
||||
// - "openai-compatibility" for external OpenAI-compatible providers
|
||||
// - "unknow" for unrecognized model names
|
||||
//
|
||||
// Parameters:
|
||||
// - modelName: The name of the model to identify the provider for.
|
||||
// - cfg: The application configuration containing OpenAI compatibility settings.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The name of the provider.
|
||||
func GetProviderName(modelName string) string {
|
||||
if strings.Contains(modelName, "gemini") {
|
||||
func GetProviderName(modelName string, cfg *config.Config) string {
|
||||
// First check if this model name is an OpenAI compatibility alias
|
||||
if IsOpenAICompatibilityAlias(modelName, cfg) {
|
||||
return "openai-compatibility"
|
||||
} else if strings.Contains(modelName, "gemini") { // Fall back to standard provider detection
|
||||
return "gemini"
|
||||
} else if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
@@ -37,6 +45,55 @@ func GetProviderName(modelName string) string {
|
||||
return "unknow"
|
||||
}
|
||||
|
||||
// IsOpenAICompatibilityAlias checks if the given model name is an alias
|
||||
// configured for OpenAI compatibility routing.
|
||||
//
|
||||
// Parameters:
|
||||
// - modelName: The model name to check
|
||||
// - cfg: The application configuration containing OpenAI compatibility settings
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the model name is an OpenAI compatibility alias, false otherwise
|
||||
func IsOpenAICompatibilityAlias(modelName string, cfg *config.Config) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, compat := range cfg.OpenAICompatibility {
|
||||
for _, model := range compat.Models {
|
||||
if model.Alias == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetOpenAICompatibilityConfig returns the OpenAI compatibility configuration
|
||||
// and model details for the given alias.
|
||||
//
|
||||
// Parameters:
|
||||
// - alias: The model alias to find configuration for
|
||||
// - cfg: The application configuration containing OpenAI compatibility settings
|
||||
//
|
||||
// Returns:
|
||||
// - *config.OpenAICompatibility: The matching compatibility configuration, or nil if not found
|
||||
// - *config.OpenAICompatibilityModel: The matching model configuration, or nil if not found
|
||||
func GetOpenAICompatibilityConfig(alias string, cfg *config.Config) (*config.OpenAICompatibility, *config.OpenAICompatibilityModel) {
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, compat := range cfg.OpenAICompatibility {
|
||||
for _, model := range compat.Models {
|
||||
if model.Alias == alias {
|
||||
return &compat, &model
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InArray checks if a string exists in a slice of strings.
|
||||
// It iterates through the slice and returns true if the target string is found,
|
||||
// otherwise it returns false.
|
||||
|
||||
Reference in New Issue
Block a user