Skip to content

Commit b03a72c

Browse files
rryantensorflower-gardener
authored andcommitted
Add tf.spectral, a module for spectral operations.
* Move existing FFT ops to tf.spectral. * Add ops for computing 1D, 2D and 3D Fourier transforms of real signals. * Define a gradient for the 1D and 2D transforms. Change: 149504891
1 parent e83a041 commit b03a72c

23 files changed

+1390
-309
lines changed

tensorflow/contrib/cmake/tf_core_ops.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ set(tf_op_lib_names
2121
"set_ops"
2222
"sendrecv_ops"
2323
"sparse_ops"
24+
"spectral_ops"
2425
"state_ops"
2526
"string_ops"
2627
"training_ops"

tensorflow/contrib/cmake/tf_python.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ GENERATE_PYTHON_OP_LIB("sdca_ops")
520520
GENERATE_PYTHON_OP_LIB("set_ops")
521521
GENERATE_PYTHON_OP_LIB("state_ops")
522522
GENERATE_PYTHON_OP_LIB("sparse_ops")
523+
GENERATE_PYTHON_OP_LIB("spectral_ops")
523524
GENERATE_PYTHON_OP_LIB("string_ops")
524525
GENERATE_PYTHON_OP_LIB("user_ops")
525526
GENERATE_PYTHON_OP_LIB("training_ops"

tensorflow/core/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ tf_gen_op_libs(
496496
"script_ops",
497497
"sendrecv_ops",
498498
"sparse_ops",
499+
"spectral_ops",
499500
"state_ops",
500501
"string_ops",
501502
"training_ops",
@@ -557,6 +558,7 @@ cc_library(
557558
":sendrecv_ops_op_lib",
558559
":set_ops_op_lib",
559560
":sparse_ops_op_lib",
561+
":spectral_ops_op_lib",
560562
":state_ops_op_lib",
561563
":string_ops_op_lib",
562564
":training_ops_op_lib",
@@ -2498,6 +2500,7 @@ tf_cc_tests(
24982500
"ops/random_ops_test.cc",
24992501
"ops/set_ops_test.cc",
25002502
"ops/sparse_ops_test.cc",
2503+
"ops/spectral_ops_test.cc",
25012504
"ops/state_ops_test.cc",
25022505
"ops/string_ops_test.cc",
25032506
"ops/training_ops_test.cc",

tensorflow/core/kernels/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2128,7 +2128,9 @@ tf_kernel_library(
21282128
tf_kernel_library(
21292129
name = "fft_ops",
21302130
prefix = "fft_ops",
2131-
deps = MATH_DEPS,
2131+
deps = MATH_DEPS + [
2132+
"//tensorflow/core:spectral_ops_op_lib",
2133+
],
21322134
)
21332135

21342136
tf_kernel_library(

tensorflow/core/kernels/fft_ops.cc

Lines changed: 145 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -49,72 +49,139 @@ class FFTGPUBase : public OpKernel {
4949
void Compute(OpKernelContext* ctx) override {
5050
const Tensor& in = ctx->input(0);
5151
const TensorShape& shape = in.shape();
52+
const int fft_rank = Rank();
5253
OP_REQUIRES(
53-
ctx, shape.dims() >= Rank(),
54-
errors::InvalidArgument("Input must have rank of at least ", Rank(),
54+
ctx, shape.dims() >= fft_rank,
55+
errors::InvalidArgument("Input must have rank of at least ", fft_rank,
5556
" but got: ", shape.DebugString()));
57+
5658
Tensor* out;
57-
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &out));
59+
TensorShape output_shape = shape;
60+
uint64 fft_shape[3] = {0, 0, 0};
61+
62+
// In R2C or C2R mode, we use a second input to specify the FFT length
63+
// instead of inferring it from the input shape.
64+
if (IsReal()) {
65+
const Tensor& fft_length = ctx->input(1);
66+
OP_REQUIRES(ctx,
67+
fft_length.shape().dims() == 1 &&
68+
fft_length.shape().dim_size(0) == fft_rank,
69+
errors::InvalidArgument("fft_length must have shape [",
70+
fft_rank, "]"));
71+
72+
auto fft_length_as_vec = fft_length.vec<int32>();
73+
for (int i = 0; i < fft_rank; ++i) {
74+
fft_shape[i] = fft_length_as_vec(i);
75+
uint64 dim = IsForward() && i == fft_rank - 1 && fft_shape[i] != 0
76+
? fft_shape[i] / 2 + 1
77+
: fft_shape[i];
78+
output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
79+
}
80+
} else {
81+
for (int i = 0; i < fft_rank; ++i) {
82+
fft_shape[i] =
83+
output_shape.dim_size(output_shape.dims() - fft_rank + i);
84+
}
85+
}
86+
87+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
5888
if (shape.num_elements() == 0) {
5989
return;
6090
}
61-
DoFFT(ctx, in, out);
91+
92+
DoFFT(ctx, in, fft_shape, out);
6293
}
6394

6495
protected:
6596
virtual int Rank() const = 0;
6697
virtual bool IsForward() const = 0;
98+
virtual bool IsReal() const = 0;
6799

68100
private:
69-
void DoFFT(OpKernelContext* ctx, const Tensor& in, Tensor* out) {
101+
void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
102+
Tensor* out) {
70103
auto* stream = ctx->op_device_context()->stream();
71104
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
72105

73-
const TensorShape& shape = in.shape();
74-
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
75-
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
106+
const TensorShape& input_shape = in.shape();
107+
const TensorShape& output_shape = out->shape();
76108

77-
const int rank = Rank();
109+
const int fft_rank = Rank();
78110
int batch_size = 1;
79-
for (int i = 0; i < shape.dims() - rank; ++i) {
80-
batch_size *= shape.dim_size(i);
111+
for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
112+
batch_size *= input_shape.dim_size(i);
81113
}
82-
uint64 data_length = 1;
83-
uint64 data_dims[3];
84-
for (int i = 0; i < rank; ++i) {
85-
auto dim = shape.dim_size(shape.dims() - rank + i);
86-
data_length *= dim;
87-
data_dims[i] = dim;
114+
uint64 input_embed[3];
115+
uint64 input_stride = 1;
116+
uint64 input_distance = 1;
117+
uint64 output_embed[3];
118+
uint64 output_stride = 1;
119+
uint64 output_distance = 1;
120+
121+
for (int i = 0; i < fft_rank; ++i) {
122+
auto dim_offset = input_shape.dims() - fft_rank + i;
123+
input_embed[i] = input_shape.dim_size(dim_offset);
124+
input_distance *= input_shape.dim_size(dim_offset);
125+
output_embed[i] = output_shape.dim_size(dim_offset);
126+
output_distance *= output_shape.dim_size(dim_offset);
88127
}
89128

90-
constexpr uint64* kInputEmbed = nullptr;
91-
constexpr uint64 kInputStride = 1;
92-
constexpr uint64 kInputDistance = 1;
93-
constexpr uint64* kOutputEmbed = nullptr;
94-
constexpr uint64 kOutputStride = 1;
95-
constexpr uint64 kOutputDistance = 1;
96129
constexpr bool kInPlaceFft = false;
130+
const auto kFftType =
131+
IsReal() ? (IsForward() ? perftools::gputools::fft::Type::kR2C
132+
: perftools::gputools::fft::Type::kC2R)
133+
: (IsForward() ? perftools::gputools::fft::Type::kC2CForward
134+
: perftools::gputools::fft::Type::kC2CInverse);
97135

98136
auto plan = stream->parent()->AsFft()->CreateBatchedPlan(
99-
stream, rank, data_dims, kInputEmbed, kInputStride, kInputDistance,
100-
kOutputEmbed, kOutputStride, kOutputDistance,
101-
IsForward() ? perftools::gputools::fft::Type::kC2CForward
102-
: perftools::gputools::fft::Type::kC2CInverse,
103-
kInPlaceFft, batch_size);
104-
105-
OP_REQUIRES(
106-
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
107-
errors::Internal("c2c fft failed : in.shape=", shape.DebugString()));
108-
if (!IsForward()) {
109-
auto alpha = complex64(1.f / data_length);
137+
stream, fft_rank, fft_shape, input_embed, input_stride, input_distance,
138+
output_embed, output_stride, output_distance, kFftType, kInPlaceFft,
139+
batch_size);
140+
141+
if (IsReal()) {
142+
if (IsForward()) {
143+
auto src = AsDeviceMemory<float>(in.flat<float>().data());
144+
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
145+
OP_REQUIRES(
146+
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
147+
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
148+
" in.shape=", input_shape.DebugString()));
149+
} else {
150+
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
151+
auto dst = AsDeviceMemory<float>(out->flat<float>().data());
152+
OP_REQUIRES(
153+
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
154+
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
155+
" in.shape=", input_shape.DebugString()));
156+
auto alpha = 1.f / output_distance;
157+
OP_REQUIRES(
158+
ctx,
159+
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
160+
.ok(),
161+
errors::Internal("BlasScal failed : in.shape=",
162+
input_shape.DebugString()));
163+
}
164+
} else {
165+
auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
166+
auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
110167
OP_REQUIRES(
111-
ctx, stream->ThenBlasScal(shape.num_elements(), alpha, &dst, 1).ok(),
112-
errors::Internal("BlasScal failed : in.shape=", shape.DebugString()));
168+
ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
169+
errors::Internal("fft failed : type=", static_cast<int>(kFftType),
170+
" in.shape=", input_shape.DebugString()));
171+
if (!IsForward()) {
172+
auto alpha = complex64(1.f / output_distance);
173+
OP_REQUIRES(
174+
ctx,
175+
stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
176+
.ok(),
177+
errors::Internal("BlasScal failed : in.shape=",
178+
input_shape.DebugString()));
179+
}
113180
}
114181
}
115182
};
116183

117-
template <bool Forward, int FFTRank>
184+
template <bool Forward, bool _Real, int FFTRank>
118185
class FFTGPU : public FFTGPUBase {
119186
public:
120187
static_assert(FFTRank >= 1 && FFTRank <= 3,
@@ -124,24 +191,53 @@ class FFTGPU : public FFTGPUBase {
124191
protected:
125192
int Rank() const override { return FFTRank; }
126193
bool IsForward() const override { return Forward; }
194+
bool IsReal() const override { return _Real; }
127195
};
128196

129-
REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU), FFTGPU<true, 1>);
130-
REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU), FFTGPU<false, 1>);
131-
REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU), FFTGPU<true, 2>);
132-
REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU), FFTGPU<false, 2>);
133-
REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU), FFTGPU<true, 3>);
134-
REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU), FFTGPU<false, 3>);
197+
REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU), FFTGPU<true, false, 1>);
198+
REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU),
199+
FFTGPU<false, false, 1>);
200+
REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU),
201+
FFTGPU<true, false, 2>);
202+
REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU),
203+
FFTGPU<false, false, 2>);
204+
REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU),
205+
FFTGPU<true, false, 3>);
206+
REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU),
207+
FFTGPU<false, false, 3>);
208+
209+
REGISTER_KERNEL_BUILDER(
210+
Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
211+
FFTGPU<true, true, 1>);
212+
REGISTER_KERNEL_BUILDER(
213+
Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
214+
FFTGPU<false, true, 1>);
215+
REGISTER_KERNEL_BUILDER(
216+
Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
217+
FFTGPU<true, true, 2>);
218+
REGISTER_KERNEL_BUILDER(
219+
Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
220+
FFTGPU<false, true, 2>);
221+
REGISTER_KERNEL_BUILDER(
222+
Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
223+
FFTGPU<true, true, 3>);
224+
REGISTER_KERNEL_BUILDER(
225+
Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
226+
FFTGPU<false, true, 3>);
135227

136228
// Deprecated kernels.
137-
REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU), FFTGPU<true, 1>);
138-
REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU), FFTGPU<false, 1>);
139-
REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU), FFTGPU<true, 2>);
229+
REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU),
230+
FFTGPU<true, false, 1>);
231+
REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU),
232+
FFTGPU<false, false, 1>);
233+
REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU),
234+
FFTGPU<true, false, 2>);
140235
REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU),
141-
FFTGPU<false, 2>);
142-
REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU), FFTGPU<true, 3>);
236+
FFTGPU<false, false, 2>);
237+
REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU),
238+
FFTGPU<true, false, 3>);
143239
REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU),
144-
FFTGPU<false, 3>);
240+
FFTGPU<false, false, 3>);
145241

146242
} // end namespace tensorflow
147243

0 commit comments

Comments
 (0)