Skip to content

Commit 0adcfc2

Browse files
Move sparsecore passes under transforms/sparsecore.
PiperOrigin-RevId: 620072718
1 parent a9db13b commit 0adcfc2

File tree

15 files changed

+277
-84
lines changed

15 files changed

+277
-84
lines changed

tensorflow/compiler/mlir/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ cc_library(
5959
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", # buildcleaner:keep
6060
"//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
6161
"//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:runtime_passes",
62+
"//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes",
6263
"//tensorflow/compiler/mlir/tf2xla:compile_mlir_util",
6364
"//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes",
6465
"//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes",
@@ -69,6 +70,7 @@ cc_library(
6970
"//tensorflow/compiler/mlir/tosa:tfl_passes",
7071
"@llvm-project//mlir:AllPassesAndDialects",
7172
"@llvm-project//mlir:MlirOptLib",
73+
"@llvm-project//mlir:Support",
7274
"@llvm-project//mlir:Transforms",
7375
"@local_xla//xla/mlir/framework/ir:xla_framework",
7476
"@local_xla//xla/mlir/framework/transforms:passes",

tensorflow/compiler/mlir/tensorflow/transforms/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,6 @@ cc_library(
460460
"device_index_selector.cc",
461461
"drop_while_shape_invariant.cc",
462462
"einsum.cc",
463-
"embedding_pipelining.cc",
464-
"embedding_program_key.cc",
465-
"embedding_sequencing.cc",
466463
"executor_island_coarsening.cc",
467464
"executor_tpuv1_inline_tpu_island.cc",
468465
"executor_tpuv1_island_coarsening.cc",

tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
3232
"//tensorflow/compiler/mlir/tensorflow:error_util",
3333
"//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
34+
"//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes",
3435
"//tensorflow/core:framework",
3536
"//tensorflow/core:lib_proto_parsing",
3637
"//tensorflow/core/platform:error_payloads",

tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/compiler/jit/flags.h"
2929
#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h"
3030
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31+
#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h"
3132
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3233
#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
3334
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"

tensorflow/compiler/mlir/tensorflow/transforms/passes.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> CreateReplicateToIslandPass(
446446
std::unique_ptr<OperationPass<func::FuncOp>>
447447
CreateReplicaIDToDeviceOrdinalPass();
448448

449-
// Creates a pass that adds pipelining to a graph that contains device
450-
// accelerated embeddings. The EmbeddingSequencingPass is a temporary fallback
451-
// while developing full pipelining capabilities.
452-
std::unique_ptr<OperationPass<ModuleOp>> CreateEmbeddingSequencingPass();
453-
std::unique_ptr<OperationPass<ModuleOp>> CreateEmbeddingPipeliningPass();
454-
std::unique_ptr<OperationPass<func::FuncOp>> CreateEmbeddingProgramKeyPass();
455-
456449
// Creates a pass that creates `tf_executor.island` from a single
457450
// `tf_device.parallel_execute` island.
458451
std::unique_ptr<OperationPass<func::FuncOp>> CreateParallelExecuteToIslandsPass(
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
2+
load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable")
3+
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
4+
5+
package(
6+
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
7+
default_visibility = [
8+
"//tensorflow/compiler/mlir:__pkg__",
9+
"//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__",
10+
"//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:__pkg__",
11+
"//tensorflow/compiler/mlir/tf2xla/api:__subpackages__",
12+
"//tensorflow/compiler/mlir/tf2xla/internal:__pkg__",
13+
],
14+
licenses = ["notice"],
15+
)
16+
17+
gentbl_cc_library(
18+
name = "sparsecore_passes_inc_gen",
19+
compatible_with = get_compatible_with_portable(),
20+
tbl_outs = [
21+
(
22+
[
23+
"-gen-pass-decls",
24+
"-name=SparseCore",
25+
],
26+
"sparsecore_passes.h.inc",
27+
),
28+
],
29+
tblgen = "@llvm-project//mlir:mlir-tblgen",
30+
td_file = "sparsecore_passes.td",
31+
deps = [
32+
"@llvm-project//mlir:PassBaseTdFiles",
33+
],
34+
)
35+
36+
cc_library(
37+
name = "sparsecore_passes",
38+
hdrs = [
39+
"sparsecore_passes.h",
40+
],
41+
textual_hdrs = [
42+
"sparsecore_passes.h.inc",
43+
],
44+
deps = [
45+
":embedding_pipelining",
46+
":embedding_program_key",
47+
":embedding_sequencing",
48+
":sparsecore_passes_inc_gen",
49+
"@llvm-project//llvm:Support",
50+
"@llvm-project//mlir:FuncDialect",
51+
"@llvm-project//mlir:IR",
52+
"@llvm-project//mlir:Pass",
53+
],
54+
)
55+
56+
cc_library(
57+
name = "embedding_pipelining",
58+
srcs = ["embedding_pipelining.cc"],
59+
hdrs = [
60+
"sparsecore_passes.h",
61+
],
62+
deps = [
63+
":sparsecore_passes_inc_gen",
64+
"//tensorflow/compiler/jit:flags_headers",
65+
"//tensorflow/compiler/mlir/tensorflow",
66+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
67+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
68+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
69+
"@com_google_absl//absl/log",
70+
"@com_google_absl//absl/strings",
71+
"@llvm-project//llvm:Support",
72+
"@llvm-project//mlir:FuncDialect",
73+
"@llvm-project//mlir:IR",
74+
"@llvm-project//mlir:InliningUtils",
75+
"@llvm-project//mlir:Pass",
76+
"@llvm-project//mlir:Support",
77+
"@llvm-project//mlir:TransformUtils",
78+
],
79+
)
80+
81+
cc_library(
82+
name = "embedding_sequencing",
83+
srcs = ["embedding_sequencing.cc"],
84+
hdrs = [
85+
"sparsecore_passes.h",
86+
],
87+
deps = [
88+
":sparsecore_passes_inc_gen",
89+
"//tensorflow/compiler/jit:flags_headers",
90+
"//tensorflow/compiler/mlir/tensorflow",
91+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
92+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
93+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
94+
"@com_google_absl//absl/log",
95+
"@com_google_absl//absl/strings",
96+
"@llvm-project//llvm:Support",
97+
"@llvm-project//mlir:FuncDialect",
98+
"@llvm-project//mlir:IR",
99+
"@llvm-project//mlir:InliningUtils",
100+
"@llvm-project//mlir:Pass",
101+
"@llvm-project//mlir:Support",
102+
"@llvm-project//mlir:TransformUtils",
103+
],
104+
)
105+
106+
cc_library(
107+
name = "embedding_program_key",
108+
srcs = ["embedding_program_key.cc"],
109+
hdrs = [
110+
"sparsecore_passes.h",
111+
],
112+
deps = [
113+
":sparsecore_passes_inc_gen",
114+
"//tensorflow/compiler/mlir/tensorflow",
115+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
116+
"@llvm-project//llvm:Support",
117+
"@llvm-project//mlir:FuncDialect",
118+
"@llvm-project//mlir:IR",
119+
"@llvm-project//mlir:Pass",
120+
"@llvm-project//mlir:Support",
121+
"@local_xla//xla/mlir_hlo",
122+
],
123+
)

tensorflow/compiler/mlir/tensorflow/transforms/embedding_pipelining.cc renamed to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ return selected_results
157157
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
158158

159159
#define GEN_PASS_DEF_EMBEDDINGPIPELININGPASS
160-
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
160+
#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc"
161161

162162
static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining";
163163
static constexpr char kEmbeddingPipeliningInlineAttr[] =
@@ -1289,7 +1289,7 @@ LogicalResult StartStep0(OpBuilder& builder, Location& loc,
12891289
func::FuncOp orig_parent_func =
12901290
callers.backward->getParentOfType<func::FuncOp>();
12911291

1292-
std::vector<Value> operands = loop_operands_nm0;
1292+
const std::vector<Value>& operands = loop_operands_nm0;
12931293

12941294
// Input types will be the same as the original loop body.
12951295
std::vector<Type> input_types = GetValueTypes(operands);
@@ -1373,7 +1373,7 @@ LogicalResult StartStep1(OpBuilder& builder, Location& loc,
13731373
func::FuncOp orig_parent_func =
13741374
callers.backward->getParentOfType<func::FuncOp>();
13751375

1376-
std::vector<Value> operands = loop_operands_1;
1376+
const std::vector<Value>& operands = loop_operands_1;
13771377

13781378
// Input types will be the same as the original loop body.
13791379
std::vector<Type> input_types = GetValueTypes(operands);

tensorflow/compiler/mlir/tensorflow/transforms/embedding_program_key.cc renamed to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ limitations under the License.
3131
#include "mlir/Support/LogicalResult.h" // from @llvm-project
3232
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
3333
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34-
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
3534
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
3635

3736
namespace mlir {
@@ -42,7 +41,7 @@ constexpr char kMiniBatchSplitsAttr[] = "mini_batch_splits";
4241
constexpr char kMiniBatchCsrAttr[] = "mini_batch_in_csr";
4342

4443
#define GEN_PASS_DEF_EMBEDDINGPROGRAMKEYPASS
45-
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
44+
#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc"
4645

4746
struct EmbeddingProgramKeyPass
4847
: public impl::EmbeddingProgramKeyPassBase<EmbeddingProgramKeyPass> {

tensorflow/compiler/mlir/tensorflow/transforms/embedding_sequencing.cc renamed to tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ limitations under the License.
3232
#include <utility>
3333
#include <vector>
3434

35+
#include "absl/log/log.h"
36+
#include "absl/strings/str_cat.h"
3537
#include "llvm/ADT/STLExtras.h"
3638
#include "llvm/ADT/SetVector.h"
3739
#include "llvm/Support/Casting.h"
@@ -40,24 +42,28 @@ limitations under the License.
4042
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
4143
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
4244
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
45+
#include "mlir/IR/DialectRegistry.h" // from @llvm-project
4346
#include "mlir/IR/Location.h" // from @llvm-project
4447
#include "mlir/IR/Operation.h" // from @llvm-project
4548
#include "mlir/IR/Region.h" // from @llvm-project
4649
#include "mlir/IR/SymbolTable.h" // from @llvm-project
4750
#include "mlir/IR/Types.h" // from @llvm-project
4851
#include "mlir/IR/Value.h" // from @llvm-project
4952
#include "mlir/IR/Visitors.h" // from @llvm-project
53+
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
5054
#include "mlir/Pass/Pass.h" // from @llvm-project
5155
#include "mlir/Support/LLVM.h" // from @llvm-project
5256
#include "mlir/Support/LogicalResult.h" // from @llvm-project
5357
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
5458
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
5559
#include "tensorflow/compiler/jit/flags.h"
60+
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
5661
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
62+
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
5763
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
5864

5965
#define GEN_PASS_DEF_EMBEDDINGSEQUENCINGPASS
60-
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
66+
#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc"
6167

6268
static constexpr char kEmbeddingPipelining[] = "_embedding_pipelining";
6369
static constexpr char kEmbeddingForward[] = "forward";
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_
17+
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_
18+
19+
#include <memory>
20+
21+
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22+
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
23+
#include "mlir/Pass/Pass.h" // from @llvm-project
24+
25+
namespace mlir {
26+
namespace TFDevice {
27+
28+
// For architectures that support accelerated embedding lookups, this pass will
29+
// rewrite the graph to use pipelining for better device utilization.
30+
std::unique_ptr<OperationPass<ModuleOp>> CreateEmbeddingSequencingPass();
31+
32+
// This is a strictly sequential and formally correct fallback option for the
33+
// embedding pipelining pass intended for debugging during pipelining
34+
// development.
35+
std::unique_ptr<OperationPass<ModuleOp>> CreateEmbeddingPipeliningPass();
36+
37+
// Passes in the program key to embedding ops, by moving the embedding ops
38+
// after the _TPUCompileMlir op.
39+
std::unique_ptr<OperationPass<func::FuncOp>> CreateEmbeddingProgramKeyPass();
40+
41+
#define GEN_PASS_REGISTRATION
42+
#define GEN_PASS_DECL_EMBEDDINGSEQUENCINGPASS
43+
#define GEN_PASS_DECL_EMBEDDINGPIPELININGPASS
44+
#define GEN_PASS_DECL_EMBEDDINGPROGRAMKEYPASS
45+
#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc"
46+
47+
} // namespace TFDevice
48+
} // namespace mlir
49+
50+
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_

0 commit comments

Comments
 (0)