diff --git a/internal/api/cli-handlers.go b/internal/api/cli-handlers.go index b4fcf146..9e41bd54 100644 --- a/internal/api/cli-handlers.go +++ b/internal/api/cli-handlers.go @@ -141,7 +141,7 @@ outLoop: log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "") hasFirstResponse := false for { select { @@ -220,7 +220,7 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendRawMessage(cliCtx, rawJson) + resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "") if err != nil { if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { continue diff --git a/internal/api/gemini-handlers.go b/internal/api/gemini-handlers.go index 7c8f3082..798bb3da 100644 --- a/internal/api/gemini-handlers.go +++ b/internal/api/gemini-handlers.go @@ -15,7 +15,7 @@ import ( ) func (h *APIHandlers) GeminiModels(c *gin.Context) { - c.Status(200) + c.Status(http.StatusOK) c.Header("Content-Type", "application/json; charset=UTF-8") _, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `)) _, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`)) @@ -30,11 +30,11 @@ func (h *APIHandlers) GeminiModels(c *gin.Context) { _, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`)) } -func (h *APIHandlers) GeminiHandler(c *gin.Context) { - var person struct { +func (h *APIHandlers) GeminiGetHandler(c *gin.Context) { + var request struct { Action string `uri:"action" binding:"required"` } - if err := c.ShouldBindUri(&person); err != nil { + if err := c.ShouldBindUri(&request); err != nil { c.JSON(http.StatusBadRequest, ErrorResponse{ Error: ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), @@ -43,7 +43,45 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) { }) return } - action := strings.Split(person.Action, ":") + if request.Action == "gemini-2.5-pro" { + c.Status(http.StatusOK) + c.Header("Content-Type", "application/json; charset=UTF-8") + _, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`)) + _, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`)) + _, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`)) + _, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`)) + _, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`)) + } else if request.Action == "gemini-2.5-flash" { + c.Status(http.StatusOK) + c.Header("Content-Type", "application/json; charset=UTF-8") + _, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`)) + _, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `)) + _, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`)) + _, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`)) + _, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`)) + _, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`)) + } else { + c.Status(http.StatusNotFound) + _, _ = c.Writer.Write([]byte( + `{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`, + )) + } +} + +func (h *APIHandlers) GeminiHandler(c *gin.Context) { + var request struct { + Action string `uri:"action" binding:"required"` + } + if err := c.ShouldBindUri(&request); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + action := strings.Split(request.Action, ":") if len(action) != 2 { c.JSON(http.StatusNotFound, ErrorResponse{ Error: ErrorDetail{ @@ -63,6 +101,8 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) { h.geminiGenerateContent(c, rawJson) } else if method == "streamGenerateContent" { h.geminiStreamGenerateContent(c, rawJson) + } else if method == "countTokens" { + h.geminiCountTokens(c, rawJson) } } @@ -82,6 +122,8 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte modelResult := gjson.GetBytes(rawJson, "model") modelName := modelResult.String() + alt := h.getAlt(c) + cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -134,7 +176,7 @@ outLoop: } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt) for { select { // Handle client disconnection. @@ -151,14 +193,33 @@ outLoop: return } else { if cliClient.GetGenerativeLanguageAPIKey() == "" { - responseResult := gjson.GetBytes(chunk, "response") - if responseResult.Exists() { - chunk = []byte(responseResult.Raw) + if alt == "" { + responseResult := gjson.GetBytes(chunk, "response") + if responseResult.Exists() { + chunk = []byte(responseResult.Raw) + } + } else { + chunkTemplate := "[]" + responseResult := gjson.ParseBytes(chunk) + if responseResult.IsArray() { + responseResultItems := responseResult.Array() + for i := 0; i < len(responseResultItems); i++ { + responseResultItem := responseResultItems[i] + if responseResultItem.Get("response").Exists() { + chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) + } + } + } + chunk = []byte(chunkTemplate) } } - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } flusher.Flush() } // Handle errors from the backend. @@ -181,9 +242,71 @@ outLoop: } } +func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { + c.Header("Content-Type", "application/json") + + alt := h.getAlt(c) + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + + template := `{"request":{}}` + template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw) + template, _ = sjson.Delete(template, "generateContentRequest") + rawJson = []byte(template) + } + + resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt) + if err != nil { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + if cliClient.GetGenerativeLanguageAPIKey() == "" { + responseResult := gjson.GetBytes(resp, "response") + if responseResult.Exists() { + resp = []byte(responseResult.Raw) + } + } + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} + func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { c.Header("Content-Type", "application/json") + alt := h.getAlt(c) + modelResult := gjson.GetBytes(rawJson, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) @@ -233,7 +356,7 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { } else { log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendRawMessage(cliCtx, rawJson) + resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt) if err != nil { if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { continue @@ -256,3 +379,16 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { } } } + +func (h *APIHandlers) getAlt(c *gin.Context) string { + var alt string + var hasAlt bool + alt, hasAlt = c.GetQuery("alt") + if !hasAlt { + alt, _ = c.GetQuery("$alt") + } + if alt == "sse" { + return "" + } + return alt +} diff --git a/internal/api/server.go b/internal/api/server.go index 83003914..5774d6aa 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -77,6 +77,7 @@ func (s *Server) setupRoutes() { { v1beta.GET("/models", s.handlers.GeminiModels) v1beta.POST("/models/:action", s.handlers.GeminiHandler) + v1beta.GET("/models/:action", s.handlers.GeminiGetHandler) } // Root endpoint diff --git a/internal/client/client.go b/internal/client/client.go index 60b9567c..041bb24b 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -28,7 +28,7 @@ const ( apiVersion = "v1internal" pluginVersion = "0.1.9" - glEndPoint = "https://generativelanguage.googleapis.com/" + glEndPoint = "https://generativelanguage.googleapis.com" glApiVersion = "v1beta" ) @@ -241,7 +241,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo } // APIRequest handles making requests to the CLI API endpoints. -func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) { +func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) { var jsonBody []byte var err error if byteBody, ok := body.([]byte); ok { @@ -257,21 +257,30 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface if c.glAPIKey == "" { // Add alt=sse for streaming url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) - if stream { + if alt == "" && stream { url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", alt) } } else { - modelResult := gjson.GetBytes(jsonBody, "model") - url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) - if stream { - url = url + "?alt=sse" - } - jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) - systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction") - if systemInstructionResult.Exists() { - jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw)) - jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction") - jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id") + if endpoint == "countTokens" { + modelResult := gjson.GetBytes(jsonBody, "model") + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) + } else { + modelResult := gjson.GetBytes(jsonBody, "model") + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) + if alt == "" && stream { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", alt) + } + jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) + systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction") + if systemInstructionResult.Exists() { + jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw)) + jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction") + jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id") + } } } @@ -392,7 +401,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, } } - respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false) + respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -544,7 +553,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st // Attempt to establish a streaming connection with the API var err *ErrorMessage - stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true) + stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true) if err != nil { // Handle quota exceeded errors by marking the model and potentially retrying if err.StatusCode == 429 { @@ -593,8 +602,49 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st return dataChan, errChan } +// SendRawTokenCount handles a token count. +func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { + modelResult := gjson.GetBytes(rawJson, "model") + model := modelResult.String() + modelName := model + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } + + respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + } +} + // SendRawMessage handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) { +func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { if c.glAPIKey == "" { rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) } @@ -618,7 +668,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E } } - respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false) + respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -639,7 +689,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E } // SendRawMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) { +func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { dataTag := []byte("data: ") errChan := make(chan *ErrorMessage) dataChan := make(chan []byte) @@ -672,7 +722,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch return } var err *ErrorMessage - stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true) + stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -688,21 +738,32 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch break } - scanner := bufio.NewScanner(stream) - for scanner.Scan() { - line := scanner.Bytes() - if bytes.HasPrefix(line, dataTag) { - dataChan <- line[6:] + if alt == "" { + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Bytes() + if bytes.HasPrefix(line, dataTag) { + dataChan <- line[6:] + } } - } - if errScanner := scanner.Err(); errScanner != nil { - errChan <- &ErrorMessage{500, errScanner} - _ = stream.Close() - return - } + if errScanner := scanner.Err(); errScanner != nil { + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + } else { + data, err := io.ReadAll(stream) + if err != nil { + errChan <- &ErrorMessage{500, err} + _ = stream.Close() + return + } + dataChan <- data + } _ = stream.Close() + }() return dataChan, errChan @@ -754,7 +815,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { // A simple request to test the API endpoint. requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID) - stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true) + stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true) if err != nil { // If a 403 Forbidden error occurs, it likely means the API is not enabled. if err.StatusCode == 403 {