Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nutkit/protocol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .requests import *
from .cypher import *
from .responses import *
from .feature import *
8 changes: 8 additions & 0 deletions nutkit/protocol/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
Enumerate all the capabilities in the drivers
"""
from enum import Enum


class Feature(Enum):
AUTHORIZATION_EXPIRED_TREATMENT = 'AuthorizationExpiredTreatment'
7 changes: 7 additions & 0 deletions nutkit/protocol/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def __init__(self, test_name):
self.testName = test_name


class GetFeatures:
""" Request the backend to the list of features supported by the driver.
The backend should respond with FeatureList.
"""
pass


class NewDriver:
""" Request to create a new driver instance on the backend.
Backend should respond with a Driver response or an Error response.
Expand Down
9 changes: 9 additions & 0 deletions nutkit/protocol/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ class RunTest:
pass


class FeatureList:
""" Response to GetFeatures indication the features supported
by the driver"""
def __init__(self, features=None):
if features is None:
features = []
self.features = features


class SkipTest:
""" Response to StartTest indicating that the test should be skipped"""

Expand Down
58 changes: 58 additions & 0 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nutkit import protocol
from nutkit.backend import Backend
import warnings


def get_backend_host_and_port():
Expand All @@ -30,6 +31,62 @@ def new_backend():
return Backend(host, port)


def driver_feature(feature):
if not isinstance(feature, protocol.Feature):
raise Exception('The argument should be instance of Feature')

def get_valid_test_case(*args, **kwargs):
if not args or not isinstance(args[0], TestkitTestCase):
raise Exception('Should only decorate TestkitTestCase methods')
return args[0]

def driver_feature_decorator(func):
def wrapper(*args, **kwargs):
test_case = get_valid_test_case(*args, **kwargs)
if (feature.value not in test_case._driver_features):
test_case.skipTest("Needs support for %s" % feature.value)
return func(*args, **kwargs)
return wrapper
return driver_feature_decorator


class MemoizedSupplier:
""" Momoize the function it annotates.
This way the decorated function will always return the
same value of the first interaction independent of the
supplied params.
"""

def __init__(self, func):
self._func = func
self._memo = None

def __call__(self, *args, **kwargs):
if self._memo is None:
self._memo = self._func(*args, **kwargs)
return self._memo


@MemoizedSupplier
def get_driver_features(backend):
# TODO Remove when dotnet implements the GetFeature message
if get_driver_name() in ['dotnet']:
warnings.warn("Driver does not implements GetFeatures.")
features = ()
return features

try:
response = backend.sendAndReceive(protocol.GetFeatures())
if not isinstance(response, protocol.FeatureList):
raise Exception("Response is not instance of FeatureList")
features = tuple(response.features)
except Exception as e:
warnings.warn("Could not fetch FeatureList: %s" % e)
features = ()

return features


def get_driver_name():
return os.environ['TEST_DRIVER_NAME']

Expand All @@ -40,6 +97,7 @@ def setUp(self):
id_ = re.sub(r"^([^\.]+\.)*?tests\.", "", self.id())
self._backend = new_backend()
self.addCleanup(self._backend.close)
self._driver_features = get_driver_features(self._backend)
response = self._backend.sendAndReceive(protocol.StartTest(id_))
if isinstance(response, protocol.SkipTest):
self.skipTest(response.reason)
Expand Down
163 changes: 140 additions & 23 deletions tests/stub/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tests.shared import (
get_driver_name,
TestkitTestCase,
driver_feature
)
from tests.stub.shared import StubServer

Expand All @@ -15,12 +16,34 @@ def get_extra_hello_props():
return ""


class BaseAuthorizationTests(TestkitTestCase):
# While there is no unified language agnostic error type mapping, a
# dedicated driver mapping is required to determine if the expected
# error is returned.
def assert_is_authorization_error(self, error):
driver = get_driver_name()
self.assertEqual("Neo.ClientError.Security.AuthorizationExpired",
error.code)
if driver in ['java']:
self.assertEqual(
'org.neo4j.driver.exceptions.AuthorizationExpiredException',
error.errorType)
elif driver in ['python']:
self.assertEqual(
"<class 'neo4j.exceptions.TransientError'>", error.errorType
)
elif driver in ['javascript']:
# only test for code
pass
else:
self.fail("no error mapping is defined for %s driver" % driver)


# TODO: find a way to test that driver ditches all open connection in the pool
# when encountering Neo.ClientError.Security.AuthorizationExpired
# TODO: re-write tests, where possible, to use only one server, utilizing
# on_send_RetryableNegative and potentially other hooks.

