Skip to content

Commit 77fb321

Browse files
committed
Support loading retrieved preds for evaluation as well
1 parent 704f0a1 commit 77fb321

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

eval.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hydra
55
import nltk
66
import pytorch_lightning as pl
7+
import wandb
78
from omegaconf import OmegaConf
89

910
from conf import EvalConfig
@@ -34,8 +35,6 @@ def main(cfg: EvalConfig) -> None:
3435
shift_labels=cfg.model.configuration != "decoder",
3536
process_retrieved=cfg.model.configuration == "race",
3637
)
37-
dm.prepare_data(stage="test")
38-
dm.setup(stage=cfg.stage)
3938

4039
if cfg.logger.use_wandb:
4140
if cfg.logger.use_api_key:
@@ -48,6 +47,26 @@ def main(cfg: EvalConfig) -> None:
4847
job_type="eval",
4948
)
5049

50+
if cfg.model.configuration == "race":
51+
# download retrieved examples
52+
artifact = wandb.use_artifact(
53+
"codet5"
54+
+ ("_with-history" if cfg.input.train_with_history else "_without-history")
55+
+ "_retrieval:latest",
56+
type="retrieval",
57+
)
58+
59+
for part in ["train", "val", "test"]:
60+
artifact.get_path(f"{part}_predictions.jsonl").download(
61+
root=os.path.join(
62+
hydra.utils.to_absolute_path(dm.get_root_dir_for_part(cfg.dataset.dataset_root, part)),
63+
"retrieval" + ("_with_history" if cfg.input.train_with_history else "_without_history"),
64+
)
65+
)
66+
67+
dm.prepare_data(stage="test")
68+
dm.setup(stage=cfg.stage)
69+
5170
run_name = WandbOrganizer.get_run_name(
5271
cfg.model,
5372
encoder_input_type=cfg.input.encoder_input_type,

0 commit comments

Comments
 (0)