mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 20:30:51 +08:00
Merge pull request #432 from huynguyen03dev/fix/amp-gemini-model-mapping
fix(amp): pass mapped model to gemini bridge via context
This commit is contained in:
@@ -28,6 +28,9 @@ const (
|
|||||||
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||||
|
const MappedModelContextKey = "mapped_model"
|
||||||
|
|
||||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||||
fields := log.Fields{
|
fields := log.Fields{
|
||||||
@@ -141,6 +144,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// Mapping found - rewrite the model in request body
|
// Mapping found - rewrite the model in request body
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||||
|
c.Set(MappedModelContextKey, mappedModel)
|
||||||
resolvedModel = mappedModel
|
resolvedModel = mappedModel
|
||||||
usedMapping = true
|
usedMapping = true
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
|
||||||
@@ -15,16 +14,31 @@ import (
|
|||||||
//
|
//
|
||||||
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
// This extracts the model+method from the AMP path and sets it as the :action parameter
|
||||||
// so the standard Gemini handler can process it.
|
// so the standard Gemini handler can process it.
|
||||||
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
|
//
|
||||||
|
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
|
||||||
|
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// Get the full path from the catch-all parameter
|
// Get the full path from the catch-all parameter
|
||||||
path := c.Param("path")
|
path := c.Param("path")
|
||||||
|
|
||||||
// Extract model:method from AMP CLI path format
|
// Extract model:method from AMP CLI path format
|
||||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
const modelsPrefix = "/models/"
|
||||||
// Extract everything after "/models/"
|
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
|
||||||
actionPart := path[idx+8:] // Skip "/models/"
|
// Extract everything after modelsPrefix
|
||||||
|
actionPart := path[idx+len(modelsPrefix):]
|
||||||
|
|
||||||
|
// Check if model was mapped by FallbackHandler
|
||||||
|
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
||||||
|
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
||||||
|
// Replace the model part in the action
|
||||||
|
// actionPart is like "model-name:method"
|
||||||
|
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
||||||
|
method := actionPart[colonIdx:] // ":method"
|
||||||
|
actionPart = strModel + method
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set this as the :action parameter that the Gemini handler expects
|
// Set this as the :action parameter that the Gemini handler expects
|
||||||
c.Params = append(c.Params, gin.Param{
|
c.Params = append(c.Params, gin.Param{
|
||||||
@@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
|
|||||||
Value: actionPart,
|
Value: actionPart,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Call the standard Gemini handler
|
// Call the handler
|
||||||
geminiHandler.GeminiHandler(c)
|
handler(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
93
internal/api/modules/amp/gemini_bridge_test.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
mappedModel string // empty string means no mapping
|
||||||
|
expectedAction string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no_mapping_uses_url_model",
|
||||||
|
path: "/publishers/google/models/gemini-pro:generateContent",
|
||||||
|
mappedModel: "",
|
||||||
|
expectedAction: "gemini-pro:generateContent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mapped_model_replaces_url_model",
|
||||||
|
path: "/publishers/google/models/gemini-exp:generateContent",
|
||||||
|
mappedModel: "gemini-2.0-flash",
|
||||||
|
expectedAction: "gemini-2.0-flash:generateContent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mapping_preserves_method",
|
||||||
|
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
|
||||||
|
mappedModel: "gemini-flash",
|
||||||
|
expectedAction: "gemini-flash:streamGenerateContent",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedAction string
|
||||||
|
|
||||||
|
mockGeminiHandler := func(c *gin.Context) {
|
||||||
|
capturedAction = c.Param("action")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the actual createGeminiBridgeHandler function
|
||||||
|
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
if tt.mappedModel != "" {
|
||||||
|
r.Use(func(c *gin.Context) {
|
||||||
|
c.Set(MappedModelContextKey, tt.mappedModel)
|
||||||
|
c.Next()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
if capturedAction != tt.expectedAction {
|
||||||
|
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mockHandler := func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
|
}
|
||||||
|
bridgeHandler := createGeminiBridgeHandler(mockHandler)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||||
@@ -169,30 +168,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
// We bridge these to our standard Gemini handler to enable local OAuth.
|
// We bridge these to our standard Gemini handler to enable local OAuth.
|
||||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
return m.getProxy()
|
return m.getProxy()
|
||||||
})
|
}, m.modelMapper)
|
||||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
|
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||||
|
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
||||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
if c.Request.Method == "POST" {
|
if c.Request.Method == "POST" {
|
||||||
// Attempt to extract the model name from the AMP-style path
|
|
||||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
|
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
// FallbackHandler will check provider/mapping and proxy if needed
|
||||||
modelPart = modelPart[:colonIdx]
|
geminiV1Beta1Handler(c)
|
||||||
}
|
return
|
||||||
if modelPart != "" {
|
|
||||||
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
|
|
||||||
// Only handle locally when we have a provider; otherwise fall back to proxy
|
|
||||||
if providers := util.GetProviderName(normalized); len(providers) > 0 {
|
|
||||||
geminiV1Beta1Handler(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Non-POST or no local provider available -> proxy upstream
|
// Non-POST or no local provider available -> proxy upstream
|
||||||
|
|||||||
Reference in New Issue
Block a user