fix preset error

This commit is contained in:
musistudio
2025-12-30 18:23:44 +08:00
parent 7e11fca0d5
commit 7400941ae8
9 changed files with 152 additions and 66 deletions

View File

@@ -4,14 +4,13 @@ import { homedir } from "os";
import { join } from "path";
import { initConfig, initDir } from "./utils";
import { createServer } from "./server";
import { router } from "./utils/router";
import { apiKeyAuth } from "./middleware/auth";
import {CONFIG_FILE, HOME_DIR, listPresets} from "@CCR/shared";
import { CONFIG_FILE, HOME_DIR, listPresets } from "@CCR/shared";
import { createStream } from 'rotating-file-stream';
import { sessionUsageCache } from "./utils/cache";
import {SSEParserTransform} from "./utils/SSEParser.transform";
import {SSESerializerTransform} from "./utils/SSESerializer.transform";
import {rewriteStream} from "./utils/rewriteStream";
import { sessionUsageCache } from "@musistudio/llms";
import { SSEParserTransform } from "./utils/SSEParser.transform";
import { SSESerializerTransform } from "./utils/SSESerializer.transform";
import { rewriteStream } from "./utils/rewriteStream";
import JSON5 from "json5";
import { IAgent, ITool } from "./agents/type";
import agentsManager from "./agents";
@@ -138,10 +137,9 @@ async function getServer(options: RunOptions = {}) {
logger: loggerConfig,
});
presets.forEach(preset => {
console.log(preset.name, preset.config);
serverInstance.registerNamespace(preset.name, preset.config);
})
await Promise.allSettled(
presets.map(async preset => await serverInstance.registerNamespace(preset.name, preset.config))
)
// Add async preHandler hook for authentication
serverInstance.addHook("preHandler", async (req: any, reply: any) => {
@@ -155,7 +153,15 @@ async function getServer(options: RunOptions = {}) {
});
});
serverInstance.addHook("preHandler", async (req: any, reply: any) => {
if (req.url.startsWith("/v1/messages") && !req.url.startsWith("/v1/messages/count_tokens")) {
const url = new URL(`http://127.0.0.1${req.url}`);
req.pathname = url.pathname;
if (req.pathname.endsWith("/v1/messages") && req.pathname !== "/v1/messages") {
req.preset = req.pathname.replace("/v1/messages", "").replace("/", "");
}
})
serverInstance.addHook("preHandler", async (req: any, reply: any) => {
if (req.pathname.endsWith("/v1/messages")) {
const useAgents = []
for (const agent of agentsManager.getAllAgents()) {
@@ -185,17 +191,13 @@ async function getServer(options: RunOptions = {}) {
if (useAgents.length) {
req.agents = useAgents;
}
await router(req, reply, {
config,
event
});
}
});
serverInstance.addHook("onError", async (request: any, reply: any, error: any) => {
event.emit('onError', request, reply, error);
})
serverInstance.addHook("onSend", (req: any, reply: any, payload: any, done: any) => {
if (req.sessionId && req.url.startsWith("/v1/messages") && !req.url.startsWith("/v1/messages/count_tokens")) {
if (req.sessionId && req.pathname.endsWith("/v1/messages")) {
if (payload instanceof ReadableStream) {
if (req.agents) {
const abortController = new AbortController();

View File

@@ -1,10 +1,9 @@
import Server from "@musistudio/llms";
import Server, { calculateTokenCount } from "@musistudio/llms";
import { readConfigFile, writeConfigFile, backupConfigFile } from "./utils";
import { join } from "path";
import fastifyStatic from "@fastify/static";
import { readdirSync, statSync, readFileSync, writeFileSync, existsSync, mkdirSync, unlinkSync, rmSync } from "fs";
import { homedir } from "os";
import { calculateTokenCount } from "./utils/router";
import {
getPresetDir,
readManifestFromDir,

View File

@@ -1,5 +1,6 @@
declare module "@musistudio/llms" {
import { FastifyInstance } from "fastify";
import { FastifyBaseLogger } from "fastify";
export interface ServerConfig {
jsonPath?: string;
@@ -9,7 +10,7 @@ declare module "@musistudio/llms" {
export interface Server {
app: FastifyInstance;
logger: any;
logger: FastifyBaseLogger;
start(): Promise<void>;
}
@@ -18,4 +19,44 @@ declare module "@musistudio/llms" {
};
export default Server;
// Export cache
export interface Usage {
input_tokens: number;
output_tokens: number;
}
export const sessionUsageCache: any;
// Export router
export interface RouterContext {
configService: any;
event?: any;
}
export const router: (req: any, res: any, context: RouterContext) => Promise<void>;
// Export utilities
export const calculateTokenCount: (messages: any[], system: any, tools: any[]) => number;
export const searchProjectBySession: (sessionId: string) => Promise<string | null>;
// Export services
export class ConfigService {
constructor(options?: any);
get<T = any>(key: string): T | undefined;
get<T = any>(key: string, defaultValue: T): T;
getAll(): any;
has(key: string): boolean;
set(key: string, value: any): void;
reload(): void;
}
export class ProviderService {
constructor(configService: any, transformerService: any, logger: any);
}
export class TransformerService {
constructor(configService: any, logger: any);
initialize(): Promise<void>;
}
}

View File

@@ -1,47 +0,0 @@
// LRU cache for session usage
export interface Usage {
input_tokens: number;
output_tokens: number;
}
class LRUCache<K, V> {
private capacity: number;
private cache: Map<K, V>;
constructor(capacity: number) {
this.capacity = capacity;
this.cache = new Map<K, V>();
}
get(key: K): V | undefined {
if (!this.cache.has(key)) {
return undefined;
}
const value = this.cache.get(key) as V;
// Move to end to mark as recently used
this.cache.delete(key);
this.cache.set(key, value);
return value;
}
put(key: K, value: V): void {
if (this.cache.has(key)) {
// If key exists, delete it to update its position
this.cache.delete(key);
} else if (this.cache.size >= this.capacity) {
// If cache is full, delete the least recently used item
const leastRecentlyUsedKey = this.cache.keys().next().value;
if (leastRecentlyUsedKey !== undefined) {
this.cache.delete(leastRecentlyUsedKey);
}
}
this.cache.set(key, value);
}
values(): V[] {
return Array.from(this.cache.values());
}
}
export const sessionUsageCache = new LRUCache<string, Usage>(100);

View File

@@ -1,317 +0,0 @@
import { get_encoding } from "tiktoken";
import { sessionUsageCache, Usage } from "./cache";
import { readFile, access } from "fs/promises";
import { opendir, stat } from "fs/promises";
import { join } from "path";
import { CLAUDE_PROJECTS_DIR, HOME_DIR } from "@CCR/shared";
import { LRUCache } from "lru-cache";
// Types from @anthropic-ai/sdk
interface Tool {
name: string;
description?: string;
input_schema: object;
}
interface ContentBlockParam {
type: string;
[key: string]: any;
}
interface MessageParam {
role: string;
content: string | ContentBlockParam[];
}
interface MessageCreateParamsBase {
messages?: MessageParam[];
system?: string | any[];
tools?: Tool[];
[key: string]: any;
}
const enc = get_encoding("cl100k_base");
export const calculateTokenCount = (
messages: MessageParam[],
system: any,
tools: Tool[]
) => {
let tokenCount = 0;
if (Array.isArray(messages)) {
messages.forEach((message) => {
if (typeof message.content === "string") {
tokenCount += enc.encode(message.content).length;
} else if (Array.isArray(message.content)) {
message.content.forEach((contentPart: any) => {
if (contentPart.type === "text") {
tokenCount += enc.encode(contentPart.text).length;
} else if (contentPart.type === "tool_use") {
tokenCount += enc.encode(JSON.stringify(contentPart.input)).length;
} else if (contentPart.type === "tool_result") {
tokenCount += enc.encode(
typeof contentPart.content === "string"
? contentPart.content
: JSON.stringify(contentPart.content)
).length;
}
});
}
});
}
if (typeof system === "string") {
tokenCount += enc.encode(system).length;
} else if (Array.isArray(system)) {
system.forEach((item: any) => {
if (item.type !== "text") return;
if (typeof item.text === "string") {
tokenCount += enc.encode(item.text).length;
} else if (Array.isArray(item.text)) {
item.text.forEach((textPart: any) => {
tokenCount += enc.encode(textPart || "").length;
});
}
});
}
if (tools) {
tools.forEach((tool: Tool) => {
if (tool.description) {
tokenCount += enc.encode(tool.name + tool.description).length;
}
if (tool.input_schema) {
tokenCount += enc.encode(JSON.stringify(tool.input_schema)).length;
}
});
}
return tokenCount;
};
const readConfigFile = async (filePath: string) => {
try {
await access(filePath);
const content = await readFile(filePath, "utf8");
return JSON.parse(content);
} catch (error) {
return null; // 文件不存在或读取失败时返回null
}
};
const getProjectSpecificRouter = async (req: any) => {
// 检查是否有项目特定的配置
if (req.sessionId) {
const project = await searchProjectBySession(req.sessionId);
if (project) {
const projectConfigPath = join(HOME_DIR, project, "config.json");
const sessionConfigPath = join(
HOME_DIR,
project,
`${req.sessionId}.json`
);
// 首先尝试读取sessionConfig文件
const sessionConfig = await readConfigFile(sessionConfigPath);
if (sessionConfig && sessionConfig.Router) {
return sessionConfig.Router;
}
const projectConfig = await readConfigFile(projectConfigPath);
if (projectConfig && projectConfig.Router) {
return projectConfig.Router;
}
}
}
return undefined; // 返回undefined表示使用原始配置
};
const getUseModel = async (
req: any,
tokenCount: number,
config: any,
lastUsage?: Usage | undefined
) => {
const projectSpecificRouter = await getProjectSpecificRouter(req);
const Router = projectSpecificRouter || config.Router;
if (req.body.model.includes(",")) {
const [provider, model] = req.body.model.split(",");
const finalProvider = config.Providers.find(
(p: any) => p.name.toLowerCase() === provider
);
const finalModel = finalProvider?.models?.find(
(m: any) => m.toLowerCase() === model
);
if (finalProvider && finalModel) {
return `${finalProvider.name},${finalModel}`;
}
return req.body.model;
}
// if tokenCount is greater than the configured threshold, use the long context model
const longContextThreshold = Router.longContextThreshold || 60000;
const lastUsageThreshold =
lastUsage &&
lastUsage.input_tokens > longContextThreshold &&
tokenCount > 20000;
const tokenCountThreshold = tokenCount > longContextThreshold;
if ((lastUsageThreshold || tokenCountThreshold) && Router.longContext) {
req.log.info(
`Using long context model due to token count: ${tokenCount}, threshold: ${longContextThreshold}`
);
return Router.longContext;
}
if (
req.body?.system?.length > 1 &&
req.body?.system[1]?.text?.startsWith("<CCR-SUBAGENT-MODEL>")
) {
const model = req.body?.system[1].text.match(
/<CCR-SUBAGENT-MODEL>(.*?)<\/CCR-SUBAGENT-MODEL>/s
);
if (model) {
req.body.system[1].text = req.body.system[1].text.replace(
`<CCR-SUBAGENT-MODEL>${model[1]}</CCR-SUBAGENT-MODEL>`,
""
);
return model[1];
}
}
// Use the background model for any Claude Haiku variant
if (
req.body.model?.includes("claude") &&
req.body.model?.includes("haiku") &&
config.Router.background
) {
req.log.info(`Using background model for ${req.body.model}`);
return config.Router.background;
}
// The priority of websearch must be higher than thinking.
if (
Array.isArray(req.body.tools) &&
req.body.tools.some((tool: any) => tool.type?.startsWith("web_search")) &&
Router.webSearch
) {
return Router.webSearch;
}
// if exits thinking, use the think model
if (req.body.thinking && Router.think) {
req.log.info(`Using think model for ${req.body.thinking}`);
return Router.think;
}
return Router!.default;
};
export const router = async (req: any, _res: any, context: any) => {
const { config, event } = context;
// Parse sessionId from metadata.user_id
if (req.body.metadata?.user_id) {
const parts = req.body.metadata.user_id.split("_session_");
if (parts.length > 1) {
req.sessionId = parts[1];
}
}
const lastMessageUsage = sessionUsageCache.get(req.sessionId);
const { messages, system = [], tools }: MessageCreateParamsBase = req.body;
if (
config.REWRITE_SYSTEM_PROMPT &&
system.length > 1 &&
system[1]?.text?.includes("<env>")
) {
const prompt = await readFile(config.REWRITE_SYSTEM_PROMPT, "utf-8");
system[1].text = `${prompt}<env>${system[1].text.split("<env>").pop()}`;
}
try {
const tokenCount = calculateTokenCount(
messages as MessageParam[],
system,
tools as Tool[]
);
let model;
if (config.CUSTOM_ROUTER_PATH) {
try {
const customRouter = require(config.CUSTOM_ROUTER_PATH);
req.tokenCount = tokenCount; // Pass token count to custom router
model = await customRouter(req, config, {
event,
});
} catch (e: any) {
req.log.error(`failed to load custom router: ${e.message}`);
}
}
if (!model) {
model = await getUseModel(req, tokenCount, config, lastMessageUsage);
}
req.body.model = model;
} catch (error: any) {
req.log.error(`Error in router middleware: ${error.message}`);
req.body.model = config.Router!.default;
}
return;
};
// 内存缓存存储sessionId到项目名称的映射
// null值表示之前已查找过但未找到项目
// 使用LRU缓存限制最大1000个条目
const sessionProjectCache = new LRUCache<string, string>({
max: 1000,
});
export const searchProjectBySession = async (
sessionId: string
): Promise<string | null> => {
// 首先检查缓存
if (sessionProjectCache.has(sessionId)) {
const result = sessionProjectCache.get(sessionId);
if (!result || result === '') {
return null;
}
return result;
}
try {
const dir = await opendir(CLAUDE_PROJECTS_DIR);
const folderNames: string[] = [];
// 收集所有文件夹名称
for await (const dirent of dir) {
if (dirent.isDirectory()) {
folderNames.push(dirent.name);
}
}
// 并发检查每个项目文件夹中是否存在sessionId.jsonl文件
const checkPromises = folderNames.map(async (folderName) => {
const sessionFilePath = join(
CLAUDE_PROJECTS_DIR,
folderName,
`${sessionId}.jsonl`
);
try {
const fileStat = await stat(sessionFilePath);
return fileStat.isFile() ? folderName : null;
} catch {
// 文件不存在,继续检查下一个
return null;
}
});
const results = await Promise.all(checkPromises);
// 返回第一个存在的项目目录名称
for (const result of results) {
if (result) {
// 缓存找到的结果
sessionProjectCache.set(sessionId, result);
return result;
}
}
// 缓存未找到的结果null值表示之前已查找过但未找到项目
sessionProjectCache.set(sessionId, '');
return null; // 没有找到匹配的项目
} catch (error) {
console.error("Error searching for project by session:", error);
// 出错时也缓存null结果避免重复出错
sessionProjectCache.set(sessionId, '');
return null;
}
};