修复模型下载问题
This commit is contained in:
@ -24,6 +24,8 @@ export interface ModelCacheResult<T, Create extends boolean> {
|
|||||||
model: Create extends true ? T : never
|
model: Create extends true ? T : never
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
|
||||||
|
|
||||||
export abstract class Model {
|
export abstract class Model {
|
||||||
protected session: backend.common.CommonSession;
|
protected session: backend.common.CommonSession;
|
||||||
|
|
||||||
@ -58,77 +60,91 @@ export abstract class Model {
|
|||||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||||
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
||||||
await fs.promises.mkdir(cacheDir, { recursive: true });
|
await fs.promises.mkdir(cacheDir, { recursive: true });
|
||||||
//加载模型配置
|
|
||||||
const cacheJsonFile = path.join(cacheDir, "config.json");
|
|
||||||
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
|
||||||
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
|
||||||
try {
|
|
||||||
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
|
||||||
} catch (e) { console.error(e); }
|
|
||||||
}
|
|
||||||
//不存在则下载
|
|
||||||
let cache = cacheJsonData.find(c => c.url === url);
|
|
||||||
if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) {
|
|
||||||
let saveType = option?.saveType ?? null;
|
|
||||||
const saveTypeDict: Record<string, ModelType> = {
|
|
||||||
".onnx": "onnx",
|
|
||||||
".mnn": "mnn",
|
|
||||||
};
|
|
||||||
const _url = new URL(url);
|
|
||||||
const res = await fetch(_url).then(res => {
|
|
||||||
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
|
||||||
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
|
||||||
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
|
||||||
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
|
||||||
return res.blob();
|
|
||||||
}).then(blob => blob.stream()).then(async stream => {
|
|
||||||
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
|
||||||
const hash = await new Promise<string>((resolve, reject) => {
|
|
||||||
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
|
||||||
let hashStream!: ReturnType<typeof crypto.createHash>;
|
|
||||||
stream.pipeTo(new WritableStream({
|
|
||||||
start(controller) {
|
|
||||||
fsStream = fs.createWriteStream(cacheFilename);
|
|
||||||
hashStream = crypto.createHash("md5");
|
|
||||||
},
|
|
||||||
async write(chunk, controller) {
|
|
||||||
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
|
||||||
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
|
||||||
},
|
|
||||||
close() {
|
|
||||||
fsStream.end();
|
|
||||||
hashStream.end();
|
|
||||||
resolve(hashStream.digest("hex"));
|
|
||||||
},
|
|
||||||
abort() { }
|
|
||||||
})).catch(reject);
|
|
||||||
})
|
|
||||||
return { filename: cacheFilename, hash };
|
|
||||||
});
|
|
||||||
//重命名
|
|
||||||
const filename = `${res.hash}.${saveType}`;
|
|
||||||
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
|
||||||
//保存缓存
|
|
||||||
if (!cache) {
|
|
||||||
cache = { url, filename };
|
|
||||||
cacheJsonData.push(cache);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
cache.filename = filename;
|
|
||||||
cache.url = url;
|
|
||||||
}
|
|
||||||
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
|
||||||
}
|
|
||||||
//返回模型数据
|
|
||||||
const modelPath = path.join(cacheDir, cache.filename);
|
|
||||||
|
|
||||||
const modelType = path.extname(cache.filename).substring(1) as ModelType;
|
//定义函数用于加载模型信息
|
||||||
|
async function resolveModel() {
|
||||||
|
//加载模型配置
|
||||||
|
const cacheJsonFile = path.join(cacheDir, "config.json");
|
||||||
|
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
||||||
|
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
||||||
|
try {
|
||||||
|
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
||||||
|
} catch (e) { console.error(e); }
|
||||||
|
}
|
||||||
|
//不存在则下载
|
||||||
|
let cache = cacheJsonData.find(c => c.url === url);
|
||||||
|
if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) {
|
||||||
|
let saveType = option?.saveType ?? null;
|
||||||
|
const saveTypeDict: Record<string, ModelType> = {
|
||||||
|
".onnx": "onnx",
|
||||||
|
".mnn": "mnn",
|
||||||
|
};
|
||||||
|
const _url = new URL(url);
|
||||||
|
const res = await fetch(_url).then(res => {
|
||||||
|
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
||||||
|
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
||||||
|
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
||||||
|
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
||||||
|
return res.blob();
|
||||||
|
}).then(blob => blob.stream()).then(async stream => {
|
||||||
|
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
||||||
|
const hash = await new Promise<string>((resolve, reject) => {
|
||||||
|
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
||||||
|
let hashStream!: ReturnType<typeof crypto.createHash>;
|
||||||
|
stream.pipeTo(new WritableStream({
|
||||||
|
start(controller) {
|
||||||
|
fsStream = fs.createWriteStream(cacheFilename);
|
||||||
|
hashStream = crypto.createHash("md5");
|
||||||
|
},
|
||||||
|
async write(chunk, controller) {
|
||||||
|
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||||
|
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||||
|
},
|
||||||
|
close() {
|
||||||
|
fsStream.end();
|
||||||
|
hashStream.end();
|
||||||
|
resolve(hashStream.digest("hex"));
|
||||||
|
},
|
||||||
|
abort() { }
|
||||||
|
})).catch(reject);
|
||||||
|
})
|
||||||
|
return { filename: cacheFilename, hash };
|
||||||
|
});
|
||||||
|
//重命名
|
||||||
|
const filename = `${res.hash}.${saveType}`;
|
||||||
|
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
||||||
|
//保存缓存
|
||||||
|
if (!cache) {
|
||||||
|
cache = { url, filename };
|
||||||
|
cacheJsonData.push(cache);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cache.filename = filename;
|
||||||
|
cache.url = url;
|
||||||
|
}
|
||||||
|
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
||||||
|
}
|
||||||
|
//返回模型数据
|
||||||
|
return path.join(cacheDir, cache.filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
//查找任务
|
||||||
|
let cache = cacheDownloadTasks.find(c => c.url === url && c.cacheDir === cacheDir);
|
||||||
|
if (!cache) {
|
||||||
|
cache = { url, cacheDir, modelPath: resolveModel() }
|
||||||
|
cacheDownloadTasks.push(cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
//获取模型数据
|
||||||
|
const modelPath = await cache.modelPath;
|
||||||
|
const modelType = path.extname(modelPath).substring(1) as ModelType;
|
||||||
let model: T | undefined = undefined;
|
let model: T | undefined = undefined;
|
||||||
if (option?.createModel) {
|
if (option?.createModel) {
|
||||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||||
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
|
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//返回结果
|
||||||
return { modelPath, modelType, model: model as any }
|
return { modelPath, modelType, model: model as any }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user