Files
ai-box/src/backend/ort/session.ts
2025-03-10 18:12:32 +08:00

34 lines
1.5 KiB
TypeScript

import { getConfig } from "../../config";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
export class OrtSession extends CommonSession {
#session: any;
#inputs: Record<string, SessionNodeInfo> | null = null;
#outputs: Record<string, SessionNodeInfo> | null = null;
public constructor(modelData: Uint8Array) {
super();
const addon = require(getConfig("ORT_ADDON_FILE"));
this.#session = new addon.OrtSession(modelData);
}
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>) {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
}
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
if (err) return reject(err);
const result: Record<string, SessionRunOutput> = {};
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
resolve(result);
}));
}
}