mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
Enhance quota management and refactor configuration handling
- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively. - Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion. - Refactored `APIHandlers` to leverage new configuration structure. - Simplified server initialization and removed redundant `ServerConfig` structure. - Streamlined client initialization by unifying configuration handling throughout the project. - Improved error handling and response mechanisms in both streaming and non-streaming flows.
This commit is contained in:
@@ -29,20 +29,29 @@ const (
|
||||
pluginVersion = "0.1.9"
|
||||
)
|
||||
|
||||
var (
|
||||
previewModels = map[string][]string{
|
||||
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
|
||||
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
|
||||
}
|
||||
)
|
||||
|
||||
// Client is the main client for interacting with the CLI API.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
RequestMutex sync.Mutex
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
httpClient *http.Client
|
||||
RequestMutex sync.Mutex
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
modelQuotaExceeded map[string]*time.Time
|
||||
}
|
||||
|
||||
// NewClient creates a new CLI API client.
|
||||
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client {
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
modelQuotaExceeded: make(map[string]*time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,97 +223,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
dataTag := []byte("data: ")
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
request.Tools = tools
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
||||
if err != nil {
|
||||
// log.Println(err)
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// log.Printf("Received stream chunk: %s", line)
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
// log.Println(err)
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// APIRequest handles making requests to the CLI API endpoints.
|
||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||
var jsonBody []byte
|
||||
@@ -415,17 +333,192 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
modelName := model
|
||||
// log.Debug(string(byteRequestBody))
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
dataTag := []byte("data: ")
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
request.Tools = tools
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
return
|
||||
}
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// log.Printf("Received stream chunk: %s", line)
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
// log.Println(err)
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func (c *Client) isModelQuotaExceeded(model string) bool {
|
||||
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||
duration := time.Now().Sub(*lastExceededTime)
|
||||
if duration > 30*time.Minute {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return bodyBytes, nil
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) getPreviewModel(model string) string {
|
||||
if models, hasKey := previewModels[model]; hasKey {
|
||||
for i := 0; i < len(models); i++ {
|
||||
if !c.isModelQuotaExceeded(models[i]) {
|
||||
return models[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
||||
if c.isModelQuotaExceeded(model) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
return c.getPreviewModel(model) == ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||
|
||||
Reference in New Issue
Block a user