@@ -49,72 +49,139 @@ class FFTGPUBase : public OpKernel {
49
49
void Compute (OpKernelContext* ctx) override {
50
50
const Tensor& in = ctx->input (0 );
51
51
const TensorShape& shape = in.shape ();
52
+ const int fft_rank = Rank ();
52
53
OP_REQUIRES (
53
- ctx, shape.dims () >= Rank () ,
54
- errors::InvalidArgument (" Input must have rank of at least " , Rank () ,
54
+ ctx, shape.dims () >= fft_rank ,
55
+ errors::InvalidArgument (" Input must have rank of at least " , fft_rank ,
55
56
" but got: " , shape.DebugString ()));
57
+
56
58
Tensor* out;
57
- OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , shape, &out));
59
+ TensorShape output_shape = shape;
60
+ uint64 fft_shape[3 ] = {0 , 0 , 0 };
61
+
62
+ // In R2C or C2R mode, we use a second input to specify the FFT length
63
+ // instead of inferring it from the input shape.
64
+ if (IsReal ()) {
65
+ const Tensor& fft_length = ctx->input (1 );
66
+ OP_REQUIRES (ctx,
67
+ fft_length.shape ().dims () == 1 &&
68
+ fft_length.shape ().dim_size (0 ) == fft_rank,
69
+ errors::InvalidArgument (" fft_length must have shape [" ,
70
+ fft_rank, " ]" ));
71
+
72
+ auto fft_length_as_vec = fft_length.vec <int32>();
73
+ for (int i = 0 ; i < fft_rank; ++i) {
74
+ fft_shape[i] = fft_length_as_vec (i);
75
+ uint64 dim = IsForward () && i == fft_rank - 1 && fft_shape[i] != 0
76
+ ? fft_shape[i] / 2 + 1
77
+ : fft_shape[i];
78
+ output_shape.set_dim (output_shape.dims () - fft_rank + i, dim);
79
+ }
80
+ } else {
81
+ for (int i = 0 ; i < fft_rank; ++i) {
82
+ fft_shape[i] =
83
+ output_shape.dim_size (output_shape.dims () - fft_rank + i);
84
+ }
85
+ }
86
+
87
+ OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , output_shape, &out));
58
88
if (shape.num_elements () == 0 ) {
59
89
return ;
60
90
}
61
- DoFFT (ctx, in, out);
91
+
92
+ DoFFT (ctx, in, fft_shape, out);
62
93
}
63
94
64
95
protected:
65
96
virtual int Rank () const = 0;
66
97
virtual bool IsForward () const = 0;
98
+ virtual bool IsReal () const = 0;
67
99
68
100
private:
69
- void DoFFT (OpKernelContext* ctx, const Tensor& in, Tensor* out) {
101
+ void DoFFT (OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
102
+ Tensor* out) {
70
103
auto * stream = ctx->op_device_context ()->stream ();
71
104
OP_REQUIRES (ctx, stream, errors::Internal (" No GPU stream available." ));
72
105
73
- const TensorShape& shape = in.shape ();
74
- auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
75
- auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
106
+ const TensorShape& input_shape = in.shape ();
107
+ const TensorShape& output_shape = out->shape ();
76
108
77
- const int rank = Rank ();
109
+ const int fft_rank = Rank ();
78
110
int batch_size = 1 ;
79
- for (int i = 0 ; i < shape .dims () - rank ; ++i) {
80
- batch_size *= shape .dim_size (i);
111
+ for (int i = 0 ; i < input_shape .dims () - fft_rank ; ++i) {
112
+ batch_size *= input_shape .dim_size (i);
81
113
}
82
- uint64 data_length = 1 ;
83
- uint64 data_dims[3 ];
84
- for (int i = 0 ; i < rank; ++i) {
85
- auto dim = shape.dim_size (shape.dims () - rank + i);
86
- data_length *= dim;
87
- data_dims[i] = dim;
114
+ uint64 input_embed[3 ];
115
+ uint64 input_stride = 1 ;
116
+ uint64 input_distance = 1 ;
117
+ uint64 output_embed[3 ];
118
+ uint64 output_stride = 1 ;
119
+ uint64 output_distance = 1 ;
120
+
121
+ for (int i = 0 ; i < fft_rank; ++i) {
122
+ auto dim_offset = input_shape.dims () - fft_rank + i;
123
+ input_embed[i] = input_shape.dim_size (dim_offset);
124
+ input_distance *= input_shape.dim_size (dim_offset);
125
+ output_embed[i] = output_shape.dim_size (dim_offset);
126
+ output_distance *= output_shape.dim_size (dim_offset);
88
127
}
89
128
90
- constexpr uint64* kInputEmbed = nullptr ;
91
- constexpr uint64 kInputStride = 1 ;
92
- constexpr uint64 kInputDistance = 1 ;
93
- constexpr uint64* kOutputEmbed = nullptr ;
94
- constexpr uint64 kOutputStride = 1 ;
95
- constexpr uint64 kOutputDistance = 1 ;
96
129
constexpr bool kInPlaceFft = false ;
130
+ const auto kFftType =
131
+ IsReal () ? (IsForward () ? perftools::gputools::fft::Type::kR2C
132
+ : perftools::gputools::fft::Type::kC2R )
133
+ : (IsForward () ? perftools::gputools::fft::Type::kC2CForward
134
+ : perftools::gputools::fft::Type::kC2CInverse );
97
135
98
136
auto plan = stream->parent ()->AsFft ()->CreateBatchedPlan (
99
- stream, rank, data_dims, kInputEmbed , kInputStride , kInputDistance ,
100
- kOutputEmbed , kOutputStride , kOutputDistance ,
101
- IsForward () ? perftools::gputools::fft::Type::kC2CForward
102
- : perftools::gputools::fft::Type::kC2CInverse ,
103
- kInPlaceFft , batch_size);
104
-
105
- OP_REQUIRES (
106
- ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
107
- errors::Internal (" c2c fft failed : in.shape=" , shape.DebugString ()));
108
- if (!IsForward ()) {
109
- auto alpha = complex64 (1 .f / data_length);
137
+ stream, fft_rank, fft_shape, input_embed, input_stride, input_distance,
138
+ output_embed, output_stride, output_distance, kFftType , kInPlaceFft ,
139
+ batch_size);
140
+
141
+ if (IsReal ()) {
142
+ if (IsForward ()) {
143
+ auto src = AsDeviceMemory<float >(in.flat <float >().data ());
144
+ auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
145
+ OP_REQUIRES (
146
+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
147
+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
148
+ " in.shape=" , input_shape.DebugString ()));
149
+ } else {
150
+ auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
151
+ auto dst = AsDeviceMemory<float >(out->flat <float >().data ());
152
+ OP_REQUIRES (
153
+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
154
+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
155
+ " in.shape=" , input_shape.DebugString ()));
156
+ auto alpha = 1 .f / output_distance;
157
+ OP_REQUIRES (
158
+ ctx,
159
+ stream->ThenBlasScal (output_shape.num_elements (), alpha, &dst, 1 )
160
+ .ok (),
161
+ errors::Internal (" BlasScal failed : in.shape=" ,
162
+ input_shape.DebugString ()));
163
+ }
164
+ } else {
165
+ auto src = AsDeviceMemory<complex64>(in.flat <complex64>().data ());
166
+ auto dst = AsDeviceMemory<complex64>(out->flat <complex64>().data ());
110
167
OP_REQUIRES (
111
- ctx, stream->ThenBlasScal (shape.num_elements (), alpha, &dst, 1 ).ok (),
112
- errors::Internal (" BlasScal failed : in.shape=" , shape.DebugString ()));
168
+ ctx, stream->ThenFft (plan.get (), src, &dst).ok (),
169
+ errors::Internal (" fft failed : type=" , static_cast <int >(kFftType ),
170
+ " in.shape=" , input_shape.DebugString ()));
171
+ if (!IsForward ()) {
172
+ auto alpha = complex64 (1 .f / output_distance);
173
+ OP_REQUIRES (
174
+ ctx,
175
+ stream->ThenBlasScal (output_shape.num_elements (), alpha, &dst, 1 )
176
+ .ok (),
177
+ errors::Internal (" BlasScal failed : in.shape=" ,
178
+ input_shape.DebugString ()));
179
+ }
113
180
}
114
181
}
115
182
};
116
183
117
- template <bool Forward, int FFTRank>
184
+ template <bool Forward, bool _Real, int FFTRank>
118
185
class FFTGPU : public FFTGPUBase {
119
186
public:
120
187
static_assert (FFTRank >= 1 && FFTRank <= 3 ,
@@ -124,24 +191,53 @@ class FFTGPU : public FFTGPUBase {
124
191
protected:
125
192
int Rank () const override { return FFTRank; }
126
193
bool IsForward () const override { return Forward; }
194
+ bool IsReal () const override { return _Real; }
127
195
};
128
196
129
- REGISTER_KERNEL_BUILDER (Name(" FFT" ).Device(DEVICE_GPU), FFTGPU<true , 1 >);
130
- REGISTER_KERNEL_BUILDER (Name(" IFFT" ).Device(DEVICE_GPU), FFTGPU<false , 1 >);
131
- REGISTER_KERNEL_BUILDER (Name(" FFT2D" ).Device(DEVICE_GPU), FFTGPU<true , 2 >);
132
- REGISTER_KERNEL_BUILDER (Name(" IFFT2D" ).Device(DEVICE_GPU), FFTGPU<false , 2 >);
133
- REGISTER_KERNEL_BUILDER (Name(" FFT3D" ).Device(DEVICE_GPU), FFTGPU<true , 3 >);
134
- REGISTER_KERNEL_BUILDER (Name(" IFFT3D" ).Device(DEVICE_GPU), FFTGPU<false , 3 >);
197
+ REGISTER_KERNEL_BUILDER (Name(" FFT" ).Device(DEVICE_GPU), FFTGPU<true , false , 1 >);
198
+ REGISTER_KERNEL_BUILDER (Name(" IFFT" ).Device(DEVICE_GPU),
199
+ FFTGPU<false , false , 1 >);
200
+ REGISTER_KERNEL_BUILDER (Name(" FFT2D" ).Device(DEVICE_GPU),
201
+ FFTGPU<true , false , 2 >);
202
+ REGISTER_KERNEL_BUILDER (Name(" IFFT2D" ).Device(DEVICE_GPU),
203
+ FFTGPU<false , false , 2 >);
204
+ REGISTER_KERNEL_BUILDER (Name(" FFT3D" ).Device(DEVICE_GPU),
205
+ FFTGPU<true , false , 3 >);
206
+ REGISTER_KERNEL_BUILDER (Name(" IFFT3D" ).Device(DEVICE_GPU),
207
+ FFTGPU<false , false , 3 >);
208
+
209
+ REGISTER_KERNEL_BUILDER (
210
+ Name (" RFFT" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
211
+ FFTGPU<true, true, 1>);
212
+ REGISTER_KERNEL_BUILDER (
213
+ Name (" IRFFT" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
214
+ FFTGPU<false, true, 1>);
215
+ REGISTER_KERNEL_BUILDER (
216
+ Name (" RFFT2D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
217
+ FFTGPU<true, true, 2>);
218
+ REGISTER_KERNEL_BUILDER (
219
+ Name (" IRFFT2D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
220
+ FFTGPU<false, true, 2>);
221
+ REGISTER_KERNEL_BUILDER (
222
+ Name (" RFFT3D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
223
+ FFTGPU<true, true, 3>);
224
+ REGISTER_KERNEL_BUILDER (
225
+ Name (" IRFFT3D" ).Device(DEVICE_GPU).HostMemory(" fft_length" ),
226
+ FFTGPU<false, true, 3>);
135
227
136
228
// Deprecated kernels.
137
- REGISTER_KERNEL_BUILDER (Name(" BatchFFT" ).Device(DEVICE_GPU), FFTGPU<true , 1 >);
138
- REGISTER_KERNEL_BUILDER (Name(" BatchIFFT" ).Device(DEVICE_GPU), FFTGPU<false , 1 >);
139
- REGISTER_KERNEL_BUILDER (Name(" BatchFFT2D" ).Device(DEVICE_GPU), FFTGPU<true , 2 >);
229
+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT" ).Device(DEVICE_GPU),
230
+ FFTGPU<true, false, 1>);
231
+ REGISTER_KERNEL_BUILDER (Name(" BatchIFFT" ).Device(DEVICE_GPU),
232
+ FFTGPU<false, false, 1>);
233
+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT2D" ).Device(DEVICE_GPU),
234
+ FFTGPU<true, false, 2>);
140
235
REGISTER_KERNEL_BUILDER (Name(" BatchIFFT2D" ).Device(DEVICE_GPU),
141
- FFTGPU<false , 2 >);
142
- REGISTER_KERNEL_BUILDER (Name(" BatchFFT3D" ).Device(DEVICE_GPU), FFTGPU<true , 3 >);
236
+ FFTGPU<false, false, 2>);
237
+ REGISTER_KERNEL_BUILDER (Name(" BatchFFT3D" ).Device(DEVICE_GPU),
238
+ FFTGPU<true, false, 3>);
143
239
REGISTER_KERNEL_BUILDER (Name(" BatchIFFT3D" ).Device(DEVICE_GPU),
144
- FFTGPU<false , 3 >);
240
+ FFTGPU<false, false, 3>);
145
241
146
242
} // end namespace tensorflow
147
243
0 commit comments