Skip to content

Commit

Permalink
[PaddlePaddle] set get_dataloader_workers to 4 (d2l-ai#1203)
Browse files Browse the repository at this point in the history
* [PaddlePaddle] set get_dataloader_workers to 4

* Rerun lr scheduler.md

* Try to resolve Termination signal

* Remove extra comments
  • Loading branch information
吴高升 authored Sep 7, 2022
1 parent 0b8045d commit 8dc2236
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion chapter_linear-networks/image-classification-dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,12 @@ batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4 if ('cpu' in paddle.device.get_device()) else 0
return 4
train_iter = paddle.io.DataLoader(dataset=mnist_train,
batch_size=batch_size,
shuffle=True,
return_list=True,
num_workers=get_dataloader_workers())
```

Expand Down Expand Up @@ -348,9 +349,11 @@ def load_data_fashion_mnist(batch_size, resize=None):
return (paddle.io.DataLoader(dataset=mnist_train,
batch_size=batch_size,
shuffle=True,
return_list=True,
num_workers=get_dataloader_workers()),
paddle.io.DataLoader(dataset=mnist_test,
batch_size=batch_size,
return_list=True,
shuffle=True,
num_workers=get_dataloader_workers()))
```
Expand Down
2 changes: 1 addition & 1 deletion d2l/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def get_dataloader_workers():
"""使用4个进程来读取数据
Defined in :numref:`sec_fashion_mnist`"""
return 4 if ('cpu' in paddle.device.get_device()) else 0
return 4

def load_data_fashion_mnist(batch_size, resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中
Expand Down

0 comments on commit 8dc2236

Please sign in to comment.