4
4
import hydra
5
5
import nltk
6
6
import pytorch_lightning as pl
7
+ import wandb
7
8
from omegaconf import OmegaConf
8
9
9
10
from conf import EvalConfig
@@ -34,8 +35,6 @@ def main(cfg: EvalConfig) -> None:
34
35
shift_labels = cfg .model .configuration != "decoder" ,
35
36
process_retrieved = cfg .model .configuration == "race" ,
36
37
)
37
- dm .prepare_data (stage = "test" )
38
- dm .setup (stage = cfg .stage )
39
38
40
39
if cfg .logger .use_wandb :
41
40
if cfg .logger .use_api_key :
@@ -48,6 +47,26 @@ def main(cfg: EvalConfig) -> None:
48
47
job_type = "eval" ,
49
48
)
50
49
50
+ if cfg .model .configuration == "race" :
51
+ # download retrieved examples
52
+ artifact = wandb .use_artifact (
53
+ "codet5"
54
+ + ("_with-history" if cfg .input .train_with_history else "_without-history" )
55
+ + "_retrieval:latest" ,
56
+ type = "retrieval" ,
57
+ )
58
+
59
+ for part in ["train" , "val" , "test" ]:
60
+ artifact .get_path (f"{ part } _predictions.jsonl" ).download (
61
+ root = os .path .join (
62
+ hydra .utils .to_absolute_path (dm .get_root_dir_for_part (cfg .dataset .dataset_root , part )),
63
+ "retrieval" + ("_with_history" if cfg .input .train_with_history else "_without_history" ),
64
+ )
65
+ )
66
+
67
+ dm .prepare_data (stage = "test" )
68
+ dm .setup (stage = cfg .stage )
69
+
51
70
run_name = WandbOrganizer .get_run_name (
52
71
cfg .model ,
53
72
encoder_input_type = cfg .input .encoder_input_type ,
0 commit comments