Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
- Update variable/function names to improve consistency
- Update final log message to improve clarity
  • Loading branch information
MattConley authored and 2sin18 committed Jun 24, 2021
1 parent ea7a246 commit b940066
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -959,10 +959,11 @@ class AutoMixedPrecisionImpl {
const absl::flat_hash_map<string, TypeAttrId>& write_ops,
const absl::flat_hash_map<string, TypeAttrId>& read_ops,
DataStructureOpsMap* object_clients_map) const;
void RecognizeNodes(int* processable_nodes, int* num_recognized_nodes) const;
void CountRecognizeNodes(
int* num_processable_nodes, int* num_recognized_nodes) const;
void AddWhitelistOps(absl::flat_hash_set<int>* white_set) const;
void PropagateBlackFwdThroughClearAndGray(
absl::flat_hash_set<int>* black_set, int* blacklist_nodes) const;
absl::flat_hash_set<int>* black_set, int* num_blacklist_nodes) const;
void ForceColorMatchBetweenDataStructureOps(
const DataStructureOpsMap& object_clients_map,
absl::flat_hash_set<int>* white_set,
Expand Down Expand Up @@ -1010,17 +1011,15 @@ bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr(
Status AutoMixedPrecisionImpl::PrintSummaryInfo(int num_processable_nodes,
int num_recognized_nodes, int num_whitelist_nodes, int num_blacklist_nodes,
int num_blacklist_affected_nodes) {
LOG(INFO) << "*******************************************************\n"
<< "Automatic Mixed Precision Grappler Pass Summary:\n\n"
LOG(INFO) << "Automatic Mixed Precision Grappler Pass Summary:\n\n"
<< "Total processable nodes: " << num_processable_nodes << "\n"
<< "Listed nodes available for conversion: " << num_recognized_nodes << " ("
<< (int)(100*(double)num_recognized_nodes/(double)num_processable_nodes)
<< " %)\nWhitelisted nodes converted: " << num_whitelist_nodes << "\n"
<< "Recognized nodes available for conversion: " << num_recognized_nodes
<< "\nWhitelisted nodes converted: " << num_whitelist_nodes << "\n"
<< "Blacklisted nodes blocking conversion: " << num_blacklist_nodes << "\n"
<< "Nodes blocked from conversion by blacklisted nodes: "
<< num_blacklist_affected_nodes << "\n\n"
<< "For more information regarding automatic mixed precision, "
<< "including how to augment the op lists to improve casting performance, "
<< "For more information regarding mixed precision training, including "
<< "how to make automatic mixed precision aware of a custom op type, "
<< "please see the documentation available here:\n"
<< "https://docs.nvidia.com/deeplearning/frameworks/"
<< "tensorflow-user-guide/index.html#tfamp\n\n";
Expand Down Expand Up @@ -1359,7 +1358,7 @@ Status AutoMixedPrecisionImpl::Optimize() {
VLOG(2) << "Counting recognized nodes";
int num_processable_nodes = 0;
int num_recognized_nodes = 0;
RecognizeNodes(&num_processable_nodes, &num_recognized_nodes);
CountRecognizeNodes(&num_processable_nodes, &num_recognized_nodes);

absl::flat_hash_set<int> white_set;
VLOG(2) << "Beginning pass 1 to add whitelist ops";
Expand Down Expand Up @@ -1448,13 +1447,13 @@ Status AutoMixedPrecisionImpl::AddDataStructureOpsToMap(
return Status::OK();
}

void AutoMixedPrecisionImpl::RecognizeNodes(
int* processable_nodes, int* num_recognized_nodes) const {
// Add whitelisted ops to white_set.
void AutoMixedPrecisionImpl::CountRecognizeNodes(
int* num_processable_nodes, int* num_recognized_nodes) const {
// Count the number of nodes viable for conversion
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!ShouldProcess(*root.node)) continue;
++*processable_nodes;
++*num_processable_nodes;
if (fp16_whitelist_.count(root.node->op()) ||
fp16_blacklist_.count(root.node->op()) ||
fp16_graylist_.count(root.node->op()) ||
Expand Down Expand Up @@ -1490,7 +1489,7 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
// E.g., black -> gray -> clear -> gray -> clear -> white -> gray
// becomes: black -> black -> black -> black -> clear -> white -> gray.
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
absl::flat_hash_set<int>* black_set, int* blacklist_nodes) const {
absl::flat_hash_set<int>* black_set, int* num_blacklist_nodes) const {
if (force_all_fp16_) return;

// Find clear nodes that are upstream of black or gray.
Expand All @@ -1501,7 +1500,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
fp16_graylist_.count(root.node->op()))) {
continue;
}
if (fp16_blacklist_.count(root.node->op())) ++*blacklist_nodes;
if (fp16_blacklist_.count(root.node->op())) ++*num_blacklist_nodes;
DfsTypeTraversal(graph_type_view_, {&root},
TypeTraversalDirection::kFollowInputs,
DfsTypePredicates::Enter([&](int idx) -> bool {
Expand Down

0 comments on commit b940066

Please sign in to comment.