增加MNN模型支持
This commit is contained in:
@ -10,7 +10,7 @@ if(NOT DEFINED CMAKE_BUILD_TYPE)
|
|||||||
set(CMAKE_BUILD_TYPE Release)
|
set(CMAKE_BUILD_TYPE Release)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(NODE_COMMON_SOURCES)
|
||||||
set(NODE_ADDON_FOUND OFF)
|
set(NODE_ADDON_FOUND OFF)
|
||||||
|
|
||||||
include_directories(cxx)
|
include_directories(cxx)
|
||||||
@ -74,10 +74,15 @@ if(EXISTS ${MNN_CMAKE_FILE})
|
|||||||
message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}")
|
message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}")
|
||||||
message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}")
|
message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}")
|
||||||
message(STATUS "MNN_LIBS: ${MNN_LIBS}")
|
message(STATUS "MNN_LIBS: ${MNN_LIBS}")
|
||||||
include_directories(${MNN_INCLUDE_DIRS})
|
include_directories(${MNN_INCLUDE_DIR})
|
||||||
link_directories(${MNN_LIB_DIR})
|
link_directories(${MNN_LIB_DIR})
|
||||||
add_compile_definitions(USE_MNN)
|
|
||||||
set(USE_MNN ON)
|
if(NODE_ADDON_FOUND)
|
||||||
|
add_node_targert(mnn cxx/mnn/node.cc)
|
||||||
|
target_link_libraries(mnn ${MNN_LIBS})
|
||||||
|
target_compile_definitions(mnn PUBLIC USE_MNN)
|
||||||
|
list(APPEND NODE_COMMON_SOURCES cxx/mnn/node.cc)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# OpenCV
|
# OpenCV
|
||||||
@ -94,6 +99,7 @@ if(EXISTS ${OpenCV_CMAKE_FILE})
|
|||||||
add_node_targert(cv cxx/cv/node.cc)
|
add_node_targert(cv cxx/cv/node.cc)
|
||||||
target_link_libraries(cv ${OpenCV_LIBS})
|
target_link_libraries(cv ${OpenCV_LIBS})
|
||||||
target_compile_definitions(cv PUBLIC USE_OPENCV)
|
target_compile_definitions(cv PUBLIC USE_OPENCV)
|
||||||
|
list(APPEND NODE_COMMON_SOURCES cxx/cv/node.cc)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -111,12 +117,34 @@ if(EXISTS ${ONNXRuntime_CMAKE_FILE})
|
|||||||
add_node_targert(ort cxx/ort/node.cc)
|
add_node_targert(ort cxx/ort/node.cc)
|
||||||
target_link_libraries(ort ${ONNXRuntime_LIBS})
|
target_link_libraries(ort ${ONNXRuntime_LIBS})
|
||||||
target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME)
|
target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME)
|
||||||
|
list(APPEND NODE_COMMON_SOURCES cxx/ort/node.cc)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# 统一的NodeJS插件
|
||||||
|
if(NODE_ADDON_FOUND)
|
||||||
|
add_node_targert(addon cxx/node.cc)
|
||||||
|
target_sources(addon PRIVATE ${NODE_COMMON_SOURCES})
|
||||||
|
target_compile_definitions(addon PUBLIC BUILD_MAIN_WORD)
|
||||||
|
# MNN
|
||||||
|
if(EXISTS ${MNN_CMAKE_FILE})
|
||||||
|
target_link_libraries(addon ${MNN_LIBS})
|
||||||
|
target_compile_definitions(addon PUBLIC USE_MNN)
|
||||||
|
endif()
|
||||||
|
# OnnxRuntime
|
||||||
|
if(EXISTS ${ONNXRuntime_CMAKE_FILE})
|
||||||
|
target_link_libraries(addon ${ONNXRuntime_LIBS})
|
||||||
|
target_compile_definitions(addon PUBLIC USE_ONNXRUNTIME)
|
||||||
|
endif()
|
||||||
|
# OpenCV
|
||||||
|
if(EXISTS ${OpenCV_CMAKE_FILE})
|
||||||
|
target_link_libraries(addon ${OpenCV_LIBS})
|
||||||
|
target_compile_definitions(addon PUBLIC USE_OPENCV)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC AND NODE_ADDON_FOUND)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
|
||||||
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
||||||
endif()
|
endif()
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <napi.h>
|
#include <napi.h>
|
||||||
|
|
||||||
#define NODE_INIT_OBJECT(name, function) \
|
#define NODE_INIT_OBJECT(name, function) \
|
||||||
do \
|
do { \
|
||||||
{ \
|
|
||||||
auto obj = Napi::Object::New(env); \
|
auto obj = Napi::Object::New(env); \
|
||||||
function(env, obj); \
|
function(env, obj); \
|
||||||
exports.Set(Napi::String::New(env, #name), obj); \
|
exports.Set(Napi::String::New(env, #name), obj); \
|
||||||
@ -21,4 +20,13 @@ inline uint64_t __node_ptr_of__(Napi::Value value)
|
|||||||
|
|
||||||
#define NODE_PTR_OF(type, value) (reinterpret_cast<type *>(__node_ptr_of__(value)))
|
#define NODE_PTR_OF(type, value) (reinterpret_cast<type *>(__node_ptr_of__(value)))
|
||||||
|
|
||||||
|
|
||||||
|
inline void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
|
||||||
|
{
|
||||||
|
auto arr = val.As<Napi::TypedArray>();
|
||||||
|
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
|
||||||
|
bytes = arr.ByteLength();
|
||||||
|
return static_cast<void *>(data + arr.ByteOffset());
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
18
cxx/common/tensor.h
Normal file
18
cxx/common/tensor.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#ifndef __COMMON_TENSOR_H__
|
||||||
|
#define __COMMON_TENSOR_H__
|
||||||
|
|
||||||
|
enum class TensorDataType {
|
||||||
|
Unknown,
|
||||||
|
Float32,
|
||||||
|
Float64,
|
||||||
|
Int32,
|
||||||
|
Uint32,
|
||||||
|
Int16,
|
||||||
|
Uint16,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Int64,
|
||||||
|
Uint64,
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -134,7 +134,7 @@ void InstallOpenCVAPI(Env env, Object exports)
|
|||||||
CVMat::Init(env, exports);
|
CVMat::Init(env, exports);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_OPENCV
|
#if defined(USE_OPENCV) && not defined(BUILD_MAIN_WORD)
|
||||||
static Object Init(Env env, Object exports)
|
static Object Init(Env env, Object exports)
|
||||||
{
|
{
|
||||||
InstallOpenCVAPI(env, exports);
|
InstallOpenCVAPI(env, exports);
|
||||||
|
233
cxx/mnn/node.cc
Normal file
233
cxx/mnn/node.cc
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <cstring>
|
||||||
|
#include <MNN/Interpreter.hpp>
|
||||||
|
#include <MNN/ImageProcess.hpp>
|
||||||
|
#include "common/tensor.h"
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
using namespace Napi;
|
||||||
|
|
||||||
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
||||||
|
|
||||||
|
static const std::map<TensorDataType, halide_type_t> DATA_TYPE_MAP = {
|
||||||
|
{TensorDataType::Float32, halide_type_of<float>()},
|
||||||
|
{TensorDataType::Float64, halide_type_of<double>()},
|
||||||
|
{TensorDataType::Int32, halide_type_of<int32_t>()},
|
||||||
|
{TensorDataType::Uint32, halide_type_of<uint32_t>()},
|
||||||
|
{TensorDataType::Int16, halide_type_of<int16_t>()},
|
||||||
|
{TensorDataType::Uint16, halide_type_of<uint16_t>()},
|
||||||
|
{TensorDataType::Int8, halide_type_of<int8_t>()},
|
||||||
|
{TensorDataType::Uint8, halide_type_of<uint8_t>()},
|
||||||
|
{TensorDataType::Int64, halide_type_of<int64_t>()},
|
||||||
|
{TensorDataType::Uint64, halide_type_of<uint64_t>()},
|
||||||
|
};
|
||||||
|
|
||||||
|
static size_t getShapeSize(const std::vector<int> &shape)
|
||||||
|
{
|
||||||
|
if (!shape.size()) return 0;
|
||||||
|
size_t sum = 1;
|
||||||
|
for (auto i : shape) {
|
||||||
|
if (i > 1) sum *= i;
|
||||||
|
};
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MNNSessionRunWorker : public AsyncWorker {
|
||||||
|
public:
|
||||||
|
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
||||||
|
: AsyncWorker(callback), interpreter_(interpreter), session_(session) {}
|
||||||
|
|
||||||
|
~MNNSessionRunWorker()
|
||||||
|
{
|
||||||
|
interpreter_->releaseSession(session_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Execute()
|
||||||
|
{
|
||||||
|
interpreter_->resizeSession(session_);
|
||||||
|
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
||||||
|
SetError(std::string("Run session failed"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnOK()
|
||||||
|
{
|
||||||
|
if (HasError()) {
|
||||||
|
Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto result = Object::New(Env());
|
||||||
|
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
||||||
|
auto tensor = it.second;
|
||||||
|
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
||||||
|
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
||||||
|
result.Set(it.first, buffer);
|
||||||
|
}
|
||||||
|
Callback().Call({Env().Undefined(), result});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInput(const std::string &name, TensorDataType dataType, const std::vector<int> &shape, void *data, size_t dataBytes)
|
||||||
|
{
|
||||||
|
auto tensor = interpreter_->getSessionInput(session_, name.c_str());
|
||||||
|
if (!tensor) {
|
||||||
|
SetError(std::string("input name #" + name + " not exists"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
halide_type_t type = tensor->getType();
|
||||||
|
if (dataType != TensorDataType::Unknown) {
|
||||||
|
auto it = DATA_TYPE_MAP.find(dataType);
|
||||||
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shape.size()) interpreter_->resizeTensor(tensor, shape);
|
||||||
|
|
||||||
|
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
||||||
|
if (tensorBytes != dataBytes) {
|
||||||
|
SetError(std::string("input name #" + name + " data size not matched"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE);
|
||||||
|
tensor->copyFromHostTensor(hostTensor);
|
||||||
|
delete hostTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetError(const std::string &what) { errorMessage_ = what; }
|
||||||
|
inline bool HasError() { return errorMessage_.size() > 0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
MNN::Interpreter *interpreter_;
|
||||||
|
MNN::Session *session_;
|
||||||
|
std::string errorMessage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MNNSession : public ObjectWrap<MNNSession> {
|
||||||
|
public:
|
||||||
|
static Napi::Object Init(Napi::Env env, Napi::Object exports)
|
||||||
|
{
|
||||||
|
Function func = DefineClass(env, "MNNSession", {
|
||||||
|
SESSION_INSTANCE_METHOD(GetInputsInfo),
|
||||||
|
SESSION_INSTANCE_METHOD(GetOutputsInfo),
|
||||||
|
SESSION_INSTANCE_METHOD(Run),
|
||||||
|
});
|
||||||
|
FunctionReference *constructor = new FunctionReference();
|
||||||
|
*constructor = Napi::Persistent(func);
|
||||||
|
exports.Set("MNNSession", func);
|
||||||
|
env.SetInstanceData<FunctionReference>(constructor);
|
||||||
|
return exports;
|
||||||
|
}
|
||||||
|
|
||||||
|
MNNSession(const CallbackInfo &info)
|
||||||
|
: ObjectWrap(info)
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
if (info[0].IsString()) {
|
||||||
|
interpreter_ = MNN::Interpreter::createFromFile(info[0].As<String>().Utf8Value().c_str());
|
||||||
|
}
|
||||||
|
else if (info[0].IsTypedArray()) {
|
||||||
|
size_t bufferBytes;
|
||||||
|
auto buffer = dataFromTypedArray(info[0], bufferBytes);
|
||||||
|
interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes);
|
||||||
|
}
|
||||||
|
else interpreter_ = nullptr;
|
||||||
|
|
||||||
|
if (interpreter_) {
|
||||||
|
backendConfig_.precision = MNN::BackendConfig::Precision_High;
|
||||||
|
backendConfig_.power = MNN::BackendConfig::Power_High;
|
||||||
|
scheduleConfig_.type = MNN_FORWARD_CPU;
|
||||||
|
scheduleConfig_.numThread = 1;
|
||||||
|
scheduleConfig_.backendConfig = &backendConfig_;
|
||||||
|
session_ = interpreter_->createSession(scheduleConfig_);
|
||||||
|
}
|
||||||
|
else session_ = nullptr;
|
||||||
|
}
|
||||||
|
catch (std::exception &e) {
|
||||||
|
Error::New(info.Env(), e.what()).ThrowAsJavaScriptException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~MNNSession() {}
|
||||||
|
|
||||||
|
Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); }
|
||||||
|
|
||||||
|
Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); }
|
||||||
|
|
||||||
|
Napi::Value Run(const Napi::CallbackInfo &info)
|
||||||
|
{
|
||||||
|
auto worker = new MNNSessionRunWorker(info[1].As<Function>(), interpreter_, interpreter_->createSession(scheduleConfig_));
|
||||||
|
auto inputArgument = info[0].As<Object>();
|
||||||
|
for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) {
|
||||||
|
auto name = (*it).first.As<String>().Utf8Value();
|
||||||
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
||||||
|
auto type = inputOption.Has("type") ? static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value()) : TensorDataType::Unknown;
|
||||||
|
size_t dataByteLen;
|
||||||
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
||||||
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : std::vector<int>();
|
||||||
|
worker->SetInput(name, type, shape, data, dataByteLen);
|
||||||
|
}
|
||||||
|
worker->Queue();
|
||||||
|
return info.Env().Undefined();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map<std::string, MNN::Tensor *> &tensors)
|
||||||
|
{
|
||||||
|
auto result = Object::New(env);
|
||||||
|
for (auto it : tensors) {
|
||||||
|
auto item = Object::New(env);
|
||||||
|
auto name = it.first;
|
||||||
|
auto shape = it.second->shape();
|
||||||
|
auto type = it.second->getType();
|
||||||
|
TensorDataType dataType = TensorDataType::Unknown;
|
||||||
|
for (auto dt : DATA_TYPE_MAP) {
|
||||||
|
if (dt.second == type) {
|
||||||
|
dataType = dt.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto shapeArr = Array::New(env, shape.size());
|
||||||
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
|
shapeArr.Set(i, Number::New(env, shape[i]));
|
||||||
|
}
|
||||||
|
item.Set("name", String::New(env, name));
|
||||||
|
item.Set("shape", shapeArr);
|
||||||
|
item.Set("type", Number::New(env, static_cast<int>(dataType)));
|
||||||
|
result.Set(name, item);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetShapeFromJavascript(const Napi::Array &shape)
|
||||||
|
{
|
||||||
|
std::vector<int> result;
|
||||||
|
for (size_t i = 0; i < shape.Length(); i++) {
|
||||||
|
result.push_back(shape.Get(i).As<Number>().Int32Value());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MNN::Interpreter *interpreter_;
|
||||||
|
MNN::Session *session_;
|
||||||
|
MNN::BackendConfig backendConfig_;
|
||||||
|
MNN::ScheduleConfig scheduleConfig_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void InstallMNNAPI(Napi::Env env, Napi::Object exports)
|
||||||
|
{
|
||||||
|
MNNSession::Init(env, exports);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(USE_MNN) && not defined(BUILD_MAIN_WORD)
|
||||||
|
static Object Init(Env env, Object exports)
|
||||||
|
{
|
||||||
|
InstallMNNAPI(env, exports);
|
||||||
|
return exports;
|
||||||
|
}
|
||||||
|
NODE_API_MODULE(addon, Init)
|
||||||
|
#endif
|
8
cxx/mnn/node.h
Normal file
8
cxx/mnn/node.h
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#ifndef __MNN_NODE_H__
|
||||||
|
#define __MNN_NODE_H__
|
||||||
|
|
||||||
|
#include "common/node.h"
|
||||||
|
|
||||||
|
void InstallMNNAPI(Napi::Env env, Napi::Object exports);
|
||||||
|
|
||||||
|
#endif
|
@ -1,19 +0,0 @@
|
|||||||
#ifndef __MNN_SESSION_H__
|
|
||||||
#define __MNN_SESSION_H__
|
|
||||||
|
|
||||||
#include <MNN/Interpreter.hpp>
|
|
||||||
#include <MNN/ImageProcess.hpp>
|
|
||||||
|
|
||||||
#include "common/session.h"
|
|
||||||
|
|
||||||
namespace ai
|
|
||||||
{
|
|
||||||
class MNNSession : public Session
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
MNNSession(const void *modelData, size_t size);
|
|
||||||
~MNNSession();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
87
cxx/node.cc
87
cxx/node.cc
@ -1,65 +1,28 @@
|
|||||||
// #include <unistd.h>
|
#include "common/node.h"
|
||||||
// #include <napi.h>
|
#include "cv/node.h"
|
||||||
// #include "cv/node.h"
|
#include "mnn/node.h"
|
||||||
// #ifdef USE_ORT
|
#include "ort/node.h"
|
||||||
// #include "ort/node.h"
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
// using namespace Napi;
|
using namespace Napi;
|
||||||
|
|
||||||
// class TestWork : public AsyncWorker
|
#if defined(BUILD_MAIN_WORD)
|
||||||
// {
|
Object Init(Env env, Object exports)
|
||||||
// public:
|
{
|
||||||
// TestWork(const Napi::Function &callback, int value) : Napi::AsyncWorker(callback), val_(value) {}
|
// OpenCV
|
||||||
// ~TestWork() {}
|
#ifdef USE_OPENCV
|
||||||
|
printf("use opencv\n");
|
||||||
|
InstallOpenCVAPI(env, exports);
|
||||||
|
#endif
|
||||||
|
// OnnxRuntime
|
||||||
|
#ifdef USE_ONNXRUNTIME
|
||||||
|
InstallOrtAPI(env, exports);
|
||||||
|
#endif
|
||||||
|
// MNN
|
||||||
|
#ifdef USE_MNN
|
||||||
|
InstallMNNAPI(env, exports);
|
||||||
|
#endif
|
||||||
|
|
||||||
// void Execute()
|
return exports;
|
||||||
// {
|
}
|
||||||
// printf("the worker-thread doing! %d \n", val_);
|
NODE_API_MODULE(addon, Init)
|
||||||
// sleep(3);
|
#endif
|
||||||
// printf("the worker-thread done! %d \n", val_);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void OnOK()
|
|
||||||
// {
|
|
||||||
// Callback().Call({Env().Undefined(), Number::New(Env(), 0)});
|
|
||||||
// }
|
|
||||||
|
|
||||||
// private:
|
|
||||||
// int val_;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// Value test(const CallbackInfo &info)
|
|
||||||
// {
|
|
||||||
// // ai::ORTSession(nullptr, 0);
|
|
||||||
|
|
||||||
// // Function callback = info[1].As<Function>();
|
|
||||||
// // TestWork *work = new TestWork(callback, info[0].As<Number>().Int32Value());
|
|
||||||
// // work->Queue();
|
|
||||||
// return info.Env().Undefined();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Object Init(Env env, Object exports)
|
|
||||||
// {
|
|
||||||
// //OpenCV
|
|
||||||
// NODE_INIT_OBJECT(cv, InstallOpenCVAPI);
|
|
||||||
// //OnnxRuntime
|
|
||||||
// #ifdef USE_ORT
|
|
||||||
// NODE_INIT_OBJECT(ort, InstallOrtAPI);
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
// Napi::Number::New(env, 0);
|
|
||||||
|
|
||||||
// #define ADD_FUNCTION(name) exports.Set(Napi::String::New(env, #name), Napi::Function::New(env, name))
|
|
||||||
// // ADD_FUNCTION(facedetPredict);
|
|
||||||
// // ADD_FUNCTION(facedetRelease);
|
|
||||||
|
|
||||||
// // ADD_FUNCTION(faceRecognitionCreate);
|
|
||||||
// // ADD_FUNCTION(faceRecognitionPredict);
|
|
||||||
// // ADD_FUNCTION(faceRecognitionRelease);
|
|
||||||
|
|
||||||
// // ADD_FUNCTION(getDistance);
|
|
||||||
// #undef ADD_FUNCTION
|
|
||||||
// return exports;
|
|
||||||
// }
|
|
||||||
// NODE_API_MODULE(addon, Init)
|
|
@ -2,6 +2,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <onnxruntime_cxx_api.h>
|
#include <onnxruntime_cxx_api.h>
|
||||||
#include "node.h"
|
#include "node.h"
|
||||||
|
#include "common/tensor.h"
|
||||||
|
|
||||||
#ifdef WIN32
|
#ifdef WIN32
|
||||||
#include <locale>
|
#include <locale>
|
||||||
@ -13,29 +14,20 @@ using namespace Napi;
|
|||||||
|
|
||||||
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
||||||
|
|
||||||
static ONNXTensorElementDataType getDataTypeFromString(const std::string &name)
|
static const std::map<TensorDataType, ONNXTensorElementDataType> DATA_TYPE_MAP = {
|
||||||
{
|
{TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
||||||
static const std::map<std::string, ONNXTensorElementDataType> dataTypeNameMap = {
|
{TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
||||||
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
{TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
||||||
{"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
{TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
||||||
{"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
{TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
||||||
{"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
{TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
||||||
{"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
{TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
||||||
{"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
{TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
||||||
{"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
{TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
||||||
{"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
{TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
||||||
{"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
};
|
||||||
{"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
|
||||||
{"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
|
||||||
{"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
|
||||||
};
|
|
||||||
auto it = dataTypeNameMap.find(name);
|
|
||||||
return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t getDataTypeSize(ONNXTensorElementDataType type)
|
static const std::map<ONNXTensorElementDataType, size_t> DATA_TYPE_SIZE_MAP = {
|
||||||
{
|
|
||||||
static const std::map<ONNXTensorElementDataType, size_t> dataTypeSizeMap = {
|
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
||||||
@ -46,18 +38,7 @@ static size_t getDataTypeSize(ONNXTensorElementDataType type)
|
|||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
||||||
};
|
};
|
||||||
auto it = dataTypeSizeMap.find(type);
|
|
||||||
return (it == dataTypeSizeMap.end()) ? 0 : it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
|
|
||||||
{
|
|
||||||
auto arr = val.As<TypedArray>();
|
|
||||||
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
|
|
||||||
bytes = arr.ByteLength();
|
|
||||||
return static_cast<void *>(data + arr.ByteOffset());
|
|
||||||
}
|
|
||||||
|
|
||||||
class OrtSessionNodeInfo {
|
class OrtSessionNodeInfo {
|
||||||
public:
|
public:
|
||||||
@ -67,7 +48,22 @@ class OrtSessionNodeInfo {
|
|||||||
inline const std::string &GetName() const { return name_; }
|
inline const std::string &GetName() const { return name_; }
|
||||||
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
||||||
inline ONNXTensorElementDataType GetType() const { return type_; }
|
inline ONNXTensorElementDataType GetType() const { return type_; }
|
||||||
inline size_t GetElementSize() const { return getDataTypeSize(type_); }
|
inline size_t GetElementSize() const
|
||||||
|
{
|
||||||
|
auto it = DATA_TYPE_SIZE_MAP.find(type_);
|
||||||
|
return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second;
|
||||||
|
}
|
||||||
|
TensorDataType GetDataType() const
|
||||||
|
{
|
||||||
|
auto datatype = TensorDataType::Unknown;
|
||||||
|
for (auto it : DATA_TYPE_MAP) {
|
||||||
|
if (it.second == type_) {
|
||||||
|
datatype = it.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return datatype;
|
||||||
|
}
|
||||||
size_t GetElementCount() const
|
size_t GetElementCount() const
|
||||||
{
|
{
|
||||||
if (!shape_.size()) return 0;
|
if (!shape_.size()) return 0;
|
||||||
@ -115,7 +111,8 @@ class OrtSessionRunWorker : public AsyncWorker {
|
|||||||
for (int i = 0; i < outputNames_.size(); ++i) {
|
for (int i = 0; i < outputNames_.size(); ++i) {
|
||||||
size_t bytes = outputElementBytes_[i];
|
size_t bytes = outputElementBytes_[i];
|
||||||
Ort::Value &value = outputValues_[i];
|
Ort::Value &value = outputValues_[i];
|
||||||
auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes);
|
auto buffer = ArrayBuffer::New(Env(), bytes);
|
||||||
|
memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes);
|
||||||
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
||||||
}
|
}
|
||||||
Callback().Call({Env().Undefined(), result});
|
Callback().Call({Env().Undefined(), result});
|
||||||
@ -236,7 +233,12 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
|||||||
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
||||||
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
||||||
else {
|
else {
|
||||||
auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As<String>().Utf8Value()) : input->GetType();
|
auto type = input->GetType();
|
||||||
|
if (inputOption.Has("type")) {
|
||||||
|
auto t = static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value());
|
||||||
|
auto it = DATA_TYPE_MAP.find(t);
|
||||||
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||||
|
}
|
||||||
size_t dataByteLen;
|
size_t dataByteLen;
|
||||||
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
||||||
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
||||||
@ -279,7 +281,7 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
|||||||
auto &node = *it.second;
|
auto &node = *it.second;
|
||||||
auto item = Object::New(env);
|
auto item = Object::New(env);
|
||||||
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
||||||
item.Set(String::New(env, "type"), Number::New(env, node.GetType()));
|
item.Set(String::New(env, "type"), Number::New(env, static_cast<int>(node.GetDataType())));
|
||||||
auto &shapeVec = node.GetShape();
|
auto &shapeVec = node.GetShape();
|
||||||
auto shape = Array::New(env, shapeVec.size());
|
auto shape = Array::New(env, shapeVec.size());
|
||||||
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
||||||
@ -303,7 +305,7 @@ void InstallOrtAPI(Napi::Env env, Napi::Object exports)
|
|||||||
OrtSession::Init(env, exports);
|
OrtSession::Init(env, exports);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_ONNXRUNTIME
|
#if defined(USE_ONNXRUNTIME) && not defined(BUILD_MAIN_WORD)
|
||||||
static Object Init(Env env, Object exports)
|
static Object Init(Env env, Object exports)
|
||||||
{
|
{
|
||||||
InstallOrtAPI(env, exports);
|
InstallOrtAPI(env, exports);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "ai-box",
|
"name": "@yizhi/ai",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"main": "index.js",
|
"main": "index.js",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
@ -6,12 +6,26 @@ export interface SessionNodeInfo {
|
|||||||
shape: number[]
|
shape: number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
||||||
|
|
||||||
|
export enum DataType {
|
||||||
|
Unknown,
|
||||||
|
Float32,
|
||||||
|
Float64,
|
||||||
|
Int32,
|
||||||
|
Uint32,
|
||||||
|
Int16,
|
||||||
|
Uint16,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Int64,
|
||||||
|
Uint64,
|
||||||
|
}
|
||||||
|
|
||||||
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
||||||
|
|
||||||
export interface SessionRunInputOption {
|
export interface SessionRunInputOption {
|
||||||
type?: SessionNodeType
|
type?: DataTypeString
|
||||||
data: SessionNodeData
|
data: SessionNodeData
|
||||||
shape?: number[]
|
shape?: number[]
|
||||||
}
|
}
|
||||||
@ -26,3 +40,36 @@ export abstract class CommonSession {
|
|||||||
export function isTypedArray(val: any): val is SessionNodeData {
|
export function isTypedArray(val: any): val is SessionNodeData {
|
||||||
return val?.buffer instanceof ArrayBuffer;
|
return val?.buffer instanceof ArrayBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function dataTypeFrom(str: DataTypeString) {
|
||||||
|
return {
|
||||||
|
"float32": DataType.Float32,
|
||||||
|
"float64": DataType.Float64,
|
||||||
|
"float": DataType.Float32,
|
||||||
|
"double": DataType.Float64,
|
||||||
|
"int32": DataType.Int32,
|
||||||
|
"uint32": DataType.Uint32,
|
||||||
|
"int16": DataType.Int16,
|
||||||
|
"uint16": DataType.Uint16,
|
||||||
|
"int8": DataType.Int8,
|
||||||
|
"uint8": DataType.Uint8,
|
||||||
|
"int64": DataType.Int64,
|
||||||
|
"uint64": DataType.Uint64,
|
||||||
|
}[str] ?? DataType.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function dataTypeToString(type: DataType): DataTypeString | null {
|
||||||
|
switch (type) {
|
||||||
|
case DataType.Float32: return "float32";
|
||||||
|
case DataType.Float64: return "float64";
|
||||||
|
case DataType.Int32: return "int32";
|
||||||
|
case DataType.Uint32: return "uint32";
|
||||||
|
case DataType.Int16: return "int16";
|
||||||
|
case DataType.Uint16: return "uint16";
|
||||||
|
case DataType.Int8: return "int8";
|
||||||
|
case DataType.Uint8: return "uint8";
|
||||||
|
case DataType.Int64: return "int64";
|
||||||
|
case DataType.Uint64: return "uint64";
|
||||||
|
default: return null;
|
||||||
|
}
|
||||||
|
}
|
@ -1,2 +1,3 @@
|
|||||||
export * as common from "./common";
|
export * as common from "./common";
|
||||||
export * as ort from "./ort";
|
export * as ort from "./ort";
|
||||||
|
export * as mnn from "./mnn";
|
1
src/backend/mnn/index.ts
Normal file
1
src/backend/mnn/index.ts
Normal file
@ -0,0 +1 @@
|
|||||||
|
export { MNNSession as Session } from "./session";
|
30
src/backend/mnn/session.ts
Normal file
30
src/backend/mnn/session.ts
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||||
|
|
||||||
|
export class MNNSession extends CommonSession {
|
||||||
|
#session: any
|
||||||
|
#inputs: Record<string, SessionNodeInfo> | null = null;
|
||||||
|
#outputs: Record<string, SessionNodeInfo> | null = null;
|
||||||
|
|
||||||
|
public constructor(modelData: Uint8Array) {
|
||||||
|
super();
|
||||||
|
const addon = require("../../../build/mnn.node")
|
||||||
|
this.#session = new addon.MNNSession(modelData);
|
||||||
|
}
|
||||||
|
|
||||||
|
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>> {
|
||||||
|
const inputArgs: Record<string, any> = {};
|
||||||
|
for (const [name, option] of Object.entries(inputs)) {
|
||||||
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||||
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
|
}
|
||||||
|
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||||
|
if (err) return reject(err);
|
||||||
|
const result: Record<string, Float32Array> = {};
|
||||||
|
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
||||||
|
resolve(result);
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
|
||||||
|
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
|
||||||
|
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
||||||
|
|
||||||
export class OrtSession extends CommonSession {
|
export class OrtSession extends CommonSession {
|
||||||
#session: any;
|
#session: any;
|
||||||
@ -19,7 +19,7 @@ export class OrtSession extends CommonSession {
|
|||||||
const inputArgs: Record<string, any> = {};
|
const inputArgs: Record<string, any> = {};
|
||||||
for (const [name, option] of Object.entries(inputs)) {
|
for (const [name, option] of Object.entries(inputs)) {
|
||||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||||
else inputArgs[name] = option;
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
||||||
|
@ -45,6 +45,14 @@ export abstract class Model {
|
|||||||
return new this(new backend.ort.Session(modelData as Uint8Array));
|
return new this(new backend.ort.Session(modelData as Uint8Array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static async fromMNN<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
|
||||||
|
if (typeof modelData === "string") {
|
||||||
|
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||||
|
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
|
||||||
|
}
|
||||||
|
return new this(new backend.mnn.Session(modelData as Uint8Array));
|
||||||
|
}
|
||||||
|
|
||||||
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
||||||
//初始化目录
|
//初始化目录
|
||||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||||
@ -111,6 +119,7 @@ export abstract class Model {
|
|||||||
let model: T | undefined = undefined;
|
let model: T | undefined = undefined;
|
||||||
if (option?.createModel) {
|
if (option?.createModel) {
|
||||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||||
|
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
return { modelPath, modelType, model: model as any }
|
return { modelPath, modelType, model: model as any }
|
||||||
|
@ -20,6 +20,7 @@ class FaceLandmark1000Result extends FaceAlignmentResult {
|
|||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
||||||
|
FACELANDMARK1000_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class FaceLandmark1000 extends Model {
|
export class FaceLandmark1000 extends Model {
|
||||||
|
@ -21,6 +21,9 @@ const MODEL_URL_CONFIG = {
|
|||||||
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
||||||
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
||||||
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
||||||
|
PFLD_106_LITE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.mnn`,
|
||||||
|
PFLD_106_V2_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.mnn`,
|
||||||
|
PFLD_106_V3_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.mnn`,
|
||||||
};
|
};
|
||||||
export class PFLD extends Model {
|
export class PFLD extends Model {
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ export interface GenderAgePredictResult {
|
|||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
||||||
|
INSIGHT_GENDER_AGE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class GenderAge extends Model {
|
export class GenderAge extends Model {
|
||||||
|
@ -4,6 +4,7 @@ import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
|
|||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
||||||
|
YOLOV5S_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class Yolov5Face extends FaceDetector {
|
export class Yolov5Face extends FaceDetector {
|
||||||
|
@ -4,6 +4,7 @@ import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
|||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
||||||
|
MOBILEFACENET_ADAFACE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class AdaFace extends FaceRecognition {
|
export class AdaFace extends FaceRecognition {
|
||||||
|
@ -7,16 +7,23 @@ const MODEL_URL_CONFIG_ARC_FACE = {
|
|||||||
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
||||||
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
||||||
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
||||||
|
INSIGHTFACE_ARCFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.mnn`,
|
||||||
|
INSIGHTFACE_ARCFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.mnn`,
|
||||||
|
INSIGHTFACE_ARCFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.mnn`,
|
||||||
};
|
};
|
||||||
const MODEL_URL_CONFIG_COS_FACE = {
|
const MODEL_URL_CONFIG_COS_FACE = {
|
||||||
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
||||||
|
INSIGHTFACE_COSFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.mnn`,
|
||||||
|
INSIGHTFACE_COSFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.mnn`,
|
||||||
|
INSIGHTFACE_COSFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.mnn`,
|
||||||
};
|
};
|
||||||
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
||||||
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
||||||
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
||||||
|
INSIGHTFACE_PARTIALFC_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
17
src/test.ts
17
src/test.ts
@ -30,8 +30,8 @@ async function cacheImage(group: string, url: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function testGenderTest() {
|
async function testGenderTest() {
|
||||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX");
|
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_MNN");
|
||||||
|
|
||||||
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
||||||
const boxes = await facedet.predict(image);
|
const boxes = await facedet.predict(image);
|
||||||
@ -44,9 +44,11 @@ async function testGenderTest() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function testFaceID() {
|
async function testFaceID() {
|
||||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
console.log("初始化模型")
|
||||||
const faceid = await deploy.faceid.PartialFC.load();
|
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
const faceid = await deploy.faceid.CosFace.load("INSIGHTFACE_COSFACE_R50_MNN");
|
||||||
|
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||||
|
console.log("初始化模型完成")
|
||||||
|
|
||||||
const { basic, tests } = faceidTestData.stars;
|
const { basic, tests } = faceidTestData.stars;
|
||||||
|
|
||||||
@ -110,8 +112,8 @@ async function testFaceID() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function testFaceAlign() {
|
async function testFaceAlign() {
|
||||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||||
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
||||||
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
||||||
image = image.rotate(image.width / 2, image.height / 2, 0);
|
image = image.rotate(image.width / 2, image.height / 2, 0);
|
||||||
@ -142,5 +144,4 @@ async function test() {
|
|||||||
|
|
||||||
test().catch(err => {
|
test().catch(err => {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
debugger
|
|
||||||
});
|
});
|
Reference in New Issue
Block a user