Skip to content

Commit

Permalink
Validate transformer unit tests (#302)
Browse files Browse the repository at this point in the history
* Implement validate transformer unit tests

* Fix lint

* Add success message

* Fix export folder.

* Remove white lines
  • Loading branch information
pvk-developer authored Oct 21, 2021
1 parent c395ee7 commit 08d9f28
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
15 changes: 11 additions & 4 deletions tests/code_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ def validate_transformer_module(transformer):

def _validate_config_json(config_path, transformer_name):
with open(config_path, 'r', encoding='utf-8') as config_file:
config_dict = json.load(config_file)
try:
config_dict = json.load(config_file)
except json.JSONDecodeError:
config_dict = None

config_error_msg = f'{config_path} does not have valid json format.'
assert config_dict is not None, config_error_msg

transformers_not_found = 'The key ``transformers`` was not found in the config.json file.'
assert 'transformers' in config_dict, transformers_not_found
Expand Down Expand Up @@ -92,7 +98,8 @@ def validate_transformer_importable_from_parent_module(transformer):
assert imported_transformer is not None, f'Could not import {name} from {module}'


def _get_test_location(transformer):
def get_test_location(transformer):
"""Return the expected unit test location of a transformer."""
transformer_file = Path(inspect.getfile(transformer))
transformer_folder = transformer_file.parent
rdt_unit_test_path = Path(__file__).parent / 'unit'
Expand All @@ -119,7 +126,7 @@ def _get_test_location(transformer):

def validate_test_location(transformer):
"""Validate if the test file exists in the expected location."""
test_location = _get_test_location(transformer)
test_location = get_test_location(transformer)
if test_location is None:
return False, 'The expected test location was not found.'

Expand Down Expand Up @@ -152,7 +159,7 @@ def _load_module_from_path(path):

def validate_test_names(transformer):
"""Validate if the test methods are properly specified."""
test_file = _get_test_location(transformer)
test_file = get_test_location(transformer)
module = _load_module_from_path(test_file)

test_class = getattr(module, f'Test{transformer.__name__}', None)
Expand Down
68 changes: 63 additions & 5 deletions tests/contributing.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""Validation methods for contributing to RDT."""


import importlib
import inspect
import subprocess
import traceback
from pathlib import Path

import coverage
import numpy as np
import pandas as pd
import pytest
from tabulate import tabulate

from rdt.transformers import get_transformers_by_type
from tests.code_style import (
load_transformer, validate_test_location, validate_test_names, validate_transformer_addon,
validate_transformer_importable_from_parent_module, validate_transformer_module,
validate_transformer_name, validate_transformer_subclass)
get_test_location, load_transformer, validate_test_location, validate_test_names,
validate_transformer_addon, validate_transformer_importable_from_parent_module,
validate_transformer_module, validate_transformer_name, validate_transformer_subclass)
from tests.datasets import get_dataset_generators_by_type
from tests.integration.test_transformers import validate_transformer
from tests.performance import evaluate_transformer_performance, validate_performance
Expand Down Expand Up @@ -275,7 +277,8 @@ def validate_transformer_code_style(transformer):
Args:
transformer (string or rdt.transformers.BaseTransformer):
The transformer to validate.
Output:
Returns:
bool:
Whether or not the transformer passes all code style checks.
"""
Expand Down Expand Up @@ -308,6 +311,61 @@ def validate_transformer_code_style(transformer):
return not bool(errors)


def validate_transformer_unit_tests(transformer):
"""Validate the unit tests of a transformer.
This function finds the module where the unit tests of the transformer
have been implemented and runs them using ``pytest``, capturing the code
coverage of the tests (how many lines of the source code are executed
during the tests).
Args:
transformer (string or rdt.transformers.BaseTransformer):
The transformer to validate.
Returns:
float:
A ``float`` value representing the test coverage where 1.0 is 100%.
"""
if not inspect.isclass(transformer):
transformer = load_transformer(transformer)

source_location = inspect.getfile(transformer)
test_location = get_test_location(transformer)
module_name = getattr(transformer, '__module__', None)

print(f'Validating source file {source_location}\n')

pytest_run = f'-v --disable-warnings --no-header {test_location}'
pytest_run = pytest_run.split(' ')

cov = coverage.Coverage(source=[module_name])

cov.start()
pytest_output = pytest.main(pytest_run)
cov.stop()

if pytest_output is pytest.ExitCode.OK:
print('\nSUCCESS: The unit tests passed.')
else:
print('\nERROR: The unit tests failed.')

score = cov.report(show_missing=True)
rounded_score = round(score / 100, 3)
if rounded_score < 1.0:
print(f'\nERROR: The unit tests only cover {round(score, 3)}% of your code.')
else:
print(f'\nSUCCESS: The unit tests cover {round(score, 3)}% of your code.')

cov.html_report()
print('\nFull coverage report here:\n')
coverage_name = module_name.replace('.', '_')
export_dir = Path('htmlcov') / f'{coverage_name}_py.html'
print(export_dir.absolute().as_uri())

return rounded_score


def validate_transformer_quality(transformer):
"""Validate quality tests for a transformer.
Expand Down

0 comments on commit 08d9f28

Please sign in to comment.