mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(amp): add per-client upstream API key mapping support
This commit is contained in:
@@ -35,6 +35,7 @@ auth-dir: "~/.cli-proxy-api"
|
|||||||
api-keys:
|
api-keys:
|
||||||
- "your-api-key-1"
|
- "your-api-key-1"
|
||||||
- "your-api-key-2"
|
- "your-api-key-2"
|
||||||
|
- "your-api-key-3"
|
||||||
|
|
||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
@@ -166,6 +167,18 @@ ws-auth: false
|
|||||||
# upstream-url: "https://ampcode.com"
|
# upstream-url: "https://ampcode.com"
|
||||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||||
# upstream-api-key: ""
|
# upstream-api-key: ""
|
||||||
|
# # Per-client upstream API key mapping
|
||||||
|
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
|
||||||
|
# # Useful when different clients need to use different Amp accounts/quotas.
|
||||||
|
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
|
||||||
|
# upstream-api-keys:
|
||||||
|
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
|
||||||
|
# api-keys: # Client keys that use this upstream key
|
||||||
|
# - "your-api-key-1"
|
||||||
|
# - "your-api-key-2"
|
||||||
|
# - upstream-api-key: "amp_key_for_team_b"
|
||||||
|
# api-keys:
|
||||||
|
# - "your-api-key-3"
|
||||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||||
# restrict-management-to-localhost: false
|
# restrict-management-to-localhost: false
|
||||||
# # Force model mappings to run before checking local API keys (default: false)
|
# # Force model mappings to run before checking local API keys (default: false)
|
||||||
|
|||||||
@@ -940,3 +940,151 @@ func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
|
|||||||
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
|
||||||
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
|
||||||
|
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
if h == nil || h.cfg == nil {
|
||||||
|
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
|
||||||
|
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Normalize entries: trim whitespace, filter empty
|
||||||
|
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = normalized
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
|
||||||
|
// Matching is done by upstream-api-key value.
|
||||||
|
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
existing := make(map[string]int)
|
||||||
|
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||||
|
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, newEntry := range body.Value {
|
||||||
|
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
|
||||||
|
UpstreamAPIKey: upstreamKey,
|
||||||
|
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
|
||||||
|
}
|
||||||
|
if idx, ok := existing[upstreamKey]; ok {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
|
||||||
|
} else {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
|
||||||
|
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
|
||||||
|
// Body must be JSON: {"value": ["<upstream-api-key>", ...]}.
|
||||||
|
// If "value" is an empty array, clears all entries.
|
||||||
|
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
|
||||||
|
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value []string `json:"value"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&body); err != nil {
|
||||||
|
c.JSON(400, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if body.Value == nil {
|
||||||
|
c.JSON(400, gin.H{"error": "missing value"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty array means clear all
|
||||||
|
if len(body.Value) == 0 {
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = nil
|
||||||
|
h.persist(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toRemove := make(map[string]bool)
|
||||||
|
for _, key := range body.Value {
|
||||||
|
trimmed := strings.TrimSpace(key)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toRemove[trimmed] = true
|
||||||
|
}
|
||||||
|
if len(toRemove) == 0 {
|
||||||
|
c.JSON(400, gin.H{"error": "empty value"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
|
||||||
|
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
|
||||||
|
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
|
||||||
|
newEntries = append(newEntries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
|
||||||
|
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
apiKeys := normalizeAPIKeysList(entry.APIKeys)
|
||||||
|
out = append(out, config.AmpUpstreamAPIKeyEntry{
|
||||||
|
UpstreamAPIKey: upstreamKey,
|
||||||
|
APIKeys: apiKeys,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
|
||||||
|
func normalizeAPIKeysList(keys []string) []string {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed != "" {
|
||||||
|
out = append(out, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -227,11 +227,20 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check API key change
|
// Check API key change (both default and per-client mappings)
|
||||||
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
|
||||||
if apiKeyChanged {
|
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
|
||||||
|
if apiKeyChanged || upstreamAPIKeysChanged {
|
||||||
if m.secretSource != nil {
|
if m.secretSource != nil {
|
||||||
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||||
|
if apiKeyChanged {
|
||||||
|
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
|
||||||
|
ms.InvalidateCache()
|
||||||
|
}
|
||||||
|
if upstreamAPIKeysChanged {
|
||||||
|
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
|
||||||
|
}
|
||||||
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
|
||||||
ms.InvalidateCache()
|
ms.InvalidateCache()
|
||||||
}
|
}
|
||||||
@@ -251,10 +260,22 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
|||||||
|
|
||||||
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
|
||||||
if m.secretSource == nil {
|
if m.secretSource == nil {
|
||||||
m.secretSource = NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
|
||||||
|
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
|
||||||
|
mappedSource := NewMappedSecretSource(defaultSource)
|
||||||
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
|
m.secretSource = mappedSource
|
||||||
|
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
|
||||||
|
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
|
||||||
|
ms.InvalidateCache()
|
||||||
|
ms.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
|
||||||
|
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
|
||||||
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
|
||||||
ms.InvalidateCache()
|
ms.InvalidateCache()
|
||||||
|
mappedSource := NewMappedSecretSource(ms)
|
||||||
|
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
|
||||||
|
m.secretSource = mappedSource
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
|
||||||
@@ -313,6 +334,66 @@ func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) b
|
|||||||
return oldKey != newKey
|
return oldKey != newKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
|
||||||
|
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
|
||||||
|
if old == nil {
|
||||||
|
return len(new.UpstreamAPIKeys) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build map for comparison: upstreamKey -> set of clientKeys
|
||||||
|
type entryInfo struct {
|
||||||
|
upstreamKey string
|
||||||
|
clientKeys map[string]struct{}
|
||||||
|
}
|
||||||
|
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
|
||||||
|
for i, entry := range old.UpstreamAPIKeys {
|
||||||
|
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
|
||||||
|
for _, k := range entry.APIKeys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clientKeys[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
oldEntries[i] = entryInfo{
|
||||||
|
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
|
||||||
|
clientKeys: clientKeys,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, newEntry := range new.UpstreamAPIKeys {
|
||||||
|
if i >= len(oldEntries) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
oldE := oldEntries[i]
|
||||||
|
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
|
||||||
|
for _, k := range newEntry.APIKeys {
|
||||||
|
trimmed := strings.TrimSpace(k)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newKeys[trimmed] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(newKeys) != len(oldE.clientKeys) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for k := range newKeys {
|
||||||
|
if _, ok := oldE.clientKeys[k]; !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||||
return m.modelMapper
|
return m.modelMapper
|
||||||
|
|||||||
@@ -312,3 +312,41 @@ func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
|
||||||
|
m := &AmpModule{}
|
||||||
|
|
||||||
|
oldCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||||
|
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
|
||||||
|
m := &AmpModule{}
|
||||||
|
|
||||||
|
oldCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
newCfg := &config.AmpCode{
|
||||||
|
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
|
||||||
|
t.Fatal("expected no change when only whitespace/empty entries differ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,33 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func removeQueryValuesMatching(req *http.Request, key string, match string) {
|
||||||
|
if req == nil || req.URL == nil || match == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q := req.URL.Query()
|
||||||
|
values, ok := q[key]
|
||||||
|
if !ok || len(values) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
kept := make([]string, 0, len(values))
|
||||||
|
for _, v := range values {
|
||||||
|
if v == match {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(kept) == 0 {
|
||||||
|
q.Del(key)
|
||||||
|
} else {
|
||||||
|
q[key] = kept
|
||||||
|
}
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
// readCloser wraps a reader and forwards Close to a separate closer.
|
// readCloser wraps a reader and forwards Close to a separate closer.
|
||||||
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
// Used to restore peeked bytes while preserving upstream body Close behavior.
|
||||||
type readCloser struct {
|
type readCloser struct {
|
||||||
@@ -45,6 +72,14 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
// We will set our own Authorization using the configured upstream-api-key
|
// We will set our own Authorization using the configured upstream-api-key
|
||||||
req.Header.Del("Authorization")
|
req.Header.Del("Authorization")
|
||||||
req.Header.Del("X-Api-Key")
|
req.Header.Del("X-Api-Key")
|
||||||
|
req.Header.Del("X-Goog-Api-Key")
|
||||||
|
|
||||||
|
// Remove query-based credentials if they match the authenticated client API key.
|
||||||
|
// This prevents leaking client auth material to the Amp upstream while avoiding
|
||||||
|
// breaking unrelated upstream query parameters.
|
||||||
|
clientKey := getClientAPIKeyFromContext(req.Context())
|
||||||
|
removeQueryValuesMatching(req, "key", clientKey)
|
||||||
|
removeQueryValuesMatching(req, "auth_token", clientKey)
|
||||||
|
|
||||||
// Preserve correlation headers for debugging
|
// Preserve correlation headers for debugging
|
||||||
if req.Header.Get("X-Request-ID") == "" {
|
if req.Header.Get("X-Request-ID") == "" {
|
||||||
|
|||||||
@@ -3,11 +3,15 @@ package amp
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper: compress data with gzip
|
// Helper: compress data with gzip
|
||||||
@@ -306,6 +310,159 @@ func TestReverseProxy_EmptySecret(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
|
||||||
|
type captured struct {
|
||||||
|
headers http.Header
|
||||||
|
query string
|
||||||
|
}
|
||||||
|
got := make(chan captured, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer client-key")
|
||||||
|
req.Header.Set("X-Api-Key", "client-key")
|
||||||
|
req.Header.Set("X-Goog-Api-Key", "client-key")
|
||||||
|
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
c := <-got
|
||||||
|
|
||||||
|
// These are client-provided credentials and must not reach the upstream.
|
||||||
|
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
|
||||||
|
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
|
||||||
|
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
|
||||||
|
}
|
||||||
|
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
|
||||||
|
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query-based credentials should be stripped only when they match the authenticated client key.
|
||||||
|
// Should keep unrelated values and parameters.
|
||||||
|
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
|
||||||
|
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
|
||||||
|
}
|
||||||
|
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
|
||||||
|
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
|
||||||
|
gotHeaders := make(chan http.Header, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotHeaders <- r.Header.Clone()
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
mapped := NewMappedSecretSource(defaultSource)
|
||||||
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Simulate clientAPIKeyMiddleware injection (per-request)
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(srv.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
hdr := <-gotHeaders
|
||||||
|
if hdr.Get("X-Api-Key") != "u1" {
|
||||||
|
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||||
|
}
|
||||||
|
if hdr.Get("Authorization") != "Bearer u1" {
|
||||||
|
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
|
||||||
|
gotHeaders := make(chan http.Header, 1)
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotHeaders <- r.Header.Clone()
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`ok`))
|
||||||
|
}))
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
mapped := NewMappedSecretSource(defaultSource)
|
||||||
|
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy, err := createReverseProxy(upstream.URL, mapped)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
|
||||||
|
proxy.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(srv.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
|
||||||
|
hdr := <-gotHeaders
|
||||||
|
if hdr.Get("X-Api-Key") != "default" {
|
||||||
|
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
|
||||||
|
}
|
||||||
|
if hdr.Get("Authorization") != "Bearer default" {
|
||||||
|
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||||
// Point proxy to a non-routable address to trigger error
|
// Point proxy to a non-routable address to trigger error
|
||||||
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package amp
|
package amp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,6 +17,37 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// clientAPIKeyContextKey is the context key used to pass the client API key
|
||||||
|
// from gin.Context to the request context for SecretSource lookup.
|
||||||
|
type clientAPIKeyContextKey struct{}
|
||||||
|
|
||||||
|
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
|
||||||
|
// into the request context so that SecretSource can look it up for per-client upstream routing.
|
||||||
|
func clientAPIKeyMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Extract the client API key from gin context (set by AuthMiddleware)
|
||||||
|
if apiKey, exists := c.Get("apiKey"); exists {
|
||||||
|
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
|
||||||
|
// Inject into request context for SecretSource.Get(ctx) to read
|
||||||
|
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientAPIKeyFromContext retrieves the client API key from request context.
|
||||||
|
// Returns empty string if not present.
|
||||||
|
func getClientAPIKeyFromContext(ctx context.Context) string {
|
||||||
|
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
|
||||||
|
if keyStr, ok := val.(string); ok {
|
||||||
|
return keyStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
|
||||||
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
|
||||||
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
|
||||||
@@ -129,6 +161,9 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Inject client API key into request context for per-client upstream routing
|
||||||
|
ampAPI.Use(clientAPIKeyMiddleware())
|
||||||
|
|
||||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||||
proxyHandler := func(c *gin.Context) {
|
proxyHandler := func(c *gin.Context) {
|
||||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
||||||
@@ -175,6 +210,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
if authWithBypass != nil {
|
if authWithBypass != nil {
|
||||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||||
}
|
}
|
||||||
|
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
|
||||||
|
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
|
||||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||||
@@ -244,6 +281,8 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
if auth != nil {
|
if auth != nil {
|
||||||
ampProviders.Use(auth)
|
ampProviders.Use(auth)
|
||||||
}
|
}
|
||||||
|
// Inject client API key into request context for per-client upstream routing
|
||||||
|
ampProviders.Use(clientAPIKeyMiddleware())
|
||||||
|
|
||||||
provider := ampProviders.Group("/:provider")
|
provider := ampProviders.Group("/:provider")
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecretSource provides Amp API keys with configurable precedence and caching
|
// SecretSource provides Amp API keys with configurable precedence and caching
|
||||||
@@ -164,3 +167,82 @@ func NewStaticSecretSource(key string) *StaticSecretSource {
|
|||||||
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
|
||||||
return s.key, nil
|
return s.key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
|
||||||
|
// When a request context contains a client API key that matches a configured mapping,
|
||||||
|
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
|
||||||
|
type MappedSecretSource struct {
|
||||||
|
defaultSource SecretSource
|
||||||
|
mu sync.RWMutex
|
||||||
|
lookup map[string]string // clientKey -> upstreamKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
|
||||||
|
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
|
||||||
|
return &MappedSecretSource{
|
||||||
|
defaultSource: defaultSource,
|
||||||
|
lookup: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves the Amp API key, checking per-client mappings first.
|
||||||
|
// If the request context contains a client API key that matches a configured mapping,
|
||||||
|
// returns the corresponding upstream key. Otherwise, falls back to the default source.
|
||||||
|
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
|
||||||
|
// Try to get client API key from request context
|
||||||
|
clientKey := getClientAPIKeyFromContext(ctx)
|
||||||
|
if clientKey != "" {
|
||||||
|
s.mu.RLock()
|
||||||
|
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return upstreamKey, nil
|
||||||
|
}
|
||||||
|
s.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to default source
|
||||||
|
return s.defaultSource.Get(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
|
||||||
|
// If the same client key appears in multiple entries, logs a warning and uses the first one.
|
||||||
|
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
|
||||||
|
newLookup := make(map[string]string)
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
|
||||||
|
if upstreamKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, clientKey := range entry.APIKeys {
|
||||||
|
trimmedKey := strings.TrimSpace(clientKey)
|
||||||
|
if trimmedKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := newLookup[trimmedKey]; exists {
|
||||||
|
// Log warning for duplicate client key, first one wins
|
||||||
|
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newLookup[trimmedKey] = upstreamKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.lookup = newLookup
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
|
||||||
|
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
|
||||||
|
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||||
|
ms.UpdateExplicitKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
|
||||||
|
func (s *MappedSecretSource) InvalidateCache() {
|
||||||
|
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
|
||||||
|
ms.InvalidateCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/sirupsen/logrus/hooks/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
|
||||||
@@ -278,3 +282,85 @@ func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
|
|||||||
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
got, err := s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "u1" {
|
||||||
|
t.Fatalf("want u1, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
|
||||||
|
got, err = s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "default" {
|
||||||
|
t.Fatalf("want default fallback, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u2",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
|
||||||
|
got, err := s.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "u1" {
|
||||||
|
t.Fatalf("want u1 (first wins), got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
|
||||||
|
hook := test.NewLocal(log.StandardLogger())
|
||||||
|
defer hook.Reset()
|
||||||
|
|
||||||
|
defaultSource := NewStaticSecretSource("default")
|
||||||
|
s := NewMappedSecretSource(defaultSource)
|
||||||
|
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u1",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpstreamAPIKey: "u2",
|
||||||
|
APIKeys: []string{"k1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
foundWarning := false
|
||||||
|
for _, entry := range hook.AllEntries() {
|
||||||
|
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
|
||||||
|
foundWarning = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundWarning {
|
||||||
|
t.Fatal("expected warning log for duplicate client key, but none was found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -551,6 +551,10 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
|
||||||
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
|
||||||
|
|
||||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||||
|
|||||||
@@ -163,6 +163,11 @@ type AmpCode struct {
|
|||||||
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
|
||||||
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
|
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
|
||||||
|
// When a client authenticates with a key that matches an entry, that upstream key is used.
|
||||||
|
// If no match is found, falls back to UpstreamAPIKey (default behavior).
|
||||||
|
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
|
||||||
|
|
||||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||||
@@ -178,6 +183,17 @@ type AmpCode struct {
|
|||||||
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
|
||||||
|
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
|
||||||
|
// is used for the upstream Amp request.
|
||||||
|
type AmpUpstreamAPIKeyEntry struct {
|
||||||
|
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
|
||||||
|
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
|
||||||
|
|
||||||
|
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
|
||||||
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
}
|
||||||
|
|
||||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||||
type PayloadConfig struct {
|
type PayloadConfig struct {
|
||||||
// Default defines rules that only set parameters when they are missing in the payload.
|
// Default defines rules that only set parameters when they are missing in the payload.
|
||||||
|
|||||||
@@ -614,71 +614,6 @@ func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
|
|
||||||
// propertyNames is used to validate object property names (e.g., must match a pattern)
|
|
||||||
// Gemini doesn't support this keyword and will reject requests containing it
|
|
||||||
input := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"propertyNames": {
|
|
||||||
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
|
|
||||||
},
|
|
||||||
"additionalProperties": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
expected := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"metadata": {
|
|
||||||
"type": "object"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
result := CleanJSONSchemaForGemini(input)
|
|
||||||
compareJSON(t, expected, result)
|
|
||||||
|
|
||||||
// Verify propertyNames is completely removed
|
|
||||||
if strings.Contains(result, "propertyNames") {
|
|
||||||
t.Errorf("propertyNames keyword should be removed, got: %s", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
|
|
||||||
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
|
|
||||||
input := `{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"items": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"config": {
|
|
||||||
"type": "object",
|
|
||||||
"propertyNames": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
result := CleanJSONSchemaForGemini(input)
|
|
||||||
|
|
||||||
if strings.Contains(result, "propertyNames") {
|
|
||||||
t.Errorf("Nested propertyNames should be removed, got: %s", result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||||
var expMap, actMap map[string]interface{}
|
var expMap, actMap map[string]interface{}
|
||||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||||
|
|||||||
@@ -185,6 +185,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||||
}
|
}
|
||||||
|
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
|
||||||
|
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
|
||||||
|
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
|
||||||
|
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
|
||||||
|
}
|
||||||
|
|
||||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||||
changes = append(changes, entries...)
|
changes = append(changes, entries...)
|
||||||
@@ -301,3 +306,43 @@ func formatProxyURL(raw string) string {
|
|||||||
}
|
}
|
||||||
return scheme + "://" + host
|
return scheme + "://" + host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalStringSet(a, b []string) bool {
|
||||||
|
if len(a) == 0 && len(b) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
aSet := make(map[string]struct{}, len(a))
|
||||||
|
for _, k := range a {
|
||||||
|
aSet[strings.TrimSpace(k)] = struct{}{}
|
||||||
|
}
|
||||||
|
bSet := make(map[string]struct{}, len(b))
|
||||||
|
for _, k := range b {
|
||||||
|
bSet[strings.TrimSpace(k)] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(aSet) != len(bSet) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k := range aSet {
|
||||||
|
if _, ok := bSet[k]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
|
||||||
|
// Comparison is done by count and content (upstream key and client keys).
|
||||||
|
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,6 +56,10 @@ func setupAmpRouter(h *management.Handler) *gin.Engine {
|
|||||||
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
|
||||||
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
|
||||||
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
|
||||||
|
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
|
||||||
|
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
|
||||||
|
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
|
||||||
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
|
||||||
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
|
||||||
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
|
||||||
@@ -188,6 +192,90 @@ func TestPutAmpUpstreamAPIKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
|
||||||
|
h, configPath := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it was persisted to disk
|
||||||
|
loaded, err := config.LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load config from disk: %v", err)
|
||||||
|
}
|
||||||
|
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
|
||||||
|
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
|
||||||
|
}
|
||||||
|
entry := loaded.AmpCode.UpstreamAPIKeys[0]
|
||||||
|
if entry.UpstreamAPIKey != "u1" {
|
||||||
|
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
|
||||||
|
}
|
||||||
|
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
|
||||||
|
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it is returned by GET /ampcode
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string]config.AmpCode
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
|
||||||
|
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
|
||||||
|
h, _ := newAmpTestHandler(t)
|
||||||
|
r := setupAmpRouter(h)
|
||||||
|
|
||||||
|
// Seed with one entry
|
||||||
|
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
deleteBody := `{"value":[]}`
|
||||||
|
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
var resp map[string][]config.AmpUpstreamAPIKeyEntry
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
|
||||||
|
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
|
||||||
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
|
||||||
h, _ := newAmpTestHandler(t)
|
h, _ := newAmpTestHandler(t)
|
||||||
|
|||||||
Reference in New Issue
Block a user