Files
CLIProxyAPI/internal/managementasset/updater.go

469 lines
13 KiB
Go

package managementasset
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
const ManagementFileName = managementAssetName
var (
lastUpdateCheckMu sync.Mutex
lastUpdateCheckTime time.Time
currentConfigPtr atomic.Pointer[config.Config]
disableControlPanel atomic.Bool
schedulerOnce sync.Once
schedulerConfigPath atomic.Value
)
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
func SetCurrentConfig(cfg *config.Config) {
if cfg == nil {
currentConfigPtr.Store(nil)
return
}
prevDisabled := disableControlPanel.Load()
currentConfigPtr.Store(cfg)
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
lastUpdateCheckMu.Lock()
lastUpdateCheckTime = time.Time{}
lastUpdateCheckMu.Unlock()
}
}
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
// It respects the disable-control-panel flag on every iteration and supports hot-reloaded configurations.
func StartAutoUpdater(ctx context.Context, configFilePath string) {
configFilePath = strings.TrimSpace(configFilePath)
if configFilePath == "" {
log.Debug("management asset auto-updater skipped: empty config path")
return
}
schedulerConfigPath.Store(configFilePath)
schedulerOnce.Do(func() {
go runAutoUpdater(ctx)
})
}
func runAutoUpdater(ctx context.Context) {
if ctx == nil {
ctx = context.Background()
}
ticker := time.NewTicker(updateCheckInterval)
defer ticker.Stop()
runOnce := func() {
cfg := currentConfigPtr.Load()
if cfg == nil {
log.Debug("management asset auto-updater skipped: config not yet available")
return
}
if disableControlPanel.Load() {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
configPath, _ := schedulerConfigPath.Load().(string)
staticDir := StaticDir(configPath)
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
}
runOnce()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
runOnce()
}
}
}
func newHTTPClient(proxyURL string) *http.Client {
client := &http.Client{Timeout: 15 * time.Second}
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: strings.TrimSpace(proxyURL)}
util.SetProxy(sdkCfg, client)
return client
}
type releaseAsset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Digest string `json:"digest"`
}
type releaseResponse struct {
Assets []releaseAsset `json:"assets"`
}
// StaticDir resolves the directory that stores the management control panel asset.
func StaticDir(configFilePath string) string {
if override := strings.TrimSpace(os.Getenv("MANAGEMENT_STATIC_PATH")); override != "" {
cleaned := filepath.Clean(override)
if strings.EqualFold(filepath.Base(cleaned), managementAssetName) {
return filepath.Dir(cleaned)
}
return cleaned
}
if writable := util.WritablePath(); writable != "" {
return filepath.Join(writable, "static")
}
configFilePath = strings.TrimSpace(configFilePath)
if configFilePath == "" {
return ""
}
base := filepath.Dir(configFilePath)
fileInfo, err := os.Stat(configFilePath)
if err == nil {
if fileInfo.IsDir() {
base = configFilePath
}
}
return filepath.Join(base, "static")
}
// FilePath resolves the absolute path to the management control panel asset.
func FilePath(configFilePath string) string {
if override := strings.TrimSpace(os.Getenv("MANAGEMENT_STATIC_PATH")); override != "" {
cleaned := filepath.Clean(override)
if strings.EqualFold(filepath.Base(cleaned), managementAssetName) {
return cleaned
}
return filepath.Join(cleaned, ManagementFileName)
}
dir := StaticDir(configFilePath)
if dir == "" {
return ""
}
return filepath.Join(dir, ManagementFileName)
}
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
// The function is designed to run in a background goroutine and will never panic.
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
if ctx == nil {
ctx = context.Background()
}
if disableControlPanel.Load() {
log.Debug("management asset sync skipped: control panel disabled by configuration")
return
}
staticDir = strings.TrimSpace(staticDir)
if staticDir == "" {
log.Debug("management asset sync skipped: empty static directory")
return
}
localPath := filepath.Join(staticDir, managementAssetName)
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
// Rate limiting: check only once every 3 hours
lastUpdateCheckMu.Lock()
now := time.Now()
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
if timeSinceLastCheck < updateCheckInterval {
lastUpdateCheckMu.Unlock()
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
return
}
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.WithError(err).Debug("failed to read local management asset hash")
}
localHash = ""
}
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to fetch latest management release information")
return
}
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
log.Debug("management asset is already up to date")
return
}
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return
}
return
}
log.WithError(err).Warn("failed to download management asset")
return
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to update management asset on disk")
return
}
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {
return defaultManagementReleaseURL
}
parsed, err := url.Parse(repo)
if err != nil || parsed.Host == "" {
return defaultManagementReleaseURL
}
host := strings.ToLower(parsed.Host)
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
if host == "api.github.com" {
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
parsed.Path = parsed.Path + "/releases/latest"
}
return parsed.String()
}
if host == "github.com" {
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
repoName := strings.TrimSuffix(parts[1], ".git")
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
}
}
return defaultManagementReleaseURL
}
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
if strings.TrimSpace(releaseURL) == "" {
releaseURL = defaultManagementReleaseURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
if err != nil {
return nil, "", fmt.Errorf("create release request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", httpUserAgent)
gitURL := strings.ToLower(strings.TrimSpace(os.Getenv("GITSTORE_GIT_URL")))
if tok := strings.TrimSpace(os.Getenv("GITSTORE_GIT_TOKEN")); tok != "" && strings.Contains(gitURL, "github.com") {
req.Header.Set("Authorization", "Bearer "+tok)
}
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("execute release request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, "", fmt.Errorf("unexpected release status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var release releaseResponse
if err = json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, "", fmt.Errorf("decode release response: %w", err)
}
for i := range release.Assets {
asset := &release.Assets[i]
if strings.EqualFold(asset.Name, managementAssetName) {
remoteHash := parseDigest(asset.Digest)
return asset, remoteHash, nil
}
}
return nil, "", fmt.Errorf("management asset %s not found in latest release", managementAssetName)
}
func downloadAsset(ctx context.Context, client *http.Client, downloadURL string) ([]byte, string, error) {
if strings.TrimSpace(downloadURL) == "" {
return nil, "", fmt.Errorf("empty download url")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
return nil, "", fmt.Errorf("create download request: %w", err)
}
req.Header.Set("User-Agent", httpUserAgent)
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("execute download request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("read download body: %w", err)
}
sum := sha256.Sum256(data)
return data, hex.EncodeToString(sum[:]), nil
}
func fileSHA256(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", err
}
defer func() {
_ = file.Close()
}()
h := sha256.New()
if _, err = io.Copy(h, file); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func atomicWriteFile(path string, data []byte) error {
tmpFile, err := os.CreateTemp(filepath.Dir(path), "management-*.html")
if err != nil {
return err
}
tmpName := tmpFile.Name()
defer func() {
_ = tmpFile.Close()
_ = os.Remove(tmpName)
}()
if _, err = tmpFile.Write(data); err != nil {
return err
}
if err = tmpFile.Chmod(0o644); err != nil {
return err
}
if err = tmpFile.Close(); err != nil {
return err
}
if err = os.Rename(tmpName, path); err != nil {
return err
}
return nil
}
func parseDigest(digest string) string {
digest = strings.TrimSpace(digest)
if digest == "" {
return ""
}
if idx := strings.Index(digest, ":"); idx >= 0 {
digest = digest[idx+1:]
}
return strings.ToLower(strings.TrimSpace(digest))
}