Add AMP fallback proxy and shared Gemini normalization

- add fallback handler that forwards Amp provider requests to ampcode.com when the provider isn’t configured locally
- wrap AMP provider routes with the fallback so requests always have a handler
- share Gemini thinking model normalization helper between core handlers and AMP fallback
This commit is contained in:
Ben Vargas
2025-10-24 09:17:12 -06:00
parent 9ad0f3f91e
commit 8193392bfe
4 changed files with 152 additions and 34 deletions

View File

@@ -0,0 +1,105 @@
package amp
import (
"bytes"
"encoding/json"
"io"
"net/http/httputil"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
}
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
}
}
// WrapHandler wraps a gin.HandlerFunc with fallback logic
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Read the request body to extract the model name
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Errorf("amp fallback: failed to read request body: %v", err)
handler(c)
return
}
// Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Try to extract model from request body or URL path (for Gemini)
modelName := extractModelFromRequest(bodyBytes, c)
if modelName == "" {
// Can't determine model, proceed with normal handler
handler(c)
return
}
// Normalize model (handles Gemini thinking suffixes)
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
// Check if we have providers for this model
providers := util.GetProviderName(normalizedModel)
if len(providers) == 0 {
// No providers configured - check if we have a proxy for fallback
proxy := fh.getProxy()
if proxy != nil {
// Fallback to ampcode.com
log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName)
// Restore body again for the proxy
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Forward to ampcode.com
proxy.ServeHTTP(c.Writer, c.Request)
return
}
// No proxy available, let the normal handler return the error
log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName)
}
// Providers available or no proxy for fallback, restore body and use normal handler
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
}
// extractModelFromRequest attempts to extract the model name from various request formats
func extractModelFromRequest(body []byte, c *gin.Context) string {
// First try to parse from JSON body (OpenAI, Claude, etc.)
var payload map[string]interface{}
if err := json.Unmarshal(body, &payload); err == nil {
// Check common model field names
if model, ok := payload["model"].(string); ok {
return model
}
}
// For Gemini requests, model is in the URL path: /models/{model}:generateContent
// Extract from :action parameter (e.g., "gemini-pro:generateContent")
if action := c.Param("action"); action != "" {
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
parts := strings.Split(action, ":")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
return ""
}

View File

@@ -2,6 +2,7 @@ package amp
import ( import (
"net" "net"
"net/http/httputil"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -17,7 +18,7 @@ import (
func localhostOnlyMiddleware() gin.HandlerFunc { func localhostOnlyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
clientIP := c.ClientIP() clientIP := c.ClientIP()
// Parse the IP to handle both IPv4 and IPv6 // Parse the IP to handle both IPv4 and IPv6
ip := net.ParseIP(clientIP) ip := net.ParseIP(clientIP)
if ip == nil { if ip == nil {
@@ -27,7 +28,7 @@ func localhostOnlyMiddleware() gin.HandlerFunc {
}) })
return return
} }
// Check if IP is loopback (127.0.0.1 or ::1) // Check if IP is loopback (127.0.0.1 or ::1)
if !ip.IsLoopback() { if !ip.IsLoopback() {
log.Warnf("Amp management: non-localhost IP %s attempted access, denying", clientIP) log.Warnf("Amp management: non-localhost IP %s attempted access, denying", clientIP)
@@ -36,7 +37,7 @@ func localhostOnlyMiddleware() gin.HandlerFunc {
}) })
return return
} }
c.Next() c.Next()
} }
} }
@@ -50,13 +51,13 @@ func noCORSMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Methods", "") c.Header("Access-Control-Allow-Methods", "")
c.Header("Access-Control-Allow-Headers", "") c.Header("Access-Control-Allow-Headers", "")
c.Header("Access-Control-Allow-Credentials", "") c.Header("Access-Control-Allow-Credentials", "")
// For OPTIONS preflight, deny with 403 // For OPTIONS preflight, deny with 403
if c.Request.Method == "OPTIONS" { if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(403) c.AbortWithStatus(403)
return return
} }
c.Next() c.Next()
} }
} }
@@ -66,10 +67,10 @@ func noCORSMiddleware() gin.HandlerFunc {
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. // If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
ampAPI := engine.Group("/api") ampAPI := engine.Group("/api")
// Always disable CORS for management routes to prevent browser-based attacks // Always disable CORS for management routes to prevent browser-based attacks
ampAPI.Use(noCORSMiddleware()) ampAPI.Use(noCORSMiddleware())
// Apply localhost-only restriction if configured // Apply localhost-only restriction if configured
if restrictToLocalhost { if restrictToLocalhost {
ampAPI.Use(localhostOnlyMiddleware()) ampAPI.Use(localhostOnlyMiddleware())
@@ -112,6 +113,12 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler) claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses lazy evaluation to access proxy (which is created after routes are registered)
fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy {
return m.proxy
})
// Provider-specific routes under /api/provider/:provider // Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider") ampProviders := engine.Group("/api/provider")
if auth != nil { if auth != nil {
@@ -136,31 +143,33 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
} }
// Root-level routes (for providers that omit /v1, like groq/cerebras) // Root-level routes (for providers that omit /v1, like groq/cerebras)
provider.GET("/models", ampModelsHandler) // Wrap handlers with fallback logic to forward to ampcode.com when provider not found
provider.POST("/chat/completions", openaiHandlers.ChatCompletions) provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
provider.POST("/completions", openaiHandlers.Completions) provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
provider.POST("/responses", openaiResponsesHandlers.Responses) provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// /v1 routes (OpenAI/Claude-compatible endpoints) // /v1 routes (OpenAI/Claude-compatible endpoints)
v1Amp := provider.Group("/v1") v1Amp := provider.Group("/v1")
{ {
v1Amp.GET("/models", ampModelsHandler) v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
// OpenAI-compatible endpoints // OpenAI-compatible endpoints with fallback
v1Amp.POST("/chat/completions", openaiHandlers.ChatCompletions) v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", openaiHandlers.Completions) v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
v1Amp.POST("/responses", openaiResponsesHandlers.Responses) v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// Claude/Anthropic-compatible endpoints // Claude/Anthropic-compatible endpoints with fallback
v1Amp.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
} }
// /v1beta routes (Gemini native API) // /v1beta routes (Gemini native API)
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
v1betaAmp := provider.Group("/v1beta") v1betaAmp := provider.Group("/v1beta")
{ {
v1betaAmp.GET("/models", geminiHandlers.GeminiModels) v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
v1betaAmp.POST("/models/:action", geminiHandlers.GeminiHandler) v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler) v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
} }
} }

