Skip to content

Commit 483920c

Browse files
refine
1 parent 6506806 commit 483920c

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

examples/yinglong/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ tar -xvf inference.tar
5555

5656
### 2. Run the code
5757

58-
The following code runs the YingLong model, and the model output will be saved in 'outputs_yinglong_eastern(western)/result.npy'.
58+
The following code runs the YingLong model, and the model output will be saved in `outputs_yinglong_eastern(western)/result.npy`.
5959

6060
``` shell
6161
# yinglong_eastern

examples/yinglong/predict.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,16 @@ def inference(cfg: DictConfig):
6565
# The variable name is 'z50', 'z500', 'z850', 'z1000', 't50', 't500', 't850', 'z1000',
6666
# 's50', 's500', 's850', 's1000', 'u50', 'u500', 'u850', 'u1000', 'v50', 'v500',
6767
# 'v850', 'v1000', 'mslp', 'u10', 'v10', 't2m'.
68-
input_file = read_h5py(cfg.INFER.input_file)
69-
nwp_file = read_h5py(cfg.INFER.nwp_file)
70-
geo_file = read_h5py(cfg.INFER.geo_file)
68+
input_data = read_h5py(cfg.INFER.input_file)
69+
nwp_data = read_h5py(cfg.INFER.nwp_file)
70+
geo_data = read_h5py(cfg.INFER.geo_file)
7171

7272
# input_data.shape: (1, 24, 440, 408)
73-
input_data = input_file[0:1]
73+
input_data_0 = input_data[0:1]
7474
# nwp_data.shape: # (num_timestamps, 24, 440, 408)
75-
nwp_data = nwp_file[0:num_timestamps]
75+
nwp_data = nwp_data[0:num_timestamps]
7676
# ground_truth.shape: (num_timestamps, 24, 440, 408)
77-
ground_truth = input_file[1 : num_timestamps + 1]
77+
ground_truth = input_data[1 : num_timestamps + 1]
7878

7979
# create time stamps
8080
cur_time = pd.to_datetime(cfg.INFER.init_time, format="%Y/%m/%d/%H")
@@ -84,7 +84,7 @@ def inference(cfg: DictConfig):
8484
time_stamps.append([cur_time])
8585

8686
# run predictor
87-
pred_data = predictor.predict(input_data, time_stamps, nwp_data, geo_file)
87+
pred_data = predictor.predict(input_data_0, time_stamps, nwp_data, geo_data)
8888
pred_data = pred_data.squeeze(axis=1) # (num_timestamps, 24, 440, 408)
8989

9090
# save predict data

examples/yinglong/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def predict(
105105
input_data (np.ndarray): Input data of shape (N, T, H, W).
106106
time_stamps (List[List[pd.Timestamp]]): Timestamps data.
107107
nwp_data (np.ndarray): NWP data.
108-
geo_data (np.ndarray): Geo data.
108+
geo_data (np.ndarray): Geographic data.
109109
batch_size (int, optional): Batch size, now only support 1. Defaults to 1.
110110
111111
Returns:

0 commit comments

Comments
 (0)