Refactor API request flow and streamline response handling

- Replaced `SendMessageStream` with synchronous `SendMessage` in API handlers for better manageability.
- Simplified `ConvertCliToOpenAINonStream` to reduce complexity and improve efficiency.
- Adjusted `client.go` functions to handle both streaming and non-streaming API requests more effectively.
- Improved error handling and channel communication in API handlers.
- Removed redundant and unused code for cleaner implementation.
This commit is contained in:
Luis Pater
2025-07-05 02:27:34 +08:00
parent bf086464dd
commit 512f2d5247
3 changed files with 199 additions and 123 deletions

View File

@@ -214,61 +214,6 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return nil
}
// StreamAPIRequest handles making streaming requests to the CLI API endpoints.
func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body interface{}) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
var err error
if byteBody, ok := body.([]byte); ok {
jsonBody = byteBody
} else {
jsonBody, err = json.Marshal(body)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
}
}
// log.Debug(string(jsonBody))
reqBody := bytes.NewBuffer(jsonBody)
// Add alt=sse for streaming
url := fmt.Sprintf("%s/%s:%s?alt=sse", codeAssistEndpoint, apiVersion, endpoint)
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)}
}
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)}
}
// Set headers
metadataStr := getClientMetadataString()
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", getUserAgent())
req.Header.Set("Client-Metadata", metadataStr)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
// return nil, fmt.Errorf("api streaming request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
return resp.Body, nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
@@ -331,7 +276,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// log.Debug(string(byteRequestBody))
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", byteRequestBody)
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
// log.Println(err)
errChan <- err
@@ -360,6 +305,129 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
// APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
var err error
if byteBody, ok := body.([]byte); ok {
jsonBody = byteBody
} else {
jsonBody, err = json.Marshal(body)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
}
}
// log.Debug(string(jsonBody))
reqBody := bytes.NewBuffer(jsonBody)
// Add alt=sse for streaming
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
if stream {
url = url + "?alt=sse"
}
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)}
}
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)}
}
// Set headers
metadataStr := getClientMetadataString()
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", getUserAgent())
req.Header.Set("Client-Metadata", metadataStr)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)}
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
}
return resp.Body, nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
if err != nil {
return nil, err
}
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
// that the Cloud AI API is enabled for the user's project. It provides
// an activation URL if the API is disabled.
@@ -374,7 +442,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
// A simple request to test the API endpoint.
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 {