Compare commits
14 Commits
8af5846490
...
1.0.6
Author | SHA1 | Date | |
---|---|---|---|
6bf7db1f4c | |||
b48d2daffb | |||
eea943bba5 | |||
232d480ca0 | |||
524dcaecbd | |||
a966b82963 | |||
e362890c96 | |||
92e46c2c33 | |||
2dfe063049 | |||
358d21b2bd | |||
e831b8e862 | |||
bd90f2f6f6 | |||
4a6d092de1 | |||
a6fd117736 |
12
.npmignore
Normal file
12
.npmignore
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
/build
|
||||||
|
/cache
|
||||||
|
/cxx
|
||||||
|
/models
|
||||||
|
/node_modules
|
||||||
|
/src
|
||||||
|
/testdata
|
||||||
|
/thirdpart
|
||||||
|
/tool
|
||||||
|
/.clang-format
|
||||||
|
/CMakeLists.txt
|
||||||
|
/tsconfig.json
|
@ -10,7 +10,7 @@ if(NOT DEFINED CMAKE_BUILD_TYPE)
|
|||||||
set(CMAKE_BUILD_TYPE Release)
|
set(CMAKE_BUILD_TYPE Release)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
set(NODE_COMMON_SOURCES)
|
||||||
set(NODE_ADDON_FOUND OFF)
|
set(NODE_ADDON_FOUND OFF)
|
||||||
|
|
||||||
include_directories(cxx)
|
include_directories(cxx)
|
||||||
@ -74,26 +74,14 @@ if(EXISTS ${MNN_CMAKE_FILE})
|
|||||||
message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}")
|
message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}")
|
||||||
message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}")
|
message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}")
|
||||||
message(STATUS "MNN_LIBS: ${MNN_LIBS}")
|
message(STATUS "MNN_LIBS: ${MNN_LIBS}")
|
||||||
include_directories(${MNN_INCLUDE_DIRS})
|
include_directories(${MNN_INCLUDE_DIR})
|
||||||
link_directories(${MNN_LIB_DIR})
|
link_directories(${MNN_LIB_DIR})
|
||||||
add_compile_definitions(USE_MNN)
|
|
||||||
set(USE_MNN ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# OpenCV
|
|
||||||
set(OpenCV_CMAKE_FILE ${CMAKE_SOURCE_DIR}/thirdpart/OpenCV/${CMAKE_BUILD_TYPE}/config.cmake)
|
|
||||||
if(EXISTS ${OpenCV_CMAKE_FILE})
|
|
||||||
include(${OpenCV_CMAKE_FILE})
|
|
||||||
message(STATUS "OpenCV_LIB_DIR: ${OpenCV_LIB_DIR}")
|
|
||||||
message(STATUS "OpenCV_INCLUDE_DIR: ${OpenCV_INCLUDE_DIR}")
|
|
||||||
message(STATUS "OpenCV_LIBS: ${OpenCV_LIBS}")
|
|
||||||
include_directories(${OpenCV_INCLUDE_DIRS})
|
|
||||||
link_directories(${OpenCV_LIB_DIR})
|
|
||||||
|
|
||||||
if(NODE_ADDON_FOUND)
|
if(NODE_ADDON_FOUND)
|
||||||
add_node_targert(cv cxx/cv/node.cc)
|
add_node_targert(mnn cxx/mnn/node.cc)
|
||||||
target_link_libraries(cv ${OpenCV_LIBS})
|
target_link_libraries(mnn ${MNN_LIBS})
|
||||||
target_compile_definitions(cv PUBLIC USE_OPENCV)
|
target_compile_definitions(mnn PUBLIC USE_MNN)
|
||||||
|
list(APPEND NODE_COMMON_SOURCES cxx/mnn/node.cc)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@ -111,12 +99,29 @@ if(EXISTS ${ONNXRuntime_CMAKE_FILE})
|
|||||||
add_node_targert(ort cxx/ort/node.cc)
|
add_node_targert(ort cxx/ort/node.cc)
|
||||||
target_link_libraries(ort ${ONNXRuntime_LIBS})
|
target_link_libraries(ort ${ONNXRuntime_LIBS})
|
||||||
target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME)
|
target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME)
|
||||||
|
list(APPEND NODE_COMMON_SOURCES cxx/ort/node.cc)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# 统一的NodeJS插件
|
||||||
|
if(NODE_ADDON_FOUND)
|
||||||
|
add_node_targert(addon cxx/node.cc)
|
||||||
|
target_sources(addon PRIVATE ${NODE_COMMON_SOURCES})
|
||||||
|
target_compile_definitions(addon PUBLIC BUILD_MAIN_WORD)
|
||||||
|
# MNN
|
||||||
|
if(EXISTS ${MNN_CMAKE_FILE})
|
||||||
|
target_link_libraries(addon ${MNN_LIBS})
|
||||||
|
target_compile_definitions(addon PUBLIC USE_MNN)
|
||||||
|
endif()
|
||||||
|
# OnnxRuntime
|
||||||
|
if(EXISTS ${ONNXRuntime_CMAKE_FILE})
|
||||||
|
target_link_libraries(addon ${ONNXRuntime_LIBS})
|
||||||
|
target_compile_definitions(addon PUBLIC USE_ONNXRUNTIME)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC AND NODE_ADDON_FOUND)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
|
||||||
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
|
||||||
endif()
|
endif()
|
||||||
|
47
README.md
47
README.md
@ -1 +1,46 @@
|
|||||||
## AI工具箱
|
#AI工具箱
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```
|
||||||
|
npm install @yizhi/ai
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import ai from "@yizhi/ai";
|
||||||
|
|
||||||
|
// 配置相应的node插件,插件需自行编译
|
||||||
|
ai.config("CV_ADDON_FILE", "/path/to/cv.node");
|
||||||
|
ai.config("MNN_ADDON_FILE", "/path/to/mnn.node");
|
||||||
|
ai.config("ORT_ADDON_FILE", "/path/to/onnxruntime.node");
|
||||||
|
|
||||||
|
//直接推理
|
||||||
|
const facedet = await ai.deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||||
|
const boxes = await facedet.predict("/path/to/image");
|
||||||
|
|
||||||
|
//使用自己的模型
|
||||||
|
const session = new ai.backend.ort.Session(modelBuffer);
|
||||||
|
const outputs = session.run(inputs);
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 插件编译
|
||||||
|
|
||||||
|
1. 依赖
|
||||||
|
1. cmake
|
||||||
|
1. ninja
|
||||||
|
1. c++编译器(gcc,clang,Visual Studio ...)
|
||||||
|
|
||||||
|
1. 编译第三方库
|
||||||
|
```
|
||||||
|
node thirdpart/install.js --with-mnn --with-onnx
|
||||||
|
```
|
||||||
|
1. 编译插件
|
||||||
|
```
|
||||||
|
cmake -B build -G Ninja . -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:注意:在Windows下编译时,需要打开Visual Studio命令行
|
||||||
|
@ -6,8 +6,7 @@
|
|||||||
#include <napi.h>
|
#include <napi.h>
|
||||||
|
|
||||||
#define NODE_INIT_OBJECT(name, function) \
|
#define NODE_INIT_OBJECT(name, function) \
|
||||||
do \
|
do { \
|
||||||
{ \
|
|
||||||
auto obj = Napi::Object::New(env); \
|
auto obj = Napi::Object::New(env); \
|
||||||
function(env, obj); \
|
function(env, obj); \
|
||||||
exports.Set(Napi::String::New(env, #name), obj); \
|
exports.Set(Napi::String::New(env, #name), obj); \
|
||||||
@ -21,4 +20,13 @@ inline uint64_t __node_ptr_of__(Napi::Value value)
|
|||||||
|
|
||||||
#define NODE_PTR_OF(type, value) (reinterpret_cast<type *>(__node_ptr_of__(value)))
|
#define NODE_PTR_OF(type, value) (reinterpret_cast<type *>(__node_ptr_of__(value)))
|
||||||
|
|
||||||
|
|
||||||
|
inline void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
|
||||||
|
{
|
||||||
|
auto arr = val.As<Napi::TypedArray>();
|
||||||
|
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
|
||||||
|
bytes = arr.ByteLength();
|
||||||
|
return static_cast<void *>(data + arr.ByteOffset());
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
18
cxx/common/tensor.h
Normal file
18
cxx/common/tensor.h
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#ifndef __COMMON_TENSOR_H__
|
||||||
|
#define __COMMON_TENSOR_H__
|
||||||
|
|
||||||
|
enum class TensorDataType {
|
||||||
|
Unknown,
|
||||||
|
Float32,
|
||||||
|
Float64,
|
||||||
|
Int32,
|
||||||
|
Uint32,
|
||||||
|
Int16,
|
||||||
|
Uint16,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Int64,
|
||||||
|
Uint64,
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
145
cxx/cv/node.cc
145
cxx/cv/node.cc
@ -1,145 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <functional>
|
|
||||||
#include <opencv2/opencv.hpp>
|
|
||||||
#include "node.h"
|
|
||||||
|
|
||||||
using namespace Napi;
|
|
||||||
|
|
||||||
#define MAT_INSTANCE_METHOD(method) InstanceMethod<&CVMat::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
|
||||||
|
|
||||||
static FunctionReference *constructor = nullptr;
|
|
||||||
|
|
||||||
class CVMat : public ObjectWrap<CVMat> {
|
|
||||||
public:
|
|
||||||
static Napi::Object Init(Napi::Env env, Napi::Object exports)
|
|
||||||
{
|
|
||||||
Function func = DefineClass(env, "Mat", {
|
|
||||||
MAT_INSTANCE_METHOD(IsEmpty),
|
|
||||||
MAT_INSTANCE_METHOD(GetCols),
|
|
||||||
MAT_INSTANCE_METHOD(GetRows),
|
|
||||||
MAT_INSTANCE_METHOD(GetChannels),
|
|
||||||
MAT_INSTANCE_METHOD(Resize),
|
|
||||||
MAT_INSTANCE_METHOD(Crop),
|
|
||||||
MAT_INSTANCE_METHOD(Rotate),
|
|
||||||
MAT_INSTANCE_METHOD(Clone),
|
|
||||||
|
|
||||||
MAT_INSTANCE_METHOD(DrawCircle),
|
|
||||||
|
|
||||||
MAT_INSTANCE_METHOD(Data),
|
|
||||||
MAT_INSTANCE_METHOD(Encode),
|
|
||||||
});
|
|
||||||
constructor = new FunctionReference();
|
|
||||||
*constructor = Napi::Persistent(func);
|
|
||||||
exports.Set("Mat", func);
|
|
||||||
env.SetInstanceData<FunctionReference>(constructor);
|
|
||||||
return exports;
|
|
||||||
}
|
|
||||||
|
|
||||||
CVMat(const CallbackInfo &info)
|
|
||||||
: ObjectWrap<CVMat>(info)
|
|
||||||
{
|
|
||||||
int mode = cv::IMREAD_COLOR_BGR;
|
|
||||||
if (info.Length() > 1 && info[1].IsObject()) {
|
|
||||||
Object options = info[1].As<Object>();
|
|
||||||
if (options.Has("mode") && options.Get("mode").IsNumber()) mode = options.Get("mode").As<Number>().Int32Value();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (info[0].IsString()) im_ = cv::imread(info[0].As<String>().Utf8Value(), mode);
|
|
||||||
else if (info[0].IsTypedArray()) {
|
|
||||||
auto buffer = info[0].As<TypedArray>().ArrayBuffer();
|
|
||||||
uint8_t *bufferPtr = static_cast<uint8_t *>(buffer.Data());
|
|
||||||
std::vector<uint8_t> data(bufferPtr, bufferPtr + buffer.ByteLength());
|
|
||||||
im_ = cv::imdecode(data, mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~CVMat() { im_.release(); }
|
|
||||||
|
|
||||||
Napi::Value IsEmpty(const Napi::CallbackInfo &info) { return Boolean::New(info.Env(), im_.empty()); }
|
|
||||||
Napi::Value GetCols(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.cols); }
|
|
||||||
Napi::Value GetRows(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.rows); }
|
|
||||||
Napi::Value GetChannels(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.channels()); }
|
|
||||||
Napi::Value Resize(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
return CreateMat(info.Env(), [this, &info](auto &mat) { cv::resize(im_, mat.im_, cv::Size(info[0].As<Number>().Int32Value(), info[1].As<Number>().Int32Value())); });
|
|
||||||
}
|
|
||||||
Napi::Value Crop(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
return CreateMat(info.Env(), [this, &info](auto &mat) {
|
|
||||||
mat.im_ = im_(cv::Rect(
|
|
||||||
info[0].As<Number>().Int32Value(), info[1].As<Number>().Int32Value(),
|
|
||||||
info[2].As<Number>().Int32Value(), info[3].As<Number>().Int32Value()));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Napi::Value Rotate(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
return CreateMat(info.Env(), [this, &info](auto &mat) {
|
|
||||||
auto x = info[0].As<Number>().DoubleValue();
|
|
||||||
auto y = info[1].As<Number>().DoubleValue();
|
|
||||||
auto angle = info[2].As<Number>().DoubleValue();
|
|
||||||
cv::Mat rotation_matix = cv::getRotationMatrix2D(cv::Point2f(x, y), angle, 1.0);
|
|
||||||
cv::warpAffine(im_, mat.im_, rotation_matix, im_.size());
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Napi::Value Clone(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
return CreateMat(info.Env(), [this, &info](auto &mat) { mat.im_ = im_.clone(); });
|
|
||||||
}
|
|
||||||
|
|
||||||
Napi::Value DrawCircle(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
int x = info[0].As<Number>().Int32Value();
|
|
||||||
int y = info[1].As<Number>().Int32Value();
|
|
||||||
int radius = info[2].As<Number>().Int32Value();
|
|
||||||
int b = info[3].As<Number>().Int32Value();
|
|
||||||
int g = info[4].As<Number>().Int32Value();
|
|
||||||
int r = info[5].As<Number>().Int32Value();
|
|
||||||
int thickness = info[6].As<Number>().Int32Value();
|
|
||||||
int lineType = info[7].As<Number>().Int32Value();
|
|
||||||
int shift = info[8].As<Number>().Int32Value();
|
|
||||||
|
|
||||||
cv::circle(im_, cv::Point(x, y), radius, cv::Scalar(b, g, r), thickness, lineType, shift);
|
|
||||||
return info.Env().Undefined();
|
|
||||||
}
|
|
||||||
|
|
||||||
Napi::Value Data(const Napi::CallbackInfo &info) { return ArrayBuffer::New(info.Env(), im_.ptr(), im_.elemSize() * im_.total()); }
|
|
||||||
Napi::Value Encode(const Napi::CallbackInfo &info)
|
|
||||||
{
|
|
||||||
auto options = info[0].As<Object>();
|
|
||||||
auto extname = options.Get("extname").As<String>().Utf8Value();
|
|
||||||
cv::imencode(extname, im_, encoded_);
|
|
||||||
return ArrayBuffer::New(info.Env(), encoded_.data(), encoded_.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
inline Napi::Object EmptyMat(Napi::Env env) { return constructor->New({}).As<Object>(); }
|
|
||||||
inline CVMat &GetMat(Napi::Object obj) { return *ObjectWrap<CVMat>::Unwrap(obj); }
|
|
||||||
inline Napi::Object CreateMat(Napi::Env env, std::function<void(CVMat &mat)> callback)
|
|
||||||
{
|
|
||||||
auto obj = EmptyMat(env);
|
|
||||||
callback(GetMat(obj));
|
|
||||||
return obj;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
cv::Mat im_;
|
|
||||||
std::vector<uint8_t> encoded_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
void InstallOpenCVAPI(Env env, Object exports)
|
|
||||||
{
|
|
||||||
CVMat::Init(env, exports);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef USE_OPENCV
|
|
||||||
static Object Init(Env env, Object exports)
|
|
||||||
{
|
|
||||||
InstallOpenCVAPI(env, exports);
|
|
||||||
return exports;
|
|
||||||
}
|
|
||||||
NODE_API_MODULE(addon, Init)
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,8 +0,0 @@
|
|||||||
#ifndef __CV_NODE_H__
|
|
||||||
#define __CV_NODE_H__
|
|
||||||
|
|
||||||
#include "common/node.h"
|
|
||||||
|
|
||||||
void InstallOpenCVAPI(Napi::Env env, Napi::Object exports);
|
|
||||||
|
|
||||||
#endif
|
|
244
cxx/mnn/node.cc
Normal file
244
cxx/mnn/node.cc
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <cstring>
|
||||||
|
#include <MNN/Interpreter.hpp>
|
||||||
|
#include <MNN/ImageProcess.hpp>
|
||||||
|
#include "common/tensor.h"
|
||||||
|
#include "node.h"
|
||||||
|
|
||||||
|
using namespace Napi;
|
||||||
|
|
||||||
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
||||||
|
|
||||||
|
static const std::map<TensorDataType, halide_type_t> DATA_TYPE_MAP = {
|
||||||
|
{TensorDataType::Float32, halide_type_of<float>()},
|
||||||
|
{TensorDataType::Float64, halide_type_of<double>()},
|
||||||
|
{TensorDataType::Int32, halide_type_of<int32_t>()},
|
||||||
|
{TensorDataType::Uint32, halide_type_of<uint32_t>()},
|
||||||
|
{TensorDataType::Int16, halide_type_of<int16_t>()},
|
||||||
|
{TensorDataType::Uint16, halide_type_of<uint16_t>()},
|
||||||
|
{TensorDataType::Int8, halide_type_of<int8_t>()},
|
||||||
|
{TensorDataType::Uint8, halide_type_of<uint8_t>()},
|
||||||
|
{TensorDataType::Int64, halide_type_of<int64_t>()},
|
||||||
|
{TensorDataType::Uint64, halide_type_of<uint64_t>()},
|
||||||
|
};
|
||||||
|
|
||||||
|
static size_t getShapeSize(const std::vector<int> &shape)
|
||||||
|
{
|
||||||
|
if (!shape.size()) return 0;
|
||||||
|
size_t sum = 1;
|
||||||
|
for (auto i : shape) {
|
||||||
|
if (i > 1) sum *= i;
|
||||||
|
};
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
|
||||||
|
{
|
||||||
|
auto result = Napi::Array::New(env, shape.size());
|
||||||
|
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
class MNNSessionRunWorker : public AsyncWorker {
|
||||||
|
public:
|
||||||
|
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
||||||
|
: AsyncWorker(callback), interpreter_(interpreter), session_(session) {}
|
||||||
|
|
||||||
|
~MNNSessionRunWorker()
|
||||||
|
{
|
||||||
|
interpreter_->releaseSession(session_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Execute()
|
||||||
|
{
|
||||||
|
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
||||||
|
SetError(std::string("Run session failed"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnOK()
|
||||||
|
{
|
||||||
|
if (HasError()) {
|
||||||
|
Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto result = Object::New(Env());
|
||||||
|
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
||||||
|
auto tensor = it.second;
|
||||||
|
auto item = Object::New(Env());
|
||||||
|
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
||||||
|
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
||||||
|
item.Set("data", buffer);
|
||||||
|
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
|
||||||
|
result.Set(it.first, item);
|
||||||
|
}
|
||||||
|
Callback().Call({Env().Undefined(), result});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetInput(const std::string &name, TensorDataType dataType, const std::vector<int> &shape, void *data, size_t dataBytes)
|
||||||
|
{
|
||||||
|
auto tensor = interpreter_->getSessionInput(session_, name.c_str());
|
||||||
|
if (!tensor) {
|
||||||
|
SetError(std::string("input name #" + name + " not exists"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
halide_type_t type = tensor->getType();
|
||||||
|
if (dataType != TensorDataType::Unknown) {
|
||||||
|
auto it = DATA_TYPE_MAP.find(dataType);
|
||||||
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shape.size()) {
|
||||||
|
interpreter_->resizeTensor(tensor, shape);
|
||||||
|
interpreter_->resizeSession(session_);
|
||||||
|
}
|
||||||
|
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
||||||
|
if (tensorBytes != dataBytes) {
|
||||||
|
SetError(std::string("input name #" + name + " data size not matched"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE);
|
||||||
|
tensor->copyFromHostTensor(hostTensor);
|
||||||
|
delete hostTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetError(const std::string &what) { errorMessage_ = what; }
|
||||||
|
inline bool HasError() { return errorMessage_.size() > 0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
MNN::Interpreter *interpreter_;
|
||||||
|
MNN::Session *session_;
|
||||||
|
std::string errorMessage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MNNSession : public ObjectWrap<MNNSession> {
|
||||||
|
public:
|
||||||
|
static Napi::Object Init(Napi::Env env, Napi::Object exports)
|
||||||
|
{
|
||||||
|
Function func = DefineClass(env, "MNNSession", {
|
||||||
|
SESSION_INSTANCE_METHOD(GetInputsInfo),
|
||||||
|
SESSION_INSTANCE_METHOD(GetOutputsInfo),
|
||||||
|
SESSION_INSTANCE_METHOD(Run),
|
||||||
|
});
|
||||||
|
FunctionReference *constructor = new FunctionReference();
|
||||||
|
*constructor = Napi::Persistent(func);
|
||||||
|
exports.Set("MNNSession", func);
|
||||||
|
env.SetInstanceData<FunctionReference>(constructor);
|
||||||
|
return exports;
|
||||||
|
}
|
||||||
|
|
||||||
|
MNNSession(const CallbackInfo &info)
|
||||||
|
: ObjectWrap(info)
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
if (info[0].IsString()) {
|
||||||
|
interpreter_ = MNN::Interpreter::createFromFile(info[0].As<String>().Utf8Value().c_str());
|
||||||
|
}
|
||||||
|
else if (info[0].IsTypedArray()) {
|
||||||
|
size_t bufferBytes;
|
||||||
|
auto buffer = dataFromTypedArray(info[0], bufferBytes);
|
||||||
|
interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes);
|
||||||
|
}
|
||||||
|
else interpreter_ = nullptr;
|
||||||
|
|
||||||
|
if (interpreter_) {
|
||||||
|
backendConfig_.precision = MNN::BackendConfig::Precision_High;
|
||||||
|
backendConfig_.power = MNN::BackendConfig::Power_High;
|
||||||
|
scheduleConfig_.type = MNN_FORWARD_CPU;
|
||||||
|
scheduleConfig_.numThread = 1;
|
||||||
|
scheduleConfig_.backendConfig = &backendConfig_;
|
||||||
|
session_ = interpreter_->createSession(scheduleConfig_);
|
||||||
|
}
|
||||||
|
else session_ = nullptr;
|
||||||
|
}
|
||||||
|
catch (std::exception &e) {
|
||||||
|
Error::New(info.Env(), e.what()).ThrowAsJavaScriptException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~MNNSession() {}
|
||||||
|
|
||||||
|
Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); }
|
||||||
|
|
||||||
|
Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); }
|
||||||
|
|
||||||
|
Napi::Value Run(const Napi::CallbackInfo &info)
|
||||||
|
{
|
||||||
|
auto worker = new MNNSessionRunWorker(info[1].As<Function>(), interpreter_, interpreter_->createSession(scheduleConfig_));
|
||||||
|
auto inputArgument = info[0].As<Object>();
|
||||||
|
for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) {
|
||||||
|
auto name = (*it).first.As<String>().Utf8Value();
|
||||||
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
||||||
|
auto type = inputOption.Has("type") ? static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value()) : TensorDataType::Unknown;
|
||||||
|
size_t dataByteLen;
|
||||||
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
||||||
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : std::vector<int>();
|
||||||
|
worker->SetInput(name, type, shape, data, dataByteLen);
|
||||||
|
}
|
||||||
|
worker->Queue();
|
||||||
|
return info.Env().Undefined();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map<std::string, MNN::Tensor *> &tensors)
|
||||||
|
{
|
||||||
|
auto result = Object::New(env);
|
||||||
|
for (auto it : tensors) {
|
||||||
|
auto item = Object::New(env);
|
||||||
|
auto name = it.first;
|
||||||
|
auto shape = it.second->shape();
|
||||||
|
auto type = it.second->getType();
|
||||||
|
TensorDataType dataType = TensorDataType::Unknown;
|
||||||
|
for (auto dt : DATA_TYPE_MAP) {
|
||||||
|
if (dt.second == type) {
|
||||||
|
dataType = dt.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto shapeArr = Array::New(env, shape.size());
|
||||||
|
for (size_t i = 0; i < shape.size(); i++) {
|
||||||
|
shapeArr.Set(i, Number::New(env, shape[i]));
|
||||||
|
}
|
||||||
|
item.Set("name", String::New(env, name));
|
||||||
|
item.Set("shape", shapeArr);
|
||||||
|
item.Set("type", Number::New(env, static_cast<int>(dataType)));
|
||||||
|
result.Set(name, item);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetShapeFromJavascript(const Napi::Array &shape)
|
||||||
|
{
|
||||||
|
std::vector<int> result;
|
||||||
|
for (size_t i = 0; i < shape.Length(); i++) {
|
||||||
|
result.push_back(shape.Get(i).As<Number>().Int32Value());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MNN::Interpreter *interpreter_;
|
||||||
|
MNN::Session *session_;
|
||||||
|
MNN::BackendConfig backendConfig_;
|
||||||
|
MNN::ScheduleConfig scheduleConfig_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void InstallMNNAPI(Napi::Env env, Napi::Object exports)
|
||||||
|
{
|
||||||
|
MNNSession::Init(env, exports);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(USE_MNN) && !defined(BUILD_MAIN_WORD)
|
||||||
|
static Object Init(Env env, Object exports)
|
||||||
|
{
|
||||||
|
InstallMNNAPI(env, exports);
|
||||||
|
return exports;
|
||||||
|
}
|
||||||
|
NODE_API_MODULE(addon, Init)
|
||||||
|
#endif
|
8
cxx/mnn/node.h
Normal file
8
cxx/mnn/node.h
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
#ifndef __MNN_NODE_H__
|
||||||
|
#define __MNN_NODE_H__
|
||||||
|
|
||||||
|
#include "common/node.h"
|
||||||
|
|
||||||
|
void InstallMNNAPI(Napi::Env env, Napi::Object exports);
|
||||||
|
|
||||||
|
#endif
|
@ -1,19 +0,0 @@
|
|||||||
#ifndef __MNN_SESSION_H__
|
|
||||||
#define __MNN_SESSION_H__
|
|
||||||
|
|
||||||
#include <MNN/Interpreter.hpp>
|
|
||||||
#include <MNN/ImageProcess.hpp>
|
|
||||||
|
|
||||||
#include "common/session.h"
|
|
||||||
|
|
||||||
namespace ai
|
|
||||||
{
|
|
||||||
class MNNSession : public Session
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
MNNSession(const void *modelData, size_t size);
|
|
||||||
~MNNSession();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
81
cxx/node.cc
81
cxx/node.cc
@ -1,65 +1,22 @@
|
|||||||
// #include <unistd.h>
|
#include "common/node.h"
|
||||||
// #include <napi.h>
|
#include "mnn/node.h"
|
||||||
// #include "cv/node.h"
|
#include "ort/node.h"
|
||||||
// #ifdef USE_ORT
|
|
||||||
// #include "ort/node.h"
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
// using namespace Napi;
|
using namespace Napi;
|
||||||
|
|
||||||
// class TestWork : public AsyncWorker
|
#if defined(BUILD_MAIN_WORD)
|
||||||
// {
|
Object Init(Env env, Object exports)
|
||||||
// public:
|
{
|
||||||
// TestWork(const Napi::Function &callback, int value) : Napi::AsyncWorker(callback), val_(value) {}
|
// OnnxRuntime
|
||||||
// ~TestWork() {}
|
#ifdef USE_ONNXRUNTIME
|
||||||
|
InstallOrtAPI(env, exports);
|
||||||
|
#endif
|
||||||
|
// MNN
|
||||||
|
#ifdef USE_MNN
|
||||||
|
InstallMNNAPI(env, exports);
|
||||||
|
#endif
|
||||||
|
|
||||||
// void Execute()
|
return exports;
|
||||||
// {
|
}
|
||||||
// printf("the worker-thread doing! %d \n", val_);
|
NODE_API_MODULE(addon, Init)
|
||||||
// sleep(3);
|
#endif
|
||||||
// printf("the worker-thread done! %d \n", val_);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void OnOK()
|
|
||||||
// {
|
|
||||||
// Callback().Call({Env().Undefined(), Number::New(Env(), 0)});
|
|
||||||
// }
|
|
||||||
|
|
||||||
// private:
|
|
||||||
// int val_;
|
|
||||||
// };
|
|
||||||
|
|
||||||
// Value test(const CallbackInfo &info)
|
|
||||||
// {
|
|
||||||
// // ai::ORTSession(nullptr, 0);
|
|
||||||
|
|
||||||
// // Function callback = info[1].As<Function>();
|
|
||||||
// // TestWork *work = new TestWork(callback, info[0].As<Number>().Int32Value());
|
|
||||||
// // work->Queue();
|
|
||||||
// return info.Env().Undefined();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Object Init(Env env, Object exports)
|
|
||||||
// {
|
|
||||||
// //OpenCV
|
|
||||||
// NODE_INIT_OBJECT(cv, InstallOpenCVAPI);
|
|
||||||
// //OnnxRuntime
|
|
||||||
// #ifdef USE_ORT
|
|
||||||
// NODE_INIT_OBJECT(ort, InstallOrtAPI);
|
|
||||||
// #endif
|
|
||||||
|
|
||||||
// Napi::Number::New(env, 0);
|
|
||||||
|
|
||||||
// #define ADD_FUNCTION(name) exports.Set(Napi::String::New(env, #name), Napi::Function::New(env, name))
|
|
||||||
// // ADD_FUNCTION(facedetPredict);
|
|
||||||
// // ADD_FUNCTION(facedetRelease);
|
|
||||||
|
|
||||||
// // ADD_FUNCTION(faceRecognitionCreate);
|
|
||||||
// // ADD_FUNCTION(faceRecognitionPredict);
|
|
||||||
// // ADD_FUNCTION(faceRecognitionRelease);
|
|
||||||
|
|
||||||
// // ADD_FUNCTION(getDistance);
|
|
||||||
// #undef ADD_FUNCTION
|
|
||||||
// return exports;
|
|
||||||
// }
|
|
||||||
// NODE_API_MODULE(addon, Init)
|
|
100
cxx/ort/node.cc
100
cxx/ort/node.cc
@ -2,6 +2,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <onnxruntime_cxx_api.h>
|
#include <onnxruntime_cxx_api.h>
|
||||||
#include "node.h"
|
#include "node.h"
|
||||||
|
#include "common/tensor.h"
|
||||||
|
|
||||||
#ifdef WIN32
|
#ifdef WIN32
|
||||||
#include <locale>
|
#include <locale>
|
||||||
@ -13,51 +14,31 @@ using namespace Napi;
|
|||||||
|
|
||||||
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
||||||
|
|
||||||
static ONNXTensorElementDataType getDataTypeFromString(const std::string &name)
|
static const std::map<TensorDataType, ONNXTensorElementDataType> DATA_TYPE_MAP = {
|
||||||
{
|
{TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
||||||
static const std::map<std::string, ONNXTensorElementDataType> dataTypeNameMap = {
|
{TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
||||||
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
{TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
||||||
{"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
{TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
||||||
{"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
{TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
||||||
{"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
{TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
||||||
{"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
{TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
||||||
{"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
{TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
||||||
{"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
{TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
||||||
{"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
{TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
||||||
{"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
};
|
||||||
{"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
|
||||||
{"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
|
||||||
{"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
|
||||||
};
|
|
||||||
auto it = dataTypeNameMap.find(name);
|
|
||||||
return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t getDataTypeSize(ONNXTensorElementDataType type)
|
static const std::map<ONNXTensorElementDataType, size_t> DATA_TYPE_SIZE_MAP = {
|
||||||
{
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
||||||
static const std::map<ONNXTensorElementDataType, size_t> dataTypeSizeMap = {
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
};
|
||||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
|
||||||
};
|
|
||||||
auto it = dataTypeSizeMap.find(type);
|
|
||||||
return (it == dataTypeSizeMap.end()) ? 0 : it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
|
|
||||||
{
|
|
||||||
auto arr = val.As<TypedArray>();
|
|
||||||
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
|
|
||||||
bytes = arr.ByteLength();
|
|
||||||
return static_cast<void *>(data + arr.ByteOffset());
|
|
||||||
}
|
|
||||||
|
|
||||||
class OrtSessionNodeInfo {
|
class OrtSessionNodeInfo {
|
||||||
public:
|
public:
|
||||||
@ -67,7 +48,22 @@ class OrtSessionNodeInfo {
|
|||||||
inline const std::string &GetName() const { return name_; }
|
inline const std::string &GetName() const { return name_; }
|
||||||
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
||||||
inline ONNXTensorElementDataType GetType() const { return type_; }
|
inline ONNXTensorElementDataType GetType() const { return type_; }
|
||||||
inline size_t GetElementSize() const { return getDataTypeSize(type_); }
|
inline size_t GetElementSize() const
|
||||||
|
{
|
||||||
|
auto it = DATA_TYPE_SIZE_MAP.find(type_);
|
||||||
|
return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second;
|
||||||
|
}
|
||||||
|
TensorDataType GetDataType() const
|
||||||
|
{
|
||||||
|
auto datatype = TensorDataType::Unknown;
|
||||||
|
for (auto it : DATA_TYPE_MAP) {
|
||||||
|
if (it.second == type_) {
|
||||||
|
datatype = it.first;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return datatype;
|
||||||
|
}
|
||||||
size_t GetElementCount() const
|
size_t GetElementCount() const
|
||||||
{
|
{
|
||||||
if (!shape_.size()) return 0;
|
if (!shape_.size()) return 0;
|
||||||
@ -115,7 +111,8 @@ class OrtSessionRunWorker : public AsyncWorker {
|
|||||||
for (int i = 0; i < outputNames_.size(); ++i) {
|
for (int i = 0; i < outputNames_.size(); ++i) {
|
||||||
size_t bytes = outputElementBytes_[i];
|
size_t bytes = outputElementBytes_[i];
|
||||||
Ort::Value &value = outputValues_[i];
|
Ort::Value &value = outputValues_[i];
|
||||||
auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes);
|
auto buffer = ArrayBuffer::New(Env(), bytes);
|
||||||
|
memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes);
|
||||||
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
||||||
}
|
}
|
||||||
Callback().Call({Env().Undefined(), result});
|
Callback().Call({Env().Undefined(), result});
|
||||||
@ -236,7 +233,12 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
|||||||
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
||||||
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
||||||
else {
|
else {
|
||||||
auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As<String>().Utf8Value()) : input->GetType();
|
auto type = input->GetType();
|
||||||
|
if (inputOption.Has("type")) {
|
||||||
|
auto t = static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value());
|
||||||
|
auto it = DATA_TYPE_MAP.find(t);
|
||||||
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||||
|
}
|
||||||
size_t dataByteLen;
|
size_t dataByteLen;
|
||||||
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
||||||
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
||||||
@ -279,7 +281,7 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
|||||||
auto &node = *it.second;
|
auto &node = *it.second;
|
||||||
auto item = Object::New(env);
|
auto item = Object::New(env);
|
||||||
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
||||||
item.Set(String::New(env, "type"), Number::New(env, node.GetType()));
|
item.Set(String::New(env, "type"), Number::New(env, static_cast<int>(node.GetDataType())));
|
||||||
auto &shapeVec = node.GetShape();
|
auto &shapeVec = node.GetShape();
|
||||||
auto shape = Array::New(env, shapeVec.size());
|
auto shape = Array::New(env, shapeVec.size());
|
||||||
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
||||||
@ -303,7 +305,7 @@ void InstallOrtAPI(Napi::Env env, Napi::Object exports)
|
|||||||
OrtSession::Init(env, exports);
|
OrtSession::Init(env, exports);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_ONNXRUNTIME
|
#if defined(USE_ONNXRUNTIME) && !defined(BUILD_MAIN_WORD)
|
||||||
static Object Init(Env env, Object exports)
|
static Object Init(Env env, Object exports)
|
||||||
{
|
{
|
||||||
InstallOrtAPI(env, exports);
|
InstallOrtAPI(env, exports);
|
||||||
|
14
package.json
14
package.json
@ -1,8 +1,11 @@
|
|||||||
{
|
{
|
||||||
"name": "ai-box",
|
"name": "@yizhi/ai",
|
||||||
"version": "1.0.0",
|
"version": "1.0.6",
|
||||||
"main": "index.js",
|
"releaseVersion": "1.0.6",
|
||||||
|
"main": "dist/index.js",
|
||||||
|
"types": "typing/index.d.ts",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"build": "rm -rf dist typing && tsc",
|
||||||
"watch": "tsc -w --inlineSourceMap"
|
"watch": "tsc -w --inlineSourceMap"
|
||||||
},
|
},
|
||||||
"keywords": [],
|
"keywords": [],
|
||||||
@ -15,5 +18,8 @@
|
|||||||
"compressing": "^1.10.1",
|
"compressing": "^1.10.1",
|
||||||
"node-addon-api": "^8.3.1",
|
"node-addon-api": "^8.3.1",
|
||||||
"unbzip2-stream": "^1.4.3"
|
"unbzip2-stream": "^1.4.3"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@yizhi/cv": "^1.0.2"
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,22 +2,41 @@
|
|||||||
|
|
||||||
export interface SessionNodeInfo {
|
export interface SessionNodeInfo {
|
||||||
name: string
|
name: string
|
||||||
type: number
|
type: DataType
|
||||||
shape: number[]
|
shape: number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
export type DataTypeString = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64"
|
||||||
|
|
||||||
|
export enum DataType {
|
||||||
|
Unknown,
|
||||||
|
Float32,
|
||||||
|
Float64,
|
||||||
|
Int32,
|
||||||
|
Uint32,
|
||||||
|
Int16,
|
||||||
|
Uint16,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Int64,
|
||||||
|
Uint64,
|
||||||
|
}
|
||||||
|
|
||||||
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array
|
||||||
|
|
||||||
export interface SessionRunInputOption {
|
export interface SessionRunInputOption {
|
||||||
type?: SessionNodeType
|
type?: DataTypeString
|
||||||
data: SessionNodeData
|
data: SessionNodeData
|
||||||
shape?: number[]
|
shape?: number[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface SessionRunOutput {
|
||||||
|
shape: number[]
|
||||||
|
data: Float32Array
|
||||||
|
}
|
||||||
|
|
||||||
export abstract class CommonSession {
|
export abstract class CommonSession {
|
||||||
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, Float32Array>>
|
public abstract run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>>
|
||||||
|
|
||||||
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
public abstract get inputs(): Record<string, SessionNodeInfo>;
|
||||||
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
public abstract get outputs(): Record<string, SessionNodeInfo>;
|
||||||
@ -26,3 +45,36 @@ export abstract class CommonSession {
|
|||||||
export function isTypedArray(val: any): val is SessionNodeData {
|
export function isTypedArray(val: any): val is SessionNodeData {
|
||||||
return val?.buffer instanceof ArrayBuffer;
|
return val?.buffer instanceof ArrayBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function dataTypeFrom(str: DataTypeString) {
|
||||||
|
return {
|
||||||
|
"float32": DataType.Float32,
|
||||||
|
"float64": DataType.Float64,
|
||||||
|
"float": DataType.Float32,
|
||||||
|
"double": DataType.Float64,
|
||||||
|
"int32": DataType.Int32,
|
||||||
|
"uint32": DataType.Uint32,
|
||||||
|
"int16": DataType.Int16,
|
||||||
|
"uint16": DataType.Uint16,
|
||||||
|
"int8": DataType.Int8,
|
||||||
|
"uint8": DataType.Uint8,
|
||||||
|
"int64": DataType.Int64,
|
||||||
|
"uint64": DataType.Uint64,
|
||||||
|
}[str] ?? DataType.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function dataTypeToString(type: DataType): DataTypeString | null {
|
||||||
|
switch (type) {
|
||||||
|
case DataType.Float32: return "float32";
|
||||||
|
case DataType.Float64: return "float64";
|
||||||
|
case DataType.Int32: return "int32";
|
||||||
|
case DataType.Uint32: return "uint32";
|
||||||
|
case DataType.Int16: return "int16";
|
||||||
|
case DataType.Uint16: return "uint16";
|
||||||
|
case DataType.Int8: return "int8";
|
||||||
|
case DataType.Uint8: return "uint8";
|
||||||
|
case DataType.Int64: return "int64";
|
||||||
|
case DataType.Uint64: return "uint64";
|
||||||
|
default: return null;
|
||||||
|
}
|
||||||
|
}
|
@ -1 +1 @@
|
|||||||
export * as backend from "./main";
|
export * as backend from "./main";
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
export * as common from "./common";
|
export { SessionNodeInfo, DataTypeString, DataType, SessionNodeData, SessionRunInputOption, SessionRunOutput, CommonSession } from "./common";
|
||||||
export * as ort from "./ort";
|
export * as ort from "./ort";
|
||||||
|
export * as mnn from "./mnn";
|
1
src/backend/mnn/index.ts
Normal file
1
src/backend/mnn/index.ts
Normal file
@ -0,0 +1 @@
|
|||||||
|
export { MNNSession as Session } from "./session";
|
34
src/backend/mnn/session.ts
Normal file
34
src/backend/mnn/session.ts
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import { getConfig } from "../../config";
|
||||||
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||||
|
|
||||||
|
export class MNNSession extends CommonSession {
|
||||||
|
#session: any
|
||||||
|
#inputs: Record<string, SessionNodeInfo> | null = null;
|
||||||
|
#outputs: Record<string, SessionNodeInfo> | null = null;
|
||||||
|
|
||||||
|
public constructor(modelData: Uint8Array) {
|
||||||
|
super();
|
||||||
|
const addon = require(getConfig("MNN_ADDON_FILE"));
|
||||||
|
this.#session = new addon.MNNSession(modelData);
|
||||||
|
}
|
||||||
|
|
||||||
|
public run(inputs: Record<string, SessionNodeData | SessionRunInputOption>): Promise<Record<string, SessionRunOutput>> {
|
||||||
|
const inputArgs: Record<string, any> = {};
|
||||||
|
for (const [name, option] of Object.entries(inputs)) {
|
||||||
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||||
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
|
}
|
||||||
|
return new Promise((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, { data: ArrayBuffer, shape: number[] }>) => {
|
||||||
|
if (err) return reject(err);
|
||||||
|
const result: Record<string, SessionRunOutput> = {};
|
||||||
|
for (const [name, val] of Object.entries(res)) result[name] = {
|
||||||
|
shape: val.shape,
|
||||||
|
data: new Float32Array(val.data),
|
||||||
|
}
|
||||||
|
resolve(result);
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
public get inputs(): Record<string, SessionNodeInfo> { return this.#inputs ??= this.#session.GetInputsInfo(); }
|
||||||
|
public get outputs(): Record<string, SessionNodeInfo> { return this.#outputs ??= this.#session.GetOutputsInfo(); }
|
||||||
|
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common";
|
import { getConfig } from "../../config";
|
||||||
|
import { CommonSession, dataTypeFrom, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption, SessionRunOutput } from "../common";
|
||||||
|
|
||||||
export class OrtSession extends CommonSession {
|
export class OrtSession extends CommonSession {
|
||||||
#session: any;
|
#session: any;
|
||||||
@ -7,7 +8,7 @@ export class OrtSession extends CommonSession {
|
|||||||
|
|
||||||
public constructor(modelData: Uint8Array) {
|
public constructor(modelData: Uint8Array) {
|
||||||
super();
|
super();
|
||||||
const addon = require("../../../build/ort.node")
|
const addon = require(getConfig("ORT_ADDON_FILE"));
|
||||||
this.#session = new addon.OrtSession(modelData);
|
this.#session = new addon.OrtSession(modelData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,13 +20,13 @@ export class OrtSession extends CommonSession {
|
|||||||
const inputArgs: Record<string, any> = {};
|
const inputArgs: Record<string, any> = {};
|
||||||
for (const [name, option] of Object.entries(inputs)) {
|
for (const [name, option] of Object.entries(inputs)) {
|
||||||
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
if (isTypedArray(option)) inputArgs[name] = { data: option }
|
||||||
else inputArgs[name] = option;
|
else inputArgs[name] = { ...option, type: option.type ? dataTypeFrom(option.type) : undefined };
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Promise<Record<string, Float32Array>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => {
|
return new Promise<Record<string, SessionRunOutput>>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: Record<string, ArrayBuffer>) => {
|
||||||
if (err) return reject(err);
|
if (err) return reject(err);
|
||||||
const result: Record<string, Float32Array> = {};
|
const result: Record<string, SessionRunOutput> = {};
|
||||||
for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer);
|
for (const [name, val] of Object.entries(res)) result[name] = { data: new Float32Array(val), shape: this.outputs[name].shape };
|
||||||
resolve(result);
|
resolve(result);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
12
src/config.ts
Normal file
12
src/config.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import path from "path";
|
||||||
|
|
||||||
|
const defaultAddonDir = path.join(__dirname, "../build")
|
||||||
|
|
||||||
|
const aiConfig = {
|
||||||
|
"MNN_ADDON_FILE": path.join(defaultAddonDir, "mnn.node"),
|
||||||
|
"ORT_ADDON_FILE": path.join(defaultAddonDir, "ort.node"),
|
||||||
|
};
|
||||||
|
|
||||||
|
export function setConfig<K extends keyof typeof aiConfig>(key: K, value: typeof aiConfig[K]) { aiConfig[key] = value; }
|
||||||
|
|
||||||
|
export function getConfig<K extends keyof typeof aiConfig>(key: K): typeof aiConfig[K] { return aiConfig[key]; }
|
@ -1 +0,0 @@
|
|||||||
export * as cv from "./main";
|
|
@ -1 +0,0 @@
|
|||||||
export { Mat, ImreadModes } from "./mat";
|
|
@ -1,79 +0,0 @@
|
|||||||
|
|
||||||
export enum ImreadModes {
|
|
||||||
IMREAD_UNCHANGED = -1,
|
|
||||||
IMREAD_GRAYSCALE = 0,
|
|
||||||
IMREAD_COLOR_BGR = 1,
|
|
||||||
IMREAD_COLOR = 1,
|
|
||||||
IMREAD_ANYDEPTH = 2,
|
|
||||||
IMREAD_ANYCOLOR = 4,
|
|
||||||
IMREAD_LOAD_GDAL = 8,
|
|
||||||
IMREAD_REDUCED_GRAYSCALE_2 = 16,
|
|
||||||
IMREAD_REDUCED_COLOR_2 = 17,
|
|
||||||
IMREAD_REDUCED_GRAYSCALE_4 = 32,
|
|
||||||
IMREAD_REDUCED_COLOR_4 = 33,
|
|
||||||
IMREAD_REDUCED_GRAYSCALE_8 = 64,
|
|
||||||
IMREAD_REDUCED_COLOR_8 = 65,
|
|
||||||
IMREAD_IGNORE_ORIENTATION = 128,
|
|
||||||
IMREAD_COLOR_RGB = 256,
|
|
||||||
};
|
|
||||||
|
|
||||||
interface MatConstructorOption {
|
|
||||||
mode?: ImreadModes;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class Mat {
|
|
||||||
#mat: any
|
|
||||||
|
|
||||||
public static async load(image: string, option?: MatConstructorOption) {
|
|
||||||
let buffer: Uint8Array
|
|
||||||
if (/^https?:\/\//.test(image)) buffer = await fetch(image).then(res => res.arrayBuffer()).then(res => new Uint8Array(res));
|
|
||||||
else buffer = await import("fs").then(fs => fs.promises.readFile(image));
|
|
||||||
return new Mat(buffer, option);
|
|
||||||
}
|
|
||||||
|
|
||||||
public constructor(imageData: Uint8Array, option?: MatConstructorOption) {
|
|
||||||
const addon = require("../../build/cv.node");
|
|
||||||
if ((imageData as any) instanceof addon.Mat) this.#mat = imageData;
|
|
||||||
else this.#mat = new addon.Mat(imageData, option);
|
|
||||||
}
|
|
||||||
|
|
||||||
public get empty(): boolean { return this.#mat.IsEmpty() }
|
|
||||||
|
|
||||||
public get cols(): number { return this.#mat.GetCols(); }
|
|
||||||
|
|
||||||
public get rows(): number { return this.#mat.GetRows(); }
|
|
||||||
|
|
||||||
public get width() { return this.cols; }
|
|
||||||
|
|
||||||
public get height() { return this.rows; }
|
|
||||||
|
|
||||||
public get channels() { return this.#mat.GetChannels(); }
|
|
||||||
|
|
||||||
public resize(width: number, height: number) { return new Mat(this.#mat.Resize.bind(this.#mat)(width, height)); }
|
|
||||||
|
|
||||||
public crop(sx: number, sy: number, sw: number, sh: number) { return new Mat(this.#mat.Crop(sx, sy, sw, sh)); }
|
|
||||||
|
|
||||||
public rotate(sx: number, sy: number, angleDeg: number) { return new Mat(this.#mat.Rotate(sx, sy, angleDeg)); }
|
|
||||||
|
|
||||||
public get data() { return new Uint8Array(this.#mat.Data()); }
|
|
||||||
|
|
||||||
public encode(extname: string) { return new Uint8Array(this.#mat.Encode({ extname })); }
|
|
||||||
|
|
||||||
public clone() { return new Mat(this.#mat.Clone()); }
|
|
||||||
|
|
||||||
public circle(x: number, y: number, radius: number, options?: {
|
|
||||||
color?: { r: number, g: number, b: number },
|
|
||||||
thickness?: number
|
|
||||||
lineType?: number
|
|
||||||
}) {
|
|
||||||
this.#mat.DrawCircle(
|
|
||||||
x, y, radius,
|
|
||||||
options?.color?.b ?? 0,
|
|
||||||
options?.color?.g ?? 0,
|
|
||||||
options?.color?.r ?? 0,
|
|
||||||
options?.thickness ?? 1,
|
|
||||||
options?.lineType ?? 8,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,13 +1,13 @@
|
|||||||
|
import { cv } from "@yizhi/cv";
|
||||||
import { backend } from "../../backend";
|
import { backend } from "../../backend";
|
||||||
import { cv } from "../../cv";
|
|
||||||
|
|
||||||
export type ModelConstructor<T> = new (session: backend.common.CommonSession) => T;
|
export type ModelConstructor<T> = new (session: backend.CommonSession) => T;
|
||||||
|
|
||||||
export type ImageSource = cv.Mat | Uint8Array | string;
|
export type ImageSource = cv.Mat | Uint8Array | string;
|
||||||
|
|
||||||
export interface ImageCropOption {
|
export interface ImageCropOption {
|
||||||
/** 图片裁剪区域 */
|
/** 图片裁剪区域 */
|
||||||
crop?: { sx: number, sy: number, sw: number, sh: number }
|
crop?: { x: number, y: number, width: number, height: number }
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelType = "onnx" | "mnn"
|
export type ModelType = "onnx" | "mnn"
|
||||||
@ -24,15 +24,17 @@ export interface ModelCacheResult<T, Create extends boolean> {
|
|||||||
model: Create extends true ? T : never
|
model: Create extends true ? T : never
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const cacheDownloadTasks: Array<{ url: string, cacheDir: string, modelPath: Promise<string> }> = [];
|
||||||
|
|
||||||
export abstract class Model {
|
export abstract class Model {
|
||||||
protected session: backend.common.CommonSession;
|
protected session: backend.CommonSession;
|
||||||
|
|
||||||
protected static async resolveImage<R>(image: ImageSource, resolver: (image: cv.Mat) => R | Promise<R>): Promise<R> {
|
protected static async resolveImage<R>(image: ImageSource, resolver: (image: cv.Mat) => R | Promise<R>): Promise<R> {
|
||||||
if (typeof image === "string") {
|
if (typeof image === "string") {
|
||||||
if (/^https?:\/\//.test(image)) image = await fetch(image).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
if (/^https?:\/\//.test(image)) image = await fetch(image).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||||
else image = await import("fs").then(fs => fs.promises.readFile(image as string));
|
else image = await import("fs").then(fs => fs.promises.readFile(image as string));
|
||||||
}
|
}
|
||||||
if (image instanceof Uint8Array) image = new cv.Mat(image, { mode: cv.ImreadModes.IMREAD_COLOR_BGR })
|
if (image instanceof Uint8Array) image = cv.imdecode(image, cv.IMREAD_COLOR_BGR);
|
||||||
if (image instanceof cv.Mat) return await resolver(image);
|
if (image instanceof cv.Mat) return await resolver(image);
|
||||||
else throw new Error("Invalid image");
|
else throw new Error("Invalid image");
|
||||||
}
|
}
|
||||||
@ -45,78 +47,108 @@ export abstract class Model {
|
|||||||
return new this(new backend.ort.Session(modelData as Uint8Array));
|
return new this(new backend.ort.Session(modelData as Uint8Array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static async fromMNN<T extends Model>(this: ModelConstructor<T>, modelData: Uint8Array | string) {
|
||||||
|
if (typeof modelData === "string") {
|
||||||
|
if (/^https?:\/\//.test(modelData)) modelData = await fetch(modelData).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer));
|
||||||
|
else modelData = await import("fs").then(fs => fs.promises.readFile(modelData as string));
|
||||||
|
}
|
||||||
|
return new this(new backend.mnn.Session(modelData as Uint8Array));
|
||||||
|
}
|
||||||
|
|
||||||
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
protected static async cacheModel<T extends Model, Create extends boolean = false>(this: ModelConstructor<T>, url: string, option?: ModelCacheOption<Create>): Promise<ModelCacheResult<T, Create>> {
|
||||||
//初始化目录
|
//初始化目录
|
||||||
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
const [fs, path, os, crypto] = await Promise.all([import("fs"), import("path"), import("os"), import("crypto")]);
|
||||||
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
const cacheDir = option?.cacheDir ?? path.join(os.homedir(), ".aibox_cache/models");
|
||||||
await fs.promises.mkdir(cacheDir, { recursive: true });
|
await fs.promises.mkdir(cacheDir, { recursive: true });
|
||||||
//加载模型配置
|
|
||||||
const cacheJsonFile = path.join(cacheDir, "config.json");
|
|
||||||
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
|
||||||
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
|
||||||
try {
|
|
||||||
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
|
||||||
} catch (e) { console.error(e); }
|
|
||||||
}
|
|
||||||
//不存在则下载
|
|
||||||
let cache = cacheJsonData.find(c => c.url === url);
|
|
||||||
if (!cache) {
|
|
||||||
let saveType = option?.saveType ?? null;
|
|
||||||
const saveTypeDict: Record<string, ModelType> = {
|
|
||||||
".onnx": "onnx",
|
|
||||||
".mnn": "mnn",
|
|
||||||
};
|
|
||||||
const _url = new URL(url);
|
|
||||||
const res = await fetch(_url).then(res => {
|
|
||||||
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
|
||||||
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
|
||||||
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
|
||||||
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
|
||||||
return res.blob();
|
|
||||||
}).then(blob => blob.stream()).then(async stream => {
|
|
||||||
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
|
||||||
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
|
||||||
let hashStream!: ReturnType<typeof crypto.createHash>;
|
|
||||||
let hash!: string;
|
|
||||||
await stream.pipeTo(new WritableStream({
|
|
||||||
start(controller) {
|
|
||||||
fsStream = fs.createWriteStream(cacheFilename);
|
|
||||||
hashStream = crypto.createHash("md5");
|
|
||||||
},
|
|
||||||
async write(chunk, controller) {
|
|
||||||
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
|
||||||
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
|
||||||
},
|
|
||||||
close() {
|
|
||||||
fsStream.end();
|
|
||||||
hashStream.end();
|
|
||||||
hash = hashStream.digest("hex")
|
|
||||||
},
|
|
||||||
abort() { }
|
|
||||||
}));
|
|
||||||
return { filename: cacheFilename, hash };
|
|
||||||
});
|
|
||||||
//重命名
|
|
||||||
const filename = `${res.hash}.${saveType}`;
|
|
||||||
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
|
||||||
//保存缓存
|
|
||||||
cache = { url, filename };
|
|
||||||
cacheJsonData.push(cache);
|
|
||||||
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
|
||||||
}
|
|
||||||
//返回模型数据
|
|
||||||
const modelPath = path.join(cacheDir, cache.filename);
|
|
||||||
|
|
||||||
const modelType = path.extname(cache.filename).substring(1) as ModelType;
|
//定义函数用于加载模型信息
|
||||||
|
async function resolveModel() {
|
||||||
|
//加载模型配置
|
||||||
|
const cacheJsonFile = path.join(cacheDir, "config.json");
|
||||||
|
let cacheJsonData: Array<{ url: string, filename: string }> = [];
|
||||||
|
if (fs.existsSync(cacheJsonFile) && await fs.promises.stat(cacheJsonFile).then(s => s.isFile())) {
|
||||||
|
try {
|
||||||
|
cacheJsonData = JSON.parse(await fs.promises.readFile(cacheJsonFile, "utf-8"));
|
||||||
|
} catch (e) { console.error(e); }
|
||||||
|
}
|
||||||
|
//不存在则下载
|
||||||
|
let cache = cacheJsonData.find(c => c.url === url);
|
||||||
|
if (!cache || !fs.existsSync(path.join(cacheDir, cache.filename))) {
|
||||||
|
let saveType = option?.saveType ?? null;
|
||||||
|
const saveTypeDict: Record<string, ModelType> = {
|
||||||
|
".onnx": "onnx",
|
||||||
|
".mnn": "mnn",
|
||||||
|
};
|
||||||
|
const _url = new URL(url);
|
||||||
|
const res = await fetch(_url).then(res => {
|
||||||
|
const filename = res.headers.get("content-disposition")?.match(/filename="(.+?)"/)?.[1];
|
||||||
|
if (filename) saveType = saveTypeDict[path.extname(filename)] ?? saveType;
|
||||||
|
if (!saveType) saveType = saveTypeDict[path.extname(_url.pathname)] ?? "onnx";
|
||||||
|
if (res.status !== 200) throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
||||||
|
return res.blob();
|
||||||
|
}).then(blob => blob.stream()).then(async stream => {
|
||||||
|
const cacheFilename = path.join(cacheDir, Date.now().toString());
|
||||||
|
const hash = await new Promise<string>((resolve, reject) => {
|
||||||
|
let fsStream!: ReturnType<typeof fs.createWriteStream>;
|
||||||
|
let hashStream!: ReturnType<typeof crypto.createHash>;
|
||||||
|
stream.pipeTo(new WritableStream({
|
||||||
|
start(controller) {
|
||||||
|
fsStream = fs.createWriteStream(cacheFilename);
|
||||||
|
hashStream = crypto.createHash("md5");
|
||||||
|
},
|
||||||
|
async write(chunk, controller) {
|
||||||
|
await new Promise<void>((resolve, reject) => fsStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||||
|
await new Promise<void>((resolve, reject) => hashStream.write(chunk, err => err ? reject(err) : resolve()));
|
||||||
|
},
|
||||||
|
close() {
|
||||||
|
fsStream.end();
|
||||||
|
hashStream.end();
|
||||||
|
resolve(hashStream.digest("hex"));
|
||||||
|
},
|
||||||
|
abort() { }
|
||||||
|
})).catch(reject);
|
||||||
|
})
|
||||||
|
return { filename: cacheFilename, hash };
|
||||||
|
});
|
||||||
|
//重命名
|
||||||
|
const filename = `${res.hash}.${saveType}`;
|
||||||
|
fs.promises.rename(res.filename, path.join(cacheDir, filename));
|
||||||
|
//保存缓存
|
||||||
|
if (!cache) {
|
||||||
|
cache = { url, filename };
|
||||||
|
cacheJsonData.push(cache);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
cache.filename = filename;
|
||||||
|
cache.url = url;
|
||||||
|
}
|
||||||
|
fs.promises.writeFile(cacheJsonFile, JSON.stringify(cacheJsonData, null, 4));
|
||||||
|
}
|
||||||
|
//返回模型数据
|
||||||
|
return path.join(cacheDir, cache.filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
//查找任务
|
||||||
|
let cache = cacheDownloadTasks.find(c => c.url === url && c.cacheDir === cacheDir);
|
||||||
|
if (!cache) {
|
||||||
|
cache = { url, cacheDir, modelPath: resolveModel() }
|
||||||
|
cacheDownloadTasks.push(cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
//获取模型数据
|
||||||
|
const modelPath = await cache.modelPath;
|
||||||
|
const modelType = path.extname(modelPath).substring(1) as ModelType;
|
||||||
let model: T | undefined = undefined;
|
let model: T | undefined = undefined;
|
||||||
if (option?.createModel) {
|
if (option?.createModel) {
|
||||||
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
if (modelType === "onnx") model = (this as any).fromOnnx(modelPath);
|
||||||
|
else if (modelType == "mnn") model = (this as any).fromMNN(modelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//返回结果
|
||||||
return { modelPath, modelType, model: model as any }
|
return { modelPath, modelType, model: model as any }
|
||||||
}
|
}
|
||||||
|
|
||||||
public constructor(session: backend.common.CommonSession) { this.session = session; }
|
public constructor(session: backend.CommonSession) { this.session = session; }
|
||||||
|
|
||||||
public get inputs() { return this.session.inputs; }
|
public get inputs() { return this.session.inputs; }
|
||||||
public get outputs() { return this.session.outputs; }
|
public get outputs() { return this.session.outputs; }
|
||||||
|
@ -5,10 +5,41 @@ export interface FacePoint {
|
|||||||
|
|
||||||
type PointType = "leftEye" | "rightEye" | "leftEyebrow" | "rightEyebrow" | "nose" | "mouth" | "contour"
|
type PointType = "leftEye" | "rightEye" | "leftEyebrow" | "rightEyebrow" | "nose" | "mouth" | "contour"
|
||||||
|
|
||||||
|
export function indexFromTo(from: number, to: number) {
|
||||||
|
const indexes: number[] = [];
|
||||||
|
for (let i = from; i <= to; i++) {
|
||||||
|
indexes.push(i);
|
||||||
|
}
|
||||||
|
return indexes;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PointGroup {
|
||||||
|
/** 用于判断方向的两个点的索引(建议选取眼球中间的点) */
|
||||||
|
directionPointIndex: [number, number];
|
||||||
|
/** 左眼点的索引 */
|
||||||
|
leftEyePointIndex: number[];
|
||||||
|
/** 右眼点的索引 */
|
||||||
|
rightEyePointIndex: number[];
|
||||||
|
/** 左眉点的索引 */
|
||||||
|
leftEyebrowPointIndex: number[];
|
||||||
|
/** 右眉点的索引 */
|
||||||
|
rightEyebrowPointIndex: number[];
|
||||||
|
/** 嘴巴点的索引 */
|
||||||
|
mouthPointIndex: number[];
|
||||||
|
/** 鼻子的索引 */
|
||||||
|
nosePointIndex: number[];
|
||||||
|
/** 轮廓点的索引 */
|
||||||
|
contourPointIndex: number[];
|
||||||
|
}
|
||||||
|
|
||||||
export abstract class FaceAlignmentResult {
|
export abstract class FaceAlignmentResult {
|
||||||
#points: FacePoint[]
|
#points: FacePoint[]
|
||||||
|
#group: PointGroup
|
||||||
|
|
||||||
public constructor(points: FacePoint[]) { this.#points = points; }
|
public constructor(points: FacePoint[], group: PointGroup) {
|
||||||
|
this.#points = points;
|
||||||
|
this.#group = group;
|
||||||
|
}
|
||||||
|
|
||||||
/** 关键点 */
|
/** 关键点 */
|
||||||
public get points() { return this.#points; }
|
public get points() { return this.#points; }
|
||||||
@ -18,7 +49,7 @@ export abstract class FaceAlignmentResult {
|
|||||||
if (typeof type == "string") type = [type];
|
if (typeof type == "string") type = [type];
|
||||||
const result: FacePoint[] = [];
|
const result: FacePoint[] = [];
|
||||||
for (const t of type) {
|
for (const t of type) {
|
||||||
for (const idx of this[`${t}PointIndex` as const]()) {
|
for (const idx of this.#group[`${t}PointIndex` as const]) {
|
||||||
result.push(this.points[idx]);
|
result.push(this.points[idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -27,32 +58,7 @@ export abstract class FaceAlignmentResult {
|
|||||||
|
|
||||||
/** 方向 */
|
/** 方向 */
|
||||||
public get direction() {
|
public get direction() {
|
||||||
const [{ x: x1, y: y1 }, { x: x2, y: y2 }] = this.directionPointIndex().map(idx => this.points[idx]);
|
const [{ x: x1, y: y1 }, { x: x2, y: y2 }] = this.#group.directionPointIndex.map(idx => this.points[idx]);
|
||||||
return Math.atan2(y1 - y2, x2 - x1)
|
return Math.atan2(y1 - y2, x2 - x1)
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 用于判断方向的两个点的索引(建议选取眼球中间的点) */
|
|
||||||
protected abstract directionPointIndex(): [number, number];
|
|
||||||
/** 左眼点的索引 */
|
|
||||||
protected abstract leftEyePointIndex(): number[];
|
|
||||||
/** 右眼点的索引 */
|
|
||||||
protected abstract rightEyePointIndex(): number[];
|
|
||||||
/** 左眉点的索引 */
|
|
||||||
protected abstract leftEyebrowPointIndex(): number[];
|
|
||||||
/** 右眉点的索引 */
|
|
||||||
protected abstract rightEyebrowPointIndex(): number[];
|
|
||||||
/** 嘴巴点的索引 */
|
|
||||||
protected abstract mouthPointIndex(): number[];
|
|
||||||
/** 鼻子的索引 */
|
|
||||||
protected abstract nosePointIndex(): number[];
|
|
||||||
/** 轮廓点的索引 */
|
|
||||||
protected abstract contourPointIndex(): number[];
|
|
||||||
|
|
||||||
protected indexFromTo(from: number, to: number) {
|
|
||||||
const indexes: number[] = [];
|
|
||||||
for (let i = from; i <= to; i++) {
|
|
||||||
indexes.push(i);
|
|
||||||
}
|
|
||||||
return indexes;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -1,25 +1,30 @@
|
|||||||
import { writeFileSync } from "fs";
|
import cv from "@yizhi/cv";
|
||||||
import { cv } from "../../cv";
|
|
||||||
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
import { FaceAlignmentResult, FacePoint } from "./common";
|
import { FaceAlignmentResult, FacePoint, indexFromTo } from "./common";
|
||||||
|
|
||||||
interface FaceLandmark1000PredictOption extends ImageCropOption { }
|
interface FaceLandmark1000PredictOption extends ImageCropOption { }
|
||||||
|
|
||||||
class FaceLandmark1000Result extends FaceAlignmentResult {
|
class FaceLandmark1000Result extends FaceAlignmentResult {
|
||||||
protected directionPointIndex(): [number, number] { return [401, 529]; }
|
public constructor(points: FacePoint[]) {
|
||||||
protected leftEyePointIndex(): number[] { return this.indexFromTo(401, 528); }
|
super(points, {
|
||||||
protected rightEyePointIndex(): number[] { return this.indexFromTo(529, 656); }
|
directionPointIndex: [401, 529],
|
||||||
protected leftEyebrowPointIndex(): number[] { return this.indexFromTo(273, 336); }
|
leftEyePointIndex: indexFromTo(401, 528),
|
||||||
protected rightEyebrowPointIndex(): number[] { return this.indexFromTo(337, 400); }
|
rightEyePointIndex: indexFromTo(529, 656),
|
||||||
protected mouthPointIndex(): number[] { return this.indexFromTo(845, 972); }
|
leftEyebrowPointIndex: indexFromTo(273, 336),
|
||||||
protected nosePointIndex(): number[] { return this.indexFromTo(657, 844); }
|
rightEyebrowPointIndex: indexFromTo(337, 400),
|
||||||
protected contourPointIndex(): number[] { return this.indexFromTo(0, 272); }
|
mouthPointIndex: indexFromTo(845, 972),
|
||||||
|
nosePointIndex: indexFromTo(657, 844),
|
||||||
|
contourPointIndex: indexFromTo(0, 272),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
FACELANDMARK1000_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.onnx`,
|
||||||
|
FACELANDMARK1000_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/FaceLandmark1000.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class FaceLandmark1000 extends Model {
|
export class FaceLandmark1000 extends Model {
|
||||||
@ -32,10 +37,10 @@ export class FaceLandmark1000 extends Model {
|
|||||||
|
|
||||||
public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) {
|
public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
if (option?.crop) image = cv.crop(image, option?.crop);
|
||||||
const ratioWidth = image.width / input.shape[3];
|
const ratioWidth = image.cols / input.shape[3];
|
||||||
const ratioHeight = image.height / input.shape[2];
|
const ratioHeight = image.rows / input.shape[2];
|
||||||
image = image.resize(input.shape[3], input.shape[2]);
|
image = cv.resize(image, input.shape[3], input.shape[2]);
|
||||||
|
|
||||||
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "gray", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
|
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "gray", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
|
||||||
|
|
||||||
@ -45,12 +50,12 @@ export class FaceLandmark1000 extends Model {
|
|||||||
data: nchwImageData,
|
data: nchwImageData,
|
||||||
type: "float32",
|
type: "float32",
|
||||||
}
|
}
|
||||||
}).then(res => res[this.output.name]);
|
}).then(res => res[this.output.name].data);
|
||||||
|
|
||||||
const points: FacePoint[] = [];
|
const points: FacePoint[] = [];
|
||||||
for (let i = 0; i < res.length; i += 2) {
|
for (let i = 0; i < res.length; i += 2) {
|
||||||
const x = res[i] * image.width * ratioWidth;
|
const x = res[i] * image.cols * ratioWidth;
|
||||||
const y = res[i + 1] * image.height * ratioHeight;
|
const y = res[i + 1] * image.rows * ratioHeight;
|
||||||
points.push({ x, y });
|
points.push({ x, y });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { writeFileSync } from "fs";
|
import cv from "@yizhi/cv";
|
||||||
import { cv } from "../../cv";
|
|
||||||
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
import { FaceAlignmentResult, FacePoint } from "./common";
|
import { FaceAlignmentResult, FacePoint } from "./common";
|
||||||
@ -7,20 +6,27 @@ import { FaceAlignmentResult, FacePoint } from "./common";
|
|||||||
export interface PFLDPredictOption extends ImageCropOption { }
|
export interface PFLDPredictOption extends ImageCropOption { }
|
||||||
|
|
||||||
class PFLDResult extends FaceAlignmentResult {
|
class PFLDResult extends FaceAlignmentResult {
|
||||||
protected directionPointIndex(): [number, number] { return [36, 92]; }
|
public constructor(points: FacePoint[]) {
|
||||||
protected leftEyePointIndex(): number[] { return [33, 34, 35, 36, 37, 38, 39, 40, 41, 42]; }
|
super(points, {
|
||||||
protected rightEyePointIndex(): number[] { return [87, 88, 89, 90, 91, 92, 93, 94, 95, 96]; }
|
directionPointIndex: [36, 92],
|
||||||
protected leftEyebrowPointIndex(): number[] { return [43, 44, 45, 46, 47, 48, 49, 50, 51]; }
|
leftEyePointIndex: [33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
|
||||||
protected rightEyebrowPointIndex(): number[] { return [97, 98, 99, 100, 101, 102, 103, 104, 105]; }
|
rightEyePointIndex: [87, 88, 89, 90, 91, 92, 93, 94, 95, 96],
|
||||||
protected mouthPointIndex(): number[] { return [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]; }
|
leftEyebrowPointIndex: [43, 44, 45, 46, 47, 48, 49, 50, 51],
|
||||||
protected nosePointIndex(): number[] { return [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86]; }
|
rightEyebrowPointIndex: [97, 98, 99, 100, 101, 102, 103, 104, 105],
|
||||||
protected contourPointIndex(): number[] { return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]; }
|
mouthPointIndex: [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71],
|
||||||
|
nosePointIndex: [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86],
|
||||||
|
contourPointIndex: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
PFLD_106_LITE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.onnx`,
|
||||||
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
PFLD_106_V2_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.onnx`,
|
||||||
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
PFLD_106_V3_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.onnx`,
|
||||||
|
PFLD_106_LITE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-lite.mnn`,
|
||||||
|
PFLD_106_V2_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v2.mnn`,
|
||||||
|
PFLD_106_V3_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facealign/pfld-106-v3.mnn`,
|
||||||
};
|
};
|
||||||
export class PFLD extends Model {
|
export class PFLD extends Model {
|
||||||
|
|
||||||
@ -32,10 +38,10 @@ export class PFLD extends Model {
|
|||||||
|
|
||||||
private async doPredict(image: cv.Mat, option?: PFLDPredictOption) {
|
private async doPredict(image: cv.Mat, option?: PFLDPredictOption) {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
if (option?.crop) image = cv.crop(image, option.crop);
|
||||||
const ratioWidth = image.width / input.shape[3];
|
const ratioWidth = image.cols / input.shape[3];
|
||||||
const ratioHeight = image.height / input.shape[2];
|
const ratioHeight = image.rows / input.shape[2];
|
||||||
image = image.resize(input.shape[3], input.shape[2]);
|
image = cv.resize(image, input.shape[3], input.shape[2]);
|
||||||
|
|
||||||
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [255] } })
|
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [255] } })
|
||||||
|
|
||||||
@ -48,12 +54,12 @@ export class PFLD extends Model {
|
|||||||
shape: [1, 3, input.shape[2], input.shape[3]],
|
shape: [1, 3, input.shape[2], input.shape[3]],
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const pointsBuffer = res[pointsOutput.name];
|
const pointsBuffer = res[pointsOutput.name].data;
|
||||||
|
|
||||||
const points: FacePoint[] = [];
|
const points: FacePoint[] = [];
|
||||||
for (let i = 0; i < pointsBuffer.length; i += 2) {
|
for (let i = 0; i < pointsBuffer.length; i += 2) {
|
||||||
const x = pointsBuffer[i] * image.width * ratioWidth;
|
const x = pointsBuffer[i] * image.cols * ratioWidth;
|
||||||
const y = pointsBuffer[i + 1] * image.height * ratioHeight;
|
const y = pointsBuffer[i + 1] * image.rows * ratioHeight;
|
||||||
points.push({ x, y });
|
points.push({ x, y });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { cv } from "../../cv";
|
import cv from "@yizhi/cv";
|
||||||
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
|
|
||||||
@ -12,6 +12,7 @@ export interface GenderAgePredictResult {
|
|||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
||||||
|
INSIGHT_GENDER_AGE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class GenderAge extends Model {
|
export class GenderAge extends Model {
|
||||||
@ -24,8 +25,8 @@ export class GenderAge extends Model {
|
|||||||
private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise<GenderAgePredictResult> {
|
private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise<GenderAgePredictResult> {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
const output = this.output;
|
const output = this.output;
|
||||||
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
if (option?.crop) image = cv.crop(image, option.crop);
|
||||||
image = image.resize(input.shape[3], input.shape[2]);
|
image = cv.resize(image, input.shape[3], input.shape[2]);
|
||||||
|
|
||||||
const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
|
const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
|
||||||
|
|
||||||
@ -35,7 +36,7 @@ export class GenderAge extends Model {
|
|||||||
data: nchwImage,
|
data: nchwImage,
|
||||||
type: "float32",
|
type: "float32",
|
||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name].data);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
gender: result[0] > result[1] ? "F" : "M",
|
gender: result[0] > result[1] ? "F" : "M",
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { cv } from "../../cv"
|
import cv from "@yizhi/cv"
|
||||||
import { ImageSource, Model } from "../common/model"
|
import { ImageSource, Model } from "../common/model"
|
||||||
|
|
||||||
interface IFaceBoxConstructorOption {
|
interface IFaceBoxConstructorOption {
|
||||||
@ -49,8 +49,8 @@ export class FaceBox {
|
|||||||
|
|
||||||
return new FaceBox({
|
return new FaceBox({
|
||||||
...this.#option,
|
...this.#option,
|
||||||
x1: this.centerX - size, y1: this.centerY - size,
|
x1: cx - size, y1: cy - size,
|
||||||
x2: this.centerX + size, y2: this.centerY + size,
|
x2: cx + size, y2: cy + size,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { cv } from "../../cv";
|
import cv from "@yizhi/cv";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
|
import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common";
|
||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
YOLOV5S_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.onnx`,
|
||||||
|
YOLOV5S_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/facedet/yolov5s.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class Yolov5Face extends FaceDetector {
|
export class Yolov5Face extends FaceDetector {
|
||||||
@ -14,9 +15,9 @@ export class Yolov5Face extends FaceDetector {
|
|||||||
|
|
||||||
public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise<FaceBox[]> {
|
public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise<FaceBox[]> {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
const resizedImage = image.resize(input.shape[2], input.shape[3]);
|
const resizedImage = cv.resize(image, input.shape[2], input.shape[3]);
|
||||||
const ratioWidth = image.width / resizedImage.width;
|
const ratioWidth = image.cols / resizedImage.cols;
|
||||||
const ratioHeight = image.height / resizedImage.height;
|
const ratioHeight = image.rows / resizedImage.rows;
|
||||||
const nchwImageData = convertImage(resizedImage.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
const nchwImageData = convertImage(resizedImage.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
||||||
|
|
||||||
const outputData = await this.session.run({ input: nchwImageData }).then(r => r.output);
|
const outputData = await this.session.run({ input: nchwImageData }).then(r => r.output);
|
||||||
@ -26,7 +27,7 @@ export class Yolov5Face extends FaceDetector {
|
|||||||
const threshold = option?.threshold ?? 0.5;
|
const threshold = option?.threshold ?? 0.5;
|
||||||
for (let i = 0; i < outShape[1]; i++) {
|
for (let i = 0; i < outShape[1]; i++) {
|
||||||
const beg = i * outShape[2];
|
const beg = i * outShape[2];
|
||||||
const rectData = outputData.slice(beg, beg + outShape[2]);
|
const rectData = outputData.data.slice(beg, beg + outShape[2]);
|
||||||
const x = parseInt(rectData[0] * ratioWidth as any);
|
const x = parseInt(rectData[0] * ratioWidth as any);
|
||||||
const y = parseInt(rectData[1] * ratioHeight as any);
|
const y = parseInt(rectData[1] * ratioHeight as any);
|
||||||
const w = parseInt(rectData[2] * ratioWidth as any);
|
const w = parseInt(rectData[2] * ratioWidth as any);
|
||||||
@ -36,7 +37,7 @@ export class Yolov5Face extends FaceDetector {
|
|||||||
faces.push(new FaceBox({
|
faces.push(new FaceBox({
|
||||||
x1: x - w / 2, y1: y - h / 2,
|
x1: x - w / 2, y1: y - h / 2,
|
||||||
x2: x + w / 2, y2: y + h / 2,
|
x2: x + w / 2, y2: y + h / 2,
|
||||||
score, imw: image.width, imh: image.height,
|
score, imw: image.cols, imh: image.rows,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
return nms(faces, option?.mnsThreshold ?? 0.3).map(box => box.toInt());
|
return nms(faces, option?.mnsThreshold ?? 0.3).map(box => box.toInt());
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { Mat } from "../../cv/mat";
|
import cv from "@yizhi/cv";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
||||||
|
|
||||||
const MODEL_URL_CONFIG = {
|
const MODEL_URL_CONFIG = {
|
||||||
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
MOBILEFACENET_ADAFACE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.onnx`,
|
||||||
|
MOBILEFACENET_ADAFACE_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/mobilefacenet_adaface.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export class AdaFace extends FaceRecognition {
|
export class AdaFace extends FaceRecognition {
|
||||||
@ -12,12 +13,12 @@ export class AdaFace extends FaceRecognition {
|
|||||||
return this.cacheModel(MODEL_URL_CONFIG[type ?? "MOBILEFACENET_ADAFACE_ONNX"], { createModel: true }).then(r => r.model);
|
return this.cacheModel(MODEL_URL_CONFIG[type ?? "MOBILEFACENET_ADAFACE_ONNX"], { createModel: true }).then(r => r.model);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
public async doPredict(image: cv.Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
const output = this.output;
|
const output = this.output;
|
||||||
|
|
||||||
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
if (option?.crop) image = cv.crop(image, option.crop);
|
||||||
image = image.resize(input.shape[3], input.shape[2]);
|
image = cv.resize(image, input.shape[3], input.shape[2]);
|
||||||
|
|
||||||
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ export class AdaFace extends FaceRecognition {
|
|||||||
data: nchwImageData,
|
data: nchwImageData,
|
||||||
shape: [1, 3, input.shape[2], input.shape[3]],
|
shape: [1, 3, input.shape[2], input.shape[3]],
|
||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name].data);
|
||||||
|
|
||||||
return new Array(...embedding);
|
return new Array(...embedding);
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { cv } from "../../cv";
|
import cv from "@yizhi/cv";
|
||||||
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
||||||
|
|
||||||
export interface FaceRecognitionPredictOption extends ImageCropOption { }
|
export interface FaceRecognitionPredictOption extends ImageCropOption { }
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Mat } from "../../cv/mat";
|
import cv from "@yizhi/cv";
|
||||||
import { convertImage } from "../common/processors";
|
import { convertImage } from "../common/processors";
|
||||||
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
import { FaceRecognition, FaceRecognitionPredictOption } from "./common";
|
||||||
|
|
||||||
@ -7,26 +7,33 @@ const MODEL_URL_CONFIG_ARC_FACE = {
|
|||||||
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
INSIGHTFACE_ARCFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.onnx`,
|
||||||
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
INSIGHTFACE_ARCFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.onnx`,
|
||||||
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
INSIGHTFACE_ARCFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.onnx`,
|
||||||
|
INSIGHTFACE_ARCFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r50.mnn`,
|
||||||
|
INSIGHTFACE_ARCFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r34.mnn`,
|
||||||
|
INSIGHTFACE_ARCFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/ms1mv3_arcface_r18.mnn`,
|
||||||
};
|
};
|
||||||
const MODEL_URL_CONFIG_COS_FACE = {
|
const MODEL_URL_CONFIG_COS_FACE = {
|
||||||
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
INSIGHTFACE_COSFACE_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r100.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
INSIGHTFACE_COSFACE_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
INSIGHTFACE_COSFACE_R34_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.onnx`,
|
||||||
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
INSIGHTFACE_COSFACE_R18_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.onnx`,
|
||||||
|
INSIGHTFACE_COSFACE_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r50.mnn`,
|
||||||
|
INSIGHTFACE_COSFACE_R34_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r34.mnn`,
|
||||||
|
INSIGHTFACE_COSFACE_R18_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/glint360k_cosface_r18.mnn`,
|
||||||
};
|
};
|
||||||
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
const MODEL_URL_CONFIG_PARTIAL_FC = {
|
||||||
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
INSIGHTFACE_PARTIALFC_R100_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r100.onnx`,
|
||||||
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
INSIGHTFACE_PARTIALFC_R50_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.onnx`,
|
||||||
|
INSIGHTFACE_PARTIALFC_R50_MNN: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceid/insightface/partial_fc_glint360k_r50.mnn`,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
export class Insightface extends FaceRecognition {
|
export class Insightface extends FaceRecognition {
|
||||||
|
|
||||||
public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
public async doPredict(image: cv.Mat, option?: FaceRecognitionPredictOption): Promise<number[]> {
|
||||||
const input = this.input;
|
const input = this.input;
|
||||||
const output = this.output;
|
const output = this.output;
|
||||||
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
if (option?.crop) image = cv.crop(image, option.crop);
|
||||||
image = image.resize(input.shape[3], input.shape[2]);
|
image = cv.resize(image, input.shape[3], input.shape[2]);
|
||||||
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } });
|
||||||
|
|
||||||
const embedding = await this.session.run({
|
const embedding = await this.session.run({
|
||||||
@ -37,7 +44,7 @@ export class Insightface extends FaceRecognition {
|
|||||||
}
|
}
|
||||||
}).then(res => res[output.name]);
|
}).then(res => res[output.name]);
|
||||||
|
|
||||||
return new Array(...embedding);
|
return new Array(...embedding.data);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
import * as ai from "./main";
|
||||||
|
|
||||||
|
export default ai;
|
||||||
|
export { ai };
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
export { deploy } from "./deploy";
|
||||||
|
export { backend } from "./backend";
|
||||||
|
export { setConfig as config } from "./config";
|
||||||
|
66
src/test.ts
66
src/test.ts
@ -1,10 +1,21 @@
|
|||||||
import fs from "fs";
|
import fs from "fs";
|
||||||
|
import cv from "@yizhi/cv";
|
||||||
import { deploy } from "./deploy";
|
import { deploy } from "./deploy";
|
||||||
import { cv } from "./cv";
|
|
||||||
import { faceidTestData } from "./test_data/faceid";
|
import { faceidTestData } from "./test_data/faceid";
|
||||||
import path from "path";
|
import path from "path";
|
||||||
import crypto from "crypto";
|
import crypto from "crypto";
|
||||||
|
|
||||||
|
cv.config("ADDON_PATH", path.join(__dirname, "../build/cv.node"));
|
||||||
|
|
||||||
|
function loadImage(url: string) {
|
||||||
|
if (/https?:\/\//.test(url)) return fetch(url).then(res => res.arrayBuffer()).then(data => cv.imdecode(new Uint8Array(data)));
|
||||||
|
else return import("fs").then(fs => fs.promises.readFile(url)).then(buffer => cv.imdecode(buffer));
|
||||||
|
}
|
||||||
|
|
||||||
|
function rotate(im: cv.Mat, x: number, y: number, angle: number) {
|
||||||
|
return cv.warpAffine(im, cv.getRotationMatrix2D(x, y, angle, 1), im.cols, im.rows);
|
||||||
|
}
|
||||||
|
|
||||||
async function cacheImage(group: string, url: string) {
|
async function cacheImage(group: string, url: string) {
|
||||||
const _url = new URL(url);
|
const _url = new URL(url);
|
||||||
const cacheDir = path.join(__dirname, "../cache/images", group);
|
const cacheDir = path.join(__dirname, "../cache/images", group);
|
||||||
@ -30,31 +41,33 @@ async function cacheImage(group: string, url: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function testGenderTest() {
|
async function testGenderTest() {
|
||||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_ONNX");
|
const detector = await deploy.faceattr.GenderAgeDetector.load("INSIGHT_GENDER_AGE_MNN");
|
||||||
|
|
||||||
const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
const image = await loadImage("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg");
|
||||||
const boxes = await facedet.predict(image);
|
const boxes = await facedet.predict(image);
|
||||||
if (!boxes.length) return console.error("未检测到人脸");
|
if (!boxes.length) return console.error("未检测到人脸");
|
||||||
for (const [idx, box] of boxes.entries()) {
|
for (const [idx, box] of boxes.entries()) {
|
||||||
const res = await detector.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
const res = await detector.predict(image, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
|
||||||
console.log(`[${idx + 1}]`, res);
|
console.log(`[${idx + 1}]`, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function testFaceID() {
|
async function testFaceID() {
|
||||||
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
console.log("初始化模型")
|
||||||
const faceid = await deploy.faceid.PartialFC.load();
|
const facedet = await deploy.facedet.Yolov5Face.load("YOLOV5S_MNN");
|
||||||
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
const faceid = await deploy.faceid.CosFace.load("INSIGHTFACE_COSFACE_R50_MNN");
|
||||||
|
const facealign = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||||
|
console.log("初始化模型完成")
|
||||||
|
|
||||||
const { basic, tests } = faceidTestData.stars;
|
const { basic, tests } = faceidTestData.stars;
|
||||||
|
|
||||||
console.log("正在加载图片资源");
|
console.log("正在加载图片资源");
|
||||||
const basicImage = await cv.Mat.load(await cacheImage("faceid", basic.image));
|
const basicImage = await loadImage(await cacheImage("faceid", basic.image));
|
||||||
const testsImages: Record<string, cv.Mat[]> = {};
|
const testsImages: Record<string, cv.Mat[]> = {};
|
||||||
for (const [name, imgs] of Object.entries(tests)) {
|
for (const [name, imgs] of Object.entries(tests)) {
|
||||||
testsImages[name] = await Promise.all(imgs.map(img => cacheImage("faceid", img).then(img => cv.Mat.load(img))));
|
testsImages[name] = await Promise.all(imgs.map(img => cacheImage("faceid", img).then(img => loadImage(img))));
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log("正在检测基本数据");
|
console.log("正在检测基本数据");
|
||||||
@ -66,9 +79,9 @@ async function testFaceID() {
|
|||||||
|
|
||||||
async function getEmbd(image: cv.Mat, box: deploy.facedet.FaceBox) {
|
async function getEmbd(image: cv.Mat, box: deploy.facedet.FaceBox) {
|
||||||
box = box.toSquare();
|
box = box.toSquare();
|
||||||
const alignResult = await facealign.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
const alignResult = await facealign.predict(image, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
|
||||||
let faceImage = image.rotate(box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI);
|
let faceImage = rotate(image, box.centerX, box.centerY, -alignResult.direction * 180 / Math.PI);
|
||||||
return faceid.predict(faceImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } });
|
return faceid.predict(faceImage, { crop: { x: box.left, y: box.top, width: box.width, height: box.height } });
|
||||||
}
|
}
|
||||||
|
|
||||||
const basicEmbds: number[][] = [];
|
const basicEmbds: number[][] = [];
|
||||||
@ -111,24 +124,28 @@ async function testFaceID() {
|
|||||||
|
|
||||||
async function testFaceAlign() {
|
async function testFaceAlign() {
|
||||||
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
const fd = await deploy.facedet.Yolov5Face.load("YOLOV5S_ONNX");
|
||||||
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_ONNX");
|
const fa = await deploy.facealign.PFLD.load("PFLD_106_LITE_MNN");
|
||||||
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
// const fa = await deploy.facealign.FaceLandmark1000.load("FACELANDMARK1000_ONNX");
|
||||||
let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712");
|
let image = await loadImage("https://i0.hdslb.com/bfs/archive/64e47ec9fdac9e24bc2b49b5aaad5560da1bfe3e.jpg");
|
||||||
image = image.rotate(image.width / 2, image.height / 2, 0);
|
image = rotate(image, image.cols / 2, image.rows / 2, 0);
|
||||||
|
|
||||||
const face = await fd.predict(image).then(res => res[0].toSquare());
|
const face = await fd.predict(image).then(res => {
|
||||||
const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } });
|
console.log(res);
|
||||||
|
return res[0].toSquare()
|
||||||
|
});
|
||||||
|
const points = await fa.predict(image, { crop: { x: face.left, y: face.top, width: face.width, height: face.height } });
|
||||||
|
|
||||||
points.points.forEach((point, idx) => {
|
points.points.forEach((point, idx) => {
|
||||||
image.circle(face.left + point.x, face.top + point.y, 2);
|
cv.circle(image, face.left + point.x, face.top + point.y, 2, 0, 0, 0);
|
||||||
})
|
})
|
||||||
// const point = points.getPointsOf("rightEye")[1];
|
// const point = points.getPointsOf("rightEye")[1];
|
||||||
// image.circle(face.left + point.x, face.top + point.y, 2);
|
// image.circle(face.left + point.x, face.top + point.y, 2);
|
||||||
fs.writeFileSync("testdata/xx.jpg", image.encode(".jpg"));
|
// fs.writeFileSync("testdata/xx.jpg", image.encode(".jpg"));
|
||||||
|
cv.imwrite("testdata/xx.jpg", image);
|
||||||
|
|
||||||
let faceImage = image.rotate(face.centerX, face.centerY, -points.direction * 180 / Math.PI);
|
let faceImage = rotate(image, face.centerX, face.centerY, -points.direction * 180 / Math.PI);
|
||||||
faceImage = faceImage.crop(face.left, face.top, face.width, face.height);
|
faceImage = cv.crop(faceImage, { x: face.left, y: face.top, width: face.width, height: face.height });
|
||||||
fs.writeFileSync("testdata/face.jpg", faceImage.encode(".jpg"));
|
fs.writeFileSync("testdata/face.jpg", cv.imencode(".jpg", faceImage)!);
|
||||||
|
|
||||||
console.log(points);
|
console.log(points);
|
||||||
console.log(points.direction * 180 / Math.PI);
|
console.log(points.direction * 180 / Math.PI);
|
||||||
@ -142,5 +159,4 @@ async function test() {
|
|||||||
|
|
||||||
test().catch(err => {
|
test().catch(err => {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
debugger
|
});
|
||||||
});
|
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
|
|
||||||
export namespace utils {
|
|
||||||
|
|
||||||
export function rgba2rgb<T extends Uint8Array | Float32Array>(data: T): T {
|
|
||||||
const pixelCount = data.length / 4;
|
|
||||||
const result = new (data.constructor as any)(pixelCount * 3) as T;
|
|
||||||
for (let i = 0; i < pixelCount; i++) {
|
|
||||||
result[i * 3 + 0] = data[i * 4 + 0];
|
|
||||||
result[i * 3 + 1] = data[i * 4 + 1];
|
|
||||||
result[i * 3 + 2] = data[i * 4 + 2];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function rgb2bgr<T extends Uint8Array | Float32Array>(data: T): T {
|
|
||||||
const pixelCount = data.length / 3;
|
|
||||||
const result = new (data.constructor as any)(pixelCount * 3) as T;
|
|
||||||
for (let i = 0; i < pixelCount; i++) {
|
|
||||||
result[i * 3 + 0] = data[i * 3 + 2];
|
|
||||||
result[i * 3 + 1] = data[i * 3 + 1];
|
|
||||||
result[i * 3 + 2] = data[i * 3 + 0];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function normalize(data: Uint8Array | Float32Array, mean: number[], std: number[]): Float32Array {
|
|
||||||
const result = new Float32Array(data.length);
|
|
||||||
for (let i = 0; i < data.length; i++) {
|
|
||||||
result[i] = (data[i] - mean[i % mean.length]) / std[i % std.length];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
@ -28,7 +28,6 @@ function assert(cond, message) {
|
|||||||
|
|
||||||
const buildOptions = {
|
const buildOptions = {
|
||||||
withMNN: findArg("with-mnn", true) ?? false,
|
withMNN: findArg("with-mnn", true) ?? false,
|
||||||
withOpenCV: findArg("with-opencv", true) ?? false,
|
|
||||||
withONNX: findArg("with-onnx", true) ?? false,
|
withONNX: findArg("with-onnx", true) ?? false,
|
||||||
buildType: findArg("build-type", false) ?? "Release",
|
buildType: findArg("build-type", false) ?? "Release",
|
||||||
proxy: findArg("proxy"),
|
proxy: findArg("proxy"),
|
||||||
@ -44,6 +43,8 @@ const spawnOption = {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function P(path) { return path.replace(/\\/g, "/"); }
|
||||||
|
|
||||||
function checkFile(...items) {
|
function checkFile(...items) {
|
||||||
return fs.existsSync(path.resolve(...items));
|
return fs.existsSync(path.resolve(...items));
|
||||||
}
|
}
|
||||||
@ -105,9 +106,6 @@ async function downloadFromURL(name, url, resolver) {
|
|||||||
|
|
||||||
if (!checkFile(saveName)) {
|
if (!checkFile(saveName)) {
|
||||||
console.log(`开始下载${name}, 地址:${url}`);
|
console.log(`开始下载${name}, 地址:${url}`);
|
||||||
await fetch(url).then(res => {
|
|
||||||
console.log(res.status)
|
|
||||||
})
|
|
||||||
const result = spawnSync("curl", ["-o", saveName + ".cache", "-L", url, "-s", "-w", "%{http_code}"], { ...spawnOption, stdio: "pipe" });
|
const result = spawnSync("curl", ["-o", saveName + ".cache", "-L", url, "-s", "-w", "%{http_code}"], { ...spawnOption, stdio: "pipe" });
|
||||||
assert(result.status == 0 && result.stdout.toString() == "200", `下载${name}失败`);
|
assert(result.status == 0 && result.stdout.toString() == "200", `下载${name}失败`);
|
||||||
fs.renameSync(saveName + ".cache", saveName);
|
fs.renameSync(saveName + ".cache", saveName);
|
||||||
@ -151,37 +149,16 @@ async function main() {
|
|||||||
"-DMNN_AVX512=ON",
|
"-DMNN_AVX512=ON",
|
||||||
"-DMNN_BUILD_TOOLS=ON",
|
"-DMNN_BUILD_TOOLS=ON",
|
||||||
"-DMNN_BUILD_CONVERTER=OFF",
|
"-DMNN_BUILD_CONVERTER=OFF",
|
||||||
"-DMNN_WIN_RUNTIME_MT=OFF",
|
"-DMNN_WIN_RUNTIME_MT=ON",
|
||||||
], (root) => [
|
], (root) => [
|
||||||
`set(MNN_INCLUDE_DIR ${JSON.stringify(path.join(root, "include"))})`,
|
`set(MNN_INCLUDE_DIR ${JSON.stringify(P(path.join(root, "include")))})`,
|
||||||
`set(MNN_LIB_DIR ${JSON.stringify(path.join(root, "lib"))})`,
|
`set(MNN_LIB_DIR ${JSON.stringify(P(path.join(root, "lib")))})`,
|
||||||
`set(MNN_LIBS MNN)`,
|
`if(WIN32)`,
|
||||||
|
` set(MNN_LIBS \${MNN_LIB_DIR}/MNN.lib)`,
|
||||||
|
`else()`,
|
||||||
|
` set(MNN_LIBS \${MNN_LIB_DIR}/libMNN.a)`,
|
||||||
|
`endif()`,
|
||||||
].join("\n"));
|
].join("\n"));
|
||||||
//OpenCV
|
|
||||||
if (buildOptions.withOpenCV) cmakeBuildFromSource("OpenCV", "https://github.com/opencv/opencv.git", "4.11.0", null, [
|
|
||||||
"-DBUILD_SHARED_LIBS=OFF",
|
|
||||||
"-DBUILD_opencv_apps=OFF",
|
|
||||||
"-DBUILD_opencv_js=OFF",
|
|
||||||
"-DBUILD_opencv_python2=OFF",
|
|
||||||
"-DBUILD_opencv_python3=OFF",
|
|
||||||
"-DBUILD_ANDROID_PROJECTS=OFF",
|
|
||||||
"-DBUILD_ANDROID_EXAMPLES=OFF",
|
|
||||||
"-DBUILD_TESTS=OFF",
|
|
||||||
"-DBUILD_FAT_JAVA_LIB=OFF",
|
|
||||||
"-DBUILD_ANDROID_SERVICE=OFF",
|
|
||||||
"-DBUILD_JAVA=OFF",
|
|
||||||
"-DBUILD_PERF_TESTS=OFF"
|
|
||||||
], (root) => [
|
|
||||||
`set(OpenCV_STATIC ON)`,
|
|
||||||
os.platform() == "win32" ?
|
|
||||||
`include(${JSON.stringify(path.join(root, "OpenCVConfig.cmake"))})` :
|
|
||||||
`include(${JSON.stringify(path.join(root, "lib/cmake/opencv4/OpenCVConfig.cmake"))})`,
|
|
||||||
`set(OpenCV_INCLUDE_DIR \${OpenCV_INCLUDE_DIRS})`,
|
|
||||||
os.platform() == "win32" ?
|
|
||||||
"set(OpenCV_LIB_DIR ${OpenCV_LIB_PATH})" :
|
|
||||||
`set(OpenCV_LIB_DIR ${JSON.stringify(path.join(root, "lib"))})`,
|
|
||||||
// `set(OpenCV_LIBS OpenCV_LIBS)`,
|
|
||||||
].join("\n"))
|
|
||||||
//ONNXRuntime
|
//ONNXRuntime
|
||||||
if (buildOptions.withONNX && !checkFile(THIRDPARTY_DIR, "ONNXRuntime/config.cmake")) {
|
if (buildOptions.withONNX && !checkFile(THIRDPARTY_DIR, "ONNXRuntime/config.cmake")) {
|
||||||
let url = "";
|
let url = "";
|
||||||
@ -229,9 +206,13 @@ async function main() {
|
|||||||
fs.cpSync(path.join(savedir, dirname), path.join(THIRDPARTY_DIR, "ONNXRuntime"), { recursive: true });
|
fs.cpSync(path.join(savedir, dirname), path.join(THIRDPARTY_DIR, "ONNXRuntime"), { recursive: true });
|
||||||
});
|
});
|
||||||
fs.writeFileSync(path.join(THIRDPARTY_DIR, "ONNXRuntime/config.cmake"), [
|
fs.writeFileSync(path.join(THIRDPARTY_DIR, "ONNXRuntime/config.cmake"), [
|
||||||
`set(ONNXRuntime_INCLUDE_DIR ${JSON.stringify(path.join(THIRDPARTY_DIR, "ONNXRuntime/include"))})`,
|
`set(ONNXRuntime_INCLUDE_DIR ${JSON.stringify(P(path.join(THIRDPARTY_DIR, "ONNXRuntime/include")))})`,
|
||||||
`set(ONNXRuntime_LIB_DIR ${JSON.stringify(path.join(THIRDPARTY_DIR, "ONNXRuntime/lib"))})`,
|
`set(ONNXRuntime_LIB_DIR ${JSON.stringify(P(path.join(THIRDPARTY_DIR, "ONNXRuntime/lib")))})`,
|
||||||
`set(ONNXRuntime_LIBS onnxruntime)`,
|
`if(WIN32)`,
|
||||||
|
` set(ONNXRuntime_LIBS \${ONNXRuntime_LIB_DIR}/onnxruntime.lib)`,
|
||||||
|
`else()`,
|
||||||
|
` set(ONNXRuntime_LIBS \${ONNXRuntime_LIB_DIR}/libonnxruntime.a)`,
|
||||||
|
`endif()`,
|
||||||
].join("\n"));
|
].join("\n"));
|
||||||
}
|
}
|
||||||
// if (buildOptions.withONNX) cmakeBuildFromSource("ONNXRuntime", "https://github.com/csukuangfj/onnxruntime-build.git", "main", (name, repo, branch) => {
|
// if (buildOptions.withONNX) cmakeBuildFromSource("ONNXRuntime", "https://github.com/csukuangfj/onnxruntime-build.git", "main", (name, repo, branch) => {
|
||||||
|
36
tool/convert_mnn.js
Normal file
36
tool/convert_mnn.js
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
const fs = require("fs");
|
||||||
|
const path = require("path");
|
||||||
|
const { spawnSync } = require("child_process");
|
||||||
|
|
||||||
|
const CONVERTER_PATH = "C:/Develop/Libs/mnn/bin/MNNConvert.exe";
|
||||||
|
const MODEL_PATH = path.join(__dirname, "../models");
|
||||||
|
|
||||||
|
function onnx2mnn(model) {
|
||||||
|
const name = path.basename(model, ".onnx");
|
||||||
|
const outputFile = path.join(path.dirname(model), name + ".mnn");
|
||||||
|
if (fs.existsSync(outputFile)) console.log(`Skip ${name}`);
|
||||||
|
else {
|
||||||
|
const result = spawnSync(CONVERTER_PATH, [
|
||||||
|
"-f", "ONNX",
|
||||||
|
"--modelFile", model,
|
||||||
|
"--MNNModel", outputFile,
|
||||||
|
"--bizCode", "biz"
|
||||||
|
]);
|
||||||
|
if (result.status !== 0) {
|
||||||
|
console.error(`Failed to convert ${name}`);
|
||||||
|
console.log(result.stdout.toString());
|
||||||
|
}
|
||||||
|
else console.log(`Convert ${name} success`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveDir(dir) {
|
||||||
|
const files = fs.readdirSync(dir);
|
||||||
|
for (const file of files) {
|
||||||
|
const filepath = path.join(dir, file);
|
||||||
|
if (fs.statSync(filepath).isDirectory()) resolveDir(filepath);
|
||||||
|
else if (path.extname(file) == ".onnx") onnx2mnn(filepath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resolveDir(MODEL_PATH)
|
Reference in New Issue
Block a user