Skip to content

Commit

Permalink
[ML] Pylint fixes for Sweep and Command jobs (Azure#26324)
Browse files Browse the repository at this point in the history
  • Loading branch information
TonyJ1 authored and mccoyp committed Sep 22, 2022
1 parent 0758439 commit fca5f94
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ def make(self, data, **kwargs):
def predump(self, data, **kwargs):
from azure.ai.ml.sweep import QLogNormal, QNormal

if not (isinstance(data, QNormal) or isinstance(data, QLogNormal)):
if not isinstance(data, (QNormal, QLogNormal)):
raise ValidationError("Cannot dump non-QNormal or non-QLogNormal object into QNormalSchema")
return data
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class UniformSchema(metaclass=PatchedSchemaMeta):
def predump(self, data, **kwargs):
from azure.ai.ml.sweep import LogUniform, Uniform

if not (isinstance(data, Uniform) or isinstance(data, LogUniform)):
if not isinstance(data, (Uniform, LogUniform)):
raise ValidationError("Cannot dump non-Uniform or non-LogUniform object into UniformSchema")
if data.type.lower() not in SearchSpace.UNIFORM_LOGUNIFORM:
raise ValidationError(BASE_ERROR_MESSAGE + str(SearchSpace.UNIFORM_LOGUNIFORM))
Expand Down Expand Up @@ -51,6 +51,6 @@ def make(self, data, **kwargs):
def predump(self, data, **kwargs):
from azure.ai.ml.sweep import QLogUniform, QUniform

if not (isinstance(data, QUniform) or isinstance(data, QLogUniform)):
if not isinstance(data, (QUniform, QLogUniform)):
raise ValidationError("Cannot dump non-QUniform or non-QLogUniform object into UniformSchema")
return data
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class ParameterizedCommandSchema(PathAwareSchema):
command = fields.Str(
metadata={
# pylint: disable=line-too-long
"description": "The command run and the parameters passed. This string may contain place holders of inputs in {}. "
},
required=True,
Expand Down
65 changes: 32 additions & 33 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())

# resolve normal dict to dict[str, JobService]
services = self._resolve_job_services(services)
services = _resolve_job_services(services)
kwargs.pop("type", None)
self._parameters = kwargs.pop("parameters", {})
BaseNode.__init__(
Expand Down Expand Up @@ -221,7 +221,7 @@ def services(self) -> Dict:

@services.setter
def services(self, value: Dict):
self._services = self._resolve_job_services(value)
self._services = _resolve_job_services(value)

@property
def component(self) -> Union[str, CommandComponent]:
Expand Down Expand Up @@ -550,37 +550,6 @@ def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema

return CommandSchema(context=context)

def _resolve_job_services(self, services: dict) -> dict:
"""Resolve normal dict to dict[str, JobService]"""
# pylint disable=no-self-use
if services is None:
return None
if not isinstance(services, dict):
msg = f"Services must be a dict, got {type(services)} instead."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.COMMAND_JOB,
error_category=ErrorCategory.USER_ERROR,
)

result = {}
for name, service in services.items():
if isinstance(service, dict):
service = load_from_dict(JobServiceSchema, service, context={BASE_PATH_CONTEXT_KEY: "."})
elif not isinstance(service, JobService):
msg = (
f"Service value for key {name!r} must be a dict or JobService object, got {type(service)} instead."
)
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.COMMAND_JOB,
error_category=ErrorCategory.USER_ERROR,
)
result[name] = service
return result

def __call__(self, *args, **kwargs) -> "Command":
"""Call Command as a function will return a new instance each time."""
if isinstance(self._component, Component):
Expand Down Expand Up @@ -618,3 +587,33 @@ def __call__(self, *args, **kwargs) -> "Command":
target=ErrorTarget.COMMAND_JOB,
error_type=ValidationErrorType.INVALID_VALUE,
)


def _resolve_job_services(services: dict) -> dict:
"""Resolve normal dict to dict[str, JobService]"""
if services is None:
return None

