forked from opencv/opencv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for stateful kernel functionality
- Loading branch information
1 parent
a3d6994
commit bf54a37
Showing
2 changed files
with
97 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#!/usr/bin/env python | ||
|
||
import numpy as np | ||
import cv2 as cv | ||
import os | ||
import sys | ||
import unittest | ||
|
||
from tests_common import NewOpenCVTests | ||
|
||
|
||
try: | ||
|
||
if sys.version_info[:2] < (3, 0): | ||
raise unittest.SkipTest('Python 2.x is not supported') | ||
|
||
|
||
class CounterState: | ||
def __init__(self): | ||
self.counter = 0 | ||
|
||
|
||
@cv.gapi.op('stateful_counter', | ||
in_types=[cv.GOpaque.Int], | ||
out_types=[cv.GOpaque.Int]) | ||
class GStatefulCounter: | ||
"""Accumulate state counter on every call""" | ||
|
||
@staticmethod | ||
def outMeta(desc): | ||
return cv.empty_gopaque_desc() | ||
|
||
|
||
@cv.gapi.kernel(GStatefulCounter) | ||
class GStatefulCounterImpl: | ||
"""Implementation for GStatefulCounter operation.""" | ||
|
||
@staticmethod | ||
def setup(desc): | ||
return CounterState() | ||
|
||
@staticmethod | ||
def run(value, state): | ||
state.counter += value | ||
return state.counter | ||
|
||
|
||
class gapi_sample_pipelines(NewOpenCVTests): | ||
def test_stateful_kernel_single_instance(self): | ||
g_in = cv.GOpaque.Int() | ||
g_out = GStatefulCounter.on(g_in) | ||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out)) | ||
pkg = cv.gapi.kernels(GStatefulCounterImpl) | ||
|
||
nums = [i for i in range(10)] | ||
acc = 0 | ||
for v in nums: | ||
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) | ||
|
||
self.assertEqual(sum(nums), acc) | ||
|
||
|
||
def test_stateful_kernel_multiple_instances(self): | ||
# NB: Every counter has his own independent state. | ||
g_in = cv.GOpaque.Int() | ||
g_out0 = GStatefulCounter.on(g_in) | ||
g_out1 = GStatefulCounter.on(g_in) | ||
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1)) | ||
pkg = cv.gapi.kernels(GStatefulCounterImpl) | ||
|
||
nums = [i for i in range(10)] | ||
acc0 = acc1 = 0 | ||
for v in nums: | ||
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg)) | ||
|
||
ref = sum(nums) | ||
self.assertEqual(ref, acc0) | ||
self.assertEqual(ref, acc1) | ||
|
||
|
||
except unittest.SkipTest as e: | ||
|
||
message = str(e) | ||
|
||
class TestSkip(unittest.TestCase): | ||
def setUp(self): | ||
self.skipTest('Skip tests: ' + message) | ||
|
||
def test_skip(): | ||
pass | ||
|
||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
NewOpenCVTests.bootstrap() |