@@ -103,6 +103,70 @@ inline plain_type_t<V> inv_fft(const V& y) {
103103 return plain_type_t <V>(res);
104104}
105105
106+ /* *
107+ * Return the two-dimensional discrete Fourier transform of the
108+ * specified complex matrix. The 2D discrete Fourier transform first
109+ * runs the discrete Fourier transform on the each row, then on each
110+ * column of the result.
111+ *
112+ * The adjoint computation is given by
113+ * ```
114+ * adjoint(x) += size(y) * inv_fft2(adjoint(y))
115+ * ```
116+ *
117+ * @tparam M type of complex matrix argument
118+ * @param[in] x matrix to transform
119+ * @return discrete 2D Fourier transform of `x`
120+ */
121+ template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr ,
122+ require_var_t <base_type_t <value_type_t <M>>>* = nullptr >
123+ inline plain_type_t <M> fft2 (const M& x) {
124+ arena_t <M> arena_v = x;
125+ arena_t <M> res = fft2 (to_complex (arena_v.real ().val (), arena_v.imag ().val ()));
126+
127+ reverse_pass_callback ([arena_v, res]() mutable {
128+ auto adj_inv_fft = inv_fft2 (to_complex (res.real ().adj (), res.imag ().adj ()));
129+ adj_inv_fft *= res.size ();
130+ arena_v.real ().adj () += adj_inv_fft.real ();
131+ arena_v.imag ().adj () += adj_inv_fft.imag ();
132+ });
133+
134+ return plain_type_t <M>(res);
135+ }
136+
137+ /* *
138+ * Return the two-dimensional inverse discrete Fourier transform of
139+ * the specified complex matrix. The 2D inverse discrete Fourier
140+ * transform first runs the 1D inverse Fourier transform on the
141+ * columns, and then on the resulting rows. The composition of the
142+ * FFT and inverse FFT (or vice-versa) is the identity.
143+ *
144+ * The adjoint computation is given by
145+ * ```
146+ * adjoint(y) += (1 / size(x)) * fft2(adjoint(x))
147+ * ```
148+ *
149+ * @tparam M type of complex matrix argument
150+ * @param[in] y matrix to inverse trnasform
151+ * @return inverse discrete 2D Fourier transform of `y`
152+ */
153+ template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr ,
154+ require_var_t <base_type_t <value_type_t <M>>>* = nullptr >
155+ inline plain_type_t <M> inv_fft2 (const M& y) {
156+ arena_t <M> arena_v = y;
157+ arena_t <M> res
158+ = inv_fft2 (to_complex (arena_v.real ().val (), arena_v.imag ().val ()));
159+
160+ reverse_pass_callback ([arena_v, res]() mutable {
161+ auto adj_fft = fft2 (to_complex (res.real ().adj (), res.imag ().adj ()));
162+ adj_fft /= res.size ();
163+
164+ arena_v.real ().adj () += adj_fft.real ();
165+ arena_v.imag ().adj () += adj_fft.imag ();
166+ });
167+ return plain_type_t <M>(res);
168+ }
169+
106170} // namespace math
107171} // namespace stan
108172#endif
0 commit comments