#include #include #include #include "node.h" using namespace Napi; #define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast(napi_writable | napi_configurable)) static ONNXTensorElementDataType getDataTypeFromString(const std::string &name) { static const std::map dataTypeNameMap = { {"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT}, {"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE}, {"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8}, {"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8}, {"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16}, {"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16}, {"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32}, {"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32}, {"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}, {"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64}, }; auto it = dataTypeNameMap.find(name); return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second; } static size_t getDataTypeSize(ONNXTensorElementDataType type) { static const std::map dataTypeSizeMap = { {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8}, {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8}, }; auto it = dataTypeSizeMap.find(type); return (it == dataTypeSizeMap.end()) ? 0 : it->second; } static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes) { auto arr = val.As(); auto data = static_cast(arr.ArrayBuffer().Data()); bytes = arr.ByteLength(); return static_cast(data + arr.ByteOffset()); } class OrtSessionNodeInfo { public: inline OrtSessionNodeInfo(const std::string &name, const Ort::TypeInfo &typeInfo) : name_(name), shape_(typeInfo.GetTensorTypeAndShapeInfo().GetShape()), type_(typeInfo.GetTensorTypeAndShapeInfo().GetElementType()) {} inline const std::string &GetName() const { return name_; } inline const std::vector &GetShape() const { return shape_; } inline ONNXTensorElementDataType GetType() const { return type_; } inline size_t GetElementSize() const { return getDataTypeSize(type_); } size_t GetElementCount() const { if (!shape_.size()) return 0; size_t result = 1; for (auto dim : shape_) { if (dim <= 0) continue; result *= dim; } return result; } inline size_t GetElementBytes() const { return GetElementSize() * GetElementCount(); } private: std::string name_; std::vector shape_; ONNXTensorElementDataType type_; }; class OrtSessionRunWorker : public AsyncWorker { public: OrtSessionRunWorker(std::shared_ptr session, const Napi::Function &callback) : Napi::AsyncWorker(callback), session_(session) {} ~OrtSessionRunWorker() {} void Execute() { if (!HasError()) { try { Ort::RunOptions runOption; session_->Run(runOption, &inputNames_[0], &inputValues_[0], inputNames_.size(), &outputNames_[0], &outputValues_[0], outputNames_.size()); } catch (std::exception &err) { SetError(err.what()); } } } void OnOK() { if (HasError()) { Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()}); } else { auto result = Object::New(Env()); for (int i = 0; i < outputNames_.size(); ++i) { size_t bytes = outputElementBytes_[i]; Ort::Value &value = outputValues_[i]; auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes); result.Set(String::New(Env(), outputNames_[i]), buffer); } Callback().Call({Env().Undefined(), result}); } } inline void AddInput(const Ort::MemoryInfo &memoryInfo, const std::string &name, ONNXTensorElementDataType type, const std::vector &shape, void *data, size_t dataByteLen) { try { inputNames_.push_back(name.c_str()); inputValues_.push_back(Ort::Value::CreateTensor(memoryInfo, data, dataByteLen, &shape[0], shape.size(), type)); } catch (std::exception &err) { SetError(err.what()); } } inline void AddOutput(const std::string &name, size_t elementBytes) { outputNames_.push_back(name.c_str()); outputValues_.push_back(Ort::Value{nullptr}); outputElementBytes_.push_back(elementBytes); } inline void SetError(const std::string &what) { errorMessage_ = what; } inline bool HasError() { return errorMessage_.size() > 0; } public: std::shared_ptr session_; std::vector inputNames_; std::vector inputValues_; std::vector outputNames_; std::vector outputValues_; std::vector outputElementBytes_; std::string errorMessage_; }; class OrtSession : public ObjectWrap { public: static Napi::Object Init(Napi::Env env, Napi::Object exports) { Function func = DefineClass(env, "OrtSession", { SESSION_INSTANCE_METHOD(GetInputsInfo), SESSION_INSTANCE_METHOD(GetOutputsInfo), SESSION_INSTANCE_METHOD(Run), }); FunctionReference *constructor = new FunctionReference(); *constructor = Napi::Persistent(func); exports.Set("OrtSession", func); env.SetInstanceData(constructor); return exports; } OrtSession(const CallbackInfo &info) : ObjectWrap(info) { try { if (info[0].IsString()) session_ = std::make_shared(env, info[0].As().Utf8Value().c_str(), sessionOptions); else if (info[0].IsTypedArray()) { size_t bufferBytes; auto buffer = dataFromTypedArray(info[0], bufferBytes); session_ = std::make_shared(env, buffer, bufferBytes, sessionOptions); } else session_ = nullptr; if (session_ != nullptr) { Ort::AllocatorWithDefaultOptions allocator; for (int i = 0; i < session_->GetInputCount(); ++i) { std::string name = session_->GetInputNameAllocated(i, allocator).get(); auto typeInfo = session_->GetInputTypeInfo(i); inputs_.emplace(name, std::make_shared(name, typeInfo)); } for (int i = 0; i < session_->GetOutputCount(); ++i) { std::string name = session_->GetOutputNameAllocated(i, allocator).get(); auto typeInfo = session_->GetOutputTypeInfo(i); outputs_.emplace(name, std::make_shared(name, typeInfo)); } } } catch (std::exception &e) { Error::New(info.Env(), e.what()).ThrowAsJavaScriptException(); } } ~OrtSession() {} Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), inputs_); } Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), outputs_); } Napi::Value Run(const Napi::CallbackInfo &info) { auto worker = new OrtSessionRunWorker(session_, info[1].As()); auto inputArgument = info[0].As(); auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) { auto name = (*it).first.As().Utf8Value(); auto input = GetInput(name); if (!input) worker->SetError(std::string("input name #" + name + " not exists")); else { auto inputOption = static_cast((*it).second).As(); if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name))); else { auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As().Utf8Value()) : input->GetType(); size_t dataByteLen; void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen); auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As()) : input->GetShape(); worker->AddInput(memoryInfo, input->GetName(), type, shape, data, dataByteLen); } } } for (auto &it : outputs_) { worker->AddOutput(it.second->GetName(), it.second->GetElementBytes()); } worker->Queue(); return info.Env().Undefined(); } private: static std::vector GetShapeFromJavascript(Napi::Array arr) { std::vector res(arr.Length()); for (int i = 0; i < res.size(); ++i) res[i] = arr.Get(i).As().Int32Value(); return res; } inline std::shared_ptr GetInput(const std::string &name) const { auto it = inputs_.find(name); return (it == inputs_.end()) ? nullptr : it->second; } inline std::shared_ptr GetOutput(const std::string &name) const { auto it = outputs_.find(name); return (it == outputs_.end()) ? nullptr : it->second; } Napi::Value BuildNodeInfoToJavascript(Napi::Env env, const std::map> &nodes) { auto result = Object::New(env); for (auto it : nodes) { auto &node = *it.second; auto item = Object::New(env); item.Set(String::New(env, "name"), String::New(env, node.GetName())); item.Set(String::New(env, "type"), Number::New(env, node.GetType())); auto &shapeVec = node.GetShape(); auto shape = Array::New(env, shapeVec.size()); for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i])); item.Set(String::New(env, "shape"), shape); result.Set(String::New(env, node.GetName()), item); } return result; } private: Ort::Env env; Ort::SessionOptions sessionOptions; std::shared_ptr session_; std::map> inputs_; std::map> outputs_; }; void InstallOrtAPI(Napi::Env env, Napi::Object exports) { OrtSession::Init(env, exports); } #ifdef USE_ONNXRUNTIME static Object Init(Env env, Object exports) { InstallOrtAPI(env, exports); return exports; } NODE_API_MODULE(addon, Init) #endif