针对mnn模型处理,输出增加shape
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
export interface SessionNodeInfo {
|
||||
name: string
|
||||
type: number
|
||||
type: DataType
|
||||
shape: number[]
|
||||
}
|
||||
|
||||
@ -30,8 +30,13 @@ export interface SessionRunInputOption {
|
||||
shape?: number[]
|
||||
}
|
||||
|
||||
export interface SessionRunOutput {
|
||||
shape: number[]
|
||||
data: Float32Array
|
||||
}
|
||||
|
||||
export abstract class CommonSession {
|
||||
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
|
||||
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>>
|
||||
|
||||
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
||||
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
||||
|
@ -1,3 +1,3 @@
|
||||
export * as common from "./common";
|
||||
export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common";
|
||||
export * as ort from "./ort";
|
||||
export * as mnn from "./mnn";
|
@ -1,4 +1,4 @@
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||
|
||||
export class MNNSession extends CommonSession {
|
||||
#session: any
|
||||
@ -11,16 +11,19 @@ export class MNNSession extends CommonSession {
|
||||
this.#session = new addon.MNNSession(modelData);
|
||||
}
|
||||
|
||||
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>> {
|
||||
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>> {
|
||||
const inputArgs: Record<string, any> = {};
|
||||
for (const [name, option] of Object.entries(inputs)) {
|
||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||
}
|
||||
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, { data: ArrayBuffer, shape: number[] }>) => {
|
||||
if (err) return reject(err);
|
||||
const result: Record<string, Float32Array> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
||||
const result: Record<string, SessionRunOutput> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = {
|
||||
shape: val.shape,
|
||||
data: new Float32Array(val.data),
|
||||
}
|
||||
resolve(result);
|
||||
}))
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||
|
||||
export class OrtSession extends CommonSession {
|
||||
#session: any;
|
||||
@ -22,10 +22,10 @@ export class OrtSession extends CommonSession {
|
||||
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||
}
|
||||
|
||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
|
||||
if (err) return reject(err);
|
||||
const result: Record<string, Float32Array> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
||||
const result: Record<string, SessionRunOutput> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
|
||||
resolve(result);
|
||||
}));
|
||||
}
|
||||
|
Reference in New Issue
Block a user