Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions openhtf/output/callbacks/console_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
from typing import TextIO

from openhtf.core import measurements
from openhtf.core import test_record
Expand All @@ -11,7 +12,9 @@ class ConsoleSummary():
"""Print test results with failure info on console."""

# pylint: disable=invalid-name
def __init__(self, indent=2, output_stream=sys.stdout):
def __init__(self,
indent: int = 2,
output_stream: TextIO = sys.stdout) -> None:
self.indent = ' ' * indent
if os.name == 'posix': # Linux and Mac.
self.RED = '\033[91m'
Expand All @@ -37,7 +40,7 @@ def __init__(self, indent=2, output_stream=sys.stdout):

# pylint: enable=invalid-name

def __call__(self, record):
def __call__(self, record: test_record.TestRecord) -> None:
output_lines = [
''.join((self.color_table[record.outcome], self.BOLD,
record.code_info.name, ':', record.outcome.name, self.RESET))
Expand Down
11 changes: 7 additions & 4 deletions openhtf/output/servers/station_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import threading
import time
import types
from typing import Optional, Union

import openhtf
from openhtf.output.servers import pub_sub
Expand Down Expand Up @@ -558,7 +559,9 @@ class StationServer(web_gui_server.WebGuiServer):
test.execute()
"""

def __init__(self, history_path=None):
def __init__(
self,
history_path: Optional[Union[str, bytes, os.PathLike]] = None) -> None:
# Disable tornado's logging.
# TODO(kenadia): Enable these logs if verbosity flag is at least -vvv.
# I think this will require changing how StoreRepsInModule works.
Expand Down Expand Up @@ -614,7 +617,7 @@ def _get_config(self):
'server_type': STATION_SERVER_TYPE,
}

def run(self):
def run(self) -> None:
_LOG.info('Announcing station server via multicast on %s:%s',
self.station_multicast.address, self.station_multicast.port)
self.station_multicast.start()
Expand All @@ -624,13 +627,13 @@ def run(self):
host=socket.gethostname(), port=self.port))
super(StationServer, self).run()

def stop(self):
def stop(self) -> None:
_LOG.info('Stopping station server.')
super(StationServer, self).stop()
_LOG.info('Stopping multicast.')
self.station_multicast.stop(timeout_s=0)

def publish_final_state(self, test_record):
def publish_final_state(self, test_record: openhtf.TestRecord) -> None:
"""Test output callback publishing a final state from the test record."""
StationPubSub.publish_test_record(test_record)

Expand Down
96 changes: 51 additions & 45 deletions openhtf/util/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from openhtf.util import measurements

class MyLessThanValidator(ValidatorBase):
def __init__(self, limit):
def __init__(self, limit) -> None:
self.limit = limit

# This will be invoked to test if the measurement is 'PASS' or 'FAIL'.
def __call__(self, value):
def __call__(self, value) -> bool:
return value < self.limit

# Name defaults to the validator's __name__ attribute unless overridden.
Expand All @@ -31,12 +31,12 @@ def MyPhase(test):
For simpler validators, you don't need to register them at all, you can
simply attach them to the Measurement with the .with_validator() method:

def LessThan4(value):
def LessThan4(value) -> bool:
return value < 4

@measurements.measures(
measurements.Measurement('my_measurement).with_validator(LessThan4))
def MyPhase(test):
def MyPhase(test: htf.TestApi) -> None:
test.measurements.my_measurement = 5 # Will also 'FAIL'

Notes:
Expand All @@ -58,37 +58,43 @@ def MyPhase(test):
import math
import numbers
import re
from typing import Callable, Dict, Optional, Type, TypeVar, Union

from openhtf import util

_VALIDATORS = {}

class ValidatorBase(abc.ABC):

@abc.abstractmethod
def __call__(self, value) -> bool:
"""Should validate value, returning a boolean result."""


_ValidatorT = TypeVar("_ValidatorT", bound=ValidatorBase)
_ValidatorFactoryT = Union[Type[_ValidatorT], Callable[..., _ValidatorT]]
_VALIDATORS: Dict[str, _ValidatorFactoryT] = {}

def register(validator, name=None):

def register(validator: _ValidatorFactoryT,
name: Optional[str] = None) -> _ValidatorFactoryT:
name = name or validator.__name__
if name in _VALIDATORS:
raise ValueError('Duplicate validator name', name)
_VALIDATORS[name] = validator
return validator


def has_validator(name):
def has_validator(name: str) -> bool:
return name in _VALIDATORS


def create_validator(name, *args, **kwargs):
def create_validator(name: str, *args, **kwargs) -> _ValidatorT:
return _VALIDATORS[name](*args, **kwargs)


_identity = lambda x: x


class ValidatorBase(abc.ABC):

@abc.abstractmethod
def __call__(self, value):
"""Should validate value, returning a boolean result."""


class RangeValidatorBase(ValidatorBase, abc.ABC):

@abc.abstractproperty
Expand Down Expand Up @@ -120,7 +126,7 @@ def __init__(self,
minimum,
maximum,
marginal_minimum=None,
marginal_maximum=None):
marginal_maximum=None) -> None:
super(AllInRangeValidator, self).__init__()
if minimum is None and maximum is None:
raise ValueError('Must specify minimum, maximum, or both')
Expand Down Expand Up @@ -168,7 +174,7 @@ def marginal_minimum(self):
def marginal_maximum(self):
return self._marginal_maximum

def __call__(self, values):
def __call__(self, values) -> bool:
within_maximum = self._maximum is None or all(
value <= self.maximum for value in values)
within_minimum = self._minimum is None or all(
Expand Down Expand Up @@ -204,18 +210,18 @@ def __str__(self):
class AllEqualsValidator(ValidatorBase):
"""Validator to verify a list of values are equal to the expected value."""

def __init__(self, spec):
def __init__(self, spec) -> None:
super(AllEqualsValidator, self).__init__()
self._spec = spec

@property
def spec(self):
return self._spec

def __call__(self, values):
def __call__(self, values) -> bool:
return all([value == self.spec for value in values])

def __str__(self):
def __str__(self) -> str:
return "'x' is equal to '%s'" % self._spec


Expand All @@ -242,7 +248,7 @@ def __init__(self,
maximum=None,
marginal_minimum=None,
marginal_maximum=None,
type=None): # pylint: disable=redefined-builtin
type=None) -> None: # pylint: disable=redefined-builtin
super(InRange, self).__init__()

if minimum is None and maximum is None:
Expand Down Expand Up @@ -292,7 +298,7 @@ def marginal_minimum(self):
return converter(self._marginal_minimum)

@property
def marginal_maximum(self):
def marginal_maximum(self) -> str:
converter = self._type if self._type is not None else _identity
return converter(self._marginal_maximum)

Expand All @@ -305,7 +311,7 @@ def with_args(self, **kwargs):
type=self._type,
)

def __call__(self, value):
def __call__(self, value) -> bool:
if value is None:
return False
if math.isnan(value):
Expand All @@ -329,7 +335,7 @@ def is_marginal(self, value) -> bool:
return True
return False

def __str__(self):
def __str__(self) -> str:
assert self._minimum is not None or self._maximum is not None
if (self._minimum is not None and self._maximum is not None and
self._minimum == self._maximum):
Expand All @@ -347,13 +353,13 @@ def __str__(self):
string_repr += ' <= {}'.format(self._maximum)
return string_repr

def __eq__(self, other):
def __eq__(self, other) -> bool:
return (isinstance(other, type(self)) and self.minimum == other.minimum and
self.maximum == other.maximum and
self.marginal_minimum == other.marginal_minimum and
self.marginal_maximum == other.marginal_maximum)

def __ne__(self, other):
def __ne__(self, other) -> bool:
return not self == other


Expand All @@ -373,10 +379,10 @@ def equals(value, type=None): # pylint: disable=redefined-builtin
return Equals(value, type=type)


class Equals(object):
class Equals(ValidatorBase):
"""Validator to verify an object is equal to the expected value."""

def __init__(self, expected, type=None): # pylint: disable=redefined-builtin
def __init__(self, expected, type=None) -> None: # pylint: disable=redefined-builtin
self._expected = expected
self._type = type

Expand All @@ -388,21 +394,21 @@ def expected(self):
def __call__(self, value):
return value == self.expected

def __str__(self):
def __str__(self) -> str:
return f"'x' is equal to '{self._expected}'"

def __eq__(self, other):
def __eq__(self, other) -> bool:
return isinstance(other, type(self)) and self.expected == other.expected


class RegexMatcher(object):
class RegexMatcher(ValidatorBase):
"""Validator to verify a string value matches a regex."""

def __init__(self, regex, compiled_regex):
def __init__(self, regex, compiled_regex) -> None:
self._compiled = compiled_regex
self.regex = regex

def __call__(self, value):
def __call__(self, value) -> bool:
return self._compiled.match(str(value)) is not None

def __deepcopy__(self, dummy_memo):
Expand All @@ -414,7 +420,7 @@ def __str__(self):
def __eq__(self, other):
return isinstance(other, type(self)) and self.regex == other.regex

def __ne__(self, other):
def __ne__(self, other) -> bool:
return not self == other


Expand All @@ -426,7 +432,7 @@ def matches_regex(regex):
class WithinPercent(RangeValidatorBase):
"""Validates that a number is within percent of a value."""

def __init__(self, expected, percent, marginal_percent=None):
def __init__(self, expected, percent, marginal_percent=None) -> None:
super(WithinPercent, self).__init__()
if percent < 0:
raise ValueError('percent argument is {}, must be >0'.format(percent))
Expand Down Expand Up @@ -465,7 +471,7 @@ def marginal_maximum(self):
return (self.expected -
self._applied_marginal_percent if self.marginal_percent else None)

def __call__(self, value):
def __call__(self, value) -> bool:
return self.minimum <= value <= self.maximum

def is_marginal(self, value) -> bool:
Expand All @@ -475,17 +481,17 @@ def is_marginal(self, value) -> bool:
return (self.minimum < value <= self.marginal_minimum or
self.marginal_maximum <= value < self.maximum)

def __str__(self):
def __str__(self) -> str:
return "'x' is within {}% of {}. Marginal: {}% of {}".format(
self.percent, self.expected, self.marginal_percent, self.expected)

def __eq__(self, other):
def __eq__(self, other) -> bool:
return (isinstance(other, type(self)) and
self.expected == other.expected and
self.percent == other.percent and
self.marginal_percent == other.marginal_percent)

def __ne__(self, other):
def __ne__(self, other) -> bool:
return not self == other


Expand All @@ -497,14 +503,14 @@ def within_percent(expected, percent):
class DimensionPivot(ValidatorBase):
"""Runs a validator on each actual value of a dimensioned measurement."""

def __init__(self, sub_validator):
def __init__(self, sub_validator) -> None:
super(DimensionPivot, self).__init__()
self._sub_validator = sub_validator

def __call__(self, dimensioned_value):
def __call__(self, dimensioned_value) -> bool:
return all(self._sub_validator(row[-1]) for row in dimensioned_value)

def __str__(self):
def __str__(self) -> str:
return 'All values pass: {}'.format(str(self._sub_validator))


Expand All @@ -516,11 +522,11 @@ def dimension_pivot_validate(sub_validator):
class ConsistentEndDimensionPivot(ValidatorBase):
"""If any rows validate, all following rows must also validate."""

def __init__(self, sub_validator):
def __init__(self, sub_validator) -> None:
super(ConsistentEndDimensionPivot, self).__init__()
self._sub_validator = sub_validator

def __call__(self, dimensioned_value):
def __call__(self, dimensioned_value) -> bool:
for index, row in enumerate(dimensioned_value):
if self._sub_validator(row[-1]):
i = index
Expand All @@ -529,7 +535,7 @@ def __call__(self, dimensioned_value):
return False
return all(self._sub_validator(rest[-1]) for rest in dimensioned_value[i:])

def __str__(self):
def __str__(self) -> str:
return 'Once pass, rest must also pass: {}'.format(str(self._sub_validator))


Expand Down