【增加模型快捷加载方法】
This commit is contained in:
@ -10,6 +10,20 @@ export interface ImageCropOption {
|
||||
crop?: { sx: number, sy: number, sw: number, sh: number }
|
||||
}
|
||||
|
||||
export type ModelType = "onnx" | "mnn"
|
||||
|
||||
export interface ModelCacheOption<Create extends boolean> {
|
||||
cacheDir?: string
|
||||
saveType?: ModelType,
|
||||
createModel?: Create
|
||||
}
|
||||
|
||||
export interface ModelCacheResult<T, Create extends boolean> {
|
||||
modelPath: string
|
||||
modelType: ModelType
|
||||
model: Create extends true ? T : never
|
||||
}
|
||||
|
||||
export abstract class Model {
|
||||
protected session: backend.common.CommonSession;
|
||||
|
||||
@ -23,13 +37,87 @@ export abstract class Model {
|
||||
else throw new Error("Invalid image");
|
||||
}
|
||||
|
||||
public static fromOnnx<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array) {
|
||||
return new this(new backend.ort.Session(modelData));
|
||||
public static async fromOnnx<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
|
||||
if (typeof modelData === "string") {
|
||||
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
|
||||
}
|
||||
return new this(new backend.ort.Session(modelData as Uint8Array));
|
||||
}
|
||||
|
||||
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
||||
//初始化目录
|
||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
||||
await fs.promises.mkdir(cacheDir, { recursive: true });
|
||||
//加载模型配置
|
||||
const cacheJsonFile = path.join(cacheDir, "config.json");
|
||||
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
||||
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
||||
try {
|
||||
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
||||
} catch (e) { console.error(e); }
|
||||
}
|
||||
//不存在则下载
|
||||
let cache = cacheJsonData.find(c => c.url === url);
|
||||
if (!cache) {
|
||||
let saveType = option?.saveType ?? null;
|
||||
const saveTypeDict: Record<string, ModelType> = {
|
||||
".onnx": "onnx",
|
||||
".mnn": "mnn",
|
||||
};
|
||||
const _url = new URL(url);
|
||||
const res = await fetch(_url).then(res => {
|
||||
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
||||
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
||||
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
||||
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
||||
return res.blob();
|
||||
}).then(blob => blob.stream()).then(async stream => {
|
||||
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
||||
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
||||
let hashStream!: ReturnType<typeof crypto.createHash>;
|
||||
let hash!: string;
|
||||
await stream.pipeTo(new WritableStream({
|
||||
start(controller) {
|
||||
fsStream = fs.createWriteStream(cacheFilename);
|
||||
hashStream = crypto.createHash("md5");
|
||||
},
|
||||
async write(chunk, controller) {
|
||||
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||
},
|
||||
close() {
|
||||
fsStream.end();
|
||||
hashStream.end();
|
||||
hash = hashStream.digest("hex")
|
||||
},
|
||||
abort() { }
|
||||
}));
|
||||
return { filename: cacheFilename, hash };
|
||||
});
|
||||
//重命名
|
||||
const filename = `${res.hash}.${saveType}`;
|
||||
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
||||
//保存缓存
|
||||
cache = { url, filename };
|
||||
cacheJsonData.push(cache);
|
||||
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
||||
}
|
||||
//返回模型数据
|
||||
const modelPath = path.join(cacheDir, cache.filename);
|
||||
|
||||
const modelType = path.extname(cache.filename).substring(1) as ModelType;
|
||||
let model: T | undefined = undefined;
|
||||
if (option?.createModel) {
|
||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||
}
|
||||
|
||||
return { modelPath, modelType, model: model as any }
|
||||
}
|
||||
|
||||
public constructor(session: backend.common.CommonSession) { this.session = session; }
|
||||
|
||||
|
||||
public get inputs() { return this.session.inputs; }
|
||||
public get outputs() { return this.session.outputs; }
|
||||
public get input() { return Object.entries(this.inputs)[0][1]; }
|
||||
|
@ -17,8 +17,17 @@ class FaceLandmark1000Result extends FaceAlignmentResult {
|
||||
protected contourPointIndex(): number[] { return this.indexFromTo(0, 272); }
|
||||
}
|
||||
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
||||
};
|
||||
|
||||
export class FaceLandmark1000 extends Model {
|
||||
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "FACELANDMARK1000_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
|
||||
public predict(image: ImageSource, option?: FaceLandmark1000PredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); }
|
||||
|
||||
public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) {
|
||||
|
@ -17,7 +17,17 @@ class PFLDResult extends FaceAlignmentResult {
|
||||
protected contourPointIndex(): number[] { return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]; }
|
||||
}
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
||||
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
||||
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
||||
};
|
||||
export class PFLD extends Model {
|
||||
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "PFLD_106_LITE_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
|
||||
public predict(image: ImageSource, option?: PFLDPredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); }
|
||||
|
||||
private async doPredict(image: cv.Mat, option?: PFLDPredictOption) {
|
||||
|
@ -10,8 +10,17 @@ export interface GenderAgePredictResult {
|
||||
age: number
|
||||
}
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
||||
};
|
||||
|
||||
export class GenderAge extends Model {
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "INSIGHT_GENDER_AGE_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
|
||||
public predict(image: ImageSource, option?: GenderAgePredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); }
|
||||
|
||||
private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise<GenderAgePredictResult> {
|
||||
const input = this.input;
|
||||
const output = this.output;
|
||||
@ -33,8 +42,4 @@ export class GenderAge extends Model {
|
||||
age: parseInt(result[2] * 100 as any),
|
||||
}
|
||||
}
|
||||
|
||||
public predict(image: ImageSource, option?: GenderAgePredictOption) {
|
||||
return Model.resolveImage(image, im => this.doPredict(im, option));
|
||||
}
|
||||
}
|
||||
|
@ -41,7 +41,6 @@ export class FaceBox {
|
||||
const { imw, imh } = this.#option;
|
||||
let size = Math.max(this.width, this.height) / 2;
|
||||
const cx = this.centerX, cy = this.centerY;
|
||||
console.log(this)
|
||||
|
||||
if (cx - size < 0) size = cx;
|
||||
if (cx + size > imw) size = imw - cx;
|
||||
|
@ -2,8 +2,16 @@ import { cv } from "../../cv";
|
||||
import { convertImage } from "../common/processors";
|
||||
import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
||||
};
|
||||
|
||||
export class Yolov5Face extends FaceDetector {
|
||||
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "YOLOV5S_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
|
||||
public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise<FaceBox[]> {
|
||||
const input = this.input;
|
||||
const resizedImage = image.resize(input.shape[2], input.shape[3]);
|
||||
|
@ -2,8 +2,16 @@ import { Mat } from "../../cv/mat";
|
||||
import { convertImage } from "../common/processors";
|
||||
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
||||
};
|
||||
|
||||
export class AdaFace extends FaceRecognition {
|
||||
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "MOBILEFACENET_ADAFACE_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
|
||||
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
||||
const input = this.input;
|
||||
const output = this.output;
|
||||
|
@ -2,6 +2,24 @@ import { Mat } from "../../cv/mat";
|
||||
import { convertImage } from "../common/processors";
|
||||
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
||||
|
||||
const MODEL_URL_CONFIG_ARC_FACE = {
|
||||
INSIGHTFACE_ARCFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r100.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
||||
};
|
||||
const MODEL_URL_CONFIG_COS_FACE = {
|
||||
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
||||
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
||||
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
||||
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
||||
};
|
||||
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
||||
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
||||
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
||||
};
|
||||
|
||||
|
||||
export class Insightface extends FaceRecognition {
|
||||
|
||||
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
||||
@ -24,9 +42,21 @@ export class Insightface extends FaceRecognition {
|
||||
|
||||
}
|
||||
|
||||
export class ArcFace extends Insightface { }
|
||||
export class ArcFace extends Insightface {
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG_ARC_FACE) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG_ARC_FACE[type ?? "INSIGHTFACE_ARCFACE_R100_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
}
|
||||
|
||||
export class CosFace extends Insightface { }
|
||||
export class CosFace extends Insightface {
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG_COS_FACE) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG_COS_FACE[type ?? "INSIGHTFACE_COSFACE_R100_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
}
|
||||
|
||||
export class PartialFC extends Insightface { }
|
||||
export class PartialFC extends Insightface {
|
||||
public static async load(type?: keyof typeof MODEL_URL_CONFIG_PARTIAL_FC) {
|
||||
return this.cacheModel(MODEL_URL_CONFIG_PARTIAL_FC[type ?? "INSIGHTFACE_PARTIALFC_R100_ONNX"], { createModel: true }).then(r => r.model);
|
||||
}
|
||||
}
|
||||
|
||||
|
30
src/test.ts
30
src/test.ts
@ -30,8 +30,8 @@ async function cacheImage(group: string, url: string) {
|
||||
}
|
||||
|
||||
async function testGenderTest() {
|
||||
const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx"));
|
||||
const detector = deploy.faceattr.GenderAgeDetector.fromOnnx(fs.readFileSync("models/faceattr/insight_gender_age.onnx"));
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX");
|
||||
|
||||
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
||||
const boxes = await facedet.predict(image);
|
||||
@ -44,8 +44,9 @@ async function testGenderTest() {
|
||||
}
|
||||
|
||||
async function testFaceID() {
|
||||
const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx"));
|
||||
const faceid = deploy.faceid.CosFace.fromOnnx(fs.readFileSync("models/faceid/insightface/glint360k_cosface_r100.onnx"));
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const faceid = await deploy.faceid.PartialFC.load();
|
||||
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
||||
|
||||
const { basic, tests } = faceidTestData.stars;
|
||||
|
||||
@ -63,9 +64,17 @@ async function testFaceID() {
|
||||
basicFaceIndex[name] = basicDetectedFaces.findIndex(box => box.x1 < x && box.x2 > x && box.y1 < y && box.y2 > y);
|
||||
}
|
||||
|
||||
async function getEmbd(image: cv.Mat, box: deploy.facedet.FaceBox) {
|
||||
box = box.toSquare();
|
||||
const alignResult = await facealign.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
let faceImage = image.rotate(box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI);
|
||||
return faceid.predict(faceImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
}
|
||||
|
||||
const basicEmbds: number[][] = [];
|
||||
for (const box of basicDetectedFaces) {
|
||||
const embd = await faceid.predict(basicImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
// const embd = await faceid.predict(basicImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
const embd = await getEmbd(basicImage, box);
|
||||
basicEmbds.push(embd);
|
||||
}
|
||||
|
||||
@ -86,7 +95,8 @@ async function testFaceID() {
|
||||
continue
|
||||
}
|
||||
|
||||
const embd = await faceid.predict(img, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
// const embd = await faceid.predict(img, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
||||
const embd = await getEmbd(img, box);
|
||||
|
||||
const compareEmbds = basicEmbds.map(e => deploy.faceid.cosineDistance(e, embd));
|
||||
const max = Math.max(...compareEmbds);
|
||||
@ -100,9 +110,9 @@ async function testFaceID() {
|
||||
}
|
||||
|
||||
async function testFaceAlign() {
|
||||
const fd = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx"));
|
||||
// const fa = deploy.facealign.PFLD.fromOnnx(fs.readFileSync("models/facealign/pfld-106-lite.onnx"));
|
||||
const fa = deploy.facealign.FaceLandmark1000.fromOnnx(fs.readFileSync("models/facealign/FaceLandmark1000.onnx"));
|
||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
||||
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
||||
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
||||
image = image.rotate(image.width / 2, image.height / 2, 0);
|
||||
|
||||
@ -122,10 +132,8 @@ async function testFaceAlign() {
|
||||
|
||||
console.log(points);
|
||||
console.log(points.direction * 180 / Math.PI);
|
||||
debugger
|
||||
}
|
||||
|
||||
|
||||
async function test() {
|
||||
await testGenderTest();
|
||||
await testFaceID();
|
||||
|
Reference in New Issue
Block a user