Skip to content

Commit 208dade

Browse files
Monica Songtensorflower-gardener
authored andcommitted
[SavedModel Fingerprinting] Add hash 2 out of 5 for fingerprinting.
This CL adds another field (graph_def_program_hash) to the FingerprintDef protobuf. RFC: tensorflow/community#415 PiperOrigin-RevId: 464154116
1 parent cf247a6 commit 208dade

File tree

5 files changed

+77
-11
lines changed

5 files changed

+77
-11
lines changed

tensorflow/cc/saved_model/fingerprinting.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ limitations under the License.
2020
#include "tensorflow/core/framework/attr_value.pb.h"
2121
#include "tensorflow/core/framework/function.pb.h"
2222
#include "tensorflow/core/framework/op_def.pb.h"
23+
#include "tensorflow/core/framework/types.pb.h"
24+
#include "tensorflow/core/framework/versions.pb.h"
2325
#include "tensorflow/core/grappler/op_types.h"
2426
#include "tensorflow/core/lib/strings/proto_serialization.h"
2527
#include "tensorflow/core/platform/fingerprint.h"
@@ -31,7 +33,8 @@ namespace tensorflow::fingerprinting {
3133

3234
namespace {
3335

34-
// This function mutates the GraphDef, changing the names of the Function nodes.
36+
// This function mutates the GraphDef, changing the names and config_proto's
37+
// of the Function nodes.
3538
void CanonicalizeNodes(GraphDef* orig_graph_def) {
3639
for (NodeDef& node : *orig_graph_def->mutable_node()) {
3740
// Check if this is a function call.
@@ -42,6 +45,15 @@ void CanonicalizeNodes(GraphDef* orig_graph_def) {
4245
// and StatefulPartitionedCall ops.
4346
node.mutable_attr()->find("f")->second.mutable_func()->set_name(
4447
"FINGERPRINT_PASS");
48+
// Erase the "config_proto" attribute which contains device-specific
49+
// information.
50+
node.mutable_attr()->find("config_proto")->second.mutable_s()->erase();
51+
}
52+
// Erase the value of string constants, which can vary based on platform.
53+
if (grappler::IsConstant(node)) {
54+
if (node.attr().at("dtype").type() == DT_STRING) {
55+
node.mutable_attr()->find("value")->second.clear_value();
56+
}
4557
}
4658
}
4759
}
@@ -55,8 +67,15 @@ uint64 ComputeHash(const GraphDef& graph_def) {
5567
}
5668

5769
FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph) {
70+
// Create a copy of `metagraph` which will be used and mutated for fingerprint
71+
// computation.
72+
MetaGraphDef metagraph_copy = metagraph;
5873
FingerprintDef fingerprint_def;
59-
fingerprint_def.set_graph_def_hash(ComputeHash(metagraph.graph_def()));
74+
fingerprint_def.set_graph_def_checksum(
75+
ComputeHash(metagraph_copy.graph_def()));
76+
CanonicalizeGraphDef(*metagraph_copy.mutable_graph_def());
77+
fingerprint_def.set_graph_def_program_hash(
78+
ComputeHash(metagraph_copy.graph_def()));
6079
return fingerprint_def;
6180
}
6281

@@ -67,6 +86,7 @@ void CanonicalizeGraphDef(GraphDef& graph_def) {
6786
// TODO(b/240173815): Complete canonicalization of the FunctionDefLibrary.
6887
// For now, we just clear the FunctionDefLibrary.
6988
graph_def.mutable_library()->Clear();
89+
graph_def.mutable_versions()->Clear();
7090
}
7191

7292
} // namespace tensorflow::fingerprinting

tensorflow/cc/saved_model/fingerprinting_test.cc

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "tensorflow/core/framework/graph.pb.h"
2222
#include "tensorflow/core/framework/node_def.pb.h"
2323
#include "tensorflow/core/lib/core/status_test_util.h"
24+
#include "tensorflow/core/platform/errors.h"
2425
#include "tensorflow/core/platform/path.h"
2526
#include "tensorflow/core/platform/test.h"
2627
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -65,10 +66,10 @@ TEST(FingerprintingTest, TestCreateFingerprint) {
6566

6667
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
6768
ReadSavedModel(export_dir));
68-
MetaGraphDef metagraph = saved_model_pb.meta_graphs(0);
69-
FingerprintDef fingerprint_def = CreateFingerprintDef(metagraph);
69+
FingerprintDef fingerprint_def =
70+
CreateFingerprintDef(saved_model_pb.meta_graphs(0));
7071

