Skip to content

Commit

Permalink
add an optional truncation operator when exporting (facebookresearch#425
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: facebookresearch#425

This diff adds an optional truncation operator for bi-transformer models if configured by user.

Reviewed By: bethebunny

Differential Revision: D14648704

fbshipit-source-id: f20576d7754a3718485db826a1ed77adb6bc432a
  • Loading branch information
Haoran Li authored and facebook-github-bot committed Apr 1, 2019
1 parent b21c407 commit 37c6468
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions pytext/config/config_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,41 @@


def v0_to_v1(json_config):
# migrate optimizer params
# migrate optimizer and random_seed params
[task] = json_config["task"].values()
if (
"optimizer" not in task
or "Adam" in task["optimizer"]
or "SGD" in task["optimizer"]
or "NAG" in task["optimizer"]
):
) and ("trainer" not in task or "random_seed" not in task["trainer"]):
return json_config
op_type = task["optimizer"].get("type", "adam")
if op_type == "adam":
op_config = {"Adam": {}}
for key in ["lr", "weight_decay"]:
if key in task["optimizer"]:
op_config["Adam"][key] = task["optimizer"][key]
elif op_type == "sgd":
op_config = {"SGD": {}}
for key in ["lr", "momentum"]:
if key in task["optimizer"]:
op_config["SGD"][key] = task["optimizer"][key]
elif op_type == "nag":
op_config = {"NAG": {}}
for key in ["lr", "weight_decay", "momentum"]:
if key in task["optimizer"]:
op_config["NAG"][key] = task["optimizer"][key]
else:
raise ValueError("Migration not supported for your optimizer")
task["optimizer"] = op_config

if "trainer" in task and "random_seed" in task["trainer"]:
json_config["random_seed"] = task["trainer"]["random_seed"]
del task["trainer"]["random_seed"]
if "optimizer" in task and not any(
opt in task["optimizer"] for opt in ["Adam", "SGD", "NAG"]
):
op_type = task["optimizer"].get("type", "adam")
if op_type == "adam":
op_config = {"Adam": {}}
for key in ["lr", "weight_decay"]:
if key in task["optimizer"]:
op_config["Adam"][key] = task["optimizer"][key]
elif op_type == "sgd":
op_config = {"SGD": {}}
for key in ["lr", "momentum"]:
if key in task["optimizer"]:
op_config["SGD"][key] = task["optimizer"][key]
elif op_type == "nag":
op_config = {"NAG": {}}
for key in ["lr", "weight_decay", "momentum"]:
if key in task["optimizer"]:
op_config["NAG"][key] = task["optimizer"][key]
else:
raise ValueError("Migration not supported for your optimizer")
task["optimizer"] = op_config
json_config["version"] = 1
return json_config

Expand Down

0 comments on commit 37c6468

Please sign in to comment.