Enhance quota management and refactor configuration handling

- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively.
- Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion.
- Refactored `APIHandlers` to leverage new configuration structure.
- Simplified server initialization and removed redundant `ServerConfig` structure.
- Streamlined client initialization by unifying configuration handling throughout the project.
- Improved error handling and response mechanisms in both streaming and non-streaming flows.
This commit is contained in:
Luis Pater
2025-07-05 07:53:46 +08:00
parent e73f165070
commit 7cb76ae1a5
6 changed files with 374 additions and 244 deletions

View File

@@ -1,7 +1,10 @@
port: 8317 port: 8317
auth_dir: "~/.cli-proxy-api" auth-dir: "~/.cli-proxy-api"
debug: true debug: true
proxy-url: "" proxy-url: ""
api_keys: quota-exceeded:
switch-project: true
switch-preview-model: true
api-keys:
- "12345" - "12345"
- "23456" - "23456"

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"net/http" "net/http"
@@ -23,15 +24,15 @@ var (
// It holds a pool of clients to interact with the backend service. // It holds a pool of clients to interact with the backend service.
type APIHandlers struct { type APIHandlers struct {
cliClients []*client.Client cliClients []*client.Client
debug bool cfg *config.Config
} }
// NewAPIHandlers creates a new API handlers instance. // NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input. // It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{ return &APIHandlers{
cliClients: cliClients, cliClients: cliClients,
debug: debug, cfg: cfg,
} }
} }
@@ -216,22 +217,37 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
} }
}() }()
// Lock the mutex to update the last used page index for {
// Lock the mutex to update the last used client index
mutex.Lock() mutex.Lock()
startIndex := lastUsedClientIndex startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients) currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex lastUsedClientIndex = currentIndex
mutex.Unlock() mutex.Unlock()
// Reorder the pages to start from the last used index // Reorder the client to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients)) reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ { for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
} }
locked := false locked := false
for i := 0; i < len(reorderedPages); i++ { for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedPages[i] cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() { if cliClient.RequestMutex.TryLock() {
locked = true locked = true
break break
@@ -246,10 +262,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
if err != nil { if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error()) _, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush() flusher.Flush()
cliCancel() cliCancel()
}
break
} else { } else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp) openAIFormat := translator.ConvertCliToOpenAINonStream(resp)
if openAIFormat != "" { if openAIFormat != "" {
@@ -257,6 +278,8 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
flusher.Flush() flusher.Flush()
} }
cliCancel() cliCancel()
break
}
} }
} }
@@ -290,34 +313,48 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
} }
}() }()
// Use a round-robin approach to select the next available client. outLoop:
// This distributes the load among the available clients. for {
// Lock the mutex to update the last used client index
mutex.Lock() mutex.Lock()
startIndex := lastUsedClientIndex startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients) currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex lastUsedClientIndex = currentIndex
mutex.Unlock() mutex.Unlock()
// Reorder the clients to start from the next client in the rotation. // Reorder the client to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients)) reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ { for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
} }
// Attempt to lock a client for the request.
locked := false locked := false
for i := 0; i < len(reorderedPages); i++ { for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedPages[i] cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() { if cliClient.RequestMutex.TryLock() {
locked = true locked = true
break break
} }
} }
// If no client is available, block and wait for the first client.
if !locked { if !locked {
cliClient = h.cliClients[0] cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock() cliClient.RequestMutex.Lock()
} }
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
// Send the message and receive response chunks and errors via channels. // Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
@@ -351,10 +388,14 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
// Handle errors from the backend. // Handle errors from the backend.
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error()) _, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush() flusher.Flush()
cliCancel() cliCancel()
}
return return
} }
// Send a keep-alive signal to the client. // Send a keep-alive signal to the client.
@@ -365,4 +406,5 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
} }
} }
} }
}
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
"strings" "strings"
@@ -17,29 +18,19 @@ type Server struct {
engine *gin.Engine engine *gin.Engine
server *http.Server server *http.Server
handlers *APIHandlers handlers *APIHandlers
cfg *ServerConfig cfg *config.Config
}
// ServerConfig contains the configuration for the API server.
type ServerConfig struct {
// Port is the port number the server will listen on.
Port string
// Debug enables or disables debug mode for the server and Gin.
Debug bool
// ApiKeys is a list of valid API keys for authentication.
ApiKeys []string
} }
// NewServer creates and initializes a new API server instance. // NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers. // It sets up the Gin engine, middleware, routes, and handlers.
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// Set gin mode // Set gin mode
if !config.Debug { if !cfg.Debug {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
// Create handlers // Create handlers
handlers := NewAPIHandlers(cliClients, config.Debug) handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine // Create gin engine
engine := gin.New() engine := gin.New()
@@ -53,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
s := &Server{ s := &Server{
engine: engine, engine: engine,
handlers: handlers, handlers: handlers,
cfg: config, cfg: cfg,
} }
// Setup routes // Setup routes
@@ -61,7 +52,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Create HTTP server // Create HTTP server
s.server = &http.Server{ s.server = &http.Server{
Addr: ":" + config.Port, Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: engine, Handler: engine,
} }
@@ -138,7 +129,7 @@ func corsMiddleware() gin.HandlerFunc {
// AuthMiddleware returns a Gin middleware handler that authenticates requests // AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests. // using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc { func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 { if len(cfg.ApiKeys) == 0 {
c.Next() c.Next()

View File

@@ -29,12 +29,20 @@ const (
pluginVersion = "0.1.9" pluginVersion = "0.1.9"
) )
var (
previewModels = map[string][]string{
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
}
)
// Client is the main client for interacting with the CLI API. // Client is the main client for interacting with the CLI API.
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
RequestMutex sync.Mutex RequestMutex sync.Mutex
tokenStorage *auth.TokenStorage tokenStorage *auth.TokenStorage
cfg *config.Config cfg *config.Config
modelQuotaExceeded map[string]*time.Time
} }
// NewClient creates a new CLI API client. // NewClient creates a new CLI API client.
@@ -43,6 +51,7 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
httpClient: httpClient, httpClient: httpClient,
tokenStorage: ts, tokenStorage: ts,
cfg: cfg, cfg: cfg,
modelQuotaExceeded: make(map[string]*time.Time),
} }
} }
@@ -214,97 +223,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return nil return nil
} }
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
// log.Println(err)
errChan <- err
return
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// log.Printf("Received stream chunk: %s", line)
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
// log.Println(err)
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
// APIRequest handles making requests to the CLI API endpoints. // APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) { func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte var jsonBody []byte
@@ -415,17 +333,192 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
} }
modelName := model
// log.Debug(string(byteRequestBody)) // log.Debug(string(byteRequestBody))
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false) respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
if err != nil { if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue
}
}
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
} }
return bodyBytes, nil return bodyBytes, nil
}
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
modelName := model
var stream io.ReadCloser
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue
}
}
errChan <- err
return
}
delete(c.modelQuotaExceeded, modelName)
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// log.Printf("Received stream chunk: %s", line)
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
// log.Println(err)
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)
if duration > 30*time.Minute {
return false
}
return true
}
return false
}
func (c *Client) getPreviewModel(model string) string {
if models, hasKey := previewModels[model]; hasKey {
for i := 0; i < len(models); i++ {
if !c.isModelQuotaExceeded(models[i]) {
return models[i]
}
}
}
return ""
}
func (c *Client) IsModelQuotaExceeded(model string) bool {
if c.isModelQuotaExceeded(model) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
return c.getPreviewModel(model) == ""
}
return true
}
return false
} }
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify // CheckCloudAPIIsEnabled sends a simple test request to the API to verify

