Skip to content

Commit

Permalink
Fix checkpoint_model (#85)
Browse files Browse the repository at this point in the history
* rename single-letter variables

* ruff auto format

* fix checkpoint_model call in train_model

mypy v1.11.0 caught following issues
aviary/train.py:320: error: Unexpected keyword argument "model" for "checkpoint_model"  [call-arg]
aviary/train.py:320: error: Unexpected keyword argument "epoch" for "checkpoint_model"; did you mean "epochs"?  [call-arg]
aviary/train.py:328: error: Argument "timestamp" to "checkpoint_model" has incompatible type "str | None"; expected "str"  [arg-type]

* fix get_formula_from_protostructure_label below 3.10

* refactor
  • Loading branch information
janosh authored Jul 27, 2024
1 parent 3289c4d commit d1bbfe1
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.3
rev: v0.5.5
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 1 addition & 3 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def __getitem__(self, idx: int):
site_atoms = [atom.species.as_dict() for atom in struct]
atom_features = np.vstack(
[
np.sum(
[self.elem_features[el] * amt for el, amt in site.items()], axis=0
)
np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0)
for site in site_atoms
]
)
Expand Down
4 changes: 1 addition & 3 deletions aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@ def evaluate(
# *_ discards identifiers like material_id and formula which we don't need when
# training tqdm(disable=None) means suppress output in non-tty (e.g. CI/log
# files) but keep in terminal (i.e. tty mode) https://git.io/JnBOi
for inputs, targets_list, *_ in tqdm(
data_loader, disable=None if pbar else True
):
for inputs, targets_list, *_ in tqdm(data_loader, disable=None if pbar else True):
inputs = [ # noqa: PLW2901
tensor.to(self.device) if hasattr(tensor, "to") else tensor
for tensor in inputs
Expand Down
18 changes: 9 additions & 9 deletions aviary/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def __init__(
dims = [input_dim, *list(hidden_layer_dims)]

self.fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)
nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1)
)

if batch_norm:
self.bns = nn.ModuleList(
nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)
nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1)
)
else:
self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1))
self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1))

self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1))

Expand Down Expand Up @@ -95,21 +95,21 @@ def __init__(
dims = [input_dim, *list(hidden_layer_dims)]

self.fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)
nn.Linear(dims[idx], dims[idx + 1]) for idx in range(len(dims) - 1)
)

if batch_norm:
self.bns = nn.ModuleList(
nn.BatchNorm1d(dims[i + 1]) for i in range(len(dims) - 1)
nn.BatchNorm1d(dims[idx + 1]) for idx in range(len(dims) - 1)
)
else:
self.bns = nn.ModuleList(nn.Identity() for i in range(len(dims) - 1))
self.bns = nn.ModuleList(nn.Identity() for _ in range(len(dims) - 1))

self.res_fcs = nn.ModuleList(
nn.Linear(dims[i], dims[i + 1], bias=False)
if (dims[i] != dims[i + 1])
nn.Linear(dims[idx], dims[idx + 1], bias=False)
if (dims[idx] != dims[idx + 1])
else nn.Identity()
for i in range(len(dims) - 1)
for idx in range(len(dims) - 1)
)
self.acts = nn.ModuleList(activation() for _ in range(len(dims) - 1))

