import { getConfig } from "../../config"; import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common"; export class OrtSession extends CommonSession { #session: any; #inputs: Record | null = null; #outputs: Record | null = null; public constructor(modelData: Uint8Array) { super(); const addon = require(getConfig("ORT_ADDON_FILE")); this.#session = new addon.OrtSession(modelData); } public get inputs(): Record { return this.#inputs ??= this.#session.GetInputsInfo(); } public get outputs(): Record { return this.#outputs ??= this.#session.GetOutputsInfo(); } public run(inputs: Record) { const inputArgs: Record = {}; for (const [name, option] of Object.entries(inputs)) { if (isTypedArray(option)) inputArgs[name] = { data: option } else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined }; } return new Promise>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record) => { if (err) return reject(err); const result: Record = {}; for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape }; resolve(result); })); } }