Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions openfold3/core/data/framework/single_datasets/base_of3.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,10 @@ def __init__(self, dataset_config) -> None:
)
self.datapoint_cache = {}

# CCD - only used if template structures are not preprocessed
if dataset_config.dataset_paths.template_structure_array_directory is not None:
self.ccd = None
else:
if dataset_config.dataset_paths.template_structures_directory is not None:
self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file)
else:
self.ccd = None

# Dataset configuration
# n_tokens can be set in the getitem method separately for each sample using
Expand Down
13 changes: 10 additions & 3 deletions openfold3/core/data/framework/single_datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,16 @@ def create_datapoint_cache(self):
pdb_ids = list(self.dataset_cache.structure_data.keys())

def null_safe_token_count(x):
token_count = self.dataset_cache.structure_data[x].token_count
elem = self.dataset_cache.structure_data[x]
token_count = elem.token_count if hasattr(elem, "token_count") else None
return token_count if token_count is not None else 0

pdb_ids = sorted(
pdb_ids,
key=null_safe_token_count,
reverse=False,
)

_datapoint_cache = pd.DataFrame({"pdb_id": pdb_ids})
self.datapoint_cache = pad_to_world_size(_datapoint_cache, self.world_size)

Expand Down Expand Up @@ -186,15 +189,19 @@ def get_validation_homology_features(self, pdb_id: str, sample_data: dict) -> di

structure_entry = self.dataset_cache.structure_data[pdb_id]

def _use_metrics(x):
"""Check if the chain or interface should be used for metrics."""
return x.use_metrics if hasattr(x, "use_metrics") else True

chains_for_intra_metrics = [
int(cid)
for cid, cdata in structure_entry.chains.items()
if cdata.use_metrics
if _use_metrics(cdata)
]

interfaces_to_include = []
for interface_id, cluster_data in structure_entry.interfaces.items():
if cluster_data.use_metrics:
if _use_metrics(cluster_data):
interface_chains = tuple(int(ci) for ci in interface_id.split("_"))
interfaces_to_include.append(interface_chains)

Expand Down
5 changes: 4 additions & 1 deletion openfold3/core/data/primitives/structure/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def sample_templates(
dict[str, TemplateCacheEntry] | dict[None]:
The sampled template data per chain given chain.
"""
if not template_structure_array_directory and not template_cache_directory:
return {}

chain_data = assembly_data[chain_id]
template_ids = chain_data["template_ids"]
if not template_ids:
Expand Down Expand Up @@ -200,7 +203,7 @@ def sample_templates(
else:
k = np.min([np.random.randint(0, l), n_templates])

if k > 0:
if (k > 0) and (template_cache_directory is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the parentheses around the individual statements be removed? I believe this will cause an issue for our ruff formatter.

# Load template cache entry numpy file
# From the representative ID during training
if "alignment_representative_id" in chain_data:
Expand Down
3 changes: 2 additions & 1 deletion openfold3/entry_points/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def manual_load_checkpoint(self):
self.lightning_module.load_state_dict(
state_dict, strict=self.ckpt_load_settings.strict_loading
)
self.lightning_module.ema.load_state_dict(ckpt["ema"])
if "ema" in ckpt:
self.lightning_module.ema.load_state_dict(ckpt["ema"])

if self.ckpt_load_settings.restore_lr_scheduler:
last_global_step = int(ckpt["global_step"])
Expand Down
4 changes: 4 additions & 0 deletions openfold3/entry_points/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,12 @@ def _maybe_download_parameters(target_path: Path) -> None:
class CheckpointConfig(BaseModel):
"""Settings for training checkpoint writing."""

model_config = PydanticConfigDict(extra="allow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is extra="allow" required to support backwards compatibility with old versions of the code?

We generally prefer to set extra="forbid" to help catch instances where a field might be ignored if it is not recognized by the Model (e.g. in the case of a typo or a new field).

monitor: str | None = None
every_n_epochs: int = 1
auto_insert_metric_name: bool = False
filename: str | None = None
enable_version_counter: bool = True
save_last: bool = True
save_top_k: int = -1

Expand Down
5 changes: 3 additions & 2 deletions openfold3/projects/of3_all_atom/config/dataset_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ def _validate_exactly_one_path_exists(
group_name: str, path_values: list[Path | None]
):
which_paths_exist = [p is not None for p in path_values]
if sum(which_paths_exist) != 1:
if sum(which_paths_exist) > 1:
existing_paths = [
p for p, b in zip(path_values, which_paths_exist, strict=True) if b
]
raise ValueError(
f"Exactly one path in set of {group_name} should exist."
"If there is a template folder, "
f"exactly one path in set of {group_name} should exist."
f"Found {existing_paths} exist."
)

Expand Down