|
10 | 10 | from azureml.automl.runtime.shared.score import scoring, constants
|
11 | 11 | from azureml.core import Run
|
12 | 12 |
|
| 13 | +import torch |
| 14 | + |
13 | 15 |
|
14 | 16 | def align_outputs(y_predicted, X_trans, X_test, y_test,
|
15 | 17 | predicted_column_name='predicted',
|
@@ -221,6 +223,10 @@ def MAPE(actual, pred):
|
221 | 223 | return np.mean(APE(actual_safe, pred_safe))
|
222 | 224 |
|
223 | 225 |
|
| 226 | +def map_location_cuda(storage, loc): |
| 227 | + return storage.cuda() |
| 228 | + |
| 229 | + |
224 | 230 | parser = argparse.ArgumentParser()
|
225 | 231 | parser.add_argument(
|
226 | 232 | '--max_horizon', type=int, dest='max_horizon',
|
@@ -274,8 +280,13 @@ def MAPE(actual, pred):
|
274 | 280 | y_lookback_df = lookback_dataset.with_timestamp_columns(
|
275 | 281 | None).keep_columns(columns=[target_column_name])
|
276 | 282 |
|
277 |
| -fitted_model = joblib.load(model_path) |
278 |
| - |
| 283 | +# Load the trained model with torch. |
| 284 | +if torch.cuda.is_available(): |
| 285 | + map_location = map_location_cuda |
| 286 | +else: |
| 287 | + map_location = 'cpu' |
| 288 | +with open(model_path, 'rb') as fh: |
| 289 | + fitted_model = torch.load(fh, map_location=map_location) |
279 | 290 |
|
280 | 291 | if hasattr(fitted_model, 'get_lookback'):
|
281 | 292 | lookback = fitted_model.get_lookback()
|
|
0 commit comments