Skip to content

Commit d952cc2

Browse files
jcai19tensorflower-gardener
authored andcommitted
Add a field to StreamZ metric /tensorflow/core/tf_mlir_bridge_first_phase_count
Add a filed for the type of TF2XLA Phase 1 Bridge , i.e. Replicated Bridge vs. Non-replicated Bridge. PiperOrigin-RevId: 620060074
1 parent 97fd9c1 commit d952cc2

File tree

14 files changed

+127
-56
lines changed

14 files changed

+127
-56
lines changed

tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ cc_library(
2626
deps = [
2727
":runtime_passes",
2828
"//tensorflow/compiler/jit:flags_headers",
29+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
2930
"//tensorflow/compiler/mlir/tensorflow:bridge_logger",
3031
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
3132
"//tensorflow/compiler/mlir/tensorflow:error_util",
32-
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
3333
"//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
3434
"//tensorflow/core:framework",
3535
"//tensorflow/core:lib_proto_parsing",
@@ -62,6 +62,7 @@ tf_cc_test(
6262
":lower_cluster_to_runtime_ops",
6363
"//tensorflow/compiler/mlir:register_common_dialects",
6464
"//tensorflow/compiler/mlir/tensorflow",
65+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
6566
"//tensorflow/compiler/tf2xla:xla_op_registry",
6667
"//tensorflow/core:framework",
6768
"//tensorflow/core:lib",

tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/compiler/jit/flags.h"
2929
#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h"
3030
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3132
#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
3233
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
3334
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
@@ -121,6 +122,7 @@ void CreateNonTPULowerClusterToRuntimeOpsPassPipeline(
121122
// TODO(b/306728216): Move this out of the Bridge component and into a Host
122123
// runtime component.
123124
tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
125+
std::string bridge_type,
124126
tsl::DeviceType device_type,
125127
absl::Status status) {
126128
if (status.ok()) {
@@ -129,11 +131,12 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
129131

130132
VLOG(2) << error_prefix << " " << status;
131133
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
132-
device_type.type_string(), /*bridge_version=*/"v2",
134+
bridge_type,
135+
/*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV2,
136+
device_type.type_string(),
133137
/*fallback_enabled=*/false,
134138
/*result=*/"failure");
135139

136-
constexpr char kBridgeComponent[] = "TFXLABridge";
137140
std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE";
138141

139142
tsl::OkOrSetErrorCounterPayload(
@@ -144,7 +147,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
144147
bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE";
145148
}
146149

147-
tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
150+
tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent,
148151
status.ToString())
149152
.IgnoreError();
150153

@@ -194,10 +197,13 @@ absl::Status RunLowerClusterToRuntimeOpsPassPipeline(
194197
module, llvm::StringRef(), &runtime_lowering);
195198
}
196199

200+
std::string bridge_type = xla_device_type == DeviceType(DEVICE_TPU_XLA_JIT)
201+
? mlir::TF::kMlirPh1BridgeCounterReplicated
202+
: mlir::TF::kMlirPh1BridgeCounterNonReplicated;
197203
auto result_status = diag_handler.ConsumeStatus();
198204
TF_RETURN_IF_ERROR(
199205
RecordIfErrorStatus(/*error_prefix=*/"lower_cluster_to_runtime",
200-
xla_device_type, result_status));
206+
bridge_type, xla_device_type, result_status));
201207

202208
return absl::OkStatus();
203209
}

tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ limitations under the License.
3333
#include "mlir/Pass/PassManager.h" // from @llvm-project
3434
#include "tensorflow/compiler/mlir/register_common_dialects.h"
3535
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3637
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
3738
#include "tensorflow/core/lib/monitoring/cell_reader.h"
3839
#include "tensorflow/core/platform/env.h"
@@ -167,9 +168,11 @@ TEST_F(LowerClusterToRuntimeOpsTest, ErrorsWithBadCluster) {
167168
*mlir_module_, DeviceType(DEVICE_TPU_XLA_JIT))
168169
.ok());
169170

170-
EXPECT_EQ(compilation_status.Delta("XLA_TPU_JIT", "v2", "fallback_disabled",
171-
"failure"),
172-
1);
171+
EXPECT_EQ(
172+
compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
173+
mlir::TF::kMlirPh1BridgeCounterV2, "XLA_TPU_JIT",
174+
"fallback_disabled", "failure"),
175+
1);
173176
}
174177

