Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix conv2d_grad error in auto parallel training #68586

Conversation

jeff41404
Copy link
Contributor

@jeff41404 jeff41404 commented Oct 9, 2024

PR Category

Auto Parallel

PR Types

Bug fixes

Description

pcard-87125
When training model in auto parallel, if there is paddle.nn.Conv2D in the network and its parameters need to be trained, an error occurs when calculating the gradient in backward.
image
The following simple program can reproduce this issue:

# -*- coding: UTF-8 -*-
import numpy as np
import paddle
# 导入动态图半自动接口
from paddle.distributed import ProcessMesh, shard_dataloader, shard_optimizer, ShardingStage1
# 导入数据加载和数据保存接口
from paddle.io import Dataset, BatchSampler, DataLoader

epoch = 5  #训练迭代次数
batch_num = 100 #每次迭代的 batch 数
batch_size = 32 #训练批次大小
class_dim = 10
in_channels = 4
out_channels = 8

# 设置数据读取器
class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([4, 32, 32]).astype('float32')
        label = np.random.randint(0, class_dim - 1, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

# 模型网络
class SimpleNet(paddle.nn.Layer):
    def __init__(self, input_size, inner_size, output_size):
        super().__init__()
        self.conv2d = paddle.nn.Conv2D(in_channels, out_channels, kernel_size=2, stride=1, bias_attr=False)
        self.linear = paddle.nn.Linear(out_channels*31*31, class_dim)
        self.relu = paddle.nn.ReLU()

    def forward(self, x):
        x = self.conv2d(x).flatten(1)
        x = self.linear(x)
        x = self.relu(x)
        return x

# 设置mesh
world_process_mesh = ProcessMesh([0, 1], dim_names=["dp"])

# 设置训练函数
def train_model():
    model = SimpleNet(input_size=256, inner_size=102400, output_size=class_dim)
    loss_func = paddle.nn.CrossEntropyLoss()
    optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())

    # 设置dataset和dataloader,用于并行训练
    dataset = RandomDataset(batch_num * batch_size)
    # 设置批采样器,用于数据并行训练
    sampler = BatchSampler(dataset,
                        batch_size=batch_size,shuffle=False, drop_last=True)
    train_loader = DataLoader(dataset,
                            batch_sampler=sampler,
                            num_workers=1)
    dist_dataloader = shard_dataloader(dataloader=train_loader, meshes=world_process_mesh, shard_dims="dp")
    dist_opt = shard_optimizer(optimizer, ShardingStage1(world_process_mesh))

    for eop in range(epoch):
        model.train()

        for batch_id, data in enumerate(dist_dataloader()):
            img, label = data
            label.stop_gradient = True

            out = model(img)
            avg_loss = loss_func(input=out, label=label)

            avg_loss.backward()   # will report error
            dist_opt.step()
            model.clear_gradients()

            if batch_id % 5 == 0:
                print("[Epoch %d, batch %d] loss: %.5f" % (eop, batch_id, np.array(avg_loss)))
# 启动训练
if __name__ == '__main__':
    train_model()

this PR will fix this issue.

Copy link

paddle-bot bot commented Oct 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404 jeff41404 merged commit b7624cf into PaddlePaddle:develop Oct 10, 2024
26 of 27 checks passed
@jeff41404 jeff41404 deleted the fix_conv2d_grad_error_in_auto_parallel_training branch October 10, 2024 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants