|
50 | 50 |
|
51 | 51 | import pickle |
52 | 52 | import re |
| 53 | +import shutil |
53 | 54 | import sys |
54 | 55 |
|
55 | 56 | from annoy import AnnoyIndex |
@@ -88,12 +89,14 @@ def query_model(query, model, indices, language, topk=100): |
88 | 89 | sys.exit(1) |
89 | 90 | wandb_api = wandb.Api() |
90 | 91 | # retrieve saved model from W&B for this run |
| 92 | + print("Fetching run from W&B...") |
91 | 93 | try: |
92 | 94 | run = wandb_api.run(args_wandb_run_id) |
93 | 95 | except wandb.CommError as e: |
94 | 96 | print("ERROR: Problem querying W&B for wandb_run_id: %s" % args_wandb_run_id, file=sys.stderr) |
95 | 97 | sys.exit(1) |
96 | 98 |
|
| 99 | + print("Fetching run files from W&B...") |
97 | 100 | gz_run_files = [f for f in run.files() if f.name.endswith('gz')] |
98 | 101 | if not gz_run_files: |
99 | 102 | print("ERROR: Run contains no model-like files") |
@@ -129,10 +132,18 @@ def query_model(query, model, indices, language, topk=100): |
129 | 132 | df = pd.DataFrame(predictions, columns=['query', 'language', 'identifier', 'url']) |
130 | 133 | df.to_csv(predictions_csv, index=False) |
131 | 134 |
|
| 135 | + |
132 | 136 | if run_id: |
| 137 | + print('Uploading predictions to W&B') |
133 | 138 | # upload model predictions CSV file to W&B |
134 | 139 |
|
135 | 140 | # we checked that there are three path components above |
136 | 141 | entity, project, name = args_wandb_run_id.split('/') |
| 142 | + |
| 143 | + # make sure the file is in our cwd, with the correct name |
| 144 | + predictions_base_csv = "model_predictions.csv" |
| 145 | + shutil.copyfile(predictions_csv, predictions_base_csv) |
| 146 | + |
| 147 | + # Using internal wandb API. TODO: Update when available as a public API |
137 | 148 | internal_api = InternalApi() |
138 | | - internal_api.push([predictions_csv], run=name, entity=entity, project=project) |
| 149 | + internal_api.push([predictions_base_csv], run=name, entity=entity, project=project) |
0 commit comments