175178
TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) {

tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ inline constexpr llvm::StringRef kDynamicArgIndexAttr = "_dynamic_arg_index";
121121
inline constexpr llvm::StringRef kParallelExecAnnotation =
122122
"_parallel_execution_ids";
123123

124+
// Logging
125+
126+
// Name of component for error logging. This name is fixed and required to
127+
// enable logging.
128+
inline const char kBridgeComponent[] = "TFXLABridge";
129+
inline const char kMlirPh1BridgeCounterReplicated[] = "replicated";
130+
inline const char kMlirPh1BridgeCounterNonReplicated[] = "nonreplicated";
131+
inline const char kMlirPh1BridgeCounterV1[] = "v1";
132+
inline const char kMlirPh1BridgeCounterV2[] = "v2";
133+
inline const char kMlirPh1BridgeCounterTpu[] = "tpu";
134+
inline const char kMlirPh1BridgeCounterNonTpu[] = "cpu/gpu";
135+
124136
// Copies attributes that satisfy the given predicate from `from` to `to`.
125137
template <typename Predicate>
126138
void CopyAttributes(Operation *from, Operation *to, Predicate P) {

tensorflow/compiler/mlir/tf2xla/api/v1/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ cc_library(
194194
],
195195
deps = [
196196
":tf_dialect_to_executor",
197+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
197198
"//tensorflow/compiler/mlir/tensorflow:bridge_logger",
198199
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
199200
"//tensorflow/compiler/mlir/tensorflow:error_util",
@@ -232,6 +233,7 @@ tf_cc_test(
232233
deps = [
233234
":cluster_tf",
234235
"//tensorflow/compiler/mlir:register_common_dialects",
236+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
235237
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
236238
"//tensorflow/core/lib/monitoring:cell_reader",
237239
"//tensorflow/core/platform:resource_loader",
@@ -241,7 +243,6 @@ tf_cc_test(
241243
"@llvm-project//mlir:IR",
242244
"@llvm-project//mlir:Parser",
243245
"@local_tsl//tsl/lib/core:status_test_util",
244-
"@local_tsl//tsl/lib/monitoring:test_utils",
245246
"@local_tsl//tsl/platform:status",
246247
],
247248
)

tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
3232
#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h"
3333
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3435
#include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h"
3536
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
3637
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
@@ -60,10 +61,6 @@ using mlir::func::FuncOp;
6061

6162
namespace {
6263

63-
// Name of component for error logging. This name is fixed and required to
64-
// enable logging.
65-
constexpr char kBridgeComponent[] = "TFXLABridge";
66-
6764
void CreateReplicatedBridgePipelineV1(OpPassManager &pm) {
6865
pm.addPass(mlir::tf2xla::internal::CreateInferenceMetricsPass());
6966

@@ -152,10 +149,12 @@ tensorflow::Status RecordStatusIfError(const std::string error_prefix,
152149
}
153150

154151
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
155-
/*device_type=*/"tpu", /*bridge_version=*/"v1",
152+
/*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated,
153+
/*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1,
154+
/*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu,
156155
/*fallback_enabled=*/is_in_fallback_enabled_mode,
157156
/*result=*/"failure");
158-
tsl::error_logging::Log(kBridgeComponent,
157+
tsl::error_logging::Log(mlir::TF::kBridgeComponent,
159158
"TFXLA_PHASE_ONE_MLIR_TPU_V1_COMPAT_BRIDGE",
160159
status.ToString())
161160
.IgnoreError();
@@ -221,7 +220,9 @@ tensorflow::Status RunSessionTf2xlaClusteringBridge(
221220
RunClusteringPipelineOnSubmodule(module, is_in_fallback_enabled_mode));
222221

223222
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
224-
/*device_type=*/"tpu", /*bridge_version=*/"v1",
223+
/*bridge_type=*/mlir::TF::kMlirPh1BridgeCounterReplicated,
224+
/*bridge_version=*/mlir::TF::kMlirPh1BridgeCounterV1,
225+
/*device_type*/ mlir::TF::kMlirPh1BridgeCounterTpu,
225226
/*n_fallback_enabled*/ is_in_fallback_enabled_mode,
226227
/*result=*/"success");
227228

tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
2828
#include "mlir/Parser/Parser.h" // from @llvm-project
2929
#include "tensorflow/compiler/mlir/register_common_dialects.h"
30+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3031
#include "tensorflow/core/lib/monitoring/cell_reader.h"
3132
#include "tensorflow/core/platform/resource_loader.h"
3233
#include "tsl/lib/core/status_test_util.h"
@@ -84,8 +85,11 @@ TEST_F(SessionClusterTensorflowDialectTest, ClustersTf) {
8485
TF_EXPECT_OK(
8586
RunSessionTf2xlaClusteringBridge(*mlir_module_,
8687
/*is_in_fallback_enabled_mode=*/false));
87-
EXPECT_EQ(
88-
compilation_status.Delta("tpu", "v1", "fallback_disabled", "success"), 1);
88+
EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
89+
mlir::TF::kMlirPh1BridgeCounterV1,
90+
mlir::TF::kMlirPh1BridgeCounterTpu,
91+
"fallback_disabled", "success"),
92+
1);
8993
}
9094

9195
TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) {
@@ -98,8 +102,11 @@ TEST_F(SessionClusterTensorflowDialectTest, FailsWithMultipleSubmodules) {
98102
/*is_in_fallback_enabled_mode=*/false)
99103
.ok());
100104

101-
EXPECT_EQ(
102-
compilation_status.Delta("tpu", "v1", "fallback_disabled", "failure"), 1);
105+
EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
106+
mlir::TF::kMlirPh1BridgeCounterV1,
107+
mlir::TF::kMlirPh1BridgeCounterTpu,
108+
"fallback_disabled", "failure"),
109+
1);
103110
}
104111

