针对mnn模型处理,输出增加shape

This commit is contained in:
2025-03-10 17:49:53 +08:00
parent e362890c96
commit a966b82963
13 changed files with 51 additions and 29 deletions

View File

@ -34,6 +34,13 @@ static size_t getShapeSize(const std::vector<int> &shape)
return sum;
}
static Napi::Value mnnSizeToJavascript(Napi::Env env, const std::vector<int> &shape)
{
auto result = Napi::Array::New(env, shape.size());
for (int i = 0; i < shape.size(); ++i) result.Set(i, Napi::Number::New(env, shape[i]));
return result;
}
class MNNSessionRunWorker : public AsyncWorker {
public:
MNNSessionRunWorker(const Napi::Function &callback, MNN::Interpreter *interpreter, MNN::Session *session)
@ -46,7 +53,6 @@ class MNNSessionRunWorker : public AsyncWorker {
void Execute()
{
interpreter_->resizeSession(session_);
if (MNN::ErrorCode::NO_ERROR != interpreter_->runSession(session_)) {
SetError(std::string("Run session failed"));
}
@ -61,9 +67,12 @@ class MNNSessionRunWorker : public AsyncWorker {
auto result = Object::New(Env());
for (auto it : interpreter_->getSessionOutputAll(session_)) {
auto tensor = it.second;
auto item = Object::New(Env());
auto buffer = ArrayBuffer::New(Env(), tensor->size());
memcpy(buffer.Data(), tensor->host<float>(), tensor->size());
result.Set(it.first, buffer);
item.Set("data", buffer);
item.Set("shape", mnnSizeToJavascript(Env(), tensor->shape()));
result.Set(it.first, item);
}
Callback().Call({Env().Undefined(), result});
}
@ -83,8 +92,10 @@ class MNNSessionRunWorker : public AsyncWorker {
if (it != DATA_TYPE_MAP.end()) type = it->second;
}
if (shape.size()) interpreter_->resizeTensor(tensor, shape);
if (shape.size()) {
interpreter_->resizeTensor(tensor, shape);
interpreter_->resizeSession(session_);
}
auto tensorBytes = getShapeSize(tensor->shape()) * type.bits / 8;
if (tensorBytes != dataBytes) {
SetError(std::string("input name #" + name + " data size not matched"));