针对mnn模型处理,输出增加shape
This commit is contained in:
@ -34,6 +34,13 @@ static size_t getShapeSize(const std::vector<int> &shape)
|
||||
return sum;
|
||||
}
|
||||
|
||||
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
|
||||
{
|
||||
auto result = Napi::Array::New(env, shape.size());
|
||||
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
class MNNSessionRunWorker : public AsyncWorker {
|
||||
public:
|
||||
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
|
||||
@ -46,7 +53,6 @@ class MNNSessionRunWorker : public AsyncWorker {
|
||||
|
||||
void Execute()
|
||||
{
|
||||
interpreter_->resizeSession(session_);
|
||||
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
|
||||
SetError(std::string("Run session failed"));
|
||||
}
|
||||
@ -61,9 +67,12 @@ class MNNSessionRunWorker : public AsyncWorker {
|
||||
auto result = Object::New(Env());
|
||||
for (auto it : interpreter_->getSessionOutputAll(session_)) {
|
||||
auto tensor = it.second;
|
||||
auto item = Object::New(Env());
|
||||
auto buffer = ArrayBuffer::New(Env(), tensor->size());
|
||||
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
|
||||
result.Set(it.first, buffer);
|
||||
item.Set("data", buffer);
|
||||
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
|
||||
result.Set(it.first, item);
|
||||
}
|
||||
Callback().Call({Env().Undefined(), result});
|
||||
}
|
||||
@ -83,8 +92,10 @@ class MNNSessionRunWorker : public AsyncWorker {
|
||||
if (it != DATA_TYPE_MAP.end()) type = it->second;
|
||||
}
|
||||
|
||||
if (shape.size()) interpreter_->resizeTensor(tensor, shape);
|
||||
|
||||
if (shape.size()) {
|
||||
interpreter_->resizeTensor(tensor, shape);
|
||||
interpreter_->resizeSession(session_);
|
||||
}
|
||||
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
|
||||
if (tensorBytes != dataBytes) {
|
||||
SetError(std::string("input name #" + name + " data size not matched"));
|
||||
|
Reference in New Issue
Block a user