105112
} // namespace

tensorflow/compiler/mlir/tf2xla/api/v2/BUILD

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,11 @@ cc_library(
119119
],
120120
deps = [
121121
":device_type_proto_cc",
122-
":tf_dialect_to_executor",
122+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
123123
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
124124
"//tensorflow/compiler/mlir/tensorflow:error_util",
125125
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
126126
"//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass",
127-
"//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops",
128127
"//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes",
129128
"//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks",
130129
"//tensorflow/core:framework",
@@ -133,7 +132,6 @@ cc_library(
133132
"//tensorflow/core/platform:errors",
134133
"//tensorflow/core/platform:stacktrace",
135134
"//tensorflow/core/platform:status",
136-
"//tensorflow/core/tpu:tpu_defs",
137135
"@com_google_absl//absl/log",
138136
"@com_google_absl//absl/status",
139137
"@llvm-project//llvm:Support",
@@ -143,7 +141,6 @@ cc_library(
143141
"@llvm-project//mlir:Support",
144142
"@local_tsl//tsl/platform:error_logging",
145143
"@local_tsl//tsl/platform:errors",
146-
"@local_tsl//tsl/platform:status",
147144
],
148145
)
149146

@@ -159,6 +156,7 @@ tf_cc_test(
159156
":cluster_tf",
160157
"//tensorflow/compiler/mlir:register_common_dialects",
161158
"//tensorflow/compiler/mlir/tensorflow",
159+
"//tensorflow/compiler/mlir/tensorflow:attribute_utils",
162160
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib",
163161
"//tensorflow/core/lib/monitoring:cell_reader",
164162
"//tensorflow/core/platform:resource_loader",

tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "mlir/Support/LogicalResult.h" // from @llvm-project
2929
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
3030
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3132
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
3233
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
3334
#include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h"
@@ -52,8 +53,6 @@ using mlir::OpPassManager;
5253
using mlir::PassManager;
5354
using mlir::func::FuncOp;
5455

55-
constexpr char kBridgeComponent[] = "TFXLABridge";
56-
5756
// Run the TF XLA Bridge based on the input pipeline, which can be either TPU
5857
// bridge pipeline or non TPU bridge pipeline.
5958
tensorflow::Status RunTFXLABridge(
@@ -114,6 +113,7 @@ tensorflow::Status RunTFXLABridge(
114113

115114
tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
116115
bool fallback_enabled,
116+
std::string bridge_type,
117117
std::string device_type,
118118
absl::Status status) {
119119
if (status.ok()) {
@@ -122,7 +122,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
122122

123123
VLOG(2) << error_prefix << " " << status;
124124
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
125-
device_type, /*bridge_version=*/"v2",
125+
/*bridge_type*/ bridge_type, /*bridge_version=*/"v2", device_type,
126126
/*fallback_enabled=*/fallback_enabled,
127127
/*result=*/"failure");
128128

@@ -135,7 +135,7 @@ tensorflow::Status RecordIfErrorStatus(const std::string error_prefix,
135135
bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE";
136136
}
137137

138-
tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent,
138+
tsl::error_logging::Log(mlir::TF::kBridgeComponent, bridge_subcomponent,
139139
status.ToString())
140140
.IgnoreError();
141141

@@ -162,8 +162,9 @@ void CreateReplicatedClusteringPipelineV2(OpPassManager &pm) {
162162
tensorflow::Status RunFunctionTf2xlaClusteringBridge(
163163
ModuleOp module, bool is_supported_by_replicated_brige,
164164
bool is_in_fallback_enabled_mode, llvm::StringRef module_name) {
165-
std::string device_type_filter =
166-
is_supported_by_replicated_brige ? "tpu" : "cpu/gpu";
165+
std::string device_type = is_supported_by_replicated_brige
166+
? mlir::TF::kMlirPh1BridgeCounterTpu
167+
: mlir::TF::kMlirPh1BridgeCounterNonTpu;
167168

168169
VLOG(2)
169170
<< (is_supported_by_replicated_brige ? "Replicated" : "NonReplicated")
@@ -186,14 +187,17 @@ tensorflow::Status RunFunctionTf2xlaClusteringBridge(
186187
},
187188
module_name, /*dump_prefix=*/"tf_xla_bridge_v2_nonreplicated");
188189

190+
std::string bridge_type = is_supported_by_replicated_brige
191+
? mlir::TF::kMlirPh1BridgeCounterReplicated
192+
: mlir::TF::kMlirPh1BridgeCounterNonReplicated;
189193
// TODO(b/317798386): add is_supported_by_replicated_brige as a filter.
190194
TF_RETURN_IF_ERROR(RecordIfErrorStatus(
191195
/*error_prefix=*/"clustering_v2", is_in_fallback_enabled_mode,
192-
device_type_filter, clustering_status));
196+
bridge_type, device_type, clustering_status));
193197

194198
// TODO(b/317798386): add is_supported_by_replicated_brige as a filter.
195199
tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter(
196-
device_type_filter, /*bridge_version=*/"v2",
200+
bridge_type, /*bridge_version=*/"v2", device_type,
197201
/*fallback_enabled=*/is_in_fallback_enabled_mode,
198202
/*result=*/"success");
199203

tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "tensorflow/compiler/mlir/register_common_dialects.h"
3232
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
3333
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
34+
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
3435
#include "tensorflow/core/lib/monitoring/cell_reader.h"
3536
#include "tensorflow/core/platform/resource_loader.h"
3637
#include "tsl/lib/core/status_test_util.h"
@@ -94,8 +95,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTfReplicatedBridge) {
9495
FuncOp main = mlir_module_->lookupSymbol<mlir::func::FuncOp>("main");
9596
ASSERT_TRUE(main);
9697

97-
EXPECT_EQ(
98-
compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1);
98+
EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
99+
mlir::TF::kMlirPh1BridgeCounterV2,
100+
mlir::TF::kMlirPh1BridgeCounterTpu,
101+
"fallback_disabled", "success"),
102+
1);
99103
}
100104

