34 lines
1.5 KiB
TypeScript
34 lines
1.5 KiB
TypeScript
import { getConfig } from "../../config";
|
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
|
|
|
export class OrtSession extends CommonSession {
|
|
#session: any;
|
|
#inputs: Record<string, SessionNodeInfo> | null = null;
|
|
#outputs: Record<string, SessionNodeInfo> | null = null;
|
|
|
|
public constructor(modelData: Uint8Array) {
|
|
super();
|
|
const addon = require(getConfig("ORT_ADDON_FILE"));
|
|
this.#session = new addon.OrtSession(modelData);
|
|
}
|
|
|
|
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
|
|
|
|
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
|
|
|
|
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>) {
|
|
const inputArgs: Record<string, any> = {};
|
|
for (const [name, option] of Object.entries(inputs)) {
|
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
|
}
|
|
|
|
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
|
|
if (err) return reject(err);
|
|
const result: Record<string, SessionRunOutput> = {};
|
|
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
|
|
resolve(result);
|
|
}));
|
|
}
|
|
}
|