From a966b82963c568191c4f3f8a5da737f8f6298c85 Mon Sep 17 00:00:00 2001 From: Yizhi <946185759@qq.com> Date: Mon, 10 Mar 2025 17:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9mnn=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=8C=E8=BE=93=E5=87=BA=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?shape?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cxx/mnn/node.cc | 19 +++++++++++++++---- src/backend/common/session.ts | 9 +++++++-- src/backend/main.ts | 2 +- src/backend/mnn/session.ts | 13 ++++++++----- src/backend/ort/session.ts | 8 ++++---- src/deploy/common/model.ts | 6 +++--- src/deploy/facealign/landmark1000.ts | 2 +- src/deploy/facealign/pfld.ts | 2 +- src/deploy/faceattr/gender-age.ts | 2 +- src/deploy/facedet/yolov5.ts | 2 +- src/deploy/faceid/adaface.ts | 2 +- src/deploy/faceid/insightface.ts | 2 +- src/test.ts | 11 +++++++---- 13 files changed, 51 insertions(+), 29 deletions(-) diff --git a/cxx/mnn/node.cc b/cxx/mnn/node.cc index 53ed4bf..0caf19a 100644 --- a/cxx/mnn/node.cc +++ b/cxx/mnn/node.cc @@ -34,6 +34,13 @@ static size_t getShapeSize(const std::vector &shape) return sum; } +static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector &shape) +{ + auto result = Napi::Array::New(env, shape.size()); + for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i])); + return result; +} + class MNNSessionRunWorker : public AsyncWorker { public: MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session) @@ -46,7 +53,6 @@ class MNNSessionRunWorker : public AsyncWorker { void Execute() { - interpreter_->resizeSession(session_); if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) { SetError(std::string("Run session failed")); } @@ -61,9 +67,12 @@ class MNNSessionRunWorker : public AsyncWorker { auto result = Object::New(Env()); for (auto it : interpreter_->getSessionOutputAll(session_)) { auto tensor = it.second; + auto item = Object::New(Env()); auto buffer = ArrayBuffer::New(Env(), tensor->size()); memcpy(buffer.Data(), tensor->host(), tensor->size()); - result.Set(it.first, buffer); + item.Set("data", buffer); + item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape())); + result.Set(it.first, item); } Callback().Call({Env().Undefined(), result}); } @@ -83,8 +92,10 @@ class MNNSessionRunWorker : public AsyncWorker { if (it != DATA_TYPE_MAP.end()) type = it->second; } - if (shape.size()) interpreter_->resizeTensor(tensor, shape); - + if (shape.size()) { + interpreter_->resizeTensor(tensor, shape); + interpreter_->resizeSession(session_); + } auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8; if (tensorBytes != dataBytes) { SetError(std::string("input name #" + name + " data size not matched")); diff --git a/src/backend/common/session.ts b/src/backend/common/session.ts index 9c479ef..a1f66b3 100644 --- a/src/backend/common/session.ts +++ b/src/backend/common/session.ts @@ -2,7 +2,7 @@ export interface SessionNodeInfo { name: string - type: number + type: DataType shape: number[] } @@ -30,8 +30,13 @@ export interface SessionRunInputOption { shape?: number[] } +export interface SessionRunOutput { + shape: number[] + data: Float32Array +} + export abstract class CommonSession { - public abstract run(inputs: Record): Promise> + public abstract run(inputs: Record): Promise> public abstract get inputs(): Record; public abstract get outputs(): Record; diff --git a/src/backend/main.ts b/src/backend/main.ts index a573c1a..33481a0 100644 --- a/src/backend/main.ts +++ b/src/backend/main.ts @@ -1,3 +1,3 @@ -export * as common from "./common"; +export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common"; export * as ort from "./ort"; export * as mnn from "./mnn"; \ No newline at end of file diff --git a/src/backend/mnn/session.ts b/src/backend/mnn/session.ts index e523748..ec60ba1 100644 --- a/src/backend/mnn/session.ts +++ b/src/backend/mnn/session.ts @@ -1,4 +1,4 @@ -import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; +import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common"; export class MNNSession extends CommonSession { #session: any @@ -11,16 +11,19 @@ export class MNNSession extends CommonSession { this.#session = new addon.MNNSession(modelData); } - public run(inputs: Record): Promise> { + public run(inputs: Record): Promise> { const inputArgs: Record = {}; for (const [name, option] of Object.entries(inputs)) { if (isTypedArray(option)) inputArgs[name] = { data: option } else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined }; } - return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => { + return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record) => { if (err) return reject(err); - const result: Record = {}; - for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer); + const result: Record = {}; + for (const [name, val] of Object.entries(res)) result[name] = { + shape: val.shape, + data: new Float32Array(val.data), + } resolve(result); })) } diff --git a/src/backend/ort/session.ts b/src/backend/ort/session.ts index 76c46d8..80ca63f 100644 --- a/src/backend/ort/session.ts +++ b/src/backend/ort/session.ts @@ -1,4 +1,4 @@ -import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; +import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common"; export class OrtSession extends CommonSession { #session: any; @@ -22,10 +22,10 @@ export class OrtSession extends CommonSession { else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined }; } - return new Promise>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => { + return new Promise>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record) => { if (err) return reject(err); - const result: Record = {}; - for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer); + const result: Record = {}; + for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape }; resolve(result); })); } diff --git a/src/deploy/common/model.ts b/src/deploy/common/model.ts index 078e1e4..cd4c735 100644 --- a/src/deploy/common/model.ts +++ b/src/deploy/common/model.ts @@ -1,7 +1,7 @@ import { backend } from "../../backend"; import { cv } from "../../cv"; -export type ModelConstructor = new (session: backend.common.CommonSession) => T; +export type ModelConstructor = new (session: backend.CommonSession) => T; export type ImageSource = cv.Mat | Uint8Array | string; @@ -27,7 +27,7 @@ export interface ModelCacheResult { const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise }> = []; export abstract class Model { - protected session: backend.common.CommonSession; + protected session: backend.CommonSession; protected static async resolveImage(image: ImageSource, resolver: (image: cv.Mat) => R | Promise): Promise { if (typeof image === "string") { @@ -148,7 +148,7 @@ export abstract class Model { return { modelPath, modelType, model: model as any } } - public constructor(session: backend.common.CommonSession) { this.session = session; } + public constructor(session: backend.CommonSession) { this.session = session; } public get inputs() { return this.session.inputs; } public get outputs() { return this.session.outputs; } diff --git a/src/deploy/facealign/landmark1000.ts b/src/deploy/facealign/landmark1000.ts index 86a2e87..4b093ee 100644 --- a/src/deploy/facealign/landmark1000.ts +++ b/src/deploy/facealign/landmark1000.ts @@ -46,7 +46,7 @@ export class FaceLandmark1000 extends Model { data: nchwImageData, type: "float32", } - }).then(res => res[this.output.name]); + }).then(res => res[this.output.name].data); const points: FacePoint[] = []; for (let i = 0; i < res.length; i += 2) { diff --git a/src/deploy/facealign/pfld.ts b/src/deploy/facealign/pfld.ts index ab06eaf..fdb34c0 100644 --- a/src/deploy/facealign/pfld.ts +++ b/src/deploy/facealign/pfld.ts @@ -51,7 +51,7 @@ export class PFLD extends Model { shape: [1, 3, input.shape[2], input.shape[3]], } }); - const pointsBuffer = res[pointsOutput.name]; + const pointsBuffer = res[pointsOutput.name].data; const points: FacePoint[] = []; for (let i = 0; i < pointsBuffer.length; i += 2) { diff --git a/src/deploy/faceattr/gender-age.ts b/src/deploy/faceattr/gender-age.ts index db1c281..e5249e9 100644 --- a/src/deploy/faceattr/gender-age.ts +++ b/src/deploy/faceattr/gender-age.ts @@ -36,7 +36,7 @@ export class GenderAge extends Model { data: nchwImage, type: "float32", } - }).then(res => res[output.name]); + }).then(res => res[output.name].data); return { gender: result[0] > result[1] ? "F" : "M", diff --git a/src/deploy/facedet/yolov5.ts b/src/deploy/facedet/yolov5.ts index 6666a7d..0a39327 100644 --- a/src/deploy/facedet/yolov5.ts +++ b/src/deploy/facedet/yolov5.ts @@ -27,7 +27,7 @@ export class Yolov5Face extends FaceDetector { const threshold = option?.threshold ?? 0.5; for (let i = 0; i < outShape[1]; i++) { const beg = i * outShape[2]; - const rectData = outputData.slice(beg, beg + outShape[2]); + const rectData = outputData.data.slice(beg, beg + outShape[2]); const x = parseInt(rectData[0] * ratioWidth as any); const y = parseInt(rectData[1] * ratioHeight as any); const w = parseInt(rectData[2] * ratioWidth as any); diff --git a/src/deploy/faceid/adaface.ts b/src/deploy/faceid/adaface.ts index 7462929..6d771a1 100644 --- a/src/deploy/faceid/adaface.ts +++ b/src/deploy/faceid/adaface.ts @@ -28,7 +28,7 @@ export class AdaFace extends FaceRecognition { data: nchwImageData, shape: [1, 3, input.shape[2], input.shape[3]], } - }).then(res => res[output.name]); + }).then(res => res[output.name].data); return new Array(...embedding); } diff --git a/src/deploy/faceid/insightface.ts b/src/deploy/faceid/insightface.ts index ac19bd6..606ef2b 100644 --- a/src/deploy/faceid/insightface.ts +++ b/src/deploy/faceid/insightface.ts @@ -44,7 +44,7 @@ export class Insightface extends FaceRecognition { } }).then(res => res[output.name]); - return new Array(...embedding); + return new Array(...embedding.data); } } diff --git a/src/test.ts b/src/test.ts index 929825a..e089278 100644 --- a/src/test.ts +++ b/src/test.ts @@ -115,10 +115,13 @@ async function testFaceAlign() { const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN"); const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN"); // const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX"); - let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712"); + let image = await cv.Mat.load("https://i0.hdslb.com/bfs/archive/64e47ec9fdac9e24bc2b49b5aaad5560da1bfe3e.jpg"); image = image.rotate(image.width / 2, image.height / 2, 0); - const face = await fd.predict(image).then(res => res[0].toSquare()); + const face = await fd.predict(image).then(res => { + console.log(res); + return res[0].toSquare() + }); const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } }); points.points.forEach((point, idx) => { @@ -137,8 +140,8 @@ async function testFaceAlign() { } async function test() { - await testGenderTest(); - await testFaceID(); + // await testGenderTest(); + // await testFaceID(); await testFaceAlign(); }