mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-06-18 15:54:08 +08:00
feat(mcp): 修正局域网接入地址展示
This commit is contained in:
+93
-5
@@ -21,6 +21,7 @@ const crypto = require("crypto");
|
||||
const fs = require("fs");
|
||||
const http = require("http");
|
||||
const net = require("net");
|
||||
const os = require("os");
|
||||
const path = require("path");
|
||||
const { Worker } = require("worker_threads");
|
||||
const {
|
||||
@@ -101,6 +102,84 @@ function getBackendAccessHost() {
|
||||
return host || "127.0.0.1";
|
||||
}
|
||||
|
||||
function getInterfacePenalty(name) {
|
||||
const lower = String(name || "").toLowerCase();
|
||||
if (/(docker|hyper-v|loopback|npcap|tailscale|virtual|virtualbox|vmware|vethernet|wsl|zerotier)/i.test(lower)) {
|
||||
return 30;
|
||||
}
|
||||
if (/(ethernet|wi-fi|wifi|wireless|wlan|以太|无线)/i.test(lower)) {
|
||||
return 0;
|
||||
}
|
||||
return 10;
|
||||
}
|
||||
|
||||
function isReachableClientIpv4(address) {
|
||||
const text = String(address || "").trim();
|
||||
const parts = text.split(".");
|
||||
if (parts.length !== 4) return false;
|
||||
const nums = parts.map((part) => Number(part));
|
||||
if (!nums.every((n) => Number.isInteger(n) && n >= 0 && n <= 255)) return false;
|
||||
if (nums[0] === 0 || nums[0] === 127 || nums[0] >= 224) return false;
|
||||
if (nums[0] === 169 && nums[1] === 254) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
function isPrivateIpv4(address) {
|
||||
const nums = String(address || "").trim().split(".").map((part) => Number(part));
|
||||
if (nums.length !== 4 || !nums.every((n) => Number.isInteger(n))) return false;
|
||||
return (
|
||||
nums[0] === 10 ||
|
||||
(nums[0] === 172 && nums[1] >= 16 && nums[1] <= 31) ||
|
||||
(nums[0] === 192 && nums[1] === 168)
|
||||
);
|
||||
}
|
||||
|
||||
function getLanAccessHost(defaultHost = DEFAULT_BACKEND_HOST) {
|
||||
const candidates = [];
|
||||
const seen = new Set();
|
||||
const addCandidate = (address, interfaceName = "", sourceOrder = 0) => {
|
||||
const value = String(address || "").trim();
|
||||
if (!isReachableClientIpv4(value) || seen.has(value)) return;
|
||||
seen.add(value);
|
||||
candidates.push([
|
||||
isPrivateIpv4(value) ? 0 : 1,
|
||||
getInterfacePenalty(interfaceName),
|
||||
sourceOrder,
|
||||
value,
|
||||
]);
|
||||
};
|
||||
|
||||
try {
|
||||
const interfaces = os.networkInterfaces();
|
||||
for (const [name, addresses] of Object.entries(interfaces || {})) {
|
||||
for (const item of addresses || []) {
|
||||
if (!item || (item.family !== "IPv4" && item.family !== 4) || item.internal) continue;
|
||||
addCandidate(item.address, name, 0);
|
||||
}
|
||||
}
|
||||
} catch {}
|
||||
|
||||
candidates.sort((a, b) => a[0] - b[0] || a[1] - b[1] || a[2] - b[2]);
|
||||
return candidates[0]?.[3] || defaultHost;
|
||||
}
|
||||
|
||||
function getMcpAccessHost(bindHost = getBackendBindHost()) {
|
||||
const host = String(bindHost || "").trim();
|
||||
if (host === LAN_BACKEND_HOST || host === "::") return getLanAccessHost(DEFAULT_BACKEND_HOST);
|
||||
return host || DEFAULT_BACKEND_HOST;
|
||||
}
|
||||
|
||||
function getMcpAccessInfo(bindHost = getBackendBindHost(), port = getBackendPort()) {
|
||||
const accessHost = getMcpAccessHost(bindHost);
|
||||
const origin = `http://${formatHostForUrl(accessHost)}:${port}`;
|
||||
return {
|
||||
accessHost,
|
||||
mcpEndpoint: `${origin}/mcp`,
|
||||
skillBundleUrl: `${origin}/mcp/skill/bundle`,
|
||||
skillMarkdownUrl: `${origin}/mcp/skill`,
|
||||
};
|
||||
}
|
||||
|
||||
function getBackendPort() {
|
||||
const envPort = parsePort(process.env.WECHAT_TOOL_PORT);
|
||||
if (envPort != null) return envPort;
|
||||
@@ -2340,19 +2419,24 @@ function registerWindowIpc() {
|
||||
|
||||
ipcMain.handle("backend:getMcpLanAccess", () => {
|
||||
try {
|
||||
const host = getBackendBindHost();
|
||||
const port = getBackendPort();
|
||||
return {
|
||||
enabled: getMcpLanAccessEnabled(),
|
||||
host: getBackendBindHost(),
|
||||
port: getBackendPort(),
|
||||
host,
|
||||
port,
|
||||
uiUrl: getDesktopUiUrl(),
|
||||
...getMcpAccessInfo(host, port),
|
||||
};
|
||||
} catch (err) {
|
||||
logMain(`[main] backend:getMcpLanAccess failed: ${err?.message || err}`);
|
||||
const port = DEFAULT_BACKEND_PORT;
|
||||
return {
|
||||
enabled: false,
|
||||
host: DEFAULT_BACKEND_HOST,
|
||||
port: DEFAULT_BACKEND_PORT,
|
||||
port,
|
||||
uiUrl: getDesktopUiUrl(),
|
||||
...getMcpAccessInfo(DEFAULT_BACKEND_HOST, port),
|
||||
};
|
||||
}
|
||||
});
|
||||
@@ -2363,13 +2447,16 @@ function registerWindowIpc() {
|
||||
const nextEnabled = !!enabled;
|
||||
const prevEnabled = getMcpLanAccessEnabled();
|
||||
if (nextEnabled === prevEnabled) {
|
||||
const host = getBackendBindHost();
|
||||
const port = getBackendPort();
|
||||
return {
|
||||
success: true,
|
||||
changed: false,
|
||||
enabled: prevEnabled,
|
||||
host: getBackendBindHost(),
|
||||
port: getBackendPort(),
|
||||
host,
|
||||
port,
|
||||
uiUrl: getDesktopUiUrl(),
|
||||
...getMcpAccessInfo(host, port),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2396,6 +2483,7 @@ function registerWindowIpc() {
|
||||
host: getBackendBindHost(),
|
||||
port: getBackendPort(),
|
||||
uiUrl,
|
||||
...getMcpAccessInfo(),
|
||||
};
|
||||
} finally {
|
||||
backendPortChangeInProgress = false;
|
||||
|
||||
@@ -269,6 +269,7 @@
|
||||
<div class="min-w-0 flex-1">
|
||||
<div class="text-[13px] font-medium text-[#222]">允许手机局域网接入 MCP</div>
|
||||
<div class="mt-0.5 text-[11px] leading-relaxed text-[#909090]">开启后后端监听 0.0.0.0,手机可通过接入提示词中的地址接入。</div>
|
||||
<div class="mt-0.5 text-[11px] leading-relaxed text-[#909090] break-all">当前地址:{{ mcpEndpoint }}</div>
|
||||
<div v-if="mcpLanAccessMessage" class="mt-1 text-[11px] leading-relaxed text-[#1b6b43]">{{ mcpLanAccessMessage }}</div>
|
||||
<div v-if="mcpLanAccessError" class="mt-1 text-[11px] leading-relaxed text-red-600">{{ mcpLanAccessError }}</div>
|
||||
</div>
|
||||
@@ -608,6 +609,8 @@ const mcpSkillBundleText = ref('')
|
||||
const mcpSkillBundleLoading = ref(false)
|
||||
const mcpSkillBundleError = ref('')
|
||||
const mcpCopiedKey = ref('')
|
||||
const mcpAccessHost = ref('')
|
||||
const mcpAccessEndpoint = ref('')
|
||||
let mcpCopiedTimer = null
|
||||
|
||||
const mcpPortText = computed(() => {
|
||||
@@ -617,6 +620,10 @@ const mcpPortText = computed(() => {
|
||||
})
|
||||
|
||||
const mcpEndpoint = computed(() => {
|
||||
const reported = String(mcpAccessEndpoint.value || '').trim()
|
||||
if (/^https?:\/\//i.test(reported)) return reported
|
||||
const reportedHost = String(mcpAccessHost.value || '').trim()
|
||||
if (reportedHost) return `http://${reportedHost}:${mcpPortText.value}/mcp`
|
||||
if (!process.client || typeof window === 'undefined') return `http://127.0.0.1:${mcpPortText.value}/mcp`
|
||||
const apiBase = useApiBase()
|
||||
if (/^https?:\/\//i.test(apiBase)) {
|
||||
@@ -630,6 +637,14 @@ const mcpEndpoint = computed(() => {
|
||||
return `${protocol}//${host}:${mcpPortText.value}/mcp`
|
||||
})
|
||||
|
||||
const applyMcpAccessInfo = (resp) => {
|
||||
if (!resp || typeof resp !== 'object') return
|
||||
const accessHost = String(resp.accessHost || resp.access_host || '').trim()
|
||||
const endpoint = String(resp.mcpEndpoint || resp.mcp_endpoint || '').trim()
|
||||
if (accessHost) mcpAccessHost.value = accessHost
|
||||
if (/^https?:\/\//i.test(endpoint)) mcpAccessEndpoint.value = endpoint
|
||||
}
|
||||
|
||||
const mcpSkillFallback = [
|
||||
'# WeChat MCP Copilot',
|
||||
'',
|
||||
@@ -800,10 +815,12 @@ const refreshMcpLanAccess = async () => {
|
||||
if (window.wechatDesktop?.getMcpLanAccess) {
|
||||
const resp = await window.wechatDesktop.getMcpLanAccess()
|
||||
mcpLanAccessEnabled.value = !!resp?.enabled
|
||||
applyMcpAccessInfo(resp)
|
||||
return
|
||||
}
|
||||
const resp = await fetchAdminEndpoint('/admin/mcp-access')
|
||||
mcpLanAccessEnabled.value = !!resp?.enabled
|
||||
applyMcpAccessInfo(resp)
|
||||
} catch (e) {
|
||||
mcpLanAccessError.value = e?.message || '读取 MCP 接入状态失败'
|
||||
} finally {
|
||||
@@ -881,7 +898,9 @@ const setMcpLanAccess = async (enabled) => {
|
||||
if (window.wechatDesktop?.setMcpLanAccess) {
|
||||
const resp = await window.wechatDesktop.setMcpLanAccess(!!enabled)
|
||||
mcpLanAccessEnabled.value = !!resp?.enabled
|
||||
applyMcpAccessInfo(resp)
|
||||
mcpLanAccessMessage.value = resp?.changed ? 'MCP 局域网接入已更新,后端已重启。' : 'MCP 局域网接入状态未变化。'
|
||||
await refreshMcpSkillBundle()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -890,11 +909,14 @@ const setMcpLanAccess = async (enabled) => {
|
||||
body: { enabled: !!enabled },
|
||||
})
|
||||
mcpLanAccessEnabled.value = !!resp?.enabled
|
||||
applyMcpAccessInfo(resp)
|
||||
mcpLanAccessMessage.value = resp?.changed ? 'MCP 局域网接入已更新,正在等待后端重启。' : 'MCP 局域网接入状态未变化。'
|
||||
if (resp?.changed) {
|
||||
await waitForBackendHealth(30_000)
|
||||
await refreshMcpLanAccess()
|
||||
mcpLanAccessMessage.value = 'MCP 局域网接入已更新,后端已恢复。'
|
||||
}
|
||||
await refreshMcpSkillBundle()
|
||||
} catch (e) {
|
||||
mcpLanAccessEnabled.value = previous
|
||||
mcpLanAccessError.value = e?.message || '设置 MCP 接入状态失败'
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
import uvicorn
|
||||
import os
|
||||
from pathlib import Path
|
||||
from wechat_decrypt_tool.network_access import get_lan_access_host
|
||||
from wechat_decrypt_tool.runtime_settings import read_effective_backend_host, read_effective_backend_port
|
||||
|
||||
def main():
|
||||
@@ -18,6 +19,7 @@ def main():
|
||||
host, host_source = read_effective_backend_host(default="127.0.0.1")
|
||||
port, port_source = read_effective_backend_port(default=10392)
|
||||
access_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host
|
||||
lan_access_host = get_lan_access_host(default="127.0.0.1") if host in {"0.0.0.0", "::"} else access_host
|
||||
|
||||
print("=" * 60)
|
||||
print("微信解密工具 API 服务")
|
||||
@@ -38,6 +40,8 @@ def main():
|
||||
print(f"监听地址: {host}")
|
||||
print(f"API文档: http://{access_host}:{port}/docs")
|
||||
print(f"健康检查: http://{access_host}:{port}/api/health")
|
||||
if lan_access_host != access_host:
|
||||
print(f"局域网 MCP: http://{lan_access_host}:{port}/mcp")
|
||||
print("按 Ctrl+C 停止服务")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
|
||||
_VIRTUAL_INTERFACE_MARKERS = (
|
||||
"docker",
|
||||
"hyper-v",
|
||||
"loopback",
|
||||
"npcap",
|
||||
"tailscale",
|
||||
"virtual",
|
||||
"virtualbox",
|
||||
"vmware",
|
||||
"vethernet",
|
||||
"wsl",
|
||||
"zerotier",
|
||||
)
|
||||
|
||||
_PREFERRED_INTERFACE_MARKERS = (
|
||||
"ethernet",
|
||||
"wi-fi",
|
||||
"wifi",
|
||||
"wireless",
|
||||
"wlan",
|
||||
"以太",
|
||||
"无线",
|
||||
)
|
||||
|
||||
|
||||
def _parse_ipv4(value: object) -> ipaddress.IPv4Address | None:
|
||||
try:
|
||||
ip = ipaddress.ip_address(str(value or "").strip())
|
||||
except ValueError:
|
||||
return None
|
||||
return ip if isinstance(ip, ipaddress.IPv4Address) else None
|
||||
|
||||
|
||||
def _is_reachable_client_ipv4(ip: ipaddress.IPv4Address) -> bool:
|
||||
return not (
|
||||
ip.is_loopback
|
||||
or ip.is_unspecified
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
|
||||
|
||||
def _interface_penalty(name: str) -> int:
|
||||
lower = str(name or "").lower()
|
||||
if any(marker in lower for marker in _VIRTUAL_INTERFACE_MARKERS):
|
||||
return 30
|
||||
if any(marker in lower for marker in _PREFERRED_INTERFACE_MARKERS):
|
||||
return 0
|
||||
return 10
|
||||
|
||||
|
||||
def _add_candidate(
|
||||
candidates: list[tuple[int, int, int, str]],
|
||||
seen: set[str],
|
||||
value: object,
|
||||
*,
|
||||
interface_name: str = "",
|
||||
source_order: int = 0,
|
||||
) -> None:
|
||||
ip = _parse_ipv4(value)
|
||||
if not ip or not _is_reachable_client_ipv4(ip):
|
||||
return
|
||||
|
||||
text = str(ip)
|
||||
if text in seen:
|
||||
return
|
||||
seen.add(text)
|
||||
|
||||
private_rank = 0 if ip.is_private else 1
|
||||
candidates.append((private_rank, _interface_penalty(interface_name), source_order, text))
|
||||
|
||||
|
||||
def _add_psutil_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None:
|
||||
try:
|
||||
import psutil # type: ignore
|
||||
except Exception:
|
||||
return
|
||||
|
||||
try:
|
||||
stats_by_name = psutil.net_if_stats()
|
||||
interfaces = psutil.net_if_addrs()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
for interface_name, addresses in interfaces.items():
|
||||
try:
|
||||
stats = stats_by_name.get(interface_name)
|
||||
if stats is not None and not bool(getattr(stats, "isup", False)):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for addr in addresses:
|
||||
try:
|
||||
if getattr(addr, "family", None) != socket.AF_INET:
|
||||
continue
|
||||
_add_candidate(
|
||||
candidates,
|
||||
seen,
|
||||
getattr(addr, "address", ""),
|
||||
interface_name=interface_name,
|
||||
source_order=0,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def _add_route_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None:
|
||||
# UDP connect 不会实际发包,只用于询问系统默认出站路由会使用哪个本机地址。
|
||||
for target in ("223.5.5.5", "8.8.8.8", "1.1.1.1"):
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.settimeout(0.2)
|
||||
sock.connect((target, 80))
|
||||
local_ip = sock.getsockname()[0]
|
||||
except Exception:
|
||||
continue
|
||||
_add_candidate(candidates, seen, local_ip, interface_name="", source_order=1)
|
||||
|
||||
|
||||
def _add_hostname_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None:
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
_, _, addresses = socket.gethostbyname_ex(hostname)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
for address in addresses:
|
||||
_add_candidate(candidates, seen, address, interface_name="", source_order=2)
|
||||
|
||||
|
||||
def get_lan_access_host(default: str = "127.0.0.1") -> str:
|
||||
"""返回同网段设备可访问的本机 IPv4 地址。"""
|
||||
|
||||
candidates: list[tuple[int, int, int, str]] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
_add_psutil_candidates(candidates, seen)
|
||||
_add_route_candidates(candidates, seen)
|
||||
_add_hostname_candidates(candidates, seen)
|
||||
|
||||
if not candidates:
|
||||
return default
|
||||
|
||||
candidates.sort()
|
||||
return candidates[0][3]
|
||||
@@ -14,6 +14,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..logging_config import get_log_file_path, get_logger
|
||||
from ..network_access import get_lan_access_host
|
||||
from ..path_fix import PathFixRoute
|
||||
from ..runtime_settings import (
|
||||
LAN_BACKEND_HOST,
|
||||
@@ -56,6 +57,28 @@ def _get_backend_access_host() -> str:
|
||||
return host
|
||||
|
||||
|
||||
def _get_mcp_access_host(bind_host: str | None = None) -> str:
|
||||
host = str(bind_host or _get_backend_bind_host() or "").strip()
|
||||
if host in {LAN_BACKEND_HOST, "::"}:
|
||||
return get_lan_access_host(default=LOOPBACK_BACKEND_HOST)
|
||||
return host or LOOPBACK_BACKEND_HOST
|
||||
|
||||
|
||||
def _get_mcp_access_urls(port: int, bind_host: str | None = None) -> dict:
|
||||
access_host = _get_mcp_access_host(bind_host)
|
||||
origin = f"http://{_format_host_for_url(access_host)}:{int(port)}"
|
||||
return {
|
||||
"access_host": access_host,
|
||||
"accessHost": access_host,
|
||||
"mcp_endpoint": f"{origin}/mcp",
|
||||
"mcpEndpoint": f"{origin}/mcp",
|
||||
"skill_bundle_url": f"{origin}/mcp/skill/bundle",
|
||||
"skillBundleUrl": f"{origin}/mcp/skill/bundle",
|
||||
"skill_markdown_url": f"{origin}/mcp/skill",
|
||||
"skillMarkdownUrl": f"{origin}/mcp/skill",
|
||||
}
|
||||
|
||||
|
||||
def _is_loopback_client(request: Request) -> bool:
|
||||
client = request.client
|
||||
host = str(getattr(client, "host", "") or "").strip()
|
||||
@@ -272,6 +295,7 @@ async def get_mcp_access() -> dict:
|
||||
"default_host": LOOPBACK_BACKEND_HOST,
|
||||
"lan_host": LAN_BACKEND_HOST,
|
||||
"restart_required": False,
|
||||
**_get_mcp_access_urls(port, host),
|
||||
}
|
||||
|
||||
|
||||
@@ -407,6 +431,7 @@ async def set_mcp_access(payload: dict, request: Request, background_tasks: Back
|
||||
"port": int(current_port),
|
||||
"ui_url": f"http://{_format_host_for_url(_get_backend_access_host())}:{int(current_port)}/",
|
||||
"env_file": str(env_file) if env_file else None,
|
||||
**_get_mcp_access_urls(int(current_port), next_host),
|
||||
}
|
||||
|
||||
_PORT_CHANGE_IN_PROGRESS = True
|
||||
@@ -428,6 +453,7 @@ async def set_mcp_access(payload: dict, request: Request, background_tasks: Back
|
||||
"ui_url": f"http://{_format_host_for_url(LOOPBACK_BACKEND_HOST)}:{int(current_port)}/",
|
||||
"env_file": str(env_file) if env_file else None,
|
||||
"restart_scheduled": True,
|
||||
**_get_mcp_access_urls(int(current_port), next_host),
|
||||
}
|
||||
finally:
|
||||
_PORT_CHANGE_IN_PROGRESS = False
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
def _close_logging_handlers() -> None:
|
||||
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
|
||||
lg = logging.getLogger(logger_name)
|
||||
for handler in lg.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
lg.removeHandler(handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestMcpAccessHost(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
self._prev_host = os.environ.get("WECHAT_TOOL_HOST")
|
||||
self._prev_port = os.environ.get("WECHAT_TOOL_PORT")
|
||||
self._td = TemporaryDirectory()
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name
|
||||
os.environ.pop("WECHAT_TOOL_HOST", None)
|
||||
os.environ.pop("WECHAT_TOOL_PORT", None)
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.runtime_settings as runtime_settings
|
||||
import wechat_decrypt_tool.routers.admin as admin_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(runtime_settings)
|
||||
importlib.reload(admin_router)
|
||||
|
||||
self.runtime_settings = runtime_settings
|
||||
self.admin_router = admin_router
|
||||
|
||||
def tearDown(self) -> None:
|
||||
_close_logging_handlers()
|
||||
|
||||
if self._prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir
|
||||
|
||||
if self._prev_host is None:
|
||||
os.environ.pop("WECHAT_TOOL_HOST", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_HOST"] = self._prev_host
|
||||
|
||||
if self._prev_port is None:
|
||||
os.environ.pop("WECHAT_TOOL_PORT", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_PORT"] = self._prev_port
|
||||
|
||||
self._td.cleanup()
|
||||
|
||||
def _client(self) -> TestClient:
|
||||
app = FastAPI()
|
||||
app.include_router(self.admin_router.router)
|
||||
return TestClient(app, client=("127.0.0.1", 52010))
|
||||
|
||||
def test_mcp_access_reports_lan_endpoint_when_lan_enabled(self) -> None:
|
||||
self.runtime_settings.write_backend_host_setting(self.runtime_settings.LAN_BACKEND_HOST)
|
||||
self.runtime_settings.write_backend_port_setting(12092)
|
||||
client = self._client()
|
||||
|
||||
with patch.object(self.admin_router, "get_lan_access_host", return_value="192.168.1.23"):
|
||||
resp = client.get("/api/admin/mcp-access")
|
||||
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
payload = resp.json()
|
||||
self.assertTrue(payload["enabled"])
|
||||
self.assertEqual(payload["host"], "0.0.0.0")
|
||||
self.assertEqual(payload["access_host"], "192.168.1.23")
|
||||
self.assertEqual(payload["accessHost"], "192.168.1.23")
|
||||
self.assertEqual(payload["mcp_endpoint"], "http://192.168.1.23:12092/mcp")
|
||||
self.assertEqual(payload["mcpEndpoint"], "http://192.168.1.23:12092/mcp")
|
||||
self.assertEqual(payload["skill_bundle_url"], "http://192.168.1.23:12092/mcp/skill/bundle")
|
||||
self.assertEqual(payload["skill_markdown_url"], "http://192.168.1.23:12092/mcp/skill")
|
||||
|
||||
def test_mcp_access_keeps_loopback_endpoint_when_lan_disabled(self) -> None:
|
||||
self.runtime_settings.write_backend_host_setting(self.runtime_settings.LOOPBACK_BACKEND_HOST)
|
||||
self.runtime_settings.write_backend_port_setting(12092)
|
||||
client = self._client()
|
||||
|
||||
with patch.object(self.admin_router, "get_lan_access_host", return_value="192.168.1.23"):
|
||||
resp = client.get("/api/admin/mcp-access")
|
||||
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
payload = resp.json()
|
||||
self.assertFalse(payload["enabled"])
|
||||
self.assertEqual(payload["host"], "127.0.0.1")
|
||||
self.assertEqual(payload["access_host"], "127.0.0.1")
|
||||
self.assertEqual(payload["mcp_endpoint"], "http://127.0.0.1:12092/mcp")
|
||||
|
||||
|
||||
class TestNetworkAccessHost(unittest.TestCase):
|
||||
def test_get_lan_access_host_prefers_physical_private_ipv4(self) -> None:
|
||||
import wechat_decrypt_tool.network_access as network_access
|
||||
|
||||
with patch.object(network_access, "_add_psutil_candidates") as mocked_psutil, patch.object(
|
||||
network_access, "_add_route_candidates"
|
||||
) as mocked_route, patch.object(network_access, "_add_hostname_candidates") as mocked_hostname:
|
||||
|
||||
def add_psutil(candidates, seen):
|
||||
network_access._add_candidate(candidates, seen, "172.18.0.2", interface_name="Docker", source_order=0)
|
||||
network_access._add_candidate(candidates, seen, "192.168.1.23", interface_name="Wi-Fi", source_order=0)
|
||||
|
||||
mocked_psutil.side_effect = add_psutil
|
||||
mocked_route.side_effect = lambda candidates, seen: network_access._add_candidate(
|
||||
candidates, seen, "10.0.0.9", source_order=1
|
||||
)
|
||||
mocked_hostname.side_effect = lambda candidates, seen: None
|
||||
|
||||
self.assertEqual(network_access.get_lan_access_host(), "192.168.1.23")
|
||||
|
||||
def test_get_lan_access_host_falls_back_when_no_candidate(self) -> None:
|
||||
import wechat_decrypt_tool.network_access as network_access
|
||||
|
||||
with patch.object(network_access, "_add_psutil_candidates", return_value=None), patch.object(
|
||||
network_access, "_add_route_candidates", return_value=None
|
||||
), patch.object(network_access, "_add_hostname_candidates", return_value=None):
|
||||
self.assertEqual(network_access.get_lan_access_host(default="127.0.0.1"), "127.0.0.1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user