Skip to content

Commit 4ea7bef

Browse files
author
Thomas Arildsen
committed
ENH: Added 2D type 1 function interfaces
Un-tested. Tests for the 2D case are not implemented yet.
1 parent ad2eb07 commit 4ea7bef

File tree

5 files changed

+70
-10
lines changed

5 files changed

+70
-10
lines changed

nufft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
__NUFFT_SETUP__ = False
99

1010
if not __NUFFT_SETUP__:
11-
__all__ = ["nufft1freqs", "nufft1", "nufft2", "nufft3"]
12-
from .nufft import nufft1freqs, nufft1, nufft2, nufft3
11+
__all__ = ["nufft1d1freqs", "nufft1d1", "nufft1d2", "nufft1d3", "nufft2d1"]
12+
from .nufft import nufft1d1freqs, nufft1d1, nufft1d2, nufft1d3, nufft2d1

nufft/nufft.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22

33
from __future__ import division, print_function
44

5-
__all__ = ["nufft1freqs", "nufft1", "nufft2", "nufft3"]
5+
__all__ = ["nufft1d1freqs", "nufft1d1", "nufft1d2", "nufft1d3", "nufft2d1"]
66

77
import numpy as np
88
from ._nufft import (
99
dirft1d1, nufft1d1f90,
1010
dirft1d2, nufft1d2f90,
1111
dirft1d3, nufft1d3f90,
12+
dirft2d1, nufft2d1f90,
1213
)
1314

1415

