Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions stan/math/prim/fun/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ inline Eigen::Matrix<scalar_type_t<V>, -1, 1> inv_fft(const V& y) {
* @param[in] x matrix to transform
* @return discrete 2D Fourier transform of `x`
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr>
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_not_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> fft2(const M& x) {
Eigen::Matrix<scalar_type_t<M>, -1, -1> y(x.rows(), x.cols());
for (int i = 0; i < y.rows(); ++i)
Expand All @@ -103,7 +104,8 @@ inline Eigen::Matrix<scalar_type_t<M>, -1, -1> fft2(const M& x) {
* @param[in] y matrix to inverse trnasform
* @return inverse discrete 2D Fourier transform of `y`
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr>
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_not_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline Eigen::Matrix<scalar_type_t<M>, -1, -1> inv_fft2(const M& y) {
Eigen::Matrix<scalar_type_t<M>, -1, -1> x(y.rows(), y.cols());
for (int j = 0; j < x.cols(); ++j)
Expand Down
64 changes: 64 additions & 0 deletions stan/math/rev/fun/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,70 @@ inline plain_type_t<V> inv_fft(const V& y) {
return plain_type_t<V>(res);
}

/**
* Return the two-dimensional discrete Fourier transform of the
* specified complex matrix. The 2D discrete Fourier transform first
* runs the discrete Fourier transform on the each row, then on each
* column of the result.
*
* The adjoint computation is given by
* ```
* adjoint(x) += size(y) * inv_fft2(adjoint(y))
* ```
*
* @tparam M type of complex matrix argument
* @param[in] x matrix to transform
* @return discrete 2D Fourier transform of `x`
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline plain_type_t<M> fft2(const M& x) {
arena_t<M> arena_v = x;
arena_t<M> res = fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));

reverse_pass_callback([arena_v, res]() mutable {
auto adj_inv_fft = inv_fft2(to_complex(res.real().adj(), res.imag().adj()));
adj_inv_fft *= res.size();
arena_v.real().adj() += adj_inv_fft.real();
arena_v.imag().adj() += adj_inv_fft.imag();
});

return plain_type_t<M>(res);
}

/**
* Return the two-dimensional inverse discrete Fourier transform of
* the specified complex matrix. The 2D inverse discrete Fourier
* transform first runs the 1D inverse Fourier transform on the
* columns, and then on the resulting rows. The composition of the
* FFT and inverse FFT (or vice-versa) is the identity.
*
* The adjoint computation is given by
* ```
* adjoint(y) += (1 / size(x)) * fft2(adjoint(x))
* ```
*
* @tparam M type of complex matrix argument
* @param[in] y matrix to inverse trnasform
* @return inverse discrete 2D Fourier transform of `y`
*/
template <typename M, require_eigen_dense_dynamic_vt<is_complex, M>* = nullptr,
require_var_t<base_type_t<value_type_t<M>>>* = nullptr>
inline plain_type_t<M> inv_fft2(const M& y) {
arena_t<M> arena_v = y;
arena_t<M> res
= inv_fft2(to_complex(arena_v.real().val(), arena_v.imag().val()));

reverse_pass_callback([arena_v, res]() mutable {
auto adj_fft = fft2(to_complex(res.real().adj(), res.imag().adj()));
adj_fft /= res.size();

arena_v.real().adj() += adj_fft.real();
arena_v.imag().adj() += adj_fft.imag();
});
return plain_type_t<M>(res);
}

} // namespace math
} // namespace stan
#endif