diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 85bdfef9..cf28f010 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -243,8 +243,18 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) } log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - jsonTemplate := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) + + respChan := make(chan []byte) + errChan := make(chan *client.ErrorMessage) + go func() { + resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) + if err != nil { + errChan <- err + } else { + respChan <- resp + } + }() + for { select { case <-c.Request.Context().Done(): @@ -253,23 +263,20 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) cliCancel() return } - case chunk, okStream := <-respChan: - if !okStream { - _, _ = fmt.Fprint(c.Writer, jsonTemplate) + case respBody := <-respChan: + openAIFormat := translator.ConvertCliToOpenAINonStream(respBody) + if openAIFormat != "" { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) flusher.Flush() - cliCancel() - return - } else { - jsonTemplate = translator.ConvertCliToOpenAINonStream(jsonTemplate, chunk) - } - case err, okError := <-errChan: - if okError { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - return } + cliCancel() + return + case err := <-errChan: + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + return case <-time.After(500 * time.Millisecond): _, _ = c.Writer.Write([]byte("\n")) flusher.Flush() diff --git a/internal/api/translator/response.go b/internal/api/translator/response.go index 885b5f30..41e4fd01 100644 --- a/internal/api/translator/response.go +++ b/internal/api/translator/response.go @@ -88,34 +88,30 @@ func ConvertCliToOpenAI(rawJson []byte) string { return template } -// ConvertCliToOpenAINonStream aggregates response chunks from the backend client -// into a single, non-streaming OpenAI-compatible JSON response. -func ConvertCliToOpenAINonStream(template string, rawJson []byte) string { - // Extract and set metadata fields that are typically set once per response. - if gjson.Get(template, "id").String() == "" { - if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { - template, _ = sjson.Set(template, "model", modelVersionResult.String()) - } - if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - unixTimestamp := time.Now().Unix() - if err == nil { - unixTimestamp = t.Unix() - } - template, _ = sjson.Set(template, "created", unixTimestamp) - } - if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { - template, _ = sjson.Set(template, "id", responseIdResult.String()) +// ConvertCliToOpenAINonStream aggregates response from the backend client +// convert a single, non-streaming OpenAI-compatible JSON response. +func ConvertCliToOpenAINonStream(rawJson []byte) string { + template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { + template, _ = sjson.Set(template, "model", modelVersionResult.String()) + } + if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { + t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) + unixTimestamp := time.Now().Unix() + if err == nil { + unixTimestamp = t.Unix() } + template, _ = sjson.Set(template, "created", unixTimestamp) + } + if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { + template, _ = sjson.Set(template, "id", responseIdResult.String()) } - // Extract and set the finish reason. if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() { template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) } - // Extract and set usage metadata (token counts). if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() { if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) @@ -132,37 +128,42 @@ func ConvertCliToOpenAINonStream(template string, rawJson []byte) string { } // Process the main content part of the response. - partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0") - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") + partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts") + if partsResult.IsArray() { + partsResults := partsResult.Array() + for i := 0; i < len(partsResults); i++ { + partResult := partsResults[i] + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") - if partTextResult.Exists() { - // Append text content, distinguishing between regular content and reasoning. - if partResult.Get("thought").Bool() { - currentContent := gjson.Get(template, "choices.0.message.reasoning_content").String() - template, _ = sjson.Set(template, "choices.0.message.reasoning_content", currentContent+partTextResult.String()) - } else { - currentContent := gjson.Get(template, "choices.0.message.content").String() - template, _ = sjson.Set(template, "choices.0.message.content", currentContent+partTextResult.String()) + if partTextResult.Exists() { + // Append text content, distinguishing between regular content and reasoning. + if partResult.Get("thought").Bool() { + template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String()) + } else { + template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String()) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } else if functionCallResult.Exists() { + // Append function call content to the tool_calls array. + toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls") + if !toolCallsResult.Exists() || !toolCallsResult.IsArray() { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) + } + functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + fcName := functionCallResult.Get("name").String() + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) + } else { + // If no usable content is found, return an empty string. + return "" + } } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - } else if functionCallResult.Exists() { - // Append function call content to the tool_calls array. - if !gjson.Get(template, "choices.0.message.tool_calls").Exists() { - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) - } - functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` - fcName := functionCallResult.Get("name").String() - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName) - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) - } else { - // If no usable content is found, return an empty string. - return "" } return template diff --git a/internal/client/client.go b/internal/client/client.go index 07fb6d9c..0f189060 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -214,61 +214,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo return nil } -// StreamAPIRequest handles making streaming requests to the CLI API endpoints. -func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body interface{}) (io.ReadCloser, *ErrorMessage) { - var jsonBody []byte - var err error - if byteBody, ok := body.([]byte); ok { - jsonBody = byteBody - } else { - jsonBody, err = json.Marshal(body) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} - } - } - // log.Debug(string(jsonBody)) - reqBody := bytes.NewBuffer(jsonBody) - - // Add alt=sse for streaming - url := fmt.Sprintf("%s/%s:%s?alt=sse", codeAssistEndpoint, apiVersion, endpoint) - - req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)} - } - - token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)} - } - - // Set headers - metadataStr := getClientMetadataString() - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", getUserAgent()) - req.Header.Set("Client-Metadata", metadataStr) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)} - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - bodyBytes, _ := io.ReadAll(resp.Body) - - return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} - // return nil, fmt.Errorf("api streaming request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - return resp.Body, nil -} - // SendMessageStream handles a single conversational turn, including tool calls. func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { dataTag := []byte("data: ") @@ -331,7 +276,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st // log.Debug(string(byteRequestBody)) - stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", byteRequestBody) + stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true) if err != nil { // log.Println(err) errChan <- err @@ -360,6 +305,129 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st return dataChan, errChan } +// APIRequest handles making requests to the CLI API endpoints. +func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) { + var jsonBody []byte + var err error + if byteBody, ok := body.([]byte); ok { + jsonBody = byteBody + } else { + jsonBody, err = json.Marshal(body) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} + } + } + // log.Debug(string(jsonBody)) + reqBody := bytes.NewBuffer(jsonBody) + + // Add alt=sse for streaming + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) + if stream { + url = url + "?alt=sse" + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)} + } + + token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)} + } + + // Set headers + metadataStr := getClientMetadataString() + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", getUserAgent()) + req.Header.Set("Client-Metadata", metadataStr) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)} + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + + return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} + } + + return resp.Body, nil +} + +// SendMessageStream handles a single conversational turn, including tool calls. +func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { + request := GenerateContentRequest{ + Contents: contents, + GenerationConfig: GenerationConfig{ + ThinkingConfig: GenerationConfigThinkingConfig{ + IncludeThoughts: true, + }, + }, + } + request.Tools = tools + + requestBody := map[string]interface{}{ + "project": c.tokenStorage.ProjectID, // Assuming ProjectID is available + "request": request, + "model": model, + } + + byteRequestBody, _ := json.Marshal(requestBody) + + // log.Debug(string(byteRequestBody)) + + reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") + if reasoningEffortResult.String() == "none" { + byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) + } else if reasoningEffortResult.String() == "auto" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } else if reasoningEffortResult.String() == "low" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) + } else if reasoningEffortResult.String() == "medium" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) + } else if reasoningEffortResult.String() == "high" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) + } else { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } + + temperatureResult := gjson.GetBytes(rawJson, "temperature") + if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) + } + + topPResult := gjson.GetBytes(rawJson, "top_p") + if topPResult.Exists() && topPResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) + } + + topKResult := gjson.GetBytes(rawJson, "top_k") + if topKResult.Exists() && topKResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) + } + + // log.Debug(string(byteRequestBody)) + + respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false) + if err != nil { + return nil, err + } + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil +} + // CheckCloudAPIIsEnabled sends a simple test request to the API to verify // that the Cloud AI API is enabled for the user's project. It provides // an activation URL if the API is disabled. @@ -374,7 +442,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { // A simple request to test the API endpoint. requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID) - stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody)) + stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true) if err != nil { // If a 403 Forbidden error occurs, it likely means the API is not enabled. if err.StatusCode == 403 {