Skip to content

Commit 970ac2d

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor] Improve memory locality by iterating over y dimension before x (pytorch#149339)
# Feature Fixes pytorch#148718 by reordering the tensor dims to `(z, y, x)`. As a bonus refactor, block pointers no longer needed the `reorder=True` argument to `self.active_range_trees()`. Since this argument is no longer used anywhere, this PR simply deletes it as opposed to updating the logic for the new iteration order. # Perf impact It looks like there's a decent perf bump on A100, with cudagraphs enabled. Granted, perf runs seem to have some noise between commits. ([Workflow run](https://github.com/pytorch/pytorch/actions/runs/13914815576).) Training (all neutral or positive): ![image](https://github.com/user-attachments/assets/57f1ef1d-60b4-446f-baf3-aca87a26b81b) Inference (one positive, one very small negative): ![image](https://github.com/user-attachments/assets/679aa057-af23-47f1-8d8e-8520daf1bd92) As reported in pytorch#148718, this PR makes consecutive threads access consecutive memory addresses. This should theoretically give the GPU more opportunities to coalesce loads and stores. From Nvidia's [kernel profiling guide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html): > Local memory is private storage for an executing thread and is not visible outside of that thread. It is intended for thread-local data like thread stacks and register spills. Local memory addresses are translated to global virtual addresses by the AGU unit. Local memory has the same latency as global memory. One difference between global and local memory is that local memory is arranged such that consecutive 32-bit words are accessed by consecutive thread IDs. Accesses are therefore fully coalesced as long as all threads in a warp access the same relative address (e.g., same index in an array variable, same member in a structure variable, etc.). I couldn't find any information on how coalescing works for other kinds of memory, but the guide mentions it is also supported for accesses to the L2 cache. > The L2 Request Coalescer (LRC) processes incoming requests for L2 and tries to coalesce read requests before forwarding them to the L2 cache. It also serves programmatic multicast requests from the SM and supports compression for writes. The [answer to this Stack Overflow post](https://stackoverflow.com/a/5044424) also explains coalescing in a straightforward way. Inductor's current iteration order corresponds to the first (uncoalesced) example in that answer, while the order after this PR corresponds to the second (coalesced) example. Besides GPUs, this order of accessing data is highly advantageous for systems relying on DMAs, as those are designed to access contiguous spans of memory. This change improves the performance of an elementwise add kernel on an internal model, using internal hardware, by 1.76x. I will share the details with reviewers who are Meta employees via a private channel. # Test plan - Updated expected code on CI tests. - Added a new test checking the {x,y,z}indices and block pointers on a 3D pointwise kernel. Pull Request resolved: pytorch#149339 Approved by: https://github.com/jansel
1 parent 3647711 commit 970ac2d

File tree

3 files changed

+65
-20
lines changed

3 files changed

+65
-20
lines changed

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def _assert_reduction_ndims(self, code, num_dims: int) -> None:
116116
for unexpected_block in reduction_blocks[num_dims:]:
117117
self.assertNotIn(unexpected_block, code)
118118

119+
def _get_lines_containing_substr(self, code: str, substr: str) -> str:
120+
return "\n".join(line for line in code.split("\n") if substr in line)
121+
119122

120123
@instantiate_parametrized_tests
121124
class CommonTemplate:
@@ -348,29 +351,29 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool):
348351
# Check the code for broadcasts.
349352
# We shouldn't see any strides of 0.
350353
load_lines, store_lines = tuple(
351-
[line for line in triton_code.split("\n") if substr in line]
354+
self._get_lines_containing_substr(triton_code, substr)
352355
for substr in ("tl.load", "tl.store")
353356
)
354357
if prefer_nd_tiling:
355358
self.assertExpectedInline(
356-
"\n".join(load_lines),
359+
load_lines,
357360
"""\
358-
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
359-
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]""", # noqa: B950
361+
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1])
362+
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", # noqa: B950
360363
)
361364
self.assertExpectedInline(
362-
"\n".join(store_lines),
363-
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
365+
store_lines,
366+
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950
364367
)
365368
else:
366369
self.assertExpectedInline(
367-
"\n".join(load_lines),
370+
load_lines,
368371
"""\
369372
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
370373
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
371374
)
372375
self.assertExpectedInline(
373-
"\n".join(store_lines),
376+
store_lines,
374377
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])""", # noqa: B950
375378
)
376379

