Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HeteroLinear/SEGMM: switch from heuristic to timing-cache #8615

Merged
merged 24 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
47209ba
rebasing stadlmax's timing heuristic for segmm
puririshi98 Dec 13, 2023
fe4dc26
finishing rebase
puririshi98 Dec 13, 2023
e352aa9
Merge branch 'master' into rebase-time-heuristic
puririshi98 Dec 14, 2023
6814647
resolve conflicts
stadlmax Jan 5, 2024
edfb5ec
fix issues
stadlmax Jan 5, 2024
586e3e3
fix further issues
stadlmax Jan 5, 2024
a29943b
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 8, 2024
32f64ea
accept review suggestions
puririshi98 Jan 9, 2024
6e60787
global measure iters
puririshi98 Jan 9, 2024
309dc53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
33d123d
applying suggestion
puririshi98 Jan 9, 2024
03abe6d
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 10, 2024
f795f54
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 11, 2024
e768f70
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 12, 2024
03c526f
MEASURE_ITER=1 for pytesting
puririshi98 Jan 17, 2024
344947e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2024
3dccb19
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 18, 2024
39ea9dd
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 19, 2024
28bda99
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 22, 2024
f1d3c7f
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 23, 2024
191d239
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 24, 2024
c91f0f5
Merge branch 'master' into rebase-time-heuristic
puririshi98 Jan 25, 2024
054e3f1
Merge branch 'master' into rebase-time-heuristic
rusty1s Jan 29, 2024
1c10dda
update
rusty1s Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
- switch to selecting `use_segment_matmul` in `HeteroLinear` based on a rudimentary timing-cache instead of relying on heuristic ([#8615](https://github.com/pyg-team/pytorch_geometric/pull/8615))

### Deprecated

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def use_segment_matmul_heuristic(
:meth:`segment_matmul` can speed up computation.
"""
# NOTE This heuristic was learned on an A100 via sklearn using a simple
# StandardScaler() -> LinearSVC() model.
# StandardScaler() -> LinearSVC() model. For now, it is only used
# in combination with RGCNConv.
x = torch.tensor([
num_segments,
max_segment_size,
Expand Down
151 changes: 99 additions & 52 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import math
import time
from typing import Any, Dict, Optional, Union

import torch
Expand All @@ -12,7 +13,7 @@
from torch_geometric import is_compiling
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort, scatter
from torch_geometric.utils import index_sort
from torch_geometric.utils.sparse import index2ptr


Expand Down Expand Up @@ -236,7 +237,7 @@ def __init__(
self.is_sorted = is_sorted
self.kwargs = kwargs

self._use_segment_matmul_heuristic_output: Optional[bool] = None
self._timing_cache = {}

if self.in_channels == -1:
self.weight = torch.nn.parameter.UninitializedParameter()
Expand All @@ -258,64 +259,110 @@ def reset_parameters(self):
reset_bias_(self.bias, self.in_channels,
self.kwargs.get('bias_initializer', None))

@torch.jit.unused
def forward_segmm(self, x: Tensor, type_vec_ptr: Tensor) -> Tensor:
assert self.weight is not None
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
return out

def forward_naive(self, x: Tensor, type_vec_ptr: Tensor) -> Tensor:
out = x.new_empty(x.size(0), self.out_channels)

for i in range(self.num_types):
off_start, off_end = type_vec_ptr[i], type_vec_ptr[i + 1]
subset_out = x[off_start:off_end] @ self.weight[i]
# The data type may have changed with mixed precision:
out[off_start:off_end] = subset_out.to(out.dtype)

return out

@torch.jit.unused
def _update_timing_cache(self, x: Tensor, type_vec_ptr: Tensor,
num_rows: int) -> bool:
measure_iter = 3
with torch.no_grad():
# only measure forward pass for now
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(measure_iter):
_ = self.forward_segmm(x, type_vec_ptr)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.perf_counter()
time_segmm = end - start

if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(measure_iter):
_ = self.forward_naive(x, type_vec_ptr)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.perf_counter()
time_naive = end - start

# first entry is with segmm, second without
# if segmm is faster based on timings, use it
self._timing_cache[num_rows] = (time_segmm, time_naive)
use_segment_matmul = time_segmm < time_naive

return use_segment_matmul

def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
r"""Forward pass.

Args:
x (torch.Tensor): The input features.
type_vec (torch.Tensor): A vector that maps each entry to a type.
"""
use_segment_matmul = torch_geometric.backend.use_segment_matmul
# If `use_segment_matmul` is not specified, use a simple heuristic to
# determine whether `segment_matmul` can speed up computation given the
# observed input sizes:
if use_segment_matmul is None:
if self._use_segment_matmul_heuristic_output is None:
segment_count = scatter(torch.ones_like(type_vec), type_vec,
dim_size=self.num_types, reduce='sum')

self._use_segment_matmul_heuristic_output = (
torch_geometric.backend.use_segment_matmul_heuristic(
num_segments=self.num_types,
max_segment_size=int(segment_count.max()),
in_channels=self.weight.size(1),
out_channels=self.weight.size(2),
))

assert self._use_segment_matmul_heuristic_output is not None
use_segment_matmul = self._use_segment_matmul_heuristic_output

if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
and not is_compiling()):
assert self.weight is not None

perm: Optional[Tensor] = None
if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = index_sort(type_vec, self.num_types)
x = x[perm]

type_vec_ptr = index2ptr(type_vec, self.num_types)
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
if self.bias is not None:
out += self.bias[type_vec]

if perm is not None: # Restore original order (if necessary).
out_unsorted = torch.empty_like(out)
out_unsorted[perm] = out
out = out_unsorted
perm: Optional[Tensor] = None
if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = index_sort(type_vec, self.num_types)
x = x[perm]

type_vec_ptr = index2ptr(type_vec, self.num_types)

if torch_geometric.backend.use_segment_matmul is None:
use_segment_matmul = False

# TODO check cses of compiling and scripting properly
if torch_geometric.typing.WITH_SEGMM and not is_compiling(
) and not torch.jit.is_scripting():
# to avoid too many measurements for dynamic shapes
# use "magnitude" of number of rows as target
num_rows = math.floor(math.log10(x.size(0)))
if num_rows in self._timing_cache:
timings = self._timing_cache[num_rows]
# first entry is with segmm, second without
# if segmm is faster based on timings, use it
use_segment_matmul = timings[0] < timings[1]

elif num_rows not in self._timing_cache:
use_segment_matmul = self._update_timing_cache(
x, type_vec_ptr, num_rows)

else:
out = x.new_empty(x.size(0), self.out_channels)
for i in range(self.num_types):
mask = type_vec == i
if mask.numel() == 0:
continue
subset_out = F.linear(x[mask], self.weight[i].T)
# The data type may have changed with mixed precision:
out[mask] = subset_out.to(out.dtype)

if self.bias is not None:
out += self.bias[type_vec]
# TODO check cases of compiling and scripting properly
use_segment_matmul = (torch_geometric.typing.WITH_SEGMM and
torch_geometric.backend.use_segment_matmul
and not is_compiling()
and not torch.jit.is_scripting())

if use_segment_matmul:
out = self.forward_segmm(x, type_vec_ptr)
else:
out = self.forward_naive(x, type_vec_ptr)

if self.bias is not None:
out += self.bias[type_vec]

if perm is not None: # Restore original order (if necessary).
out_unsorted = torch.empty_like(out)
out_unsorted[perm] = out
out = out_unsorted

return out

@torch.no_grad()
Expand Down
Loading