View File

@@ -62,6 +62,23 @@ func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) {
return base, &budgetValue, nil, true return base, &budgetValue, nil, true
} }
func NormalizeGeminiThinkingModel(modelName string) (string, map[string]any) {
baseModel, budget, include, matched := ParseGeminiThinkingSuffix(modelName)
if !matched {
return baseModel, nil
}
metadata := map[string]any{
GeminiOriginalModelMetadataKey: modelName,
}
if budget != nil {
metadata[GeminiThinkingBudgetMetadataKey] = *budget
}
if include != nil {
metadata[GeminiIncludeThoughtsMetadataKey] = *include
}
return baseModel, metadata
}
func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte { func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
if budget == nil && includeThoughts == nil { if budget == nil && includeThoughts == nil {
return body return body

View File

@@ -359,20 +359,7 @@ func cloneBytes(src []byte) []byte {
} }
func normalizeModelMetadata(modelName string) (string, map[string]any) { func normalizeModelMetadata(modelName string) (string, map[string]any) {
baseModel, budget, include, matched := util.ParseGeminiThinkingSuffix(modelName) return util.NormalizeGeminiThinkingModel(modelName)
if !matched {
return baseModel, nil
}
metadata := map[string]any{
util.GeminiOriginalModelMetadataKey: modelName,
}
if budget != nil {
metadata[util.GeminiThinkingBudgetMetadataKey] = *budget
}
if include != nil {
metadata[util.GeminiIncludeThoughtsMetadataKey] = *include
}
return baseModel, metadata
} }
func cloneMetadata(src map[string]any) map[string]any { func cloneMetadata(src map[string]any) map[string]any {