mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
feat: Add support for VertexAI compatible service (#375)
feat: consolidate Vertex AI compatibility with API key support in Gemini
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -15,6 +15,7 @@ pgstore/*
|
|||||||
gitstore/*
|
gitstore/*
|
||||||
objectstore/*
|
objectstore/*
|
||||||
static/*
|
static/*
|
||||||
|
refs/*
|
||||||
|
|
||||||
# Authentication data
|
# Authentication data
|
||||||
auths/*
|
auths/*
|
||||||
@@ -30,3 +31,7 @@ GEMINI.md
|
|||||||
.vscode/*
|
.vscode/*
|
||||||
.claude/*
|
.claude/*
|
||||||
.serena/*
|
.serena/*
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
._*
|
||||||
|
|||||||
@@ -934,6 +934,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||||
codexAPIKeyCount := len(cfg.CodexKey)
|
codexAPIKeyCount := len(cfg.CodexKey)
|
||||||
|
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
|
||||||
openAICompatCount := 0
|
openAICompatCount := 0
|
||||||
for i := range cfg.OpenAICompatibility {
|
for i := range cfg.OpenAICompatibility {
|
||||||
entry := cfg.OpenAICompatibility[i]
|
entry := cfg.OpenAICompatibility[i]
|
||||||
@@ -944,13 +945,14 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
openAICompatCount += len(entry.APIKeys)
|
openAICompatCount += len(entry.APIKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n",
|
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||||
total,
|
total,
|
||||||
authFiles,
|
authFiles,
|
||||||
geminiAPIKeyCount,
|
geminiAPIKeyCount,
|
||||||
claudeAPIKeyCount,
|
claudeAPIKeyCount,
|
||||||
codexAPIKeyCount,
|
codexAPIKeyCount,
|
||||||
|
vertexAICompatCount,
|
||||||
openAICompatCount,
|
openAICompatCount,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ type Config struct {
|
|||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
|
// VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers.
|
||||||
|
// Used for services that use Vertex AI-style paths but with simple API key authentication.
|
||||||
|
VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"`
|
||||||
|
|
||||||
// RequestRetry defines the retry times when the request failed.
|
// RequestRetry defines the retry times when the request failed.
|
||||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||||
@@ -343,6 +347,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
|
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
||||||
|
cfg.SanitizeVertexCompatKeys()
|
||||||
|
|
||||||
// Sanitize Codex keys: drop entries without base-url
|
// Sanitize Codex keys: drop entries without base-url
|
||||||
cfg.SanitizeCodexKeys()
|
cfg.SanitizeCodexKeys()
|
||||||
|
|
||||||
@@ -831,6 +838,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
|||||||
switch key {
|
switch key {
|
||||||
case "generative-language-api-key",
|
case "generative-language-api-key",
|
||||||
"gemini-api-key",
|
"gemini-api-key",
|
||||||
|
"vertex-api-key",
|
||||||
"claude-api-key",
|
"claude-api-key",
|
||||||
"codex-api-key",
|
"codex-api-key",
|
||||||
"openai-compatibility":
|
"openai-compatibility":
|
||||||
|
|||||||
84
internal/config/vertex_compat.go
Normal file
84
internal/config/vertex_compat.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// VertexCompatKey represents the configuration for Vertex AI-compatible API keys.
|
||||||
|
// This supports third-party services that use Vertex AI-style endpoint paths
|
||||||
|
// (/publishers/google/models/{model}:streamGenerateContent) but authenticate
|
||||||
|
// with simple API keys instead of Google Cloud service account credentials.
|
||||||
|
//
|
||||||
|
// Example services: zenmux.ai and similar Vertex-compatible providers.
|
||||||
|
type VertexCompatKey struct {
|
||||||
|
// APIKey is the authentication key for accessing the Vertex-compatible API.
|
||||||
|
// Maps to the x-goog-api-key header.
|
||||||
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||||
|
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||||
|
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||||
|
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||||
|
|
||||||
|
// ProxyURL optionally overrides the global proxy for this API key.
|
||||||
|
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||||
|
|
||||||
|
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||||
|
// Commonly used for cookies, user-agent, and other authentication headers.
|
||||||
|
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||||
|
|
||||||
|
// Models defines the model configurations including aliases for routing.
|
||||||
|
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VertexCompatModel represents a model configuration for Vertex compatibility,
|
||||||
|
// including the actual model name and its alias for API routing.
|
||||||
|
type VertexCompatModel struct {
|
||||||
|
// Name is the actual model name used by the external provider.
|
||||||
|
Name string `yaml:"name" json:"name"`
|
||||||
|
|
||||||
|
// Alias is the model name alias that clients will use to reference this model.
|
||||||
|
Alias string `yaml:"alias" json:"alias"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||||
|
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey))
|
||||||
|
out := cfg.VertexCompatAPIKey[:0]
|
||||||
|
for i := range cfg.VertexCompatAPIKey {
|
||||||
|
entry := cfg.VertexCompatAPIKey[i]
|
||||||
|
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||||
|
if entry.APIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||||
|
if entry.BaseURL == "" {
|
||||||
|
// BaseURL is required for vertex-compat keys
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
|
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||||
|
|
||||||
|
// Sanitize models: remove entries without valid alias
|
||||||
|
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
||||||
|
for _, model := range entry.Models {
|
||||||
|
model.Alias = strings.TrimSpace(model.Alias)
|
||||||
|
model.Name = strings.TrimSpace(model.Name)
|
||||||
|
if model.Alias != "" && model.Name != "" {
|
||||||
|
sanitizedModels = append(sanitizedModels, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entry.Models = sanitizedModels
|
||||||
|
|
||||||
|
// Use API key + base URL as uniqueness key
|
||||||
|
uniqueKey := entry.APIKey + "|" + entry.BaseURL
|
||||||
|
if _, exists := seen[uniqueKey]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[uniqueKey] = struct{}{}
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
cfg.VertexCompatAPIKey = out
|
||||||
|
}
|
||||||
@@ -44,6 +44,22 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
|||||||
// Identifier returns provider key for manager routing.
|
// Identifier returns provider key for manager routing.
|
||||||
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||||
|
|
||||||
|
// GeminiVertexCompatExecutor is a thin wrapper around GeminiVertexExecutor
|
||||||
|
// that provides the correct identifier for vertex-compat routing.
|
||||||
|
type GeminiVertexCompatExecutor struct {
|
||||||
|
*GeminiVertexExecutor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiVertexCompatExecutor constructs the Vertex-compatible executor.
|
||||||
|
func NewGeminiVertexCompatExecutor(cfg *config.Config) *GeminiVertexCompatExecutor {
|
||||||
|
return &GeminiVertexCompatExecutor{
|
||||||
|
GeminiVertexExecutor: NewGeminiVertexExecutor(cfg),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier returns provider key for manager routing.
|
||||||
|
func (e *GeminiVertexCompatExecutor) Identifier() string { return "vertex-compat" }
|
||||||
|
|
||||||
// PrepareRequest is a no-op for Vertex.
|
// PrepareRequest is a no-op for Vertex.
|
||||||
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
return nil
|
return nil
|
||||||
@@ -51,11 +67,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A
|
|||||||
|
|
||||||
// Execute handles non-streaming requests.
|
// Execute handles non-streaming requests.
|
||||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
// Try API key authentication first
|
||||||
if errCreds != nil {
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
return resp, errCreds
|
|
||||||
|
// If no API key found, fall back to service account authentication
|
||||||
|
if apiKey == "" {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return resp, errCreds
|
||||||
|
}
|
||||||
|
return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use API key authentication
|
||||||
|
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStream handles SSE streaming for Vertex.
|
||||||
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
// Try API key authentication first
|
||||||
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
|
|
||||||
|
// If no API key found, fall back to service account authentication
|
||||||
|
if apiKey == "" {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return nil, errCreds
|
||||||
|
}
|
||||||
|
return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use API key authentication
|
||||||
|
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens calls Vertex countTokens endpoint.
|
||||||
|
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
// Try API key authentication first
|
||||||
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
|
|
||||||
|
// If no API key found, fall back to service account authentication
|
||||||
|
if apiKey == "" {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errCreds
|
||||||
|
}
|
||||||
|
return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use API key authentication
|
||||||
|
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// countTokensWithServiceAccount handles token counting using service account credentials.
|
||||||
|
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
|
baseURL := vertexBaseURL(location)
|
||||||
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
} else if errTok != nil {
|
||||||
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translatedReq,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return cliproxyexecutor.Response{}, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return cliproxyexecutor.Response{}, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
|
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
|
// For API key auth, use simpler URL format without project/location
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if apiKey != "" {
|
||||||
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translatedReq,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return cliproxyexecutor.Response{}, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return cliproxyexecutor.Response{}, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh is a no-op for service account based credentials.
|
||||||
|
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeWithServiceAccount handles authentication using service account credentials.
|
||||||
|
// This method contains the original service account authentication logic.
|
||||||
|
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
@@ -149,13 +392,105 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream handles SSE streaming for Vertex.
|
// executeWithAPIKey handles authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
// This method follows the vertex-compat pattern for API key authentication.
|
||||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
||||||
if errCreds != nil {
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
return nil, errCreds
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
|
||||||
|
action := "generateContent"
|
||||||
|
if req.Metadata != nil {
|
||||||
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
|
action = "countTokens"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For API key auth, use simpler URL format without project/location
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action)
|
||||||
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return resp, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if apiKey != "" {
|
||||||
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return resp, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return resp, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||||
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
@@ -266,42 +601,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens calls Vertex countTokens endpoint.
|
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
if errCreds != nil {
|
defer reporter.trackFailure(ctx, &err)
|
||||||
return cliproxyexecutor.Response{}, errCreds
|
|
||||||
}
|
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
if budgetOverride != nil {
|
if budgetOverride != nil {
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
budgetOverride = &norm
|
budgetOverride = &norm
|
||||||
}
|
}
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
}
|
}
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
||||||
|
|
||||||
baseURL := vertexBaseURL(location)
|
// For API key auth, use simpler URL format without project/location
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent")
|
||||||
|
if opts.Alt == "" {
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
if errNewReq != nil {
|
if errNewReq != nil {
|
||||||
return cliproxyexecutor.Response{}, errNewReq
|
return nil, errNewReq
|
||||||
}
|
}
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
if apiKey != "" {
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
} else if errTok != nil {
|
|
||||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
|
||||||
}
|
}
|
||||||
applyGeminiHeaders(httpReq, auth)
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
@@ -315,7 +652,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
URL: url,
|
URL: url,
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
Body: translatedReq,
|
Body: body,
|
||||||
Provider: e.Identifier(),
|
Provider: e.Identifier(),
|
||||||
AuthID: authID,
|
AuthID: authID,
|
||||||
AuthLabel: authLabel,
|
AuthLabel: authLabel,
|
||||||
@@ -327,38 +664,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
httpResp, errDo := httpClient.Do(httpReq)
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
return cliproxyexecutor.Response{}, errDo
|
return nil, errDo
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
return cliproxyexecutor.Response{}, errRead
|
|
||||||
}
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
||||||
}
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh is a no-op for service account based credentials.
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
stream = out
|
||||||
return auth, nil
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
scanner.Buffer(nil, 20_971_520)
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
@@ -401,6 +753,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou
|
|||||||
return projectID, location, saJSON, nil
|
return projectID, location, saJSON, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern.
|
||||||
|
func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
|
if a == nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if a.Attributes != nil {
|
||||||
|
apiKey = a.Attributes["api_key"]
|
||||||
|
baseURL = a.Attributes["base_url"]
|
||||||
|
}
|
||||||
|
if apiKey == "" && a.Metadata != nil {
|
||||||
|
if v, ok := a.Metadata["access_token"].(string); ok {
|
||||||
|
apiKey = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func vertexBaseURL(location string) string {
|
func vertexBaseURL(location string) string {
|
||||||
loc := strings.TrimSpace(location)
|
loc := strings.TrimSpace(location)
|
||||||
if loc == "" {
|
if loc == "" {
|
||||||
|
|||||||
@@ -498,6 +498,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
|
|||||||
return hex.EncodeToString(sum[:])
|
return hex.EncodeToString(sum[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func computeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(models)
|
||||||
|
if err != nil || len(data) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
|
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||||
func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||||
if len(models) == 0 {
|
if len(models) == 0 {
|
||||||
@@ -920,8 +932,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
|||||||
// no legacy clients to unregister
|
// no legacy clients to unregister
|
||||||
|
|
||||||
// Create new API key clients based on the new config
|
// Create new API key clients based on the new config
|
||||||
geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||||
totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||||
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
||||||
|
|
||||||
var authFileCount int
|
var authFileCount int
|
||||||
@@ -964,7 +976,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
|||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||||
|
|
||||||
// Ensure consumers observe the new configuration before auth updates dispatch.
|
// Ensure consumers observe the new configuration before auth updates dispatch.
|
||||||
if w.reloadCallback != nil {
|
if w.reloadCallback != nil {
|
||||||
@@ -974,10 +986,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
|||||||
|
|
||||||
w.refreshAuthState()
|
w.refreshAuthState()
|
||||||
|
|
||||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex-compat keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||||
totalNewClients,
|
totalNewClients,
|
||||||
authFileCount,
|
authFileCount,
|
||||||
geminiAPIKeyCount,
|
geminiAPIKeyCount,
|
||||||
|
vertexCompatAPIKeyCount,
|
||||||
claudeAPIKeyCount,
|
claudeAPIKeyCount,
|
||||||
codexAPIKeyCount,
|
codexAPIKeyCount,
|
||||||
openAICompatCount,
|
openAICompatCount,
|
||||||
@@ -1092,6 +1105,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
||||||
out = append(out, a)
|
out = append(out, a)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Claude API keys -> synthesize auths
|
// Claude API keys -> synthesize auths
|
||||||
for i := range cfg.ClaudeKey {
|
for i := range cfg.ClaudeKey {
|
||||||
ck := cfg.ClaudeKey[i]
|
ck := cfg.ClaudeKey[i]
|
||||||
@@ -1258,6 +1272,42 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process Vertex compatibility providers
|
||||||
|
for i := range cfg.VertexCompatAPIKey {
|
||||||
|
compat := &cfg.VertexCompatAPIKey[i]
|
||||||
|
providerName := "vertex-compat"
|
||||||
|
base := strings.TrimSpace(compat.BaseURL)
|
||||||
|
|
||||||
|
key := strings.TrimSpace(compat.APIKey)
|
||||||
|
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||||
|
idKind := fmt.Sprintf("vertex-compatibility:%s", base)
|
||||||
|
id, token := idGen.next(idKind, key, base, proxyURL)
|
||||||
|
attrs := map[string]string{
|
||||||
|
"source": fmt.Sprintf("config:vertex-compatibility[%s]", token),
|
||||||
|
"base_url": base,
|
||||||
|
"provider_key": providerName,
|
||||||
|
}
|
||||||
|
if key != "" {
|
||||||
|
attrs["api_key"] = key
|
||||||
|
}
|
||||||
|
if hash := computeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||||
|
attrs["models_hash"] = hash
|
||||||
|
}
|
||||||
|
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||||
|
a := &coreauth.Auth{
|
||||||
|
ID: id,
|
||||||
|
Provider: providerName,
|
||||||
|
Label: "Vertex Compatibility",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Attributes: attrs,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
|
|
||||||
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
||||||
entries, _ := os.ReadDir(w.authDir)
|
entries, _ := os.ReadDir(w.authDir)
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
@@ -1474,8 +1524,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
|||||||
return authFileCount
|
return authFileCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||||
geminiAPIKeyCount := 0
|
geminiAPIKeyCount := 0
|
||||||
|
vertexCompatAPIKeyCount := 0
|
||||||
claudeAPIKeyCount := 0
|
claudeAPIKeyCount := 0
|
||||||
codexAPIKeyCount := 0
|
codexAPIKeyCount := 0
|
||||||
openAICompatCount := 0
|
openAICompatCount := 0
|
||||||
@@ -1484,6 +1535,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
|||||||
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
|
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
|
||||||
geminiAPIKeyCount += len(cfg.GeminiKey)
|
geminiAPIKeyCount += len(cfg.GeminiKey)
|
||||||
}
|
}
|
||||||
|
if len(cfg.VertexCompatAPIKey) > 0 {
|
||||||
|
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
||||||
|
}
|
||||||
if len(cfg.ClaudeKey) > 0 {
|
if len(cfg.ClaudeKey) > 0 {
|
||||||
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
||||||
}
|
}
|
||||||
@@ -1501,7 +1555,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider {
|
|||||||
type apiKeyClientProvider struct{}
|
type apiKeyClientProvider struct{}
|
||||||
|
|
||||||
func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) {
|
func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) {
|
||||||
geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
|
geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &APIKeyClientResult{
|
return &APIKeyClientResult{
|
||||||
GeminiKeyCount: geminiCount,
|
GeminiKeyCount: geminiCount,
|
||||||
ClaudeKeyCount: claudeCount,
|
VertexCompatKeyCount: vertexCompatCount,
|
||||||
CodexKeyCount: codexCount,
|
ClaudeKeyCount: claudeCount,
|
||||||
OpenAICompatCount: openAICompat,
|
CodexKeyCount: codexCount,
|
||||||
|
OpenAICompatCount: openAICompat,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -324,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
|
|||||||
if len(a.Attributes) > 0 {
|
if len(a.Attributes) > 0 {
|
||||||
providerKey = strings.TrimSpace(a.Attributes["provider_key"])
|
providerKey = strings.TrimSpace(a.Attributes["provider_key"])
|
||||||
compatName = strings.TrimSpace(a.Attributes["compat_name"])
|
compatName = strings.TrimSpace(a.Attributes["compat_name"])
|
||||||
if providerKey != "" || compatName != "" {
|
if compatName != "" {
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
providerKey = compatName
|
providerKey = compatName
|
||||||
}
|
}
|
||||||
@@ -362,6 +362,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||||
case "vertex":
|
case "vertex":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
|
||||||
|
case "vertex-compat":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewGeminiVertexCompatExecutor(s.cfg))
|
||||||
case "gemini-cli":
|
case "gemini-cli":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
||||||
case "aistudio":
|
case "aistudio":
|
||||||
@@ -498,7 +500,7 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
fmt.Println("API server started successfully")
|
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
|
||||||
|
|
||||||
if s.hooks.OnAfterStart != nil {
|
if s.hooks.OnAfterStart != nil {
|
||||||
s.hooks.OnAfterStart(s)
|
s.hooks.OnAfterStart(s)
|
||||||
@@ -680,6 +682,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
// Vertex AI Gemini supports the same model identifiers as Gemini.
|
// Vertex AI Gemini supports the same model identifiers as Gemini.
|
||||||
models = registry.GetGeminiVertexModels()
|
models = registry.GetGeminiVertexModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
case "vertex-compat":
|
||||||
|
// Handle Vertex AI compatibility providers with custom model definitions
|
||||||
|
if s.cfg != nil && len(s.cfg.VertexCompatAPIKey) > 0 {
|
||||||
|
// Create models for all Vertex compatibility providers
|
||||||
|
allModels := make([]*ModelInfo, 0)
|
||||||
|
for i := range s.cfg.VertexCompatAPIKey {
|
||||||
|
compat := &s.cfg.VertexCompatAPIKey[i]
|
||||||
|
for j := range compat.Models {
|
||||||
|
m := compat.Models[j]
|
||||||
|
// Use alias as model ID, fallback to name if alias is empty
|
||||||
|
modelID := m.Alias
|
||||||
|
if modelID == "" {
|
||||||
|
modelID = m.Name
|
||||||
|
}
|
||||||
|
if modelID != "" {
|
||||||
|
allModels = append(allModels, &ModelInfo{
|
||||||
|
ID: modelID,
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "vertex-compat",
|
||||||
|
Type: "vertex-compat",
|
||||||
|
DisplayName: m.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
models = allModels
|
||||||
|
}
|
||||||
|
|
||||||
case "gemini-cli":
|
case "gemini-cli":
|
||||||
models = registry.GetGeminiCLIModels()
|
models = registry.GetGeminiCLIModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
|||||||
@@ -49,19 +49,21 @@ type APIKeyClientProvider interface {
|
|||||||
Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error)
|
Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyClientResult contains API key based clients along with type counts.
|
// APIKeyClientResult is returned by APIKeyClientProvider.Load()
|
||||||
// It provides metadata about the number of clients loaded for each provider type.
|
|
||||||
type APIKeyClientResult struct {
|
type APIKeyClientResult struct {
|
||||||
// GeminiKeyCount is the number of Gemini API key clients loaded.
|
// GeminiKeyCount is the number of Gemini API keys loaded
|
||||||
GeminiKeyCount int
|
GeminiKeyCount int
|
||||||
|
|
||||||
// ClaudeKeyCount is the number of Claude API key clients loaded.
|
// VertexCompatKeyCount is the number of Vertex-compatible API keys loaded
|
||||||
|
VertexCompatKeyCount int
|
||||||
|
|
||||||
|
// ClaudeKeyCount is the number of Claude API keys loaded
|
||||||
ClaudeKeyCount int
|
ClaudeKeyCount int
|
||||||
|
|
||||||
// CodexKeyCount is the number of Codex API key clients loaded.
|
// CodexKeyCount is the number of Codex API keys loaded
|
||||||
CodexKeyCount int
|
CodexKeyCount int
|
||||||
|
|
||||||
// OpenAICompatCount is the number of OpenAI-compatible API key clients loaded.
|
// OpenAICompatCount is the number of OpenAI compatibility API keys loaded
|
||||||
OpenAICompatCount int
|
OpenAICompatCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user