diff --git a/CMakeLists.txt b/CMakeLists.txt index e648432..a89b960 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,7 @@ if(NOT DEFINED CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() - +set(NODE_COMMON_SOURCES) set(NODE_ADDON_FOUND OFF) include_directories(cxx) @@ -74,10 +74,15 @@ if(EXISTS ${MNN_CMAKE_FILE}) message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}") message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}") message(STATUS "MNN_LIBS: ${MNN_LIBS}") - include_directories(${MNN_INCLUDE_DIRS}) + include_directories(${MNN_INCLUDE_DIR}) link_directories(${MNN_LIB_DIR}) - add_compile_definitions(USE_MNN) - set(USE_MNN ON) + + if(NODE_ADDON_FOUND) + add_node_targert(mnn cxx/mnn/node.cc) + target_link_libraries(mnn ${MNN_LIBS}) + target_compile_definitions(mnn PUBLIC USE_MNN) + list(APPEND NODE_COMMON_SOURCES cxx/mnn/node.cc) + endif() endif() # OpenCV @@ -94,6 +99,7 @@ if(EXISTS ${OpenCV_CMAKE_FILE}) add_node_targert(cv cxx/cv/node.cc) target_link_libraries(cv ${OpenCV_LIBS}) target_compile_definitions(cv PUBLIC USE_OPENCV) + list(APPEND NODE_COMMON_SOURCES cxx/cv/node.cc) endif() endif() @@ -111,12 +117,34 @@ if(EXISTS ${ONNXRuntime_CMAKE_FILE}) add_node_targert(ort cxx/ort/node.cc) target_link_libraries(ort ${ONNXRuntime_LIBS}) target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME) + list(APPEND NODE_COMMON_SOURCES cxx/ort/node.cc) endif() endif() +# 统一的NodeJS插件 +if(NODE_ADDON_FOUND) + add_node_targert(addon cxx/node.cc) + target_sources(addon PRIVATE ${NODE_COMMON_SOURCES}) + target_compile_definitions(addon PUBLIC BUILD_MAIN_WORD) + # MNN + if(EXISTS ${MNN_CMAKE_FILE}) + target_link_libraries(addon ${MNN_LIBS}) + target_compile_definitions(addon PUBLIC USE_MNN) + endif() + # OnnxRuntime + if(EXISTS ${ONNXRuntime_CMAKE_FILE}) + target_link_libraries(addon ${ONNXRuntime_LIBS}) + target_compile_definitions(addon PUBLIC USE_ONNXRUNTIME) + endif() + # OpenCV + if(EXISTS ${OpenCV_CMAKE_FILE}) + target_link_libraries(addon ${OpenCV_LIBS}) + target_compile_definitions(addon PUBLIC USE_OPENCV) + endif() +endif() -if(MSVC) +if(MSVC AND NODE_ADDON_FOUND) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT") execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS}) endif() diff --git a/cxx/common/node.h b/cxx/common/node.h index 2cdaec2..30d600b 100644 --- a/cxx/common/node.h +++ b/cxx/common/node.h @@ -6,8 +6,7 @@ #include #define NODE_INIT_OBJECT(name, function) \ - do \ - { \ + do { \ auto obj = Napi::Object::New(env); \ function(env, obj); \ exports.Set(Napi::String::New(env, #name), obj); \ @@ -21,4 +20,13 @@ inline uint64_t __node_ptr_of__(Napi::Value value) #define NODE_PTR_OF(type, value) (reinterpret_cast(__node_ptr_of__(value))) + +inline void *dataFromTypedArray(const Napi::Value &val, size_t &bytes) +{ + auto arr = val.As(); + auto data = static_cast(arr.ArrayBuffer().Data()); + bytes = arr.ByteLength(); + return static_cast(data + arr.ByteOffset()); +} + #endif diff --git a/cxx/common/tensor.h b/cxx/common/tensor.h new file mode 100644 index 0000000..901f096 --- /dev/null +++ b/cxx/common/tensor.h @@ -0,0 +1,18 @@ +#ifndef __COMMON_TENSOR_H__ +#define __COMMON_TENSOR_H__ + +enum class TensorDataType { + Unknown, + Float32, + Float64, + Int32, + Uint32, + Int16, + Uint16, + Int8, + Uint8, + Int64, + Uint64, +}; + +#endif \ No newline at end of file diff --git a/cxx/cv/node.cc b/cxx/cv/node.cc index d49319f..daec886 100644 --- a/cxx/cv/node.cc +++ b/cxx/cv/node.cc @@ -134,7 +134,7 @@ void InstallOpenCVAPI(Env env, Object exports) CVMat::Init(env, exports); } -#ifdef USE_OPENCV +#if defined(USE_OPENCV) && not defined(BUILD_MAIN_WORD) static Object Init(Env env, Object exports) { InstallOpenCVAPI(env, exports); diff --git a/cxx/mnn/node.cc b/cxx/mnn/node.cc new file mode 100644 index 0000000..8612264 --- /dev/null +++ b/cxx/mnn/node.cc @@ -0,0 +1,233 @@ +#include +#include +#include +#include +#include +#include +#include "common/tensor.h" +#include "node.h" + +using namespace Napi; + +#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast(napi_writable | napi_configurable)) + +static const std::map DATA_TYPE_MAP = { + {TensorDataType::Float32, halide_type_of()}, + {TensorDataType::Float64, halide_type_of()}, + {TensorDataType::Int32, halide_type_of()}, + {TensorDataType::Uint32, halide_type_of()}, + {TensorDataType::Int16, halide_type_of()}, + {TensorDataType::Uint16, halide_type_of()}, + {TensorDataType::Int8, halide_type_of()}, + {TensorDataType::Uint8, halide_type_of()}, + {TensorDataType::Int64, halide_type_of()}, + {TensorDataType::Uint64, halide_type_of()}, +}; + +static size_t getShapeSize(const std::vector &shape) +{ + if (!shape.size()) return 0; + size_t sum = 1; + for (auto i : shape) { + if (i > 1) sum *= i; + }; + return sum; +} + +class MNNSessionRunWorker : public AsyncWorker { + public: + MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session) + : AsyncWorker(callback), interpreter_(interpreter), session_(session) {} + + ~MNNSessionRunWorker() + { + interpreter_->releaseSession(session_); + } + + void Execute() + { + interpreter_->resizeSession(session_); + if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) { + SetError(std::string("Run session failed")); + } + } + + void OnOK() + { + if (HasError()) { + Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()}); + } + else { + auto result = Object::New(Env()); + for (auto it : interpreter_->getSessionOutputAll(session_)) { + auto tensor = it.second; + auto buffer = ArrayBuffer::New(Env(), tensor->size()); + memcpy(buffer.Data(), tensor->host(), tensor->size()); + result.Set(it.first, buffer); + } + Callback().Call({Env().Undefined(), result}); + } + } + + void SetInput(const std::string &name, TensorDataType dataType, const std::vector &shape, void *data, size_t dataBytes) + { + auto tensor = interpreter_->getSessionInput(session_, name.c_str()); + if (!tensor) { + SetError(std::string("input name #" + name + " not exists")); + return; + } + + halide_type_t type = tensor->getType(); + if (dataType != TensorDataType::Unknown) { + auto it = DATA_TYPE_MAP.find(dataType); + if (it != DATA_TYPE_MAP.end()) type = it->second; + } + + if (shape.size()) interpreter_->resizeTensor(tensor, shape); + + auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8; + if (tensorBytes != dataBytes) { + SetError(std::string("input name #" + name + " data size not matched")); + return; + } + + auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE); + tensor->copyFromHostTensor(hostTensor); + delete hostTensor; + } + + inline void SetError(const std::string &what) { errorMessage_ = what; } + inline bool HasError() { return errorMessage_.size() > 0; } + + private: + MNN::Interpreter *interpreter_; + MNN::Session *session_; + std::string errorMessage_; +}; + +class MNNSession : public ObjectWrap { + public: + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + Function func = DefineClass(env, "MNNSession", { + SESSION_INSTANCE_METHOD(GetInputsInfo), + SESSION_INSTANCE_METHOD(GetOutputsInfo), + SESSION_INSTANCE_METHOD(Run), + }); + FunctionReference *constructor = new FunctionReference(); + *constructor = Napi::Persistent(func); + exports.Set("MNNSession", func); + env.SetInstanceData(constructor); + return exports; + } + + MNNSession(const CallbackInfo &info) + : ObjectWrap(info) + { + try { + if (info[0].IsString()) { + interpreter_ = MNN::Interpreter::createFromFile(info[0].As().Utf8Value().c_str()); + } + else if (info[0].IsTypedArray()) { + size_t bufferBytes; + auto buffer = dataFromTypedArray(info[0], bufferBytes); + interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes); + } + else interpreter_ = nullptr; + + if (interpreter_) { + backendConfig_.precision = MNN::BackendConfig::Precision_High; + backendConfig_.power = MNN::BackendConfig::Power_High; + scheduleConfig_.type = MNN_FORWARD_CPU; + scheduleConfig_.numThread = 1; + scheduleConfig_.backendConfig = &backendConfig_; + session_ = interpreter_->createSession(scheduleConfig_); + } + else session_ = nullptr; + } + catch (std::exception &e) { + Error::New(info.Env(), e.what()).ThrowAsJavaScriptException(); + } + } + + ~MNNSession() {} + + Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); } + + Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); } + + Napi::Value Run(const Napi::CallbackInfo &info) + { + auto worker = new MNNSessionRunWorker(info[1].As(), interpreter_, interpreter_->createSession(scheduleConfig_)); + auto inputArgument = info[0].As(); + for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) { + auto name = (*it).first.As().Utf8Value(); + auto inputOption = static_cast((*it).second).As(); + auto type = inputOption.Has("type") ? static_cast(inputOption.Get("type").As().Int32Value()) : TensorDataType::Unknown; + size_t dataByteLen; + void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); + auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : std::vector(); + worker->SetInput(name, type, shape, data, dataByteLen); + } + worker->Queue(); + return info.Env().Undefined(); + } + + private: + Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map &tensors) + { + auto result = Object::New(env); + for (auto it : tensors) { + auto item = Object::New(env); + auto name = it.first; + auto shape = it.second->shape(); + auto type = it.second->getType(); + TensorDataType dataType = TensorDataType::Unknown; + for (auto dt : DATA_TYPE_MAP) { + if (dt.second == type) { + dataType = dt.first; + break; + } + } + auto shapeArr = Array::New(env, shape.size()); + for (size_t i = 0; i < shape.size(); i++) { + shapeArr.Set(i, Number::New(env, shape[i])); + } + item.Set("name", String::New(env, name)); + item.Set("shape", shapeArr); + item.Set("type", Number::New(env, static_cast(dataType))); + result.Set(name, item); + } + return result; + } + + std::vector GetShapeFromJavascript(const Napi::Array &shape) + { + std::vector result; + for (size_t i = 0; i < shape.Length(); i++) { + result.push_back(shape.Get(i).As().Int32Value()); + } + return result; + } + + private: + MNN::Interpreter *interpreter_; + MNN::Session *session_; + MNN::BackendConfig backendConfig_; + MNN::ScheduleConfig scheduleConfig_; +}; + +void InstallMNNAPI(Napi::Env env, Napi::Object exports) +{ + MNNSession::Init(env, exports); +} + + +#if defined(USE_MNN) && not defined(BUILD_MAIN_WORD) +static Object Init(Env env, Object exports) +{ + InstallMNNAPI(env, exports); + return exports; +} +NODE_API_MODULE(addon, Init) +#endif diff --git a/cxx/mnn/node.h b/cxx/mnn/node.h new file mode 100644 index 0000000..b30158b --- /dev/null +++ b/cxx/mnn/node.h @@ -0,0 +1,8 @@ +#ifndef __MNN_NODE_H__ +#define __MNN_NODE_H__ + +#include "common/node.h" + +void InstallMNNAPI(Napi::Env env, Napi::Object exports); + +#endif diff --git a/cxx/mnn/session.h b/cxx/mnn/session.h deleted file mode 100644 index deb1fb0..0000000 --- a/cxx/mnn/session.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef __MNN_SESSION_H__ -#define __MNN_SESSION_H__ - -#include -#include - -#include "common/session.h" - -namespace ai -{ - class MNNSession : public Session - { - public: - MNNSession(const void *modelData, size_t size); - ~MNNSession(); - }; -} - -#endif diff --git a/cxx/node.cc b/cxx/node.cc index afded5d..b502d65 100644 --- a/cxx/node.cc +++ b/cxx/node.cc @@ -1,65 +1,28 @@ -// #include -// #include -// #include "cv/node.h" -// #ifdef USE_ORT -// #include "ort/node.h" -// #endif +#include "common/node.h" +#include "cv/node.h" +#include "mnn/node.h" +#include "ort/node.h" -// using namespace Napi; +using namespace Napi; -// class TestWork : public AsyncWorker -// { -// public: -// TestWork(const Napi::Function &callback, int value) : Napi::AsyncWorker(callback), val_(value) {} -// ~TestWork() {} +#if defined(BUILD_MAIN_WORD) +Object Init(Env env, Object exports) +{ +// OpenCV +#ifdef USE_OPENCV + printf("use opencv\n"); + InstallOpenCVAPI(env, exports); +#endif +// OnnxRuntime +#ifdef USE_ONNXRUNTIME + InstallOrtAPI(env, exports); +#endif +// MNN +#ifdef USE_MNN + InstallMNNAPI(env, exports); +#endif -// void Execute() -// { -// printf("the worker-thread doing! %d \n", val_); -// sleep(3); -// printf("the worker-thread done! %d \n", val_); -// } - -// void OnOK() -// { -// Callback().Call({Env().Undefined(), Number::New(Env(), 0)}); -// } - -// private: -// int val_; -// }; - -// Value test(const CallbackInfo &info) -// { -// // ai::ORTSession(nullptr, 0); - -// // Function callback = info[1].As(); -// // TestWork *work = new TestWork(callback, info[0].As().Int32Value()); -// // work->Queue(); -// return info.Env().Undefined(); -// } - -// Object Init(Env env, Object exports) -// { -// //OpenCV -// NODE_INIT_OBJECT(cv, InstallOpenCVAPI); -// //OnnxRuntime -// #ifdef USE_ORT -// NODE_INIT_OBJECT(ort, InstallOrtAPI); -// #endif - -// Napi::Number::New(env, 0); - -// #define ADD_FUNCTION(name) exports.Set(Napi::String::New(env, #name), Napi::Function::New(env, name)) -// // ADD_FUNCTION(facedetPredict); -// // ADD_FUNCTION(facedetRelease); - -// // ADD_FUNCTION(faceRecognitionCreate); -// // ADD_FUNCTION(faceRecognitionPredict); -// // ADD_FUNCTION(faceRecognitionRelease); - -// // ADD_FUNCTION(getDistance); -// #undef ADD_FUNCTION -// return exports; -// } -// NODE_API_MODULE(addon, Init) \ No newline at end of file + return exports; +} +NODE_API_MODULE(addon, Init) +#endif \ No newline at end of file diff --git a/cxx/ort/node.cc b/cxx/ort/node.cc index 07b02ac..3cb85c2 100644 --- a/cxx/ort/node.cc +++ b/cxx/ort/node.cc @@ -2,6 +2,7 @@ #include #include #include "node.h" +#include "common/tensor.h" #ifdef WIN32 #include @@ -13,51 +14,31 @@ using namespace Napi; #define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast(napi_writable | napi_configurable)) -static ONNXTensorElementDataType getDataTypeFromString(const std::string &name) -{ - static const std::map dataTypeNameMap = { - {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, - {"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, - {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, - {"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, - {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, - {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, - {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, - {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, - {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, - {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, - {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, - {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}, - }; - auto it = dataTypeNameMap.find(name); - return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second; -} +static const std::map DATA_TYPE_MAP = { + {TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, + {TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, + {TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, + {TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, + {TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, + {TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, + {TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, + {TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, + {TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, + {TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}, +}; -static size_t getDataTypeSize(ONNXTensorElementDataType type) -{ - static const std::map dataTypeSizeMap = { - {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8}, - {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8}, - }; - auto it = dataTypeSizeMap.find(type); - return (it == dataTypeSizeMap.end()) ? 0 : it->second; -} - -static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes) -{ - auto arr = val.As(); - auto data = static_cast(arr.ArrayBuffer().Data()); - bytes = arr.ByteLength(); - return static_cast(data + arr.ByteOffset()); -} +static const std::map DATA_TYPE_SIZE_MAP = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8}, +}; class OrtSessionNodeInfo { public: @@ -67,7 +48,22 @@ class OrtSessionNodeInfo { inline const std::string &GetName() const { return name_; } inline const std::vector &GetShape() const { return shape_; } inline ONNXTensorElementDataType GetType() const { return type_; } - inline size_t GetElementSize() const { return getDataTypeSize(type_); } + inline size_t GetElementSize() const + { + auto it = DATA_TYPE_SIZE_MAP.find(type_); + return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second; + } + TensorDataType GetDataType() const + { + auto datatype = TensorDataType::Unknown; + for (auto it : DATA_TYPE_MAP) { + if (it.second == type_) { + datatype = it.first; + break; + } + } + return datatype; + } size_t GetElementCount() const { if (!shape_.size()) return 0; @@ -115,7 +111,8 @@ class OrtSessionRunWorker : public AsyncWorker { for (int i = 0; i < outputNames_.size(); ++i) { size_t bytes = outputElementBytes_[i]; Ort::Value &value = outputValues_[i]; - auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes); + auto buffer = ArrayBuffer::New(Env(), bytes); + memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes); result.Set(String::New(Env(), outputNames_[i]), buffer); } Callback().Call({Env().Undefined(), result}); @@ -236,7 +233,12 @@ class OrtSession : public ObjectWrap { auto inputOption = static_cast((*it).second).As(); if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name))); else { - auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As().Utf8Value()) : input->GetType(); + auto type = input->GetType(); + if (inputOption.Has("type")) { + auto t = static_cast(inputOption.Get("type").As().Int32Value()); + auto it = DATA_TYPE_MAP.find(t); + if (it != DATA_TYPE_MAP.end()) type = it->second; + } size_t dataByteLen; void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : input->GetShape(); @@ -279,7 +281,7 @@ class OrtSession : public ObjectWrap { auto &node = *it.second; auto item = Object::New(env); item.Set(String::New(env, "name"), String::New(env, node.GetName())); - item.Set(String::New(env, "type"), Number::New(env, node.GetType())); + item.Set(String::New(env, "type"), Number::New(env, static_cast(node.GetDataType()))); auto &shapeVec = node.GetShape(); auto shape = Array::New(env, shapeVec.size()); for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i])); @@ -303,7 +305,7 @@ void InstallOrtAPI(Napi::Env env, Napi::Object exports) OrtSession::Init(env, exports); } -#ifdef USE_ONNXRUNTIME +#if defined(USE_ONNXRUNTIME) && not defined(BUILD_MAIN_WORD) static Object Init(Env env, Object exports) { InstallOrtAPI(env, exports); diff --git a/package.json b/package.json index af8d005..1f3b27c 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "name": "ai-box", + "name": "@yizhi/ai", "version": "1.0.0", "main": "index.js", "scripts": { diff --git a/src/backend/common/session.ts b/src/backend/common/session.ts index 46d10a4..9c479ef 100644 --- a/src/backend/common/session.ts +++ b/src/backend/common/session.ts @@ -6,12 +6,26 @@ export interface SessionNodeInfo { shape: number[] } -export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64" +export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64" + +export enum DataType { + Unknown, + Float32, + Float64, + Int32, + Uint32, + Int16, + Uint16, + Int8, + Uint8, + Int64, + Uint64, +} export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array export interface SessionRunInputOption { - type?: SessionNodeType + type?: DataTypeString data: SessionNodeData shape?: number[] } @@ -26,3 +40,36 @@ export abstract class CommonSession { export function isTypedArray(val: any): val is SessionNodeData { return val?.buffer instanceof ArrayBuffer; } + +export function dataTypeFrom(str: DataTypeString) { + return { + "float32": DataType.Float32, + "float64": DataType.Float64, + "float": DataType.Float32, + "double": DataType.Float64, + "int32": DataType.Int32, + "uint32": DataType.Uint32, + "int16": DataType.Int16, + "uint16": DataType.Uint16, + "int8": DataType.Int8, + "uint8": DataType.Uint8, + "int64": DataType.Int64, + "uint64": DataType.Uint64, + }[str] ?? DataType.Unknown; +} + +export function dataTypeToString(type: DataType): DataTypeString | null { + switch (type) { + case DataType.Float32: return "float32"; + case DataType.Float64: return "float64"; + case DataType.Int32: return "int32"; + case DataType.Uint32: return "uint32"; + case DataType.Int16: return "int16"; + case DataType.Uint16: return "uint16"; + case DataType.Int8: return "int8"; + case DataType.Uint8: return "uint8"; + case DataType.Int64: return "int64"; + case DataType.Uint64: return "uint64"; + default: return null; + } +} \ No newline at end of file diff --git a/src/backend/index.ts b/src/backend/index.ts index a6d283b..76e42f0 100644 --- a/src/backend/index.ts +++ b/src/backend/index.ts @@ -1 +1 @@ -export * as backend from "./main"; \ No newline at end of file +export * as backend from "./main"; diff --git a/src/backend/main.ts b/src/backend/main.ts index cf9fde7..a573c1a 100644 --- a/src/backend/main.ts +++ b/src/backend/main.ts @@ -1,2 +1,3 @@ export * as common from "./common"; export * as ort from "./ort"; +export * as mnn from "./mnn"; \ No newline at end of file diff --git a/src/backend/mnn/index.ts b/src/backend/mnn/index.ts new file mode 100644 index 0000000..31dadca --- /dev/null +++ b/src/backend/mnn/index.ts @@ -0,0 +1 @@ +export { MNNSession as Session } from "./session"; \ No newline at end of file diff --git a/src/backend/mnn/session.ts b/src/backend/mnn/session.ts new file mode 100644 index 0000000..e523748 --- /dev/null +++ b/src/backend/mnn/session.ts @@ -0,0 +1,30 @@ +import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; + +export class MNNSession extends CommonSession { + #session: any + #inputs: Record | null = null; + #outputs: Record | null = null; + + public constructor(modelData: Uint8Array) { + super(); + const addon = require("../../../build/mnn.node") + this.#session = new addon.MNNSession(modelData); + } + + public run(inputs: Record): Promise> { + const inputArgs: Record = {}; + for (const [name, option] of Object.entries(inputs)) { + if (isTypedArray(option)) inputArgs[name] = { data: option } + else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined }; + } + return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => { + if (err) return reject(err); + const result: Record = {}; + for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer); + resolve(result); + })) + } + public get inputs(): Record { return this.#inputs ??= this.#session.GetInputsInfo(); } + public get outputs(): Record { return this.#outputs ??= this.#session.GetOutputsInfo(); } + +} \ No newline at end of file diff --git a/src/backend/ort/session.ts b/src/backend/ort/session.ts index c17af37..76c46d8 100644 --- a/src/backend/ort/session.ts +++ b/src/backend/ort/session.ts @@ -1,4 +1,4 @@ -import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; +import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; export class OrtSession extends CommonSession { #session: any; @@ -19,7 +19,7 @@ export class OrtSession extends CommonSession { const inputArgs: Record = {}; for (const [name, option] of Object.entries(inputs)) { if (isTypedArray(option)) inputArgs[name] = { data: option } - else inputArgs[name] = option; + else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined }; } return new Promise>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => { diff --git a/src/deploy/common/model.ts b/src/deploy/common/model.ts index c1621e8..1fe418a 100644 --- a/src/deploy/common/model.ts +++ b/src/deploy/common/model.ts @@ -45,6 +45,14 @@ export abstract class Model { return new this(new backend.ort.Session(modelData as Uint8Array)); } + public static async fromMNN(this: ModelConstructor, modelData: Uint8Array | string) { + if (typeof modelData === "string") { + if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer)); + else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string)); + } + return new this(new backend.mnn.Session(modelData as Uint8Array)); + } + protected static async cacheModel(this: ModelConstructor, url: string, option?: ModelCacheOption): Promise> { //初始化目录 const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]); @@ -111,6 +119,7 @@ export abstract class Model { let model: T | undefined = undefined; if (option?.createModel) { if (modelType === "onnx") model = (this as any).fromOnnx(modelPath); + else if (modelType == "mnn") model = (this as any).fromMNN(modelPath); } return { modelPath, modelType, model: model as any } diff --git a/src/deploy/facealign/landmark1000.ts b/src/deploy/facealign/landmark1000.ts index 3dde638..86a2e87 100644 --- a/src/deploy/facealign/landmark1000.ts +++ b/src/deploy/facealign/landmark1000.ts @@ -20,6 +20,7 @@ class FaceLandmark1000Result extends FaceAlignmentResult { const MODEL_URL_CONFIG = { FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`, + FACELANDMARK1000_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.mnn`, }; export class FaceLandmark1000 extends Model { diff --git a/src/deploy/facealign/pfld.ts b/src/deploy/facealign/pfld.ts index 0e34024..ab06eaf 100644 --- a/src/deploy/facealign/pfld.ts +++ b/src/deploy/facealign/pfld.ts @@ -21,6 +21,9 @@ const MODEL_URL_CONFIG = { PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`, PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`, PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`, + PFLD_106_LITE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.mnn`, + PFLD_106_V2_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.mnn`, + PFLD_106_V3_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.mnn`, }; export class PFLD extends Model { diff --git a/src/deploy/faceattr/gender-age.ts b/src/deploy/faceattr/gender-age.ts index c3b8c76..db1c281 100644 --- a/src/deploy/faceattr/gender-age.ts +++ b/src/deploy/faceattr/gender-age.ts @@ -12,6 +12,7 @@ export interface GenderAgePredictResult { const MODEL_URL_CONFIG = { INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`, + INSIGHT_GENDER_AGE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.mnn`, }; export class GenderAge extends Model { diff --git a/src/deploy/facedet/yolov5.ts b/src/deploy/facedet/yolov5.ts index 982a54f..6666a7d 100644 --- a/src/deploy/facedet/yolov5.ts +++ b/src/deploy/facedet/yolov5.ts @@ -4,6 +4,7 @@ import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common"; const MODEL_URL_CONFIG = { YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`, + YOLOV5S_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.mnn`, }; export class Yolov5Face extends FaceDetector { diff --git a/src/deploy/faceid/adaface.ts b/src/deploy/faceid/adaface.ts index fa0974b..7462929 100644 --- a/src/deploy/faceid/adaface.ts +++ b/src/deploy/faceid/adaface.ts @@ -4,6 +4,7 @@ import { FaceRecognition, FaceRecognitionPredictOption } from "./common"; const MODEL_URL_CONFIG = { MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`, + MOBILEFACENET_ADAFACE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.mnn`, }; export class AdaFace extends FaceRecognition { diff --git a/src/deploy/faceid/insightface.ts b/src/deploy/faceid/insightface.ts index e08b78d..ac19bd6 100644 --- a/src/deploy/faceid/insightface.ts +++ b/src/deploy/faceid/insightface.ts @@ -7,16 +7,23 @@ const MODEL_URL_CONFIG_ARC_FACE = { INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`, INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`, INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`, + INSIGHTFACE_ARCFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.mnn`, + INSIGHTFACE_ARCFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.mnn`, + INSIGHTFACE_ARCFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.mnn`, }; const MODEL_URL_CONFIG_COS_FACE = { INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`, INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`, INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`, INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`, + INSIGHTFACE_COSFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.mnn`, + INSIGHTFACE_COSFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.mnn`, + INSIGHTFACE_COSFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.mnn`, }; const MODEL_URL_CONFIG_PARTIAL_FC = { INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`, INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`, + INSIGHTFACE_PARTIALFC_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.mnn`, }; diff --git a/src/test.ts b/src/test.ts index 67ce90f..929825a 100644 --- a/src/test.ts +++ b/src/test.ts @@ -30,8 +30,8 @@ async function cacheImage(group: string, url: string) { } async function testGenderTest() { - const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); - const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX"); + const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN"); + const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_MNN"); const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg"); const boxes = await facedet.predict(image); @@ -44,9 +44,11 @@ async function testGenderTest() { } async function testFaceID() { - const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); - const faceid = await deploy.faceid.PartialFC.load(); - const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX"); + console.log("初始化模型") + const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN"); + const faceid = await deploy.faceid.CosFace.load("INSIGHTFACE_COSFACE_R50_MNN"); + const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN"); + console.log("初始化模型完成") const { basic, tests } = faceidTestData.stars; @@ -110,8 +112,8 @@ async function testFaceID() { } async function testFaceAlign() { - const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX"); - const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX"); + const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN"); + const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN"); // const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX"); let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712"); image = image.rotate(image.width / 2, image.height / 2, 0); @@ -142,5 +144,4 @@ async function test() { test().catch(err => { console.error(err); - debugger -}); \ No newline at end of file +});