This repository was archived by the owner on Mar 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 454
Incremental PBG training #232
Open
howardchanth
wants to merge
12
commits into
facebookresearch:main
Choose a base branch
from
howardchanth:master
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
68a1935
* Initial ideas of adding incremental training feature
affbd9e
* Initial version of incremental training
010ddfa
* Initial workable version of incremental training
7223618
* Faster loading of pretrained embeddings in enlargements
0ea84d9
* Workable version of recurrent training
3f0a9d6
* Fix bugs in recurrent training
0ecff68
* Develop workable v2, still investigating bugs
c337299
Delete update_plan.txt
b5969cd
* Bug fixed; recurrent training implemented
b44f6c3
* Bug fixed; recurrent training implemented
b9e2cf6
* bug fixes on recurrent training
38feaa7
fixed overfitting to sub-buckets in GPU Training
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
Partition, | ||
Rank, | ||
) | ||
from torchbiggraph.graph_storages import AbstractEntityStorage | ||
from torchbiggraph.util import CouldNotLoadData | ||
|
||
|
||
|
@@ -416,8 +417,9 @@ def switch_to_new_version(self) -> None: | |
|
||
def remove_old_version(self, config: ConfigSchema) -> None: | ||
old_version = self.checkpoint_version - 1 | ||
# We never create a v0 checkpoint, so if there is one we leave it there. | ||
if old_version == 0: | ||
# We almost never create a v0 checkpoint, so if there is one we leave it there. | ||
# Checkpoint of version 0 will be created in incremental training, so need to remove it | ||
if old_version == 0 and config.init_entity_path is None: | ||
return | ||
for entity, econf in config.entities.items(): | ||
for part in range(self.rank, econf.num_partitions, self.num_machines): | ||
|
@@ -448,6 +450,90 @@ def preserve_current_version(self, config: ConfigSchema, epoch_idx: int) -> None | |
self.storage.copy_model_to_snapshot(version, epoch_idx) | ||
self.storage.copy_version_to_snapshot(version, epoch_idx) | ||
|
||
def enlarge( | ||
self, | ||
config: ConfigSchema, | ||
init_entity_storage: AbstractEntityStorage, | ||
entity_storage: AbstractEntityStorage, | ||
entity_counts: Dict[str, List[int]], | ||
) -> None: | ||
""" | ||
Enlarge a checkpoint to the new checkpoint path | ||
|
||
* Read new entity counts and offsets from the updated partitioned data | ||
* Enlarge previous N embeddings to N + M with M new entities | ||
- Map the previous N embeddings to according to the new offsets | ||
- Initialize the rest M embeddings to with random vectors | ||
@param config: Config dictionary for the PBG run | ||
@param init_entity_storage: | ||
@param entity_storage: | ||
@param entity_counts: | ||
@return: None | ||
""" | ||
logger.debug(f"Enlarging checkpoint from {config.init_path} to {config.checkpoint_path}") | ||
# Checkpoint exist, not going to enlarge | ||
if self.checkpoint_version > 0: | ||
logger.info(f"Checkpoint with version {self.checkpoint_version} found at {config.checkpoint_path}" | ||
f", not enlarging") | ||
return | ||
init_entity_offsets: Dict[str, List[str]] = {} | ||
init_entity_counts: Dict[str, List[int]] = {} | ||
init_checkpoint_storage: AbstractCheckpointStorage = CHECKPOINT_STORAGES.make_instance(config.init_path) | ||
init_version: int = init_checkpoint_storage.load_version() | ||
metadata = self.collect_metadata() | ||
# Load offsets from initial entities | ||
for entity, econf in config.entities.items(): | ||
init_entity_offsets[entity] = [] | ||
init_entity_counts[entity] = [] | ||
for part in range(econf.num_partitions): | ||
init_entity_offsets[entity]. \ | ||
append(init_entity_storage.load_names(entity, part)) | ||
init_entity_counts[entity]. \ | ||
append(init_entity_storage.load_count(entity, part)) | ||
|
||
# Enlarge embeddings to the new check point | ||
for entity, econf in config.entities.items(): | ||
for part in range(econf.num_partitions): | ||
|
||
embs, _ = init_checkpoint_storage.load_entity_partition(init_version, entity, part) | ||
|
||
new_count = entity_counts[entity][part] | ||
dimension = config.entity_dimension(entity) | ||
|
||
new_embs = torch.randn((new_count, dimension)) | ||
|
||
logger.debug(f"Loading {entity} embeddings of shape {new_embs.shape}") | ||
|
||
# Initialize an (N + M) X (emb_dim) enlarged embeddings storage | ||
init_names: Dict = {init_name: j for (j, init_name) in enumerate(init_entity_offsets[entity][part])} | ||
new_names: List = entity_storage.load_names(entity, part) | ||
subset_idxs = {name: None for name in init_names.keys()} | ||
old_subset_idxs = {name: None for name in init_names.keys()} | ||
|
||
init_names_set = set(init_names.keys()) | ||
|
||
for i, new_name in enumerate(new_names): | ||
if new_name in init_names_set: | ||
subset_idxs[new_name] = i | ||
old_subset_idxs[new_name] = init_names[new_name] | ||
|
||
subset_idxs = [v for v in subset_idxs.values() if v is not None] | ||
old_subset_idxs = [v for v in old_subset_idxs.values() if v is not None] | ||
|
||
# Enlarged embeddings with the offsets obtained from previous training | ||
# Initialize new embeddings with random numbers | ||
old_embs = embs[old_subset_idxs].clone() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cut this clone out unless in debug? |
||
new_embs[subset_idxs, :] = embs[old_subset_idxs].clone() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would do this in one like to save a memory allocation |
||
|
||
# Test case 1: Whether the embeddings are correctly mapped into the new embeddings | ||
assert torch.equal(new_embs[subset_idxs, :], old_embs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assert could be quite expensive. |
||
|
||
embs = new_embs | ||
optim_state = None | ||
|
||
# Save the previous embeddings as the first version (v0) | ||
self.storage.save_entity_partition(0, entity, part, embs, optim_state, metadata) | ||
|
||
def close(self) -> None: | ||
self.join() | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,15 +86,29 @@ def read(self, path: Path): | |
"'pip install parquet'" | ||
) | ||
|
||
with path.open("rb") as tf: | ||
columns = [self.lhs_col, self.rhs_col] | ||
if self.rel_col is not None: | ||
columns.append(self.rel_col) | ||
for row in parquet.reader(tf, columns=columns): | ||
if path.is_dir(): | ||
files = [p for p in path.glob('*.parquet')] | ||
random.shuffle(files) | ||
for pq in files: | ||
with pq.open("rb") as tf: | ||
columns = [self.lhs_col, self.rhs_col] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add support for weights |
||
if self.rel_col is not None: | ||
columns.append(self.rel_col) | ||
for row in parquet.reader(tf, columns=columns): | ||
if self.rel_col is not None: | ||
yield row | ||
else: | ||
yield row[0], row[1], None | ||
else: | ||
with path.open("rb") as tf: | ||
columns = [self.lhs_col, self.rhs_col] | ||
if self.rel_col is not None: | ||
yield row | ||
else: | ||
yield row[0], row[1], None | ||
columns.append(self.rel_col) | ||
for row in parquet.reader(tf, columns=columns): | ||
if self.rel_col is not None: | ||
yield row | ||
else: | ||
yield row[0], row[1], None | ||
|
||
|
||
def collect_relation_types( | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some tests for enlarge? It looks like it'll work correctly, but I'd rather be guaranteed that it will