Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Incremental PBG training #232

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions test/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_basic(self):
entity_path="foo",
edge_paths=["bar"],
checkpoint_path="baz",
init_entity_path="foo"
)
metadata = ConfigMetadataProvider(config).get_checkpoint_metadata()
self.assertIsInstance(metadata, dict)
Expand Down
90 changes: 88 additions & 2 deletions torchbiggraph/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Partition,
Rank,
)
from torchbiggraph.graph_storages import AbstractEntityStorage
from torchbiggraph.util import CouldNotLoadData


Expand Down Expand Up @@ -416,8 +417,9 @@ def switch_to_new_version(self) -> None:

def remove_old_version(self, config: ConfigSchema) -> None:
old_version = self.checkpoint_version - 1
# We never create a v0 checkpoint, so if there is one we leave it there.
if old_version == 0:
# We almost never create a v0 checkpoint, so if there is one we leave it there.
# Checkpoint of version 0 will be created in incremental training, so need to remove it
if old_version == 0 and config.init_entity_path is None:
return
for entity, econf in config.entities.items():
for part in range(self.rank, econf.num_partitions, self.num_machines):
Expand Down Expand Up @@ -448,6 +450,90 @@ def preserve_current_version(self, config: ConfigSchema, epoch_idx: int) -> None
self.storage.copy_model_to_snapshot(version, epoch_idx)
self.storage.copy_version_to_snapshot(version, epoch_idx)

def enlarge(
self,
config: ConfigSchema,
init_entity_storage: AbstractEntityStorage,
entity_storage: AbstractEntityStorage,
entity_counts: Dict[str, List[int]],
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some tests for enlarge? It looks like it'll work correctly, but I'd rather be guaranteed that it will

"""
Enlarge a checkpoint to the new checkpoint path

* Read new entity counts and offsets from the updated partitioned data
* Enlarge previous N embeddings to N + M with M new entities
- Map the previous N embeddings to according to the new offsets
- Initialize the rest M embeddings to with random vectors
@param config: Config dictionary for the PBG run
@param init_entity_storage:
@param entity_storage:
@param entity_counts:
@return: None
"""
logger.debug(f"Enlarging checkpoint from {config.init_path} to {config.checkpoint_path}")
# Checkpoint exist, not going to enlarge
if self.checkpoint_version > 0:
logger.info(f"Checkpoint with version {self.checkpoint_version} found at {config.checkpoint_path}"
f", not enlarging")
return
init_entity_offsets: Dict[str, List[str]] = {}
init_entity_counts: Dict[str, List[int]] = {}
init_checkpoint_storage: AbstractCheckpointStorage = CHECKPOINT_STORAGES.make_instance(config.init_path)
init_version: int = init_checkpoint_storage.load_version()
metadata = self.collect_metadata()
# Load offsets from initial entities
for entity, econf in config.entities.items():
init_entity_offsets[entity] = []
init_entity_counts[entity] = []
for part in range(econf.num_partitions):
init_entity_offsets[entity]. \
append(init_entity_storage.load_names(entity, part))
init_entity_counts[entity]. \
append(init_entity_storage.load_count(entity, part))

# Enlarge embeddings to the new check point
for entity, econf in config.entities.items():
for part in range(econf.num_partitions):

embs, _ = init_checkpoint_storage.load_entity_partition(init_version, entity, part)

new_count = entity_counts[entity][part]
dimension = config.entity_dimension(entity)

new_embs = torch.randn((new_count, dimension))

logger.debug(f"Loading {entity} embeddings of shape {new_embs.shape}")

# Initialize an (N + M) X (emb_dim) enlarged embeddings storage
init_names: Dict = {init_name: j for (j, init_name) in enumerate(init_entity_offsets[entity][part])}
new_names: List = entity_storage.load_names(entity, part)
subset_idxs = {name: None for name in init_names.keys()}
old_subset_idxs = {name: None for name in init_names.keys()}

init_names_set = set(init_names.keys())

for i, new_name in enumerate(new_names):
if new_name in init_names_set:
subset_idxs[new_name] = i
old_subset_idxs[new_name] = init_names[new_name]

subset_idxs = [v for v in subset_idxs.values() if v is not None]
old_subset_idxs = [v for v in old_subset_idxs.values() if v is not None]

# Enlarged embeddings with the offsets obtained from previous training
# Initialize new embeddings with random numbers
old_embs = embs[old_subset_idxs].clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cut this clone out unless in debug?

new_embs[subset_idxs, :] = embs[old_subset_idxs].clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do this in one like to save a memory allocation


# Test case 1: Whether the embeddings are correctly mapped into the new embeddings
assert torch.equal(new_embs[subset_idxs, :], old_embs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert could be quite expensive.


embs = new_embs
optim_state = None

# Save the previous embeddings as the first version (v0)
self.storage.save_entity_partition(0, entity, part, embs, optim_state, metadata)

def close(self) -> None:
self.join()

Expand Down
7 changes: 7 additions & 0 deletions torchbiggraph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ class ConfigSchema(Schema):
"the entities of some types."
},
)
init_entity_path: Optional[str] = attr.ib(
default=None,
metadata={
"help": "If set, it must be a path to a directory that "
"contains initial values of the entities and their offsets "
},
)
checkpoint_preservation_interval: Optional[int] = attr.ib(
default=None,
metadata={
Expand Down
30 changes: 22 additions & 8 deletions torchbiggraph/converters/importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,29 @@ def read(self, path: Path):
"'pip install parquet'"
)

with path.open("rb") as tf:
columns = [self.lhs_col, self.rhs_col]
if self.rel_col is not None:
columns.append(self.rel_col)
for row in parquet.reader(tf, columns=columns):
if path.is_dir():
files = [p for p in path.glob('*.parquet')]
random.shuffle(files)
for pq in files:
with pq.open("rb") as tf:
columns = [self.lhs_col, self.rhs_col]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add support for weights

if self.rel_col is not None:
columns.append(self.rel_col)
for row in parquet.reader(tf, columns=columns):
if self.rel_col is not None:
yield row
else:
yield row[0], row[1], None
else:
with path.open("rb") as tf:
columns = [self.lhs_col, self.rhs_col]
if self.rel_col is not None:
yield row
else:
yield row[0], row[1], None
columns.append(self.rel_col)
for row in parquet.reader(tf, columns=columns):
if self.rel_col is not None:
yield row
else:
yield row[0], row[1], None


def collect_relation_types(
Expand Down
Loading