feat(management): add generic API call handler to management endpoints

This commit is contained in:
Luis Pater
2025-12-28 04:40:32 +08:00
parent 3a436e116a
commit ade7194792
2 changed files with 540 additions and 0 deletions

View File

@@ -0,0 +1,538 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}