diff --git a/src/deploy/common/model.ts b/src/deploy/common/model.ts index 70dfe09..078e1e4 100644 --- a/src/deploy/common/model.ts +++ b/src/deploy/common/model.ts @@ -24,6 +24,8 @@ export interface ModelCacheResult { model: Create extends true ? T : never } +const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise }> = []; + export abstract class Model { protected session: backend.common.CommonSession; @@ -58,77 +60,91 @@ export abstract class Model { const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]); const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models"); await fs.promises.mkdir(cacheDir, { recursive: true }); - //加载模型配置 - const cacheJsonFile = path.join(cacheDir, "config.json"); - let cacheJsonData: Array<{ url: string, filename: string }> = []; - if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) { - try { - cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8")); - } catch (e) { console.error(e); } - } - //不存在则下载 - let cache = cacheJsonData.find(c => c.url === url); - if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) { - let saveType = option?.saveType ?? null; - const saveTypeDict: Record = { - ".onnx": "onnx", - ".mnn": "mnn", - }; - const _url = new URL(url); - const res = await fetch(_url).then(res => { - const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1]; - if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType; - if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx"; - if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`); - return res.blob(); - }).then(blob => blob.stream()).then(async stream => { - const cacheFilename = path.join(cacheDir, Date.now().toString()); - const hash = await new Promise((resolve, reject) => { - let fsStream!: ReturnType; - let hashStream!: ReturnType; - stream.pipeTo(new WritableStream({ - start(controller) { - fsStream = fs.createWriteStream(cacheFilename); - hashStream = crypto.createHash("md5"); - }, - async write(chunk, controller) { - await new Promise((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve())); - await new Promise((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve())); - }, - close() { - fsStream.end(); - hashStream.end(); - resolve(hashStream.digest("hex")); - }, - abort() { } - })).catch(reject); - }) - return { filename: cacheFilename, hash }; - }); - //重命名 - const filename = `${res.hash}.${saveType}`; - fs.promises.rename(res.filename, path.join(cacheDir, filename)); - //保存缓存 - if (!cache) { - cache = { url, filename }; - cacheJsonData.push(cache); - } - else { - cache.filename = filename; - cache.url = url; - } - fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4)); - } - //返回模型数据 - const modelPath = path.join(cacheDir, cache.filename); - const modelType = path.extname(cache.filename).substring(1) as ModelType; + //定义函数用于加载模型信息 + async function resolveModel() { + //加载模型配置 + const cacheJsonFile = path.join(cacheDir, "config.json"); + let cacheJsonData: Array<{ url: string, filename: string }> = []; + if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) { + try { + cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8")); + } catch (e) { console.error(e); } + } + //不存在则下载 + let cache = cacheJsonData.find(c => c.url === url); + if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) { + let saveType = option?.saveType ?? null; + const saveTypeDict: Record = { + ".onnx": "onnx", + ".mnn": "mnn", + }; + const _url = new URL(url); + const res = await fetch(_url).then(res => { + const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1]; + if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType; + if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx"; + if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`); + return res.blob(); + }).then(blob => blob.stream()).then(async stream => { + const cacheFilename = path.join(cacheDir, Date.now().toString()); + const hash = await new Promise((resolve, reject) => { + let fsStream!: ReturnType; + let hashStream!: ReturnType; + stream.pipeTo(new WritableStream({ + start(controller) { + fsStream = fs.createWriteStream(cacheFilename); + hashStream = crypto.createHash("md5"); + }, + async write(chunk, controller) { + await new Promise((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve())); + await new Promise((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve())); + }, + close() { + fsStream.end(); + hashStream.end(); + resolve(hashStream.digest("hex")); + }, + abort() { } + })).catch(reject); + }) + return { filename: cacheFilename, hash }; + }); + //重命名 + const filename = `${res.hash}.${saveType}`; + fs.promises.rename(res.filename, path.join(cacheDir, filename)); + //保存缓存 + if (!cache) { + cache = { url, filename }; + cacheJsonData.push(cache); + } + else { + cache.filename = filename; + cache.url = url; + } + fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4)); + } + //返回模型数据 + return path.join(cacheDir, cache.filename); + } + + //查找任务 + let cache = cacheDownloadTasks.find(c => c.url === url && c.cacheDir === cacheDir); + if (!cache) { + cache = { url, cacheDir, modelPath: resolveModel() } + cacheDownloadTasks.push(cache); + } + + //获取模型数据 + const modelPath = await cache.modelPath; + const modelType = path.extname(modelPath).substring(1) as ModelType; let model: T | undefined = undefined; if (option?.createModel) { if (modelType === "onnx") model = (this as any).fromOnnx(modelPath); else if (modelType == "mnn") model = (this as any).fromMNN(modelPath); } + //返回结果 return { modelPath, modelType, model: model as any } }