if not isinstance(services, dict):
msg = f"Services must be a dict, got {type(services)} instead."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.COMMAND_JOB,
error_category=ErrorCategory.USER_ERROR,
)

result = {}
for name, service in services.items():
if isinstance(service, dict):
service = load_from_dict(JobServiceSchema, service, context={BASE_PATH_CONTEXT_KEY: "."})
elif not isinstance(service, JobService):
msg = f"Service value for key {name!r} must be a dict or JobService object, got {type(service)} instead."
raise ValidationException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.COMMAND_JOB,
error_category=ErrorCategory.USER_ERROR,
)
result[name] = service
return result
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CommandJobLimits(JobLimits):
"""

def __init__(self, *, timeout: int = None):
super().__init__()
self.type = JobType.COMMAND
self.timeout = timeout

Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
timeout: int = None,
trial_timeout: int = None,
):
super().__init__()
self.type = JobType.SWEEP
self.max_concurrent_trials = max_concurrent_trials
self.max_total_trials = max_total_trials
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
code: str = None,
environment_variables: Dict = None,
distribution: Union[dict, MpiDistribution, TensorFlowDistribution, PyTorchDistribution] = None,
environment: Union["Environment", str] = None,
environment: Union[Environment, str] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _from_rest_object(cls, obj: RestEarlyTerminationPolicy) -> "EarlyTermination

policy = None
if obj.policy_type == EarlyTerminationPolicyType.BANDIT:
policy = BanditPolicy._from_rest_object(obj)
policy = BanditPolicy._from_rest_object(obj) # pylint: disable=protected-access

if obj.policy_type == EarlyTerminationPolicyType.MEDIAN_STOPPING:
policy = MedianStoppingPolicy._from_rest_object(obj)
policy = MedianStoppingPolicy._from_rest_object(obj) # pylint: disable=protected-access

if obj.policy_type == EarlyTerminationPolicyType.TRUNCATION_SELECTION:
policy = TruncationSelectionPolicy._from_rest_object(obj)
policy = TruncationSelectionPolicy._from_rest_object(obj) # pylint: disable=protected-access

return policy

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Callable, Dict, Optional, Union
from typing import Dict, Optional, Type, Union

from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException

Expand All @@ -27,16 +27,18 @@
)
from .search_space import SweepDistribution

SAMPLING_ALGORITHM_TO_REST_CONSTRUCTOR: Dict[SamplingAlgorithmType, Callable[[], RestSamplingAlgorithm]] = {
SamplingAlgorithmType.RANDOM: lambda: RestRandomSamplingAlgorithm(),
SamplingAlgorithmType.GRID: lambda: RestGridSamplingAlgorithm(),
SamplingAlgorithmType.BAYESIAN: lambda: RestBayesianSamplingAlgorithm(),
# pylint: disable=unnecessary-lambda
SAMPLING_ALGORITHM_TO_REST_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[RestSamplingAlgorithm]] = {
SamplingAlgorithmType.RANDOM: RestRandomSamplingAlgorithm,
SamplingAlgorithmType.GRID: RestGridSamplingAlgorithm,
SamplingAlgorithmType.BAYESIAN: RestBayesianSamplingAlgorithm,
}

SAMPLING_ALGORITHM_CONSTRUCTOR: Dict[SamplingAlgorithmType, Callable[[], SamplingAlgorithm]] = {
SamplingAlgorithmType.RANDOM: lambda: RandomSamplingAlgorithm(),
SamplingAlgorithmType.GRID: lambda: GridSamplingAlgorithm(),
SamplingAlgorithmType.BAYESIAN: lambda: BayesianSamplingAlgorithm(),
# pylint: disable=unnecessary-lambda
SAMPLING_ALGORITHM_CONSTRUCTOR: Dict[SamplingAlgorithmType, Type[SamplingAlgorithm]] = {
SamplingAlgorithmType.RANDOM: RandomSamplingAlgorithm,
SamplingAlgorithmType.GRID: GridSamplingAlgorithm,
SamplingAlgorithmType.BAYESIAN: BayesianSamplingAlgorithm,
}


Expand Down Expand Up @@ -158,11 +160,13 @@ def sampling_algorithm(self, value: Optional[Union[SamplingAlgorithm, str]] = No
def _get_rest_sampling_algorithm(self) -> RestSamplingAlgorithm:
# TODO: self.sampling_algorithm will always return SamplingAlgorithm
if isinstance(self.sampling_algorithm, SamplingAlgorithm):
return self.sampling_algorithm._to_rest_object()
elif isinstance(self.sampling_algorithm, str):
return SAMPLING_ALGORITHM_CONSTRUCTOR[
return self.sampling_algorithm._to_rest_object() # pylint: disable=protected-access

if isinstance(self.sampling_algorithm, str):
return SAMPLING_ALGORITHM_CONSTRUCTOR[ # pylint: disable=protected-access
SamplingAlgorithmType(self.sampling_algorithm.lower().capitalize())
]()._to_rest_object()

msg = f"Received unsupported value {self._sampling_algorithm} as the sampling algorithm"
raise ValidationException(
message=msg,
Expand Down Expand Up @@ -208,6 +212,3 @@ def early_termination(self, value: Union[EarlyTerminationPolicy, str]):
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.INVALID_VALUE,
)

def _override_missing_properties_from_trial(self):
return
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ def _from_rest_object(cls, obj: RestSamplingAlgorithm) -> "SamplingAlgorithm":

sampling_algorithm = None
if obj.sampling_algorithm_type == SamplingAlgorithmType.RANDOM:
sampling_algorithm = RandomSamplingAlgorithm._from_rest_object(obj)
sampling_algorithm = RandomSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access

if obj.sampling_algorithm_type == SamplingAlgorithmType.GRID:
sampling_algorithm = GridSamplingAlgorithm._from_rest_object(obj)
sampling_algorithm = GridSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access

if obj.sampling_algorithm_type == SamplingAlgorithmType.BAYESIAN:
sampling_algorithm = BayesianSamplingAlgorithm._from_rest_object(obj)
sampling_algorithm = BayesianSamplingAlgorithm._from_rest_object(obj) # pylint: disable=protected-access

return sampling_algorithm

Expand All @@ -42,6 +42,7 @@ def __init__(
rule=None,
seed=None,
) -> None:
super().__init__()
self.type = SamplingAlgorithmType.RANDOM.lower()
self.rule = rule
self.seed = seed
Expand All @@ -62,6 +63,7 @@ def _from_rest_object(cls, obj: RestRandomSamplingAlgorithm) -> "RandomSamplingA

class GridSamplingAlgorithm(SamplingAlgorithm):
def __init__(self) -> None:
super().__init__()
self.type = SamplingAlgorithmType.GRID.lower()

# pylint: disable=no-self-use
Expand All @@ -76,6 +78,7 @@ def _from_rest_object(cls, obj: RestGridSamplingAlgorithm) -> "GridSamplingAlgor

class BayesianSamplingAlgorithm(SamplingAlgorithm):
def __init__(self):
super().__init__()
self.type = SamplingAlgorithmType.BAYESIAN.lower()

# pylint: disable=no-self-use
Expand Down
50 changes: 25 additions & 25 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/sweep/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@


class SweepDistribution(ABC, RestTranslatableMixin):
def __init__(self, *, type: str = None):
def __init__(self, *, type: str = None): # pylint: disable=redefined-builtin
self.type = type

@classmethod
def _from_rest_object(cls, rest: List) -> "SweepDistribution":
def _from_rest_object(cls, obj: List) -> "SweepDistribution":

mapping = {
SearchSpace.CHOICE: Choice,
Expand All @@ -33,17 +33,17 @@ def _from_rest_object(cls, rest: List) -> "SweepDistribution":
SearchSpace.QLOGUNIFORM: QLogUniform,
}

ss_class = mapping.get(rest[0], None)
ss_class = mapping.get(obj[0], None)
if ss_class:
return ss_class._from_rest_object(rest)
else:
msg = f"Unknown search space type: {rest[0]}"
raise JobException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.SWEEP_JOB,
error_category=ErrorCategory.SYSTEM_ERROR,
)
return ss_class._from_rest_object(obj)

msg = f"Unknown search space type: {obj[0]}"
raise JobException(
message=msg,
no_personal_data_message=msg,
target=ErrorTarget.SWEEP_JOB,
error_category=ErrorCategory.SYSTEM_ERROR,
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, SweepDistribution):
Expand Down Expand Up @@ -73,8 +73,8 @@ def _to_rest_object(self) -> List:
return [self.type, [items]]

@classmethod
def _from_rest_object(cls, rest: List) -> "Choice":
rest_values = rest[1][0]
def _from_rest_object(cls, obj: List) -> "Choice":
rest_values = obj[1][0]
from_rest_values = []
for rest_value in rest_values:
if isinstance(rest_value, dict):
Expand All @@ -84,7 +84,7 @@ def _from_rest_object(cls, rest: List) -> "Choice":
# first assume that any dictionary value is a valid distribution (i.e. normal, uniform, etc)
# and try to deserialize it into a the correct SDK distribution object
from_rest_dict[k] = SweepDistribution._from_rest_object(v)
except Exception:
except Exception: # pylint: disable=broad-except
# if an exception is raised, assume that the value was not a valid distribution and use the
# value as it is for deserialization
from_rest_dict[k] = v
Expand All @@ -105,8 +105,8 @@ def _to_rest_object(self) -> List:
return [self.type, [self.mu, self.sigma]]

@classmethod
def _from_rest_object(cls, rest: List) -> "Normal":
return cls(mu=rest[1][0], sigma=rest[1][1])
def _from_rest_object(cls, obj: List) -> "Normal":
return cls(mu=obj[1][0], sigma=obj[1][1])


class LogNormal(Normal):
Expand All @@ -125,8 +125,8 @@ def _to_rest_object(self) -> List:
return [self.type, [self.mu, self.sigma, self.q]]

@classmethod
def _from_rest_object(cls, rest: List) -> "QNormal":
return cls(mu=rest[1][0], sigma=rest[1][1], q=rest[1][2])
def _from_rest_object(cls, obj: List) -> "QNormal":
return cls(mu=obj[1][0], sigma=obj[1][1], q=obj[1][2])


class QLogNormal(QNormal):
Expand All @@ -145,8 +145,8 @@ def _to_rest_object(self) -> List:
return [self.type, [self.upper]]

@classmethod
def _from_rest_object(cls, rest: List) -> "Randint":
return cls(upper=rest[1][0])
def _from_rest_object(cls, obj: List) -> "Randint":
return cls(upper=obj[1][0])


class Uniform(SweepDistribution):
Expand All @@ -160,8 +160,8 @@ def _to_rest_object(self) -> List:
return [self.type, [self.min_value, self.max_value]]

@classmethod
def _from_rest_object(cls, rest: List) -> "Uniform":
return cls(min_value=rest[1][0], max_value=rest[1][1])
def _from_rest_object(cls, obj: List) -> "Uniform":
return cls(min_value=obj[1][0], max_value=obj[1][1])


class LogUniform(Uniform):
Expand All @@ -186,8 +186,8 @@ def _to_rest_object(self) -> List:
return [self.type, [self.min_value, self.max_value, self.q]]

@classmethod
def _from_rest_object(cls, rest: List) -> "QUniform":
return cls(min_value=rest[1][0], max_value=rest[1][1], q=rest[1][2])
def _from_rest_object(cls, obj: List) -> "QUniform":
return cls(min_value=obj[1][0], max_value=obj[1][1], q=obj[1][2])


class QLogUniform(QUniform):
Expand Down
Loading

0 comments on commit fca5f94

Please sign in to comment.