mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
fix(amp): pass mapped model to gemini bridge via context
Gemini handler extracts model from URL path, not JSON body, so rewriting the request body alone wasn't sufficient for model mapping. - Add MappedModelContextKey constant for context passing - Update routes.go to use NewFallbackHandlerWithMapper - Add check for valid mapping before routing to local handler - Add tests for gemini bridge model mapping
This commit is contained in:
@@ -28,6 +28,9 @@ const (
|
||||
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||
)
|
||||
|
||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||
const MappedModelContextKey = "mapped_model"
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
fields := log.Fields{
|
||||
@@ -141,6 +144,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Mapping found - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
|
||||
|
||||
@@ -26,6 +26,18 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
|
||||
// Extract everything after "/models/"
|
||||
actionPart := path[idx+8:] // Skip "/models/"
|
||||
|
||||
// Check if model was mapped by FallbackHandler
|
||||
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
||||
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
||||
// Replace the model part in the action
|
||||
// actionPart is like "model-name:method"
|
||||
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
||||
method := actionPart[colonIdx:] // ":method"
|
||||
actionPart = strModel + method
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set this as the :action parameter that the Gemini handler expects
|
||||
c.Params = append(c.Params, gin.Param{
|
||||
Key: "action",
|
||||
|
||||
120
internal/api/modules/amp/gemini_bridge_test.go
Normal file
120
internal/api/modules/amp/gemini_bridge_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
)
|
||||
|
||||
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
mappedModel string // empty string means no mapping
|
||||
expectedAction string
|
||||
}{
|
||||
{
|
||||
name: "no_mapping_uses_url_model",
|
||||
path: "/publishers/google/models/gemini-pro:generateContent",
|
||||
mappedModel: "",
|
||||
expectedAction: "gemini-pro:generateContent",
|
||||
},
|
||||
{
|
||||
name: "mapped_model_replaces_url_model",
|
||||
path: "/publishers/google/models/gemini-exp:generateContent",
|
||||
mappedModel: "gemini-2.0-flash",
|
||||
expectedAction: "gemini-2.0-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "mapping_preserves_method",
|
||||
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
|
||||
mappedModel: "gemini-flash",
|
||||
expectedAction: "gemini-flash:streamGenerateContent",
|
||||
},
|
||||
{
|
||||
name: "empty_mapped_model_ignored",
|
||||
path: "/publishers/google/models/gemini-pro:generateContent",
|
||||
mappedModel: "",
|
||||
expectedAction: "gemini-pro:generateContent",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedAction string
|
||||
|
||||
mockGeminiHandler := func(c *gin.Context) {
|
||||
capturedAction = c.Param("action")
|
||||
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
|
||||
}
|
||||
|
||||
// Mirror the bridge logic from gemini_bridge.go
|
||||
bridgeHandler := func(c *gin.Context) {
|
||||
path := c.Param("path")
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
actionPart := path[idx+8:]
|
||||
|
||||
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
|
||||
if strModel, ok := mappedModel.(string); ok && strModel != "" {
|
||||
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
|
||||
method := actionPart[colonIdx:]
|
||||
actionPart = strModel + method
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Params = append(c.Params, gin.Param{Key: "action", Value: actionPart})
|
||||
mockGeminiHandler(c)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
if tt.mappedModel != "" {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(MappedModelContextKey, tt.mappedModel)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
if capturedAction != tt.expectedAction {
|
||||
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(base)
|
||||
bridgeHandler := createGeminiBridgeHandler(geminiHandlers)
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -170,9 +170,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
|
||||
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
})
|
||||
}, m.modelMapper)
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
|
||||
@@ -187,8 +187,18 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
}
|
||||
if modelPart != "" {
|
||||
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
|
||||
// Only handle locally when we have a provider; otherwise fall back to proxy
|
||||
// Only handle locally when we have a provider or a valid mapping; otherwise fall back to proxy
|
||||
hasProvider := false
|
||||
if providers := util.GetProviderName(normalized); len(providers) > 0 {
|
||||
hasProvider = true
|
||||
} else if m.modelMapper != nil {
|
||||
// Check if mapped model has provider (MapModel returns target only if it has providers)
|
||||
if mapped := m.modelMapper.MapModel(normalized); mapped != "" {
|
||||
hasProvider = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasProvider {
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user