Skip to content

Commit

Permalink
Add tests for stateful kernel functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov committed Sep 4, 2022
1 parent a3d6994 commit bf54a37
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_invalid_op_input(self):
with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat())


def test_stateful_kernel(self):
def test_state_in_class(self):
@cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int])
class GSum:
@staticmethod
Expand Down
96 changes: 96 additions & 0 deletions modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python

import numpy as np
import cv2 as cv
import os
import sys
import unittest

from tests_common import NewOpenCVTests


try:

if sys.version_info[:2] < (3, 0):
raise unittest.SkipTest('Python 2.x is not supported')


class CounterState:
def __init__(self):
self.counter = 0


@cv.gapi.op('stateful_counter',
in_types=[cv.GOpaque.Int],
out_types=[cv.GOpaque.Int])
class GStatefulCounter:
"""Accumulate state counter on every call"""

@staticmethod
def outMeta(desc):
return cv.empty_gopaque_desc()


@cv.gapi.kernel(GStatefulCounter)
class GStatefulCounterImpl:
"""Implementation for GStatefulCounter operation."""

@staticmethod
def setup(desc):
return CounterState()

@staticmethod
def run(value, state):
state.counter += value
return state.counter


class gapi_sample_pipelines(NewOpenCVTests):
def test_stateful_kernel_single_instance(self):
g_in = cv.GOpaque.Int()
g_out = GStatefulCounter.on(g_in)
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))
pkg = cv.gapi.kernels(GStatefulCounterImpl)

nums = [i for i in range(10)]
acc = 0
for v in nums:
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))

self.assertEqual(sum(nums), acc)


def test_stateful_kernel_multiple_instances(self):
# NB: Every counter has his own independent state.
g_in = cv.GOpaque.Int()
g_out0 = GStatefulCounter.on(g_in)
g_out1 = GStatefulCounter.on(g_in)
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1))
pkg = cv.gapi.kernels(GStatefulCounterImpl)

nums = [i for i in range(10)]
acc0 = acc1 = 0
for v in nums:
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))

ref = sum(nums)
self.assertEqual(ref, acc0)
self.assertEqual(ref, acc1)


except unittest.SkipTest as e:

message = str(e)

class TestSkip(unittest.TestCase):
def setUp(self):
self.skipTest('Skip tests: ' + message)

def test_skip():
pass

pass


if __name__ == '__main__':
NewOpenCVTests.bootstrap()

0 comments on commit bf54a37

Please sign in to comment.