View File

@@ -3,7 +3,6 @@ package cmd
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api" "github.com/luispater/CLIProxyAPI/internal/api"
"github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
@@ -22,13 +21,6 @@ import (
// It loads all available authentication tokens, creates a pool of clients, // It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals. // starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) { func StartService(cfg *config.Config) {
// Configure the API server based on the main application config.
apiConfig := &api.ServerConfig{
Port: fmt.Sprintf("%d", cfg.Port),
Debug: cfg.Debug,
ApiKeys: cfg.ApiKeys,
}
// Create a pool of API clients, one for each token file found. // Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0) cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
@@ -73,8 +65,8 @@ func StartService(cfg *config.Config) {
} }
// Create and start the API server with the pool of clients. // Create and start the API server with the pool of clients.
apiServer := api.NewServer(apiConfig, cliClients) apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %s", apiConfig.Port) log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil { if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err) log.Fatalf("API server failed to start: %v", err)
} }

View File

@@ -11,13 +11,22 @@ type Config struct {
// Port is the network port on which the API server will listen. // Port is the network port on which the API server will listen.
Port int `yaml:"port"` Port int `yaml:"port"`
// AuthDir is the directory where authentication token files are stored. // AuthDir is the directory where authentication token files are stored.
AuthDir string `yaml:"auth_dir"` AuthDir string `yaml:"auth-dir"`
// Debug enables or disables debug-level logging and other debug features. // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"` Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests. // ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"` ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server. // ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api_keys"` ApiKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
}
type ConfigQuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model"`
} }
// LoadConfig reads a YAML configuration file from the given path, // LoadConfig reads a YAML configuration file from the given path,