From 589ae6d3aafe7fde8cb0b93b3175dd1201fdcb9b Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 6 Jul 2025 02:04:39 +0800 Subject: [PATCH] Add support for Generative Language API Key and improve client initialization - Added `GlAPIKey` support in configuration to enable Generative Language API. - Integrated `GenerativeLanguageAPIKey` handling in client and API handlers. - Updated response translators to manage generative language responses properly. - Enhanced HTTP client initialization logic with proxy support for API requests. - Refactored streaming and non-streaming flows to account for generative language-specific logic. --- README.md | 16 +++--- config.yaml | 7 ++- internal/api/handlers.go | 20 ++++++-- internal/api/translator/response.go | 19 +++++-- internal/client/client.go | 78 ++++++++++++++++++++--------- internal/cmd/run.go | 38 ++++++++++++++ internal/config/config.go | 2 + 7 files changed, 140 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 018d64bb..0006863a 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ A proxy server that provides an OpenAI-compatible API interface for CLI. This al - Multimodal input support (text and images) - Multiple account support with load balancing - Simple CLI authentication flow +- Support for Generative Language API Key ## Installation @@ -146,13 +147,14 @@ The server uses a YAML configuration file (`config.yaml`) located in the project ### Configuration Options -| Parameter | Type | Default | Description | -|-------------|----------|--------------------|----------------------------------------------------------------------------------------------| -| `port` | integer | 8317 | The port number on which the server will listen | -| `auth_dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory | -| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ | -| `debug` | boolean | false | Enable debug mode for verbose logging | -| `api_keys` | string[] | [] | List of API keys that can be used to authenticate requests | +| Parameter | Type | Default | Description | +|-------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------| +| `port` | integer | 8317 | The port number on which the server will listen | +| `auth_dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory | +| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ | +| `debug` | boolean | false | Enable debug mode for verbose logging | +| `api_keys` | string[] | [] | List of API keys that can be used to authenticate requests | +| `generative-language-api-key` | string[] | [] | List of Generative Language API keys | ### Example Configuration File diff --git a/config.yaml b/config.yaml index d5e7e16a..552276d4 100644 --- a/config.yaml +++ b/config.yaml @@ -7,4 +7,9 @@ quota-exceeded: switch-preview-model: true api-keys: - "12345" - - "23456" \ No newline at end of file + - "23456" +generative-language-api-key: + - "AIzaSy...01" + - "AIzaSy...02" + - "AIzaSy...03" + - "AIzaSy...04" diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 9c399d9b..edf8ddc7 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -258,7 +258,13 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) cliClient.RequestMutex.Lock() } - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + isGlAPIKey := false + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + isGlAPIKey = true + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) if err != nil { @@ -272,7 +278,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) } break } else { - openAIFormat := translator.ConvertCliToOpenAINonStream(resp) + openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) if openAIFormat != "" { _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) flusher.Flush() @@ -355,7 +361,13 @@ outLoop: cliClient.RequestMutex.Lock() } - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + isGlAPIKey := false + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + isGlAPIKey = true + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } // Send the message and receive response chunks and errors via channels. respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) hasFirstResponse := false @@ -379,7 +391,7 @@ outLoop: } else { // Convert the chunk to OpenAI format and send it to the client. hasFirstResponse = true - openAIFormat := translator.ConvertCliToOpenAI(chunk) + openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey) if openAIFormat != "" { _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) flusher.Flush() diff --git a/internal/api/translator/response.go b/internal/api/translator/response.go index 41e4fd01..2fce1679 100644 --- a/internal/api/translator/response.go +++ b/internal/api/translator/response.go @@ -10,7 +10,11 @@ import ( // ConvertCliToOpenAI translates a single chunk of a streaming response from the // backend client format to the OpenAI Server-Sent Events (SSE) format. // It returns an empty string if the chunk contains no useful data. -func ConvertCliToOpenAI(rawJson []byte) string { +func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string { + if isGlAPIKey { + rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) + } + // Initialize the OpenAI SSE template. template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` @@ -22,11 +26,12 @@ func ConvertCliToOpenAI(rawJson []byte) string { // Extract and set the creation timestamp. if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - unixTimestamp := time.Now().Unix() if err == nil { unixTimestamp = t.Unix() } template, _ = sjson.Set(template, "created", unixTimestamp) + } else { + template, _ = sjson.Set(template, "created", unixTimestamp) } // Extract and set the response ID. @@ -90,19 +95,25 @@ func ConvertCliToOpenAI(rawJson []byte) string { // ConvertCliToOpenAINonStream aggregates response from the backend client // convert a single, non-streaming OpenAI-compatible JSON response. -func ConvertCliToOpenAINonStream(rawJson []byte) string { +func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string { + if isGlAPIKey { + rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) + } template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { template, _ = sjson.Set(template, "model", modelVersionResult.String()) } + if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - unixTimestamp := time.Now().Unix() if err == nil { unixTimestamp = t.Unix() } template, _ = sjson.Set(template, "created", unixTimestamp) + } else { + template, _ = sjson.Set(template, "created", unixTimestamp) } + if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { template, _ = sjson.Set(template, "id", responseIdResult.String()) } diff --git a/internal/client/client.go b/internal/client/client.go index c99dfe55..12d4f8f6 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -27,6 +27,9 @@ const ( codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" apiVersion = "v1internal" pluginVersion = "0.1.9" + + glEndPoint = "https://generativelanguage.googleapis.com/" + glApiVersion = "v1beta" ) var ( @@ -43,15 +46,21 @@ type Client struct { tokenStorage *auth.TokenStorage cfg *config.Config modelQuotaExceeded map[string]*time.Time + glAPIKey string } // NewClient creates a new CLI API client. -func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client { +func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config, glAPIKey ...string) *Client { + var glKey string + if len(glAPIKey) > 0 { + glKey = glAPIKey[0] + } return &Client{ httpClient: httpClient, tokenStorage: ts, cfg: cfg, modelQuotaExceeded: make(map[string]*time.Time), + glAPIKey: glKey, } } @@ -80,7 +89,14 @@ func (c *Client) GetEmail() string { } func (c *Client) GetProjectID() string { - return c.tokenStorage.ProjectID + if c.tokenStorage != nil { + return c.tokenStorage.ProjectID + } + return "" +} + +func (c *Client) GetGenerativeLanguageAPIKey() string { + return c.glAPIKey } // SetupUser performs the initial user onboarding and setup. @@ -235,35 +251,49 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} } } + + var url string + if c.glAPIKey == "" { + // Add alt=sse for streaming + url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) + if stream { + url = url + "?alt=sse" + } + } else { + modelResult := gjson.GetBytes(jsonBody, "model") + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) + if stream { + url = url + "?alt=sse" + } + jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) + } + // log.Debug(string(jsonBody)) reqBody := bytes.NewBuffer(jsonBody) - // Add alt=sse for streaming - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) - if stream { - url = url + "?alt=sse" - } - req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)} - } - - token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)} + return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)} } // Set headers metadataStr := getClientMetadataString() req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", getUserAgent()) - req.Header.Set("Client-Metadata", metadataStr) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + if c.glAPIKey == "" { + token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if errToken != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)} + } + req.Header.Set("User-Agent", getUserAgent()) + req.Header.Set("Client-Metadata", metadataStr) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + } else { + req.Header.Set("x-goog-api-key", c.glAPIKey) + } resp, err := c.httpClient.Do(req) if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)} + return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)} } if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -293,7 +323,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, request.Tools = tools requestBody := map[string]interface{}{ - "project": c.tokenStorage.ProjectID, // Assuming ProjectID is available + "project": c.GetProjectID(), // Assuming ProjectID is available "request": request, "model": model, } @@ -337,7 +367,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, // log.Debug(string(byteRequestBody)) for { if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) @@ -356,7 +386,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { continue } } @@ -391,7 +421,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st request.Tools = tools requestBody := map[string]interface{}{ - "project": c.tokenStorage.ProjectID, // Assuming ProjectID is available + "project": c.GetProjectID(), // Assuming ProjectID is available "request": request, "model": model, } @@ -436,7 +466,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st var stream io.ReadCloser for { if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) @@ -456,7 +486,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { continue } } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 12bfc032..64df3043 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -8,7 +8,11 @@ import ( "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" + "golang.org/x/net/proxy" "io/fs" + "net" + "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -64,6 +68,40 @@ func StartService(cfg *config.Config) { log.Fatalf("Error walking auth directory: %v", err) } + if len(cfg.GlAPIKey) > 0 { + var transport *http.Transport + proxyURL, errParse := url.Parse(cfg.ProxyUrl) + if errParse == nil { + if proxyURL.Scheme == "socks5" { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + proxyAuth := &proxy.Auth{User: username, Password: password} + dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) + } + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + } + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + // Handle HTTP/HTTPS proxy. + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + } + + for i := 0; i < len(cfg.GlAPIKey); i++ { + httpClient := &http.Client{} + if transport != nil { + httpClient.Transport = transport + } + log.Debug("Initializing with Generative Language API key...") + cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) + cliClients = append(cliClients, cliClient) + } + } + // Create and start the API server with the pool of clients. apiServer := api.NewServer(cfg, cliClients) log.Infof("Starting API server on port %d", cfg.Port) diff --git a/internal/config/config.go b/internal/config/config.go index 4af1d9a8..534c565b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,8 @@ type Config struct { ApiKeys []string `yaml:"api-keys"` // QuotaExceeded defines the behavior when a quota is exceeded. QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"` + // GlAPIKey is the API key for the generative language API. + GlAPIKey []string `yaml:"generative-language-api-key"` } type ConfigQuotaExceeded struct {