mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
feat: Add support for iFlow provider
This commit is contained in:
143
internal/auth/iflow/oauth_server.go
Normal file
143
internal/auth/iflow/oauth_server.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package iflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const errorRedirectURL = "https://iflow.cn/oauth/error"
|
||||
|
||||
// OAuthResult captures the outcome of the local OAuth callback.
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// OAuthServer provides a minimal HTTP server for handling the iFlow OAuth callback.
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
result chan *OAuthResult
|
||||
errChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewOAuthServer constructs a new OAuthServer bound to the provided port.
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
result: make(chan *OAuthResult, 1),
|
||||
errChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the callback listener.
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.running {
|
||||
return fmt.Errorf("iflow oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth2callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully terminates the callback listener.
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// WaitForCallback blocks until a callback result, server error, or timeout occurs.
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case res := <-s.result:
|
||||
return res, nil
|
||||
case err := <-s.errChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
if code == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code"})
|
||||
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
state := query.Get("state")
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
http.Redirect(w, r, SuccessRedirectURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(res *OAuthResult) {
|
||||
select {
|
||||
case s.result <- res:
|
||||
default:
|
||||
log.Debug("iflow oauth result channel full, dropping result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user