Skip to content

Commit eb49ef9

Browse files
[Layout] Layout API: Shard (apache#35)
* dev * dev * dev
1 parent 355518c commit eb49ef9

File tree

6 files changed

+346
-54
lines changed

6 files changed

+346
-54
lines changed

include/tvm/tir/exec_scope.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ bool Higher(const ExecScope& lhs, const ExecScope& rhs);
270270

271271
bool Higher(const String& lhs, const String& rhs);
272272

273+
bool Equal(const ExecScope& lhs, const ExecScope& rhs);
274+
275+
bool Equal(const String& lhs, const String& rhs);
276+
273277
bool ValideScope(const ExecScope& scope);
274278

275279
bool ValideScope(const String& scope);

include/tvm/tir/layout.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,18 @@ class TileLayout : public TLayout {
350350
/*! \brief Construct a new layout by tiling the ouer layout over the inner layout */
351351
TileLayout Tile(TileLayout outer, TileLayout inner);
352352

353+
/*!
354+
* \brief Construct a new layout to express the sharding strategy of a tensor.
355+
* \param shape The shape of the tensor.
356+
* \param mesh The device mesh
357+
* \param strategy The sharding strategy of the tensor.
358+
* \param inner The layout of the sharded partition of the tensor.
359+
* \param from The source scope of the layout.
360+
* \param to The target scope of the layout.
361+
*/
362+
TileLayout Shard(Array<PrimExpr> shape, IterTree mesh, String strategy, TileLayout inner,
363+
ExecScope from, ExecScope to);
364+
353365
/*! \brief Layout normalization
354366
1. Deduplicate the split nodes in the tree, such that no two split nodes share the same child
355367
node.

python/tvm/tir/layout.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ def _construct_data_iter_tree(
176176

177177
@staticmethod
178178
def from_nested_tuple(
179-
data: Tuple,
180-
strides: Tuple,
181-
device: Optional[Tuple] = None,
179+
data: Union[Tuple, int],
180+
strides: Union[Tuple, int],
181+
device: Optional[Union[Tuple, int]] = None,
182182
exclusive: Optional[Tuple] = None,
183183
from_to: Optional[Tuple[str]] = None,
184184
) -> "TileLayout":
@@ -189,6 +189,11 @@ def inc_leaf_cnt():
189189
leaf_cnt += 1
190190
return leaf_cnt - 1
191191

192+
if not isinstance(data, tuple):
193+
data = (data,)
194+
if not isinstance(strides, tuple):
195+
strides = (strides,)
196+
192197
if device is None:
193198
assert exclusive is None, "exclusive must be None if device is None"
194199
assert from_to is None, "from_to must be None if device is None"
@@ -199,6 +204,9 @@ def inc_leaf_cnt():
199204
return TileLayout(data_tree=data_tree)
200205

201206
else:
207+
if not isinstance(device, tuple):
208+
device = (device,)
209+
202210
assert from_to is not None, "from_to must be provided if device is provided"
203211
assert isinstance(from_to, tuple) and len(from_to) == 2, "from_to must be a tuple of 2"
204212

@@ -227,6 +235,28 @@ def inc_leaf_cnt():
227235
def tile(outer: "TileLayout", inner: "TileLayout") -> "TileLayout":
228236
return get_global_func("tir.TileLayoutTile")(outer, inner)
229237

238+
@staticmethod
239+
def shard(
240+
shape: Tuple[PrimExpr, int],
241+
mesh: Tuple,
242+
strategy: str,
243+
inner: "TileLayout",
244+
from_to: Optional[Tuple[str]] = None,
245+
) -> "TileLayout":
246+
assert from_to is not None, "from_to must be provided if device is provided"
247+
assert isinstance(from_to, tuple) and len(from_to) == 2, "from_to must be a tuple of 2"
248+
249+
f = get_global_func("tir.IterTreeFromTuple")
250+
iter_tree, _ = f(convert_to_object(mesh))
251+
return get_global_func("tir.TileLayoutShard")(
252+
shape,
253+
iter_tree,
254+
strategy,
255+
inner,
256+
ExecScope.create(from_to[0]) if from_to else None,
257+
ExecScope.create(from_to[1]) if from_to else None,
258+
)
259+
230260
@staticmethod
231261
def normalize(layout: "TileLayout") -> "TileLayout":
232262
return get_global_func("tir.NormalizeTileLayout")(layout)

src/tir/ir/exec_scope.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ ExecScope ExecScope::Create(String name) {
112112

113113
TVM_REGISTER_NODE_TYPE(ExecScopeNode);
114114

115-
TVM_REGISTER_GLOBAL("tir.ExecScope").set_body_typed([](String name) {
116-
return ExecScope(name);
117-
});
115+
TVM_REGISTER_GLOBAL("tir.ExecScope").set_body_typed([](String name) { return ExecScope(name); });
118116

119117
TVM_REGISTER_GLOBAL("tir.ExecScopeCreate").set_body_typed([](String name) {
120118
return ExecScope::Create(name);
@@ -178,6 +176,10 @@ bool Higher(const String& lhs, const String& rhs) {
178176
return ScopeOrder.at(lhs) < ScopeOrder.at(rhs);
179177
}
180178

179+
bool Equal(const ExecScope& lhs, const ExecScope& rhs) { return lhs->name == rhs->name; }
180+
181+
bool Equal(const String& lhs, const String& rhs) { return lhs == rhs; }
182+
181183
bool ValideScope(const ExecScope& scope) { return ValideScope(scope->name); }
182184

183185
bool ValideScope(const String& scope) { return ScopeOrder.count(scope) > 0; }

0 commit comments

Comments
 (0)