71-
EXPECT_GT(fingerprint_def.graph_def_hash(), 0);
72+
EXPECT_GT(fingerprint_def.graph_def_checksum(), 0);
7273
}
7374

7475
// Test that canonicalization returns the same hash for two models saved by
@@ -98,5 +99,44 @@ TEST(FingerprintingTest, TestCanonicalizeGraphDeforModelSavedTwice) {
9899
EXPECT_EQ(hash1, hash2);
99100
}
100101

102+
// Compare the fingerprints of two models saved by calling
103+
// `tf.saved_model.save` twice in a row in the same program.
104+
TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) {
105+
const std::string export_dir =
106+
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
107+
"bert1", "saved_model.pb");
108+
109+
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
110+
ReadSavedModel(export_dir));
111+
FingerprintDef fingerprint_def =
112+
CreateFingerprintDef(saved_model_pb.meta_graphs(0));
113+
114+
const std::string export_dir2 =
115+
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
116+
"bert2", "saved_model.pb");
117+
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2,
118+
ReadSavedModel(export_dir2));
119+
FingerprintDef fingerprint_def2 =
120+
CreateFingerprintDef(saved_model_pb2.meta_graphs(0));
121+
122+
EXPECT_EQ(fingerprint_def.graph_def_program_hash(),
123+
fingerprint_def2.graph_def_program_hash());
124+
}
125+
126+
TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) {
127+
const std::string export_dir =
128+
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
129+
"bert1", "saved_model.pb");
130+
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
131+
ReadSavedModel(export_dir));
132+
FingerprintDef fingerprint_def =
133+
CreateFingerprintDef(saved_model_pb.meta_graphs(0));
134+
FingerprintDef fingerprint_def2 =
135+
CreateFingerprintDef(saved_model_pb.meta_graphs(0));
136+
137+
EXPECT_EQ(fingerprint_def.graph_def_checksum(),
138+
fingerprint_def2.graph_def_checksum());
139+
}
140+
101141
} // namespace
102142
} // namespace tensorflow::fingerprinting

tensorflow/core/protobuf/fingerprint.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
1313
// If there are multiple MetaGraphDefs in the SavedModel, the FingerprintDef
1414
// corresponds to the first one.
1515
message FingerprintDef {
16-
// Hash of the graph_def.
17-
uint64 graph_def_hash = 1;
16+
// Hash of the graph_def, referred to as a "checksum".
17+
uint64 graph_def_checksum = 1;
18+
// Hash of regularized graph_def.
19+
uint64 graph_def_program_hash = 2;
1820
}

tensorflow/python/saved_model/fingerprinting_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def test_basic_module(self):
5555
fingerprint_def = self._read_fingerprint(
5656
file_io.join(save_dir, constants.FINGERPRINT_FILENAME))
5757
# We cannot check the value due to non-determinism in serialization.
58-
self.assertGreater(fingerprint_def.graph_def_hash, 0)
58+
self.assertGreater(fingerprint_def.graph_def_checksum, 0)
59+
self.assertEqual(fingerprint_def.graph_def_program_hash,
60+
16358308617800096964)
5961

6062

6163
if __name__ == "__main__":

tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ def test_graphdef_basic(self):
3131
fingerprint_def = fingerprint_pb2.FingerprintDef()
3232
fingerprint_def.ParseFromString(
3333
fingerprinting.CreateFingerprintDef(file_content))
34-
# We cannot check the value of the graph_def_hash due to non-determinism in
35-
# serialization.
36-
self.assertGreater(fingerprint_def.graph_def_hash, 0)
34+
# We cannot check the value of the graph_def_checksum due to non-determinism
35+
# in serialization.
36+
self.assertGreater(fingerprint_def.graph_def_checksum, 0)
37+
self.assertEqual(fingerprint_def.graph_def_program_hash,
38+
13188891313422428336)
3739

3840

3941
if __name__ == "__main__":

0 commit comments

Comments
 (0)