class AuthorizationTests(TestkitTestCase):
class AuthorizationTests(BaseAuthorizationTests):
def setUp(self):
super().setUp()
self._routing_server1 = StubServer(9000)
Expand Down Expand Up @@ -175,27 +198,6 @@ def get_vars(self, host=None):
def get_db(self):
return "adb"

# While there is no unified language agnostic error type mapping, a
# dedicated driver mapping is required to determine if the expected
# error is returned.
def assert_is_authorization_error(self, error):
driver = get_driver_name()
self.assertEqual("Neo.ClientError.Security.AuthorizationExpired",
error.code)
if driver in ['java']:
self.assertEqual(
'org.neo4j.driver.exceptions.AuthorizationExpiredException',
error.errorType)
elif driver in ['python']:
self.assertEqual(
"<class 'neo4j.exceptions.TransientError'>", error.errorType
)
elif driver in ['javascript']:
# only test for code
pass
else:
self.fail("no error mapping is defined for %s driver" % driver)

@staticmethod
def collectRecords(result):
sequence = []
Expand Down Expand Up @@ -712,3 +714,118 @@ def get_vars(self, host=None):

def get_db(self):
return None


class NoRoutingAuthorizationTests(BaseAuthorizationTests):
def setUp(self):
super().setUp()
self._server = StubServer(9010)
self._uri = "bolt://%s:%d" % (self._server.host,
self._server.port)
self._auth = types.AuthorizationToken(
scheme="basic", principal="p", credentials="c")
self._userAgent = "007"

def tearDown(self):
self._server.reset()
super().tearDown()

def read_return_1_failure_return_2_and_3_succeed_script(self):
return """
!: BOLT 4
!: AUTO HELLO
!: AUTO GOODBYE
!: AUTO RESET
!: ALLOW CONCURRENT

{+
{{
C: RUN "RETURN 1 as n" {} {"mode": "r"}
C: PULL {"n": 1000}
S: SUCCESS {"fields": ["n"]}
FAILURE {"code": "Neo.ClientError.Security.AuthorizationExpired", "message": "Authorization expired"}
S: <EXIT>
----
C: RUN "RETURN 2 as n" {} {"mode": "r"}
S: SUCCESS {"fields": ["n"]}
C: PULL {"n": 1}
S: RECORD [1]
SUCCESS {"type": "r", "has_more": true}
{{
C: DISCARD {"n": -1}
S: SUCCESS {}
----
C: PULL {"n": "*"}
S: RECORD [1]
SUCCESS {"type": "r"}
}}
----
C: RUN "RETURN 3 as n" {} {"mode": "r"}
C: PULL {"n": "*"}
S: SUCCESS {"fields": ["n"]}
RECORD [1]
SUCCESS {"type": "r"}
}}

+}
"""

@driver_feature(types.Feature.AUTHORIZATION_EXPIRED_TREATMENT)
def test_should_drop_connection_after_AuthorizationExpired(self):
self._server.start(
script=self.read_return_1_failure_return_2_and_3_succeed_script()
)
driver = Driver(self._backend, self._uri, self._auth,
userAgent=self._userAgent)

session1 = driver.session('r', fetchSize=1)
session2 = driver.session('r')

session1.run('RETURN 2 as n').next()

try:
session2.run('RETURN 1 as n').next()
except types.DriverError as e:
self.assert_is_authorization_error(e)

session2.close()
session1.close()

accept_count = self._server.count_responses("<ACCEPT>")

# fetching another connection and run a query to force
# drivers which lazy close the connection do it
session3 = driver.session('r')
session3.run('RETURN 3 as n').next()
session3.close()

hangup_count = self._server.count_responses("<HANGUP>")

self.assertEqual(accept_count, hangup_count)
self.assertGreaterEqual(accept_count, 2)

driver.close()

@driver_feature(types.Feature.AUTHORIZATION_EXPIRED_TREATMENT)
def test_should_be_able_to_use_current_sessions_after_AuthorizationExpired(self):
self._server.start(
script=self.read_return_1_failure_return_2_and_3_succeed_script()
)

driver = Driver(self._backend, self._uri, self._auth,
userAgent=self._userAgent)

session1 = driver.session('r', fetchSize=1)
session2 = driver.session('r')

session1.run('RETURN 2 as n').next()

try:
session2.run('RETURN 1 as n').next()
except types.DriverError as e:
self.assert_is_authorization_error(e)

session2.close()

session1.run('RETURN 2 as n').next()
session1.close()