Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cugraph Examples Updates #9953

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

guanxingithub
Copy link
Contributor

We consolidate cugraph training with
Models: GAT, GCN, GraphSAGE
Datasets: ogbn-arxiv, ogbn-products, ogbn-papers100M
examples/ogbn_train_cugraph.py is the script for single gpu training
examples/multi_gpu/ogbn_train_cugraph_multigpu.py is the script for multiple gpu training.

@guanxingithub guanxingithub requested a review from wsad1 as a code owner January 15, 2025 19:21
Update cugraph examples
@puririshi98
Copy link
Contributor

@guanxingithub, as mentioned before please delete the old cugraph examples that you based these new files on, otherwise there will be duplicates.

@puririshi98
Copy link
Contributor

also please attach logs of these scripts running succesfully on your setup for review

@guanxingithub
Copy link
Contributor Author

examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call optimizer.zero_grad at the end? Otherwise, the gradients are still left on the device during subsequent evaluation loops, I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. But at the beginning of the loop, especially first iteration, we had better to clean gradients left on the device. We can add optimizer.zero_grad after the whole loop so that the optimizer is cleaned up and ready for other training.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But at the beginning of the loop, especially first iteration, we had better to clean gradients left on the device.

In what case gradients are left on device in the current script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when optimizer is created, the default gradients maybe not zero. so when we use it first time, had better to reset it.

examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @guanxingithub! This is not blocking you from merging this PR, but what would be so nice before merging this PR would be to profile some parts of the scripts and make sure there's no basic inefficient code :)

akihironitta
akihironitta previously approved these changes Jan 21, 2025
Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work :)

examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
@akihironitta akihironitta self-requested a review January 21, 2025 17:49
@akihironitta akihironitta dismissed their stale review January 21, 2025 17:50

accidentally pressed approve

Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I'd like these cugraph examples to be as performant as possible :) Let's merge this once comments are addressed and this PR is cleaned up :)

examples/ogbn_train_cugraph.py Outdated Show resolved Hide resolved
@guanxingithub
Copy link
Contributor Author

@puririshi98
Copy link
Contributor

@akihironitta reviewing the profile by @guanxingithub, it does not seem like there are any bottlenecks, i will leave to you to decide if we can merge if this investigation satisfies your concerns

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants