mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Enhance quota management and refactor configuration handling
- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively. - Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion. - Refactored `APIHandlers` to leverage new configuration structure. - Simplified server initialization and removed redundant `ServerConfig` structure. - Streamlined client initialization by unifying configuration handling throughout the project. - Improved error handling and response mechanisms in both streaming and non-streaming flows.
This commit is contained in:
@@ -1,7 +1,10 @@
|
|||||||
port: 8317
|
port: 8317
|
||||||
auth_dir: "~/.cli-proxy-api"
|
auth-dir: "~/.cli-proxy-api"
|
||||||
debug: true
|
debug: true
|
||||||
proxy-url: ""
|
proxy-url: ""
|
||||||
api_keys:
|
quota-exceeded:
|
||||||
|
switch-project: true
|
||||||
|
switch-preview-model: true
|
||||||
|
api-keys:
|
||||||
- "12345"
|
- "12345"
|
||||||
- "23456"
|
- "23456"
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -23,15 +24,15 @@ var (
|
|||||||
// It holds a pool of clients to interact with the backend service.
|
// It holds a pool of clients to interact with the backend service.
|
||||||
type APIHandlers struct {
|
type APIHandlers struct {
|
||||||
cliClients []*client.Client
|
cliClients []*client.Client
|
||||||
debug bool
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIHandlers creates a new API handlers instance.
|
// NewAPIHandlers creates a new API handlers instance.
|
||||||
// It takes a slice of clients and a debug flag as input.
|
// It takes a slice of clients and a debug flag as input.
|
||||||
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
||||||
return &APIHandlers{
|
return &APIHandlers{
|
||||||
cliClients: cliClients,
|
cliClients: cliClients,
|
||||||
debug: debug,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,22 +217,37 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Lock the mutex to update the last used page index
|
for {
|
||||||
|
// Lock the mutex to update the last used client index
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
startIndex := lastUsedClientIndex
|
startIndex := lastUsedClientIndex
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||||
lastUsedClientIndex = currentIndex
|
lastUsedClientIndex = currentIndex
|
||||||
mutex.Unlock()
|
mutex.Unlock()
|
||||||
|
|
||||||
// Reorder the pages to start from the last used index
|
// Reorder the client to start from the last used index
|
||||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
reorderedClients := make([]*client.Client, 0)
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
for i := 0; i < len(h.cliClients); i++ {
|
||||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||||
|
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||||
|
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
cliClient = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reorderedClients = append(reorderedClients, cliClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reorderedClients) == 0 {
|
||||||
|
c.Status(429)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
locked := false
|
locked := false
|
||||||
for i := 0; i < len(reorderedPages); i++ {
|
for i := 0; i < len(reorderedClients); i++ {
|
||||||
cliClient = reorderedPages[i]
|
cliClient = reorderedClients[i]
|
||||||
if cliClient.RequestMutex.TryLock() {
|
if cliClient.RequestMutex.TryLock() {
|
||||||
locked = true
|
locked = true
|
||||||
break
|
break
|
||||||
@@ -246,10 +262,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
|
|
||||||
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
|
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
} else {
|
} else {
|
||||||
openAIFormat := translator.ConvertCliToOpenAINonStream(resp)
|
openAIFormat := translator.ConvertCliToOpenAINonStream(resp)
|
||||||
if openAIFormat != "" {
|
if openAIFormat != "" {
|
||||||
@@ -257,6 +278,8 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
cliCancel()
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,34 +313,48 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Use a round-robin approach to select the next available client.
|
outLoop:
|
||||||
// This distributes the load among the available clients.
|
for {
|
||||||
|
// Lock the mutex to update the last used client index
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
startIndex := lastUsedClientIndex
|
startIndex := lastUsedClientIndex
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||||
lastUsedClientIndex = currentIndex
|
lastUsedClientIndex = currentIndex
|
||||||
mutex.Unlock()
|
mutex.Unlock()
|
||||||
|
|
||||||
// Reorder the clients to start from the next client in the rotation.
|
// Reorder the client to start from the last used index
|
||||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
reorderedClients := make([]*client.Client, 0)
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
for i := 0; i < len(h.cliClients); i++ {
|
||||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||||
|
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||||
|
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
cliClient = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reorderedClients = append(reorderedClients, cliClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reorderedClients) == 0 {
|
||||||
|
c.Status(429)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to lock a client for the request.
|
|
||||||
locked := false
|
locked := false
|
||||||
for i := 0; i < len(reorderedPages); i++ {
|
for i := 0; i < len(reorderedClients); i++ {
|
||||||
cliClient = reorderedPages[i]
|
cliClient = reorderedClients[i]
|
||||||
if cliClient.RequestMutex.TryLock() {
|
if cliClient.RequestMutex.TryLock() {
|
||||||
locked = true
|
locked = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If no client is available, block and wait for the first client.
|
|
||||||
if !locked {
|
if !locked {
|
||||||
cliClient = h.cliClients[0]
|
cliClient = h.cliClients[0]
|
||||||
cliClient.RequestMutex.Lock()
|
cliClient.RequestMutex.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
// Send the message and receive response chunks and errors via channels.
|
// Send the message and receive response chunks and errors via channels.
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||||
@@ -351,10 +388,14 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
// Handle errors from the backend.
|
// Handle errors from the backend.
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Send a keep-alive signal to the client.
|
// Send a keep-alive signal to the client.
|
||||||
@@ -365,4 +406,5 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,29 +18,19 @@ type Server struct {
|
|||||||
engine *gin.Engine
|
engine *gin.Engine
|
||||||
server *http.Server
|
server *http.Server
|
||||||
handlers *APIHandlers
|
handlers *APIHandlers
|
||||||
cfg *ServerConfig
|
cfg *config.Config
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConfig contains the configuration for the API server.
|
|
||||||
type ServerConfig struct {
|
|
||||||
// Port is the port number the server will listen on.
|
|
||||||
Port string
|
|
||||||
// Debug enables or disables debug mode for the server and Gin.
|
|
||||||
Debug bool
|
|
||||||
// ApiKeys is a list of valid API keys for authentication.
|
|
||||||
ApiKeys []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates and initializes a new API server instance.
|
// NewServer creates and initializes a new API server instance.
|
||||||
// It sets up the Gin engine, middleware, routes, and handlers.
|
// It sets up the Gin engine, middleware, routes, and handlers.
|
||||||
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
||||||
// Set gin mode
|
// Set gin mode
|
||||||
if !config.Debug {
|
if !cfg.Debug {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create handlers
|
// Create handlers
|
||||||
handlers := NewAPIHandlers(cliClients, config.Debug)
|
handlers := NewAPIHandlers(cliClients, cfg)
|
||||||
|
|
||||||
// Create gin engine
|
// Create gin engine
|
||||||
engine := gin.New()
|
engine := gin.New()
|
||||||
@@ -53,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
|||||||
s := &Server{
|
s := &Server{
|
||||||
engine: engine,
|
engine: engine,
|
||||||
handlers: handlers,
|
handlers: handlers,
|
||||||
cfg: config,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
@@ -61,7 +52,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
|||||||
|
|
||||||
// Create HTTP server
|
// Create HTTP server
|
||||||
s.server = &http.Server{
|
s.server = &http.Server{
|
||||||
Addr: ":" + config.Port,
|
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||||
Handler: engine,
|
Handler: engine,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +129,7 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||||
// using API keys. If no API keys are configured, it allows all requests.
|
// using API keys. If no API keys are configured, it allows all requests.
|
||||||
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if len(cfg.ApiKeys) == 0 {
|
if len(cfg.ApiKeys) == 0 {
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -29,12 +29,20 @@ const (
|
|||||||
pluginVersion = "0.1.9"
|
pluginVersion = "0.1.9"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
previewModels = map[string][]string{
|
||||||
|
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
|
||||||
|
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// Client is the main client for interacting with the CLI API.
|
// Client is the main client for interacting with the CLI API.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
RequestMutex sync.Mutex
|
RequestMutex sync.Mutex
|
||||||
tokenStorage *auth.TokenStorage
|
tokenStorage *auth.TokenStorage
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
modelQuotaExceeded map[string]*time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new CLI API client.
|
// NewClient creates a new CLI API client.
|
||||||
@@ -43,6 +51,7 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
|
|||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
tokenStorage: ts,
|
tokenStorage: ts,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
modelQuotaExceeded: make(map[string]*time.Time),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,97 +223,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
|
||||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
|
||||||
dataTag := []byte("data: ")
|
|
||||||
errChan := make(chan *ErrorMessage)
|
|
||||||
dataChan := make(chan []byte)
|
|
||||||
go func() {
|
|
||||||
defer close(errChan)
|
|
||||||
defer close(dataChan)
|
|
||||||
|
|
||||||
request := GenerateContentRequest{
|
|
||||||
Contents: contents,
|
|
||||||
GenerationConfig: GenerationConfig{
|
|
||||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
|
||||||
IncludeThoughts: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
request.Tools = tools
|
|
||||||
|
|
||||||
requestBody := map[string]interface{}{
|
|
||||||
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
|
||||||
"request": request,
|
|
||||||
"model": model,
|
|
||||||
}
|
|
||||||
|
|
||||||
byteRequestBody, _ := json.Marshal(requestBody)
|
|
||||||
|
|
||||||
// log.Debug(string(byteRequestBody))
|
|
||||||
|
|
||||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
|
||||||
if reasoningEffortResult.String() == "none" {
|
|
||||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
|
||||||
} else if reasoningEffortResult.String() == "auto" {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
} else if reasoningEffortResult.String() == "low" {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
|
||||||
} else if reasoningEffortResult.String() == "medium" {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
|
||||||
} else if reasoningEffortResult.String() == "high" {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
|
||||||
} else {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
|
||||||
}
|
|
||||||
|
|
||||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
|
||||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
|
||||||
}
|
|
||||||
|
|
||||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
|
||||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
|
||||||
}
|
|
||||||
|
|
||||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
|
||||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
|
||||||
}
|
|
||||||
|
|
||||||
// log.Debug(string(byteRequestBody))
|
|
||||||
|
|
||||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
|
||||||
if err != nil {
|
|
||||||
// log.Println(err)
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(stream)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
// log.Printf("Received stream chunk: %s", line)
|
|
||||||
if bytes.HasPrefix(line, dataTag) {
|
|
||||||
dataChan <- line[6:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errScanner := scanner.Err(); errScanner != nil {
|
|
||||||
// log.Println(err)
|
|
||||||
errChan <- &ErrorMessage{500, errScanner}
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = stream.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return dataChan, errChan
|
|
||||||
}
|
|
||||||
|
|
||||||
// APIRequest handles making requests to the CLI API endpoints.
|
// APIRequest handles making requests to the CLI API endpoints.
|
||||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
|
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||||
var jsonBody []byte
|
var jsonBody []byte
|
||||||
@@ -415,17 +333,192 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelName := model
|
||||||
// log.Debug(string(byteRequestBody))
|
// log.Debug(string(byteRequestBody))
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
}
|
}
|
||||||
return bodyBytes, nil
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||||
|
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
dataTag := []byte("data: ")
|
||||||
|
errChan := make(chan *ErrorMessage)
|
||||||
|
dataChan := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(dataChan)
|
||||||
|
|
||||||
|
request := GenerateContentRequest{
|
||||||
|
Contents: contents,
|
||||||
|
GenerationConfig: GenerationConfig{
|
||||||
|
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
request.Tools = tools
|
||||||
|
|
||||||
|
requestBody := map[string]interface{}{
|
||||||
|
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
||||||
|
"request": request,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
byteRequestBody, _ := json.Marshal(requestBody)
|
||||||
|
|
||||||
|
// log.Debug(string(byteRequestBody))
|
||||||
|
|
||||||
|
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||||
|
if reasoningEffortResult.String() == "none" {
|
||||||
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||||
|
} else if reasoningEffortResult.String() == "auto" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
} else if reasoningEffortResult.String() == "low" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
|
} else if reasoningEffortResult.String() == "medium" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
|
} else if reasoningEffortResult.String() == "high" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||||
|
} else {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||||
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||||
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||||
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
// log.Debug(string(byteRequestBody))
|
||||||
|
modelName := model
|
||||||
|
var stream io.ReadCloser
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errChan <- &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err *ErrorMessage
|
||||||
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
// log.Printf("Received stream chunk: %s", line)
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
dataChan <- line[6:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
|
// log.Println(err)
|
||||||
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return dataChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) isModelQuotaExceeded(model string) bool {
|
||||||
|
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||||
|
duration := time.Now().Sub(*lastExceededTime)
|
||||||
|
if duration > 30*time.Minute {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getPreviewModel(model string) string {
|
||||||
|
if models, hasKey := previewModels[model]; hasKey {
|
||||||
|
for i := 0; i < len(models); i++ {
|
||||||
|
if !c.isModelQuotaExceeded(models[i]) {
|
||||||
|
return models[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
||||||
|
if c.isModelQuotaExceeded(model) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
return c.getPreviewModel(model) == ""
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api"
|
"github.com/luispater/CLIProxyAPI/internal/api"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
@@ -22,13 +21,6 @@ import (
|
|||||||
// It loads all available authentication tokens, creates a pool of clients,
|
// It loads all available authentication tokens, creates a pool of clients,
|
||||||
// starts the API server, and handles graceful shutdown signals.
|
// starts the API server, and handles graceful shutdown signals.
|
||||||
func StartService(cfg *config.Config) {
|
func StartService(cfg *config.Config) {
|
||||||
// Configure the API server based on the main application config.
|
|
||||||
apiConfig := &api.ServerConfig{
|
|
||||||
Port: fmt.Sprintf("%d", cfg.Port),
|
|
||||||
Debug: cfg.Debug,
|
|
||||||
ApiKeys: cfg.ApiKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a pool of API clients, one for each token file found.
|
// Create a pool of API clients, one for each token file found.
|
||||||
cliClients := make([]*client.Client, 0)
|
cliClients := make([]*client.Client, 0)
|
||||||
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
@@ -73,8 +65,8 @@ func StartService(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create and start the API server with the pool of clients.
|
// Create and start the API server with the pool of clients.
|
||||||
apiServer := api.NewServer(apiConfig, cliClients)
|
apiServer := api.NewServer(cfg, cliClients)
|
||||||
log.Infof("Starting API server on port %s", apiConfig.Port)
|
log.Infof("Starting API server on port %d", cfg.Port)
|
||||||
if err = apiServer.Start(); err != nil {
|
if err = apiServer.Start(); err != nil {
|
||||||
log.Fatalf("API server failed to start: %v", err)
|
log.Fatalf("API server failed to start: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,13 +11,22 @@ type Config struct {
|
|||||||
// Port is the network port on which the API server will listen.
|
// Port is the network port on which the API server will listen.
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
// AuthDir is the directory where authentication token files are stored.
|
// AuthDir is the directory where authentication token files are stored.
|
||||||
AuthDir string `yaml:"auth_dir"`
|
AuthDir string `yaml:"auth-dir"`
|
||||||
// Debug enables or disables debug-level logging and other debug features.
|
// Debug enables or disables debug-level logging and other debug features.
|
||||||
Debug bool `yaml:"debug"`
|
Debug bool `yaml:"debug"`
|
||||||
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyUrl string `yaml:"proxy-url"`
|
ProxyUrl string `yaml:"proxy-url"`
|
||||||
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
ApiKeys []string `yaml:"api_keys"`
|
ApiKeys []string `yaml:"api-keys"`
|
||||||
|
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||||
|
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConfigQuotaExceeded struct {
|
||||||
|
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
||||||
|
SwitchProject bool `yaml:"switch-project"`
|
||||||
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
|
SwitchPreviewModel bool `yaml:"switch-preview-model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig reads a YAML configuration file from the given path,
|
// LoadConfig reads a YAML configuration file from the given path,
|
||||||
|
|||||||
Reference in New Issue
Block a user