Skip to content

Commit

Permalink
cache lm packed data (facebookresearch#447)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#447

cache lm packed data
For every single GPU node, we cache the packed numberized LM data in Json format (pickle is slower in theory)
The first epoch, we will generate the numberized_rows and write to cache
For the following epoch, we will read from cache directly.

Reviewed By: borguz

Differential Revision: D14651814

fbshipit-source-id: 05b0880e6d0cf979e3cc5bcdf0ea60f3d2702320
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Apr 4, 2019
1 parent 4e7a3b1 commit d6054c1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
export_saved_model_to_torchscript,
get_logits as workflow_get_logits,
prepare_task_metadata,
preprocess_task,
test_model_from_snapshot_path,
train_model,
)
Expand Down Expand Up @@ -313,6 +314,7 @@ def train(context):
if config.use_tensorboard:
metric_channels.append(TensorBoardChannel())
try:
preprocess_task(config)
if config.distributed_world_size == 1:
train_model(config, metric_channels=metric_channels)
else:
Expand Down
9 changes: 9 additions & 0 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def prepare_task(
return task


def preprocess_task(config: PyTextConfig):
if hasattr(config.task, "data") and hasattr(config.task.data, "numberized_dir"):
if config.load_snapshot_path and os.path.isfile(config.load_snapshot_path):
task = load(config.load_snapshot_path)
else:
task = create_task(config.task)
task.data.initialize_numberized_data()


def save_and_export(
config: PyTextConfig, task: Task, metric_channels: Optional[List[Channel]] = None
) -> None:
Expand Down

0 comments on commit d6054c1

Please sign in to comment.