增加MNN模型支持

This commit is contained in:
2025-03-07 17:18:20 +08:00
parent 4a6d092de1
commit bd90f2f6f6
24 changed files with 497 additions and 153 deletions

View File

@ -6,12 +6,26 @@ export interface SessionNodeInfo {
shape: number[]
}
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
export enum DataType {
Unknown,
Float32,
Float64,
Int32,
Uint32,
Int16,
Uint16,
Int8,
Uint8,
Int64,
Uint64,
}
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
export interface SessionRunInputOption {
type?: SessionNodeType
type?: DataTypeString
data: SessionNodeData
shape?: number[]
}
@ -26,3 +40,36 @@ export abstract class CommonSession {
export function isTypedArray(val: any): val is SessionNodeData {
return val?.buffer instanceof ArrayBuffer;
}
export function dataTypeFrom(str: DataTypeString) {
return {
"float32": DataType.Float32,
"float64": DataType.Float64,
"float": DataType.Float32,
"double": DataType.Float64,
"int32": DataType.Int32,
"uint32": DataType.Uint32,
"int16": DataType.Int16,
"uint16": DataType.Uint16,
"int8": DataType.Int8,
"uint8": DataType.Uint8,
"int64": DataType.Int64,
"uint64": DataType.Uint64,
}[str] ?? DataType.Unknown;
}
export function dataTypeToString(type: DataType): DataTypeString | null {
switch (type) {
case DataType.Float32: return "float32";
case DataType.Float64: return "float64";
case DataType.Int32: return "int32";
case DataType.Uint32: return "uint32";
case DataType.Int16: return "int16";
case DataType.Uint16: return "uint16";
case DataType.Int8: return "int8";
case DataType.Uint8: return "uint8";
case DataType.Int64: return "int64";
case DataType.Uint64: return "uint64";
default: return null;
}
}

View File

@ -1 +1 @@
export * as backend from "./main";
export * as backend from "./main";

View File

@ -1,2 +1,3 @@
export * as common from "./common";
export * as ort from "./ort";
export * as mnn from "./mnn";

1
src/backend/mnn/index.ts Normal file
View File

@ -0,0 +1 @@
export { MNNSession as Session } from "./session";

View File

@ -0,0 +1,30 @@
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
export class MNNSession extends CommonSession {
#session: any
#inputs: Record<string, SessionNodeInfo> | null = null;
#outputs: Record<string, SessionNodeInfo> | null = null;
public constructor(modelData: Uint8Array) {
super();
const addon = require("../../../build/mnn.node")
this.#session = new addon.MNNSession(modelData);
}
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>> {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
}
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
if (err) return reject(err);
const result: Record<string, Float32Array> = {};
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
resolve(result);
}))
}
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
}

View File

@ -1,4 +1,4 @@
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
export class OrtSession extends CommonSession {
#session: any;
@ -19,7 +19,7 @@ export class OrtSession extends CommonSession {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = option;
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
}
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {