【增加模型快捷加载方法】
This commit is contained in:
@ -10,6 +10,20 @@ export interface ImageCropOption {
|
||||
crop?: { sx: number, sy: number, sw: number, sh: number }
|
||||
}
|
||||
|
||||
export type ModelType = "onnx" | "mnn"
|
||||
|
||||
export interface ModelCacheOption<Create extends boolean> {
|
||||
cacheDir?: string
|
||||
saveType?: ModelType,
|
||||
createModel?: Create
|
||||
}
|
||||
|
||||
export interface ModelCacheResult<T, Create extends boolean> {
|
||||
modelPath: string
|
||||
modelType: ModelType
|
||||
model: Create extends true ? T : never
|
||||
}
|
||||
|
||||
export abstract class Model {
|
||||
protected session: backend.common.CommonSession;
|
||||
|
||||
@ -23,13 +37,87 @@ export abstract class Model {
|
||||
else throw new Error("Invalid image");
|
||||
}
|
||||
|
||||
public static fromOnnx<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array) {
|
||||
return new this(new backend.ort.Session(modelData));
|
||||
public static async fromOnnx<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
|
||||
if (typeof modelData === "string") {
|
||||
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
|
||||
}
|
||||
return new this(new backend.ort.Session(modelData as Uint8Array));
|
||||
}
|
||||
|
||||
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
||||
//初始化目录
|
||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
||||
await fs.promises.mkdir(cacheDir, { recursive: true });
|
||||
//加载模型配置
|
||||
const cacheJsonFile = path.join(cacheDir, "config.json");
|
||||
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
||||
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
||||
try {
|
||||
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
||||
} catch (e) { console.error(e); }
|
||||
}
|
||||
//不存在则下载
|
||||
let cache = cacheJsonData.find(c => c.url === url);
|
||||
if (!cache) {
|
||||
let saveType = option?.saveType ?? null;
|
||||
const saveTypeDict: Record<string, ModelType> = {
|
||||
".onnx": "onnx",
|
||||
".mnn": "mnn",
|
||||
};
|
||||
const _url = new URL(url);
|
||||
const res = await fetch(_url).then(res => {
|
||||
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
||||
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
||||
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
||||
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
||||
return res.blob();
|
||||
}).then(blob => blob.stream()).then(async stream => {
|
||||
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
||||
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
||||
let hashStream!: ReturnType<typeof crypto.createHash>;
|
||||
let hash!: string;
|
||||
await stream.pipeTo(new WritableStream({
|
||||
start(controller) {
|
||||
fsStream = fs.createWriteStream(cacheFilename);
|
||||
hashStream = crypto.createHash("md5");
|
||||
},
|
||||
async write(chunk, controller) {
|
||||
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||
},
|
||||
close() {
|
||||
fsStream.end();
|
||||
hashStream.end();
|
||||
hash = hashStream.digest("hex")
|
||||
},
|
||||
abort() { }
|
||||
}));
|
||||
return { filename: cacheFilename, hash };
|
||||
});
|
||||
//重命名
|
||||
const filename = `${res.hash}.${saveType}`;
|
||||
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
||||
//保存缓存
|
||||
cache = { url, filename };
|
||||
cacheJsonData.push(cache);
|
||||
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
||||
}
|
||||
//返回模型数据
|
||||
const modelPath = path.join(cacheDir, cache.filename);
|
||||
|
||||
const modelType = path.extname(cache.filename).substring(1) as ModelType;
|
||||
let model: T | undefined = undefined;
|
||||
if (option?.createModel) {
|
||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||
}
|
||||
|
||||
return { modelPath, modelType, model: model as any }
|
||||
}
|
||||
|
||||
public constructor(session: backend.common.CommonSession) { this.session = session; }
|
||||
|
||||
|
||||
public get inputs() { return this.session.inputs; }
|
||||
public get outputs() { return this.session.outputs; }
|
||||
public get input() { return Object.entries(this.inputs)[0][1]; }
|
||||
|
Reference in New Issue
Block a user