@@ -65,16 +65,16 @@ def inference(cfg: DictConfig):
65
65
# The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000',
66
66
# 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500',
67
67
# 'v850', 'v1000', 'mslp', 'u10', 'v10', 't2m'.
68
- input_file = read_h5py (cfg .INFER .input_file )
69
- nwp_file = read_h5py (cfg .INFER .nwp_file )
70
- geo_file = read_h5py (cfg .INFER .geo_file )
68
+ input_data = read_h5py (cfg .INFER .input_file )
69
+ nwp_data = read_h5py (cfg .INFER .nwp_file )
70
+ geo_data = read_h5py (cfg .INFER .geo_file )
71
71
72
72
# input_data.shape: (1, 24, 440, 408)
73
- input_data = input_file [0 :1 ]
73
+ input_data_0 = input_data [0 :1 ]
74
74
# nwp_data.shape: # (num_timestamps, 24, 440, 408)
75
- nwp_data = nwp_file [0 :num_timestamps ]
75
+ nwp_data = nwp_data [0 :num_timestamps ]
76
76
# ground_truth.shape: (num_timestamps, 24, 440, 408)
77
- ground_truth = input_file [1 : num_timestamps + 1 ]
77
+ ground_truth = input_data [1 : num_timestamps + 1 ]
78
78
79
79
# create time stamps
80
80
cur_time = pd .to_datetime (cfg .INFER .init_time , format = "%Y/%m/%d/%H" )
@@ -84,7 +84,7 @@ def inference(cfg: DictConfig):
84
84
time_stamps .append ([cur_time ])
85
85
86
86
# run predictor
87
- pred_data = predictor .predict (input_data , time_stamps , nwp_data , geo_file )
87
+ pred_data = predictor .predict (input_data_0 , time_stamps , nwp_data , geo_data )
88
88
pred_data = pred_data .squeeze (axis = 1 ) # (num_timestamps, 24, 440, 408)
89
89
90
90
# save predict data
0 commit comments