针对mnn模型处理,输出增加shape
This commit is contained in:
@ -34,6 +34,13 @@ static size_t getShapeSize(const std::vector<int> &shape)
|
|||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
|
||||||
|
{
|
||||||
|
auto result = Napi::Array::New(env, shape.size());
|
||||||
|
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
class MNNSessionRunWorker : public AsyncWorker {
|
class MNNSessionRunWorker : public AsyncWorker {
|
||||||
public:
|
public:
|
||||||
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
||||||
@ -46,7 +53,6 @@ class MNNSessionRunWorker : public AsyncWorker {
|
|||||||
|
|
||||||
void Execute()
|
void Execute()
|
||||||
{
|
{
|
||||||
interpreter_->resizeSession(session_);
|
|
||||||
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
||||||
SetError(std::string("Run session failed"));
|
SetError(std::string("Run session failed"));
|
||||||
}
|
}
|
||||||
@ -61,9 +67,12 @@ class MNNSessionRunWorker : public AsyncWorker {
|
|||||||
auto result = Object::New(Env());
|
auto result = Object::New(Env());
|
||||||
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
||||||
auto tensor = it.second;
|
auto tensor = it.second;
|
||||||
|
auto item = Object::New(Env());
|
||||||
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
||||||
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
||||||
result.Set(it.first, buffer);
|
item.Set("data", buffer);
|
||||||
|
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
|
||||||
|
result.Set(it.first, item);
|
||||||
}
|
}
|
||||||
Callback().Call({Env().Undefined(), result});
|
Callback().Call({Env().Undefined(), result});
|
||||||
}
|
}
|
||||||
@ -83,8 +92,10 @@ class MNNSessionRunWorker : public AsyncWorker {
|
|||||||
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (shape.size()) interpreter_->resizeTensor(tensor, shape);
|
if (shape.size()) {
|
||||||
|
interpreter_->resizeTensor(tensor, shape);
|
||||||
|
interpreter_->resizeSession(session_);
|
||||||
|
}
|
||||||
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
||||||
if (tensorBytes != dataBytes) {
|
if (tensorBytes != dataBytes) {
|
||||||
SetError(std::string("input name #" + name + " data size not matched"));
|
SetError(std::string("input name #" + name + " data size not matched"));
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
export interface SessionNodeInfo {
|
export interface SessionNodeInfo {
|
||||||
name: string
|
name: string
|
||||||
type: number
|
type: DataType
|
||||||
shape: number[]
|
shape: number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,8 +30,13 @@ export interface SessionRunInputOption {
|
|||||||
shape?: number[]
|
shape?: number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface SessionRunOutput {
|
||||||
|
shape: number[]
|
||||||
|
data: Float32Array
|
||||||
|
}
|
||||||
|
|
||||||
export abstract class CommonSession {
|
export abstract class CommonSession {
|
||||||
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
|
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>>
|
||||||
|
|
||||||
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
||||||
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
export * as common from "./common";
|
export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common";
|
||||||
export * as ort from "./ort";
|
export * as ort from "./ort";
|
||||||
export * as mnn from "./mnn";
|
export * as mnn from "./mnn";
|
@ -1,4 +1,4 @@
|
|||||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||||
|
|
||||||
export class MNNSession extends CommonSession {
|
export class MNNSession extends CommonSession {
|
||||||
#session: any
|
#session: any
|
||||||
@ -11,16 +11,19 @@ export class MNNSession extends CommonSession {
|
|||||||
this.#session = new addon.MNNSession(modelData);
|
this.#session = new addon.MNNSession(modelData);
|
||||||
}
|
}
|
||||||
|
|
||||||
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>> {
|
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>> {
|
||||||
const inputArgs: Record<string, any> = {};
|
const inputArgs: Record<string, any> = {};
|
||||||
for (const [name, option] of Object.entries(inputs)) {
|
for (const [name, option] of Object.entries(inputs)) {
|
||||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
}
|
}
|
||||||
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, { data: ArrayBuffer, shape: number[] }>) => {
|
||||||
if (err) return reject(err);
|
if (err) return reject(err);
|
||||||
const result: Record<string, Float32Array> = {};
|
const result: Record<string, SessionRunOutput> = {};
|
||||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
for (const [name, val] of Object.entries(res)) result[name] = {
|
||||||
|
shape: val.shape,
|
||||||
|
data: new Float32Array(val.data),
|
||||||
|
}
|
||||||
resolve(result);
|
resolve(result);
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||||
|
|
||||||
export class OrtSession extends CommonSession {
|
export class OrtSession extends CommonSession {
|
||||||
#session: any;
|
#session: any;
|
||||||
@ -22,10 +22,10 @@ export class OrtSession extends CommonSession {
|
|||||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
|
||||||
if (err) return reject(err);
|
if (err) return reject(err);
|
||||||
const result: Record<string, Float32Array> = {};
|
const result: Record<string, SessionRunOutput> = {};
|
||||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
|
||||||
resolve(result);
|
resolve(result);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { backend } from "../../backend";
|
import { backend } from "../../backend";
|
||||||
import { cv } from "../../cv";
|
import { cv } from "../../cv";
|
||||||
|
|
||||||
export type ModelConstructor<T> = new (session: backend.common.CommonSession) => T;
|
export type ModelConstructor<T> = new (session: backend.CommonSession) => T;
|
||||||
|
|
||||||
export type ImageSource = cv.Mat | Uint8Array | string;
|
export type ImageSource = cv.Mat | Uint8Array | string;
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ export interface ModelCacheResult<T, Create extends boolean> {
|
|||||||
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
|
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
|
||||||
|
|
||||||
export abstract class Model {
|
export abstract class Model {
|
||||||
protected session: backend.common.CommonSession;
|
protected session: backend.CommonSession;
|
||||||
|
|
||||||
protected static async resolveImage<R>(image: ImageSource, resolver: (image: cv.Mat) => R | Promise<R>): Promise<R> {
|
protected static async resolveImage<R>(image: ImageSource, resolver: (image: cv.Mat) => R | Promise<R>): Promise<R> {
|
||||||
if (typeof image === "string") {
|
if (typeof image === "string") {
|
||||||
@ -148,7 +148,7 @@ export abstract class Model {
|
|||||||
return { modelPath, modelType, model: model as any }
|
return { modelPath, modelType, model: model as any }
|
||||||
}
|
}
|
||||||
|
|
||||||
public constructor(session: backend.common.CommonSession) { this.session = session; }
|
public constructor(session: backend.CommonSession) { this.session = session; }
|
||||||
|
|
||||||
public get inputs() { return this.session.inputs; }
|
public get inputs() { return this.session.inputs; }
|
||||||
public get outputs() { return this.session.outputs; }
|
public get outputs() { return this.session.outputs; }
|
||||||
|
@ -46,7 +46,7 @@ export class FaceLandmark1000 extends Model {
|
|||||||
data: nchwImageData,
|
data: nchwImageData,
|
||||||
type: "float32",
|
type: "float32",
|
||||||
}
|
}
|
||||||
}).then(res => res[this.output.name]);
|
}).then(res => res[this.output.name].data);
|
||||||
|
|
||||||
const points: FacePoint[] = [];
|
const points: FacePoint[] = [];
|
||||||
for (let i = 0; i < res.length; i += 2) {
|
for (let i = 0; i < res.length; i += 2) {
|
||||||
|
@ -51,7 +51,7 @@ export class PFLD extends Model {
|
|||||||
shape: [1, 3, input.shape[2], input.shape[3]],
|
shape: [1, 3, input.shape[2], input.shape[3]],
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const pointsBuffer = res[pointsOutput.name];
|
const pointsBuffer = res[pointsOutput.name].data;
|
||||||
|
|
||||||
const points: FacePoint[] = [];
|
const points: FacePoint[] = [];
|
||||||
for (let i = 0; i < pointsBuffer.length; i += 2) {
|
for (let i = 0; i < pointsBuffer.length; i += 2) {
|
||||||
|
@ -36,7 +36,7 @@ export class GenderAge extends Model {
|
|||||||
data: nchwImage,
|
data: nchwImage,
|
||||||
type: "float32",
|
type: "float32",
|
||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name].data);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
gender: result[0] > result[1] ? "F" : "M",
|
gender: result[0] > result[1] ? "F" : "M",
|
||||||
|
@ -27,7 +27,7 @@ export class Yolov5Face extends FaceDetector {
|
|||||||
const threshold = option?.threshold ?? 0.5;
|
const threshold = option?.threshold ?? 0.5;
|
||||||
for (let i = 0; i < outShape[1]; i++) {
|
for (let i = 0; i < outShape[1]; i++) {
|
||||||
const beg = i * outShape[2];
|
const beg = i * outShape[2];
|
||||||
const rectData = outputData.slice(beg, beg + outShape[2]);
|
const rectData = outputData.data.slice(beg, beg + outShape[2]);
|
||||||
const x = parseInt(rectData[0] * ratioWidth as any);
|
const x = parseInt(rectData[0] * ratioWidth as any);
|
||||||
const y = parseInt(rectData[1] * ratioHeight as any);
|
const y = parseInt(rectData[1] * ratioHeight as any);
|
||||||
const w = parseInt(rectData[2] * ratioWidth as any);
|
const w = parseInt(rectData[2] * ratioWidth as any);
|
||||||
|
@ -28,7 +28,7 @@ export class AdaFace extends FaceRecognition {
|
|||||||
data: nchwImageData,
|
data: nchwImageData,
|
||||||
shape: [1, 3, input.shape[2], input.shape[3]],
|
shape: [1, 3, input.shape[2], input.shape[3]],
|
||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name].data);
|
||||||
|
|
||||||
return new Array(...embedding);
|
return new Array(...embedding);
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ export class Insightface extends FaceRecognition {
|
|||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name]);
|
||||||
|
|
||||||
return new Array(...embedding);
|
return new Array(...embedding.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
11
src/test.ts
11
src/test.ts
@ -115,10 +115,13 @@ async function testFaceAlign() {
|
|||||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||||
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
||||||
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
let image = await cv.Mat.load("https://i0.hdslb.com/bfs/archive/64e47ec9fdac9e24bc2b49b5aaad5560da1bfe3e.jpg");
|
||||||
image = image.rotate(image.width / 2, image.height / 2, 0);
|
image = image.rotate(image.width / 2, image.height / 2, 0);
|
||||||
|
|
||||||
const face = await fd.predict(image).then(res => res[0].toSquare());
|
const face = await fd.predict(image).then(res => {
|
||||||
|
console.log(res);
|
||||||
|
return res[0].toSquare()
|
||||||
|
});
|
||||||
const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } });
|
const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } });
|
||||||
|
|
||||||
points.points.forEach((point, idx) => {
|
points.points.forEach((point, idx) => {
|
||||||
@ -137,8 +140,8 @@ async function testFaceAlign() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function test() {
|
async function test() {
|
||||||
await testGenderTest();
|
// await testGenderTest();
|
||||||
await testFaceID();
|
// await testFaceID();
|
||||||
await testFaceAlign();
|
await testFaceAlign();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user