mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
refactor: improve thinking logic
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -379,7 +380,7 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
@@ -388,16 +389,13 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -420,7 +418,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
@@ -429,16 +427,13 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -461,7 +456,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- errMsg
|
||||
@@ -473,16 +468,13 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: true,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
@@ -595,38 +587,40 @@ func statusFromError(err error) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||
if initialSuffix.HasSuffix {
|
||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||
} else {
|
||||
resolvedModelName = resolvedBase
|
||||
}
|
||||
} else {
|
||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||
}
|
||||
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
|
||||
providers = util.GetProviderName(baseModel)
|
||||
// Fallback: if baseModel has no provider but differs from resolvedModelName,
|
||||
// try using the full model name. This handles edge cases where custom models
|
||||
// may be registered with their full suffixed name (e.g., "my-model(8192)").
|
||||
// Evaluated in Story 11.8: This fallback is intentionally preserved to support
|
||||
// custom model registrations that include thinking suffixes.
|
||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||
providers = util.GetProviderName(resolvedModelName)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
return nil, "", &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
}
|
||||
|
||||
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
||||
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
||||
// So, normalizedModel is already correctly set at this point.
|
||||
|
||||
return providers, normalizedModel, metadata, nil
|
||||
// The thinking suffix is preserved in the model name itself, so no
|
||||
// metadata-based configuration passing is needed.
|
||||
return providers, resolvedModelName, nil
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
@@ -638,10 +632,6 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func normalizeModelMetadata(modelName string) (string, map[string]any) {
|
||||
return util.NormalizeThinkingModel(modelName)
|
||||
}
|
||||
|
||||
func cloneMetadata(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user