forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_shuffle_data_loader.py
138 lines (120 loc) · 4.23 KB
/
dataset_shuffle_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import json
import time
import ray
from pyarrow import fs
import numpy as np
import torch
PATHS = {
"aws": [
f"s3://shuffling-data-loader-benchmarks/data/input_data_{i}.parquet.snappy"
for i in range(0, 25)
],
"gcp": [
f"gcs://shuffling-data-loader-benchmarks/data/input_data_{i}.parquet.snappy"
for i in range(0, 25)
],
}
def create_parser():
parser = argparse.ArgumentParser(description="Dataset shuffle")
parser.add_argument("--address", type=str, default=os.environ["RAY_ADDRESS"])
parser.add_argument(
"--batch-size",
type=int,
default=250000,
metavar="N",
help="input batch size for training (default: 250000)",
)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--repeat-times", type=int, default=16)
parser.add_argument("--cloud", type=str, choices=["aws", "gcp"])
return parser
def create_torch_iterator(split, batch_size, rank=None):
print(
f"Creating Torch shuffling dataset for worker {rank} with "
f"{batch_size} batch size."
)
numpy_to_torch_dtype = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
DATA_SPEC = {
"embeddings_name0": (0, 2385, np.int64),
"embeddings_name1": (0, 201, np.int64),
"embeddings_name2": (0, 201, np.int64),
"embeddings_name3": (0, 6, np.int64),
"embeddings_name4": (0, 19, np.int64),
"embeddings_name5": (0, 1441, np.int64),
"embeddings_name6": (0, 201, np.int64),
"embeddings_name7": (0, 22, np.int64),
"embeddings_name8": (0, 156, np.int64),
"embeddings_name9": (0, 1216, np.int64),
"embeddings_name10": (0, 9216, np.int64),
"embeddings_name11": (0, 88999, np.int64),
"embeddings_name12": (0, 941792, np.int64),
"embeddings_name13": (0, 9405, np.int64),
"embeddings_name14": (0, 83332, np.int64),
"embeddings_name15": (0, 828767, np.int64),
"embeddings_name16": (0, 945195, np.int64),
"one_hot0": (0, 3, np.int64),
"one_hot1": (0, 50, np.int64),
"labels": (0, 1, np.float64),
}
feature_columns = list(DATA_SPEC.keys())
feature_types = [numpy_to_torch_dtype[dtype] for _, _, dtype in DATA_SPEC.values()]
label_column = feature_columns.pop()
label_type = feature_types.pop()
torch_iterator = split.to_torch(
label_column=label_column,
feature_columns=feature_columns,
label_column_dtype=label_type,
feature_column_dtypes=feature_types[0],
batch_size=batch_size,
)
return torch_iterator
def create_dataset(filenames, repeat_times, cloud):
if cloud == "gcp":
filesystem = fs.GcsFileSystem()
else:
filesystem = None
pipeline = (
ray.data.read_parquet(list(filenames), filesystem=filesystem)
.repeat(times=repeat_times)
.random_shuffle_each_window()
)
return pipeline
if __name__ == "__main__":
parser = create_parser()
args = parser.parse_args()
print("Connecting to Ray cluster...")
ray.init(address=args.address)
start = time.time()
pipeline = create_dataset(PATHS[args.cloud], args.repeat_times, args.cloud)
splits = pipeline.split(args.num_workers)
@ray.remote(num_gpus=1)
def consume(split, rank=None, batch_size=None):
torch_iterator = create_torch_iterator(split, batch_size=batch_size, rank=rank)
for i, (x, y) in enumerate(torch_iterator):
time.sleep(1)
if i % 10 == 0:
print(i)
return
tasks = [
consume.remote(split, rank=idx, batch_size=args.batch_size)
for idx, split in enumerate(splits)
]
ray.get(tasks)
delta = time.time() - start
print(f"success! total time {delta}")
with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
f.write(json.dumps({"shuffle_time": delta, "success": 1}))