diff --git a/desktop/src/main.cjs b/desktop/src/main.cjs index 59857f8..3331a89 100644 --- a/desktop/src/main.cjs +++ b/desktop/src/main.cjs @@ -21,6 +21,7 @@ const crypto = require("crypto"); const fs = require("fs"); const http = require("http"); const net = require("net"); +const os = require("os"); const path = require("path"); const { Worker } = require("worker_threads"); const { @@ -101,6 +102,84 @@ function getBackendAccessHost() { return host || "127.0.0.1"; } +function getInterfacePenalty(name) { + const lower = String(name || "").toLowerCase(); + if (/(docker|hyper-v|loopback|npcap|tailscale|virtual|virtualbox|vmware|vethernet|wsl|zerotier)/i.test(lower)) { + return 30; + } + if (/(ethernet|wi-fi|wifi|wireless|wlan|以太|无线)/i.test(lower)) { + return 0; + } + return 10; +} + +function isReachableClientIpv4(address) { + const text = String(address || "").trim(); + const parts = text.split("."); + if (parts.length !== 4) return false; + const nums = parts.map((part) => Number(part)); + if (!nums.every((n) => Number.isInteger(n) && n >= 0 && n <= 255)) return false; + if (nums[0] === 0 || nums[0] === 127 || nums[0] >= 224) return false; + if (nums[0] === 169 && nums[1] === 254) return false; + return true; +} + +function isPrivateIpv4(address) { + const nums = String(address || "").trim().split(".").map((part) => Number(part)); + if (nums.length !== 4 || !nums.every((n) => Number.isInteger(n))) return false; + return ( + nums[0] === 10 || + (nums[0] === 172 && nums[1] >= 16 && nums[1] <= 31) || + (nums[0] === 192 && nums[1] === 168) + ); +} + +function getLanAccessHost(defaultHost = DEFAULT_BACKEND_HOST) { + const candidates = []; + const seen = new Set(); + const addCandidate = (address, interfaceName = "", sourceOrder = 0) => { + const value = String(address || "").trim(); + if (!isReachableClientIpv4(value) || seen.has(value)) return; + seen.add(value); + candidates.push([ + isPrivateIpv4(value) ? 0 : 1, + getInterfacePenalty(interfaceName), + sourceOrder, + value, + ]); + }; + + try { + const interfaces = os.networkInterfaces(); + for (const [name, addresses] of Object.entries(interfaces || {})) { + for (const item of addresses || []) { + if (!item || (item.family !== "IPv4" && item.family !== 4) || item.internal) continue; + addCandidate(item.address, name, 0); + } + } + } catch {} + + candidates.sort((a, b) => a[0] - b[0] || a[1] - b[1] || a[2] - b[2]); + return candidates[0]?.[3] || defaultHost; +} + +function getMcpAccessHost(bindHost = getBackendBindHost()) { + const host = String(bindHost || "").trim(); + if (host === LAN_BACKEND_HOST || host === "::") return getLanAccessHost(DEFAULT_BACKEND_HOST); + return host || DEFAULT_BACKEND_HOST; +} + +function getMcpAccessInfo(bindHost = getBackendBindHost(), port = getBackendPort()) { + const accessHost = getMcpAccessHost(bindHost); + const origin = `http://${formatHostForUrl(accessHost)}:${port}`; + return { + accessHost, + mcpEndpoint: `${origin}/mcp`, + skillBundleUrl: `${origin}/mcp/skill/bundle`, + skillMarkdownUrl: `${origin}/mcp/skill`, + }; +} + function getBackendPort() { const envPort = parsePort(process.env.WECHAT_TOOL_PORT); if (envPort != null) return envPort; @@ -2340,19 +2419,24 @@ function registerWindowIpc() { ipcMain.handle("backend:getMcpLanAccess", () => { try { + const host = getBackendBindHost(); + const port = getBackendPort(); return { enabled: getMcpLanAccessEnabled(), - host: getBackendBindHost(), - port: getBackendPort(), + host, + port, uiUrl: getDesktopUiUrl(), + ...getMcpAccessInfo(host, port), }; } catch (err) { logMain(`[main] backend:getMcpLanAccess failed: ${err?.message || err}`); + const port = DEFAULT_BACKEND_PORT; return { enabled: false, host: DEFAULT_BACKEND_HOST, - port: DEFAULT_BACKEND_PORT, + port, uiUrl: getDesktopUiUrl(), + ...getMcpAccessInfo(DEFAULT_BACKEND_HOST, port), }; } }); @@ -2363,13 +2447,16 @@ function registerWindowIpc() { const nextEnabled = !!enabled; const prevEnabled = getMcpLanAccessEnabled(); if (nextEnabled === prevEnabled) { + const host = getBackendBindHost(); + const port = getBackendPort(); return { success: true, changed: false, enabled: prevEnabled, - host: getBackendBindHost(), - port: getBackendPort(), + host, + port, uiUrl: getDesktopUiUrl(), + ...getMcpAccessInfo(host, port), }; } @@ -2396,6 +2483,7 @@ function registerWindowIpc() { host: getBackendBindHost(), port: getBackendPort(), uiUrl, + ...getMcpAccessInfo(), }; } finally { backendPortChangeInProgress = false; diff --git a/frontend/components/SettingsDialog.vue b/frontend/components/SettingsDialog.vue index ac029bb..07c22ab 100644 --- a/frontend/components/SettingsDialog.vue +++ b/frontend/components/SettingsDialog.vue @@ -269,6 +269,7 @@
允许手机局域网接入 MCP
开启后后端监听 0.0.0.0,手机可通过接入提示词中的地址接入。
+
当前地址:{{ mcpEndpoint }}
{{ mcpLanAccessMessage }}
{{ mcpLanAccessError }}
@@ -608,6 +609,8 @@ const mcpSkillBundleText = ref('') const mcpSkillBundleLoading = ref(false) const mcpSkillBundleError = ref('') const mcpCopiedKey = ref('') +const mcpAccessHost = ref('') +const mcpAccessEndpoint = ref('') let mcpCopiedTimer = null const mcpPortText = computed(() => { @@ -617,6 +620,10 @@ const mcpPortText = computed(() => { }) const mcpEndpoint = computed(() => { + const reported = String(mcpAccessEndpoint.value || '').trim() + if (/^https?:\/\//i.test(reported)) return reported + const reportedHost = String(mcpAccessHost.value || '').trim() + if (reportedHost) return `http://${reportedHost}:${mcpPortText.value}/mcp` if (!process.client || typeof window === 'undefined') return `http://127.0.0.1:${mcpPortText.value}/mcp` const apiBase = useApiBase() if (/^https?:\/\//i.test(apiBase)) { @@ -630,6 +637,14 @@ const mcpEndpoint = computed(() => { return `${protocol}//${host}:${mcpPortText.value}/mcp` }) +const applyMcpAccessInfo = (resp) => { + if (!resp || typeof resp !== 'object') return + const accessHost = String(resp.accessHost || resp.access_host || '').trim() + const endpoint = String(resp.mcpEndpoint || resp.mcp_endpoint || '').trim() + if (accessHost) mcpAccessHost.value = accessHost + if (/^https?:\/\//i.test(endpoint)) mcpAccessEndpoint.value = endpoint +} + const mcpSkillFallback = [ '# WeChat MCP Copilot', '', @@ -800,10 +815,12 @@ const refreshMcpLanAccess = async () => { if (window.wechatDesktop?.getMcpLanAccess) { const resp = await window.wechatDesktop.getMcpLanAccess() mcpLanAccessEnabled.value = !!resp?.enabled + applyMcpAccessInfo(resp) return } const resp = await fetchAdminEndpoint('/admin/mcp-access') mcpLanAccessEnabled.value = !!resp?.enabled + applyMcpAccessInfo(resp) } catch (e) { mcpLanAccessError.value = e?.message || '读取 MCP 接入状态失败' } finally { @@ -881,7 +898,9 @@ const setMcpLanAccess = async (enabled) => { if (window.wechatDesktop?.setMcpLanAccess) { const resp = await window.wechatDesktop.setMcpLanAccess(!!enabled) mcpLanAccessEnabled.value = !!resp?.enabled + applyMcpAccessInfo(resp) mcpLanAccessMessage.value = resp?.changed ? 'MCP 局域网接入已更新,后端已重启。' : 'MCP 局域网接入状态未变化。' + await refreshMcpSkillBundle() return } @@ -890,11 +909,14 @@ const setMcpLanAccess = async (enabled) => { body: { enabled: !!enabled }, }) mcpLanAccessEnabled.value = !!resp?.enabled + applyMcpAccessInfo(resp) mcpLanAccessMessage.value = resp?.changed ? 'MCP 局域网接入已更新,正在等待后端重启。' : 'MCP 局域网接入状态未变化。' if (resp?.changed) { await waitForBackendHealth(30_000) + await refreshMcpLanAccess() mcpLanAccessMessage.value = 'MCP 局域网接入已更新,后端已恢复。' } + await refreshMcpSkillBundle() } catch (e) { mcpLanAccessEnabled.value = previous mcpLanAccessError.value = e?.message || '设置 MCP 接入状态失败' diff --git a/main.py b/main.py index b6acd1f..ecacc45 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ import uvicorn import os from pathlib import Path +from wechat_decrypt_tool.network_access import get_lan_access_host from wechat_decrypt_tool.runtime_settings import read_effective_backend_host, read_effective_backend_port def main(): @@ -18,6 +19,7 @@ def main(): host, host_source = read_effective_backend_host(default="127.0.0.1") port, port_source = read_effective_backend_port(default=10392) access_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host + lan_access_host = get_lan_access_host(default="127.0.0.1") if host in {"0.0.0.0", "::"} else access_host print("=" * 60) print("微信解密工具 API 服务") @@ -38,6 +40,8 @@ def main(): print(f"监听地址: {host}") print(f"API文档: http://{access_host}:{port}/docs") print(f"健康检查: http://{access_host}:{port}/api/health") + if lan_access_host != access_host: + print(f"局域网 MCP: http://{lan_access_host}:{port}/mcp") print("按 Ctrl+C 停止服务") print("=" * 60) diff --git a/src/wechat_decrypt_tool/network_access.py b/src/wechat_decrypt_tool/network_access.py new file mode 100644 index 0000000..b11cef7 --- /dev/null +++ b/src/wechat_decrypt_tool/network_access.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import ipaddress +import socket + + +_VIRTUAL_INTERFACE_MARKERS = ( + "docker", + "hyper-v", + "loopback", + "npcap", + "tailscale", + "virtual", + "virtualbox", + "vmware", + "vethernet", + "wsl", + "zerotier", +) + +_PREFERRED_INTERFACE_MARKERS = ( + "ethernet", + "wi-fi", + "wifi", + "wireless", + "wlan", + "以太", + "无线", +) + + +def _parse_ipv4(value: object) -> ipaddress.IPv4Address | None: + try: + ip = ipaddress.ip_address(str(value or "").strip()) + except ValueError: + return None + return ip if isinstance(ip, ipaddress.IPv4Address) else None + + +def _is_reachable_client_ipv4(ip: ipaddress.IPv4Address) -> bool: + return not ( + ip.is_loopback + or ip.is_unspecified + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + ) + + +def _interface_penalty(name: str) -> int: + lower = str(name or "").lower() + if any(marker in lower for marker in _VIRTUAL_INTERFACE_MARKERS): + return 30 + if any(marker in lower for marker in _PREFERRED_INTERFACE_MARKERS): + return 0 + return 10 + + +def _add_candidate( + candidates: list[tuple[int, int, int, str]], + seen: set[str], + value: object, + *, + interface_name: str = "", + source_order: int = 0, +) -> None: + ip = _parse_ipv4(value) + if not ip or not _is_reachable_client_ipv4(ip): + return + + text = str(ip) + if text in seen: + return + seen.add(text) + + private_rank = 0 if ip.is_private else 1 + candidates.append((private_rank, _interface_penalty(interface_name), source_order, text)) + + +def _add_psutil_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None: + try: + import psutil # type: ignore + except Exception: + return + + try: + stats_by_name = psutil.net_if_stats() + interfaces = psutil.net_if_addrs() + except Exception: + return + + for interface_name, addresses in interfaces.items(): + try: + stats = stats_by_name.get(interface_name) + if stats is not None and not bool(getattr(stats, "isup", False)): + continue + except Exception: + pass + + for addr in addresses: + try: + if getattr(addr, "family", None) != socket.AF_INET: + continue + _add_candidate( + candidates, + seen, + getattr(addr, "address", ""), + interface_name=interface_name, + source_order=0, + ) + except Exception: + continue + + +def _add_route_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None: + # UDP connect 不会实际发包,只用于询问系统默认出站路由会使用哪个本机地址。 + for target in ("223.5.5.5", "8.8.8.8", "1.1.1.1"): + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.settimeout(0.2) + sock.connect((target, 80)) + local_ip = sock.getsockname()[0] + except Exception: + continue + _add_candidate(candidates, seen, local_ip, interface_name="", source_order=1) + + +def _add_hostname_candidates(candidates: list[tuple[int, int, int, str]], seen: set[str]) -> None: + try: + hostname = socket.gethostname() + _, _, addresses = socket.gethostbyname_ex(hostname) + except Exception: + return + + for address in addresses: + _add_candidate(candidates, seen, address, interface_name="", source_order=2) + + +def get_lan_access_host(default: str = "127.0.0.1") -> str: + """返回同网段设备可访问的本机 IPv4 地址。""" + + candidates: list[tuple[int, int, int, str]] = [] + seen: set[str] = set() + + _add_psutil_candidates(candidates, seen) + _add_route_candidates(candidates, seen) + _add_hostname_candidates(candidates, seen) + + if not candidates: + return default + + candidates.sort() + return candidates[0][3] diff --git a/src/wechat_decrypt_tool/routers/admin.py b/src/wechat_decrypt_tool/routers/admin.py index 519a0c7..bfd70c9 100644 --- a/src/wechat_decrypt_tool/routers/admin.py +++ b/src/wechat_decrypt_tool/routers/admin.py @@ -14,6 +14,7 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException from starlette.requests import Request from ..logging_config import get_log_file_path, get_logger +from ..network_access import get_lan_access_host from ..path_fix import PathFixRoute from ..runtime_settings import ( LAN_BACKEND_HOST, @@ -56,6 +57,28 @@ def _get_backend_access_host() -> str: return host +def _get_mcp_access_host(bind_host: str | None = None) -> str: + host = str(bind_host or _get_backend_bind_host() or "").strip() + if host in {LAN_BACKEND_HOST, "::"}: + return get_lan_access_host(default=LOOPBACK_BACKEND_HOST) + return host or LOOPBACK_BACKEND_HOST + + +def _get_mcp_access_urls(port: int, bind_host: str | None = None) -> dict: + access_host = _get_mcp_access_host(bind_host) + origin = f"http://{_format_host_for_url(access_host)}:{int(port)}" + return { + "access_host": access_host, + "accessHost": access_host, + "mcp_endpoint": f"{origin}/mcp", + "mcpEndpoint": f"{origin}/mcp", + "skill_bundle_url": f"{origin}/mcp/skill/bundle", + "skillBundleUrl": f"{origin}/mcp/skill/bundle", + "skill_markdown_url": f"{origin}/mcp/skill", + "skillMarkdownUrl": f"{origin}/mcp/skill", + } + + def _is_loopback_client(request: Request) -> bool: client = request.client host = str(getattr(client, "host", "") or "").strip() @@ -272,6 +295,7 @@ async def get_mcp_access() -> dict: "default_host": LOOPBACK_BACKEND_HOST, "lan_host": LAN_BACKEND_HOST, "restart_required": False, + **_get_mcp_access_urls(port, host), } @@ -407,6 +431,7 @@ async def set_mcp_access(payload: dict, request: Request, background_tasks: Back "port": int(current_port), "ui_url": f"http://{_format_host_for_url(_get_backend_access_host())}:{int(current_port)}/", "env_file": str(env_file) if env_file else None, + **_get_mcp_access_urls(int(current_port), next_host), } _PORT_CHANGE_IN_PROGRESS = True @@ -428,6 +453,7 @@ async def set_mcp_access(payload: dict, request: Request, background_tasks: Back "ui_url": f"http://{_format_host_for_url(LOOPBACK_BACKEND_HOST)}:{int(current_port)}/", "env_file": str(env_file) if env_file else None, "restart_scheduled": True, + **_get_mcp_access_urls(int(current_port), next_host), } finally: _PORT_CHANGE_IN_PROGRESS = False diff --git a/tests/test_mcp_access_host.py b/tests/test_mcp_access_host.py new file mode 100644 index 0000000..185c7e8 --- /dev/null +++ b/tests/test_mcp_access_host.py @@ -0,0 +1,143 @@ +import importlib +import logging +import os +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +def _close_logging_handlers() -> None: + for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"): + lg = logging.getLogger(logger_name) + for handler in lg.handlers[:]: + try: + handler.close() + except Exception: + pass + try: + lg.removeHandler(handler) + except Exception: + pass + + +class TestMcpAccessHost(unittest.TestCase): + def setUp(self) -> None: + self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR") + self._prev_host = os.environ.get("WECHAT_TOOL_HOST") + self._prev_port = os.environ.get("WECHAT_TOOL_PORT") + self._td = TemporaryDirectory() + os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name + os.environ.pop("WECHAT_TOOL_HOST", None) + os.environ.pop("WECHAT_TOOL_PORT", None) + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.runtime_settings as runtime_settings + import wechat_decrypt_tool.routers.admin as admin_router + + importlib.reload(app_paths) + importlib.reload(runtime_settings) + importlib.reload(admin_router) + + self.runtime_settings = runtime_settings + self.admin_router = admin_router + + def tearDown(self) -> None: + _close_logging_handlers() + + if self._prev_data_dir is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir + + if self._prev_host is None: + os.environ.pop("WECHAT_TOOL_HOST", None) + else: + os.environ["WECHAT_TOOL_HOST"] = self._prev_host + + if self._prev_port is None: + os.environ.pop("WECHAT_TOOL_PORT", None) + else: + os.environ["WECHAT_TOOL_PORT"] = self._prev_port + + self._td.cleanup() + + def _client(self) -> TestClient: + app = FastAPI() + app.include_router(self.admin_router.router) + return TestClient(app, client=("127.0.0.1", 52010)) + + def test_mcp_access_reports_lan_endpoint_when_lan_enabled(self) -> None: + self.runtime_settings.write_backend_host_setting(self.runtime_settings.LAN_BACKEND_HOST) + self.runtime_settings.write_backend_port_setting(12092) + client = self._client() + + with patch.object(self.admin_router, "get_lan_access_host", return_value="192.168.1.23"): + resp = client.get("/api/admin/mcp-access") + + self.assertEqual(resp.status_code, 200) + payload = resp.json() + self.assertTrue(payload["enabled"]) + self.assertEqual(payload["host"], "0.0.0.0") + self.assertEqual(payload["access_host"], "192.168.1.23") + self.assertEqual(payload["accessHost"], "192.168.1.23") + self.assertEqual(payload["mcp_endpoint"], "http://192.168.1.23:12092/mcp") + self.assertEqual(payload["mcpEndpoint"], "http://192.168.1.23:12092/mcp") + self.assertEqual(payload["skill_bundle_url"], "http://192.168.1.23:12092/mcp/skill/bundle") + self.assertEqual(payload["skill_markdown_url"], "http://192.168.1.23:12092/mcp/skill") + + def test_mcp_access_keeps_loopback_endpoint_when_lan_disabled(self) -> None: + self.runtime_settings.write_backend_host_setting(self.runtime_settings.LOOPBACK_BACKEND_HOST) + self.runtime_settings.write_backend_port_setting(12092) + client = self._client() + + with patch.object(self.admin_router, "get_lan_access_host", return_value="192.168.1.23"): + resp = client.get("/api/admin/mcp-access") + + self.assertEqual(resp.status_code, 200) + payload = resp.json() + self.assertFalse(payload["enabled"]) + self.assertEqual(payload["host"], "127.0.0.1") + self.assertEqual(payload["access_host"], "127.0.0.1") + self.assertEqual(payload["mcp_endpoint"], "http://127.0.0.1:12092/mcp") + + +class TestNetworkAccessHost(unittest.TestCase): + def test_get_lan_access_host_prefers_physical_private_ipv4(self) -> None: + import wechat_decrypt_tool.network_access as network_access + + with patch.object(network_access, "_add_psutil_candidates") as mocked_psutil, patch.object( + network_access, "_add_route_candidates" + ) as mocked_route, patch.object(network_access, "_add_hostname_candidates") as mocked_hostname: + + def add_psutil(candidates, seen): + network_access._add_candidate(candidates, seen, "172.18.0.2", interface_name="Docker", source_order=0) + network_access._add_candidate(candidates, seen, "192.168.1.23", interface_name="Wi-Fi", source_order=0) + + mocked_psutil.side_effect = add_psutil + mocked_route.side_effect = lambda candidates, seen: network_access._add_candidate( + candidates, seen, "10.0.0.9", source_order=1 + ) + mocked_hostname.side_effect = lambda candidates, seen: None + + self.assertEqual(network_access.get_lan_access_host(), "192.168.1.23") + + def test_get_lan_access_host_falls_back_when_no_candidate(self) -> None: + import wechat_decrypt_tool.network_access as network_access + + with patch.object(network_access, "_add_psutil_candidates", return_value=None), patch.object( + network_access, "_add_route_candidates", return_value=None + ), patch.object(network_access, "_add_hostname_candidates", return_value=None): + self.assertEqual(network_access.get_lan_access_host(default="127.0.0.1"), "127.0.0.1") + + +if __name__ == "__main__": + unittest.main()