commit b0e71d0ebbbdf87b5098f9958844fffa668b9791 Author: Yizhi <946185759@qq.com> Date: Thu Mar 6 14:38:32 2025 +0800 初次提交 diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..ecab9ed --- /dev/null +++ b/.clang-format @@ -0,0 +1,45 @@ +# 配置参考: +# https://clang.llvm.org/docs/ClangFormatStyleOptions.html + +# 基于那个配置文件 +BasedOnStyle: Microsoft +# 使用TAB +UseTab: Always +# TAB宽度 +TabWidth: 4 +# 行长度限制 +ColumnLimit: 0 +# 允许短的块放在同一行 +AllowShortBlocksOnASingleLine: Empty +# 是否允许短if单行 If true, if (a) return; 可以放到同一行 +AllowShortIfStatementsOnASingleLine: AllIfsAndElse +# 允许短的枚举放在同一行 +AllowShortEnumsOnASingleLine: true +# 允许短的case放在同一行 +AllowShortCaseLabelsOnASingleLine: true +# 允许短的循环保持在同一行 +AllowShortLoopsOnASingleLine: true +# 连续的空行保留几行 +MaxEmptyLinesToKeep: 2 +# 是否允许短方法单行 +AllowShortFunctionsOnASingleLine: InlineOnly +# 是否对include进行排序 +SortIncludes: Never +NamespaceIndentation: All +# case前增加空白 +IndentCaseLabels: true +# 大括号换行 +BraceWrapping: + AfterEnum: false + AfterStruct: false + SplitEmptyFunction: false + AfterClass: false + AfterControlStatement: Never + AfterFunction: true + AfterNamespace: true + AfterUnion: false + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true + BeforeLambdaBody: false + BeforeWhile: false \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..14fe965 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +/node_modules +/.vscode +/build +/cache +/data +/dist +/testdata +/thirdpart +/typing +/package-lock.json +/models diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..7a074fd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,148 @@ +cmake_minimum_required(VERSION 3.10.0) +project(ai-box VERSION 0.1.0 LANGUAGES C CXX) +set(CMAKE_CXX_STANDARD 17) + +if(NOT DEFINED CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +set(TARGET_INCLUDE_DIRS cxx) +set(TARGET_LINK_DIRS) +set(TARGET_LINK_LIBS) +set(COMMON_SOURCE_FILES) +set(NODE_SOURCE_FILES) + +set(NODE_ADDON_FOUND OFF) + +# NodeJS +execute_process( + COMMAND node ${CMAKE_SOURCE_DIR}/thirdpart/cmake-js-util.js --include + RESULT_VARIABLE CMAKE_JS_RESULT + OUTPUT_VARIABLE CMAKE_JS_INC + OUTPUT_STRIP_TRAILING_WHITESPACE +) +if(CMAKE_JS_RESULT EQUAL 0) + execute_process( + COMMAND node ${CMAKE_SOURCE_DIR}/thirdpart/cmake-js-util.js --src + RESULT_VARIABLE CMAKE_JS_RESULT + OUTPUT_VARIABLE CMAKE_JS_SRC + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(CMAKE_JS_RESULT EQUAL 0) + execute_process( + COMMAND node ${CMAKE_SOURCE_DIR}/thirdpart/cmake-js-util.js --lib + RESULT_VARIABLE CMAKE_JS_RESULT + OUTPUT_VARIABLE CMAKE_JS_LIB + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + endif() + + # NAPI + if(CMAKE_JS_RESULT EQUAL 0) + execute_process( + COMMAND node ${CMAKE_SOURCE_DIR}/thirdpart/cmake-js-util.js --napi + RESULT_VARIABLE CMAKE_JS_RESULT + OUTPUT_VARIABLE NODE_ADDON_API_DIR + ) + endif() + if(CMAKE_JS_RESULT EQUAL 0) + message(STATUS "CMAKE_JS_INC: ${CMAKE_JS_INC}") + message(STATUS "CMAKE_JS_SRC: ${CMAKE_JS_SRC}") + message(STATUS "CMAKE_JS_LIB: ${CMAKE_JS_LIB}") + message(STATUS "NODE_ADDON_API_DIR: ${NODE_ADDON_API_DIR}") + list(APPEND TARGET_INCLUDE_DIRS ${CMAKE_JS_INC} ${NODE_ADDON_API_DIR}) + list(APPEND TARGET_LINK_LIBS ${CMAKE_JS_LIB}) + list(APPEND COMMON_SOURCE_FILES ${CMAKE_JS_SRC}) + set(NODE_ADDON_FOUND ON) + endif() +endif() +if(NOT (CMAKE_JS_RESULT EQUAL 0)) + message(FATAL_ERROR "cmake js config failed") +endif() + + +function(add_node_targert TARGET_NAME SOURCE_FILES) + add_library(${TARGET_NAME} SHARED ${CMAKE_JS_SRC} ${SOURCE_FILES}) + target_link_libraries(${TARGET_NAME} ${CMAKE_JS_LIB}) + set_target_properties(${TARGET_NAME} PROPERTIES PREFIX "" SUFFIX ".node") +endfunction() + + + +# MNN +set(MNN_CMAKE_FILE ${CMAKE_SOURCE_DIR}/thirdpart/MNN/${CMAKE_BUILD_TYPE}/config.cmake) +if(EXISTS ${MNN_CMAKE_FILE}) + include(${MNN_CMAKE_FILE}) + message(STATUS "MNN_LIB_DIR: ${MNN_LIB_DIR}") + message(STATUS "MNN_INCLUDE_DIR: ${MNN_INCLUDE_DIR}") + message(STATUS "MNN_LIBS: ${MNN_LIBS}") + include_directories(${MNN_INCLUDE_DIRS}) + list(APPEND TARGET_INCLUDE_DIRS ${MNN_INCLUDE_DIR}) + list(APPEND TARGET_LINK_DIRS ${MNN_LIB_DIR}) + list(APPEND TARGET_LINK_LIBS ${MNN_LIBS}) + add_compile_definitions(USE_MNN) + set(USE_MNN ON) +else() + message(WARNING "MNN not found") +endif() + +# OpenCV +set(OpenCV_CMAKE_FILE ${CMAKE_SOURCE_DIR}/thirdpart/OpenCV/${CMAKE_BUILD_TYPE}/config.cmake) +if(EXISTS ${OpenCV_CMAKE_FILE}) + include(${OpenCV_CMAKE_FILE}) + message(STATUS "OpenCV_LIB_DIR: ${OpenCV_LIB_DIR}") + message(STATUS "OpenCV_INCLUDE_DIR: ${OpenCV_INCLUDE_DIR}") + message(STATUS "OpenCV_LIBS: ${OpenCV_LIBS}") + include_directories(${OpenCV_INCLUDE_DIRS}) + + if(NODE_ADDON_FOUND) + add_node_targert(cv cxx/cv/node.cc) + target_link_libraries(cv ${OpenCV_LIBS}) + target_link_directories(cv PUBLIC ${OpenCV_INCLUDE_DIR}) + target_compile_definitions(cv PUBLIC USE_OPENCV) + endif() + + list(APPEND TARGET_INCLUDE_DIRS ${OpenCV_INCLUDE_DIR}) + list(APPEND TARGET_LINK_DIRS ${OpenCV_LIB_DIR}) + list(APPEND TARGET_LINK_LIBS ${OpenCV_LIBS}) + list(APPEND NODE_SOURCE_FILES cxx/cv/node.cc) +endif() + +# OnnxRuntime +set(ONNXRuntime_CMAKE_FILE ${CMAKE_SOURCE_DIR}/thirdpart/ONNXRuntime/config.cmake) +if(EXISTS ${ONNXRuntime_CMAKE_FILE}) + include(${ONNXRuntime_CMAKE_FILE}) + message(STATUS "ONNXRuntime_LIB_DIR: ${ONNXRuntime_LIB_DIR}") + message(STATUS "ONNXRuntime_INCLUDE_DIR: ${ONNXRuntime_INCLUDE_DIR}") + message(STATUS "ONNXRuntime_LIBS: ${ONNXRuntime_LIBS}") + list(APPEND TARGET_INCLUDE_DIRS ${ONNXRuntime_INCLUDE_DIR}) + list(APPEND TARGET_LINK_DIRS ${ONNXRuntime_LIB_DIR}) + list(APPEND TARGET_LINK_LIBS ${ONNXRuntime_LIBS}) + + if(NODE_ADDON_FOUND) + add_node_targert(ort cxx/ort/node.cc) + target_link_libraries(ort ${ONNXRuntime_LIBS}) + target_link_directories(ort PUBLIC ${ONNXRuntime_LIB_DIR}) + target_compile_definitions(ort PUBLIC USE_ONNXRUNTIME) + endif() + + list(APPEND NODE_SOURCE_FILES + cxx/ort/node.cc + ) +else() + message(WARNING "ONNXRuntime not found") +endif() + + +include_directories(${TARGET_INCLUDE_DIRS}) +link_directories(${TARGET_LINK_DIRS}) +add_library(addon SHARED ${COMMON_SOURCE_FILES} ${NODE_SOURCE_FILES} cxx/node.cc) +target_link_libraries(addon ${TARGET_LINK_LIBS}) +set_target_properties(addon PROPERTIES PREFIX "" SUFFIX ".node") + +# add_executable(test ${COMMON_SOURCE_FILES} cxx/test.cc) +# target_link_libraries(test ${TARGET_LINK_LIBS}) + +if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET) + execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS}) +endif() diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/cxx/common/node.h b/cxx/common/node.h new file mode 100644 index 0000000..2cdaec2 --- /dev/null +++ b/cxx/common/node.h @@ -0,0 +1,24 @@ +#ifndef __COMMON_NODE_H__ +#define __COMMON_NODE_H__ + +#include +#include +#include + +#define NODE_INIT_OBJECT(name, function) \ + do \ + { \ + auto obj = Napi::Object::New(env); \ + function(env, obj); \ + exports.Set(Napi::String::New(env, #name), obj); \ + } while (0) + +inline uint64_t __node_ptr_of__(Napi::Value value) +{ + bool lossless; + return value.As().Uint64Value(&lossless); +} + +#define NODE_PTR_OF(type, value) (reinterpret_cast(__node_ptr_of__(value))) + +#endif diff --git a/cxx/common/session.h b/cxx/common/session.h new file mode 100644 index 0000000..5367697 --- /dev/null +++ b/cxx/common/session.h @@ -0,0 +1,24 @@ +#ifndef __COMMON_SESSION_H__ +#define __COMMON_SESSION_H__ +#include + +namespace ai +{ + + class Tensor + { + public: + virtual ~Tensor() {} + }; + + class Session + { + public: + virtual ~Session() {} + + virtual const std::map> &getInputShapes() const = 0; + virtual const std::map> &getOutputShapes() const = 0; + }; +} + +#endif diff --git a/cxx/cv/node.cc b/cxx/cv/node.cc new file mode 100644 index 0000000..d49319f --- /dev/null +++ b/cxx/cv/node.cc @@ -0,0 +1,145 @@ +#include +#include +#include +#include "node.h" + +using namespace Napi; + +#define MAT_INSTANCE_METHOD(method) InstanceMethod<&CVMat::method>(#method, static_cast(napi_writable | napi_configurable)) + +static FunctionReference *constructor = nullptr; + +class CVMat : public ObjectWrap { + public: + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + Function func = DefineClass(env, "Mat", { + MAT_INSTANCE_METHOD(IsEmpty), + MAT_INSTANCE_METHOD(GetCols), + MAT_INSTANCE_METHOD(GetRows), + MAT_INSTANCE_METHOD(GetChannels), + MAT_INSTANCE_METHOD(Resize), + MAT_INSTANCE_METHOD(Crop), + MAT_INSTANCE_METHOD(Rotate), + MAT_INSTANCE_METHOD(Clone), + + MAT_INSTANCE_METHOD(DrawCircle), + + MAT_INSTANCE_METHOD(Data), + MAT_INSTANCE_METHOD(Encode), + }); + constructor = new FunctionReference(); + *constructor = Napi::Persistent(func); + exports.Set("Mat", func); + env.SetInstanceData(constructor); + return exports; + } + + CVMat(const CallbackInfo &info) + : ObjectWrap(info) + { + int mode = cv::IMREAD_COLOR_BGR; + if (info.Length() > 1 && info[1].IsObject()) { + Object options = info[1].As(); + if (options.Has("mode") && options.Get("mode").IsNumber()) mode = options.Get("mode").As().Int32Value(); + } + + if (info[0].IsString()) im_ = cv::imread(info[0].As().Utf8Value(), mode); + else if (info[0].IsTypedArray()) { + auto buffer = info[0].As().ArrayBuffer(); + uint8_t *bufferPtr = static_cast(buffer.Data()); + std::vector data(bufferPtr, bufferPtr + buffer.ByteLength()); + im_ = cv::imdecode(data, mode); + } + } + + ~CVMat() { im_.release(); } + + Napi::Value IsEmpty(const Napi::CallbackInfo &info) { return Boolean::New(info.Env(), im_.empty()); } + Napi::Value GetCols(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.cols); } + Napi::Value GetRows(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.rows); } + Napi::Value GetChannels(const Napi::CallbackInfo &info) { return Number::New(info.Env(), im_.channels()); } + Napi::Value Resize(const Napi::CallbackInfo &info) + { + return CreateMat(info.Env(), [this, &info](auto &mat) { cv::resize(im_, mat.im_, cv::Size(info[0].As().Int32Value(), info[1].As().Int32Value())); }); + } + Napi::Value Crop(const Napi::CallbackInfo &info) + { + return CreateMat(info.Env(), [this, &info](auto &mat) { + mat.im_ = im_(cv::Rect( + info[0].As().Int32Value(), info[1].As().Int32Value(), + info[2].As().Int32Value(), info[3].As().Int32Value())); + }); + } + Napi::Value Rotate(const Napi::CallbackInfo &info) + { + return CreateMat(info.Env(), [this, &info](auto &mat) { + auto x = info[0].As().DoubleValue(); + auto y = info[1].As().DoubleValue(); + auto angle = info[2].As().DoubleValue(); + cv::Mat rotation_matix = cv::getRotationMatrix2D(cv::Point2f(x, y), angle, 1.0); + cv::warpAffine(im_, mat.im_, rotation_matix, im_.size()); + }); + } + + Napi::Value Clone(const Napi::CallbackInfo &info) + { + return CreateMat(info.Env(), [this, &info](auto &mat) { mat.im_ = im_.clone(); }); + } + + Napi::Value DrawCircle(const Napi::CallbackInfo &info) + { + int x = info[0].As().Int32Value(); + int y = info[1].As().Int32Value(); + int radius = info[2].As().Int32Value(); + int b = info[3].As().Int32Value(); + int g = info[4].As().Int32Value(); + int r = info[5].As().Int32Value(); + int thickness = info[6].As().Int32Value(); + int lineType = info[7].As().Int32Value(); + int shift = info[8].As().Int32Value(); + + cv::circle(im_, cv::Point(x, y), radius, cv::Scalar(b, g, r), thickness, lineType, shift); + return info.Env().Undefined(); + } + + Napi::Value Data(const Napi::CallbackInfo &info) { return ArrayBuffer::New(info.Env(), im_.ptr(), im_.elemSize() * im_.total()); } + Napi::Value Encode(const Napi::CallbackInfo &info) + { + auto options = info[0].As(); + auto extname = options.Get("extname").As().Utf8Value(); + cv::imencode(extname, im_, encoded_); + return ArrayBuffer::New(info.Env(), encoded_.data(), encoded_.size()); + } + + + private: + inline Napi::Object EmptyMat(Napi::Env env) { return constructor->New({}).As(); } + inline CVMat &GetMat(Napi::Object obj) { return *ObjectWrap::Unwrap(obj); } + inline Napi::Object CreateMat(Napi::Env env, std::function callback) + { + auto obj = EmptyMat(env); + callback(GetMat(obj)); + return obj; + } + + private: + cv::Mat im_; + std::vector encoded_; +}; + + +void InstallOpenCVAPI(Env env, Object exports) +{ + CVMat::Init(env, exports); +} + +#ifdef USE_OPENCV +static Object Init(Env env, Object exports) +{ + InstallOpenCVAPI(env, exports); + return exports; +} +NODE_API_MODULE(addon, Init) + +#endif diff --git a/cxx/cv/node.h b/cxx/cv/node.h new file mode 100644 index 0000000..498a7d2 --- /dev/null +++ b/cxx/cv/node.h @@ -0,0 +1,8 @@ +#ifndef __CV_NODE_H__ +#define __CV_NODE_H__ + +#include "common/node.h" + +void InstallOpenCVAPI(Napi::Env env, Napi::Object exports); + +#endif diff --git a/cxx/mnn/session.h b/cxx/mnn/session.h new file mode 100644 index 0000000..deb1fb0 --- /dev/null +++ b/cxx/mnn/session.h @@ -0,0 +1,19 @@ +#ifndef __MNN_SESSION_H__ +#define __MNN_SESSION_H__ + +#include +#include + +#include "common/session.h" + +namespace ai +{ + class MNNSession : public Session + { + public: + MNNSession(const void *modelData, size_t size); + ~MNNSession(); + }; +} + +#endif diff --git a/cxx/node.cc b/cxx/node.cc new file mode 100644 index 0000000..1c0b25c --- /dev/null +++ b/cxx/node.cc @@ -0,0 +1,65 @@ +#include +#include +#include "cv/node.h" +#ifdef USE_ORT +#include "ort/node.h" +#endif + +using namespace Napi; + +class TestWork : public AsyncWorker +{ +public: + TestWork(const Napi::Function &callback, int value) : Napi::AsyncWorker(callback), val_(value) {} + ~TestWork() {} + + void Execute() + { + printf("the worker-thread doing! %d \n", val_); + sleep(3); + printf("the worker-thread done! %d \n", val_); + } + + void OnOK() + { + Callback().Call({Env().Undefined(), Number::New(Env(), 0)}); + } + +private: + int val_; +}; + +Value test(const CallbackInfo &info) +{ + // ai::ORTSession(nullptr, 0); + + // Function callback = info[1].As(); + // TestWork *work = new TestWork(callback, info[0].As().Int32Value()); + // work->Queue(); + return info.Env().Undefined(); +} + +Object Init(Env env, Object exports) +{ + //OpenCV + NODE_INIT_OBJECT(cv, InstallOpenCVAPI); + //OnnxRuntime + #ifdef USE_ORT + NODE_INIT_OBJECT(ort, InstallOrtAPI); + #endif + + Napi::Number::New(env, 0); + +#define ADD_FUNCTION(name) exports.Set(Napi::String::New(env, #name), Napi::Function::New(env, name)) + // ADD_FUNCTION(facedetPredict); + // ADD_FUNCTION(facedetRelease); + + // ADD_FUNCTION(faceRecognitionCreate); + // ADD_FUNCTION(faceRecognitionPredict); + // ADD_FUNCTION(faceRecognitionRelease); + + // ADD_FUNCTION(getDistance); +#undef ADD_FUNCTION + return exports; +} +NODE_API_MODULE(addon, Init) \ No newline at end of file diff --git a/cxx/ort/node.cc b/cxx/ort/node.cc new file mode 100644 index 0000000..3ee8b5e --- /dev/null +++ b/cxx/ort/node.cc @@ -0,0 +1,292 @@ +#include +#include +#include +#include "node.h" + +using namespace Napi; + +#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast(napi_writable | napi_configurable)) + +static ONNXTensorElementDataType getDataTypeFromString(const std::string &name) +{ + static const std::map dataTypeNameMap = { + {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, + {"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, + {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, + {"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, + {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, + {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, + {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, + {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, + {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, + {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, + {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, + {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}, + }; + auto it = dataTypeNameMap.find(name); + return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second; +} + +static size_t getDataTypeSize(ONNXTensorElementDataType type) +{ + static const std::map dataTypeSizeMap = { + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8}, + }; + auto it = dataTypeSizeMap.find(type); + return (it == dataTypeSizeMap.end()) ? 0 : it->second; +} + +static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes) +{ + auto arr = val.As(); + auto data = static_cast(arr.ArrayBuffer().Data()); + bytes = arr.ByteLength(); + return static_cast(data + arr.ByteOffset()); +} + +class OrtSessionNodeInfo { + public: + inline OrtSessionNodeInfo(const std::string &name, const Ort::TypeInfo &typeInfo) + : name_(name), shape_(typeInfo.GetTensorTypeAndShapeInfo().GetShape()), type_(typeInfo.GetTensorTypeAndShapeInfo().GetElementType()) {} + + inline const std::string &GetName() const { return name_; } + inline const std::vector &GetShape() const { return shape_; } + inline ONNXTensorElementDataType GetType() const { return type_; } + inline size_t GetElementSize() const { return getDataTypeSize(type_); } + size_t GetElementCount() const + { + if (!shape_.size()) return 0; + size_t result = 1; + for (auto dim : shape_) { + if (dim <= 0) continue; + result *= dim; + } + return result; + } + inline size_t GetElementBytes() const { return GetElementSize() * GetElementCount(); } + + private: + std::string name_; + std::vector shape_; + ONNXTensorElementDataType type_; +}; + +class OrtSessionRunWorker : public AsyncWorker { + public: + OrtSessionRunWorker(std::shared_ptr session, const Napi::Function &callback) + : Napi::AsyncWorker(callback), session_(session) {} + ~OrtSessionRunWorker() {} + + void Execute() + { + if (!HasError()) { + try { + Ort::RunOptions runOption; + session_->Run(runOption, &inputNames_[0], &inputValues_[0], inputNames_.size(), &outputNames_[0], &outputValues_[0], outputNames_.size()); + } + catch (std::exception &err) { + SetError(err.what()); + } + } + } + + void OnOK() + { + if (HasError()) { + Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()}); + } + else { + auto result = Object::New(Env()); + for (int i = 0; i < outputNames_.size(); ++i) { + size_t bytes = outputElementBytes_[i]; + Ort::Value &value = outputValues_[i]; + auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes); + result.Set(String::New(Env(), outputNames_[i]), buffer); + } + Callback().Call({Env().Undefined(), result}); + } + } + + inline void AddInput(const Ort::MemoryInfo &memoryInfo, const std::string &name, ONNXTensorElementDataType type, const std::vector &shape, void *data, size_t dataByteLen) + { + try { + inputNames_.push_back(name.c_str()); + inputValues_.push_back(Ort::Value::CreateTensor(memoryInfo, data, dataByteLen, &shape[0], shape.size(), type)); + } + catch (std::exception &err) { + SetError(err.what()); + } + } + + inline void AddOutput(const std::string &name, size_t elementBytes) + { + outputNames_.push_back(name.c_str()); + outputValues_.push_back(Ort::Value{nullptr}); + outputElementBytes_.push_back(elementBytes); + } + + inline void SetError(const std::string &what) { errorMessage_ = what; } + + inline bool HasError() { return errorMessage_.size() > 0; } + + public: + std::shared_ptr session_; + std::vector inputNames_; + std::vector inputValues_; + std::vector outputNames_; + std::vector outputValues_; + std::vector outputElementBytes_; + std::string errorMessage_; +}; + +class OrtSession : public ObjectWrap { + public: + static Napi::Object Init(Napi::Env env, Napi::Object exports) + { + Function func = DefineClass(env, "OrtSession", { + SESSION_INSTANCE_METHOD(GetInputsInfo), + SESSION_INSTANCE_METHOD(GetOutputsInfo), + SESSION_INSTANCE_METHOD(Run), + }); + FunctionReference *constructor = new FunctionReference(); + *constructor = Napi::Persistent(func); + exports.Set("OrtSession", func); + env.SetInstanceData(constructor); + return exports; + } + + OrtSession(const CallbackInfo &info) + : ObjectWrap(info) + { + try { + if (info[0].IsString()) session_ = std::make_shared(env, info[0].As().Utf8Value().c_str(), sessionOptions); + else if (info[0].IsTypedArray()) { + size_t bufferBytes; + auto buffer = dataFromTypedArray(info[0], bufferBytes); + session_ = std::make_shared(env, buffer, bufferBytes, sessionOptions); + } + else session_ = nullptr; + + if (session_ != nullptr) { + Ort::AllocatorWithDefaultOptions allocator; + for (int i = 0; i < session_->GetInputCount(); ++i) { + std::string name = session_->GetInputNameAllocated(i, allocator).get(); + auto typeInfo = session_->GetInputTypeInfo(i); + inputs_.emplace(name, std::make_shared(name, typeInfo)); + } + for (int i = 0; i < session_->GetOutputCount(); ++i) { + std::string name = session_->GetOutputNameAllocated(i, allocator).get(); + auto typeInfo = session_->GetOutputTypeInfo(i); + outputs_.emplace(name, std::make_shared(name, typeInfo)); + } + } + } + catch (std::exception &e) { + Error::New(info.Env(), e.what()).ThrowAsJavaScriptException(); + } + } + + ~OrtSession() {} + + Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), inputs_); } + + Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), outputs_); } + + Napi::Value Run(const Napi::CallbackInfo &info) + { + auto worker = new OrtSessionRunWorker(session_, info[1].As()); + auto inputArgument = info[0].As(); + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) { + auto name = (*it).first.As().Utf8Value(); + auto input = GetInput(name); + if (!input) worker->SetError(std::string("input name #" + name + " not exists")); + else { + auto inputOption = static_cast((*it).second).As(); + if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name))); + else { + auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As().Utf8Value()) : input->GetType(); + size_t dataByteLen; + void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); + auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : input->GetShape(); + worker->AddInput(memoryInfo, input->GetName(), type, shape, data, dataByteLen); + } + } + } + for (auto &it : outputs_) { + worker->AddOutput(it.second->GetName(), it.second->GetElementBytes()); + } + worker->Queue(); + + return info.Env().Undefined(); + } + + private: + static std::vector GetShapeFromJavascript(Napi::Array arr) + { + std::vector res(arr.Length()); + for (int i = 0; i < res.size(); ++i) res[i] = arr.Get(i).As().Int32Value(); + return res; + } + + inline std::shared_ptr GetInput(const std::string &name) const + { + auto it = inputs_.find(name); + return (it == inputs_.end()) ? nullptr : it->second; + } + + inline std::shared_ptr GetOutput(const std::string &name) const + { + auto it = outputs_.find(name); + return (it == outputs_.end()) ? nullptr : it->second; + } + + Napi::Value BuildNodeInfoToJavascript(Napi::Env env, const std::map> &nodes) + { + auto result = Object::New(env); + for (auto it : nodes) { + auto &node = *it.second; + auto item = Object::New(env); + item.Set(String::New(env, "name"), String::New(env, node.GetName())); + item.Set(String::New(env, "type"), Number::New(env, node.GetType())); + auto &shapeVec = node.GetShape(); + auto shape = Array::New(env, shapeVec.size()); + for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i])); + item.Set(String::New(env, "shape"), shape); + result.Set(String::New(env, node.GetName()), item); + } + + return result; + } + + private: + Ort::Env env; + Ort::SessionOptions sessionOptions; + std::shared_ptr session_; + std::map> inputs_; + std::map> outputs_; +}; + +void InstallOrtAPI(Napi::Env env, Napi::Object exports) +{ + OrtSession::Init(env, exports); +} + +#ifdef USE_ONNXRUNTIME +static Object Init(Env env, Object exports) +{ + InstallOrtAPI(env, exports); + return exports; +} +NODE_API_MODULE(addon, Init) + +#endif \ No newline at end of file diff --git a/cxx/ort/node.h b/cxx/ort/node.h new file mode 100644 index 0000000..39e0cfe --- /dev/null +++ b/cxx/ort/node.h @@ -0,0 +1,8 @@ +#ifndef __ORT_NODE_H__ +#define __ORT_NODE_H__ + +#include "common/node.h" + +void InstallOrtAPI(Napi::Env env, Napi::Object exports); + +#endif diff --git a/cxx/test.cc b/cxx/test.cc new file mode 100644 index 0000000..db338a0 --- /dev/null +++ b/cxx/test.cc @@ -0,0 +1,53 @@ +#include + +static void buildNames(const std::vector &input, std::vector &output) +{ + if (input.size() == output.size()) + return; + for (auto &it : input) + output.push_back(it.c_str()); +} + +void test() +{ + static Ort::Env env; + static Ort::SessionOptions sessionOptions; + static Ort::Session *session_ = new Ort::Session(env, "/home/yizhi/Develop/ai-box/models/test.onnx", sessionOptions); + static std::vector inputNames_; + static std::vector inputNames_c_; + static std::vector inputValues_; + static std::vector outputNames_; + static std::vector outputNames_c_; + static std::vector outputValues_; + + inputNames_.emplace_back("a"); + inputNames_.emplace_back("b"); + outputNames_.emplace_back("c"); + outputValues_.emplace_back(Ort::Value{nullptr}); + static std::vector shapeA = {3, 4}; + static std::vector shapeB = {4, 3}; + static float inputA[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + static float inputB[] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}; + auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + inputValues_.emplace_back(Ort::Value::CreateTensor(memoryInfo, inputA, sizeof(inputA), &shapeA[0], shapeA.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + inputValues_.emplace_back(Ort::Value::CreateTensor(memoryInfo, inputB, sizeof(inputB), &shapeB[0], shapeB.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + buildNames(inputNames_, inputNames_c_); + buildNames(outputNames_, outputNames_c_); + const char **ptr = &inputNames_c_[0]; + for (int i = 0; i < inputNames_c_.size(); ++i) + { + printf("input [%d] = %s\n", i, ptr[i]); + } + Ort::RunOptions runOption; + printf("start run\n"); + session_->Run(runOption, &inputNames_c_[0], &inputValues_[0], inputNames_.size(), &outputNames_c_[0], &outputValues_[0], outputNames_.size()); + printf("end run\n"); +} + +int main(int argc, char **argv) +{ + Ort::Env env; + test(); + // ai::ORTSession session(nullptr, 0); + return 0; +} \ No newline at end of file diff --git a/package.json b/package.json new file mode 100644 index 0000000..af8d005 --- /dev/null +++ b/package.json @@ -0,0 +1,19 @@ +{ + "name": "ai-box", + "version": "1.0.0", + "main": "index.js", + "scripts": { + "watch": "tsc -w --inlineSourceMap" + }, + "keywords": [], + "author": "", + "license": "ISC", + "description": "", + "devDependencies": { + "@types/node": "^22.13.5", + "cmake-js": "^7.3.0", + "compressing": "^1.10.1", + "node-addon-api": "^8.3.1", + "unbzip2-stream": "^1.4.3" + } +} diff --git a/src/backend/common/index.ts b/src/backend/common/index.ts new file mode 100644 index 0000000..3a46e9e --- /dev/null +++ b/src/backend/common/index.ts @@ -0,0 +1 @@ +export * from "./session"; diff --git a/src/backend/common/session.ts b/src/backend/common/session.ts new file mode 100644 index 0000000..46d10a4 --- /dev/null +++ b/src/backend/common/session.ts @@ -0,0 +1,28 @@ + + +export interface SessionNodeInfo { + name: string + type: number + shape: number[] +} + +export type SessionNodeType = "float32" | "float64" | "float" | "double" | "int32" | "uint32" | "int16" | "uint16" | "int8" | "uint8" | "int64" | "uint64" + +export type SessionNodeData = Float32Array | Float64Array | Int32Array | Uint32Array | Int16Array | Uint16Array | Int8Array | Uint8Array | BigInt64Array | BigUint64Array + +export interface SessionRunInputOption { + type?: SessionNodeType + data: SessionNodeData + shape?: number[] +} + +export abstract class CommonSession { + public abstract run(inputs: Record): Promise> + + public abstract get inputs(): Record; + public abstract get outputs(): Record; +} + +export function isTypedArray(val: any): val is SessionNodeData { + return val?.buffer instanceof ArrayBuffer; +} diff --git a/src/backend/index.ts b/src/backend/index.ts new file mode 100644 index 0000000..a6d283b --- /dev/null +++ b/src/backend/index.ts @@ -0,0 +1 @@ +export * as backend from "./main"; \ No newline at end of file diff --git a/src/backend/main.ts b/src/backend/main.ts new file mode 100644 index 0000000..cf9fde7 --- /dev/null +++ b/src/backend/main.ts @@ -0,0 +1,2 @@ +export * as common from "./common"; +export * as ort from "./ort"; diff --git a/src/backend/ort/index.ts b/src/backend/ort/index.ts new file mode 100644 index 0000000..dfca203 --- /dev/null +++ b/src/backend/ort/index.ts @@ -0,0 +1 @@ +export { OrtSession as Session } from "./session"; diff --git a/src/backend/ort/session.ts b/src/backend/ort/session.ts new file mode 100644 index 0000000..c17af37 --- /dev/null +++ b/src/backend/ort/session.ts @@ -0,0 +1,32 @@ +import { CommonSession, isTypedArray, SessionNodeData, SessionNodeInfo, SessionRunInputOption } from "../common"; + +export class OrtSession extends CommonSession { + #session: any; + #inputs: Record | null = null; + #outputs: Record | null = null; + + public constructor(modelData: Uint8Array) { + super(); + const addon = require("../../../build/ort.node") + this.#session = new addon.OrtSession(modelData); + } + + public get inputs(): Record { return this.#inputs ??= this.#session.GetInputsInfo(); } + + public get outputs(): Record { return this.#outputs ??= this.#session.GetOutputsInfo(); } + + public run(inputs: Record) { + const inputArgs: Record = {}; + for (const [name, option] of Object.entries(inputs)) { + if (isTypedArray(option)) inputArgs[name] = { data: option } + else inputArgs[name] = option; + } + + return new Promise>((resolve, reject) => this.#session.Run(inputArgs, (err: any, res: any) => { + if (err) return reject(err); + const result: Record = {}; + for (const [name, val] of Object.entries(res)) result[name] = new Float32Array(val as ArrayBuffer); + resolve(result); + })); + } +} diff --git a/src/cv/index.ts b/src/cv/index.ts new file mode 100644 index 0000000..69b587f --- /dev/null +++ b/src/cv/index.ts @@ -0,0 +1 @@ +export * as cv from "./main"; diff --git a/src/cv/main.ts b/src/cv/main.ts new file mode 100644 index 0000000..ca1ede5 --- /dev/null +++ b/src/cv/main.ts @@ -0,0 +1 @@ +export { Mat, ImreadModes } from "./mat"; diff --git a/src/cv/mat.ts b/src/cv/mat.ts new file mode 100644 index 0000000..981e156 --- /dev/null +++ b/src/cv/mat.ts @@ -0,0 +1,79 @@ + +export enum ImreadModes { + IMREAD_UNCHANGED = -1, + IMREAD_GRAYSCALE = 0, + IMREAD_COLOR_BGR = 1, + IMREAD_COLOR = 1, + IMREAD_ANYDEPTH = 2, + IMREAD_ANYCOLOR = 4, + IMREAD_LOAD_GDAL = 8, + IMREAD_REDUCED_GRAYSCALE_2 = 16, + IMREAD_REDUCED_COLOR_2 = 17, + IMREAD_REDUCED_GRAYSCALE_4 = 32, + IMREAD_REDUCED_COLOR_4 = 33, + IMREAD_REDUCED_GRAYSCALE_8 = 64, + IMREAD_REDUCED_COLOR_8 = 65, + IMREAD_IGNORE_ORIENTATION = 128, + IMREAD_COLOR_RGB = 256, +}; + +interface MatConstructorOption { + mode?: ImreadModes; +} + +export class Mat { + #mat: any + + public static async load(image: string, option?: MatConstructorOption) { + let buffer: Uint8Array + if (/^https?:\/\//.test(image)) buffer = await fetch(image).then(res => res.arrayBuffer()).then(res => new Uint8Array(res)); + else buffer = await import("fs").then(fs => fs.promises.readFile(image)); + return new Mat(buffer, option); + } + + public constructor(imageData: Uint8Array, option?: MatConstructorOption) { + const addon = require("../../build/cv.node"); + if ((imageData as any) instanceof addon.Mat) this.#mat = imageData; + else this.#mat = new addon.Mat(imageData, option); + } + + public get empty(): boolean { return this.#mat.IsEmpty() } + + public get cols(): number { return this.#mat.GetCols(); } + + public get rows(): number { return this.#mat.GetRows(); } + + public get width() { return this.cols; } + + public get height() { return this.rows; } + + public get channels() { return this.#mat.GetChannels(); } + + public resize(width: number, height: number) { return new Mat(this.#mat.Resize.bind(this.#mat)(width, height)); } + + public crop(sx: number, sy: number, sw: number, sh: number) { return new Mat(this.#mat.Crop(sx, sy, sw, sh)); } + + public rotate(sx: number, sy: number, angleDeg: number) { return new Mat(this.#mat.Rotate(sx, sy, angleDeg)); } + + public get data() { return new Uint8Array(this.#mat.Data()); } + + public encode(extname: string) { return new Uint8Array(this.#mat.Encode({ extname })); } + + public clone() { return new Mat(this.#mat.Clone()); } + + public circle(x: number, y: number, radius: number, options?: { + color?: { r: number, g: number, b: number }, + thickness?: number + lineType?: number + }) { + this.#mat.DrawCircle( + x, y, radius, + options?.color?.b ?? 0, + options?.color?.g ?? 0, + options?.color?.r ?? 0, + options?.thickness ?? 1, + options?.lineType ?? 8, + 0, + ); + } +} diff --git a/src/deploy/common/model.ts b/src/deploy/common/model.ts new file mode 100644 index 0000000..95e64b1 --- /dev/null +++ b/src/deploy/common/model.ts @@ -0,0 +1,37 @@ +import { backend } from "../../backend"; +import { cv } from "../../cv"; + +export type ModelConstructor = new (session: backend.common.CommonSession) => T; + +export type ImageSource = cv.Mat | Uint8Array | string; + +export interface ImageCropOption { + /** 图片裁剪区域 */ + crop?: { sx: number, sy: number, sw: number, sh: number } +} + +export abstract class Model { + protected session: backend.common.CommonSession; + + protected static async resolveImage(image: ImageSource, resolver: (image: cv.Mat) => R | Promise): Promise { + if (typeof image === "string") { + if (/^https?:\/\//.test(image)) image = await fetch(image).then(res => res.arrayBuffer()).then(buffer => new Uint8Array(buffer)); + else image = await import("fs").then(fs => fs.promises.readFile(image as string)); + } + if (image instanceof Uint8Array) image = new cv.Mat(image, { mode: cv.ImreadModes.IMREAD_COLOR_BGR }) + if (image instanceof cv.Mat) return await resolver(image); + else throw new Error("Invalid image"); + } + + public static fromOnnx(this: ModelConstructor, modelData: Uint8Array) { + return new this(new backend.ort.Session(modelData)); + } + + public constructor(session: backend.common.CommonSession) { this.session = session; } + + + public get inputs() { return this.session.inputs; } + public get outputs() { return this.session.outputs; } + public get input() { return Object.entries(this.inputs)[0][1]; } + public get output() { return Object.entries(this.outputs)[0][1]; } +} diff --git a/src/deploy/common/processors.ts b/src/deploy/common/processors.ts new file mode 100644 index 0000000..eba25de --- /dev/null +++ b/src/deploy/common/processors.ts @@ -0,0 +1,84 @@ + + +interface IConverImageOption { + sourceImageFormat: "rgba" | "rgb" | "bgr" + targetShapeFormat?: "nchw" | "nhwc" + targetColorFormat?: "rgb" | "bgr" | "gray" + targetNormalize?: { mean: number[], std: number[] } +} + +export function convertImage(image: Uint8Array, option?: IConverImageOption) { + const sourceImageFormat = option?.sourceImageFormat ?? "rgb"; + const targetShapeFormat = option?.targetShapeFormat ?? "nchw"; + const targetColorFormat = option?.targetColorFormat ?? "bgr"; + const targetNormalize = option?.targetNormalize ?? { mean: [127.5], std: [127.5] }; + + let rgbReader: (pixel: number) => [number, number, number]; + let pixelCount: number; + switch (sourceImageFormat) { + case "bgr": + rgbReader = pixel => [image[pixel * 3 + 2], image[pixel * 3 + 1], image[pixel * 3 + 0]]; + pixelCount = image.length / 3; + break; + case "rgb": + rgbReader = pixel => [image[pixel * 3 + 0], image[pixel * 3 + 1], image[pixel * 3 + 2]]; + pixelCount = image.length / 3; + break; + case "rgba": + rgbReader = pixel => [image[pixel * 4 + 0], image[pixel * 4 + 1], image[pixel * 4 + 2]]; + pixelCount = image.length / 4; + break; + } + + let targetChannelGetter: (stride: number, pixel: number, offset: number) => number; + switch (targetShapeFormat) { + case "nchw": + targetChannelGetter = (stride, pixel, offset) => stride * offset + pixel; + break; + case "nhwc": + targetChannelGetter = (stride, pixel, offset) => stride * pixel + offset; + break + } + + let normIndex = 0; + const normValue = (val: number) => { + const mean = targetNormalize.mean[normIndex % targetNormalize.mean.length]; + const std = targetNormalize.std[normIndex % targetNormalize.std.length]; + const result = (val - mean) / std; + ++normIndex; + return result; + } + + let outBuffer: Float32Array; + let pixelWriter: (pixel: number, r: number, g: number, b: number) => any + switch (targetColorFormat) { + case "rgb": + outBuffer = new Float32Array(pixelCount * 3); + pixelWriter = (pixel, r, g, b) => { + outBuffer[targetChannelGetter(3, pixel, 0)] = normValue(r); + outBuffer[targetChannelGetter(3, pixel, 1)] = normValue(g); + outBuffer[targetChannelGetter(3, pixel, 2)] = normValue(b); + } + break; + case "bgr": + outBuffer = new Float32Array(pixelCount * 3); + pixelWriter = (pixel, r, g, b) => { + outBuffer[targetChannelGetter(3, pixel, 0)] = normValue(b); + outBuffer[targetChannelGetter(3, pixel, 1)] = normValue(g); + outBuffer[targetChannelGetter(3, pixel, 2)] = normValue(r); + } + break; + case "gray": + outBuffer = new Float32Array(pixelCount); + pixelWriter = (pixel, r, g, b) => { + outBuffer[targetChannelGetter(1, pixel, 0)] = normValue(0.2126 * r + 0.7152 * g + 0.0722 * b); + } + break; + } + + for (let i = 0; i < pixelCount; ++i) { + const [r, g, b] = rgbReader(i); + pixelWriter(i, r, g, b); + } + return outBuffer; +} \ No newline at end of file diff --git a/src/deploy/facealign/common.ts b/src/deploy/facealign/common.ts new file mode 100644 index 0000000..c0b2631 --- /dev/null +++ b/src/deploy/facealign/common.ts @@ -0,0 +1,58 @@ +export interface FacePoint { + x: number + y: number +} + +type PointType = "leftEye" | "rightEye" | "leftEyebrow" | "rightEyebrow" | "nose" | "mouth" | "contour" + +export abstract class FaceAlignmentResult { + #points: FacePoint[] + + public constructor(points: FacePoint[]) { this.#points = points; } + + /** 关键点 */ + public get points() { return this.#points; } + + /** 获取特定的关键点 */ + public getPointsOf(type: PointType | PointType[]) { + if (typeof type == "string") type = [type]; + const result: FacePoint[] = []; + for (const t of type) { + for (const idx of this[`${t}PointIndex` as const]()) { + result.push(this.points[idx]); + } + } + return result; + } + + /** 方向 */ + public get direction() { + const [{ x: x1, y: y1 }, { x: x2, y: y2 }] = this.directionPointIndex().map(idx => this.points[idx]); + return Math.atan2(y1 - y2, x2 - x1) + } + + /** 用于判断方向的两个点的索引(建议选取眼球中间的点) */ + protected abstract directionPointIndex(): [number, number]; + /** 左眼点的索引 */ + protected abstract leftEyePointIndex(): number[]; + /** 右眼点的索引 */ + protected abstract rightEyePointIndex(): number[]; + /** 左眉点的索引 */ + protected abstract leftEyebrowPointIndex(): number[]; + /** 右眉点的索引 */ + protected abstract rightEyebrowPointIndex(): number[]; + /** 嘴巴点的索引 */ + protected abstract mouthPointIndex(): number[]; + /** 鼻子的索引 */ + protected abstract nosePointIndex(): number[]; + /** 轮廓点的索引 */ + protected abstract contourPointIndex(): number[]; + + protected indexFromTo(from: number, to: number) { + const indexes: number[] = []; + for (let i = from; i <= to; i++) { + indexes.push(i); + } + return indexes; + } +} diff --git a/src/deploy/facealign/index.ts b/src/deploy/facealign/index.ts new file mode 100644 index 0000000..6af4c87 --- /dev/null +++ b/src/deploy/facealign/index.ts @@ -0,0 +1,2 @@ +export { PFLD } from "./pfld"; +export { FaceLandmark1000 } from "./landmark1000"; \ No newline at end of file diff --git a/src/deploy/facealign/landmark1000.ts b/src/deploy/facealign/landmark1000.ts new file mode 100644 index 0000000..0826a7a --- /dev/null +++ b/src/deploy/facealign/landmark1000.ts @@ -0,0 +1,52 @@ +import { writeFileSync } from "fs"; +import { cv } from "../../cv"; +import { ImageCropOption, ImageSource, Model } from "../common/model"; +import { convertImage } from "../common/processors"; +import { FaceAlignmentResult, FacePoint } from "./common"; + +interface FaceLandmark1000PredictOption extends ImageCropOption { } + +class FaceLandmark1000Result extends FaceAlignmentResult { + protected directionPointIndex(): [number, number] { return [401, 529]; } + protected leftEyePointIndex(): number[] { return this.indexFromTo(401, 528); } + protected rightEyePointIndex(): number[] { return this.indexFromTo(529, 656); } + protected leftEyebrowPointIndex(): number[] { return this.indexFromTo(273, 336); } + protected rightEyebrowPointIndex(): number[] { return this.indexFromTo(337, 400); } + protected mouthPointIndex(): number[] { return this.indexFromTo(845, 972); } + protected nosePointIndex(): number[] { return this.indexFromTo(657, 844); } + protected contourPointIndex(): number[] { return this.indexFromTo(0, 272); } +} + +export class FaceLandmark1000 extends Model { + + public predict(image: ImageSource, option?: FaceLandmark1000PredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } + + public async doPredict(image: cv.Mat, option?: FaceLandmark1000PredictOption) { + const input = this.input; + if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); + const ratioWidth = image.width / input.shape[3]; + const ratioHeight = image.height / input.shape[2]; + image = image.resize(input.shape[3], input.shape[2]); + + const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "gray", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } }); + + const res = await this.session.run({ + [input.name]: { + shape: [1, 1, input.shape[2], input.shape[3]], + data: nchwImageData, + type: "float32", + } + }).then(res => res[this.output.name]); + + const points: FacePoint[] = []; + for (let i = 0; i < res.length; i += 2) { + const x = res[i] * image.width * ratioWidth; + const y = res[i + 1] * image.height * ratioHeight; + points.push({ x, y }); + } + + return new FaceLandmark1000Result(points); + + } + +} \ No newline at end of file diff --git a/src/deploy/facealign/pfld.ts b/src/deploy/facealign/pfld.ts new file mode 100644 index 0000000..41cc24e --- /dev/null +++ b/src/deploy/facealign/pfld.ts @@ -0,0 +1,53 @@ +import { writeFileSync } from "fs"; +import { cv } from "../../cv"; +import { ImageCropOption, ImageSource, Model } from "../common/model"; +import { convertImage } from "../common/processors"; +import { FaceAlignmentResult, FacePoint } from "./common"; + +export interface PFLDPredictOption extends ImageCropOption { } + +class PFLDResult extends FaceAlignmentResult { + protected directionPointIndex(): [number, number] { return [36, 92]; } + protected leftEyePointIndex(): number[] { return [33, 34, 35, 36, 37, 38, 39, 40, 41, 42]; } + protected rightEyePointIndex(): number[] { return [87, 88, 89, 90, 91, 92, 93, 94, 95, 96]; } + protected leftEyebrowPointIndex(): number[] { return [43, 44, 45, 46, 47, 48, 49, 50, 51]; } + protected rightEyebrowPointIndex(): number[] { return [97, 98, 99, 100, 101, 102, 103, 104, 105]; } + protected mouthPointIndex(): number[] { return [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]; } + protected nosePointIndex(): number[] { return [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86]; } + protected contourPointIndex(): number[] { return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]; } +} + +export class PFLD extends Model { + public predict(image: ImageSource, option?: PFLDPredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } + + private async doPredict(image: cv.Mat, option?: PFLDPredictOption) { + const input = this.input; + if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); + const ratioWidth = image.width / input.shape[3]; + const ratioHeight = image.height / input.shape[2]; + image = image.resize(input.shape[3], input.shape[2]); + + const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [255] } }) + + const pointsOutput = Object.entries(this.outputs).filter(([_, out]) => out.shape.length == 2)[0][1]; + + const res = await this.session.run({ + [input.name]: { + type: "float32", + data: nchwImageData, + shape: [1, 3, input.shape[2], input.shape[3]], + } + }); + const pointsBuffer = res[pointsOutput.name]; + + const points: FacePoint[] = []; + for (let i = 0; i < pointsBuffer.length; i += 2) { + const x = pointsBuffer[i] * image.width * ratioWidth; + const y = pointsBuffer[i + 1] * image.height * ratioHeight; + points.push({ x, y }); + } + + return new PFLDResult(points); + } + +} diff --git a/src/deploy/faceattr/gender-age.ts b/src/deploy/faceattr/gender-age.ts new file mode 100644 index 0000000..693a936 --- /dev/null +++ b/src/deploy/faceattr/gender-age.ts @@ -0,0 +1,40 @@ +import { cv } from "../../cv"; +import { ImageCropOption, ImageSource, Model } from "../common/model"; +import { convertImage } from "../common/processors"; + +interface GenderAgePredictOption extends ImageCropOption { +} + +export interface GenderAgePredictResult { + gender: "M" | "F" + age: number +} + + +export class GenderAge extends Model { + private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise { + const input = this.input; + const output = this.output; + if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); + image = image.resize(input.shape[3], input.shape[2]); + + const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } }); + + const result = await this.session.run({ + [input.name]: { + shape: [1, 3, input.shape[2], input.shape[3]], + data: nchwImage, + type: "float32", + } + }).then(res => res[output.name]); + + return { + gender: result[0] > result[1] ? "F" : "M", + age: parseInt(result[2] * 100 as any), + } + } + + public predict(image: ImageSource, option?: GenderAgePredictOption) { + return Model.resolveImage(image, im => this.doPredict(im, option)); + } +} diff --git a/src/deploy/faceattr/index.ts b/src/deploy/faceattr/index.ts new file mode 100644 index 0000000..7f8e37a --- /dev/null +++ b/src/deploy/faceattr/index.ts @@ -0,0 +1 @@ +export { GenderAge as GenderAgeDetector } from "./gender-age"; diff --git a/src/deploy/facedet/common.ts b/src/deploy/facedet/common.ts new file mode 100644 index 0000000..9493ab1 --- /dev/null +++ b/src/deploy/facedet/common.ts @@ -0,0 +1,103 @@ +import { cv } from "../../cv" +import { ImageSource, Model } from "../common/model" + +interface IFaceBoxConstructorOption { + x1: number + y1: number + x2: number + y2: number + score: number + imw: number + imh: number +} + +export class FaceBox { + #option: IFaceBoxConstructorOption; + constructor(option: IFaceBoxConstructorOption) { this.#option = option; } + + public get x1() { return this.#option.x1; } + public get y1() { return this.#option.y1; } + public get x2() { return this.#option.x2; } + public get y2() { return this.#option.y2; } + public get centerX() { return this.x1 + this.width / 2; } + public get centerY() { return this.y1 + this.height / 2; } + public get score() { return this.#option.score; } + public get left() { return this.x1; } + public get top() { return this.y1; } + public get width() { return this.x2 - this.x1; } + public get height() { return this.y2 - this.y1; } + + /** 转换成整数 */ + public toInt() { + return new FaceBox({ + ...this.#option, + x1: parseInt(this.#option.x1 as any), y1: parseInt(this.#option.y1 as any), + x2: parseInt(this.#option.x2 as any), y2: parseInt(this.#option.y2 as any), + }); + } + + /** 转换成正方形 */ + public toSquare() { + const { imw, imh } = this.#option; + let size = Math.max(this.width, this.height) / 2; + const cx = this.centerX, cy = this.centerY; + console.log(this) + + if (cx - size < 0) size = cx; + if (cx + size > imw) size = imw - cx; + if (cy - size < 0) size = cy; + if (cy + size > imh) size = imh - cy; + + return new FaceBox({ + ...this.#option, + x1: this.centerX - size, y1: this.centerY - size, + x2: this.centerX + size, y2: this.centerY + size, + }); + } +} + +export interface FaceDetectOption { + /** 阈值,默认0.5 */ + threshold?: number + /** MNS阈值,默认0.3 */ + mnsThreshold?: number +} + +export abstract class FaceDetector extends Model { + + protected abstract doPredict(image: cv.Mat, option?: FaceDetectOption): Promise; + + public async predict(image: ImageSource, option?: FaceDetectOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } +} + + +export function nms(input: FaceBox[], threshold: number) { + if (!input.length) return []; + input = input.sort((a, b) => b.score - a.score); + const merged = input.map(() => 0); + const output: FaceBox[] = []; + + for (let i = 0; i < input.length; i++) { + if (merged[i]) continue; + output.push(input[i]); + + for (let j = i + 1; j < input.length; j++) { + if (merged[j]) continue; + const inner_x0 = input[i].x1 > input[j].x1 ? input[i].x1 : input[j].x1; + const inner_y0 = input[i].y1 > input[j].y1 ? input[i].y1 : input[j].y1; + const inner_x1 = input[i].x2 < input[j].x2 ? input[i].x2 : input[j].x2; + const inner_y1 = input[i].y2 < input[j].y2 ? input[i].y2 : input[j].y2; + const inner_h = inner_y1 - inner_y0 + 1; + const inner_w = inner_x1 - inner_x0 + 1; + if (inner_h <= 0 || inner_w <= 0) continue; + const inner_area = inner_h * inner_w; + const h1 = input[j].y2 - input[j].y1 + 1; + const w1 = input[j].x2 - input[j].x1 + 1; + const area1 = h1 * w1; + const score = inner_area / area1; + if (score > threshold) merged[j] = 1; + } + } + + return output; +} diff --git a/src/deploy/facedet/index.ts b/src/deploy/facedet/index.ts new file mode 100644 index 0000000..181deab --- /dev/null +++ b/src/deploy/facedet/index.ts @@ -0,0 +1,2 @@ +export { FaceBox } from "./common"; +export { Yolov5Face } from "./yolov5"; diff --git a/src/deploy/facedet/yolov5.ts b/src/deploy/facedet/yolov5.ts new file mode 100644 index 0000000..6692c42 --- /dev/null +++ b/src/deploy/facedet/yolov5.ts @@ -0,0 +1,37 @@ +import { cv } from "../../cv"; +import { convertImage } from "../common/processors"; +import { FaceBox, FaceDetectOption, FaceDetector, nms } from "./common"; + +export class Yolov5Face extends FaceDetector { + + public async doPredict(image: cv.Mat, option?: FaceDetectOption): Promise { + const input = this.input; + const resizedImage = image.resize(input.shape[2], input.shape[3]); + const ratioWidth = image.width / resizedImage.width; + const ratioHeight = image.height / resizedImage.height; + const nchwImageData = convertImage(resizedImage.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } }); + + const outputData = await this.session.run({ input: nchwImageData }).then(r => r.output); + const outShape = this.outputs["output"].shape; + + const faces: FaceBox[] = []; + const threshold = option?.threshold ?? 0.5; + for (let i = 0; i < outShape[1]; i++) { + const beg = i * outShape[2]; + const rectData = outputData.slice(beg, beg + outShape[2]); + const x = parseInt(rectData[0] * ratioWidth as any); + const y = parseInt(rectData[1] * ratioHeight as any); + const w = parseInt(rectData[2] * ratioWidth as any); + const h = parseInt(rectData[3] * ratioHeight as any); + const score = rectData[4] * rectData[15]; + if (score < threshold) continue; + faces.push(new FaceBox({ + x1: x - w / 2, y1: y - h / 2, + x2: x + w / 2, y2: y + h / 2, + score, imw: image.width, imh: image.height, + })) + } + return nms(faces, option?.mnsThreshold ?? 0.3).map(box => box.toInt()); + } + +} diff --git a/src/deploy/faceid/adaface.ts b/src/deploy/faceid/adaface.ts new file mode 100644 index 0000000..9d61966 --- /dev/null +++ b/src/deploy/faceid/adaface.ts @@ -0,0 +1,27 @@ +import { Mat } from "../../cv/mat"; +import { convertImage } from "../common/processors"; +import { FaceRecognition, FaceRecognitionPredictOption } from "./common"; + +export class AdaFace extends FaceRecognition { + + public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise { + const input = this.input; + const output = this.output; + + if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); + image = image.resize(input.shape[3], input.shape[2]); + + const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } }); + + const embedding = await this.session.run({ + [input.name]: { + type: "float32", + data: nchwImageData, + shape: [1, 3, input.shape[2], input.shape[3]], + } + }).then(res => res[output.name]); + + return new Array(...embedding); + } + +} diff --git a/src/deploy/faceid/common.ts b/src/deploy/faceid/common.ts new file mode 100644 index 0000000..0320046 --- /dev/null +++ b/src/deploy/faceid/common.ts @@ -0,0 +1,31 @@ +import { cv } from "../../cv"; +import { ImageCropOption, ImageSource, Model } from "../common/model"; + +export interface FaceRecognitionPredictOption extends ImageCropOption { } + +export abstract class FaceRecognition extends Model { + public abstract doPredict(image: cv.Mat, option?: FaceRecognitionPredictOption): Promise; + + public async predict(image: ImageSource, option?: FaceRecognitionPredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } +} + +export function cosineDistance(lhs: number[], rhs: number[]) { + if (lhs.length !== rhs.length) throw new Error('length not match'); + + const getMod = (vec: number[]) => { + let sum = 0; + for (let i = 0; i < vec.length; ++i) sum += vec[i] * vec[i]; + return sum ** 0.5; + } + + let temp = 0; + for (let i = 0; i < lhs.length; ++i) temp += lhs[i] * rhs[i]; + return temp / (getMod(lhs) * getMod(rhs)); +} + +export function euclideanDistance(lhs: number[], rhs: number[]) { + if (lhs.length !== rhs.length) throw new Error('length not match'); + let sumDescriptor = 0; + for (let i = 0; i < lhs.length; i++) sumDescriptor += (lhs[i] - rhs[i]) ** 2; + return sumDescriptor ** 0.5; +} diff --git a/src/deploy/faceid/index.ts b/src/deploy/faceid/index.ts new file mode 100644 index 0000000..c256c88 --- /dev/null +++ b/src/deploy/faceid/index.ts @@ -0,0 +1,3 @@ +export { cosineDistance, euclideanDistance } from "./common"; +export { ArcFace, CosFace, PartialFC } from "./insightface"; +export { AdaFace } from "./adaface"; \ No newline at end of file diff --git a/src/deploy/faceid/insightface.ts b/src/deploy/faceid/insightface.ts new file mode 100644 index 0000000..90dc06b --- /dev/null +++ b/src/deploy/faceid/insightface.ts @@ -0,0 +1,32 @@ +import { Mat } from "../../cv/mat"; +import { convertImage } from "../common/processors"; +import { FaceRecognition, FaceRecognitionPredictOption } from "./common"; + +export class Insightface extends FaceRecognition { + + public async doPredict(image: Mat, option?: FaceRecognitionPredictOption): Promise { + const input = this.input; + const output = this.output; + if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); + image = image.resize(input.shape[3], input.shape[2]); + const nchwImageData = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "bgr", targetShapeFormat: "nchw", targetNormalize: { mean: [127.5], std: [127.5] } }); + + const embedding = await this.session.run({ + [input.name]: { + type: "float32", + data: nchwImageData, + shape: [1, 3, input.shape[2], input.shape[3]], + } + }).then(res => res[output.name]); + + return new Array(...embedding); + } + +} + +export class ArcFace extends Insightface { } + +export class CosFace extends Insightface { } + +export class PartialFC extends Insightface { } + diff --git a/src/deploy/index.ts b/src/deploy/index.ts new file mode 100644 index 0000000..75ad87c --- /dev/null +++ b/src/deploy/index.ts @@ -0,0 +1 @@ +export * as deploy from "./main"; diff --git a/src/deploy/main.ts b/src/deploy/main.ts new file mode 100644 index 0000000..2472b41 --- /dev/null +++ b/src/deploy/main.ts @@ -0,0 +1,4 @@ +export * as facedet from "./facedet"; +export * as faceid from "./faceid"; +export * as faceattr from "./faceattr"; +export * as facealign from "./facealign"; diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/main.ts b/src/main.ts new file mode 100644 index 0000000..e69de29 diff --git a/src/test.ts b/src/test.ts new file mode 100644 index 0000000..ad4cdec --- /dev/null +++ b/src/test.ts @@ -0,0 +1,138 @@ +import fs from "fs"; +import { deploy } from "./deploy"; +import { cv } from "./cv"; +import { faceidTestData } from "./test_data/faceid"; +import path from "path"; +import crypto from "crypto"; + +async function cacheImage(group: string, url: string) { + const _url = new URL(url); + const cacheDir = path.join(__dirname, "../cache/images", group); + fs.mkdirSync(cacheDir, { recursive: true }); + const cacheJsonFile = path.join(cacheDir, "config.json"); + let jsonData: Record = {}; + if (cacheJsonFile && fs.existsSync(cacheJsonFile)) { + jsonData = JSON.parse(fs.readFileSync(cacheJsonFile, "utf-8")); + const filename = jsonData[url]; + if (filename && fs.existsSync(filename)) return path.join(cacheDir, filename); + } + const data = await fetch(_url).then(res => res.arrayBuffer()).then(buf => new Uint8Array(buf)); + const allowedExtnames = [".jpg", ".jpeg", ".png", ".webp"]; + let extname = path.extname(_url.pathname); + if (!allowedExtnames.includes(extname)) extname = ".jpeg"; + const md5 = crypto.hash("md5", data, "hex"); + const savename = md5 + extname; + jsonData[url] = savename; + fs.writeFileSync(cacheJsonFile, JSON.stringify(jsonData)); + const savepath = path.join(cacheDir, savename); + fs.writeFileSync(savepath, data); + return savepath; +} + +async function testGenderTest() { + const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); + const detector = deploy.faceattr.GenderAgeDetector.fromOnnx(fs.readFileSync("models/faceattr/insight_gender_age.onnx")); + + const image = await cv.Mat.load("https://b0.bdstatic.com/ugc/iHBWUj0XqytakT1ogBfBJwc7c305331d2cf904b9fb3d8dd3ed84f5.jpg"); + const boxes = await facedet.predict(image); + if (!boxes.length) return console.error("未检测到人脸"); + for (const [idx, box] of boxes.entries()) { + const res = await detector.predict(image, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + console.log(`[${idx + 1}]`, res); + } + +} + +async function testFaceID() { + const facedet = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); + const faceid = deploy.faceid.CosFace.fromOnnx(fs.readFileSync("models/faceid/insightface/glint360k_cosface_r100.onnx")); + + const { basic, tests } = faceidTestData.stars; + + console.log("正在加载图片资源"); + const basicImage = await cv.Mat.load(await cacheImage("faceid", basic.image)); + const testsImages: Record = {}; + for (const [name, imgs] of Object.entries(tests)) { + testsImages[name] = await Promise.all(imgs.map(img => cacheImage("faceid", img).then(img => cv.Mat.load(img)))); + } + + console.log("正在检测基本数据"); + const basicDetectedFaces = await facedet.predict(basicImage); + const basicFaceIndex: Record = {}; + for (const [name, [x, y]] of Object.entries(basic.faces)) { + basicFaceIndex[name] = basicDetectedFaces.findIndex(box => box.x1 < x && box.x2 > x && box.y1 < y && box.y2 > y); + } + + const basicEmbds: number[][] = []; + for (const box of basicDetectedFaces) { + const embd = await faceid.predict(basicImage, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + basicEmbds.push(embd); + } + + console.log("正在进行人脸对比"); + for (const [name, image] of Object.entries(testsImages)) { + console.log(`[${name}] 正在检测`); + const index = basicFaceIndex[name]; + if (index < 0) { + console.error(`[${name}] 不存在`); + continue; + } + const basicEmbd = basicEmbds[index] + + for (const [idx, img] of image.entries()) { + const box = await facedet.predict(img).then(boxes => boxes[0]); + if (!box) { + console.error(`[${idx + 1}] 未检测到人脸`); + continue + } + + const embd = await faceid.predict(img, { crop: { sx: box.left, sy: box.top, sw: box.width, sh: box.height } }); + + const compareEmbds = basicEmbds.map(e => deploy.faceid.cosineDistance(e, embd)); + const max = Math.max(...compareEmbds); + const min = Math.min(...compareEmbds); + + const distance = deploy.faceid.cosineDistance(basicEmbd, embd); + console.log(`[${idx + 1}] [${(distance.toFixed(4) == max.toFixed(4)) ? '\x1b[102m成功' : '\x1b[101m失败'}\x1b[0m] 相似度:${distance.toFixed(4)}, 最大:${max.toFixed(4)}, 最小:${min.toFixed(4)}`); + } + } + +} + +async function testFaceAlign() { + const fd = deploy.facedet.Yolov5Face.fromOnnx(fs.readFileSync("models/facedet/yolov5s.onnx")); + // const fa = deploy.facealign.PFLD.fromOnnx(fs.readFileSync("models/facealign/pfld-106-lite.onnx")); + const fa = deploy.facealign.FaceLandmark1000.fromOnnx(fs.readFileSync("models/facealign/FaceLandmark1000.onnx")); + let image = await cv.Mat.load("https://bkimg.cdn.bcebos.com/pic/d52a2834349b033b5bb5f183119c21d3d539b6001712"); + image = image.rotate(image.width / 2, image.height / 2, 0); + + const face = await fd.predict(image).then(res => res[0].toSquare()); + const points = await fa.predict(image, { crop: { sx: face.left, sy: face.top, sw: face.width, sh: face.height } }); + + points.points.forEach((point, idx) => { + image.circle(face.left + point.x, face.top + point.y, 2); + }) + // const point = points.getPointsOf("rightEye")[1]; + // image.circle(face.left + point.x, face.top + point.y, 2); + fs.writeFileSync("testdata/xx.jpg", image.encode(".jpg")); + + let faceImage = image.rotate(face.centerX, face.centerY, -points.direction * 180 / Math.PI); + faceImage = faceImage.crop(face.left, face.top, face.width, face.height); + fs.writeFileSync("testdata/face.jpg", faceImage.encode(".jpg")); + + console.log(points); + console.log(points.direction * 180 / Math.PI); + debugger +} + + +async function test() { + // testGenderTest(); + // testFaceID(); + testFaceAlign(); +} + +test().catch(err => { + console.error(err); + debugger +}); \ No newline at end of file diff --git a/src/test_data/faceid.ts b/src/test_data/faceid.ts new file mode 100644 index 0000000..9156e18 --- /dev/null +++ b/src/test_data/faceid.ts @@ -0,0 +1,70 @@ + +interface StarData { + basic: { image: string, faces: Record } + tests: Record +} + +const stars: StarData = { + basic: { + image: "https://i0.hdslb.com/bfs/archive/64e47ec9fdac9e24bc2b49b5aaad5560da1bfe3e.jpg", + faces: { + "huge": [758, 492], + "yangmi": [1901, 551], + "yangying": [2630, 521], + "liuyifei": [353, 671], + "dengjiajia": [1406, 597], + } + }, + tests: { + "huge": [ + "https://p4.itc.cn/images01/20231025/539a0095f34a47f0b6e7ebc8dca43967.png", + "https://p6.itc.cn/q_70/images03/20230517/bc9abf4b1d2f462390f9d5f4d799e102.jpeg", + "https://bj.bcebos.com/bjh-pixel/1703489647650979543_1_ainote_new.jpg", + "https://pic.rmb.bdstatic.com/bjh/beautify/e7eb49c5097c757f068b4dc4b4140a99.jpeg", + "https://q4.itc.cn/images01/20240922/0e27ce021504496fb33e428b5a6eb09d.jpeg", + "https://q6.itc.cn/images01/20241111/8d2eb1503c0f43129943d8a5f5b84291.jpeg", + "https://q0.itc.cn/images01/20250120/b0910e8c37a341e290ae205c6a930e11.jpeg", + "https://n.sinaimg.cn/sinacn20107/320/w2048h3072/20190107/4e5b-hrfcctn3800719.jpg" + ], + "yangmi": [ + "https://q1.itc.cn/images01/20240221/4044fff01df4480d841d8a3923869455.jpeg", + "https://b0.bdstatic.com/7f7d30e4ce392a11a3612e8c0c1970b0.jpg", + "https://b0.bdstatic.com/c157a78f9302d0538c3dd22fef2659ce.jpg", + "https://q5.itc.cn/images01/20240419/cb15f51e65e5440f92b0f36298f64643.jpeg", + "https://pic.rmb.bdstatic.com/bjh/240423/b5e5aa734b1b80e93b36b7b4250dae887703.jpeg", + "https://q7.itc.cn/images01/20240316/297474bdc779495a8f4c49825ff4af98.jpeg", + "https://p8.itc.cn/q_70/images03/20230426/dfdae05508aa4d6caa566e0a29c70501.jpeg", + "https://pic.rmb.bdstatic.com/bjh/8ca1d31655cc109d363b88dded23bfa03807.jpeg", + ], + "yangying": [ + "https://q9.itc.cn/images01/20240401/e59294e8d51140ef95a32ac15a18f1c5.jpeg", + "https://q2.itc.cn/images01/20240401/8c7266ed665140889054aa029b441948.jpeg", + "https://q7.itc.cn/images01/20240327/3b34d9fdcd5249489c91bff44e1b5f67.jpeg", + "https://p6.itc.cn/images01/20230111/766db970998e44baa0c52a4fe32c8d73.jpeg", + "https://bj.bcebos.com/bjh-pixel/1699070679633105560_0_ainote_new.jpg", + "https://p6.itc.cn/images01/20230217/556c66340432485ea9e1fe53109716ee.jpeg", + "https://img-nos.yiyouliao.com/e13ca727b210046d7dfe49943e167525.jpeg", + "https://p8.itc.cn/q_70/images03/20230614/d9c37a45350543db850e26b1712cb4b2.jpeg", + ], + "liuyifei": [ + "https://pic1.zhimg.com/v2-f8517e0d3d925640b6c8cc53068906f1_r.jpg", + "https://q8.itc.cn/q_70/images03/20240906/c296907327c44c3db0ea2232d98e5b41.png", + "https://q0.itc.cn/q_70/images03/20240706/dd0aded7edc94fbcad8175fcb7111d89.jpeg", + "https://q8.itc.cn/images01/20241002/561d2334b0e0451f97261292fcf66544.jpeg", + "https://q0.itc.cn/images01/20241109/8b02037be3c6403aa0dee761f7fc3841.jpeg", + "https://b0.bdstatic.com/ugc/M0gY3M_9OvzSG6xuTniCqQaa50a2e9a3f8d275158a90a8e448c237.jpg", + "https://q2.itc.cn/images01/20241002/cd77b7cb8f2044faaa46f75a1f392a50.jpeg", + "https://q0.itc.cn/images01/20241110/a22ba19b0a944eb28066b602378e2d45.jpeg", + ], + "dengjiajia":[ + "https://q4.itc.cn/q_70/images01/20240601/db61f410a20b4155b2e487b183372fc1.jpeg", + "https://b0.bdstatic.com/ugc/sR1WfaiiWi2KtStV_U3YGw49f741338ad3b10c10a1dec7aa5d8dd8.jpg", + "https://bkimg.cdn.bcebos.com/pic/b8389b504fc2d56285357c05334187ef76c6a6efc680", + "https://bkimg.cdn.bcebos.com/pic/b2de9c82d158ccbfc3fa73271ed8bc3eb135416d", + "https://bkimg.cdn.bcebos.com/pic/2cf5e0fe9925bc315c6099498a8f9ab1cb134854da9b", + "https://bkimg.cdn.bcebos.com/pic/738b4710b912c8fcc3ce2e773f578545d688d43f8b76" + ] + } +} + +export const faceidTestData = { stars } \ No newline at end of file diff --git a/src/utils/utils.ts b/src/utils/utils.ts new file mode 100644 index 0000000..c0795bf --- /dev/null +++ b/src/utils/utils.ts @@ -0,0 +1,33 @@ + +export namespace utils { + + export function rgba2rgb(data: T): T { + const pixelCount = data.length / 4; + const result = new (data.constructor as any)(pixelCount * 3) as T; + for (let i = 0; i < pixelCount; i++) { + result[i * 3 + 0] = data[i * 4 + 0]; + result[i * 3 + 1] = data[i * 4 + 1]; + result[i * 3 + 2] = data[i * 4 + 2]; + } + return result; + } + + export function rgb2bgr(data: T): T { + const pixelCount = data.length / 3; + const result = new (data.constructor as any)(pixelCount * 3) as T; + for (let i = 0; i < pixelCount; i++) { + result[i * 3 + 0] = data[i * 3 + 2]; + result[i * 3 + 1] = data[i * 3 + 1]; + result[i * 3 + 2] = data[i * 3 + 0]; + } + return result; + } + + export function normalize(data: Uint8Array | Float32Array, mean: number[], std: number[]): Float32Array { + const result = new Float32Array(data.length); + for (let i = 0; i < data.length; i++) { + result[i] = (data[i] - mean[i % mean.length]) / std[i % std.length]; + } + return result; + } +} diff --git a/tsconfig.json b/tsconfig.json new file mode 100644 index 0000000..37a7f90 --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,111 @@ +{ + "compilerOptions": { + /* Visit https://aka.ms/tsconfig to read more about this file */ + + /* Projects */ + // "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */ + // "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */ + // "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */ + // "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */ + // "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */ + // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */ + + /* Language and Environment */ + "target": "esnext", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ + // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ + // "jsx": "preserve", /* Specify what JSX code is generated. */ + // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */ + // "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */ + // "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */ + // "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */ + // "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */ + // "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */ + // "noLib": true, /* Disable including any library files, including the default lib.d.ts. */ + // "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */ + // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */ + + /* Modules */ + "module": "commonjs", /* Specify what module code is generated. */ + "rootDir": "./src", /* Specify the root folder within your source files. */ + // "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */ + // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */ + // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ + // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ + // "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */ + // "types": [], /* Specify type package names to be included without being referenced in a source file. */ + // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ + // "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */ + // "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */ + // "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */ + // "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */ + // "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */ + // "resolveJsonModule": true, /* Enable importing .json files. */ + // "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */ + // "noResolve": true, /* Disallow 'import's, 'require's or ''s from expanding the number of files TypeScript should add to a project. */ + + /* JavaScript Support */ + // "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */ + // "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */ + // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */ + + /* Emit */ + "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */ + // "declarationMap": true, /* Create sourcemaps for d.ts files. */ + // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */ + // "sourceMap": true, /* Create source map files for emitted JavaScript files. */ + // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */ + // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */ + "outDir": "./dist", /* Specify an output folder for all emitted files. */ + // "removeComments": true, /* Disable emitting comments. */ + // "noEmit": true, /* Disable emitting files from a compilation. */ + // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */ + // "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */ + // "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */ + // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ + // "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */ + // "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */ + // "newLine": "crlf", /* Set the newline character for emitting files. */ + // "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */ + // "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */ + // "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */ + // "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */ + "declarationDir": "./typing", /* Specify the output directory for generated declaration files. */ + + /* Interop Constraints */ + // "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */ + // "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */ + // "isolatedDeclarations": true, /* Require sufficient annotation on exports so other tools can trivially generate declaration files. */ + // "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */ + "esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */ + // "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */ + "forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */ + + /* Type Checking */ + "strict": true, /* Enable all strict type-checking options. */ + // "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */ + // "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */ + // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */ + // "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */ + // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */ + // "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */ + // "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */ + // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */ + // "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */ + // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */ + // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */ + // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */ + // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */ + // "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */ + // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */ + // "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */ + // "allowUnusedLabels": true, /* Disable error reporting for unused labels. */ + // "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */ + + /* Completeness */ + // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */ + "skipLibCheck": true /* Skip type checking all .d.ts files. */ + }, + "include": [ + "src/**/*.ts" + ] +}