增加下载工具

This commit is contained in:
2025-03-20 18:24:12 +08:00
parent 6bf7db1f4c
commit 6cdf4cbcd6
12 changed files with 131 additions and 9 deletions

12
src/backend/config.ts Normal file
View File

@ -0,0 +1,12 @@
import path from "path";
const defaultAddonDir = path.join(__dirname, "../../build")
const aiConfig = {
"MNN_ADDON_FILE": path.join(defaultAddonDir, "mnn.node"),
"ORT_ADDON_FILE": path.join(defaultAddonDir, "ort.node"),
};
export function setConfig<K extends keyof typeof aiConfig>(key: K, value: typeof aiConfig[K]) { aiConfig[key] = value; }
export function getConfig<K extends keyof typeof aiConfig>(key: K): typeof aiConfig[K] { return aiConfig[key]; }

95
src/backend/download.ts Normal file
View File

@ -0,0 +1,95 @@
import os from "os";
import fs from "fs";
import path from "path";
import { getConfig } from "./config";
const URLS = {
GITHUB: "https://github.com/kangkang520/node-addons/releases/download/ai{{version}}/{{backend}}_{{platform}}_{{arch}}.node",
XXXXX: `http://git.urnas.cn:5200/yizhi-js-lib/ai-box/releases/download/{{version}}/{{backend}}_{{platform}}_{{arch}}.node`,
}
function releaseVersion() { return require("../../package.json").releaseVersion }
function getURL(backend: "ort" | "mnn", template: string) {
let platform = "";
let arch = "";
switch (os.platform()) {
case "win32":
platform = "windows";
break;
case "linux":
platform = "linux";
break;
case "darwin":
platform = "macos";
break
default:
throw new Error(`Unsupported platform: ${os.platform()}, Please compile the addon yourself.`);
}
switch (os.arch()) {
case "x64":
arch = "x64";
break;
case "arm64":
arch = "arm64";
break;
default:
throw new Error(`Unsupported architecture: ${os.arch()}, Please compile the addon yourself.`);
}
return template.replaceAll("{{backend}}", backend).replaceAll("{{version}}", releaseVersion()).replaceAll("{{platform}}", platform).replaceAll("{{arch}}", arch);
}
async function getStream(backend: "mnn" | "ort") {
for (const [name, url] of Object.entries(URLS)) {
try {
return await fetch(getURL(backend, url)).then(res => {
if (res.status != 200) throw new Error("Failed to download addon.");
return res.blob().then(b => b.stream());
})
} catch (e) { }
}
throw new Error("Failed to download addon.");
}
export async function downloadBackend(backend: "ort" | "mnn", savename?: string) {
const backendConfigNameDict = { ort: "ORT_ADDON_FILE" as const, mnn: "MNN_ADDON_FILE" as const };
const defaultAddon = path.resolve(process.cwd(), getConfig(backendConfigNameDict[backend]));
const saveName = savename ? path.resolve(path.dirname(defaultAddon), savename) : defaultAddon;
if (fs.existsSync(saveName)) {
try {
const addon = require(saveName);
if (addon.__release__ === releaseVersion()) return saveName;
//清除缓存
delete require.cache[saveName];
} catch (err) { }
}
await fs.promises.mkdir(path.dirname(saveName), { recursive: true });
const stream = await getStream(backend);
const cacheFile = await new Promise<string>((resolve, reject) => {
const cacheFile = path.join(os.tmpdir(), Date.now() + ".cv.node");
let fsStream!: ReturnType<typeof fs.createWriteStream>;
stream.pipeTo(new WritableStream({
start(controller) {
fsStream = fs.createWriteStream(cacheFile);
},
async write(chunk, controller) {
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
},
close() {
fsStream.end();
resolve(cacheFile);
},
abort() { }
})).catch(reject);
});
if (fs.existsSync(saveName)) await fs.promises.rm(saveName, { recursive: true, force: true });
await fs.promises.cp(cacheFile, saveName);
await fs.promises.rm(cacheFile);
return saveName;
}

View File

@ -1,3 +1,4 @@
export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common";
export * as ort from "./ort";
export * as mnn from "./mnn";
export * as mnn from "./mnn";
export { downloadBackend } from "./download";

View File

@ -1,4 +1,4 @@
import { getConfig } from "../../config";
import { getConfig } from "../config";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
export class MNNSession extends CommonSession {

View File

@ -1,4 +1,4 @@
import { getConfig } from "../../config";
import { getConfig } from "../config";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
export class OrtSession extends CommonSession {