Skip to content

Commit 0b14ae4

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[XLA] Remove LocalClient::ExecuteLocally(), in lieu of LocalClient::Compile() and LocalExecutable::Run().
Change: 149482633
1 parent b69dd29 commit 0b14ae4

13 files changed

+127
-361
lines changed

tensorflow/compiler/jit/kernels/xla_local_launch_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ namespace tensorflow {
3131
// Once all inputs are present, and their shapes are known, the op can
3232
// use a 'XlaCompilationCache' to compile and execute code which is specific
3333
// to the shapes of input Tensors.
34-
// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes
35-
// arguments into/out of XLA in device memory.
34+
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
35+
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
36+
// memory.
3637
class XlaLocalLaunchOp : public OpKernel {
3738
public:
3839
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);

tensorflow/compiler/xla/client/local_client.cc

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
8787

8888
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
8989
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
90-
const ExecutableRunOptions& options) {
90+
const ExecutableRunOptions& options, const Backend& backend) {
9191
const ComputationLayout& computation_layout =
9292
executable_->module_config().entry_computation_layout();
9393

@@ -156,13 +156,24 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
156156
run_executor->GetDeviceDescription().name().c_str());
157157
}
158158

159+
if (!options.allocator()) {
160+
return InvalidArgument("an allocator must be provided to ExecuteLocally");
161+
}
162+
163+
if (options.allocator()->platform() != backend.platform()) {
164+
return InvalidArgument(
165+
"allocator platform (%s) does not match service platform (%s)",
166+
options.allocator()->platform()->Name().c_str(),
167+
backend.platform()->Name().c_str());
168+
}
169+
159170
return tensorflow::Status::OK();
160171
}
161172

162173
StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
163174
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
164175
const ExecutableRunOptions& options) {
165-
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options));
176+
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_));
166177

167178
ExecutableRunOptions actual_options = options;
168179
Backend::StreamPtr stream;
@@ -180,8 +191,16 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
180191
if (executable_->dumping()) {
181192
return ExecuteAndDump(&service_options, arguments);
182193
}
183-
return executable_->ExecuteOnStream(&service_options, arguments,
184-
/*hlo_execution_profile=*/nullptr);
194+
return Service::ExecuteOnStreamWrapper<
195+
StatusOr<std::unique_ptr<ShapedBuffer>>>(
196+
executable_.get(), &service_options, options.execution_profile(),
197+
backend_,
198+
[&arguments](Executable* executable,
199+
const ServiceExecutableRunOptions* run_options,
200+
HloExecutionProfile* hlo_execution_profile) {
201+
return executable->ExecuteOnStream(run_options, arguments,
202+
hlo_execution_profile);
203+
});
185204
}
186205

187206
StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::ExecuteAndDump(
@@ -242,14 +261,6 @@ tensorflow::Status LocalClient::ResolveArguments(
242261
argument_ptrs);
243262
}
244263

245-
StatusOr<std::unique_ptr<ShapedBuffer>> LocalClient::ExecuteLocally(
246-
const Computation& computation,
247-
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
248-
const LocalExecuteOptions& options) {
249-
return local_service_->ExecuteLocally(computation.handle(), arguments,
250-
options);
251-
}
252-
253264
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
254265
LocalClient::CompileAheadOfTime(
255266
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>

tensorflow/compiler/xla/client/local_client.h

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class LocalExecutable {
111111
// of the computation.
112112
tensorflow::Status ValidateExecutionOptions(
113113
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
114-
const ExecutableRunOptions& options);
114+
const ExecutableRunOptions& options, const Backend& backend);
115115

116116
// Records the computation in a SessionModule proto with the arguments used to
117117
// invoke it, and the result. Enabled by flag: --tla_dump_executions_to.
@@ -175,25 +175,6 @@ class LocalClient : public Client {
175175
const Shape& shape, int device_ordinal,
176176
bool allocate_space_for_deep_copy);
177177

178-
// Executes the given computation with the given arguments and
179-
// options. Arguments and result are "zero-copy", and are passed as pointers
180-
// to device memory. See LocalExecuteOptions class comments for description of
181-
// available options. The returned ShapedBuffer includes pointer(s) to device
182-
// memory (DeviceMemoryBase) which are the caller's responsibility to
183-
// deallocate. The layout of the result is chosen by the XLA service and
184-
// should not be relied upon to be a specific value. If a specific result
185-
// layout is needed, then the layout should be set in options.
186-
//
187-
// The arrays of arguments with different shapes or layouts are assumed not to
188-
// alias.
189-
//
190-
// TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use
191-
// Compile and run the returned LocalExecutable.
192-
StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteLocally(
193-
const Computation& computation,
194-
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
195-
const LocalExecuteOptions& options);
196-
197178
// Build and return a LocalExecutable object. The executable is compiled using
198179
// the given argument layouts and options.
199180
StatusOr<std::unique_ptr<LocalExecutable>> Compile(

tensorflow/compiler/xla/executable_run_options.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,14 @@ const Eigen::ThreadPoolDevice* ExecutableRunOptions::intra_op_thread_pool()
6767
return intra_op_thread_pool_;
6868
}
6969

70+
ExecutableRunOptions& ExecutableRunOptions::set_execution_profile(
71+
ExecutionProfile* profile) {
72+
execution_profile_ = profile;
73+
return *this;
74+
}
75+
76+
ExecutionProfile* ExecutableRunOptions::execution_profile() const {
77+
return execution_profile_;
78+
}
79+
7080
} // namespace xla

tensorflow/compiler/xla/executable_run_options.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct ThreadPoolDevice;
4040
namespace xla {
4141

4242
class DeviceMemoryAllocator;
43+
class ExecutionProfile;
4344

4445
// Class containing options for running a LocalExecutable.
4546
class ExecutableRunOptions {
@@ -74,12 +75,17 @@ class ExecutableRunOptions {
7475
const Eigen::ThreadPoolDevice* intra_op_thread_pool);
7576
const Eigen::ThreadPoolDevice* intra_op_thread_pool() const;
7677

78+
// If set, profiling information is written to 'profile'.
79+
ExecutionProfile* execution_profile() const;
80+
ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
81+
7782
private:
7883
DeviceMemoryAllocator* allocator_ = nullptr;
7984
int device_ordinal_ = -1;
8085
perftools::gputools::Stream* stream_ = nullptr;
8186
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
8287
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
88+
ExecutionProfile* execution_profile_ = nullptr;
8389
};
8490

8591
} // namespace xla

0 commit comments

Comments
 (0)