Skip to content

Commit 322087a

Browse files
committed
update samples from Release-73 as a part of SDK release
1 parent e255c00 commit 322087a

File tree

1 file changed

+25
-20
lines changed
  • how-to-use-azureml/automated-machine-learning/forecasting-beer-remote

1 file changed

+25
-20
lines changed

how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/infer.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23

34
import numpy as np
45
import pandas as pd
@@ -10,7 +11,12 @@
1011
from azureml.automl.runtime.shared.score import scoring, constants
1112
from azureml.core import Run
1213

13-
import torch
14+
try:
15+
import torch
16+
17+
_torch_present = True
18+
except ImportError:
19+
_torch_present = False
1420

1521

1622
def align_outputs(y_predicted, X_trans, X_test, y_test,
@@ -50,7 +56,7 @@ def align_outputs(y_predicted, X_trans, X_test, y_test,
5056
# or at edges of time due to lags/rolling windows
5157
clean = together[together[[target_column_name,
5258
predicted_column_name]].notnull().all(axis=1)]
53-
return(clean)
59+
return (clean)
5460

5561

5662
def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,
@@ -85,8 +91,7 @@ def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,
8591
if origin_time != X[time_column_name].min():
8692
# Set the context by including actuals up-to the origin time
8793
test_context_expand_wind = (X[time_column_name] < origin_time)
88-
context_expand_wind = (
89-
X_test_expand[time_column_name] < origin_time)
94+
context_expand_wind = (X_test_expand[time_column_name] < origin_time)
9095
y_query_expand[context_expand_wind] = y[test_context_expand_wind]
9196

9297
# Print some debug info
@@ -117,8 +122,7 @@ def do_rolling_forecast_with_lookback(fitted_model, X_test, y_test,
117122
# Align forecast with test set for dates within
118123
# the current rolling window
119124
trans_tindex = X_trans.index.get_level_values(time_column_name)
120-
trans_roll_wind = (trans_tindex >= origin_time) & (
121-
trans_tindex < horizon_time)
125+
trans_roll_wind = (trans_tindex >= origin_time) & (trans_tindex < horizon_time)
122126
test_roll_wind = expand_wind & (X[time_column_name] >= origin_time)
123127
df_list.append(align_outputs(
124128
y_fcst[trans_roll_wind], X_trans[trans_roll_wind],
@@ -157,8 +161,7 @@ def do_rolling_forecast(fitted_model, X_test, y_test, max_horizon, freq='D'):
157161
if origin_time != X_test[time_column_name].min():
158162
# Set the context by including actuals up-to the origin time
159163
test_context_expand_wind = (X_test[time_column_name] < origin_time)
160-
context_expand_wind = (
161-
X_test_expand[time_column_name] < origin_time)
164+
context_expand_wind = (X_test_expand[time_column_name] < origin_time)
162165
y_query_expand[context_expand_wind] = y_test[
163166
test_context_expand_wind]
164167

@@ -188,10 +191,8 @@ def do_rolling_forecast(fitted_model, X_test, y_test, max_horizon, freq='D'):
188191
# Align forecast with test set for dates within the
189192
# current rolling window
190193
trans_tindex = X_trans.index.get_level_values(time_column_name)
191-
trans_roll_wind = (trans_tindex >= origin_time) & (
192-
trans_tindex < horizon_time)
193-
test_roll_wind = expand_wind & (
194-
X_test[time_column_name] >= origin_time)
194+
trans_roll_wind = (trans_tindex >= origin_time) & (trans_tindex < horizon_time)
195+
test_roll_wind = expand_wind & (X_test[time_column_name] >= origin_time)
195196
df_list.append(align_outputs(y_fcst[trans_roll_wind],
196197
X_trans[trans_roll_wind],
197198
X_test[test_roll_wind],
@@ -244,15 +245,13 @@ def map_location_cuda(storage, loc):
244245
'--model_path', type=str, dest='model_path',
245246
default='model.pkl', help='Filename of model to be loaded')
246247

247-
248248
args = parser.parse_args()
249249
max_horizon = args.max_horizon
250250
target_column_name = args.target_column_name
251251
time_column_name = args.time_column_name
252252
freq = args.freq
253253
model_path = args.model_path
254254

255-
256255
print('args passed are: ')
257256
print(max_horizon)
258257
print(target_column_name)
@@ -280,13 +279,19 @@ def map_location_cuda(storage, loc):
280279
y_lookback_df = lookback_dataset.with_timestamp_columns(
281280
None).keep_columns(columns=[target_column_name])
282281

283-
# Load the trained model with torch.
284-
if torch.cuda.is_available():
285-
map_location = map_location_cuda
282+
_, ext = os.path.splitext(model_path)
283+
if ext == '.pt':
284+
# Load the fc-tcn torch model.
285+
assert _torch_present
286+
if torch.cuda.is_available():
287+
map_location = map_location_cuda
288+
else:
289+
map_location = 'cpu'
290+
with open(model_path, 'rb') as fh:
291+
fitted_model = torch.load(fh, map_location=map_location)
286292
else:
287-
map_location = 'cpu'
288-
with open(model_path, 'rb') as fh:
289-
fitted_model = torch.load(fh, map_location=map_location)
293+
# Load the sklearn pipeline.
294+
fitted_model = joblib.load(model_path)
290295

291296
if hasattr(fitted_model, 'get_lookback'):
292297
lookback = fitted_model.get_lookback()

0 commit comments

Comments
 (0)