15-
def nufft1freqs(ms, df=1.0):
16+
def nufft1d1freqs(ms, df=1.0):
1617
return df * (np.arange(-ms // 2, ms // 2) + ms % 2)
1718

1819

19-
def nufft1(x, y, ms, df=1.0, eps=1e-15, iflag=1, direct=False):
20+
def nufft1d1(x, y, ms, df=1.0, eps=1e-15, iflag=1, direct=False):
2021
# Make sure that the data are properly formatted.
2122
x = np.ascontiguousarray(x, dtype=np.float64)
2223
y = np.ascontiguousarray(y, dtype=np.complex128)
@@ -34,7 +35,7 @@ def nufft1(x, y, ms, df=1.0, eps=1e-15, iflag=1, direct=False):
3435
return p
3536

3637

37-
def nufft2(x, p, df=1.0, eps=1e-15, iflag=1, direct=False):
38+
def nufft1d2(x, p, df=1.0, eps=1e-15, iflag=1, direct=False):
3839
# Make sure that the data are properly formatted.
3940
x = np.ascontiguousarray(x, dtype=np.float64)
4041
p = np.ascontiguousarray(p, dtype=np.complex128)
@@ -50,7 +51,7 @@ def nufft2(x, p, df=1.0, eps=1e-15, iflag=1, direct=False):
5051
return y
5152

5253

53-
def nufft3(x, y, f, eps=1e-15, iflag=1, direct=False):
54+
def nufft1d3(x, y, f, eps=1e-15, iflag=1, direct=False):
5455
# Make sure that the data are properly formatted.
5556
x = np.ascontiguousarray(x, dtype=np.float64)
5657
y = np.ascontiguousarray(y, dtype=np.complex128)
@@ -69,3 +70,22 @@ def nufft3(x, y, f, eps=1e-15, iflag=1, direct=False):
6970
if flag:
7071
raise RuntimeError("nufft1d3 failed with code {0}".format(flag))
7172
return p / len(x)
73+
74+
75+
def nufft2d1(x, y, z, ms, mt, df=1.0, eps=1e-15, iflag=1, direct=False):
76+
# Make sure that the data are properly formatted.
77+
x = np.ascontiguousarray(x, dtype=np.float64)
78+
y = np.ascontiguousarray(x, dtype=np.float64)
79+
z = np.ascontiguousarray(y, dtype=np.complex128)
80+
if len(x) != len(y) or len(y) != len(z):
81+
raise ValueError("Dimension mismatch")
82+
83+
# Run the Fortran code.
84+
if direct:
85+
p = dirft2d1(x * df, y * df, z, iflag, ms, mt)
86+
else:
87+
p, flag = nufft2d1f90(x * df, y * df, z, iflag, eps, ms, mt)
88+
# Check the output and return.
89+
if flag:
90+
raise RuntimeError("nufft2d1 failed with code {0}".format(flag))
91+
return p

nufft/nufft1d.pyf

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,29 @@ python module _nufft ! in
6565
complex*16 dimension(nk),intent(out),depend(nk) :: fk
6666
end subroutine dirft1d3
6767

68-
end interface
68+
subroutine nufft2d1f90(nj,xj,yj,cj, iflag,eps, ms,mt,fk,ier) ! in :_nufft:src/nufft1d/nufft2df90.f
69+
integer, optional,check(len(xj)>=nj),depend(xj) :: nj=len(xj)
70+
real*8 dimension(nj) :: xj
71+
real*8 dimension(nj) :: yj
72+
complex*16 dimension(nj),depend(nj) :: cj
73+
integer :: iflag
74+
real*8 :: eps
75+
integer :: ms
76+
integer :: mt
77+
complex*16 dimension(ms, mt),intent(out),depend(ms, mt) :: fk
78+
integer,intent(out) :: ier
79+
end subroutine nufft2d1f90
80+
81+
subroutine dirft2d1(nj,xj,yj,cj, iflag, ms,mt,fk) ! in :_nufft:src/nufft1d/dirft2d.f
82+
integer, optional,check(len(xj)>=nj),depend(xj) :: nj=len(xj)
83+
real*8 dimension(nj) :: xj
84+
real*8 dimension(nj) :: yj
85+
complex*16 dimension(nj),depend(nj) :: cj
86+
integer :: iflag
87+
integer :: ms
88+
integer :: mt
89+
complex*16 dimension(ms, mt),intent(out),depend(ms, mt) :: fk
90+
end subroutine dirft2d1
91+
92+
end interface
6993
end python module _nufft

nufft/test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
__all__ = ["test_type_1", "test_type_2", "test_type_3", "test_1_and_3"]
66

77
import numpy as np
8-
from .nufft import nufft1freqs, nufft1, nufft2, nufft3
8+
from .nufft import nufft1d1freqs, nufft1d1, nufft1d2, nufft1d3, nufft2d1
99

1010

1111
def _get_data():
@@ -20,6 +20,21 @@ def _get_data():
2020
f = 48 * np.cos((np.arange(ms) + 1) * np.pi / ms)
2121
return x, y, f
2222

23+
def _get_data_2d():
24+
ms = 20
25+
mt = 20
26+
nj = 128
27+
k1 = np.arange(-0.5 * nj, 0.5 * nj)
28+
j = k1 + 0.5 * nj + 1
29+
x = np.pi * np.cos(-np.pi * j / nj)
30+
y = np.pi * np.sin(-np.pi * j / nj)
31+
z = np.empty_like(x, dtype=np.complex128)
32+
z.real = np.sin(np.pi * j / nj)
33+
z.imag = np.cos(np.pi * j / nj)
34+
ms_val, mt_val = np.meshgrid(np.arange(ms)/ms, np.arange(mt)/mt, sparse=True)
35+
f = 48 * np.cos((ms_val + mt_val + 1) * np.pi)
36+
return x, y, f
37+
2338

2439
def _get_data_roundtrip():
2540
ms = 512

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
# Set up the compiled extension.
3030
sources = list(map(os.path.join("src", "nufft1d", "{0}").format,
31-
["dfftpack.f", "dirft1d.f", "next235.f", "nufft1df90.f"]))
31+
["dfftpack.f", "dirft1d.f", "dirft2d.f",
32+
"next235.f", "nufft1df90.f", "nufft2df90.f"]))
3233
sources += [os.path.join("nufft", "nufft1d.pyf")]
3334
extensions = [Extension("nufft._nufft", sources=sources)]
3435

0 commit comments

Comments
 (0)