Skip to content

Commit

Permalink
Test complex math functions with real-valued arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Jul 11, 2024
1 parent c61d5a4 commit d2788a2
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions test/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,81 @@ def test_np_bool_handling(ctx_factory):
assert out.get().item() is True


@pytest.mark.parametrize("target", [lp.PyOpenCLTarget, lp.ExecutableCTarget])
def test_complex_functions_with_real_args(ctx_factory, target):
# Reported by David Ham. See <https://github.com/inducer/loopy/issues/851>
t_unit = lp.make_kernel(
"{[i]: 0<=i<10}",
"""
y1[i] = abs(c64[i])
y2[i] = real(c64[i])
y3[i] = imag(c64[i])
y4[i] = conj(c64[i])
y5[i] = abs(c128[i])
y6[i] = real(c128[i])
y7[i] = imag(c128[i])
y8[i] = conj(c128[i])
y9[i] = abs(f32[i])
y10[i] = real(f32[i])
y11[i] = imag(f32[i])
y12[i] = conj(f32[i])
y13[i] = abs(f64[i])
y14[i] = real(f64[i])
y15[i] = imag(f64[i])
y16[i] = conj(f64[i])
""",
target=target())

t_unit = lp.add_dtypes(t_unit,
{"y9,y10,y11,y12": np.complex64,
"y13,y14,y15,y16": np.complex128,
"c64": np.complex64,
"c128": np.complex128,
"f64": np.float64,
"f32": np.float32})
t_unit = lp.set_options(t_unit, return_dict=True)

from numpy.random import default_rng
rng = default_rng(0)
c64 = (rng.random(10, dtype=np.float32)
+ np.csingle(1j)*rng.random(10, dtype=np.float32))
c128 = (rng.random(10, dtype=np.float64)
+ np.cdouble(1j)*rng.random(10, dtype=np.float64))
f32 = rng.random(10, dtype=np.float32)
f64 = rng.random(10, dtype=np.float64)

if target == lp.PyOpenCLTarget:
cl_ctx = ctx_factory()
with cl.CommandQueue(cl_ctx) as queue:
evt, out = t_unit(queue, c64=c64, c128=c128, f32=f32, f64=f64)
elif target == lp.ExecutableCTarget:
t_unit = lp.set_options(t_unit, build_options=["-Werror"])
evt, out = t_unit(c64=c64, c128=c128, f32=f32, f64=f64)
else:
raise NotImplementedError("unsupported target")

np.testing.assert_allclose(out["y1"], np.abs(c64), rtol=1e-6)
np.testing.assert_allclose(out["y2"], np.real(c64), rtol=1e-6)
np.testing.assert_allclose(out["y3"], np.imag(c64), rtol=1e-6)
np.testing.assert_allclose(out["y4"], np.conj(c64), rtol=1e-6)
np.testing.assert_allclose(out["y5"], np.abs(c128), rtol=1e-6)
np.testing.assert_allclose(out["y6"], np.real(c128), rtol=1e-6)
np.testing.assert_allclose(out["y7"], np.imag(c128), rtol=1e-6)
np.testing.assert_allclose(out["y8"], np.conj(c128), rtol=1e-6)
np.testing.assert_allclose(out["y9"], np.abs(f32), rtol=1e-6)
np.testing.assert_allclose(out["y10"], np.real(f32), rtol=1e-6)
np.testing.assert_allclose(out["y11"], np.imag(f32), rtol=1e-6)
np.testing.assert_allclose(out["y12"], np.conj(f32), rtol=1e-6)
np.testing.assert_allclose(out["y13"], np.abs(f64), rtol=1e-6)
np.testing.assert_allclose(out["y14"], np.real(f64), rtol=1e-6)
np.testing.assert_allclose(out["y15"], np.imag(f64), rtol=1e-6)
np.testing.assert_allclose(out["y16"], np.conj(f64), rtol=1e-6)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit d2788a2

Please sign in to comment.