Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

nicer gen_config_impl #944

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pytext/config/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Union
from typing import Dict, List, Union


def is_component_class(obj):
Expand Down Expand Up @@ -38,6 +38,8 @@ def resolve_optional(type_v):


def cast_str(to_type, value):
if type(value) != str:
return value
if to_type == int:
return int(value)
elif to_type == float:
Expand All @@ -51,9 +53,9 @@ def cast_str(to_type, value):
return False
else:
raise Exception(f'Not a boolean value: "{value}"')
elif getattr(to_type, "__origin__", None) == list:
elif getattr(to_type, "__origin__", None) in (list, List):
return [cast_str(to_type.__args__[0], v.strip()) for v in value.split(",")]
elif getattr(to_type, "__origin__", None) == dict:
elif getattr(to_type, "__origin__", None) in (dict, Dict):
key_type, value_type = to_type.__args__
ret = {}
for entry in value.split(","):
Expand Down
72 changes: 37 additions & 35 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def run_single(
)


def gen_config_impl(task_name, options):
def gen_config_impl(task_name, *args, **kwargs):
# import the classes required by parameters
requested_classes = [locate(opt) for opt in options] + [locate(task_name)]
requested_classes = [locate(opt) for opt in args] + [locate(task_name)]
register_tasks(requested_classes)

task_class_set = find_config_class(task_name)
Expand All @@ -132,42 +132,44 @@ def gen_config_impl(task_name, options):
root = PyTextConfig(task=task_config(), version=LATEST_VERSION)
eprint("INFO - Applying task option:", task_class.__name__)

# Use components listed in options instead of defaults
for opt in options:
# Use components in args instead of defaults
for opt in args:
if "=" in opt:
param_path, value = opt.split("=", 1)
found = find_param(root, "." + param_path)
if len(found) == 1:
eprint("INFO - Applying parameter option to", found[0], ":", opt)
replace_param(root, found[0].split("."), value)
elif not found:
raise Exception(f"Unknown parameter option: {opt}")
kwargs[param_path] = value
continue
replace_class_set = find_config_class(opt)
if not replace_class_set:
raise Exception(f"Not a component class: {opt}")
elif len(replace_class_set) > 1:
raise Exception(f"Multiple component named {opt}: {replace_class_set}")
replace_class = next(iter(replace_class_set))
found = replace_components(root, opt, get_subclasses(replace_class))
if found:
eprint("INFO - Applying class option:", ".".join(reversed(found)), "=", opt)
obj = root
for k in reversed(found[1:]):
obj = getattr(obj, k)
if hasattr(replace_class, "Config"):
setattr(obj, found[0], replace_class.Config())
else:
raise Exception(f"Multiple possibilities for {opt}: {', '.join(found)}")
setattr(obj, found[0], replace_class())
else:
replace_class_set = find_config_class(opt)
if not replace_class_set:
raise Exception(f"Not a component class: {opt}")
elif len(replace_class_set) > 1:
raise Exception(f"Multiple component named {opt}: {replace_class_set}")
replace_class = next(iter(replace_class_set))
found = replace_components(root, opt, get_subclasses(replace_class))
if found:
eprint(
"INFO - Applying class option:",
"->".join(reversed(found)),
"=",
opt,
)
obj = root
for k in reversed(found[1:]):
obj = getattr(obj, k)
if hasattr(replace_class, "Config"):
setattr(obj, found[0], replace_class.Config())
else:
setattr(obj, found[0], replace_class())
else:
raise Exception(f"Unknown class option: {opt}")
raise Exception(f"Unknown class option: {opt}")

# Use parameters in kwargs instead of defaults
for param_path, value in kwargs.items():
found = find_param(root, "." + param_path)
if len(found) == 1:
eprint("INFO - Applying parameter option to", found[0], "=", value)
replace_param(root, found[0].split("."), value)
elif not found:
raise Exception(f"Unknown parameter option: {param_path}")
else:
raise Exception(
f"Multiple possibilities for {param_path}: {', '.join(found)}"
)

return root


Expand Down Expand Up @@ -244,7 +246,7 @@ def gen_default_config(context, task_name, options):
components as `options`.
"""
try:
cfg = gen_config_impl(task_name, options)
cfg = gen_config_impl(task_name, *options)
except TypeError as ex:
eprint(
"ERROR - Cannot create this config",
Expand Down