初次提交

This commit is contained in:
2025-03-06 14:38:32 +08:00
commit b0e71d0ebb
46 changed files with 1926 additions and 0 deletions

View File

@ -0,0 +1 @@
export * from "./session";

View File

@ -0,0 +1,28 @@
export interface SessionNodeInfo {
name: string
type: number
shape: number[]
}
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
export interface SessionRunInputOption {
type?: SessionNodeType
data: SessionNodeData
shape?: number[]
}
export abstract class CommonSession {
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
public abstract get inputs(): Record<string, SessionNodeInfo>;
public abstract get outputs(): Record<string, SessionNodeInfo>;
}
export function isTypedArray(val: any): val is SessionNodeData {
return val?.buffer instanceof ArrayBuffer;
}

1
src/backend/index.ts Normal file
View File

@ -0,0 +1 @@
export * as backend from "./main";

2
src/backend/main.ts Normal file
View File

@ -0,0 +1,2 @@
export * as common from "./common";
export * as ort from "./ort";

1
src/backend/ort/index.ts Normal file
View File

@ -0,0 +1 @@
export { OrtSession as Session } from "./session";

View File

@ -0,0 +1,32 @@
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
export class OrtSession extends CommonSession {
#session: any;
#inputs: Record<string, SessionNodeInfo> | null = null;
#outputs: Record<string, SessionNodeInfo> | null = null;
public constructor(modelData: Uint8Array) {
super();
const addon = require("../../../build/ort.node")
this.#session = new addon.OrtSession(modelData);
}
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>) {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = option;
}
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
if (err) return reject(err);
const result: Record<string, Float32Array> = {};
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
resolve(result);
}));
}
}