Skip to content
8 changes: 4 additions & 4 deletions napi-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5873,11 +5873,11 @@ inline void AsyncProgressWorker<T>::SendProgress_(const T* data, size_t count) {
}

template<class T>
inline void AsyncProgressWorker<T>::Signal() const {
inline void AsyncProgressWorker<T>::Signal() {
this->NonBlockingCall(static_cast<T*>(nullptr));
}

template<class T>
template <class T>
inline void AsyncProgressWorker<T>::ExecutionProgress::Signal() const {
_worker->Signal();
}
Expand Down Expand Up @@ -5985,7 +5985,7 @@ inline void AsyncProgressQueueWorker<T>::SendProgress_(const T* data, size_t cou
}

template<class T>
inline void AsyncProgressQueueWorker<T>::Signal() const {
inline void AsyncProgressQueueWorker<T>::Signal() {
this->NonBlockingCall(nullptr);
}

Expand All @@ -5995,7 +5995,7 @@ inline void AsyncProgressQueueWorker<T>::OnWorkComplete(Napi::Env env, napi_stat
AsyncProgressWorkerBase<std::pair<T*, size_t>>::OnWorkComplete(env, status);
}

template<class T>
template <class T>
inline void AsyncProgressQueueWorker<T>::ExecutionProgress::Signal() const {
_worker->Signal();
}
Expand Down
4 changes: 2 additions & 2 deletions napi.h
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ namespace Napi {

private:
void Execute() override;
void Signal() const;
void Signal();
void SendProgress_(const T* data, size_t count);

std::mutex _mutex;
Expand Down Expand Up @@ -2934,7 +2934,7 @@ namespace Napi {

private:
void Execute() override;
void Signal() const;
void Signal();
void SendProgress_(const T* data, size_t count);
};
#endif // NAPI_VERSION > 3 && !defined(__wasm32__)
Expand Down
54 changes: 54 additions & 0 deletions test/async_progress_queue_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,66 @@ class TestWorker : public AsyncProgressQueueWorker<ProgressData> {
FunctionReference _js_progress_cb;
};

class SignalTestWorker : public AsyncProgressQueueWorker<ProgressData> {
public:
static Napi::Value CreateWork(const CallbackInfo& info) {
int32_t times = info[0].As<Number>().Int32Value();
Function cb = info[1].As<Function>();
Function progress = info[2].As<Function>();

SignalTestWorker* worker = new SignalTestWorker(
cb, progress, "TestResource", Object::New(info.Env()), times);

return Napi::External<SignalTestWorker>::New(info.Env(), worker);
}

static void QueueWork(const CallbackInfo& info) {
auto wrap = info[0].As<Napi::External<SignalTestWorker>>();
auto worker = wrap.Data();
worker->Queue();
}

protected:
void Execute(const ExecutionProgress& progress) override {
using namespace std::chrono_literals;
std::this_thread::sleep_for(1s);

for (int32_t idx = 0; idx < _times; idx++) {
// TODO: unlike AsyncProgressWorker, this signal does not trigger
// OnProgress() below, to run the JS callback. Investigate and fix.
progress.Signal();
}
}

void OnProgress(const ProgressData* /* data */, size_t /* count */) override {
if (!_js_progress_cb.IsEmpty()) {
_js_progress_cb.Call(Receiver().Value(), {});
}
}

private:
SignalTestWorker(Function cb,
Function progress,
const char* resource_name,
const Object& resource,
int32_t times)
: AsyncProgressQueueWorker(cb, resource_name, resource), _times(times) {
_js_progress_cb.Reset(progress, 1);
}

int32_t _times;
FunctionReference _js_progress_cb;
};

} // namespace

Object InitAsyncProgressQueueWorker(Env env) {
Object exports = Object::New(env);
exports["createWork"] = Function::New(env, TestWorker::CreateWork);
exports["queueWork"] = Function::New(env, TestWorker::QueueWork);
exports["createSignalWork"] =
Function::New(env, SignalTestWorker::CreateWork);
exports["queueSignalWork"] = Function::New(env, SignalTestWorker::QueueWork);
return exports;
}

Expand Down
33 changes: 28 additions & 5 deletions test/async_progress_queue_worker.js
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
'use strict';

const common = require('./common')
const common = require('./common');
const assert = require('assert');

module.exports = common.runTest(test);

async function test({ asyncprogressqueueworker }) {
async function test ({ asyncprogressqueueworker }) {
await success(asyncprogressqueueworker);
await fail(asyncprogressqueueworker);
await signalTest(asyncprogressqueueworker);
}

function success(binding) {
function success (binding) {
return new Promise((resolve, reject) => {
const expected = [0, 1, 2, 3];
const actual = [];
Expand All @@ -32,15 +33,37 @@ function success(binding) {
});
}

function fail(binding) {
function fail (binding) {
return new Promise((resolve, reject) => {
const worker = binding.createWork(-1,
common.mustCall((err) => {
assert.throws(() => { throw err }, /test error/);
assert.throws(() => { throw err; }, /test error/);
resolve();
}),
common.mustNotCall()
);
binding.queueWork(worker);
});
}

function signalTest (binding) {
return new Promise((resolve, reject) => {
const expectedCalls = 4;
let actualCalls = 0;
const worker = binding.createSignalWork(expectedCalls,
common.mustCall((err) => {
if (err) {
reject(err);
} else {
if (actualCalls === expectedCalls) {
resolve();
}
}
}),
common.mustCall((_progress) => {
actualCalls++;
}, expectedCalls)
);
binding.queueSignalWork(worker);
});
}
47 changes: 47 additions & 0 deletions test/async_progress_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,59 @@ class MalignWorker : public AsyncProgressWorker<ProgressData> {
std::mutex _cvm;
FunctionReference _progress;
};

class SignalTestWorker : public AsyncProgressWorker<ProgressData> {
public:
static void DoWork(const CallbackInfo& info) {
int32_t times = info[0].As<Number>().Int32Value();
Function cb = info[1].As<Function>();
Function progress = info[2].As<Function>();

SignalTestWorker* worker = new SignalTestWorker(
cb, progress, "TestResource", Object::New(info.Env()));
worker->_times = times;
worker->Queue();
}

protected:
void Execute(const ExecutionProgress& progress) override {
if (_times < 0) {
SetError("test error");
}
std::unique_lock<std::mutex> lock(_cvm);
for (int32_t idx = 0; idx < _times; idx++) {
progress.Signal();
_cv.wait(lock);
}
}

void OnProgress(const ProgressData* /* data */, size_t /* count */) override {
if (!_progress.IsEmpty()) {
_progress.MakeCallback(Receiver().Value(), {});
}
_cv.notify_one();
}

private:
SignalTestWorker(Function cb,
Function progress,
const char* resource_name,
const Object& resource)
: AsyncProgressWorker(cb, resource_name, resource) {
_progress.Reset(progress, 1);
}
std::condition_variable _cv;
std::mutex _cvm;
int32_t _times;
FunctionReference _progress;
};
} // namespace

Object InitAsyncProgressWorker(Env env) {
Object exports = Object::New(env);
exports["doWork"] = Function::New(env, TestWorker::DoWork);
exports["doMalignTest"] = Function::New(env, MalignWorker::DoWork);
exports["doSignalTest"] = Function::New(env, SignalTestWorker::DoWork);
return exports;
}

Expand Down
33 changes: 27 additions & 6 deletions test/async_progress_worker.js
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
'use strict';

const common = require('./common')
const common = require('./common');
const assert = require('assert');

module.exports = common.runTest(test);

async function test({ asyncprogressworker }) {
async function test ({ asyncprogressworker }) {
await success(asyncprogressworker);
await fail(asyncprogressworker);
await malignTest(asyncprogressworker);
await signalTest(asyncprogressworker);
}

function success(binding) {
function success (binding) {
return new Promise((resolve, reject) => {
const expected = [0, 1, 2, 3];
const actual = [];
Expand All @@ -32,19 +33,19 @@ function success(binding) {
});
}

function fail(binding) {
function fail (binding) {
return new Promise((resolve) => {
binding.doWork(-1,
common.mustCall((err) => {
assert.throws(() => { throw err }, /test error/)
assert.throws(() => { throw err; }, /test error/);
resolve();
}),
common.mustNotCall()
);
});
}

function malignTest(binding) {
function malignTest (binding) {
return new Promise((resolve, reject) => {
binding.doMalignTest(
common.mustCall((err) => {
Expand All @@ -59,3 +60,23 @@ function malignTest(binding) {
);
});
}

function signalTest (binding) {
return new Promise((resolve, reject) => {
const expectedCalls = 3;
let actualCalls = 0;
binding.doSignalTest(expectedCalls,
common.mustCall((err) => {
if (err) {
reject(err);
}
}),
common.mustCall((_progress) => {
actualCalls++;
if (expectedCalls === actualCalls) {
resolve();
}
}, expectedCalls)
);
});
}