101105
TEST_F(FunctionClusterTensorflowDialectTest,
@@ -118,8 +122,11 @@ TEST_F(FunctionClusterTensorflowDialectTest,
118122
});
119123

120124
EXPECT_TRUE(has_cluster_op);
121-
EXPECT_EQ(
122-
compilation_status.Delta("tpu", "v2", "fallback_disabled", "success"), 1);
125+
EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
126+
mlir::TF::kMlirPh1BridgeCounterV2,
127+
mlir::TF::kMlirPh1BridgeCounterTpu,
128+
"fallback_disabled", "success"),
129+
1);
123130
}
124131

125132
TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) {
@@ -135,7 +142,10 @@ TEST_F(FunctionClusterTensorflowDialectTest, ClustersTFNonReplicatedBridge) {
135142
ASSERT_TRUE(main);
136143

137144
EXPECT_EQ(
138-
compilation_status.Delta("cpu/gpu", "v2", "fallback_disabled", "success"),
145+
compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterNonReplicated,
146+
mlir::TF::kMlirPh1BridgeCounterV2,
147+
mlir::TF::kMlirPh1BridgeCounterNonTpu,
148+
"fallback_disabled", "success"),
139149
1);
140150
}
141151

@@ -148,8 +158,11 @@ TEST_F(FunctionClusterTensorflowDialectTest, LogsFallbackMode) {
148158
*mlir_module_, /*is_supported_by_replicated_brige*/ true,
149159
/*is_in_fallback_enabled_mode=*/true));
150160

151-
EXPECT_EQ(
152-
compilation_status.Delta("tpu", "v2", "fallback_enabled", "success"), 1);
161+
EXPECT_EQ(compilation_status.Delta(mlir::TF::kMlirPh1BridgeCounterReplicated,
162+
mlir::TF::kMlirPh1BridgeCounterV2,
163+
mlir::TF::kMlirPh1BridgeCounterTpu,
164+
"fallback_enabled", "success"),
165+
1);
153166
}
154167

155168
} // namespace

0 commit comments

Comments
 (0)