#include #include #include #include "node.h" #include "common/tensor.h" #ifdef WIN32 #include #include #include #endif using namespace Napi; #define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast(napi_writable | napi_configurable)) static const std::map DATA_TYPE_MAP = { {TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}, }; static const std::map DATA_TYPE_SIZE_MAP = { {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8}, }; class OrtSessionNodeInfo { public: inline OrtSessionNodeInfo(const std::string &name, const Ort::TypeInfo &typeInfo) : name_(name), shape_(typeInfo.GetTensorTypeAndShapeInfo().GetShape()), type_(typeInfo.GetTensorTypeAndShapeInfo().GetElementType()) {} inline const std::string &GetName() const { return name_; } inline const std::vector &GetShape() const { return shape_; } inline ONNXTensorElementDataType GetType() const { return type_; } inline size_t GetElementSize() const { auto it = DATA_TYPE_SIZE_MAP.find(type_); return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second; } TensorDataType GetDataType() const { auto datatype = TensorDataType::Unknown; for (auto it : DATA_TYPE_MAP) { if (it.second == type_) { datatype = it.first; break; } } return datatype; } size_t GetElementCount() const { if (!shape_.size()) return 0; size_t result = 1; for (auto dim : shape_) { if (dim <= 0) continue; result *= dim; } return result; } inline size_t GetElementBytes() const { return GetElementSize() * GetElementCount(); } private: std::string name_; std::vector shape_; ONNXTensorElementDataType type_; }; class OrtSessionRunWorker : public AsyncWorker { public: OrtSessionRunWorker(std::shared_ptr session, const Napi::Function &callback) : Napi::AsyncWorker(callback), session_(session) {} ~OrtSessionRunWorker() {} void Execute() { if (!HasError()) { try { Ort::RunOptions runOption; session_->Run(runOption, &inputNames_[0], &inputValues_[0], inputNames_.size(), &outputNames_[0], &outputValues_[0], outputNames_.size()); } catch (std::exception &err) { SetError(err.what()); } } } void OnOK() { if (HasError()) { Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()}); } else { auto result = Object::New(Env()); for (int i = 0; i < outputNames_.size(); ++i) { size_t bytes = outputElementBytes_[i]; Ort::Value &value = outputValues_[i]; auto buffer = ArrayBuffer::New(Env(), bytes); memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes); result.Set(String::New(Env(), outputNames_[i]), buffer); } Callback().Call({Env().Undefined(), result}); } } inline void AddInput(const Ort::MemoryInfo &memoryInfo, const std::string &name, ONNXTensorElementDataType type, const std::vector &shape, void *data, size_t dataByteLen) { try { inputNames_.push_back(name.c_str()); inputValues_.push_back(Ort::Value::CreateTensor(memoryInfo, data, dataByteLen, &shape[0], shape.size(), type)); } catch (std::exception &err) { SetError(err.what()); } } inline void AddOutput(const std::string &name, size_t elementBytes) { outputNames_.push_back(name.c_str()); outputValues_.push_back(Ort::Value{nullptr}); outputElementBytes_.push_back(elementBytes); } inline void SetError(const std::string &what) { errorMessage_ = what; } inline bool HasError() { return errorMessage_.size() > 0; } public: std::shared_ptr session_; std::vector inputNames_; std::vector inputValues_; std::vector outputNames_; std::vector outputValues_; std::vector outputElementBytes_; std::string errorMessage_; }; class OrtSession : public ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports) { Function func = DefineClass(env, "OrtSession", { SESSION_INSTANCE_METHOD(GetInputsInfo), SESSION_INSTANCE_METHOD(GetOutputsInfo), SESSION_INSTANCE_METHOD(Run), }); FunctionReference *constructor = new FunctionReference(); *constructor = Napi::Persistent(func); exports.Set("OrtSession", func); env.SetInstanceData(constructor); return exports; } OrtSession(const CallbackInfo &info) : ObjectWrap(info) { try { if (info[0].IsString()) { #ifdef WIN32 std::string str = info[0].As().Utf8Value(); auto len = MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.size(), NULL, 0); wchar_t *buffer = new wchar_t[len + 1]; MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.size(), buffer, len); buffer[len] = '\0'; std::wstring filename(buffer); delete[] buffer; // std::wstring_convert> converter; // std::wstring filename = converter.from_bytes(info[0].As().Utf8Value()); session_ = std::make_shared(env, filename.c_str(), sessionOptions); #else session_ = std::make_shared(env, info[0].As().Utf8Value().c_str(), sessionOptions); #endif } else if (info[0].IsTypedArray()) { size_t bufferBytes; auto buffer = dataFromTypedArray(info[0], bufferBytes); session_ = std::make_shared(env, buffer, bufferBytes, sessionOptions); } else session_ = nullptr; if (session_ != nullptr) { Ort::AllocatorWithDefaultOptions allocator; for (int i = 0; i < session_->GetInputCount(); ++i) { std::string name = session_->GetInputNameAllocated(i, allocator).get(); auto typeInfo = session_->GetInputTypeInfo(i); inputs_.emplace(name, std::make_shared(name, typeInfo)); } for (int i = 0; i < session_->GetOutputCount(); ++i) { std::string name = session_->GetOutputNameAllocated(i, allocator).get(); auto typeInfo = session_->GetOutputTypeInfo(i); outputs_.emplace(name, std::make_shared(name, typeInfo)); } } } catch (std::exception &e) { Error::New(info.Env(), e.what()).ThrowAsJavaScriptException(); } } ~OrtSession() {} Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), inputs_); } Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), outputs_); } Napi::Value Run(const Napi::CallbackInfo &info) { auto worker = new OrtSessionRunWorker(session_, info[1].As()); auto inputArgument = info[0].As(); auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) { auto name = (*it).first.As().Utf8Value(); auto input = GetInput(name); if (!input) worker->SetError(std::string("input name #" + name + " not exists")); else { auto inputOption = static_cast((*it).second).As(); if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name))); else { auto type = input->GetType(); if (inputOption.Has("type")) { auto t = static_cast(inputOption.Get("type").As().Int32Value()); auto it = DATA_TYPE_MAP.find(t); if (it != DATA_TYPE_MAP.end()) type = it->second; } size_t dataByteLen; void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : input->GetShape(); worker->AddInput(memoryInfo, input->GetName(), type, shape, data, dataByteLen); } } } for (auto &it : outputs_) { worker->AddOutput(it.second->GetName(), it.second->GetElementBytes()); } worker->Queue(); return info.Env().Undefined(); } private: static std::vector GetShapeFromJavascript(Napi::Array arr) { std::vector res(arr.Length()); for (int i = 0; i < res.size(); ++i) res[i] = arr.Get(i).As().Int32Value(); return res; } inline std::shared_ptr GetInput(const std::string &name) const { auto it = inputs_.find(name); return (it == inputs_.end()) ? nullptr : it->second; } inline std::shared_ptr GetOutput(const std::string &name) const { auto it = outputs_.find(name); return (it == outputs_.end()) ? nullptr : it->second; } Napi::Value BuildNodeInfoToJavascript(Napi::Env env, const std::map> &nodes) { auto result = Object::New(env); for (auto it : nodes) { auto &node = *it.second; auto item = Object::New(env); item.Set(String::New(env, "name"), String::New(env, node.GetName())); item.Set(String::New(env, "type"), Number::New(env, static_cast(node.GetDataType()))); auto &shapeVec = node.GetShape(); auto shape = Array::New(env, shapeVec.size()); for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i])); item.Set(String::New(env, "shape"), shape); result.Set(String::New(env, node.GetName()), item); } return result; } private: Ort::Env env; Ort::SessionOptions sessionOptions; std::shared_ptr session_; std::map> inputs_; std::map> outputs_; }; void InstallOrtAPI(Napi::Env env, Napi::Object exports) { OrtSession::Init(env, exports); } #if defined(USE_ONNXRUNTIME) && !defined(BUILD_MAIN_WORD) static Object Init(Env env, Object exports) { #ifdef RELEASE_VERSION exports.Set("__release__", Napi::String::New(env, RELEASE_VERSION)); #endif InstallOrtAPI(env, exports); return exports; } NODE_API_MODULE(addon, Init) #endif