feat(registry): unify Gemini models and add AI Studio set

This commit is contained in:
hkfires
2025-10-28 19:00:25 +08:00
parent 5891785125
commit 5dced4c0a6
5 changed files with 106 additions and 164 deletions

View File

@@ -19,27 +19,27 @@ import (
"github.com/tidwall/sjson"
)
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
type AistudioExecutor struct {
// AIStudioExecutor routes AI Studio requests through a websocket-backed transport.
type AIStudioExecutor struct {
provider string
relay *wsrelay.Manager
cfg *config.Config
}
// NewAistudioExecutor constructs a websocket executor for the provider name.
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
// NewAIStudioExecutor constructs a websocket executor for the provider name.
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
}
// Identifier returns the provider key served by this executor.
func (e *AistudioExecutor) Identifier() string { return e.provider }
// Identifier returns the logical provider key for routing.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest is a no-op because websocket transport already injects headers.
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -66,7 +66,7 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
@@ -92,7 +92,7 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, nil
}
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -118,7 +118,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
@@ -151,7 +151,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
filtered := filterAistudioUsageMetadata(event.Payload)
filtered := filterAIStudioUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
}
@@ -188,7 +188,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return stream, nil
}
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
@@ -215,7 +215,7 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
@@ -241,7 +241,7 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
_ = ctx
return auth, nil
}
@@ -252,7 +252,7 @@ type translatedPayload struct {
toFormat sdktranslator.Format
}
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
@@ -275,7 +275,7 @@ func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
}
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string {
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
if action == "streamGenerateContent" {
if alt == "" {
@@ -289,9 +289,9 @@ func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
return base
}
// filterAistudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// filterAIStudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// only the terminal chunk retains token statistics.
func filterAistudioUsageMetadata(payload []byte) []byte {
func filterAIStudioUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}