diff --git a/internal/api/gemini-handlers.go b/internal/api/gemini-handlers.go index cf56d36d..1ae70c7d 100644 --- a/internal/api/gemini-handlers.go +++ b/internal/api/gemini-handlers.go @@ -266,7 +266,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.getClient(modelName, false) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 0d88a9b7..d05ffe54 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -85,7 +85,7 @@ func (h *APIHandlers) Models(c *gin.Context) { }) } -func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) { +func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) { if len(h.cliClients) == 0 { return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")} } @@ -95,8 +95,10 @@ func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.Error // Lock the mutex to update the last used client index mutex.Lock() startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex + if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 { + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + } mutex.Unlock() // Reorder the client to start from the last used index