mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-12 17:30:51 +08:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ebabf5152 | ||
|
|
d7564173dd | ||
|
|
c44c46dd80 | ||
|
|
412148af0e | ||
|
|
d28258501a | ||
|
|
55cd31fb96 | ||
|
|
c5df8e7897 | ||
|
|
d4d529833d | ||
|
|
caa48e7c6f | ||
|
|
acdfb3bceb | ||
|
|
89d68962b1 | ||
|
|
361443db10 | ||
|
|
d6352dd4d4 | ||
|
|
a7eeb06f3d | ||
|
|
9426be7a5c | ||
|
|
4a135f1986 | ||
|
|
c4c02f4ad0 | ||
|
|
b87b9b455f | ||
|
|
db03ae9663 | ||
|
|
969ff6bb68 |
@@ -1472,6 +1472,17 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
projectID := ""
|
||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
||||
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||
if errProject != nil {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
} else {
|
||||
projectID = fetchedProjectID
|
||||
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
metadata := map[string]any{
|
||||
"type": "antigravity",
|
||||
@@ -1484,6 +1495,9 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
if projectID != "" {
|
||||
metadata["project_id"] = projectID
|
||||
}
|
||||
|
||||
fileName := sanitizeAntigravityFileName(email)
|
||||
label := strings.TrimSpace(email)
|
||||
@@ -1507,6 +1521,9 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
delete(oauthStatus, state)
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
}
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
|
||||
latestReleaseUserAgent = "CLIProxyAPI"
|
||||
)
|
||||
|
||||
func (h *Handler) GetConfig(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(200, gin.H{})
|
||||
@@ -20,6 +32,66 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, &cfgCopy)
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
|
||||
func (h *Handler) GetLatestVersion(c *gin.Context) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
proxyURL := ""
|
||||
if h != nil && h.cfg != nil {
|
||||
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
|
||||
}
|
||||
if proxyURL != "" {
|
||||
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
|
||||
util.SetProxy(sdkCfg, client)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("User-Agent", latestReleaseUserAgent)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.WithError(errClose).Debug("failed to close latest version response body")
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
|
||||
return
|
||||
}
|
||||
|
||||
var info releaseInfo
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
version := strings.TrimSpace(info.TagName)
|
||||
if version == "" {
|
||||
version = strings.TrimSpace(info.Name)
|
||||
}
|
||||
if version == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"latest-version": version})
|
||||
}
|
||||
|
||||
func WriteConfig(path string, data []byte) error {
|
||||
data = config.NormalizeCommentIndentation(data)
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
|
||||
@@ -112,5 +112,10 @@ func shouldLogRequest(path string) bool {
|
||||
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, "/api") {
|
||||
return strings.HasPrefix(path, "/api/provider")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -27,11 +27,20 @@ type Option func(*AmpModule)
|
||||
type AmpModule struct {
|
||||
secretSource SecretSource
|
||||
proxy *httputil.ReverseProxy
|
||||
proxyMu sync.RWMutex // protects proxy for hot-reload
|
||||
accessManager *sdkaccess.Manager
|
||||
authMiddleware_ gin.HandlerFunc
|
||||
modelMapper *DefaultModelMapper
|
||||
enabled bool
|
||||
registerOnce sync.Once
|
||||
|
||||
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
|
||||
restrictToLocalhost bool
|
||||
restrictMu sync.RWMutex
|
||||
|
||||
// configMu protects lastConfig for partial reload comparison
|
||||
configMu sync.RWMutex
|
||||
lastConfig *config.AmpCode
|
||||
}
|
||||
|
||||
// New creates a new Amp routing module with the given options.
|
||||
@@ -107,9 +116,19 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
// Always register provider aliases - these work without an upstream
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
|
||||
@@ -118,28 +137,11 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
return
|
||||
}
|
||||
|
||||
// Create secret source with precedence: config > env > file
|
||||
// Cache secrets for 5 minutes to reduce file I/O
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
}
|
||||
|
||||
// Create reverse proxy with gzip handling via ModifyResponse
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
|
||||
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.proxy = proxy
|
||||
m.enabled = true
|
||||
|
||||
// Register management proxy routes (requires upstream)
|
||||
// Restrict to localhost by default for security (prevents drive-by browser attacks)
|
||||
handler := proxyHandler(proxy)
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
|
||||
|
||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||
log.Debug("amp provider alias routes registered")
|
||||
})
|
||||
|
||||
@@ -162,44 +164,169 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// OnConfigUpdated handles configuration updates.
|
||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||
// OnConfigUpdated handles configuration updates with partial reload support.
|
||||
// Only updates components that have actually changed to avoid unnecessary work.
|
||||
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
|
||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
settings := cfg.AmpCode
|
||||
newSettings := cfg.AmpCode
|
||||
|
||||
// Update model mappings (hot-reload supported)
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateMappings(settings.ModelMappings)
|
||||
if m.enabled {
|
||||
log.Infof("amp config updated: reloading %d model mapping(s)", len(settings.ModelMappings))
|
||||
}
|
||||
} else if m.enabled {
|
||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||
}
|
||||
// Get previous config for comparison
|
||||
m.configMu.RLock()
|
||||
oldSettings := m.lastConfig
|
||||
m.configMu.RUnlock()
|
||||
|
||||
if !m.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamURL := strings.TrimSpace(settings.UpstreamURL)
|
||||
if upstreamURL == "" {
|
||||
log.Warn("amp upstream URL removed from config, restart required to disable")
|
||||
return nil
|
||||
}
|
||||
|
||||
// If API key changed, invalidate the cache
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.InvalidateCache()
|
||||
log.Debug("amp secret cache invalidated due to config update")
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("amp config updated (restart required for URL changes)")
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
oldUpstreamURL := ""
|
||||
if oldSettings != nil {
|
||||
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
|
||||
}
|
||||
|
||||
if !m.enabled && newUpstreamURL != "" {
|
||||
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
|
||||
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check model mappings change
|
||||
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
|
||||
if modelMappingsChanged {
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
|
||||
} else if m.enabled {
|
||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||
}
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
// Check upstream URL change - now supports hot-reload
|
||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||
m.setProxy(nil)
|
||||
m.enabled = false
|
||||
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
|
||||
// Recreate proxy with new URL
|
||||
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
|
||||
} else {
|
||||
m.setProxy(proxy)
|
||||
}
|
||||
}
|
||||
|
||||
// Check API key change
|
||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||
if apiKeyChanged {
|
||||
if m.secretSource != nil {
|
||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Store current config for next comparison
|
||||
m.configMu.Lock()
|
||||
settingsCopy := newSettings // copy struct
|
||||
m.lastConfig = &settingsCopy
|
||||
m.configMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||
if m.secretSource == nil {
|
||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||
ms.InvalidateCache()
|
||||
}
|
||||
|
||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.setProxy(proxy)
|
||||
m.enabled = true
|
||||
|
||||
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasModelMappingsChanged compares old and new model mappings.
|
||||
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
if old == nil {
|
||||
return len(new.ModelMappings) > 0
|
||||
}
|
||||
|
||||
if len(old.ModelMappings) != len(new.ModelMappings) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for efficient comparison
|
||||
oldMap := make(map[string]string, len(old.ModelMappings))
|
||||
for _, mapping := range old.ModelMappings {
|
||||
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
|
||||
}
|
||||
|
||||
for _, mapping := range new.ModelMappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasAPIKeyChanged compares old and new API keys.
|
||||
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||
oldKey := ""
|
||||
if old != nil {
|
||||
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
|
||||
}
|
||||
newKey := strings.TrimSpace(new.UpstreamAPIKey)
|
||||
return oldKey != newKey
|
||||
}
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
}
|
||||
|
||||
// getProxy returns the current proxy instance (thread-safe for hot-reload).
|
||||
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
|
||||
m.proxyMu.RLock()
|
||||
defer m.proxyMu.RUnlock()
|
||||
return m.proxy
|
||||
}
|
||||
|
||||
// setProxy updates the proxy instance (thread-safe for hot-reload).
|
||||
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
|
||||
m.proxyMu.Lock()
|
||||
defer m.proxyMu.Unlock()
|
||||
m.proxy = proxy
|
||||
}
|
||||
|
||||
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
|
||||
func (m *AmpModule) IsRestrictedToLocalhost() bool {
|
||||
m.restrictMu.RLock()
|
||||
defer m.restrictMu.RUnlock()
|
||||
return m.restrictToLocalhost
|
||||
}
|
||||
|
||||
// setRestrictToLocalhost updates the localhost restriction setting.
|
||||
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
|
||||
m.restrictMu.Lock()
|
||||
defer m.restrictMu.Unlock()
|
||||
m.restrictToLocalhost = restrict
|
||||
}
|
||||
|
||||
@@ -48,25 +48,25 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
case RouteTypeLocalProvider:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
log.WithFields(fields).Infof("[amp] using local provider for model: %s", requestedModel)
|
||||
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
|
||||
|
||||
case RouteTypeModelMapping:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
||||
log.WithFields(fields).Infof("[amp] model mapped: %s -> %s", requestedModel, resolvedModel)
|
||||
// model mapping already logged in mapper; avoid duplicate here
|
||||
|
||||
case RouteTypeAmpCredits:
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[amp] forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
fields["source"] = "error"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[amp] no provider available for model_id: %s", requestedModel)
|
||||
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
|
||||
mapper := NewModelMapper(nil)
|
||||
|
||||
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||
{From: "", To: "model-b"}, // Invalid: empty from
|
||||
{From: "model-a", To: ""}, // Invalid: empty to
|
||||
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||
{From: "", To: "model-b"}, // Invalid: empty from
|
||||
{From: "model-a", To: ""}, // Invalid: empty to
|
||||
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||
{From: "model-c", To: "model-d"}, // Valid
|
||||
})
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
@@ -14,15 +17,16 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
|
||||
// Returns 403 Forbidden for non-localhost clients.
|
||||
//
|
||||
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
|
||||
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
|
||||
// middleware will not work correctly behind reverse proxies - users deploying behind
|
||||
// nginx/Cloudflare should disable this feature and use firewall rules instead.
|
||||
func localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Check current setting (hot-reloadable)
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
|
||||
// This cannot be forged by X-Forwarded-For or other client-controlled headers
|
||||
remoteAddr := c.Request.RemoteAddr
|
||||
@@ -77,23 +81,58 @@ func noCORSMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// managementAvailabilityMiddleware short-circuits management routes when the upstream
|
||||
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
|
||||
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if m.getProxy() == nil {
|
||||
logging.SkipGinRequestLogging(c)
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": "amp upstream proxy not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
ampAPI.Use(noCORSMiddleware())
|
||||
ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
|
||||
|
||||
// Apply localhost-only restriction if configured
|
||||
if restrictToLocalhost {
|
||||
ampAPI.Use(localhostOnlyMiddleware())
|
||||
log.Info("amp management routes restricted to localhost only (CORS disabled)")
|
||||
} else {
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
// Upstream already wrote the status (often 404) before the client/stream ended.
|
||||
return
|
||||
}
|
||||
panic(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy := m.getProxy()
|
||||
if proxy == nil {
|
||||
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||
return
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
// Management routes - these are proxied directly to Amp upstream
|
||||
ampAPI.Any("/internal", proxyHandler)
|
||||
ampAPI.Any("/internal/*path", proxyHandler)
|
||||
@@ -114,11 +153,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
ampAPI.Any("/tab/*path", proxyHandler)
|
||||
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes
|
||||
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
|
||||
if restrictToLocalhost {
|
||||
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
|
||||
}
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
// Root-level auth routes for CLI login flow
|
||||
@@ -134,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
return m.getProxy()
|
||||
})
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
@@ -177,10 +214,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
return m.getProxy()
|
||||
}, m.modelMapper)
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
|
||||
@@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Spy to track if proxy handler was called
|
||||
proxyCalled := false
|
||||
proxyHandler := func(c *gin.Context) {
|
||||
proxyCalled = true
|
||||
c.String(200, "proxied")
|
||||
// Create module with proxy for testing
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: false, // disable localhost restriction for tests
|
||||
}
|
||||
|
||||
m := &AmpModule{}
|
||||
// Create a mock proxy that tracks calls
|
||||
proxyCalled := false
|
||||
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxyCalled = true
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("proxied"))
|
||||
}))
|
||||
defer mockProxy.Close()
|
||||
|
||||
// Create real proxy to mock server
|
||||
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests
|
||||
m.registerManagementRoutes(r, base)
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -37,13 +47,14 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
{"/api/meta", http.MethodGet},
|
||||
{"/api/telemetry", http.MethodGet},
|
||||
{"/api/threads", http.MethodGet},
|
||||
{"/threads/", http.MethodGet},
|
||||
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
||||
{"/api/otel", http.MethodGet},
|
||||
{"/api/tab", http.MethodGet},
|
||||
{"/api/tab/some/path", http.MethodGet},
|
||||
{"/auth", http.MethodGet}, // Root-level auth route
|
||||
{"/auth/cli-login", http.MethodGet}, // CLI login flow
|
||||
{"/auth/callback", http.MethodGet}, // OAuth callback
|
||||
{"/auth", http.MethodGet}, // Root-level auth route
|
||||
{"/auth/cli-login", http.MethodGet}, // CLI login flow
|
||||
{"/auth/callback", http.MethodGet}, // OAuth callback
|
||||
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
||||
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
||||
{"/api/provider/google/v1beta1/models", http.MethodPost},
|
||||
@@ -231,8 +242,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Apply localhost-only middleware
|
||||
r.Use(localhostOnlyMiddleware())
|
||||
// Create module with localhost restriction enabled
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: true,
|
||||
}
|
||||
|
||||
// Apply dynamic localhost-only middleware
|
||||
r.Use(m.localhostOnlyMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
@@ -305,3 +321,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
|
||||
// Create module with localhost restriction initially enabled
|
||||
m := &AmpModule{
|
||||
restrictToLocalhost: true,
|
||||
}
|
||||
|
||||
// Apply dynamic localhost-only middleware
|
||||
r.Use(m.localhostOnlyMiddleware())
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
// Test 1: Remote IP should be blocked when restriction is enabled
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test 2: Hot-reload - disable restriction
|
||||
m.setRestrictToLocalhost(false)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test 3: Hot-reload - re-enable restriction
|
||||
m.setRestrictToLocalhost(true)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w = httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,6 +139,17 @@ func (s *MultiSourceSecret) InvalidateCache() {
|
||||
s.cache = nil
|
||||
}
|
||||
|
||||
// UpdateExplicitKey refreshes the config-provided key and clears cache.
|
||||
func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.explicitKey = strings.TrimSpace(key)
|
||||
s.cache = nil
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// StaticSecretSource returns a fixed API key (for testing)
|
||||
type StaticSecretSource struct {
|
||||
key string
|
||||
|
||||
@@ -472,6 +472,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
|
||||
|
||||
mgmt.GET("/debug", s.mgmt.GetDebug)
|
||||
mgmt.PUT("/debug", s.mgmt.PutDebug)
|
||||
|
||||
@@ -85,6 +85,9 @@ type InlineData struct {
|
||||
// FunctionCall represents a tool call requested by the model.
|
||||
// It includes the function name and its arguments that the model wants to execute.
|
||||
type FunctionCall struct {
|
||||
// ID is the identifier of the function to be called.
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
// Name is the identifier of the function to be called.
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -95,6 +98,9 @@ type FunctionCall struct {
|
||||
// FunctionResponse represents the result of a tool execution.
|
||||
// This is sent back to the model after a tool call has been processed.
|
||||
type FunctionResponse struct {
|
||||
// ID is the identifier of the function to be called.
|
||||
ID string `json:"id,omitempty"`
|
||||
|
||||
// Name is the identifier of the function that was called.
|
||||
Name string `json:"name"`
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||
|
||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||
// using logrus. It captures request details including method, path, status code, latency,
|
||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
||||
@@ -28,6 +30,10 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
|
||||
c.Next()
|
||||
|
||||
if shouldSkipGinRequestLogging(c) {
|
||||
return
|
||||
}
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
@@ -77,3 +83,24 @@ func GinLogrusRecovery() gin.HandlerFunc {
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
})
|
||||
}
|
||||
|
||||
// SkipGinRequestLogging marks the provided Gin context so that GinLogrusLogger
|
||||
// will skip emitting a log line for the associated request.
|
||||
func SkipGinRequestLogging(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Set(skipGinLogKey, true)
|
||||
}
|
||||
|
||||
func shouldSkipGinRequestLogging(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
val, exists := c.Get(skipGinLogKey)
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
flag, ok := val.(bool)
|
||||
return ok && flag
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -508,8 +509,46 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
requestURL.WriteString(url.QueryEscape(alt))
|
||||
}
|
||||
|
||||
payload = geminiToAntigravity(modelName, payload)
|
||||
// Extract project_id from auth metadata if available
|
||||
projectID := ""
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
}
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
strJSON = util.DeleteKey(strJSON, "$schema")
|
||||
strJSON = util.DeleteKey(strJSON, "maxItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||
|
||||
paths = make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||
for _, p := range paths {
|
||||
anyOf := gjson.Get(strJSON, p)
|
||||
if anyOf.IsArray() {
|
||||
anyOfItems := anyOf.Array()
|
||||
if len(anyOfItems) > 0 {
|
||||
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -670,10 +709,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func geminiToAntigravity(modelName string, payload []byte) []byte {
|
||||
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
template, _ = sjson.Set(template, "project", projectID)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
|
||||
|
||||
|
||||
@@ -89,10 +89,11 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
functionID := contentResult.Get("id").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{Name: functionName, Args: args},
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
}
|
||||
@@ -105,7 +106,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,35 +141,38 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
}
|
||||
} else {
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block if already in content state
|
||||
if params.ResponseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
// First, close any existing content block
|
||||
if params.ResponseType != 0 {
|
||||
if params.ResponseType == 2 {
|
||||
// output = output + "event: content_block_delta\n"
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
|
||||
if partTextResult.String() != "" || !finishReasonResult.Exists() {
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block if already in content state
|
||||
if params.ResponseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
// First, close any existing content block
|
||||
if params.ResponseType != 0 {
|
||||
if params.ResponseType == 2 {
|
||||
// output = output + "event: content_block_delta\n"
|
||||
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
|
||||
// output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
params.ResponseIndex++
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
params.ResponseIndex++
|
||||
}
|
||||
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.ResponseType = 1 // Set state to content
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.ResponseType = 1 // Set state to content
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if functionCallResult.Exists() {
|
||||
|
||||
@@ -251,6 +251,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
@@ -266,6 +267,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
|
||||
@@ -327,7 +327,7 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
|
||||
func mustMarshalJSON(v interface{}) string {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
@@ -249,6 +249,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
|
||||
functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature)
|
||||
functionCall, _ = sjson.Set(functionCall, "functionCall.id", item.Get("call_id").String())
|
||||
|
||||
// Parse arguments JSON string and set as args object
|
||||
if arguments != "" {
|
||||
@@ -285,6 +286,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.id", callID)
|
||||
|
||||
// Set the raw JSON output directly (preserves string encoding)
|
||||
if outputRaw != "" && outputRaw != "null" {
|
||||
|
||||
@@ -79,6 +79,15 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
||||
return finalJson, nil
|
||||
}
|
||||
|
||||
func DeleteKey(jsonStr, keyName string) string {
|
||||
paths := make([]string, 0)
|
||||
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
||||
for _, p := range paths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||
// double-quoted strings with proper escaping.
|
||||
|
||||
@@ -570,6 +570,35 @@ func summarizeExcludedModels(list []string) excludedModelsSummary {
|
||||
}
|
||||
}
|
||||
|
||||
type ampModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
func summarizeAmpModelMappings(mappings []config.AmpModelMapping) ampModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return ampModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return ampModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return ampModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
@@ -1762,6 +1791,31 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := summarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := summarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
|
||||
if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -127,6 +128,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
|
||||
projectID := ""
|
||||
if tokenResp.AccessToken != "" {
|
||||
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
||||
if errProject != nil {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
} else {
|
||||
projectID = fetchedProjectID
|
||||
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
metadata := map[string]any{
|
||||
"type": "antigravity",
|
||||
@@ -139,6 +152,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
if email != "" {
|
||||
metadata["email"] = email
|
||||
}
|
||||
if projectID != "" {
|
||||
metadata["project_id"] = projectID
|
||||
}
|
||||
|
||||
fileName := sanitizeAntigravityFileName(email)
|
||||
label := email
|
||||
@@ -147,6 +163,9 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
}
|
||||
|
||||
fmt.Println("Antigravity authentication successful")
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
}
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "antigravity",
|
||||
@@ -291,3 +310,89 @@ func sanitizeAntigravityFileName(email string) string {
|
||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
||||
}
|
||||
|
||||
// Antigravity API constants for project discovery
|
||||
const (
|
||||
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityAPIVersion = "v1internal"
|
||||
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||
)
|
||||
|
||||
// FetchAntigravityProjectID exposes project discovery for external callers.
|
||||
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
|
||||
}
|
||||
|
||||
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
|
||||
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
|
||||
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||
// Call loadCodeAssist to get the project
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
},
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||
if errMarshal != nil {
|
||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||
}
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
return "", fmt.Errorf("execute request: %w", errDo)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read response: %w", errRead)
|
||||
}
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
|
||||
var loadResp map[string]any
|
||||
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
// Extract projectID from response
|
||||
projectID := ""
|
||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||
if id, okID := projectMap["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("no cloudaicompanionProject in response")
|
||||
}
|
||||
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user