Compare commits

...

17 Commits

Author SHA1 Message Date
2d98a1c764 完善预编译插件自动下载工具 2025-03-21 12:40:49 +08:00
aab09dbb86 更新版本号 2025-03-21 08:53:37 +08:00
6cdf4cbcd6 增加下载工具 2025-03-20 18:24:12 +08:00
6bf7db1f4c 移除cv,请使用@yizhi/cv包 2025-03-17 12:24:30 +08:00
b48d2daffb 修复人脸框toSquare函数的问题 2025-03-13 18:16:26 +08:00
eea943bba5 删除多余日志输出 2025-03-13 15:43:53 +08:00
232d480ca0 修复typescript编译时问题 2025-03-11 11:34:05 +08:00
524dcaecbd 代码优化 2025-03-10 18:12:32 +08:00
a966b82963 针对mnn模型处理,输出增加shape 2025-03-10 17:49:53 +08:00
e362890c96 修复模型下载问题 2025-03-10 15:50:55 +08:00
92e46c2c33 修复模型下载问题 2025-03-10 15:34:24 +08:00
2dfe063049 修复Windows编译问题 2025-03-10 15:18:30 +08:00
358d21b2bd 修复Windows第三库编译问题 2025-03-10 15:08:24 +08:00
e831b8e862 修复Windows下的编译问题 2025-03-10 14:45:15 +08:00
bd90f2f6f6 增加MNN模型支持 2025-03-07 17:18:20 +08:00
4a6d092de1 Merge branch 'main' of http://git.urnas.cn:5200/yizhi-js-lib/ai-box 2025-03-07 17:07:04 +08:00
a6fd117736 增加模型转换工具 2025-03-07 17:07:02 +08:00
41 changed files with 983 additions and 650 deletions

12
.npmignore Normal file
View File

@ -0,0 +1,12 @@
/build
/cache
/cxx
/models
/node_modules
/src
/testdata
/thirdpart
/tool
/.clang-format
/CMakeLists.txt
/tsconfig.json

View File

@ -10,7 +10,7 @@ if(NOT DEFINED CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
set(NODE_COMMON_SOURCES)
set(NODE_ADDON_FOUND OFF)
include_directories(cxx)
@ -38,6 +38,20 @@ if(CMAKE_JS_RESULT EQUAL 0)
)
endif()
# ReleaseVersion
if(CMAKE_JS_RESULT EQUAL 0)
execute_process(
COMMAND node ${CMAKE_SOURCE_DIR}/thirdpart/cmake-js-util.js --release
RESULT_VARIABLE CMAKE_JS_RESULT
OUTPUT_VARIABLE RELEASE_VERSION
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(CMAKE_JS_RESULT EQUAL 0)
message(STATUS "RELEASE_VERSION: ${RELEASE_VERSION}")
add_compile_definitions(RELEASE_VERSION="${RELEASE_VERSION}")
endif()
endif()
# NAPI
if(CMAKE_JS_RESULT EQUAL 0)
execute_process(
@ -74,26 +88,14 @@ if(EXISTS ${MNN_CMAKE_FILE})
message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}")
message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}")
message(STATUS "MNN_LIBS: ${MNN_LIBS}")
include_directories(${MNN_INCLUDE_DIRS})
include_directories(${MNN_INCLUDE_DIR})
link_directories(${MNN_LIB_DIR})
add_compile_definitions(USE_MNN)
set(USE_MNN ON)
endif()
# OpenCV
set(OpenCV_CMAKE_FILE ${CMAKE_SOURCE_DIR}/thirdpart/OpenCV/${CMAKE_BUILD_TYPE}/config.cmake)
if(EXISTS ${OpenCV_CMAKE_FILE})
include(${OpenCV_CMAKE_FILE})
message(STATUS "OpenCV_LIB_DIR: ${OpenCV_LIB_DIR}")
message(STATUS "OpenCV_INCLUDE_DIR: ${OpenCV_INCLUDE_DIR}")
message(STATUS "OpenCV_LIBS: ${OpenCV_LIBS}")
include_directories(${OpenCV_INCLUDE_DIRS})
link_directories(${OpenCV_LIB_DIR})
if(NODE_ADDON_FOUND)
add_node_targert(cv cxx/cv/node.cc)
target_link_libraries(cv ${OpenCV_LIBS})
target_compile_definitions(cv PUBLIC USE_OPENCV)
add_node_targert(mnn cxx/mnn/node.cc)
target_link_libraries(mnn ${MNN_LIBS})
target_compile_definitions(mnn PUBLIC USE_MNN)
list(APPEND NODE_COMMON_SOURCES cxx/mnn/node.cc)
endif()
endif()
@ -111,12 +113,29 @@ if(EXISTS ${ONNXRuntime_CMAKE_FILE})
add_node_targert(ort cxx/ort/node.cc)
target_link_libraries(ort ${ONNXRuntime_LIBS})
target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME)
list(APPEND NODE_COMMON_SOURCES cxx/ort/node.cc)
endif()
endif()
# 统一的NodeJS插件
if(NODE_ADDON_FOUND)
add_node_targert(addon cxx/node.cc)
target_sources(addon PRIVATE ${NODE_COMMON_SOURCES})
target_compile_definitions(addon PUBLIC BUILD_MAIN_WORD)
# MNN
if(EXISTS ${MNN_CMAKE_FILE})
target_link_libraries(addon ${MNN_LIBS})
target_compile_definitions(addon PUBLIC USE_MNN)
endif()
# OnnxRuntime
if(EXISTS ${ONNXRuntime_CMAKE_FILE})
target_link_libraries(addon ${ONNXRuntime_LIBS})
target_compile_definitions(addon PUBLIC USE_ONNXRUNTIME)
endif()
endif()
if(MSVC)
if(MSVC AND NODE_ADDON_FOUND)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
endif()

View File

@ -1 +1,46 @@
## AI工具箱
#AI工具箱
## 安装
```
npm install @yizhi/ai
```
## 使用
```typescript
import ai from "@yizhi/ai";
// 配置相应的node插件插件需自行编译
ai.config("CV_ADDON_FILE", "/path/to/cv.node");
ai.config("MNN_ADDON_FILE", "/path/to/mnn.node");
ai.config("ORT_ADDON_FILE", "/path/to/onnxruntime.node");
//直接推理
const facedet = await ai.deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
const boxes = await facedet.predict("/path/to/image");
//使用自己的模型
const session = new ai.backend.ort.Session(modelBuffer);
const outputs = session.run(inputs);
```
## 插件编译
1. 依赖
1. cmake
1. ninja
1. c++编译器(gcc,clang,Visual Studio ...)
1. 编译第三方库
```
node thirdpart/install.js --with-mnn --with-onnx
```
1. 编译插件
```
cmake -B build -G Ninja . -DCMAKE_BUILD_TYPE=Release
cmake --build build --config Release
```
注意注意在Windows下编译时需要打开Visual Studio命令行

View File

