fix(amp): pass mapped model to gemini bridge via context

Gemini handler extracts model from URL path, not JSON body, so
rewriting the request body alone wasn't sufficient for model mapping.

- Add MappedModelContextKey constant for context passing
- Update routes.go to use NewFallbackHandlerWithMapper
- Add check for valid mapping before routing to local handler
- Add tests for gemini bridge model mapping
This commit is contained in:
huynguyen03.dev
2025-12-06 18:59:44 +07:00
parent 7ea14479fb
commit 08586334af
4 changed files with 150 additions and 3 deletions

View File

@@ -28,6 +28,9 @@ const (
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
)
// MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model"
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{
@@ -141,6 +144,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Mapping found - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true

View File

@@ -26,6 +26,18 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
// Extract everything after "/models/"
actionPart := path[idx+8:] // Skip "/models/"
// Check if model was mapped by FallbackHandler
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
// Replace the model part in the action
// actionPart is like "model-name:method"
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:] // ":method"
actionPart = strModel + method
}
}
}
// Set this as the :action parameter that the Gemini handler expects
c.Params = append(c.Params, gin.Param{
Key: "action",

View File

@@ -0,0 +1,120 @@
package amp
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
)
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
mappedModel string // empty string means no mapping
expectedAction string
}{
{
name: "no_mapping_uses_url_model",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
{
name: "mapped_model_replaces_url_model",
path: "/publishers/google/models/gemini-exp:generateContent",
mappedModel: "gemini-2.0-flash",
expectedAction: "gemini-2.0-flash:generateContent",
},
{
name: "mapping_preserves_method",
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
mappedModel: "gemini-flash",
expectedAction: "gemini-flash:streamGenerateContent",
},
{
name: "empty_mapped_model_ignored",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedAction string
mockGeminiHandler := func(c *gin.Context) {
capturedAction = c.Param("action")
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
}
// Mirror the bridge logic from gemini_bridge.go
bridgeHandler := func(c *gin.Context) {
path := c.Param("path")
if idx := strings.Index(path, "/models/"); idx >= 0 {
actionPart := path[idx+8:]
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:]
actionPart = strModel + method
}
}
}
c.Params = append(c.Params, gin.Param{Key: "action", Value: actionPart})
mockGeminiHandler(c)
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"})
}
r := gin.New()
if tt.mappedModel != "" {
r.Use(func(c *gin.Context) {
c.Set(MappedModelContextKey, tt.mappedModel)
c.Next()
})
}
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
if capturedAction != tt.expectedAction {
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
}
})
}
}
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
gin.SetMode(gin.TestMode)
base := &handlers.BaseAPIHandler{}
geminiHandlers := gemini.NewGeminiAPIHandler(base)
bridgeHandler := createGeminiBridgeHandler(geminiHandlers)
r := gin.New()
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
}
}

View File

@@ -170,9 +170,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
})
}, m.modelMapper)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
@@ -187,8 +187,18 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
}
if modelPart != "" {
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
// Only handle locally when we have a provider; otherwise fall back to proxy
// Only handle locally when we have a provider or a valid mapping; otherwise fall back to proxy
hasProvider := false
if providers := util.GetProviderName(normalized); len(providers) > 0 {
hasProvider = true
} else if m.modelMapper != nil {
// Check if mapped model has provider (MapModel returns target only if it has providers)
if mapped := m.modelMapper.MapModel(normalized); mapped != "" {
hasProvider = true
}
}
if hasProvider {
geminiV1Beta1Handler(c)
return
}