Enhance quota management and refactor configuration handling

- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively.
- Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion.
- Refactored `APIHandlers` to leverage new configuration structure.
- Simplified server initialization and removed redundant `ServerConfig` structure.
- Streamlined client initialization by unifying configuration handling throughout the project.
- Improved error handling and response mechanisms in both streaming and non-streaming flows.
This commit is contained in:
Luis Pater
2025-07-05 07:53:46 +08:00
parent e73f165070
commit 7cb76ae1a5
6 changed files with 374 additions and 244 deletions

View File

@@ -1,7 +1,10 @@
port: 8317
auth_dir: "~/.cli-proxy-api"
auth-dir: "~/.cli-proxy-api"
debug: true
proxy-url: ""
api_keys:
quota-exceeded:
switch-project: true
switch-preview-model: true
api-keys:
- "12345"
- "23456"

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"net/http"
@@ -23,15 +24,15 @@ var (
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct {
cliClients []*client.Client
debug bool
cfg *config.Config
}
// NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{
cliClients: cliClients,
debug: debug,
cfg: cfg,
}
}
@@ -216,22 +217,37 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
}
}()
// Lock the mutex to update the last used page index
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the pages to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients))
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
@@ -246,10 +262,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
break
} else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp)
if openAIFormat != "" {
@@ -257,6 +278,8 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
flusher.Flush()
}
cliCancel()
break
}
}
}
@@ -290,34 +313,48 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
}
}()
// Use a round-robin approach to select the next available client.
// This distributes the load among the available clients.
outLoop:
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the clients to start from the next client in the rotation.
reorderedPages := make([]*client.Client, len(h.cliClients))
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
// Attempt to lock a client for the request.
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
// If no client is available, block and wait for the first client.
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
@@ -351,10 +388,14 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
@@ -365,4 +406,5 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
}
}
}
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
@@ -17,29 +18,19 @@ type Server struct {
engine *gin.Engine
server *http.Server
handlers *APIHandlers
cfg *ServerConfig
}
// ServerConfig contains the configuration for the API server.
type ServerConfig struct {
// Port is the port number the server will listen on.
Port string
// Debug enables or disables debug mode for the server and Gin.
Debug bool
// ApiKeys is a list of valid API keys for authentication.
ApiKeys []string
cfg *config.Config
}
// NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers.
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// Set gin mode
if !config.Debug {
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
// Create handlers
handlers := NewAPIHandlers(cliClients, config.Debug)
handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine
engine := gin.New()
@@ -53,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
s := &Server{
engine: engine,
handlers: handlers,
cfg: config,
cfg: cfg,
}
// Setup routes
@@ -61,7 +52,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Create HTTP server
s.server = &http.Server{
Addr: ":" + config.Port,
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: engine,
}
@@ -138,7 +129,7 @@ func corsMiddleware() gin.HandlerFunc {
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 {
c.Next()

View File

@@ -29,12 +29,20 @@ const (
pluginVersion = "0.1.9"
)
var (
previewModels = map[string][]string{
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
}
)
// Client is the main client for interacting with the CLI API.
type Client struct {
httpClient *http.Client
RequestMutex sync.Mutex
tokenStorage *auth.TokenStorage
cfg *config.Config
modelQuotaExceeded map[string]*time.Time
}
// NewClient creates a new CLI API client.
@@ -43,6 +51,7 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
httpClient: httpClient,
tokenStorage: ts,
cfg: cfg,
modelQuotaExceeded: make(map[string]*time.Time),
}
}
@@ -214,97 +223,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
// log.Println(err)
errChan <- err
return
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// log.Printf("Received stream chunk: %s", line)
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
// log.Println(err)
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
// APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
@@ -415,17 +333,192 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
modelName := model
// log.Debug(string(byteRequestBody))
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
modelName := model
var stream io.ReadCloser
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue
}
}
errChan <- err
return
}
delete(c.modelQuotaExceeded, modelName)
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// log.Printf("Received stream chunk: %s", line)
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
// log.Println(err)
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)
if duration > 30*time.Minute {
return false
}
return true
}
return false
}
func (c *Client) getPreviewModel(model string) string {
if models, hasKey := previewModels[model]; hasKey {
for i := 0; i < len(models); i++ {
if !c.isModelQuotaExceeded(models[i]) {
return models[i]
}
}
}
return ""
}
func (c *Client) IsModelQuotaExceeded(model string) bool {
if c.isModelQuotaExceeded(model) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
return c.getPreviewModel(model) == ""
}
return true
}
return false
}
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify

View File

@@ -3,7 +3,6 @@ package cmd
import (
"context"
"encoding/json"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
@@ -22,13 +21,6 @@ import (
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) {
// Configure the API server based on the main application config.
apiConfig := &api.ServerConfig{
Port: fmt.Sprintf("%d", cfg.Port),
Debug: cfg.Debug,
ApiKeys: cfg.ApiKeys,
}
// Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
@@ -73,8 +65,8 @@ func StartService(cfg *config.Config) {
}
// Create and start the API server with the pool of clients.
apiServer := api.NewServer(apiConfig, cliClients)
log.Infof("Starting API server on port %s", apiConfig.Port)
apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
}

View File

@@ -11,13 +11,22 @@ type Config struct {
// Port is the network port on which the API server will listen.
Port int `yaml:"port"`
// AuthDir is the directory where authentication token files are stored.
AuthDir string `yaml:"auth_dir"`
AuthDir string `yaml:"auth-dir"`
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api_keys"`
ApiKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
}
type ConfigQuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model"`
}
// LoadConfig reads a YAML configuration file from the given path,