@ -6,8 +6,7 @@
#include <napi.h>
#define NODE_INIT_OBJECT(name, function) \
do \
{ \
do { \
auto obj = Napi::Object::New(env); \
function(env, obj); \
exports.Set(Napi::String::New(env, #name), obj); \
@ -21,4 +20,13 @@ inline uint64_t __node_ptr_of__(Napi::Value value)
#define NODE_PTR_OF(type, value) (reinterpret_cast<type *>(__node_ptr_of__(value)))
inline void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
{
auto arr = val.As<Napi::TypedArray>();
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
bytes = arr.ByteLength();
return static_cast<void *>(data + arr.ByteOffset());
}
#endif

18
cxx/common/tensor.h Normal file
View File

@ -0,0 +1,18 @@
#ifndef __COMMON_TENSOR_H__
#define __COMMON_TENSOR_H__
enum class TensorDataType {
Unknown,
Float32,
Float64,
Int32,
Uint32,
Int16,
Uint16,
Int8,
Uint8,
Int64,
Uint64,
};
#endif

View File

@ -1,145 +0,0 @@
#include <iostream>
#include <functional>
#include <opencv2/opencv.hpp>
#include "node.h"
using namespace Napi;
#define MAT_INSTANCE_METHOD(method) InstanceMethod<&CVMat::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
static FunctionReference *constructor = nullptr;
class CVMat : public ObjectWrap<CVMat> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports)
{
Function func = DefineClass(env, "Mat", {
MAT_INSTANCE_METHOD(IsEmpty),
MAT_INSTANCE_METHOD(GetCols),
MAT_INSTANCE_METHOD(GetRows),
MAT_INSTANCE_METHOD(GetChannels),
MAT_INSTANCE_METHOD(Resize),
MAT_INSTANCE_METHOD(Crop),
MAT_INSTANCE_METHOD(Rotate),
MAT_INSTANCE_METHOD(Clone),
MAT_INSTANCE_METHOD(DrawCircle),
MAT_INSTANCE_METHOD(Data),
MAT_INSTANCE_METHOD(Encode),
});
constructor = new FunctionReference();
*constructor = Napi::Persistent(func);
exports.Set("Mat", func);
env.SetInstanceData<FunctionReference>(constructor);
return exports;
}
CVMat(const CallbackInfo &info)
: ObjectWrap<CVMat>(info)
{
int mode = cv::IMREAD_COLOR_BGR;
if (info.Length() > 1 && info[1].IsObject()) {
Object options = info[1].As<Object>();
if (options.Has("mode") && options.Get("mode").IsNumber()) mode = options.Get("mode").As<Number>().Int32Value();
}
if (info[0].IsString()) im_ = cv::imread(info[0].As<String>().Utf8Value(), mode);
else if (info[0].IsTypedArray()) {
auto buffer = info[0].As<TypedArray>().ArrayBuffer();
uint8_t *bufferPtr = static_cast<uint8_t *>(buffer.Data());
std::vector<uint8_t> data(bufferPtr, bufferPtr + buffer.ByteLength());
im_ = cv::imdecode(data, mode);
}
}
~CVMat() { im_.release(); }
Napi::Value IsEmpty(const Napi::CallbackInfo &info) { return Boolean::New(info.Env(), im_.empty()); }
Napi::Value GetCols(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.cols); }
Napi::Value GetRows(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.rows); }
Napi::Value GetChannels(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.channels()); }
Napi::Value Resize(const Napi::CallbackInfo &info)
{
return CreateMat(info.Env(), [this, &info](auto &mat) { cv::resize(im_, mat.im_, cv::Size(info[0].As<Number>().Int32Value(), info[1].As<Number>().Int32Value())); });
}
Napi::Value Crop(const Napi::CallbackInfo &info)
{
return CreateMat(info.Env(), [this, &info](auto &mat) {
mat.im_ = im_(cv::Rect(
info[0].As<Number>().Int32Value(), info[1].As<Number>().Int32Value(),
info[2].As<Number>().Int32Value(), info[3].As<Number>().Int32Value()));
});
}
Napi::Value Rotate(const Napi::CallbackInfo &info)
{
return CreateMat(info.Env(), [this, &info](auto &mat) {
auto x = info[0].As<Number>().DoubleValue();
auto y = info[1].As<Number>().DoubleValue();
auto angle = info[2].As<Number>().DoubleValue();
cv::Mat rotation_matix = cv::getRotationMatrix2D(cv::Point2f(x, y), angle, 1.0);
cv::warpAffine(im_, mat.im_, rotation_matix, im_.size());
});
}
Napi::Value Clone(const Napi::CallbackInfo &info)
{
return CreateMat(info.Env(), [this, &info](auto &mat) { mat.im_ = im_.clone(); });
}
Napi::Value DrawCircle(const Napi::CallbackInfo &info)
{
int x = info[0].As<Number>().Int32Value();
int y = info[1].As<Number>().Int32Value();
int radius = info[2].As<Number>().Int32Value();
int b = info[3].As<Number>().Int32Value();
int g = info[4].As<Number>().Int32Value();
int r = info[5].As<Number>().Int32Value();
int thickness = info[6].As<Number>().Int32Value();
int lineType = info[7].As<Number>().Int32Value();
int shift = info[8].As<Number>().Int32Value();
cv::circle(im_, cv::Point(x, y), radius, cv::Scalar(b, g, r), thickness, lineType, shift);
return info.Env().Undefined();
}
Napi::Value Data(const Napi::CallbackInfo &info) { return ArrayBuffer::New(info.Env(), im_.ptr(), im_.elemSize() * im_.total()); }
Napi::Value Encode(const Napi::CallbackInfo &info)
{
auto options = info[0].As<Object>();
auto extname = options.Get("extname").As<String>().Utf8Value();
cv::imencode(extname, im_, encoded_);
return ArrayBuffer::New(info.Env(), encoded_.data(), encoded_.size());
}
private:
inline Napi::Object EmptyMat(Napi::Env env) { return constructor->New({}).As<Object>(); }
inline CVMat &GetMat(Napi::Object obj) { return *ObjectWrap<CVMat>::Unwrap(obj); }
inline Napi::Object CreateMat(Napi::Env env, std::function<void(CVMat &mat)> callback)
{
auto obj = EmptyMat(env);
callback(GetMat(obj));
return obj;
}
private:
cv::Mat im_;
std::vector<uint8_t> encoded_;
};
void InstallOpenCVAPI(Env env, Object exports)
{
CVMat::Init(env, exports);
}
#ifdef USE_OPENCV
static Object Init(Env env, Object exports)
{
InstallOpenCVAPI(env, exports);
return exports;
}
NODE_API_MODULE(addon, Init)
#endif

View File

@ -1,8 +0,0 @@
#ifndef __CV_NODE_H__
#define __CV_NODE_H__
#include "common/node.h"
void InstallOpenCVAPI(Napi::Env env, Napi::Object exports);
#endif

247
cxx/mnn/node.cc Normal file
View File

@ -0,0 +1,247 @@
#include <iostream>
#include <vector>
#include <map>
#include <cstring>
#include <MNN/Interpreter.hpp>
#include <MNN/ImageProcess.hpp>
#include "common/tensor.h"
#include "node.h"
using namespace Napi;
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
static const std::map<TensorDataType, halide_type_t> DATA_TYPE_MAP = {
{TensorDataType::Float32, halide_type_of<float>()},
{TensorDataType::Float64, halide_type_of<double>()},
{TensorDataType::Int32, halide_type_of<int32_t>()},
{TensorDataType::Uint32, halide_type_of<uint32_t>()},
{TensorDataType::Int16, halide_type_of<int16_t>()},
{TensorDataType::Uint16, halide_type_of<uint16_t>()},
{TensorDataType::Int8, halide_type_of<int8_t>()},
{TensorDataType::Uint8, halide_type_of<uint8_t>()},
{TensorDataType::Int64, halide_type_of<int64_t>()},
{TensorDataType::Uint64, halide_type_of<uint64_t>()},
};
static size_t getShapeSize(const std::vector<int> &shape)
{
if (!shape.size()) return 0;
size_t sum = 1;
for (auto i : shape) {
if (i > 1) sum *= i;
};
return sum;
}
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
{
auto result = Napi::Array::New(env, shape.size());
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
return result;
}
class MNNSessionRunWorker : public AsyncWorker {
public:
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
: AsyncWorker(callback), interpreter_(interpreter), session_(session) {}
~MNNSessionRunWorker()
{
interpreter_->releaseSession(session_);
}
void Execute()
{
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
SetError(std::string("Run session failed"));
}
}
void OnOK()
{
if (HasError()) {
Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()});
}
else {
auto result = Object::New(Env());
for (auto it : interpreter_->getSessionOutputAll(session_)) {
auto tensor = it.second;
auto item = Object::New(Env());
auto buffer = ArrayBuffer::New(Env(), tensor->size());
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
item.Set("data", buffer);
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
result.Set(it.first, item);
}
Callback().Call({Env().Undefined(), result});
}
}
void SetInput(const std::string &name, TensorDataType dataType, const std::vector<int> &shape, void *data, size_t dataBytes)
{
auto tensor = interpreter_->getSessionInput(session_, name.c_str());
if (!tensor) {
SetError(std::string("input name #" + name + " not exists"));
return;
}
halide_type_t type = tensor->getType();
if (dataType != TensorDataType::Unknown) {
auto it = DATA_TYPE_MAP.find(dataType);
if (it != DATA_TYPE_MAP.end()) type = it->second;
}
if (shape.size()) {
interpreter_->resizeTensor(tensor, shape);
interpreter_->resizeSession(session_);
}
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
if (tensorBytes != dataBytes) {
SetError(std::string("input name #" + name + " data size not matched"));
return;
}
auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE);
tensor->copyFromHostTensor(hostTensor);
delete hostTensor;
}
inline void SetError(const std::string &what) { errorMessage_ = what; }
inline bool HasError() { return errorMessage_.size() > 0; }
private:
MNN::Interpreter *interpreter_;
MNN::Session *session_;
std::string errorMessage_;
};
class MNNSession : public ObjectWrap<MNNSession> {
public:
static Napi::Object Init(Napi::Env env, Napi::Object exports)
{
Function func = DefineClass(env, "MNNSession", {
SESSION_INSTANCE_METHOD(GetInputsInfo),
SESSION_INSTANCE_METHOD(GetOutputsInfo),
SESSION_INSTANCE_METHOD(Run),
});
FunctionReference *constructor = new FunctionReference();
*constructor = Napi::Persistent(func);
exports.Set("MNNSession", func);
env.SetInstanceData<FunctionReference>(constructor);
return exports;
}
MNNSession(const CallbackInfo &info)
: ObjectWrap(info)
{
try {
if (info[0].IsString()) {
interpreter_ = MNN::Interpreter::createFromFile(info[0].As<String>().Utf8Value().c_str());
}
else if (info[0].IsTypedArray()) {
size_t bufferBytes;
auto buffer = dataFromTypedArray(info[0], bufferBytes);
interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes);
}
else interpreter_ = nullptr;
if (interpreter_) {
backendConfig_.precision = MNN::BackendConfig::Precision_High;
backendConfig_.power = MNN::BackendConfig::Power_High;
scheduleConfig_.type = MNN_FORWARD_CPU;
scheduleConfig_.numThread = 1;
scheduleConfig_.backendConfig = &backendConfig_;
session_ = interpreter_->createSession(scheduleConfig_);
}
else session_ = nullptr;
}
catch (std::exception &e) {
Error::New(info.Env(), e.what()).ThrowAsJavaScriptException();
}
}
~MNNSession() {}
Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); }
Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); }
Napi::Value Run(const Napi::CallbackInfo &info)
{
auto worker = new MNNSessionRunWorker(info[1].As<Function>(), interpreter_, interpreter_->createSession(scheduleConfig_));
auto inputArgument = info[0].As<Object>();
for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) {
auto name = (*it).first.As<String>().Utf8Value();
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
auto type = inputOption.Has("type") ? static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value()) : TensorDataType::Unknown;
size_t dataByteLen;
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : std::vector<int>();
worker->SetInput(name, type, shape, data, dataByteLen);
}
worker->Queue();
return info.Env().Undefined();
}
private:
Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map<std::string, MNN::Tensor *> &tensors)
{
auto result = Object::New(env);
for (auto it : tensors) {
auto item = Object::New(env);
auto name = it.first;
auto shape = it.second->shape();
auto type = it.second->getType();
TensorDataType dataType = TensorDataType::Unknown;
for (auto dt : DATA_TYPE_MAP) {
if (dt.second == type) {
dataType = dt.first;
break;
}
}
auto shapeArr = Array::New(env, shape.size());
for (size_t i = 0; i < shape.size(); i++) {
shapeArr.Set(i, Number::New(env, shape[i]));
}
item.Set("name", String::New(env, name));
item.Set("shape", shapeArr);
item.Set("type", Number::New(env, static_cast<int>(dataType)));
result.Set(name, item);
}
return result;
}
std::vector<int> GetShapeFromJavascript(const Napi::Array &shape)
{
std::vector<int> result;
for (size_t i = 0; i < shape.Length(); i++) {
result.push_back(shape.Get(i).As<Number>().Int32Value());
}
return result;
}
private:
MNN::Interpreter *interpreter_;
MNN::Session *session_;
MNN::BackendConfig backendConfig_;
MNN::ScheduleConfig scheduleConfig_;
};
void InstallMNNAPI(Napi::Env env, Napi::Object exports)
{
MNNSession::Init(env, exports);
}
#if defined(USE_MNN) && !defined(BUILD_MAIN_WORD)
static Object Init(Env env, Object exports)
{
#ifdef RELEASE_VERSION
exports.Set("__release__", Napi::String::New(env, RELEASE_VERSION));
#endif
InstallMNNAPI(env, exports);
return exports;
}
NODE_API_MODULE(addon, Init)
#endif

