import { backend } from "../../backend"; import { cv } from "../../cv"; export type ModelConstructor = new (session: backend.common.CommonSession) => T; export type ImageSource = cv.Mat | Uint8Array | string; export interface ImageCropOption { /** 图片裁剪区域 */ crop?: { sx: number, sy: number, sw: number, sh: number } } export type ModelType = "onnx" | "mnn" export interface ModelCacheOption { cacheDir?: string saveType?: ModelType, createModel?: Create } export interface ModelCacheResult { modelPath: string modelType: ModelType model: Create extends true ? T : never } export abstract class Model { protected session: backend.common.CommonSession; protected static async resolveImage(image: ImageSource, resolver: (image: cv.Mat) => R | Promise): Promise { if (typeof image === "string") { if (/^https?:\/\//.test(image)) image = await fetch(image).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer)); else image = await import("fs").then(fs => fs.promises.readFile(image as string)); } if (image instanceof Uint8Array) image = new cv.Mat(image, { mode: cv.ImreadModes.IMREAD_COLOR_BGR }) if (image instanceof cv.Mat) return await resolver(image); else throw new Error("Invalid image"); } public static async fromOnnx(this: ModelConstructor, modelData: Uint8Array | string) { if (typeof modelData === "string") { if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer)); else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string)); } return new this(new backend.ort.Session(modelData as Uint8Array)); } protected static async cacheModel(this: ModelConstructor, url: string, option?: ModelCacheOption): Promise> { //初始化目录 const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]); const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models"); await fs.promises.mkdir(cacheDir, { recursive: true }); //加载模型配置 const cacheJsonFile = path.join(cacheDir, "config.json"); let cacheJsonData: Array<{ url: string, filename: string }> = []; if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) { try { cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8")); } catch (e) { console.error(e); } } //不存在则下载 let cache = cacheJsonData.find(c => c.url === url); if (!cache) { let saveType = option?.saveType ?? null; const saveTypeDict: Record = { ".onnx": "onnx", ".mnn": "mnn", }; const _url = new URL(url); const res = await fetch(_url).then(res => { const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1]; if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType; if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx"; if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`); return res.blob(); }).then(blob => blob.stream()).then(async stream => { const cacheFilename = path.join(cacheDir, Date.now().toString()); let fsStream!: ReturnType; let hashStream!: ReturnType; let hash!: string; await stream.pipeTo(new WritableStream({ start(controller) { fsStream = fs.createWriteStream(cacheFilename); hashStream = crypto.createHash("md5"); }, async write(chunk, controller) { await new Promise((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve())); await new Promise((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve())); }, close() { fsStream.end(); hashStream.end(); hash = hashStream.digest("hex") }, abort() { } })); return { filename: cacheFilename, hash }; }); //重命名 const filename = `${res.hash}.${saveType}`; fs.promises.rename(res.filename, path.join(cacheDir, filename)); //保存缓存 cache = { url, filename }; cacheJsonData.push(cache); fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4)); } //返回模型数据 const modelPath = path.join(cacheDir, cache.filename); const modelType = path.extname(cache.filename).substring(1) as ModelType; let model: T | undefined = undefined; if (option?.createModel) { if (modelType === "onnx") model = (this as any).fromOnnx(modelPath); } return { modelPath, modelType, model: model as any } } public constructor(session: backend.common.CommonSession) { this.session = session; } public get inputs() { return this.session.inputs; } public get outputs() { return this.session.outputs; } public get input() { return Object.entries(this.inputs)[0][1]; } public get output() { return Object.entries(this.outputs)[0][1]; } }