#include #include #include #include #include #include #include "common/tensor.h" #include "node.h" using namespace Napi; #define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast(napi_writable | napi_configurable)) static const std::map DATA_TYPE_MAP = { {TensorDataType::Float32, halide_type_of()}, {TensorDataType::Float64, halide_type_of()}, {TensorDataType::Int32, halide_type_of()}, {TensorDataType::Uint32, halide_type_of()}, {TensorDataType::Int16, halide_type_of()}, {TensorDataType::Uint16, halide_type_of()}, {TensorDataType::Int8, halide_type_of()}, {TensorDataType::Uint8, halide_type_of()}, {TensorDataType::Int64, halide_type_of()}, {TensorDataType::Uint64, halide_type_of()}, }; static size_t getShapeSize(const std::vector &shape) { if (!shape.size()) return 0; size_t sum = 1; for (auto i : shape) { if (i > 1) sum *= i; }; return sum; } static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector &shape) { auto result = Napi::Array::New(env, shape.size()); for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i])); return result; } class MNNSessionRunWorker : public AsyncWorker { public: MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session) : AsyncWorker(callback), interpreter_(interpreter), session_(session) {} ~MNNSessionRunWorker() { interpreter_->releaseSession(session_); } void Execute() { if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) { SetError(std::string("Run session failed")); } } void OnOK() { if (HasError()) { Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()}); } else { auto result = Object::New(Env()); for (auto it : interpreter_->getSessionOutputAll(session_)) { auto tensor = it.second; auto item = Object::New(Env()); auto buffer = ArrayBuffer::New(Env(), tensor->size()); memcpy(buffer.Data(), tensor->host(), tensor->size()); item.Set("data", buffer); item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape())); result.Set(it.first, item); } Callback().Call({Env().Undefined(), result}); } } void SetInput(const std::string &name, TensorDataType dataType, const std::vector &shape, void *data, size_t dataBytes) { auto tensor = interpreter_->getSessionInput(session_, name.c_str()); if (!tensor) { SetError(std::string("input name #" + name + " not exists")); return; } halide_type_t type = tensor->getType(); if (dataType != TensorDataType::Unknown) { auto it = DATA_TYPE_MAP.find(dataType); if (it != DATA_TYPE_MAP.end()) type = it->second; } if (shape.size()) { interpreter_->resizeTensor(tensor, shape); interpreter_->resizeSession(session_); } auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8; if (tensorBytes != dataBytes) { SetError(std::string("input name #" + name + " data size not matched")); return; } auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE); tensor->copyFromHostTensor(hostTensor); delete hostTensor; } inline void SetError(const std::string &what) { errorMessage_ = what; } inline bool HasError() { return errorMessage_.size() > 0; } private: MNN::Interpreter *interpreter_; MNN::Session *session_; std::string errorMessage_; }; class MNNSession : public ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports) { Function func = DefineClass(env, "MNNSession", { SESSION_INSTANCE_METHOD(GetInputsInfo), SESSION_INSTANCE_METHOD(GetOutputsInfo), SESSION_INSTANCE_METHOD(Run), }); FunctionReference *constructor = new FunctionReference(); *constructor = Napi::Persistent(func); exports.Set("MNNSession", func); env.SetInstanceData(constructor); return exports; } MNNSession(const CallbackInfo &info) : ObjectWrap(info) { try { if (info[0].IsString()) { interpreter_ = MNN::Interpreter::createFromFile(info[0].As().Utf8Value().c_str()); } else if (info[0].IsTypedArray()) { size_t bufferBytes; auto buffer = dataFromTypedArray(info[0], bufferBytes); interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes); } else interpreter_ = nullptr; if (interpreter_) { backendConfig_.precision = MNN::BackendConfig::Precision_High; backendConfig_.power = MNN::BackendConfig::Power_High; scheduleConfig_.type = MNN_FORWARD_CPU; scheduleConfig_.numThread = 1; scheduleConfig_.backendConfig = &backendConfig_; session_ = interpreter_->createSession(scheduleConfig_); } else session_ = nullptr; } catch (std::exception &e) { Error::New(info.Env(), e.what()).ThrowAsJavaScriptException(); } } ~MNNSession() {} Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); } Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); } Napi::Value Run(const Napi::CallbackInfo &info) { auto worker = new MNNSessionRunWorker(info[1].As(), interpreter_, interpreter_->createSession(scheduleConfig_)); auto inputArgument = info[0].As(); for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) { auto name = (*it).first.As().Utf8Value(); auto inputOption = static_cast((*it).second).As(); auto type = inputOption.Has("type") ? static_cast(inputOption.Get("type").As().Int32Value()) : TensorDataType::Unknown; size_t dataByteLen; void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : std::vector(); worker->SetInput(name, type, shape, data, dataByteLen); } worker->Queue(); return info.Env().Undefined(); } private: Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map &tensors) { auto result = Object::New(env); for (auto it : tensors) { auto item = Object::New(env); auto name = it.first; auto shape = it.second->shape(); auto type = it.second->getType(); TensorDataType dataType = TensorDataType::Unknown; for (auto dt : DATA_TYPE_MAP) { if (dt.second == type) { dataType = dt.first; break; } } auto shapeArr = Array::New(env, shape.size()); for (size_t i = 0; i < shape.size(); i++) { shapeArr.Set(i, Number::New(env, shape[i])); } item.Set("name", String::New(env, name)); item.Set("shape", shapeArr); item.Set("type", Number::New(env, static_cast(dataType))); result.Set(name, item); } return result; } std::vector GetShapeFromJavascript(const Napi::Array &shape) { std::vector result; for (size_t i = 0; i < shape.Length(); i++) { result.push_back(shape.Get(i).As().Int32Value()); } return result; } private: MNN::Interpreter *interpreter_; MNN::Session *session_; MNN::BackendConfig backendConfig_; MNN::ScheduleConfig scheduleConfig_; }; void InstallMNNAPI(Napi::Env env, Napi::Object exports) { MNNSession::Init(env, exports); } #if defined(USE_MNN) && !defined(BUILD_MAIN_WORD) static Object Init(Env env, Object exports) { InstallMNNAPI(env, exports); return exports; } NODE_API_MODULE(addon, Init) #endif