8
cxx/mnn/node.h Normal file
View File

@ -0,0 +1,8 @@
#ifndef __MNN_NODE_H__
#define __MNN_NODE_H__
#include "common/node.h"
void InstallMNNAPI(Napi::Env env, Napi::Object exports);
#endif

View File

@ -1,19 +0,0 @@
#ifndef __MNN_SESSION_H__
#define __MNN_SESSION_H__
#include <MNN/Interpreter.hpp>
#include <MNN/ImageProcess.hpp>
#include "common/session.h"
namespace ai
{
class MNNSession : public Session
{
public:
MNNSession(const void *modelData, size_t size);
~MNNSession();
};
}
#endif

View File

@ -1,65 +1,25 @@
// #include <unistd.h>
// #include <napi.h>
// #include "cv/node.h"
// #ifdef USE_ORT
// #include "ort/node.h"
// #endif
#include "common/node.h"
#include "mnn/node.h"
#include "ort/node.h"
// using namespace Napi;
using namespace Napi;
// class TestWork : public AsyncWorker
// {
// public:
// TestWork(const Napi::Function &callback, int value) : Napi::AsyncWorker(callback), val_(value) {}
// ~TestWork() {}
#if defined(BUILD_MAIN_WORD)
Object Init(Env env, Object exports)
{
#ifdef RELEASE_VERSION
exports.Set("__release__", Napi::String::New(env, RELEASE_VERSION));
#endif
// OnnxRuntime
#ifdef USE_ONNXRUNTIME
InstallOrtAPI(env, exports);
#endif
// MNN
#ifdef USE_MNN
InstallMNNAPI(env, exports);
#endif
// void Execute()
// {
// printf("the worker-thread doing! %d \n", val_);
// sleep(3);
// printf("the worker-thread done! %d \n", val_);
// }
// void OnOK()
// {
// Callback().Call({Env().Undefined(), Number::New(Env(), 0)});
// }
// private:
// int val_;
// };
// Value test(const CallbackInfo &info)
// {
// // ai::ORTSession(nullptr, 0);
// // Function callback = info[1].As<Function>();
// // TestWork *work = new TestWork(callback, info[0].As<Number>().Int32Value());
// // work->Queue();
// return info.Env().Undefined();
// }
// Object Init(Env env, Object exports)
// {
// //OpenCV
// NODE_INIT_OBJECT(cv, InstallOpenCVAPI);
// //OnnxRuntime
// #ifdef USE_ORT
// NODE_INIT_OBJECT(ort, InstallOrtAPI);
// #endif
// Napi::Number::New(env, 0);
// #define ADD_FUNCTION(name) exports.Set(Napi::String::New(env, #name), Napi::Function::New(env, name))
// // ADD_FUNCTION(facedetPredict);
// // ADD_FUNCTION(facedetRelease);
// // ADD_FUNCTION(faceRecognitionCreate);
// // ADD_FUNCTION(faceRecognitionPredict);
// // ADD_FUNCTION(faceRecognitionRelease);
// // ADD_FUNCTION(getDistance);
// #undef ADD_FUNCTION
// return exports;
// }
// NODE_API_MODULE(addon, Init)
return exports;
}
NODE_API_MODULE(addon, Init)
#endif

View File

