Skip to content

Commit 96cb8f8

Browse files
Started to open source Grappler. First application is the GPU layout optimizer.
Change: 149558284
1 parent b2a4a7d commit 96cb8f8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+4064
-0
lines changed

tensorflow/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ filegroup(
213213
"//tensorflow/core/debug:all_files",
214214
"//tensorflow/core/distributed_runtime:all_files",
215215
"//tensorflow/core/distributed_runtime/rpc:all_files",
216+
"//tensorflow/core/grappler:all_files",
217+
"//tensorflow/core/grappler/clusters:all_files",
218+
"//tensorflow/core/grappler/costs:all_files",
219+
"//tensorflow/core/grappler/inputs:all_files",
220+
"//tensorflow/core/grappler/optimizers:all_files",
216221
"//tensorflow/core/kernels:all_files",
217222
"//tensorflow/core/kernels/cloud:all_files",
218223
"//tensorflow/core/kernels/hexagon:all_files",

tensorflow/core/grappler/BUILD

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
exports_files(["LICENSE"])
4+
5+
filegroup(
6+
name = "all_files",
7+
srcs = glob(
8+
["**/*"],
9+
exclude = [
10+
"**/METADATA",
11+
"**/OWNERS",
12+
],
13+
),
14+
visibility = ["//tensorflow:__subpackages__"],
15+
)
16+
17+
cc_library(
18+
name = "utils",
19+
srcs = ["utils.cc"],
20+
hdrs = ["utils.h"],
21+
visibility = ["//visibility:public"],
22+
deps = [
23+
"//tensorflow/core:gpu_runtime",
24+
"//tensorflow/core:lib",
25+
"//tensorflow/core:lib_internal",
26+
"//tensorflow/core:stream_executor",
27+
],
28+
)
29+
30+
cc_test(
31+
name = "utils_test",
32+
srcs = ["utils_test.cc"],
33+
deps = [
34+
":utils",
35+
"//tensorflow/core:test",
36+
"//tensorflow/core:test_main",
37+
],
38+
)
39+
40+
cc_library(
41+
name = "grappler_item",
42+
srcs = ["grappler_item.cc"],
43+
hdrs = ["grappler_item.h"],
44+
visibility = ["//visibility:public"],
45+
deps = [
46+
":utils",
47+
"//tensorflow/cc:cc_ops",
48+
"//tensorflow/core:core_cpu_internal",
49+
"//tensorflow/core:framework",
50+
"//tensorflow/core:lib_internal",
51+
"//tensorflow/core:protos_all_cc",
52+
"//tensorflow/core:tensorflow",
53+
"//tensorflow/core/grappler/inputs:utils",
54+
],
55+
)
56+
57+
cc_test(
58+
name = "grappler_item_test",
59+
srcs = ["grappler_item_test.cc"],
60+
deps = [
61+
":grappler_item",
62+
"//tensorflow/core:protos_all_cc",
63+
"//tensorflow/core:test",
64+
"//tensorflow/core:test_main",
65+
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
66+
],
67+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
exports_files(["LICENSE"])
4+
5+
filegroup(
6+
name = "all_files",
7+
srcs = glob(
8+
["**/*"],
9+
exclude = [
10+
"**/METADATA",
11+
"**/OWNERS",
12+
],
13+
),
14+
visibility = ["//tensorflow:__subpackages__"],
15+
)
16+
17+
cc_library(
18+
name = "cluster",
19+
srcs = ["cluster.cc"],
20+
hdrs = [
21+
"cluster.h",
22+
],
23+
visibility = ["//visibility:public"],
24+
deps = [
25+
"//tensorflow/core:core_cpu",
26+
"//tensorflow/core:framework",
27+
"//tensorflow/core:lib",
28+
"//tensorflow/core:protos_all_cc",
29+
"//tensorflow/core/grappler:grappler_item",
30+
],
31+
)
32+
33+
cc_library(
34+
name = "single_machine",
35+
srcs = ["single_machine.cc"],
36+
hdrs = [
37+
"single_machine.h",
38+
],
39+
visibility = ["//visibility:public"],
40+
deps = [
41+
":cluster",
42+
"//tensorflow/cc:coordinator",
43+
"//tensorflow/cc:queue_runner",
44+
"//tensorflow/core:core_cpu",
45+
"//tensorflow/core:direct_session",
46+
"//tensorflow/core:lib",
47+
"//tensorflow/core/kernels:ops_util",
48+
],
49+
)
50+
51+
cc_test(
52+
name = "single_machine_test",
53+
srcs = ["single_machine_test.cc"],
54+
args = ["--heap_check=local"], # The GPU tracer leaks memory
55+
deps = [
56+
":single_machine",
57+
"//tensorflow/core:lib_proto_parsing",
58+
"//tensorflow/core:protos_all_cc",
59+
"//tensorflow/core:test",
60+
"//tensorflow/core:test_main",
61+
"//tensorflow/core/grappler:grappler_item",
62+
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
63+
],
64+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/grappler/clusters/cluster.h"
17+
#include <atomic>
18+
19+
namespace tensorflow {
20+
namespace grappler {
21+
22+
static std::atomic<bool> already_created(false);
23+
24+
Cluster::Cluster(int timeout_s) : timeout_s_(timeout_s) {
25+
// This is really ugly: to avoid leaking variables, we need to reset the tf
26+
// session every time we're done processing a grappler item. However,
27+
// variables are global, and therefore we can't have more than 1 session alive
28+
// at a time. This check detects when more that one cluster is created.
29+
CHECK(!already_created);
30+
already_created = true;
31+
32+
options_.config.mutable_graph_options()->set_build_cost_model(1);
33+
34+
run_options_.set_trace_level(RunOptions::HARDWARE_TRACE);
35+
}
36+
37+
Cluster::~Cluster() {
38+
CHECK(already_created);
39+
already_created = false;
40+
}
41+
42+
void Cluster::SetNumWarmupSteps(int num_steps) {
43+
options_.config.mutable_graph_options()->set_build_cost_model_after(
44+
num_steps);
45+
}
46+
47+
} // end namespace grappler
48+
} // end namespace tensorflow
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
17+
#define TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
18+
19+
#include <string>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include "tensorflow/core/framework/device_attributes.pb.h"
24+
#include "tensorflow/core/framework/tensor.h"
25+
#include "tensorflow/core/grappler/grappler_item.h"
26+
#include "tensorflow/core/lib/core/status.h"
27+
#include "tensorflow/core/public/session.h"
28+
29+
namespace tensorflow {
30+
namespace grappler {
31+
32+
// A cluster represents of collection of hardware resources available to run
33+
// the TensorFlow model.
34+
// A process can only create a single cluster at a time.
35+
class Cluster {
36+
public:
37+
explicit Cluster(int timeout_s);
38+
virtual ~Cluster();
39+
40+
// Provision the hardware resources needed to run TensorFlow and start a
41+
// TensorFlow session that can take advantage of these resources.
42+
// The actual resources that are leveraged depend on the type of cluster
43+
// instantiated.
44+
// Returns OK iff all the requested resources could be reserved and a
45+
// TensorFlow session successfully created. Returns an error otherwise.
46+
// There is no graceful degradation to handle the case where only a subset
47+
// of the requested resources are available.
48+
virtual Status Provision() = 0;
49+
50+
// Set the number of steps required to warmup TensorFlow. Must be called
51+
// before Provision().
52+
void SetNumWarmupSteps(int num_steps);
53+
54+
// Return the list of TensorFlow devices that are available to execute a
55+
// graph. This is empty until provision() is called.
56+
const std::vector<DeviceAttributes>& GetDevices() const { return devices_; }
57+
58+
// Convenience method that returns the set of device names.
59+
const std::vector<string> GetDeviceNames() const {
60+
std::vector<string> device_names;
61+
device_names.reserve(devices_.size());
62+
for (const auto& device : devices_) {
63+
device_names.push_back(device.name());
64+
}
65+
return device_names;
66+
}
67+
68+
// Prepare the session to run the specified grappler item. This include
69+
// initializing all the model variables.
70+
virtual Status Initialize(const GrapplerItem& item) = 0;
71+
72+
// Run the specified graph_def and return the corresponding metadata.
73+
virtual Status Run(const GraphDef& graph_def,
74+
const std::vector<std::pair<string, Tensor>>& feed,
75+
const std::vector<string>& fetch,
76+
RunMetadata* metadata) = 0;
77+
78+
protected:
79+
std::vector<DeviceAttributes> devices_;
80+
const int timeout_s_;
81+
SessionOptions options_;
82+
RunOptions run_options_;
83+
};
84+
85+
} // end namespace grappler
86+
} // end namespace tensorflow
87+
88+
#endif // TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_

0 commit comments

Comments
 (0)