Add token refresh handling for 401 responses across clients

- Implemented `RefreshTokens` method in client interfaces and Gemini clients.
- Updated handlers to call `RefreshTokens` on 401 responses and retry requests if token refresh succeeds.
- Enhanced error handling and retry logic to accommodate token refresh flow.
This commit is contained in:
Luis Pater
2025-08-30 16:10:56 +08:00
parent 1aad033fec
commit 512c8b600a
7 changed files with 46 additions and 0 deletions

View File

@@ -197,6 +197,14 @@ outLoop:
log.Debugf("http status code %d, switch client, %s", errInfo.StatusCode, util.HideAPIKey(cliClient.GetEmail())) log.Debugf("http status code %d, switch client, %s", errInfo.StatusCode, util.HideAPIKey(cliClient.GetEmail()))
retryCount++ retryCount++
continue outLoop continue outLoop
case 401:
log.Debugf("unauthorized request, try to refresh token, %s", util.HideAPIKey(cliClient.GetEmail()))
err := cliClient.RefreshTokens(cliCtx)
if err != nil {
log.Debugf("refresh token failed, switch client, %s", util.HideAPIKey(cliClient.GetEmail()))
}
retryCount++
continue outLoop
default: default:
// Forward other errors directly to the client // Forward other errors directly to the client
c.Status(errInfo.StatusCode) c.Status(errInfo.StatusCode)

View File

@@ -275,6 +275,14 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
log.Debugf("http status code %d, switch client", err.StatusCode) log.Debugf("http status code %d, switch client", err.StatusCode)
retryCount++ retryCount++
continue continue
case 401:
log.Debugf("unauthorized request, try to refresh token, %s", util.HideAPIKey(cliClient.GetEmail()))
errRefreshTokens := cliClient.RefreshTokens(cliCtx)
if errRefreshTokens != nil {
log.Debugf("refresh token failed, switch client, %s", util.HideAPIKey(cliClient.GetEmail()))
}
retryCount++
continue
default: default:
// Forward other errors directly to the client // Forward other errors directly to the client
c.Status(err.StatusCode) c.Status(err.StatusCode)

View File

@@ -17,6 +17,7 @@ import (
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -387,6 +388,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
log.Debugf("http status code %d, switch client", err.StatusCode) log.Debugf("http status code %d, switch client", err.StatusCode)
retryCount++ retryCount++
continue continue
case 401:
log.Debugf("unauthorized request, try to refresh token, %s", util.HideAPIKey(cliClient.GetEmail()))
errRefreshTokens := cliClient.RefreshTokens(cliCtx)
if errRefreshTokens != nil {
log.Debugf("refresh token failed, switch client, %s", util.HideAPIKey(cliClient.GetEmail()))
}
retryCount++
continue
default: default:
// Forward other errors directly to the client // Forward other errors directly to the client
c.Status(err.StatusCode) c.Status(err.StatusCode)

View File

@@ -18,6 +18,7 @@ import (
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -409,6 +410,14 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
log.Debugf("http status code %d, switch client", err.StatusCode) log.Debugf("http status code %d, switch client", err.StatusCode)
retryCount++ retryCount++
continue continue
case 401:
log.Debugf("unauthorized request, try to refresh token, %s", util.HideAPIKey(cliClient.GetEmail()))
errRefreshTokens := cliClient.RefreshTokens(cliCtx)
if errRefreshTokens != nil {
log.Debugf("refresh token failed, switch client, %s", util.HideAPIKey(cliClient.GetEmail()))
}
retryCount++
continue
default: default:
// Forward other errors directly to the client // Forward other errors directly to the client
c.Status(err.StatusCode) c.Status(err.StatusCode)

View File

@@ -860,3 +860,8 @@ func (c *GeminiCLIClient) GetUserAgent() string {
func (c *GeminiCLIClient) GetRequestMutex() *sync.Mutex { func (c *GeminiCLIClient) GetRequestMutex() *sync.Mutex {
return nil return nil
} }
func (c *GeminiCLIClient) RefreshTokens(ctx context.Context) error {
// API keys don't need refreshing
return nil
}

View File

@@ -434,3 +434,8 @@ func (c *GeminiClient) GetUserAgent() string {
func (c *GeminiClient) GetRequestMutex() *sync.Mutex { func (c *GeminiClient) GetRequestMutex() *sync.Mutex {
return nil return nil
} }
func (c *GeminiClient) RefreshTokens(ctx context.Context) error {
// API keys don't need refreshing
return nil
}

View File

@@ -51,4 +51,6 @@ type Client interface {
// Provider returns the name of the AI service provider (e.g., "gemini", "claude"). // Provider returns the name of the AI service provider (e.g., "gemini", "claude").
Provider() string Provider() string
RefreshTokens(ctx context.Context) error
} }