From 8af5846490c268225cee40ebc5673e7e0de0b9a4 Mon Sep 17 00:00:00 2001 From: Yizhi <946185759@qq.com> Date: Fri, 7 Mar 2025 10:19:26 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=BF=AB=E6=8D=B7=E5=8A=A0=E8=BD=BD=E6=96=B9=E6=B3=95=E3=80=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/deploy/common/model.ts | 94 +++++++++++++++++++++++++++- src/deploy/facealign/landmark1000.ts | 9 +++ src/deploy/facealign/pfld.ts | 10 +++ src/deploy/faceattr/gender-age.ts | 13 ++-- src/deploy/facedet/common.ts | 1 - src/deploy/facedet/yolov5.ts | 8 +++ src/deploy/faceid/adaface.ts | 8 +++ src/deploy/faceid/insightface.ts | 36 ++++++++++- src/test.ts | 30 +++++---- 9 files changed, 187 insertions(+), 22 deletions(-) diff --git a/src/deploy/common/model.ts b/src/deploy/common/model.ts index 95e64b1..c1621e8 100644 --- a/src/deploy/common/model.ts +++ b/src/deploy/common/model.ts @@ -10,6 +10,20 @@ export interface ImageCropOption { crop?: { sx: number, sy: number, sw: number, sh: number } } +export type ModelType = "onnx" | "mnn" + +export interface ModelCacheOption { + cacheDir?: string + saveType?: ModelType, + createModel?: Create +} + +export interface ModelCacheResult { + modelPath: string + modelType: ModelType + model: Create extends true ? T : never +} + export abstract class Model { protected session: backend.common.CommonSession; @@ -23,13 +37,87 @@ export abstract class Model { else throw new Error("Invalid image"); } - public static fromOnnx(this: ModelConstructor, modelData: Uint8Array) { - return new this(new backend.ort.Session(modelData)); + public static async fromOnnx(this: ModelConstructor, modelData: Uint8Array | string) { + if (typeof modelData === "string") { + if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer)); + else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string)); + } + return new this(new backend.ort.Session(modelData as Uint8Array)); + } + + protected static async cacheModel(this: ModelConstructor, url: string, option?: ModelCacheOption): Promise> { + //初始化目录 + const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]); + const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models"); + await fs.promises.mkdir(cacheDir, { recursive: true }); + //加载模型配置 + const cacheJsonFile = path.join(cacheDir, "config.json"); + let cacheJsonData: Array<{ url: string, filename: string }> = []; + if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) { + try { + cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8")); + } catch (e) { console.error(e); } + } + //不存在则下载 + let cache = cacheJsonData.find(c => c.url === url); + if (!cache) { + let saveType = option?.saveType ?? null; + const saveTypeDict: Record = { + ".onnx": "onnx", + ".mnn": "mnn", + }; + const _url = new URL(url); + const res = await fetch(_url).then(res => { + const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1]; + if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType; + if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx"; + if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`); + return res.blob(); + }).then(blob => blob.stream()).then(async stream => { + const cacheFilename = path.join(cacheDir, Date.now().toString()); + let fsStream!: ReturnType; + let hashStream!: ReturnType; + let hash!: string; + await stream.pipeTo(new WritableStream({ + start(controller) { + fsStream = fs.createWriteStream(cacheFilename); + hashStream = crypto.createHash("md5"); + }, + async write(chunk, controller) { + await new Promise((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve())); + await new Promise((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve())); + }, + close() { + fsStream.end(); + hashStream.end(); + hash = hashStream.digest("hex") + }, + abort() { } + })); + return { filename: cacheFilename, hash }; + }); + //重命名 + const filename = `${res.hash}.${saveType}`; + fs.promises.rename(res.filename, path.join(cacheDir, filename)); + //保存缓存 + cache = { url, filename }; + cacheJsonData.push(cache); + fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4)); + } + //返回模型数据 + const modelPath = path.join(cacheDir, cache.filename); + + const modelType = path.extname(cache.filename).substring(1) as ModelType; + let model: T | undefined = undefined; + if (option?.createModel) { + if (modelType === "onnx") model = (this as any).fromOnnx(modelPath); + } + + return { modelPath, modelType, model: model as any } } public constructor(session: backend.common.CommonSession) { this.session = session; } - public get inputs() { return this.session.inputs; } public get outputs() { return this.session.outputs; } public get input() { return Object.entries(this.inputs)[0][1]; } diff --git a/src/deploy/facealign/landmark1000.ts b/src/deploy/facealign/landmark1000.ts index 0826a7a..3dde638 100644 --- a/src/deploy/facealign/landmark1000.ts +++ b/src/deploy/facealign/landmark1000.ts @@ -17,8 +17,17 @@ class FaceLandmark1000Result extends FaceAlignmentResult { protected contourPointIndex(): number[] { return this.indexFromTo(0, 272); } } + +const MODEL_URL_CONFIG = { + FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`, +}; + export class FaceLandmark1000 extends Model { + public static async load(type?: keyof typeof MODEL_URL_CONFIG) { + return this.cacheModel(MODEL_URL_CONFIG[type ?? "FACELANDMARK1000_ONNX"], { createModel: true }).then(r => r.model); + } + public predict(image: ImageSource, option?: FaceLandmark1000PredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) { diff --git a/src/deploy/facealign/pfld.ts b/src/deploy/facealign/pfld.ts index 41cc24e..0e34024 100644 --- a/src/deploy/facealign/pfld.ts +++ b/src/deploy/facealign/pfld.ts @@ -17,7 +17,17 @@ class PFLDResult extends FaceAlignmentResult { protected contourPointIndex(): number[] { return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]; } } +const MODEL_URL_CONFIG = { + PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`, + PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`, + PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`, +}; export class PFLD extends Model { + + public static async load(type?: keyof typeof MODEL_URL_CONFIG) { + return this.cacheModel(MODEL_URL_CONFIG[type ?? "PFLD_106_LITE_ONNX"], { createModel: true }).then(r => r.model); + } + public predict(image: ImageSource, option?: PFLDPredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } private async doPredict(image: cv.Mat, option?: PFLDPredictOption) { diff --git a/src/deploy/faceattr/gender-age.ts b/src/deploy/faceattr/gender-age.ts index 693a936..c3b8c76 100644 --- a/src/deploy/faceattr/gender-age.ts +++ b/src/deploy/faceattr/gender-age.ts @@ -10,8 +10,17 @@ export interface GenderAgePredictResult { age: number } +const MODEL_URL_CONFIG = { + INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`, +}; export class GenderAge extends Model { + public static async load(type?: keyof typeof MODEL_URL_CONFIG) { + return this.cacheModel(MODEL_URL_CONFIG[type ?? "INSIGHT_GENDER_AGE_ONNX"], { createModel: true }).then(r => r.model); + } + + public predict(image: ImageSource, option?: GenderAgePredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } + private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise { const input = this.input; const output = this.output; @@ -33,8 +42,4 @@ export class GenderAge extends Model { age: parseInt(result[2] * 100 as any), } } - - public predict(image: ImageSource, option?: GenderAgePredictOption) { - return Model.resolveImage(image, im => this.doPredict(im, option)); - } } diff --git a/src/deploy/facedet/common.ts b/src/deploy/facedet/common.ts index 9493ab1..a27be8e 100644 --- a/src/deploy/facedet/common.ts +++ b/src/deploy/facedet/common.ts @@ -41,7 +41,6 @@ export class FaceBox { const { imw, imh } = this.#option; let size = Math.max(this.width, this.height) / 2; const cx = this.centerX, cy = this.centerY; - console.log(this) if (cx - size < 0) size = cx; if (cx + size > imw) size = imw - cx; diff --git a/src/deploy/facedet/yolov5.ts b/src/deploy/facedet/yolov5.ts index 6692c42..982a54f 100644 --- a/src/deploy/facedet/yolov5.ts +++ b/src/deploy/facedet/yolov5.ts @@ -2,8 +2,16 @@ import { cv } from "../../cv"; import { convertImage } from "../common/processors"; import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common"; +const MODEL_URL_CONFIG = { + YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`, +}; + export class Yolov5Face extends FaceDetector { + public static async load(type?: keyof typeof MODEL_URL_CONFIG) { + return this.cacheModel(MODEL_URL_CONFIG[type ?? "YOLOV5S_ONNX"], { createModel: true }).then(r => r.model); + } + public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise { const input = this.input; const resizedImage = image.resize(input.shape[2], input.shape[3]); diff --git a/src/deploy/faceid/adaface.ts b/src/deploy/faceid/adaface.ts index 9d61966..fa0974b 100644 --- a/src/deploy/faceid/adaface.ts +++ b/src/deploy/faceid/adaface.ts @@ -2,8 +2,16 @@ import { Mat } from "../../cv/mat"; import { convertImage } from "../common/processors"; import { FaceRecognition, FaceRecognitionPredictOption } from "./common"; +const MODEL_URL_CONFIG = { + MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`, +}; + export class AdaFace extends FaceRecognition { + public static async load(type?: keyof typeof MODEL_URL_CONFIG) { + return this.cacheModel(MODEL_URL_CONFIG[type ?? "MOBILEFACENET_ADAFACE_ONNX"], { createModel: true }).then(r => r.model); + } + public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise { const input = this.input; const output = this.output; diff --git a/src/deploy/faceid/insightface.ts b/src/deploy/faceid/insightface.ts index 90dc06b..e08b78d 100644 --- a/src/deploy/faceid/insightface.ts +++ b/src/deploy/faceid/insightface.ts @@ -2,6 +2,24 @@ import { Mat } from "../../cv/mat"; import { convertImage } from "../common/processors"; import { FaceRecognition, FaceRecognitionPredictOption } from "./common"; +const MODEL_URL_CONFIG_ARC_FACE = { + INSIGHTFACE_ARCFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r100.onnx`, + INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`, + INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`, + INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`, +}; +const MODEL_URL_CONFIG_COS_FACE = { + INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`, + INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`, + INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`, + INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`, +}; +const MODEL_URL_CONFIG_PARTIAL_FC = { + INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`, + INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`, +}; + + export class Insightface extends FaceRecognition { public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise { @@ -24,9 +42,21 @@ export class Insightface extends FaceRecognition { } -export class ArcFace extends Insightface { } +export class ArcFace extends Insightface { + public static async load(type?: keyof typeof MODEL_URL_CONFIG_ARC_FACE) { + return this.cacheModel(MODEL_URL_CONFIG_ARC_FACE[type ?? "INSIGHTFACE_ARCFACE_R100_ONNX"], { createModel: true }).then(r => r.model); + } +} -export class CosFace extends Insightface { } +export class CosFace extends Insightface { + public static async load(type?: keyof typeof MODEL_URL_CONFIG_COS_FACE) { + return this.cacheModel(MODEL_URL_CONFIG_COS_FACE[type ?? "INSIGHTFACE_COSFACE_R100_ONNX"], { createModel: true }).then(r => r.model); + } +} -export class PartialFC extends Insightface { } +export class PartialFC extends Insightface { + public static async load(type?: keyof typeof MODEL_URL_CONFIG_PARTIAL_FC) { + return this.cacheModel(MODEL_URL_CONFIG_PARTIAL_FC[type ?? "INSIGHTFACE_PARTIALFC_R100_ONNX"], { createModel: true }).then(r => r.model); + } +} diff --git a/src/test.ts b/src/test.ts index 3931b30..67ce90f 100644 --- a/src/test.ts +++ b/src/test.ts @@ -30,8 +30,8 @@ async function cacheImage(group: string, url: string) { } async function testGenderTest() { - const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); - const detector = deploy.faceattr.GenderAgeDetector.fromOnnx(fs.readFileSync("models/faceattr/insight_gender_age.onnx")); + const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); + const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX"); const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg"); const boxes = await facedet.predict(image); @@ -44,8 +44,9 @@ async function testGenderTest() { } async function testFaceID() { - const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); - const faceid = deploy.faceid.CosFace.fromOnnx(fs.readFileSync("models/faceid/insightface/glint360k_cosface_r100.onnx")); + const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); + const faceid = await deploy.faceid.PartialFC.load(); + const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX"); const { basic, tests } = faceidTestData.stars; @@ -63,9 +64,17 @@ async function testFaceID() { basicFaceIndex[name] = basicDetectedFaces.findIndex(box => box.x1 < x && box.x2 > x && box.y1 < y && box.y2 > y); } + async function getEmbd(image: cv.Mat, box: deploy.facedet.FaceBox) { + box = box.toSquare(); + const alignResult = await facealign.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + let faceImage = image.rotate(box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI); + return faceid.predict(faceImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + } + const basicEmbds: number[][] = []; for (const box of basicDetectedFaces) { - const embd = await faceid.predict(basicImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + // const embd = await faceid.predict(basicImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + const embd = await getEmbd(basicImage, box); basicEmbds.push(embd); } @@ -86,7 +95,8 @@ async function testFaceID() { continue } - const embd = await faceid.predict(img, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + // const embd = await faceid.predict(img, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + const embd = await getEmbd(img, box); const compareEmbds = basicEmbds.map(e => deploy.faceid.cosineDistance(e, embd)); const max = Math.max(...compareEmbds); @@ -100,9 +110,9 @@ async function testFaceID() { } async function testFaceAlign() { - const fd = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); - // const fa = deploy.facealign.PFLD.fromOnnx(fs.readFileSync("models/facealign/pfld-106-lite.onnx")); - const fa = deploy.facealign.FaceLandmark1000.fromOnnx(fs.readFileSync("models/facealign/FaceLandmark1000.onnx")); + const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); + const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX"); + // const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX"); let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712"); image = image.rotate(image.width / 2, image.height / 2, 0); @@ -122,10 +132,8 @@ async function testFaceAlign() { console.log(points); console.log(points.direction * 180 / Math.PI); - debugger } - async function test() { await testGenderTest(); await testFaceID();