Skip to content

Commit

Permalink
Changed names of constant variables
Browse files Browse the repository at this point in the history
  • Loading branch information
gcasadesus committed Aug 2, 2021
1 parent 8e40e41 commit eb852e6
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions dislib/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
cbor2 = None

# Dislib models with saving tested (model: str -> module: str)
_implemented_models = {
IMPLEMENTED_MODELS = {
"KMeans": "cluster",
"GaussianMixture": "cluster",
"CascadeSVM": "classification",
Expand All @@ -41,7 +41,7 @@
}

# Classes used by models
_dislib_classes = {
DISLIB_CLASSES = {
"KMeans": dislib.cluster.KMeans,
"DecisionTreeClassifier": DecisionTreeClassifier,
"_Node": _Node,
Expand All @@ -50,7 +50,7 @@
"_SkTreeWrapper": _SkTreeWrapper,
}

_sklearn_classes = {
SKLEARN_CLASSES = {
"SVC": SklearnSVC,
"DecisionTreeClassifier": SklearnDTClassifier,
}
Expand Down Expand Up @@ -99,10 +99,10 @@ def save_model(model, filepath, overwrite=True, save_format="json"):

# Check for dislib model
model_name = model.__class__.__name__
if model_name not in _implemented_models.keys():
if model_name not in IMPLEMENTED_MODELS.keys():
raise NotImplementedError(
"Saving has only been implemented for the following models:\n%s"
% _implemented_models.keys()
% IMPLEMENTED_MODELS.keys()
)

# Synchronize model
Expand Down Expand Up @@ -170,15 +170,15 @@ def load_model(filepath, load_format="json"):

# Check for dislib model
model_name = model_metadata["model_name"]
if model_name not in _implemented_models.keys():
if model_name not in IMPLEMENTED_MODELS.keys():
raise NotImplementedError(
"Saving has only been implemented for the following models:\n%s"
% _implemented_models.keys()
% IMPLEMENTED_MODELS.keys()
)
del model_metadata["model_name"]

# Create model
model_module = getattr(ds, _implemented_models[model_name])
model_module = getattr(ds, IMPLEMENTED_MODELS[model_name])
model_class = getattr(model_module, model_name)
model = model_class()
model.__dict__.update(model_metadata)
Expand Down Expand Up @@ -249,7 +249,7 @@ def _encode_helper(obj):
"items": obj.__getstate__(),
}
elif isinstance(
obj, tuple(_dislib_classes.values()) + tuple(_sklearn_classes.values())
obj, tuple(DISLIB_CLASSES.values()) + tuple(SKLEARN_CLASSES.values())
):
return {
"class_name": obj.__class__.__name__,
Expand Down Expand Up @@ -302,12 +302,12 @@ def _decode_helper(obj):
model.__setstate__(dict_)
return model
elif (
class_name in _dislib_classes.keys()
class_name in DISLIB_CLASSES.keys()
and "dislib" in obj["module_name"]
):
dict_ = _decode_helper(obj["items"])
if class_name == "DecisionTreeClassifier":
model = _dislib_classes[obj["class_name"]](
model = DISLIB_CLASSES[obj["class_name"]](
try_features=dict_.pop("try_features"),
max_depth=dict_.pop("max_depth"),
distr_depth=dict_.pop("distr_depth"),
Expand All @@ -317,17 +317,17 @@ def _decode_helper(obj):
)
elif class_name == "_SkTreeWrapper":
sk_tree = _decode_helper(dict_.pop("sk_tree"))
model = _dislib_classes[obj["class_name"]](sk_tree)
model = DISLIB_CLASSES[obj["class_name"]](sk_tree)
else:
model = _dislib_classes[obj["class_name"]]()
model = DISLIB_CLASSES[obj["class_name"]]()
model.__dict__.update(dict_)
return model
elif (
class_name in _sklearn_classes.keys()
class_name in SKLEARN_CLASSES.keys()
and "sklearn" in obj["module_name"]
):
dict_ = _decode_helper(obj["items"])
model = _sklearn_classes[obj["class_name"]]()
model = SKLEARN_CLASSES[obj["class_name"]]()
model.__dict__.update(dict_)
return model
elif class_name == "callable":
Expand Down

0 comments on commit eb852e6

Please sign in to comment.