增加MNN模型支持
This commit is contained in:
100
cxx/ort/node.cc
100
cxx/ort/node.cc
@ -2,6 +2,7 @@
|
||||
#include <vector>
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
#include "node.h"
|
||||
#include "common/tensor.h"
|
||||
|
||||
#ifdef WIN32
|
||||
#include <locale>
|
||||
@ -13,51 +14,31 @@ using namespace Napi;
|
||||
|
||||
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&OrtSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
||||
|
||||
static ONNXTensorElementDataType getDataTypeFromString(const std::string &name)
|
||||
{
|
||||
static const std::map<std::string, ONNXTensorElementDataType> dataTypeNameMap = {
|
||||
{"float32", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
||||
{"float", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
||||
{"float64", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
||||
{"double", ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
||||
{"int8", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
||||
{"uint8", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
||||
{"int16", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
||||
{"uint16", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
||||
{"int32", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
||||
{"uint32", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
||||
{"int64", ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
||||
{"uint64", ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
||||
};
|
||||
auto it = dataTypeNameMap.find(name);
|
||||
return (it == dataTypeNameMap.end()) ? ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED : it->second;
|
||||
}
|
||||
static const std::map<TensorDataType, ONNXTensorElementDataType> DATA_TYPE_MAP = {
|
||||
{TensorDataType::Float32, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT},
|
||||
{TensorDataType::Float64, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE},
|
||||
{TensorDataType::Int32, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32},
|
||||
{TensorDataType::Uint32, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32},
|
||||
{TensorDataType::Int16, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16},
|
||||
{TensorDataType::Uint16, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16},
|
||||
{TensorDataType::Int8, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8},
|
||||
{TensorDataType::Uint8, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8},
|
||||
{TensorDataType::Int64, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64},
|
||||
{TensorDataType::Uint64, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64},
|
||||
};
|
||||
|
||||
static size_t getDataTypeSize(ONNXTensorElementDataType type)
|
||||
{
|
||||
static const std::map<ONNXTensorElementDataType, size_t> dataTypeSizeMap = {
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
||||
};
|
||||
auto it = dataTypeSizeMap.find(type);
|
||||
return (it == dataTypeSizeMap.end()) ? 0 : it->second;
|
||||
}
|
||||
|
||||
static void *dataFromTypedArray(const Napi::Value &val, size_t &bytes)
|
||||
{
|
||||
auto arr = val.As<TypedArray>();
|
||||
auto data = static_cast<uint8_t *>(arr.ArrayBuffer().Data());
|
||||
bytes = arr.ByteLength();
|
||||
return static_cast<void *>(data + arr.ByteOffset());
|
||||
}
|
||||
static const std::map<ONNXTensorElementDataType, size_t> DATA_TYPE_SIZE_MAP = {
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, 8},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, 1},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, 1},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, 2},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, 2},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, 4},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, 8},
|
||||
{ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, 8},
|
||||
};
|
||||
|
||||
class OrtSessionNodeInfo {
|
||||
public:
|
||||
@ -67,7 +48,22 @@ class OrtSessionNodeInfo {
|
||||
inline const std::string &GetName() const { return name_; }
|
||||
inline const std::vector<int64_t> &GetShape() const { return shape_; }
|
||||
inline ONNXTensorElementDataType GetType() const { return type_; }
|
||||
inline size_t GetElementSize() const { return getDataTypeSize(type_); }
|
||||
inline size_t GetElementSize() const
|
||||
{
|
||||
auto it = DATA_TYPE_SIZE_MAP.find(type_);
|
||||
return (it == DATA_TYPE_SIZE_MAP.end()) ? 0 : it->second;
|
||||
}
|
||||
TensorDataType GetDataType() const
|
||||
{
|
||||
auto datatype = TensorDataType::Unknown;
|
||||
for (auto it : DATA_TYPE_MAP) {
|
||||
if (it.second == type_) {
|
||||
datatype = it.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return datatype;
|
||||
}
|
||||
size_t GetElementCount() const
|
||||
{
|
||||
if (!shape_.size()) return 0;
|
||||
@ -115,7 +111,8 @@ class OrtSessionRunWorker : public AsyncWorker {
|
||||
for (int i = 0; i < outputNames_.size(); ++i) {
|
||||
size_t bytes = outputElementBytes_[i];
|
||||
Ort::Value &value = outputValues_[i];
|
||||
auto buffer = ArrayBuffer::New(Env(), value.GetTensorMutableRawData(), bytes);
|
||||
auto buffer = ArrayBuffer::New(Env(), bytes);
|
||||
memcpy(buffer.Data(), value.GetTensorMutableRawData(), bytes);
|
||||
result.Set(String::New(Env(), outputNames_[i]), buffer);
|
||||
}
|
||||
Callback().Call({Env().Undefined(), result});
|
||||
@ -236,7 +233,12 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
||||
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
||||
if (!inputOption.Has("data") || !inputOption.Get("data").IsTypedArray()) worker->SetError((std::string("data is required in inputs #" + name)));
|
||||
else {
|
||||
auto type = inputOption.Has("type") ? getDataTypeFromString(inputOption.Get("type").As<String>().Utf8Value()) : input->GetType();
|
||||
auto type = input->GetType();
|
||||
if (inputOption.Has("type")) {
|
||||
auto t = static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value());
|
||||
auto it = DATA_TYPE_MAP.find(t);
|
||||
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||
}
|
||||
size_t dataByteLen;
|
||||
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
||||
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : input->GetShape();
|
||||
@ -279,7 +281,7 @@ class OrtSession : public ObjectWrap<OrtSession> {
|
||||
auto &node = *it.second;
|
||||
auto item = Object::New(env);
|
||||
item.Set(String::New(env, "name"), String::New(env, node.GetName()));
|
||||
item.Set(String::New(env, "type"), Number::New(env, node.GetType()));
|
||||
item.Set(String::New(env, "type"), Number::New(env, static_cast<int>(node.GetDataType())));
|
||||
auto &shapeVec = node.GetShape();
|
||||
auto shape = Array::New(env, shapeVec.size());
|
||||
for (int i = 0; i < shapeVec.size(); ++i) shape.Set(i, Number::New(env, shapeVec[i]));
|
||||
@ -303,7 +305,7 @@ void InstallOrtAPI(Napi::Env env, Napi::Object exports)
|
||||
OrtSession::Init(env, exports);
|
||||
}
|
||||
|
||||
#ifdef USE_ONNXRUNTIME
|
||||
#if defined(USE_ONNXRUNTIME) && not defined(BUILD_MAIN_WORD)
|
||||
static Object Init(Env env, Object exports)
|
||||
{
|
||||
InstallOrtAPI(env, exports);
|
||||
|
Reference in New Issue
Block a user