316 lines
11 KiB
C++
316 lines
11 KiB
C++
#include <iostream>
|
|
#include <vector>
|
|
#include <onnxruntime_cxx_api.h>
|
|
#include "node.h"
|
|
#include "common/tensor.h"
|
|
|
|
#ifdef WIN32
|
|
#include <locale>
|
|
#include <codecvt>
|
|
#include <Windows.h>
|
|
#endif
|
|
|
|
using namespace Napi;
|
|
|
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
|
|
|
static const std::map<TensorDataType, ONNXTensorElementDataType> DATA_TYPE_MAP = {
|
|
{TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
|
{TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
|
{TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
|
{TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
|
{TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
|
{TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
|
{TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
|
{TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
|
{TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
|
{TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
|
};
|
|
|
|
static const std::map<ONNXTensorElementDataType, size_t> DATA_TYPE_SIZE_MAP = {
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
|
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
|
};
|
|
|
|
class OrtSessionNodeInfo {
|
|
public:
|
|
inline OrtSessionNodeInfo(const std::string &name, const Ort::TypeInfo &typeInfo)
|
|
: name_(name), shape_(typeInfo.GetTensorTypeAndShapeInfo().GetShape()), type_(typeInfo.GetTensorTypeAndShapeInfo().GetElementType()) {}
|
|
|
|
inline const std::string &GetName() const { return name_; }
|
|
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
|
inline ONNXTensorElementDataType GetType() const { return type_; }
|
|
inline size_t GetElementSize() const
|
|
{
|
|
auto it = DATA_TYPE_SIZE_MAP.find(type_);
|
|
return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second;
|
|
}
|
|
TensorDataType GetDataType() const
|
|
{
|
|
auto datatype = TensorDataType::Unknown;
|
|
for (auto it : DATA_TYPE_MAP) {
|
|
if (it.second == type_) {
|
|
datatype = it.first;
|
|
break;
|
|
}
|
|
}
|
|
return datatype;
|
|
}
|
|
size_t GetElementCount() const
|
|
{
|
|
if (!shape_.size()) return 0;
|
|
size_t result = 1;
|
|
for (auto dim : shape_) {
|
|
if (dim <= 0) continue;
|
|
result *= dim;
|
|
}
|
|
return result;
|
|
}
|
|
inline size_t GetElementBytes() const { return GetElementSize() * GetElementCount(); }
|
|
|
|
private:
|
|
std::string name_;
|
|
std::vector<int64_t> shape_;
|
|
ONNXTensorElementDataType type_;
|
|
};
|
|
|
|
class OrtSessionRunWorker : public AsyncWorker {
|
|
public:
|
|
OrtSessionRunWorker(std::shared_ptr<Ort::Session> session, const Napi::Function &callback)
|
|
: Napi::AsyncWorker(callback), session_(session) {}
|
|
~OrtSessionRunWorker() {}
|
|
|
|
void Execute()
|
|
{
|
|
if (!HasError()) {
|
|
try {
|
|
Ort::RunOptions runOption;
|
|
session_->Run(runOption, &inputNames_[0], &inputValues_[0], inputNames_.size(), &outputNames_[0], &outputValues_[0], outputNames_.size());
|
|
}
|
|
catch (std::exception &err) {
|
|
SetError(err.what());
|
|
}
|
|
}
|
|
}
|
|
|
|
void OnOK()
|
|
{
|
|
if (HasError()) {
|
|
Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()});
|
|
}
|
|
else {
|
|
auto result = Object::New(Env());
|
|
for (int i = 0; i < outputNames_.size(); ++i) {
|
|
size_t bytes = outputElementBytes_[i];
|
|
Ort::Value &value = outputValues_[i];
|
|
auto buffer = ArrayBuffer::New(Env(), bytes);
|
|
memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes);
|
|
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
|
}
|
|
Callback().Call({Env().Undefined(), result});
|
|
}
|
|
}
|
|
|
|
inline void AddInput(const Ort::MemoryInfo &memoryInfo, const std::string &name, ONNXTensorElementDataType type, const std::vector<int64_t> &shape, void *data, size_t dataByteLen)
|
|
{
|
|
try {
|
|
inputNames_.push_back(name.c_str());
|
|
inputValues_.push_back(Ort::Value::CreateTensor(memoryInfo, data, dataByteLen, &shape[0], shape.size(), type));
|
|
}
|
|
catch (std::exception &err) {
|
|
SetError(err.what());
|
|
}
|
|
}
|
|
|
|
inline void AddOutput(const std::string &name, size_t elementBytes)
|
|
{
|
|
outputNames_.push_back(name.c_str());
|
|
outputValues_.push_back(Ort::Value{nullptr});
|
|
outputElementBytes_.push_back(elementBytes);
|
|
}
|
|
|
|
inline void SetError(const std::string &what) { errorMessage_ = what; }
|
|
|
|
inline bool HasError() { return errorMessage_.size() > 0; }
|
|
|
|
public:
|
|
std::shared_ptr<Ort::Session> session_;
|
|
std::vector<const char *> inputNames_;
|
|
std::vector<Ort::Value> inputValues_;
|
|
std::vector<const char *> outputNames_;
|
|
std::vector<Ort::Value> outputValues_;
|
|
std::vector<size_t> outputElementBytes_;
|
|
std::string errorMessage_;
|
|
};
|
|
|
|
class OrtSession : public ObjectWrap<OrtSession> {
|
|
public:
|
|
static Napi::Object Init(Napi::Env env, Napi::Object exports)
|
|
{
|
|
Function func = DefineClass(env, "OrtSession", {
|
|
SESSION_INSTANCE_METHOD(GetInputsInfo),
|
|
SESSION_INSTANCE_METHOD(GetOutputsInfo),
|
|
SESSION_INSTANCE_METHOD(Run),
|
|
});
|
|
FunctionReference *constructor = new FunctionReference();
|
|
*constructor = Napi::Persistent(func);
|
|
exports.Set("OrtSession", func);
|
|
env.SetInstanceData<FunctionReference>(constructor);
|
|
return exports;
|
|
}
|
|
|
|
OrtSession(const CallbackInfo &info)
|
|
: ObjectWrap(info)
|
|
{
|
|
try {
|
|
if (info[0].IsString()) {
|
|
#ifdef WIN32
|
|
std::string str = info[0].As<String>().Utf8Value();
|
|
auto len = MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.size(), NULL, 0);
|
|
wchar_t *buffer = new wchar_t[len + 1];
|
|
MultiByteToWideChar(CP_ACP, 0, str.c_str(), str.size(), buffer, len);
|
|
buffer[len] = '\0';
|
|
std::wstring filename(buffer);
|
|
delete[] buffer;
|
|
|
|
// std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
|
// std::wstring filename = converter.from_bytes(info[0].As<String>().Utf8Value());
|
|
session_ = std::make_shared<Ort::Session>(env, filename.c_str(), sessionOptions);
|
|
#else
|
|
session_ = std::make_shared<Ort::Session>(env, info[0].As<String>().Utf8Value().c_str(), sessionOptions);
|
|
#endif
|
|
}
|
|
else if (info[0].IsTypedArray()) {
|
|
size_t bufferBytes;
|
|
auto buffer = dataFromTypedArray(info[0], bufferBytes);
|
|
session_ = std::make_shared<Ort::Session>(env, buffer, bufferBytes, sessionOptions);
|
|
}
|
|
else session_ = nullptr;
|
|
|
|
if (session_ != nullptr) {
|
|
Ort::AllocatorWithDefaultOptions allocator;
|
|
for (int i = 0; i < session_->GetInputCount(); ++i) {
|
|
std::string name = session_->GetInputNameAllocated(i, allocator).get();
|
|
auto typeInfo = session_->GetInputTypeInfo(i);
|
|
inputs_.emplace(name, std::make_shared<OrtSessionNodeInfo>(name, typeInfo));
|
|
}
|
|
for (int i = 0; i < session_->GetOutputCount(); ++i) {
|
|
std::string name = session_->GetOutputNameAllocated(i, allocator).get();
|
|
auto typeInfo = session_->GetOutputTypeInfo(i);
|
|
outputs_.emplace(name, std::make_shared<OrtSessionNodeInfo>(name, typeInfo));
|
|
}
|
|
}
|
|
}
|
|
catch (std::exception &e) {
|
|
Error::New(info.Env(), e.what()).ThrowAsJavaScriptException();
|
|
}
|
|
}
|
|
|
|
~OrtSession() {}
|
|
|
|
Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), inputs_); }
|
|
|
|
Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildNodeInfoToJavascript(info.Env(), outputs_); }
|
|
|
|
Napi::Value Run(const Napi::CallbackInfo &info)
|
|
{
|
|
auto worker = new OrtSessionRunWorker(session_, info[1].As<Function>());
|
|
auto inputArgument = info[0].As<Object>();
|
|
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) {
|
|
auto name = (*it).first.As<String>().Utf8Value();
|
|
auto input = GetInput(name);
|
|
if (!input) worker->SetError(std::string("input name #" + name + " not exists"));
|
|
else {
|
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
|
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
|
else {
|
|
auto type = input->GetType();
|
|
if (inputOption.Has("type")) {
|
|
auto t = static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value());
|
|
auto it = DATA_TYPE_MAP.find(t);
|
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
|
}
|
|
size_t dataByteLen;
|
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
|
worker->AddInput(memoryInfo, input->GetName(), type, shape, data, dataByteLen);
|
|
}
|
|
}
|
|
}
|
|
for (auto &it : outputs_) {
|
|
worker->AddOutput(it.second->GetName(), it.second->GetElementBytes());
|
|
}
|
|
worker->Queue();
|
|
|
|
return info.Env().Undefined();
|
|
}
|
|
|
|
private:
|
|
static std::vector<int64_t> GetShapeFromJavascript(Napi::Array arr)
|
|
{
|
|
std::vector<int64_t> res(arr.Length());
|
|
for (int i = 0; i < res.size(); ++i) res[i] = arr.Get(i).As<Number>().Int32Value();
|
|
return res;
|
|
}
|
|
|
|
inline std::shared_ptr<OrtSessionNodeInfo> GetInput(const std::string &name) const
|
|
{
|
|
auto it = inputs_.find(name);
|
|
return (it == inputs_.end()) ? nullptr : it->second;
|
|
}
|
|
|
|
inline std::shared_ptr<OrtSessionNodeInfo> GetOutput(const std::string &name) const
|
|
{
|
|
auto it = outputs_.find(name);
|
|
return (it == outputs_.end()) ? nullptr : it->second;
|
|
}
|
|
|
|
Napi::Value BuildNodeInfoToJavascript(Napi::Env env, const std::map<std::string, std::shared_ptr<OrtSessionNodeInfo>> &nodes)
|
|
{
|
|
auto result = Object::New(env);
|
|
for (auto it : nodes) {
|
|
auto &node = *it.second;
|
|
auto item = Object::New(env);
|
|
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
|
item.Set(String::New(env, "type"), Number::New(env, static_cast<int>(node.GetDataType())));
|
|
auto &shapeVec = node.GetShape();
|
|
auto shape = Array::New(env, shapeVec.size());
|
|
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
|
item.Set(String::New(env, "shape"), shape);
|
|
result.Set(String::New(env, node.GetName()), item);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
Ort::Env env;
|
|
Ort::SessionOptions sessionOptions;
|
|
std::shared_ptr<Ort::Session> session_;
|
|
std::map<std::string, std::shared_ptr<OrtSessionNodeInfo>> inputs_;
|
|
std::map<std::string, std::shared_ptr<OrtSessionNodeInfo>> outputs_;
|
|
};
|
|
|
|
void InstallOrtAPI(Napi::Env env, Napi::Object exports)
|
|
{
|
|
OrtSession::Init(env, exports);
|
|
}
|
|
|
|
#if defined(USE_ONNXRUNTIME) && !defined(BUILD_MAIN_WORD)
|
|
static Object Init(Env env, Object exports)
|
|
{
|
|
InstallOrtAPI(env, exports);
|
|
return exports;
|
|
}
|
|
NODE_API_MODULE(addon, Init)
|
|
|
|
#endif |