修复模型下载问题

This commit is contained in:
2025-03-10 15:50:55 +08:00
parent 92e46c2c33
commit e362890c96

View File

@ -24,6 +24,8 @@ export interface ModelCacheResult<T, Create extends boolean> {
model: Create extends true ? T : never model: Create extends true ? T : never
} }
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
export abstract class Model { export abstract class Model {
protected session: backend.common.CommonSession; protected session: backend.common.CommonSession;
@ -58,6 +60,9 @@ export abstract class Model {
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]); const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models"); const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
await fs.promises.mkdir(cacheDir, { recursive: true }); await fs.promises.mkdir(cacheDir, { recursive: true });
//定义函数用于加载模型信息
async function resolveModel() {
//加载模型配置 //加载模型配置
const cacheJsonFile = path.join(cacheDir, "config.json"); const cacheJsonFile = path.join(cacheDir, "config.json");
let cacheJsonData: Array<{ url: string, filename: string }> = []; let cacheJsonData: Array<{ url: string, filename: string }> = [];
@ -120,15 +125,26 @@ export abstract class Model {
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4)); fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
} }
//返回模型数据 //返回模型数据
const modelPath = path.join(cacheDir, cache.filename); return path.join(cacheDir, cache.filename);
}
const modelType = path.extname(cache.filename).substring(1) as ModelType; //查找任务
let cache = cacheDownloadTasks.find(c => c.url === url && c.cacheDir === cacheDir);
if (!cache) {
cache = { url, cacheDir, modelPath: resolveModel() }
cacheDownloadTasks.push(cache);
}
//获取模型数据
const modelPath = await cache.modelPath;
const modelType = path.extname(modelPath).substring(1) as ModelType;
let model: T | undefined = undefined; let model: T | undefined = undefined;
if (option?.createModel) { if (option?.createModel) {
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath); if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath); else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
} }
//返回结果
return { modelPath, modelType, model: model as any } return { modelPath, modelType, model: model as any }
} }