增加MNN模型支持
This commit is contained in:
@ -6,12 +6,26 @@ export interface SessionNodeInfo {
|
||||
shape: number[]
|
||||
}
|
||||
|
||||
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
||||
export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
||||
|
||||
export enum DataType {
|
||||
Unknown,
|
||||
Float32,
|
||||
Float64,
|
||||
Int32,
|
||||
Uint32,
|
||||
Int16,
|
||||
Uint16,
|
||||
Int8,
|
||||
Uint8,
|
||||
Int64,
|
||||
Uint64,
|
||||
}
|
||||
|
||||
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
||||
|
||||
export interface SessionRunInputOption {
|
||||
type?: SessionNodeType
|
||||
type?: DataTypeString
|
||||
data: SessionNodeData
|
||||
shape?: number[]
|
||||
}
|
||||
@ -26,3 +40,36 @@ export abstract class CommonSession {
|
||||
export function isTypedArray(val: any): val is SessionNodeData {
|
||||
return val?.buffer instanceof ArrayBuffer;
|
||||
}
|
||||
|
||||
export function dataTypeFrom(str: DataTypeString) {
|
||||
return {
|
||||
"float32": DataType.Float32,
|
||||
"float64": DataType.Float64,
|
||||
"float": DataType.Float32,
|
||||
"double": DataType.Float64,
|
||||
"int32": DataType.Int32,
|
||||
"uint32": DataType.Uint32,
|
||||
"int16": DataType.Int16,
|
||||
"uint16": DataType.Uint16,
|
||||
"int8": DataType.Int8,
|
||||
"uint8": DataType.Uint8,
|
||||
"int64": DataType.Int64,
|
||||
"uint64": DataType.Uint64,
|
||||
}[str] ?? DataType.Unknown;
|
||||
}
|
||||
|
||||
export function dataTypeToString(type: DataType): DataTypeString | null {
|
||||
switch (type) {
|
||||
case DataType.Float32: return "float32";
|
||||
case DataType.Float64: return "float64";
|
||||
case DataType.Int32: return "int32";
|
||||
case DataType.Uint32: return "uint32";
|
||||
case DataType.Int16: return "int16";
|
||||
case DataType.Uint16: return "uint16";
|
||||
case DataType.Int8: return "int8";
|
||||
case DataType.Uint8: return "uint8";
|
||||
case DataType.Int64: return "int64";
|
||||
case DataType.Uint64: return "uint64";
|
||||
default: return null;
|
||||
}
|
||||
}
|
@ -1 +1 @@
|
||||
export * as backend from "./main";
|
||||
export * as backend from "./main";
|
||||
|
@ -1,2 +1,3 @@
|
||||
export * as common from "./common";
|
||||
export * as ort from "./ort";
|
||||
export * as mnn from "./mnn";
|
1
src/backend/mnn/index.ts
Normal file
1
src/backend/mnn/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export { MNNSession as Session } from "./session";
|
30
src/backend/mnn/session.ts
Normal file
30
src/backend/mnn/session.ts
Normal file
@ -0,0 +1,30 @@
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
|
||||
export class MNNSession extends CommonSession {
|
||||
#session: any
|
||||
#inputs: Record<string, SessionNodeInfo> | null = null;
|
||||
#outputs: Record<string, SessionNodeInfo> | null = null;
|
||||
|
||||
public constructor(modelData: Uint8Array) {
|
||||
super();
|
||||
const addon = require("../../../build/mnn.node")
|
||||
this.#session = new addon.MNNSession(modelData);
|
||||
}
|
||||
|
||||
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>> {
|
||||
const inputArgs: Record<string, any> = {};
|
||||
for (const [name, option] of Object.entries(inputs)) {
|
||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||
}
|
||||
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||
if (err) return reject(err);
|
||||
const result: Record<string, Float32Array> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
||||
resolve(result);
|
||||
}))
|
||||
}
|
||||
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
|
||||
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
|
||||
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
|
||||
export class OrtSession extends CommonSession {
|
||||
#session: any;
|
||||
@ -19,7 +19,7 @@ export class OrtSession extends CommonSession {
|
||||
const inputArgs: Record<string, any> = {};
|
||||
for (const [name, option] of Object.entries(inputs)) {
|
||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||
else inputArgs[name] = option;
|
||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||
}
|
||||
|
||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||
|
@ -45,6 +45,14 @@ export abstract class Model {
|
||||
return new this(new backend.ort.Session(modelData as Uint8Array));
|
||||
}
|
||||
|
||||
public static async fromMNN<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
|
||||
if (typeof modelData === "string") {
|
||||
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
|
||||
}
|
||||
return new this(new backend.mnn.Session(modelData as Uint8Array));
|
||||
}
|
||||
|
||||
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
||||
//初始化目录
|
||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||
@ -111,6 +119,7 @@ export abstract class Model {
|
||||
let model: T | undefined = undefined;
|
||||
if (option?.createModel) {
|
||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
|
||||
}
|
||||
|
||||
return { modelPath, modelType, model: model as any }
|
||||
|
@ -20,6 +20,7 @@ class FaceLandmark1000Result extends FaceAlignmentResult {
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
||||
FACELANDMARK1000_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.mnn`,
|
||||
};
|
||||
|
||||
export class FaceLandmark1000 extends Model {
|
||||
|
@ -21,6 +21,9 @@ const MODEL_URL_CONFIG = {
|
||||
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
||||
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
||||
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
||||
PFLD_106_LITE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.mnn`,
|
||||
PFLD_106_V2_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.mnn`,
|
||||
PFLD_106_V3_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.mnn`,
|
||||
};
|
||||
export class PFLD extends Model {
|
||||
|
||||
|
@ -12,6 +12,7 @@ export interface GenderAgePredictResult {
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
||||
INSIGHT_GENDER_AGE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.mnn`,
|
||||
};
|
||||
|
||||
export class GenderAge extends Model {
|
||||
|
@ -4,6 +4,7 @@ import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
||||
YOLOV5S_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.mnn`,
|
||||
};
|
||||
|
||||
export class Yolov5Face extends FaceDetector {
|
||||
|
@ -4,6 +4,7 @@ import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
||||
|
||||
const MODEL_URL_CONFIG = {
|
||||
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
||||
MOBILEFACENET_ADAFACE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.mnn`,
|
||||
};
|
||||
|
||||
export class AdaFace extends FaceRecognition {
|
||||
|
@ -7,16 +7,23 @@ const MODEL_URL_CONFIG_ARC_FACE = {
|
||||
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
||||
INSIGHTFACE_ARCFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.mnn`,
|
||||
INSIGHTFACE_ARCFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.mnn`,
|
||||
INSIGHTFACE_ARCFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.mnn`,
|
||||
};
|
||||
const MODEL_URL_CONFIG_COS_FACE = {
|
||||
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
||||
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
||||
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
||||
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
||||
INSIGHTFACE_COSFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.mnn`,
|
||||
INSIGHTFACE_COSFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.mnn`,
|
||||
INSIGHTFACE_COSFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.mnn`,
|
||||
};
|
||||
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
||||
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
||||
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
||||
INSIGHTFACE_PARTIALFC_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.mnn`,
|
||||
};
|
||||
|
||||
|
||||
|
19
src/test.ts
19
src/test.ts
@ -30,8 +30,8 @@ async function cacheImage(group: string, url: string) {
|
||||
}
|
||||
|
||||
async function testGenderTest() {
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX");
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_MNN");
|
||||
|
||||
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
||||
const boxes = await facedet.predict(image);
|
||||
@ -44,9 +44,11 @@ async function testGenderTest() {
|
||||
}
|
||||
|
||||
async function testFaceID() {
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const faceid = await deploy.faceid.PartialFC.load();
|
||||
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
||||
console.log("初始化模型")
|
||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||
const faceid = await deploy.faceid.CosFace.load("INSIGHTFACE_COSFACE_R50_MNN");
|
||||
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||
console.log("初始化模型完成")
|
||||
|
||||
const { basic, tests } = faceidTestData.stars;
|
||||
|
||||
@ -110,8 +112,8 @@ async function testFaceID() {
|
||||
}
|
||||
|
||||
async function testFaceAlign() {
|
||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
||||
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
||||
image = image.rotate(image.width / 2, image.height / 2, 0);
|
||||
@ -142,5 +144,4 @@ async function test() {
|
||||
|
||||
test().catch(err => {
|
||||
console.error(err);
|
||||
debugger
|
||||
});
|
||||
});
|
||||
|
Reference in New Issue
Block a user