初次提交
This commit is contained in:
1
src/backend/common/index.ts
Normal file
1
src/backend/common/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export * from "./session";
|
28
src/backend/common/session.ts
Normal file
28
src/backend/common/session.ts
Normal file
@ -0,0 +1,28 @@
|
||||
|
||||
|
||||
export interface SessionNodeInfo {
|
||||
name: string
|
||||
type: number
|
||||
shape: number[]
|
||||
}
|
||||
|
||||
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
||||
|
||||
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
||||
|
||||
export interface SessionRunInputOption {
|
||||
type?: SessionNodeType
|
||||
data: SessionNodeData
|
||||
shape?: number[]
|
||||
}
|
||||
|
||||
export abstract class CommonSession {
|
||||
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
|
||||
|
||||
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
||||
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
||||
}
|
||||
|
||||
export function isTypedArray(val: any): val is SessionNodeData {
|
||||
return val?.buffer instanceof ArrayBuffer;
|
||||
}
|
1
src/backend/index.ts
Normal file
1
src/backend/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export * as backend from "./main";
|
2
src/backend/main.ts
Normal file
2
src/backend/main.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export * as common from "./common";
|
||||
export * as ort from "./ort";
|
1
src/backend/ort/index.ts
Normal file
1
src/backend/ort/index.ts
Normal file
@ -0,0 +1 @@
|
||||
export { OrtSession as Session } from "./session";
|
32
src/backend/ort/session.ts
Normal file
32
src/backend/ort/session.ts
Normal file
@ -0,0 +1,32 @@
|
||||
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||
|
||||
export class OrtSession extends CommonSession {
|
||||
#session: any;
|
||||
#inputs: Record<string, SessionNodeInfo> | null = null;
|
||||
#outputs: Record<string, SessionNodeInfo> | null = null;
|
||||
|
||||
public constructor(modelData: Uint8Array) {
|
||||
super();
|
||||
const addon = require("../../../build/ort.node")
|
||||
this.#session = new addon.OrtSession(modelData);
|
||||
}
|
||||
|
||||
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
|
||||
|
||||
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
|
||||
|
||||
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>) {
|
||||
const inputArgs: Record<string, any> = {};
|
||||
for (const [name, option] of Object.entries(inputs)) {
|
||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||
else inputArgs[name] = option;
|
||||
}
|
||||
|
||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||
if (err) return reject(err);
|
||||
const result: Record<string, Float32Array> = {};
|
||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
||||
resolve(result);
|
||||
}));
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user