@@ -952,6 +955,54 @@ def fn(a):
952955
rtol=0.06,
953956
)
954957

958+
def test_pointwise_index_order(self):
959+
"""
960+
Test the order of indices in pointwise kernels. Expect Z to be the leading dim,
961+
then Y, then X.
962+
"""
963+
964+
inps = [
965+
self._discontiguous_tensor((5, 5, 5), device=self.device) for _ in range(2)
966+
]
967+
968+
result, (triton_code,) = run_and_compare(
969+
self,
970+
torch.add,
971+
*inps,
972+
expected_num_triton_kernels=1,
973+
expected_num_block_pointers=3,
974+
config_patches={
975+
"triton.max_tiles": 3,
976+
"triton.prefer_nd_tiling": True,
977+
},
978+
)
979+
980+
# Check the load and store for block pointer strides.
981+
load_lines, store_lines, index_lines = tuple(
982+
self._get_lines_containing_substr(triton_code, substr)
983+
for substr in ("tl.load", "tl.store", "index =")
984+
)
985+
self.assertExpectedInline(
986+
load_lines,
987+
"""\
988+
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])
989+
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950
990+
)
991+
992+
self.assertExpectedInline(
993+
store_lines,
994+
""" tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950
995+
)
996+
997+
# Check the indices. These are used for non-block pointers.
998+
self.assertExpectedInline(
999+
index_lines,
1000+
"""\
1001+
zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None]
1002+
yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None]
1003+
xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950
1004+
)
1005+
9551006

9561007
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
9571008
@config.patch(cpu_backend="triton")

torch/_inductor/codegen/simd.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -424,13 +424,14 @@ def filtered_index_map(seq, mask) -> dict[Any, int]:
424424
}
425425

426426
grid_dims = ["x", "y", "z"]
427+
pointwise_tensor_dims = list(reversed(grid_dims))
427428
reduction_dims = ["r0_", "r1_"]
428429
if no_x_dim:
429430
tensor_dims = reduction_dims
430431
elif no_r_dim:
431-
tensor_dims = grid_dims
432+
tensor_dims = pointwise_tensor_dims
432433
else:
433-
tensor_dims = grid_dims + reduction_dims
434+
tensor_dims = pointwise_tensor_dims + reduction_dims
434435

435436
# Filter out unused tensor dims.
436437
# Convert to dicts for O(1) index lookup.
@@ -814,17 +815,10 @@ def prepare_indexing(
814815

815816
return self.codegen_indexing(simp_index)
816817

817-
def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]:
818-
trees = [
818+
def active_range_trees(self) -> list[IterationRangesRoot]:
819+
return [
819820
t for t in self.range_trees if not t.is_reduction or self.inside_reduction
820821
]
821-
if reorder and len(trees) > 1:
822-
count = sum(t.prefix in "xyz" for t in trees)
823-
assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [
824-
t.prefix for t in trees[:count]
825-
]
826-
trees[:count] = reversed(trees[:count])
827-
return trees
828822

829823
def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
830824
expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())

torch/_inductor/codegen/triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1966,7 +1966,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]:
19661966
index_relative_to_xyr_index = sympy_subs(
19671967
index, {v: t.expr for v, t in self.range_tree_nodes.items()}
19681968
)
1969-
range_trees = self.active_range_trees(reorder=True)
1969+
range_trees = self.active_range_trees()
19701970

19711971
# Partition the index into subexpressions pertaining to each range tree.
19721972
# For example xindex * 5 + r0_index * 3 is partitioned to

0 commit comments

Comments
 (0)