How to avoid recomputation #6245
-
I used map for this calculation. Expect it to be faster with with_transform. But it didn't work. And how do I guarantee not to recalculate? I found that just changing the num_proc parameter caused a recalculation. image_transforms = transforms.Compose(
[
transforms.Resize(
args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
images = [image_transforms(image) for image in images]
return {
"pixel_values": images,
}
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = (
dataset["train"]
.shuffle(seed=args.seed)
.select(range(args.max_train_samples))
)
# Set the training transforms
if args.load_dataset_streaming:
train_dataset = dataset["train"].map(
preprocess_train,
batched=True,
)
train_dataset = train_dataset.shuffle(seed=args.seed)
else:
if args.dataset_map:
train_dataset = dataset["train"].map(
preprocess_train,
batch_size=args.train_batch_size,
batched=True,
num_proc=args.load_dataset_num_proc,
)
else:
train_dataset = dataset["train"].with_transform(preprocess_train)
print(type(train_dataset[0]['pixel_values'])) the output is |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
I noticed that I need to use dataset.map and need to use torch.tensor conversion in collate_fn of dataloader. Now I want to know how to control whether recalculation is needed. I just change num_proc and it automatically triggers recalculation. They use same cache, right? |
Beta Was this translation helpful? Give feedback.
-
The default formatting returns the built-in types as values (lists, dictionaries, etc.). To get Changing |
Beta Was this translation helpful? Give feedback.
-
@mariosasko thanks |
Beta Was this translation helpful? Give feedback.
The default formatting returns the built-in types as values (lists, dictionaries, etc.). To get
torch
tensors, use.set_format("pt")
on the dataset object.Changing
num_proc
in many scenarios leads to a slightly different result (e.g., tokenization with truncation in the batched mode), which is why it requires re-computation, as we cannot be sure the result will be the same.