Add GeminiGetHandler, enhance Gemini functionality, and enable token counting

- Added `GeminiGetHandler` for handling GET requests with extended Gemini model support.
- Introduced `geminiCountTokens` function to calculate token usage.
- Refactored `APIRequest` and related methods to support `alt` parameter for enhanced flexibility.
- Updated routes and request processing to integrate new handler and functions.
This commit is contained in:
Luis Pater
2025-07-26 06:51:49 +08:00
parent 423faae3da
commit 31a9e2d11f
4 changed files with 245 additions and 47 deletions

View File

@@ -28,7 +28,7 @@ const (
apiVersion = "v1internal"
pluginVersion = "0.1.9"
glEndPoint = "https://generativelanguage.googleapis.com/"
glEndPoint = "https://generativelanguage.googleapis.com"
glApiVersion = "v1beta"
)
@@ -241,7 +241,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
}
// APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
var err error
if byteBody, ok := body.([]byte); ok {
@@ -257,21 +257,30 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
if c.glAPIKey == "" {
// Add alt=sse for streaming
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
if stream {
if alt == "" && stream {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if stream {
url = url + "?alt=sse"
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
if endpoint == "countTokens" {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if alt == "" && stream {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
}
}
}
@@ -392,7 +401,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -544,7 +553,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// Attempt to establish a streaming connection with the API
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
if err != nil {
// Handle quota exceeded errors by marking the model and potentially retrying
if err.StatusCode == 429 {
@@ -593,8 +602,49 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
// SendRawTokenCount handles a token count.
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
}
@@ -618,7 +668,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
}
}
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false)
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -639,7 +689,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
}
// SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) {
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
@@ -672,7 +722,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true)
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -688,21 +738,32 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
if alt == "" {
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
} else {
data, err := io.ReadAll(stream)
if err != nil {
errChan <- &ErrorMessage{500, err}
_ = stream.Close()
return
}
dataChan <- data
}
_ = stream.Close()
}()
return dataChan, errChan
@@ -754,7 +815,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
// A simple request to test the API endpoint.
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 {