From afcab5efda90f05b51ffb22380fb71b6888ae5a9 Mon Sep 17 00:00:00 2001 From: huynhgiabuu Date: Sun, 7 Dec 2025 22:47:43 +0700 Subject: [PATCH] feat: add prioritize-model-mappings config option Add a configuration option to control whether model mappings take precedence over local API keys for Amp CLI requests. - Add PrioritizeModelMappings field to AmpCode config struct - When false (default): Local API keys take precedence (original behavior) - When true: Model mappings take precedence over local API keys - Add management API endpoints GET/PUT /prioritize-model-mappings This allows users who want mapping priority to enable it explicitly while preserving backward compatibility. Config example: ampcode: model-mappings: - from: claude-opus-4-5-20251101 to: gemini-claude-opus-4-5-thinking prioritize-model-mappings: true --- .../api/handlers/management/config_basic.go | 8 ++ internal/api/modules/amp/amp.go | 10 +++ internal/api/modules/amp/fallback_handlers.go | 89 +++++++++++++------ internal/api/modules/amp/routes.go | 4 +- internal/api/server.go | 4 + internal/config/config.go | 4 + 6 files changed, 90 insertions(+), 29 deletions(-) diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae292982..e61c695e 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -241,3 +241,11 @@ func (h *Handler) DeleteProxyURL(c *gin.Context) { h.cfg.ProxyURL = "" h.persist(c) } + +// Prioritize Model Mappings (for Amp CLI) +func (h *Handler) GetPrioritizeModelMappings(c *gin.Context) { + c.JSON(200, gin.H{"prioritize-model-mappings": h.cfg.AmpCode.PrioritizeModelMappings}) +} +func (h *Handler) PutPrioritizeModelMappings(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.PrioritizeModelMappings = v }) +} diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index dabb7404..5c7c2708 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -100,6 +100,16 @@ func (m *AmpModule) Name() string { return "amp-routing" } +// getPrioritizeModelMappings returns whether model mappings should take precedence over local API keys +func (m *AmpModule) getPrioritizeModelMappings() bool { + m.configMu.RLock() + defer m.configMu.RUnlock() + if m.lastConfig == nil { + return false + } + return m.lastConfig.PrioritizeModelMappings +} + // Register sets up Amp routes if configured. // This implements the RouteModuleV2 interface with Context. // Routes are registered only once via sync.Once for idempotent behavior. diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 0cbe0e1a..771e2713 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -77,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid // FallbackHandler wraps a standard handler with fallback logic to ampcode.com // when the model's provider is not available in CLIProxyAPI type FallbackHandler struct { - getProxy func() *httputil.ReverseProxy - modelMapper ModelMapper + getProxy func() *httputil.ReverseProxy + modelMapper ModelMapper + getPrioritizeModelMappings func() bool } // NewFallbackHandler creates a new fallback handler wrapper // The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes) func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler { return &FallbackHandler{ - getProxy: getProxy, + getProxy: getProxy, + getPrioritizeModelMappings: func() bool { return false }, } } // NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support -func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler { +func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, getPrioritize func() bool) *FallbackHandler { + if getPrioritize == nil { + getPrioritize = func() bool { return false } + } return &FallbackHandler{ - getProxy: getProxy, - modelMapper: mapper, + getProxy: getProxy, + modelMapper: mapper, + getPrioritizeModelMappings: getPrioritize, } } @@ -130,34 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Normalize model (handles Gemini thinking suffixes) normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName) - // Check if we have providers for this model - providers := util.GetProviderName(normalizedModel) - // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel usedMapping := false + var providers []string - if len(providers) == 0 { - // No providers configured - check if we have a model mapping + // Check if model mappings should take priority over local API keys + prioritizeMappings := fh.getPrioritizeModelMappings != nil && fh.getPrioritizeModelMappings() + + if prioritizeMappings { + // PRIORITY MODE: Check model mappings FIRST (takes precedence over local API keys) + // This allows users to route Amp requests to their preferred OAuth providers if fh.modelMapper != nil { if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - - // Get providers for the mapped model - providers = util.GetProviderName(mappedModel) - - // Continue to handler with remapped model - goto handleRequest + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } } } - // No mapping found - check if we have a proxy for fallback + // If no mapping applied, check for local providers + if !usedMapping { + providers = util.GetProviderName(normalizedModel) + } + } else { + // DEFAULT MODE: Check local providers first, then mappings as fallback + providers = util.GetProviderName(normalizedModel) + + if len(providers) == 0 { + // No providers configured - check if we have a model mapping + if fh.modelMapper != nil { + if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { + // Mapping found - check if we have a provider for the mapped model + mappedProviders := util.GetProviderName(mappedModel) + if len(mappedProviders) > 0 { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders + } + } + } + } + } + + // If no providers available, fallback to ampcode.com + if len(providers) == 0 { proxy := fh.getProxy() if proxy != nil { // Log: Forwarding to ampcode.com (uses Amp credits) @@ -175,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath) } - handleRequest: - // Log the routing decision providerName := "" if len(providers) > 0 { diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 6826dbbe..dedbd444 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -171,7 +171,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler) geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.getPrioritizeModelMappings) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) // Route POST model calls through Gemini bridge with FallbackHandler. @@ -209,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return m.getProxy() - }, m.modelMapper) + }, m.modelMapper, m.getPrioritizeModelMappings) // Provider-specific routes under /api/provider/:provider ampProviders := engine.Group("/api/provider") diff --git a/internal/api/server.go b/internal/api/server.go index 9e1c5848..93d13557 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -520,6 +520,10 @@ func (s *Server) registerManagementRoutes() { mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) + mgmt.GET("/prioritize-model-mappings", s.mgmt.GetPrioritizeModelMappings) + mgmt.PUT("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) + mgmt.PATCH("/prioritize-model-mappings", s.mgmt.PutPrioritizeModelMappings) + mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/config/config.go b/internal/config/config.go index 2681d049..d7e455f6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -143,6 +143,10 @@ type AmpCode struct { // When Amp requests a model that isn't available locally, these mappings // allow routing to an alternative model that IS available. ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"` + + // PrioritizeModelMappings when true, model mappings take precedence over local API keys. + // When false (default), local API keys are used first if available. + PrioritizeModelMappings bool `yaml:"prioritize-model-mappings" json:"prioritize-model-mappings"` } // PayloadConfig defines default and override parameter rules applied to provider payloads.