248 lines
7.8 KiB
C++
248 lines
7.8 KiB
C++
#include <iostream>
|
|
#include <vector>
|
|
#include <map>
|
|
#include <cstring>
|
|
#include <MNN/Interpreter.hpp>
|
|
#include <MNN/ImageProcess.hpp>
|
|
#include "common/tensor.h"
|
|
#include "node.h"
|
|
|
|
using namespace Napi;
|
|
|
|
#define SESSION_INSTANCE_METHOD(method) InstanceMethod<&MNNSession::method>(#method, static_cast<napi_property_attributes>(napi_writable | napi_configurable))
|
|
|
|
static const std::map<TensorDataType, halide_type_t> DATA_TYPE_MAP = {
|
|
{TensorDataType::Float32, halide_type_of<float>()},
|
|
{TensorDataType::Float64, halide_type_of<double>()},
|
|
{TensorDataType::Int32, halide_type_of<int32_t>()},
|
|
{TensorDataType::Uint32, halide_type_of<uint32_t>()},
|
|
{TensorDataType::Int16, halide_type_of<int16_t>()},
|
|
{TensorDataType::Uint16, halide_type_of<uint16_t>()},
|
|
{TensorDataType::Int8, halide_type_of<int8_t>()},
|
|
{TensorDataType::Uint8, halide_type_of<uint8_t>()},
|
|
{TensorDataType::Int64, halide_type_of<int64_t>()},
|
|
{TensorDataType::Uint64, halide_type_of<uint64_t>()},
|
|
};
|
|
|
|
static size_t getShapeSize(const std::vector<int> &shape)
|
|
{
|
|
if (!shape.size()) return 0;
|
|
size_t sum = 1;
|
|
for (auto i : shape) {
|
|
if (i > 1) sum *= i;
|
|
};
|
|
return sum;
|
|
}
|
|
|
|
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
|
|
{
|
|
auto result = Napi::Array::New(env, shape.size());
|
|
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
|
|
return result;
|
|
}
|
|
|
|
class MNNSessionRunWorker : public AsyncWorker {
|
|
public:
|
|
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
|
: AsyncWorker(callback), interpreter_(interpreter), session_(session) {}
|
|
|
|
~MNNSessionRunWorker()
|
|
{
|
|
interpreter_->releaseSession(session_);
|
|
}
|
|
|
|
void Execute()
|
|
{
|
|
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
|
SetError(std::string("Run session failed"));
|
|
}
|
|
}
|
|
|
|
void OnOK()
|
|
{
|
|
if (HasError()) {
|
|
Callback().Call({Error::New(Env(), errorMessage_.c_str()).Value(), Env().Undefined()});
|
|
}
|
|
else {
|
|
auto result = Object::New(Env());
|
|
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
|
auto tensor = it.second;
|
|
auto item = Object::New(Env());
|
|
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
|
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
|
item.Set("data", buffer);
|
|
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
|
|
result.Set(it.first, item);
|
|
}
|
|
Callback().Call({Env().Undefined(), result});
|
|
}
|
|
}
|
|
|
|
void SetInput(const std::string &name, TensorDataType dataType, const std::vector<int> &shape, void *data, size_t dataBytes)
|
|
{
|
|
auto tensor = interpreter_->getSessionInput(session_, name.c_str());
|
|
if (!tensor) {
|
|
SetError(std::string("input name #" + name + " not exists"));
|
|
return;
|
|
}
|
|
|
|
halide_type_t type = tensor->getType();
|
|
if (dataType != TensorDataType::Unknown) {
|
|
auto it = DATA_TYPE_MAP.find(dataType);
|
|
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
|
}
|
|
|
|
if (shape.size()) {
|
|
interpreter_->resizeTensor(tensor, shape);
|
|
interpreter_->resizeSession(session_);
|
|
}
|
|
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
|
if (tensorBytes != dataBytes) {
|
|
SetError(std::string("input name #" + name + " data size not matched"));
|
|
return;
|
|
}
|
|
|
|
auto hostTensor = MNN::Tensor::create(tensor->shape(), type, data, MNN::Tensor::CAFFE);
|
|
tensor->copyFromHostTensor(hostTensor);
|
|
delete hostTensor;
|
|
}
|
|
|
|
inline void SetError(const std::string &what) { errorMessage_ = what; }
|
|
inline bool HasError() { return errorMessage_.size() > 0; }
|
|
|
|
private:
|
|
MNN::Interpreter *interpreter_;
|
|
MNN::Session *session_;
|
|
std::string errorMessage_;
|
|
};
|
|
|
|
class MNNSession : public ObjectWrap<MNNSession> {
|
|
public:
|
|
static Napi::Object Init(Napi::Env env, Napi::Object exports)
|
|
{
|
|
Function func = DefineClass(env, "MNNSession", {
|
|
SESSION_INSTANCE_METHOD(GetInputsInfo),
|
|
SESSION_INSTANCE_METHOD(GetOutputsInfo),
|
|
SESSION_INSTANCE_METHOD(Run),
|
|
});
|
|
FunctionReference *constructor = new FunctionReference();
|
|
*constructor = Napi::Persistent(func);
|
|
exports.Set("MNNSession", func);
|
|
env.SetInstanceData<FunctionReference>(constructor);
|
|
return exports;
|
|
}
|
|
|
|
MNNSession(const CallbackInfo &info)
|
|
: ObjectWrap(info)
|
|
{
|
|
try {
|
|
if (info[0].IsString()) {
|
|
interpreter_ = MNN::Interpreter::createFromFile(info[0].As<String>().Utf8Value().c_str());
|
|
}
|
|
else if (info[0].IsTypedArray()) {
|
|
size_t bufferBytes;
|
|
auto buffer = dataFromTypedArray(info[0], bufferBytes);
|
|
interpreter_ = MNN::Interpreter::createFromBuffer(buffer, bufferBytes);
|
|
}
|
|
else interpreter_ = nullptr;
|
|
|
|
if (interpreter_) {
|
|
backendConfig_.precision = MNN::BackendConfig::Precision_High;
|
|
backendConfig_.power = MNN::BackendConfig::Power_High;
|
|
scheduleConfig_.type = MNN_FORWARD_CPU;
|
|
scheduleConfig_.numThread = 1;
|
|
scheduleConfig_.backendConfig = &backendConfig_;
|
|
session_ = interpreter_->createSession(scheduleConfig_);
|
|
}
|
|
else session_ = nullptr;
|
|
}
|
|
catch (std::exception &e) {
|
|
Error::New(info.Env(), e.what()).ThrowAsJavaScriptException();
|
|
}
|
|
}
|
|
|
|
~MNNSession() {}
|
|
|
|
Napi::Value GetInputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionInputAll(session_)); }
|
|
|
|
Napi::Value GetOutputsInfo(const Napi::CallbackInfo &info) { return BuildInputOutputInfo(info.Env(), interpreter_->getSessionOutputAll(session_)); }
|
|
|
|
Napi::Value Run(const Napi::CallbackInfo &info)
|
|
{
|
|
auto worker = new MNNSessionRunWorker(info[1].As<Function>(), interpreter_, interpreter_->createSession(scheduleConfig_));
|
|
auto inputArgument = info[0].As<Object>();
|
|
for (auto it = inputArgument.begin(); it != inputArgument.end(); ++it) {
|
|
auto name = (*it).first.As<String>().Utf8Value();
|
|
auto inputOption = static_cast<Napi::Value>((*it).second).As<Object>();
|
|
auto type = inputOption.Has("type") ? static_cast<TensorDataType>(inputOption.Get("type").As<Number>().Int32Value()) : TensorDataType::Unknown;
|
|
size_t dataByteLen;
|
|
void *data = dataFromTypedArray(inputOption.Get("data"), dataByteLen);
|
|
auto shape = inputOption.Has("shape") ? GetShapeFromJavascript(inputOption.Get("shape").As<Array>()) : std::vector<int>();
|
|
worker->SetInput(name, type, shape, data, dataByteLen);
|
|
}
|
|
worker->Queue();
|
|
return info.Env().Undefined();
|
|
}
|
|
|
|
private:
|
|
Napi::Object BuildInputOutputInfo(Napi::Env env, const std::map<std::string, MNN::Tensor *> &tensors)
|
|
{
|
|
auto result = Object::New(env);
|
|
for (auto it : tensors) {
|
|
auto item = Object::New(env);
|
|
auto name = it.first;
|
|
auto shape = it.second->shape();
|
|
auto type = it.second->getType();
|
|
TensorDataType dataType = TensorDataType::Unknown;
|
|
for (auto dt : DATA_TYPE_MAP) {
|
|
if (dt.second == type) {
|
|
dataType = dt.first;
|
|
break;
|
|
}
|
|
}
|
|
auto shapeArr = Array::New(env, shape.size());
|
|
for (size_t i = 0; i < shape.size(); i++) {
|
|
shapeArr.Set(i, Number::New(env, shape[i]));
|
|
}
|
|
item.Set("name", String::New(env, name));
|
|
item.Set("shape", shapeArr);
|
|
item.Set("type", Number::New(env, static_cast<int>(dataType)));
|
|
result.Set(name, item);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::vector<int> GetShapeFromJavascript(const Napi::Array &shape)
|
|
{
|
|
std::vector<int> result;
|
|
for (size_t i = 0; i < shape.Length(); i++) {
|
|
result.push_back(shape.Get(i).As<Number>().Int32Value());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
MNN::Interpreter *interpreter_;
|
|
MNN::Session *session_;
|
|
MNN::BackendConfig backendConfig_;
|
|
MNN::ScheduleConfig scheduleConfig_;
|
|
};
|
|
|
|
void InstallMNNAPI(Napi::Env env, Napi::Object exports)
|
|
{
|
|
MNNSession::Init(env, exports);
|
|
}
|
|
|
|
|
|
#if defined(USE_MNN) && !defined(BUILD_MAIN_WORD)
|
|
static Object Init(Env env, Object exports)
|
|
{
|
|
#ifdef RELEASE_VERSION
|
|
exports.Set("__release__", Napi::String::New(env, RELEASE_VERSION));
|
|
#endif
|
|
InstallMNNAPI(env, exports);
|
|
return exports;
|
|
}
|
|
NODE_API_MODULE(addon, Init)
|
|
#endif
|