refactor: improve gemini bridge testability and code quality

- Change createGeminiBridgeHandler to accept gin.HandlerFunc instead of *gemini.GeminiAPIHandler
  This allows tests to inject mock handlers instead of duplicating bridge logic
- Replace magic number 8 with len(modelsPrefix) for better maintainability
- Remove redundant test case that doesn't test edge case in production
- Update routes.go to pass geminiHandlers.GeminiHandler directly

Addresses PR review feedback on test architecture and code clarity.

Amp-Thread-ID: https://ampcode.com/threads/T-1ae2c691-e434-4b99-a49a-10cabd3544db
This commit is contained in:
huynguyen03.dev
2025-12-07 10:15:42 +07:00
parent edc654edf9
commit 396899a530
3 changed files with 16 additions and 41 deletions

View File

@@ -4,7 +4,6 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
) )
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths // createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
@@ -15,16 +14,19 @@ import (
// //
// This extracts the model+method from the AMP path and sets it as the :action parameter // This extracts the model+method from the AMP path and sets it as the :action parameter
// so the standard Gemini handler can process it. // so the standard Gemini handler can process it.
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc { //
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Get the full path from the catch-all parameter // Get the full path from the catch-all parameter
path := c.Param("path") path := c.Param("path")
// Extract model:method from AMP CLI path format // Extract model:method from AMP CLI path format
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent // Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if idx := strings.Index(path, "/models/"); idx >= 0 { const modelsPrefix = "/models/"
// Extract everything after "/models/" if idx := strings.Index(path, modelsPrefix); idx >= 0 {
actionPart := path[idx+8:] // Skip "/models/" // Extract everything after modelsPrefix
actionPart := path[idx+len(modelsPrefix):]
// Check if model was mapped by FallbackHandler // Check if model was mapped by FallbackHandler
if mappedModel, exists := c.Get(MappedModelContextKey); exists { if mappedModel, exists := c.Get(MappedModelContextKey); exists {
@@ -44,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
Value: actionPart, Value: actionPart,
}) })
// Call the standard Gemini handler // Call the handler
geminiHandler.GeminiHandler(c) handler(c)
return return
} }

View File

@@ -3,12 +3,9 @@ package amp
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
) )
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) { func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
@@ -38,12 +35,6 @@ func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
mappedModel: "gemini-flash", mappedModel: "gemini-flash",
expectedAction: "gemini-flash:streamGenerateContent", expectedAction: "gemini-flash:streamGenerateContent",
}, },
{
name: "empty_mapped_model_ignored",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -55,27 +46,8 @@ func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
c.JSON(http.StatusOK, gin.H{"captured": capturedAction}) c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
} }
// Mirror the bridge logic from gemini_bridge.go // Use the actual createGeminiBridgeHandler function
bridgeHandler := func(c *gin.Context) { bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
path := c.Param("path")
if idx := strings.Index(path, "/models/"); idx >= 0 {
actionPart := path[idx+8:]
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:]
actionPart = strModel + method
}
}
}
c.Params = append(c.Params, gin.Param{Key: "action", Value: actionPart})
mockGeminiHandler(c)
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"})
}
r := gin.New() r := gin.New()
if tt.mappedModel != "" { if tt.mappedModel != "" {
@@ -103,9 +75,10 @@ func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) { func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
base := &handlers.BaseAPIHandler{} mockHandler := func(c *gin.Context) {
geminiHandlers := gemini.NewGeminiAPIHandler(base) c.JSON(http.StatusOK, gin.H{"ok": true})
bridgeHandler := createGeminiBridgeHandler(geminiHandlers) }
bridgeHandler := createGeminiBridgeHandler(mockHandler)
r := gin.New() r := gin.New()
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler) r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)

View File

@@ -169,7 +169,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// We bridge these to our standard Gemini handler to enable local OAuth. // We bridge these to our standard Gemini handler to enable local OAuth.
// If no local OAuth is available, falls back to ampcode.com proxy. // If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers) geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy() return m.getProxy()
}, m.modelMapper) }, m.modelMapper)