@ -2,6 +2,7 @@
#include <vector>
#include <onnxruntime_cxx_api.h>
#include "node.h"
#include "common/tensor.h"
#ifdef WIN32
#include <locale>
@ -13,29 +14,20 @@ using namespace Napi;
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
static ONNXTensorElementDataType getDataTypeFromString(const std::string &name)
{
static const std::map<std::string, ONNXTensorElementDataType> dataTypeNameMap = {
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
{"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
{"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
{"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
{"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
{"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
{"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
{"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
{"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
{"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
{"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
{"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
};
auto it = dataTypeNameMap.find(name);
return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second;
}
static const std::map<TensorDataType, ONNXTensorElementDataType> DATA_TYPE_MAP = {
{TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
{TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
{TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
{TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
{TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
{TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
{TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
{TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
{TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
{TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
};
static size_t getDataTypeSize(ONNXTensorElementDataType type)
{
static const std::map<ONNXTensorElementDataType, size_t> dataTypeSizeMap = {
static const std::map<ONNXTensorElementDataType, size_t> DATA_TYPE_SIZE_MAP = {
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
@ -46,18 +38,7 @@ static size_t getDataTypeSize(ONNXTensorElementDataType type)
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
};
auto it = dataTypeSizeMap.find(type);
return (it == dataTypeSizeMap.end()) ? 0 : it->second;
}
static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
{
auto arr = val.As<TypedArray>();
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
bytes = arr.ByteLength();
return static_cast<void *>(data + arr.ByteOffset());
}
};
class OrtSessionNodeInfo {
public:
@ -67,7 +48,22 @@ class OrtSessionNodeInfo {
inline const std::string &GetName() const { return name_; }
inline const std::vector<int64_t> &GetShape() const { return shape_; }
inline ONNXTensorElementDataType GetType() const { return type_; }
inline size_t GetElementSize() const { return getDataTypeSize(type_); }
inline size_t GetElementSize() const
{
auto it = DATA_TYPE_SIZE_MAP.find(type_);
return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second;
}
TensorDataType GetDataType() const
{
auto datatype = TensorDataType::Unknown;
for (auto it : DATA_TYPE_MAP) {
if (it.second == type_) {
datatype = it.first;
break;
}
}
return datatype;
}
size_t GetElementCount() const
{
if (!shape_.size()) return 0;
@ -115,7 +111,8 @@ class OrtSessionRunWorker : public AsyncWorker {
for (int i = 0; i < outputNames_.size(); ++i) {
size_t bytes = outputElementBytes_[i];
Ort::Value &value = outputValues_[i];
auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes);
auto buffer = ArrayBuffer::New(Env(), bytes);
memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes);
result.Set(String::New(Env(), outputNames_[i]), buffer);
}
Callback().Call({Env().Undefined(), result});
@ -236,7 +233,12 @@ class OrtSession : public ObjectWrap<OrtSession> {
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
else {
auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As<String>().Utf8Value()) : input->GetType();
auto type = input->GetType();
if (inputOption.Has("type")) {
auto t = static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value());
auto it = DATA_TYPE_MAP.find(t);
if (it != DATA_TYPE_MAP.end()) type = it->second;
}
size_t dataByteLen;
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
@ -279,7 +281,7 @@ class OrtSession : public ObjectWrap<OrtSession> {
auto &node = *it.second;
auto item = Object::New(env);
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
item.Set(String::New(env, "type"), Number::New(env, node.GetType()));
item.Set(String::New(env, "type"), Number::New(env, static_cast<int>(node.GetDataType())));
auto &shapeVec = node.GetShape();
auto shape = Array::New(env, shapeVec.size());
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
@ -303,9 +305,12 @@ void InstallOrtAPI(Napi::Env env, Napi::Object exports)
OrtSession::Init(env, exports);
}
#ifdef USE_ONNXRUNTIME
#if defined(USE_ONNXRUNTIME) && !defined(BUILD_MAIN_WORD)
static Object Init(Env env, Object exports)
{
#ifdef RELEASE_VERSION
exports.Set("__release__", Napi::String::New(env, RELEASE_VERSION));
#endif
InstallOrtAPI(env, exports);
return exports;
}

View File

@ -1,8 +1,11 @@
{
"name": "ai-box",
"version": "1.0.0",
"main": "index.js",
"name": "@yizhi/ai",
"version": "1.0.8",
"releaseVersion": "1.0.6",
"main": "dist/index.js",
"types": "typing/index.d.ts",
"scripts": {
"build": "rm -rf dist typing && tsc",
"watch": "tsc -w --inlineSourceMap"
},
"keywords": [],
@ -15,5 +18,8 @@
"compressing": "^1.10.1",
"node-addon-api": "^8.3.1",
"unbzip2-stream": "^1.4.3"
},
"dependencies": {
"@yizhi/cv": "^1.0.2"
}
}

View File

@ -2,22 +2,41 @@
export interface SessionNodeInfo {
name: string
type: number
type: DataType
shape: number[]
}
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
export enum DataType {
Unknown,
Float32,
Float64,
Int32,
Uint32,
Int16,
Uint16,
Int8,
Uint8,
Int64,
Uint64,
}
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
export interface SessionRunInputOption {
type?: SessionNodeType
type?: DataTypeString
data: SessionNodeData
shape?: number[]
}
export interface SessionRunOutput {
shape: number[]
data: Float32Array
}
export abstract class CommonSession {
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>>
public abstract get inputs(): Record<string, SessionNodeInfo>;
public abstract get outputs(): Record<string, SessionNodeInfo>;
@ -26,3 +45,36 @@ export abstract class CommonSession {
export function isTypedArray(val: any): val is SessionNodeData {
return val?.buffer instanceof ArrayBuffer;
}
export function dataTypeFrom(str: DataTypeString) {
return {
"float32": DataType.Float32,
"float64": DataType.Float64,
"float": DataType.Float32,
"double": DataType.Float64,
"int32": DataType.Int32,
"uint32": DataType.Uint32,
"int16": DataType.Int16,
"uint16": DataType.Uint16,
"int8": DataType.Int8,
"uint8": DataType.Uint8,
"int64": DataType.Int64,
"uint64": DataType.Uint64,
}[str] ?? DataType.Unknown;
}
export function dataTypeToString(type: DataType): DataTypeString | null {
switch (type) {
case DataType.Float32: return "float32";
case DataType.Float64: return "float64";
case DataType.Int32: return "int32";
case DataType.Uint32: return "uint32";
case DataType.Int16: return "int16";
case DataType.Uint16: return "uint16";
case DataType.Int8: return "int8";
case DataType.Uint8: return "uint8";
case DataType.Int64: return "int64";
case DataType.Uint64: return "uint64";
default: return null;
}
}

12
src/backend/config.ts Normal file
View File

@ -0,0 +1,12 @@
import path from "path";
const defaultAddonDir = path.join(__dirname, "../../build")
const aiConfig = {
"MNN_ADDON_FILE": path.join(defaultAddonDir, "mnn.node"),
"ORT_ADDON_FILE": path.join(defaultAddonDir, "ort.node"),
};
export function setConfig<K extends keyof typeof aiConfig>(key: K, value: typeof aiConfig[K]) { aiConfig[key] = value; }
export function getConfig<K extends keyof typeof aiConfig>(key: K): typeof aiConfig[K] { return aiConfig[key]; }

87
src/backend/download.ts Normal file
View File

@ -0,0 +1,87 @@
import os from "os";
import fs from "fs";
import path from "path";
import { getConfig } from "./config";
const URLS = {
GITHUB: "https://github.com/kangkang520/node-addons/releases/download/ai{{version}}/{{filename}}",
XXXXX: "http://git.urnas.cn:5200/yizhi-js-lib/ai-box/releases/download/{{version}}/{{filename}}",
}
function releaseVersion() { return require("../../package.json").releaseVersion }
function getURL(backend: "ort" | "mnn", template: string) {
const URL_DICT: Record<string, Record<string, string>> = {
"win32": {
"x64": `${backend}_windows_x64.node`
},
"linux": {
"x64": `${backend}_linux_x64.node`,
"arm64": `${backend}_linux_arm64.node`,
},
"darwin": {
"arm64": `${backend}_macos_arm64.node`,
}
}
const archConfig = URL_DICT[os.platform()];
if (!archConfig) throw new Error(`Unsupported platform: ${os.platform()}`);
const downloadName = archConfig[os.arch()];
if (!downloadName) throw new Error(`Unsupported arch: ${os.arch()}`);
return template.replaceAll("{{version}}", releaseVersion()).replaceAll("{{filename}}", downloadName);
}
async function getStream(backend: "mnn" | "ort") {
for (const [name, url] of Object.entries(URLS)) {
try {
return await fetch(getURL(backend, url)).then(res => {
if (res.status != 200) throw new Error("Failed to download addon.");
return res.blob().then(b => b.stream());
})
} catch (e) { }
}
throw new Error("Failed to download addon.");
}
export async function downloadBackend(backend: "ort" | "mnn", savename?: string) {
const backendConfigNameDict = { ort: "ORT_ADDON_FILE" as const, mnn: "MNN_ADDON_FILE" as const };
const defaultAddon = path.resolve(process.cwd(), getConfig(backendConfigNameDict[backend]));
const saveName = savename ? path.resolve(path.dirname(defaultAddon), savename) : defaultAddon;
if (fs.existsSync(saveName)) {
try {
const addon = require(saveName);
if (addon.__release__ === releaseVersion()) return saveName;
//清除缓存
delete require.cache[saveName];
} catch (err) { }
}
await fs.promises.mkdir(path.dirname(saveName), { recursive: true });
const stream = await getStream(backend);
const cacheFile = await new Promise<string>((resolve, reject) => {
const cacheFile = path.join(os.tmpdir(), Date.now() + ".cv.node");
let fsStream!: ReturnType<typeof fs.createWriteStream>;
stream.pipeTo(new WritableStream({
start(controller) {
fsStream = fs.createWriteStream(cacheFile);
},
async write(chunk, controller) {
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
},
close() {
fsStream.end();
resolve(cacheFile);
},
abort() { }
})).catch(reject);
});
if (fs.existsSync(saveName)) await fs.promises.rm(saveName, { recursive: true, force: true });
await fs.promises.cp(cacheFile, saveName);
await fs.promises.rm(cacheFile);
return saveName;
}

View File

@ -1,2 +1,4 @@
export * as common from "./common";
export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common";
export * as ort from "./ort";
export * as mnn from "./mnn";
export { downloadBackend } from "./download";

1
src/backend/mnn/index.ts Normal file
View File

@ -0,0 +1 @@
export { MNNSession as Session } from "./session";

View File

@ -0,0 +1,34 @@
import { getConfig } from "../config";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
export class MNNSession extends CommonSession {
#session: any
#inputs: Record<string, SessionNodeInfo> | null = null;
#outputs: Record<string, SessionNodeInfo> | null = null;
public constructor(modelData: Uint8Array) {
super();
const addon = require(getConfig("MNN_ADDON_FILE"));
this.#session = new addon.MNNSession(modelData);
}
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>> {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
}
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, { data: ArrayBuffer, shape: number[] }>) => {
if (err) return reject(err);
const result: Record<string, SessionRunOutput> = {};
for (const [name, val] of Object.entries(res)) result[name] = {
shape: val.shape,
data: new Float32Array(val.data),
}
resolve(result);
}))
}
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
}

View File

@ -1,4 +1,5 @@
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
import { getConfig } from "../config";
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
export class OrtSession extends CommonSession {
#session: any;
@ -7,7 +8,7 @@ export class OrtSession extends CommonSession {
public constructor(modelData: Uint8Array) {
super();
const addon = require("../../../build/ort.node")
const addon = require(getConfig("ORT_ADDON_FILE"));
this.#session = new addon.OrtSession(modelData);
}
@ -19,13 +20,13 @@ export class OrtSession extends CommonSession {
const inputArgs: Record<string, any> = {};
for (const [name, option] of Object.entries(inputs)) {
if (isTypedArray(option)) inputArgs[name] = { data: option }
else inputArgs[name] = option;
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
}
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
if (err) return reject(err);
const result: Record<string, Float32Array> = {};
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
const result: Record<string, SessionRunOutput> = {};
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
resolve(result);
}));
}

View File

@ -1 +0,0 @@
export * as cv from "./main";

View File

@ -1 +0,0 @@
export { Mat, ImreadModes } from "./mat";

View File

@ -1,79 +0,0 @@
export enum ImreadModes {
IMREAD_UNCHANGED = -1,
IMREAD_GRAYSCALE = 0,
IMREAD_COLOR_BGR = 1,
IMREAD_COLOR = 1,
IMREAD_ANYDEPTH = 2,
IMREAD_ANYCOLOR = 4,
IMREAD_LOAD_GDAL = 8,
IMREAD_REDUCED_GRAYSCALE_2 = 16,
IMREAD_REDUCED_COLOR_2 = 17,
IMREAD_REDUCED_GRAYSCALE_4 = 32,
IMREAD_REDUCED_COLOR_4 = 33,
IMREAD_REDUCED_GRAYSCALE_8 = 64,
IMREAD_REDUCED_COLOR_8 = 65,
IMREAD_IGNORE_ORIENTATION = 128,
IMREAD_COLOR_RGB = 256,
};
interface MatConstructorOption {
mode?: ImreadModes;
}
export class Mat {
#mat: any
public static async load(image: string, option?: MatConstructorOption) {
let buffer: Uint8Array
if (/^https?:\/\//.test(image)) buffer = await fetch(image).then(res => res.arrayBuffer()).then(res => new Uint8Array(res));
else buffer = await import("fs").then(fs => fs.promises.readFile(image));
return new Mat(buffer, option);
}
public constructor(imageData: Uint8Array, option?: MatConstructorOption) {
const addon = require("../../build/cv.node");
if ((imageData as any) instanceof addon.Mat) this.#mat = imageData;
else this.#mat = new addon.Mat(imageData, option);
}
public get empty(): boolean { return this.#mat.IsEmpty() }
public get cols(): number { return this.#mat.GetCols(); }
public get rows(): number { return this.#mat.GetRows(); }
public get width() { return this.cols; }
public get height() { return this.rows; }
public get channels() { return this.#mat.GetChannels(); }
public resize(width: number, height: number) { return new Mat(this.#mat.Resize.bind(this.#mat)(width, height)); }
public crop(sx: number, sy: number, sw: number, sh: number) { return new Mat(this.#mat.Crop(sx, sy, sw, sh)); }
public rotate(sx: number, sy: number, angleDeg: number) { return new Mat(this.#mat.Rotate(sx, sy, angleDeg)); }
public get data() { return new Uint8Array(this.#mat.Data()); }
public encode(extname: string) { return new Uint8Array(this.#mat.Encode({ extname })); }
public clone() { return new Mat(this.#mat.Clone()); }
public circle(x: number, y: number, radius: number, options?: {
color?: { r: number, g: number, b: number },
thickness?: number
lineType?: number
}) {
this.#mat.DrawCircle(
x, y, radius,
options?.color?.b ?? 0,
options?.color?.g ?? 0,
options?.color?.r ?? 0,
options?.thickness ?? 1,
options?.lineType ?? 8,
0,
);
}
}

View File

@ -1,13 +1,13 @@
import { cv } from "@yizhi/cv";
import { backend } from "../../backend";
import { cv } from "../../cv";
export type ModelConstructor<T> = new (session: backend.common.CommonSession) => T;
export type ModelConstructor<T> = new (session: backend.CommonSession) => T;
export type ImageSource = cv.Mat | Uint8Array | string;
export interface ImageCropOption {
/** 图片裁剪区域 */
crop?: { sx: number, sy: number, sw: number, sh: number }
crop?: { x: number, y: number, width: number, height: number }
}
export type ModelType = "onnx" | "mnn"
@ -24,15 +24,17 @@ export interface ModelCacheResult<T, Create extends boolean> {
model: Create extends true ? T : never
}
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
export abstract class Model {
protected session: backend.common.CommonSession;
protected session: backend.CommonSession;
protected static async resolveImage<R>(image: ImageSource, resolver: (image: cv.Mat) => R | Promise<R>): Promise<R> {
if (typeof image === "string") {
if (/^https?:\/\//.test(image)) image = await fetch(image).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
else image = await import("fs").then(fs => fs.promises.readFile(image as string));
}
if (image instanceof Uint8Array) image = new cv.Mat(image, { mode: cv.ImreadModes.IMREAD_COLOR_BGR })
if (image instanceof Uint8Array) image = cv.imdecode(image, cv.IMREAD_COLOR_BGR);
if (image instanceof cv.Mat) return await resolver(image);
else throw new Error("Invalid image");
}
@ -45,11 +47,22 @@ export abstract class Model {
return new this(new backend.ort.Session(modelData as Uint8Array));
}
public static async fromMNN<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
if (typeof modelData === "string") {
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
}
return new this(new backend.mnn.Session(modelData as Uint8Array));
}
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
//初始化目录
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
await fs.promises.mkdir(cacheDir, { recursive: true });
//定义函数用于加载模型信息
async function resolveModel() {
//加载模型配置
const cacheJsonFile = path.join(cacheDir, "config.json");
let cacheJsonData: Array<{ url: string, filename: string }> = [];
@ -60,7 +73,7 @@ export abstract class Model {
}
//不存在则下载
let cache = cacheJsonData.find(c => c.url === url);
if (!cache) {
if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) {
let saveType = option?.saveType ?? null;
const saveTypeDict: Record<string, ModelType> = {
".onnx": "onnx",
@ -75,10 +88,10 @@ export abstract class Model {
return res.blob();
}).then(blob => blob.stream()).then(async stream => {
const cacheFilename = path.join(cacheDir, Date.now().toString());
const hash = await new Promise<string>((resolve, reject) => {
let fsStream!: ReturnType<typeof fs.createWriteStream>;
let hashStream!: ReturnType<typeof crypto.createHash>;
let hash!: string;
await stream.pipeTo(new WritableStream({
stream.pipeTo(new WritableStream({
start(controller) {
fsStream = fs.createWriteStream(cacheFilename);
hashStream = crypto.createHash("md5");
@ -90,33 +103,52 @@ export abstract class Model {
close() {
fsStream.end();
hashStream.end();
hash = hashStream.digest("hex")
resolve(hashStream.digest("hex"));
},
abort() { }
}));
})).catch(reject);
})
return { filename: cacheFilename, hash };
});
//重命名
const filename = `${res.hash}.${saveType}`;
fs.promises.rename(res.filename, path.join(cacheDir, filename));
//保存缓存
if (!cache) {
cache = { url, filename };
cacheJsonData.push(cache);
}
else {
cache.filename = filename;
cache.url = url;
}
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
}
//返回模型数据
const modelPath = path.join(cacheDir, cache.filename);
return path.join(cacheDir, cache.filename);
}
const modelType = path.extname(cache.filename).substring(1) as ModelType;
//查找任务
let cache = cacheDownloadTasks.find(c => c.url === url && c.cacheDir === cacheDir);
if (!cache) {
cache = { url, cacheDir, modelPath: resolveModel() }
cacheDownloadTasks.push(cache);
}
//获取模型数据
const modelPath = await cache.modelPath;
const modelType = path.extname(modelPath).substring(1) as ModelType;
let model: T | undefined = undefined;
if (option?.createModel) {
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
}
//返回结果
return { modelPath, modelType, model: model as any }
}
public constructor(session: backend.common.CommonSession) { this.session = session; }
public constructor(session: backend.CommonSession) { this.session = session; }
public get inputs() { return this.session.inputs; }
public get outputs() { return this.session.outputs; }

View File

@ -5,10 +5,41 @@ export interface FacePoint {
type PointType = "leftEye" | "rightEye" | "leftEyebrow" | "rightEyebrow" | "nose" | "mouth" | "contour"
export function indexFromTo(from: number, to: number) {
const indexes: number[] = [];
for (let i = from; i <= to; i++) {
indexes.push(i);
}
return indexes;
}
interface PointGroup {
/** 用于判断方向的两个点的索引(建议选取眼球中间的点) */
directionPointIndex: [number, number];
/** 左眼点的索引 */
leftEyePointIndex: number[];
/** 右眼点的索引 */
rightEyePointIndex: number[];
/** 左眉点的索引 */
leftEyebrowPointIndex: number[];
/** 右眉点的索引 */
rightEyebrowPointIndex: number[];
/** 嘴巴点的索引 */
mouthPointIndex: number[];
/** 鼻子的索引 */
nosePointIndex: number[];
/** 轮廓点的索引 */
contourPointIndex: number[];
}
export abstract class FaceAlignmentResult {
#points: FacePoint[]
#group: PointGroup
public constructor(points: FacePoint[]) { this.#points = points; }
public constructor(points: FacePoint[], group: PointGroup) {
this.#points = points;
this.#group = group;
}
/** 关键点 */
public get points() { return this.#points; }
@ -18,7 +49,7 @@ export abstract class FaceAlignmentResult {
if (typeof type == "string") type = [type];
const result: FacePoint[] = [];
for (const t of type) {
for (const idx of this[`${t}PointIndex` as const]()) {
for (const idx of this.#group[`${t}PointIndex` as const]) {
result.push(this.points[idx]);
}
}
@ -27,32 +58,7 @@ export abstract class FaceAlignmentResult {
/** 方向 */
public get direction() {
const [{ x: x1, y: y1 }, { x: x2, y: y2 }] = this.directionPointIndex().map(idx => this.points[idx]);
const [{ x: x1, y: y1 }, { x: x2, y: y2 }] = this.#group.directionPointIndex.map(idx => this.points[idx]);
return Math.atan2(y1 - y2, x2 - x1)
}
/** 用于判断方向的两个点的索引(建议选取眼球中间的点) */
protected abstract directionPointIndex(): [number, number];
/** 左眼点的索引 */
protected abstract leftEyePointIndex(): number[];
/** 右眼点的索引 */
protected abstract rightEyePointIndex(): number[];
/** 左眉点的索引 */
protected abstract leftEyebrowPointIndex(): number[];
/** 右眉点的索引 */
protected abstract rightEyebrowPointIndex(): number[];
/** 嘴巴点的索引 */
protected abstract mouthPointIndex(): number[];
/** 鼻子的索引 */
protected abstract nosePointIndex(): number[];
/** 轮廓点的索引 */
protected abstract contourPointIndex(): number[];
protected indexFromTo(from: number, to: number) {
const indexes: number[] = [];
for (let i = from; i <= to; i++) {
indexes.push(i);
}
return indexes;
}
}

View File

@ -1,25 +1,30 @@
import { writeFileSync } from "fs";
import { cv } from "../../cv";
import cv from "@yizhi/cv";
import { ImageCropOption, ImageSource, Model } from "../common/model";
import { convertImage } from "../common/processors";
import { FaceAlignmentResult, FacePoint } from "./common";
import { FaceAlignmentResult, FacePoint, indexFromTo } from "./common";
interface FaceLandmark1000PredictOption extends ImageCropOption { }
class FaceLandmark1000Result extends FaceAlignmentResult {
protected directionPointIndex(): [number, number] { return [401, 529]; }
protected leftEyePointIndex(): number[] { return this.indexFromTo(401, 528); }
protected rightEyePointIndex(): number[] { return this.indexFromTo(529, 656); }
protected leftEyebrowPointIndex(): number[] { return this.indexFromTo(273, 336); }
protected rightEyebrowPointIndex(): number[] { return this.indexFromTo(337, 400); }
protected mouthPointIndex(): number[] { return this.indexFromTo(845, 972); }
protected nosePointIndex(): number[] { return this.indexFromTo(657, 844); }
protected contourPointIndex(): number[] { return this.indexFromTo(0, 272); }
public constructor(points: FacePoint[]) {
super(points, {
directionPointIndex: [401, 529],
leftEyePointIndex: indexFromTo(401, 528),
rightEyePointIndex: indexFromTo(529, 656),
leftEyebrowPointIndex: indexFromTo(273, 336),
rightEyebrowPointIndex: indexFromTo(337, 400),
mouthPointIndex: indexFromTo(845, 972),
nosePointIndex: indexFromTo(657, 844),
contourPointIndex: indexFromTo(0, 272),
});
}
}
const MODEL_URL_CONFIG = {
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
FACELANDMARK1000_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.mnn`,
};
export class FaceLandmark1000 extends Model {
@ -32,10 +37,10 @@ export class FaceLandmark1000 extends Model {
public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) {
const input = this.input;
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
const ratioWidth = image.width / input.shape[3];
const ratioHeight = image.height / input.shape[2];
image = image.resize(input.shape[3], input.shape[2]);
if (option?.crop) image = cv.crop(image, option?.crop);
const ratioWidth = image.cols / input.shape[3];
const ratioHeight = image.rows / input.shape[2];
image = cv.resize(image, input.shape[3], input.shape[2]);
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "gray", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
@ -45,12 +50,12 @@ export class FaceLandmark1000 extends Model {
data: nchwImageData,
type: "float32",
}
}).then(res => res[this.output.name]);
}).then(res => res[this.output.name].data);
const points: FacePoint[] = [];
for (let i = 0; i < res.length; i += 2) {
const x = res[i] * image.width * ratioWidth;
const y = res[i + 1] * image.height * ratioHeight;
const x = res[i] * image.cols * ratioWidth;
const y = res[i + 1] * image.rows * ratioHeight;
points.push({ x, y });
}

View File

@ -1,5 +1,4 @@
import { writeFileSync } from "fs";
import { cv } from "../../cv";
import cv from "@yizhi/cv";
import { ImageCropOption, ImageSource, Model } from "../common/model";
import { convertImage } from "../common/processors";
import { FaceAlignmentResult, FacePoint } from "./common";
@ -7,20 +6,27 @@ import { FaceAlignmentResult, FacePoint } from "./common";
export interface PFLDPredictOption extends ImageCropOption { }
class PFLDResult extends FaceAlignmentResult {
protected directionPointIndex(): [number, number] { return [36, 92]; }
protected leftEyePointIndex(): number[] { return [33, 34, 35, 36, 37, 38, 39, 40, 41, 42]; }
protected rightEyePointIndex(): number[] { return [87, 88, 89, 90, 91, 92, 93, 94, 95, 96]; }
protected leftEyebrowPointIndex(): number[] { return [43, 44, 45, 46, 47, 48, 49, 50, 51]; }
protected rightEyebrowPointIndex(): number[] { return [97, 98, 99, 100, 101, 102, 103, 104, 105]; }
protected mouthPointIndex(): number[] { return [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]; }
protected nosePointIndex(): number[] { return [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86]; }
protected contourPointIndex(): number[] { return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]; }
public constructor(points: FacePoint[]) {
super(points, {
directionPointIndex: [36, 92],
leftEyePointIndex: [33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
rightEyePointIndex: [87, 88, 89, 90, 91, 92, 93, 94, 95, 96],
leftEyebrowPointIndex: [43, 44, 45, 46, 47, 48, 49, 50, 51],
rightEyebrowPointIndex: [97, 98, 99, 100, 101, 102, 103, 104, 105],
mouthPointIndex: [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
nosePointIndex: [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86],
contourPointIndex: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
});
}
}
const MODEL_URL_CONFIG = {
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
PFLD_106_LITE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.mnn`,
PFLD_106_V2_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.mnn`,
PFLD_106_V3_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.mnn`,
};
export class PFLD extends Model {
@ -32,10 +38,10 @@ export class PFLD extends Model {
private async doPredict(image: cv.Mat, option?: PFLDPredictOption) {
const input = this.input;
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
const ratioWidth = image.width / input.shape[3];
const ratioHeight = image.height / input.shape[2];
image = image.resize(input.shape[3], input.shape[2]);
if (option?.crop) image = cv.crop(image, option.crop);
const ratioWidth = image.cols / input.shape[3];
const ratioHeight = image.rows / input.shape[2];
image = cv.resize(image, input.shape[3], input.shape[2]);
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [255] } })
@ -48,12 +54,12 @@ export class PFLD extends Model {
shape: [1, 3, input.shape[2], input.shape[3]],
}
});
const pointsBuffer = res[pointsOutput.name];
const pointsBuffer = res[pointsOutput.name].data;
const points: FacePoint[] = [];
for (let i = 0; i < pointsBuffer.length; i += 2) {
const x = pointsBuffer[i] * image.width * ratioWidth;
const y = pointsBuffer[i + 1] * image.height * ratioHeight;
const x = pointsBuffer[i] * image.cols * ratioWidth;
const y = pointsBuffer[i + 1] * image.rows * ratioHeight;
points.push({ x, y });
}

View File

@ -1,4 +1,4 @@
import { cv } from "../../cv";
import cv from "@yizhi/cv";
import { ImageCropOption, ImageSource, Model } from "../common/model";
import { convertImage } from "../common/processors";
@ -12,6 +12,7 @@ export interface GenderAgePredictResult {
const MODEL_URL_CONFIG = {
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
INSIGHT_GENDER_AGE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.mnn`,
};
export class GenderAge extends Model {
@ -24,8 +25,8 @@ export class GenderAge extends Model {
private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise<GenderAgePredictResult> {
const input = this.input;
const output = this.output;
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
image = image.resize(input.shape[3], input.shape[2]);
if (option?.crop) image = cv.crop(image, option.crop);
image = cv.resize(image, input.shape[3], input.shape[2]);
const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
@ -35,7 +36,7 @@ export class GenderAge extends Model {
data: nchwImage,
type: "float32",
}
}).then(res => res[output.name]);
}).then(res => res[output.name].data);
return {
gender: result[0] > result[1] ? "F" : "M",

View File

@ -1,4 +1,4 @@
import { cv } from "../../cv"
import cv from "@yizhi/cv"
import { ImageSource, Model } from "../common/model"
interface IFaceBoxConstructorOption {
@ -49,8 +49,8 @@ export class FaceBox {
return new FaceBox({
...this.#option,
x1: this.centerX - size, y1: this.centerY - size,
x2: this.centerX + size, y2: this.centerY + size,
x1: cx - size, y1: cy - size,
x2: cx + size, y2: cy + size,
});
}
}

View File

@ -1,9 +1,10 @@
import { cv } from "../../cv";
import cv from "@yizhi/cv";
import { convertImage } from "../common/processors";
import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
const MODEL_URL_CONFIG = {
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
YOLOV5S_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.mnn`,
};
export class Yolov5Face extends FaceDetector {
@ -14,9 +15,9 @@ export class Yolov5Face extends FaceDetector {
public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise<FaceBox[]> {
const input = this.input;
const resizedImage = image.resize(input.shape[2], input.shape[3]);
const ratioWidth = image.width / resizedImage.width;
const ratioHeight = image.height / resizedImage.height;
const resizedImage = cv.resize(image, input.shape[2], input.shape[3]);
const ratioWidth = image.cols / resizedImage.cols;
const ratioHeight = image.rows / resizedImage.rows;
const nchwImageData = convertImage(resizedImage.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
const outputData = await this.session.run({ input: nchwImageData }).then(r => r.output);
@ -26,7 +27,7 @@ export class Yolov5Face extends FaceDetector {
const threshold = option?.threshold ?? 0.5;
for (let i = 0; i < outShape[1]; i++) {
const beg = i * outShape[2];
const rectData = outputData.slice(beg, beg + outShape[2]);
const rectData = outputData.data.slice(beg, beg + outShape[2]);
const x = parseInt(rectData[0] * ratioWidth as any);
const y = parseInt(rectData[1] * ratioHeight as any);
const w = parseInt(rectData[2] * ratioWidth as any);
@ -36,7 +37,7 @@ export class Yolov5Face extends FaceDetector {
faces.push(new FaceBox({
x1: x - w / 2, y1: y - h / 2,
x2: x + w / 2, y2: y + h / 2,
score, imw: image.width, imh: image.height,
score, imw: image.cols, imh: image.rows,
}))
}
return nms(faces, option?.mnsThreshold ?? 0.3).map(box => box.toInt());

View File

@ -1,9 +1,10 @@
import { Mat } from "../../cv/mat";
import cv from "@yizhi/cv";
import { convertImage } from "../common/processors";
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
const MODEL_URL_CONFIG = {
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
MOBILEFACENET_ADAFACE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.mnn`,
};
export class AdaFace extends FaceRecognition {
@ -12,12 +13,12 @@ export class AdaFace extends FaceRecognition {
return this.cacheModel(MODEL_URL_CONFIG[type ?? "MOBILEFACENET_ADAFACE_ONNX"], { createModel: true }).then(r => r.model);
}
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
public async doPredict(image: cv.Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
const input = this.input;
const output = this.output;
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
image = image.resize(input.shape[3], input.shape[2]);
if (option?.crop) image = cv.crop(image, option.crop);
image = cv.resize(image, input.shape[3], input.shape[2]);
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
@ -27,7 +28,7 @@ export class AdaFace extends FaceRecognition {
data: nchwImageData,
shape: [1, 3, input.shape[2], input.shape[3]],
}
}).then(res => res[output.name]);
}).then(res => res[output.name].data);
return new Array(...embedding);
}

View File

@ -1,4 +1,4 @@
import { cv } from "../../cv";
import cv from "@yizhi/cv";
import { ImageCropOption, ImageSource, Model } from "../common/model";
export interface FaceRecognitionPredictOption extends ImageCropOption { }

View File

@ -1,4 +1,4 @@
import { Mat } from "../../cv/mat";
import cv from "@yizhi/cv";
import { convertImage } from "../common/processors";
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
@ -7,26 +7,33 @@ const MODEL_URL_CONFIG_ARC_FACE = {
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
INSIGHTFACE_ARCFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.mnn`,
INSIGHTFACE_ARCFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.mnn`,
INSIGHTFACE_ARCFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.mnn`,
};
const MODEL_URL_CONFIG_COS_FACE = {
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
INSIGHTFACE_COSFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.mnn`,
INSIGHTFACE_COSFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.mnn`,
INSIGHTFACE_COSFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.mnn`,
};
const MODEL_URL_CONFIG_PARTIAL_FC = {
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
INSIGHTFACE_PARTIALFC_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.mnn`,
};
export class Insightface extends FaceRecognition {
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
public async doPredict(image: cv.Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
const input = this.input;
const output = this.output;
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
image = image.resize(input.shape[3], input.shape[2]);
if (option?.crop) image = cv.crop(image, option.crop);
image = cv.resize(image, input.shape[3], input.shape[2]);
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
const embedding = await this.session.run({
@ -37,7 +44,7 @@ export class Insightface extends FaceRecognition {
}
}).then(res => res[output.name]);
return new Array(...embedding);
return new Array(...embedding.data);
}
}

View File

@ -0,0 +1,4 @@
import * as ai from "./main";
export default ai;
export { ai };

View File

@ -0,0 +1,3 @@
export { deploy } from "./deploy";
export { backend } from "./backend";
export { setConfig as config } from "./backend/config";

View File

@ -1,9 +1,21 @@
import fs from "fs";
import cv from "@yizhi/cv";
import { deploy } from "./deploy";
import { cv } from "./cv";
import { faceidTestData } from "./test_data/faceid";
import path from "path";
import crypto from "crypto";
import ai from ".";
cv.config("ADDON_PATH", path.join(__dirname, "../build/cv.node"));
function loadImage(url: string) {
if (/https?:\/\//.test(url)) return fetch(url).then(res => res.arrayBuffer()).then(data => cv.imdecode(new Uint8Array(data)));
else return import("fs").then(fs => fs.promises.readFile(url)).then(buffer => cv.imdecode(buffer));
}
function rotate(im: cv.Mat, x: number, y: number, angle: number) {
return cv.warpAffine(im, cv.getRotationMatrix2D(x, y, angle, 1), im.cols, im.rows);
}
async function cacheImage(group: string, url: string) {
const _url = new URL(url);
@ -30,31 +42,33 @@ async function cacheImage(group: string, url: string) {
}
async function testGenderTest() {
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX");
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_MNN");
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
const image = await loadImage("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
const boxes = await facedet.predict(image);
if (!boxes.length) return console.error("未检测到人脸");
for (const [idx, box] of boxes.entries()) {
const res = await detector.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
const res = await detector.predict(image, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
console.log(`[${idx + 1}]`, res);
}
}
async function testFaceID() {
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
const faceid = await deploy.faceid.PartialFC.load();
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
console.log("初始化模型")
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
const faceid = await deploy.faceid.CosFace.load("INSIGHTFACE_COSFACE_R50_MNN");
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
console.log("初始化模型完成")
const { basic, tests } = faceidTestData.stars;
console.log("正在加载图片资源");
const basicImage = await cv.Mat.load(await cacheImage("faceid", basic.image));
const basicImage = await loadImage(await cacheImage("faceid", basic.image));
const testsImages: Record<string, cv.Mat[]> = {};
for (const [name, imgs] of Object.entries(tests)) {
testsImages[name] = await Promise.all(imgs.map(img => cacheImage("faceid", img).then(img => cv.Mat.load(img))));
testsImages[name] = await Promise.all(imgs.map(img => cacheImage("faceid", img).then(img => loadImage(img))));
}
console.log("正在检测基本数据");
@ -66,9 +80,9 @@ async function testFaceID() {
async function getEmbd(image: cv.Mat, box: deploy.facedet.FaceBox) {
box = box.toSquare();
const alignResult = await facealign.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
let faceImage = image.rotate(box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI);
return faceid.predict(faceImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
const alignResult = await facealign.predict(image, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
let faceImage = rotate(image, box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI);
return faceid.predict(faceImage, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
}
const basicEmbds: number[][] = [];
@ -111,36 +125,40 @@ async function testFaceID() {
async function testFaceAlign() {
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
image = image.rotate(image.width / 2, image.height / 2, 0);
let image = await loadImage("https://i0.hdslb.com/bfs/archive/64e47ec9fdac9e24bc2b49b5aaad5560da1bfe3e.jpg");
image = rotate(image, image.cols / 2, image.rows / 2, 0);
const face = await fd.predict(image).then(res => res[0].toSquare());
const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } });
const face = await fd.predict(image).then(res => {
console.log(res);
return res[0].toSquare()
});
const points = await fa.predict(image, { crop: { x: face.left, y: face.top, width: face.width, height: face.height } });
points.points.forEach((point, idx) => {
image.circle(face.left + point.x, face.top + point.y, 2);
cv.circle(image, face.left + point.x, face.top + point.y, 2, 0, 0, 0);
})
// const point = points.getPointsOf("rightEye")[1];
// image.circle(face.left + point.x, face.top + point.y, 2);
fs.writeFileSync("testdata/xx.jpg", image.encode(".jpg"));
// fs.writeFileSync("testdata/xx.jpg", image.encode(".jpg"));
cv.imwrite("testdata/xx.jpg", image);
let faceImage = image.rotate(face.centerX, face.centerY, -points.direction * 180 / Math.PI);
faceImage = faceImage.crop(face.left, face.top, face.width, face.height);
fs.writeFileSync("testdata/face.jpg", faceImage.encode(".jpg"));
let faceImage = rotate(image, face.centerX, face.centerY, -points.direction * 180 / Math.PI);
faceImage = cv.crop(faceImage, { x: face.left, y: face.top, width: face.width, height: face.height });
fs.writeFileSync("testdata/face.jpg", cv.imencode(".jpg", faceImage)!);
console.log(points);
console.log(points.direction * 180 / Math.PI);
}
async function test() {
await testGenderTest();
await testFaceID();
await testFaceAlign();
await ai.backend.downloadBackend("ort");
// await testGenderTest();
// await testFaceID();
// await testFaceAlign();
}
test().catch(err => {
console.error(err);
debugger
});

View File

@ -1,33 +0,0 @@
export namespace utils {
export function rgba2rgb<T extends Uint8Array | Float32Array>(data: T): T {
const pixelCount = data.length / 4;
const result = new (data.constructor as any)(pixelCount * 3) as T;
for (let i = 0; i < pixelCount; i++) {
result[i * 3 + 0] = data[i * 4 + 0];
result[i * 3 + 1] = data[i * 4 + 1];
result[i * 3 + 2] = data[i * 4 + 2];
}
return result;
}
export function rgb2bgr<T extends Uint8Array | Float32Array>(data: T): T {
const pixelCount = data.length / 3;
const result = new (data.constructor as any)(pixelCount * 3) as T;
for (let i = 0; i < pixelCount; i++) {
result[i * 3 + 0] = data[i * 3 + 2];
result[i * 3 + 1] = data[i * 3 + 1];
result[i * 3 + 2] = data[i * 3 + 0];
}
return result;
}
export function normalize(data: Uint8Array | Float32Array, mean: number[], std: number[]): Float32Array {
const result = new Float32Array(data.length);
for (let i = 0; i < data.length; i++) {
result[i] = (data[i] - mean[i % mean.length]) / std[i % std.length];
}
return result;
}
}

View File

@ -16,3 +16,4 @@ if (args.includes("--include")) runCmakeJS(["print-cmakejs-include"]);
else if (args.includes("--src")) runCmakeJS(["print-cmakejs-src"]);
else if (args.includes("--lib")) runCmakeJS(["print-cmakejs-lib"]);
else if (args.includes("--napi")) console.log(require("node-addon-api").include.replace(/^"/, "").replace(/"$/, ""));
else if (args.includes("--release")) console.log(require("../package.json").releaseVersion);

View File

@ -28,7 +28,6 @@ function assert(cond, message) {
const buildOptions = {
withMNN: findArg("with-mnn", true) ?? false,
withOpenCV: findArg("with-opencv", true) ?? false,
withONNX: findArg("with-onnx", true) ?? false,
buildType: findArg("build-type", false) ?? "Release",
proxy: findArg("proxy"),
@ -44,6 +43,8 @@ const spawnOption = {
}
};
function P(path) { return path.replace(/\\/g, "/"); }
function checkFile(...items) {
return fs.existsSync(path.resolve(...items));
}
@ -105,9 +106,6 @@ async function downloadFromURL(name, url, resolver) {
if (!checkFile(saveName)) {
console.log(`开始下载${name}, 地址:${url}`);
await fetch(url).then(res => {
console.log(res.status)
})
const result = spawnSync("curl", ["-o", saveName + ".cache", "-L", url, "-s", "-w", "%{http_code}"], { ...spawnOption, stdio: "pipe" });
assert(result.status == 0 && result.stdout.toString() == "200", `下载${name}失败`);
fs.renameSync(saveName + ".cache", saveName);
@ -151,37 +149,16 @@ async function main() {
"-DMNN_AVX512=ON",
"-DMNN_BUILD_TOOLS=ON",
"-DMNN_BUILD_CONVERTER=OFF",
"-DMNN_WIN_RUNTIME_MT=OFF",
"-DMNN_WIN_RUNTIME_MT=ON",
], (root) => [
`set(MNN_INCLUDE_DIR ${JSON.stringify(path.join(root, "include"))})`,
`set(MNN_LIB_DIR ${JSON.stringify(path.join(root, "lib"))})`,
`set(MNN_LIBS MNN)`,
`set(MNN_INCLUDE_DIR ${JSON.stringify(P(path.join(root, "include")))})`,
`set(MNN_LIB_DIR ${JSON.stringify(P(path.join(root, "lib")))})`,
`if(WIN32)`,
` set(MNN_LIBS \${MNN_LIB_DIR}/MNN.lib)`,
`else()`,
` set(MNN_LIBS \${MNN_LIB_DIR}/libMNN.a)`,
`endif()`,
].join("\n"));
//OpenCV
if (buildOptions.withOpenCV) cmakeBuildFromSource("OpenCV", "https://github.com/opencv/opencv.git", "4.11.0", null, [
"-DBUILD_SHARED_LIBS=OFF",
"-DBUILD_opencv_apps=OFF",
"-DBUILD_opencv_js=OFF",
"-DBUILD_opencv_python2=OFF",
"-DBUILD_opencv_python3=OFF",
"-DBUILD_ANDROID_PROJECTS=OFF",
"-DBUILD_ANDROID_EXAMPLES=OFF",
"-DBUILD_TESTS=OFF",
"-DBUILD_FAT_JAVA_LIB=OFF",
"-DBUILD_ANDROID_SERVICE=OFF",
"-DBUILD_JAVA=OFF",
"-DBUILD_PERF_TESTS=OFF"
], (root) => [
`set(OpenCV_STATIC ON)`,
os.platform() == "win32" ?
`include(${JSON.stringify(path.join(root, "OpenCVConfig.cmake"))})` :
`include(${JSON.stringify(path.join(root, "lib/cmake/opencv4/OpenCVConfig.cmake"))})`,
`set(OpenCV_INCLUDE_DIR \${OpenCV_INCLUDE_DIRS})`,
os.platform() == "win32" ?
"set(OpenCV_LIB_DIR ${OpenCV_LIB_PATH})" :
`set(OpenCV_LIB_DIR ${JSON.stringify(path.join(root, "lib"))})`,
// `set(OpenCV_LIBS OpenCV_LIBS)`,
].join("\n"))
//ONNXRuntime
if (buildOptions.withONNX && !checkFile(THIRDPARTY_DIR, "ONNXRuntime/config.cmake")) {
let url = "";
@ -229,9 +206,13 @@ async function main() {
fs.cpSync(path.join(savedir, dirname), path.join(THIRDPARTY_DIR, "ONNXRuntime"), { recursive: true });
});
fs.writeFileSync(path.join(THIRDPARTY_DIR, "ONNXRuntime/config.cmake"), [
`set(ONNXRuntime_INCLUDE_DIR ${JSON.stringify(path.join(THIRDPARTY_DIR, "ONNXRuntime/include"))})`,
`set(ONNXRuntime_LIB_DIR ${JSON.stringify(path.join(THIRDPARTY_DIR, "ONNXRuntime/lib"))})`,
`set(ONNXRuntime_LIBS onnxruntime)`,
`set(ONNXRuntime_INCLUDE_DIR ${JSON.stringify(P(path.join(THIRDPARTY_DIR, "ONNXRuntime/include")))})`,
`set(ONNXRuntime_LIB_DIR ${JSON.stringify(P(path.join(THIRDPARTY_DIR, "ONNXRuntime/lib")))})`,
`if(WIN32)`,
` set(ONNXRuntime_LIBS \${ONNXRuntime_LIB_DIR}/onnxruntime.lib)`,
`else()`,
` set(ONNXRuntime_LIBS \${ONNXRuntime_LIB_DIR}/libonnxruntime.a)`,
`endif()`,
].join("\n"));
}
// if (buildOptions.withONNX) cmakeBuildFromSource("ONNXRuntime", "https://github.com/csukuangfj/onnxruntime-build.git", "main", (name, repo, branch) => {

36
tool/convert_mnn.js Normal file
View File

@ -0,0 +1,36 @@
const fs = require("fs");
const path = require("path");
const { spawnSync } = require("child_process");
const CONVERTER_PATH = "C:/Develop/Libs/mnn/bin/MNNConvert.exe";
const MODEL_PATH = path.join(__dirname, "../models");
function onnx2mnn(model) {
const name = path.basename(model, ".onnx");
const outputFile = path.join(path.dirname(model), name + ".mnn");
if (fs.existsSync(outputFile)) console.log(`Skip ${name}`);
else {
const result = spawnSync(CONVERTER_PATH, [
"-f", "ONNX",
"--modelFile", model,
"--MNNModel", outputFile,
"--bizCode", "biz"
]);
if (result.status !== 0) {
console.error(`Failed to convert ${name}`);
console.log(result.stdout.toString());
}
else console.log(`Convert ${name} success`);
}
}
function resolveDir(dir) {
const files = fs.readdirSync(dir);
for (const file of files) {
const filepath = path.join(dir, file);
if (fs.statSync(filepath).isDirectory()) resolveDir(filepath);
else if (path.extname(file) == ".onnx") onnx2mnn(filepath);
}
}
resolveDir(MODEL_PATH)