Expand Down
4 changes: 2 additions & 2 deletions aviary/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def make_ensemble_predictions(
# (i.e. tty mode) https://git.io/JnBOi
print(f"Pytorch running on {device=}")
for idx, checkpoint_path in tqdm(
enumerate(tqdm(checkpoint_paths), 1), disable=None if pbar else True
enumerate(tqdm(checkpoint_paths), start=1), disable=None if pbar else True
):
try:
checkpoint = torch.load(checkpoint_path, map_location=device)
Expand Down Expand Up @@ -189,7 +189,7 @@ def predict_from_wandb_checkpoints(

checkpoint_paths: list[str] = []

for idx, run in enumerate(runs, 1):
for idx, run in enumerate(runs, start=1):
run_path = "/".join(run.path)
out_dir = f"{cache_dir}/{run_path}"
os.makedirs(out_dir, exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def collate_batch(
batch_cry_ids = []

cry_base_idx = 0
for i, (inputs, target, *cry_ids) in enumerate(samples):
for idx, (inputs, target, *cry_ids) in enumerate(samples):
elem_weights, elem_fea, self_idx, nbr_idx = inputs

n_sites = elem_fea.shape[0] # number of atoms for this crystal
Expand All @@ -190,7 +190,7 @@ def collate_batch(
batch_nbr_idx.append(nbr_idx + cry_base_idx)

# mapping from atoms to crystals
crystal_elem_idx.append(torch.tensor([i] * n_sites))
crystal_elem_idx.append(torch.tensor([idx] * n_sites))

# batch the targets and ids
batch_targets.append(target)
Expand Down
12 changes: 7 additions & 5 deletions aviary/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand All @@ -20,7 +21,7 @@
try:
import wandb
except ImportError:
wandb = None
wandb = None # type: ignore[assignment]

if TYPE_CHECKING:
from torch import nn
Expand Down Expand Up @@ -319,13 +320,14 @@ def train_model(
if checkpoint is not None:
checkpoint_model(
checkpoint_endpoint=checkpoint,
model=inference_model,
model_params=model_params,
inference_model=inference_model,
optimizer_instance=optimizer_instance,
lr_scheduler=lr_scheduler,
loss_dict=loss_dict,
epoch=epochs,
epochs=epochs,
test_metrics=test_metrics,
timestamp=timestamp,
timestamp=timestamp or datetime.now().astimezone().strftime("%Y%m%d-%H%M%S"),
run_name=run_name,
normalizer_dict=normalizer_dict,
run_params=run_params,
Expand Down Expand Up @@ -364,7 +366,7 @@ def train_model(

def checkpoint_model(
checkpoint_endpoint: str,
model_params: dict,
model_params: dict | None,
inference_model: nn.Module,
optimizer_instance: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
Expand Down
16 changes: 4 additions & 12 deletions aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,9 @@ def initialize_optim(
momentum=momentum,
)
elif optim == "Adam":
optimizer = Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optim == "AdamW":
optimizer = AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
raise NameError("Only SGD, Adam or AdamW are allowed as --optim")

Expand Down Expand Up @@ -361,9 +357,7 @@ def train_ensemble(
sample_target = Tensor(train_set.df[target].values)
if not restart_params["resume"]:
normalizer.fit(sample_target)
print(
f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}"
)
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")

if log:
writer = SummaryWriter(
Expand Down Expand Up @@ -528,9 +522,7 @@ def results_multitask(
elif task_type == "classification":
if model.robust:
mean, log_std = output.chunk(2, dim=1)
logits = (
sampled_softmax(mean, log_std, samples=10).data.cpu().numpy()
)
logits = sampled_softmax(mean, log_std, samples=10).data.cpu().numpy()
pre_logits = mean.data.cpu().numpy()
pre_logits_std = torch.exp(log_std).data.cpu().numpy()
res_dict["pre-logits_ale"].append(pre_logits_std) # type: ignore[union-attr]
Expand Down
14 changes: 6 additions & 8 deletions aviary/wren/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ def __getitem__(self, idx: int):
n_wyks = len(elements)
self_idx = []
nbr_idx = []
for i in range(n_wyks):
self_idx += [i] * n_wyks
for wyk_idx in range(n_wyks):
self_idx += [wyk_idx] * n_wyks
nbr_idx += list(range(n_wyks))

self_aug_fea_idx = []
nbr_aug_fea_idx = []
n_aug = len(augmented_wyks)
for i in range(n_aug):
self_aug_fea_idx += [x + i * n_wyks for x in self_idx]
nbr_aug_fea_idx += [x + i * n_wyks for x in nbr_idx]
for aug_idx in range(n_aug):
self_aug_fea_idx += [x + aug_idx * n_wyks for x in self_idx]
nbr_aug_fea_idx += [x + aug_idx * n_wyks for x in nbr_idx]

# convert all data to tensors
wyckoff_weights = Tensor(wyk_site_multiplcities)
Expand Down Expand Up @@ -291,9 +291,7 @@ def parse_protostructure_label(
)

# Separate out pairs of Wyckoff letters and their number of occurrences
sep_n_wyks = [
"".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)
]
sep_n_wyks = ["".join(g) for _, g in groupby(wyk_letters_normalized, str.isalpha)]

# Process Wyckoff letters and multiplicities
mults = map(int, sep_n_wyks[0::2])
Expand Down
6 changes: 2 additions & 4 deletions aviary/wren/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
msg_gate_layers=elem_gate,
msg_net_layers=elem_msg,
)
for i in range(n_graph)
for _ in range(n_graph)
)

# define a global pooling function for materials
Expand Down Expand Up @@ -259,9 +259,7 @@ def forward(
for attnhead in self.cry_pool
]

return scatter_mean(
torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0
)
return scatter_mean(torch.mean(torch.stack(head_fea), dim=0), aug_cry_idx, dim=0)

def __repr__(self) -> str:
return (
Expand Down
35 changes: 12 additions & 23 deletions aviary/wren/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,11 @@ def sort_and_score_element_wyckoffs(element_wyckoffs: str) -> tuple[str, int]:
wp_counts = split_alpha_numeric(el_wyks)
sorted_element_wyckoffs.append(
"".join(
[
f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter
for count, wyckoff_letter in sorted(
zip(wp_counts["numeric"], wp_counts["alpha"]),
key=lambda x: x[1],
)
]
f"{count}{wyckoff_letter}" if count != "1" else wyckoff_letter
for count, wyckoff_letter in sorted(
zip(wp_counts["numeric"], wp_counts["alpha"]),
key=lambda x: x[1],
)
)
)
score += sum(
Expand All @@ -391,19 +389,19 @@ def get_prototype_formula_from_composition(composition: Composition) -> str:
"""
reduced = composition.element_composition
if all(x == int(x) for x in composition.values()):
reduced /= gcd(*(int(i) for i in composition.values()))
reduced /= gcd(*(int(amt) for amt in composition.values()))

amounts = [reduced[key] for key in sorted(reduced, key=str)]

anon = ""
for e, amt in zip(ascii_uppercase, amounts):
for elem, amt in zip(ascii_uppercase, amounts):
if amt == 1:
amt_str = ""
elif abs(amt % 1) < 1e-8:
amt_str = str(int(amt))
else:
amt_str = str(amt)
anon += f"{e}{amt_str}"
anon += f"{elem}{amt_str}"
return anon


Expand All @@ -415,13 +413,8 @@ def get_anonymous_formula_from_prototype_formula(prototype_formula: str) -> str:
anom_list = split_alpha_numeric(prototype_formula)

return "".join(
[
f"{el}{num}" if num != 1 else el
for el, num in zip(
anom_list["alpha"],
sorted(map(int, anom_list["numeric"])),
)
]
f"{el}{num}" if num != 1 else el
for el, num in zip(anom_list["alpha"], sorted(map(int, anom_list["numeric"])))
)


Expand All @@ -435,12 +428,8 @@ def get_formula_from_protostructure_label(protostructure_label: str) -> str:
anom_list = split_alpha_numeric(prototype_formula)

return "".join(
[
f"{el}{num}" if num != 1 else el
for el, num in zip(
chemsys.split("-"), map(int, anom_list["numeric"]), strict=True
)
]
f"{el}{num}" if num != 1 else el
for el, num in zip(chemsys.split("-"), map(int, anom_list["numeric"]))
)


Expand Down
8 changes: 2 additions & 6 deletions aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def collate_batch(


@cache
def get_wyckoff_features(
equivalent_wyckoff_set: list[tuple], spg_num: int
) -> np.ndarray:
def get_wyckoff_features(equivalent_wyckoff_set: list[tuple], spg_num: int) -> np.ndarray:
"""Get Wyckoff set features from the precomputed dictionary. The output of this
function is cached for speed.
Expand Down Expand Up @@ -204,6 +202,4 @@ def df_to_in_mem_dataloader(
inputs[idx] = tensor.to(device)

ids = df.get(id_col, df.index).to_numpy()
return InMemoryDataLoader(
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
)
return InMemoryDataLoader([inputs, targets, ids], collate_fn=collate_batch, **kwargs)
4 changes: 2 additions & 2 deletions examples/wrenformer/mat_bench/make_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# %% --- Load other's scores ---
# load benchmark data for models with existing Matbench submission
for idx, dirname in enumerate(glob(f"{bench_dir}/*"), 1):
for idx, dirname in enumerate(glob(f"{bench_dir}/*"), start=1):
model_name = dirname.split("/matbench_v0.1_")[-1]
print(f"{idx}. {model_name}")
mbbm = MatbenchBenchmark.from_file(f"{dirname}/results.json.gz")
Expand All @@ -62,7 +62,7 @@
# %% --- Load our scores ---
our_score_files = sorted(glob("model_scores/*.json"), key=lambda s: s.split("@")[0])

for idx, filename in enumerate(our_score_files, 1):
for idx, filename in enumerate(our_score_files, start=1):
date, model_name = re.split(r"@\d\d-\d\d-", filename.split("/")[-1])

print(f"{idx}. {date} {model_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

benchmark = MatbenchBenchmark()

for idx, task in enumerate(benchmark.tasks, 1):
for idx, task in enumerate(benchmark.tasks, start=1):
print(f"\n\n{idx}/{len(benchmark.tasks)}")
task.load()
df: pd.DataFrame = task.df
Expand Down
4 changes: 1 addition & 3 deletions examples/wrenformer/mat_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,5 @@ def non_serializable_handler(obj: object) -> str:
return f"<not serializable: {type(obj).__qualname__}>"

with open(file_path, "w") as file:
default = (
non_serializable_handler if on_non_serializable == "annotate" else None
)
default = non_serializable_handler if on_non_serializable == "annotate" else None
json.dump(dct, file, default=default, indent=2)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def df_matbench_phonons():
"""Returns the dataframe for the Matbench phonon DOS peak task."""

df = load_dataset("matbench_phonons")
df["material_id"] = [f"mb_phdos_{i + 1}" for i in range(len(df))]
df["material_id"] = [f"mb_phdos_{idx + 1}" for idx in range(len(df))]
df = df.set_index("material_id", drop=False)
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]

Expand All @@ -33,7 +33,7 @@ def df_matbench_jdft2d():
"""Returns Matbench experimental band gap task dataframe. Currently unused."""

df = load_dataset("matbench_jdft2d")
df["material_id"] = [f"mb_jdft2d_{i + 1}" for i in range(len(df))]
df["material_id"] = [f"mb_jdft2d_{idx + 1}" for idx in range(len(df))]
df = df.set_index("material_id", drop=False)
df["composition"] = [x.composition.formula.replace(" ", "") for x in df.structure]

Expand Down
Loading

0 comments on commit d1bbfe1

Please sign in to comment.