Skip to content

Commit

Permalink
[FutureWarning] Addressing the warning related to use of a deprecated…
Browse files Browse the repository at this point in the history
… function. (#9674)

This PR resolves a FutureWarning that appears in several tests:
```
test/nn/models/test_g_retriever.py::test_g_retriever
test/nn/models/test_g_retriever.py::test_g_retriever_many_tokens
test/nn/nlp/test_llm.py::test_llm
  /usr/local/lib/python3.10/dist-packages/torch_geometric/nn/nlp/llm.py:97: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
    self.autocast_context = torch.cuda.amp.autocast(dtype=dtype)
```
  • Loading branch information
drivanov authored Sep 24, 2024
1 parent 49fe936 commit b1c198e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_geometric/nn/nlp/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.autocast_context = nullcontext()
else:
self.device = self.llm.device
self.autocast_context = torch.cuda.amp.autocast(dtype=dtype)
self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)

def _encode_inputs(
self,
Expand Down

0 comments on commit b1c198e

Please sign in to comment.