Skip to content

Commit

Permalink
error functions on complex domain (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinhenz authored Jun 6, 2024
1 parent 3759fbd commit ca15b80
Show file tree
Hide file tree
Showing 13 changed files with 620 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docs/beignet.special.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# beignet.special

## Error and Related Functions

::: beignet.special.error_erf
::: beignet.special.error_erfc
::: beignet.special.error_erfi

## Dawson and Fresnel Integrals

::: beignet.special.dawson_integral_f
::: beignet.special.faddeeva_w
1 change: 1 addition & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
rotation_vector_to_rotation_matrix,
)
from ._translation_identity import translation_identity
from .special import error_erf, error_erfc

__all__ = [
"apply_euler_angle",
Expand Down
13 changes: 13 additions & 0 deletions src/beignet/special/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ._dawson_integral_f import dawson_integral_f
from ._error_erf import error_erf
from ._error_erfc import error_erfc
from ._error_erfi import error_erfi
from ._faddeeva_w import faddeeva_w

__all__ = [
"dawson_integral_f",
"error_erf",
"error_erfc",
"error_erfi",
"faddeeva_w",
]
32 changes: 32 additions & 0 deletions src/beignet/special/_dawson_integral_f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import math

import torch
from torch import Tensor

from ._error_erfi import error_erfi


def dawson_integral_f(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Dawson’s integral.
Parameters
----------
input : Tensor
Input tensor.
out : Tensor, optional
Output tensor.
Returns
-------
Tensor
"""
output = math.sqrt(torch.pi) / 2.0 * torch.exp(-(input**2)) * error_erfi(input)

if out is not None:
out.copy_(output)

return out

return output
29 changes: 29 additions & 0 deletions src/beignet/special/_error_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch import Tensor

from ._error_erfc import error_erfc


def error_erf(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Error function.
Parameters
----------
input : Tensor
Input tensor.
out : Tensor, optional
Output tensor.
Returns
-------
Tensor
"""
output = 1.0 - error_erfc(input)

if out is not None:
out.copy_(output)

return out

return output
30 changes: 30 additions & 0 deletions src/beignet/special/_error_erfc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch import Tensor

from ._faddeeva_w import faddeeva_w


def error_erfc(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Complementary error function.
Parameters
----------
input : Tensor
Input tensor.
out : Tensor, optional
Output tensor.
Returns
-------
Tensor
"""
output = torch.exp(-(input**2)) * faddeeva_w(1.0j * input)

if out is not None:
out.copy_(output)

return out

return output
29 changes: 29 additions & 0 deletions src/beignet/special/_error_erfi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch import Tensor

from ._error_erf import error_erf


def error_erfi(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Imaginary error function.
Parameters
----------
input : Tensor
Input tensor.
out : Tensor, optional
Output tensor.
Returns
-------
Tensor
"""
output = -1.0j * error_erf(1.0j * input)

if out is not None:
out.copy_(output)

return out

return output
199 changes: 199 additions & 0 deletions src/beignet/special/_faddeeva_w.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torch
from torch import Tensor


def _voigt_v(x, y):
# assumes x >= 0, y >= 0

N = 11

# h = math.sqrt(math.pi / (N + 1))
h = 0.5116633539732443

phi = (x / h) - (x / h).floor()

k = torch.arange(N + 1, dtype=x.dtype, device=x.device)
t = (k + 0.5) * h
tau = k[1:] * h

# equation 12
w_m = (2 * h * y / torch.pi) * (
torch.exp(-t.pow(2))
* (t.pow(2) + x[..., None].pow(2) + y[..., None].pow(2))
/ (
((t - x[..., None]).pow(2) + y[..., None].pow(2))
* ((t + x[..., None]).pow(2) + y[..., None].pow(2))
)
).sum(dim=-1)

# equation 13
w_mm = (
(
2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.cos(2 * x * y)
+ torch.exp(2 * torch.pi * y / h)
* torch.cos(2 * torch.pi * x / h - 2 * x * y)
)
)
/ (
1
+ torch.exp(4 * torch.pi * y / h)
+ 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
) + w_m

w_mt_1 = (
2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.cos(2 * x * y)
- torch.exp(2 * torch.pi * y / h)
* torch.cos(2 * torch.pi * x / h - 2 * x * y)
)
) / (
1
+ torch.exp(4 * torch.pi * y / h)
- 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)

w_mt_2 = (h * y) / (torch.pi * (x.pow(2) + y.pow(2)))

w_mt_3 = (2 * h * y / torch.pi) * (
torch.exp(-tau.pow(2))
* (tau.pow(2) + x[..., None].pow(2) + y[..., None].pow(2))
/ (
((tau - x[..., None]).pow(2) + y[..., None].pow(2))
* ((tau + x[..., None]).pow(2) + y[..., None].pow(2))
)
).sum(dim=-1)

# equation 14
w_mt = w_mt_1 + w_mt_2 + w_mt_3

return torch.where(
y >= torch.maximum(x, torch.tensor(torch.pi / h)),
w_m,
torch.where((y < x) & (1 / 4 <= phi) & (phi <= 3 / 4), w_mt, w_mm),
)


def _voigt_l(x, y):
# assumes x >= 0, y >= 0

N = 11

# h = math.sqrt(math.pi / (N + 1))
h = 0.5116633539732443

phi = (x / h) - (x / h).floor()

k = torch.arange(N + 1, dtype=x.dtype, device=x.device)
t = (k + 0.5) * h
tau = k[1:] * h

w_m = (2 * h * x / torch.pi) * (
torch.exp(-t.pow(2))
* (-t.pow(2) + x[..., None].pow(2) + y[..., None].pow(2))
/ (
((t - x[..., None]).pow(2) + y[..., None].pow(2))
* ((t + x[..., None]).pow(2) + y[..., None].pow(2))
)
).sum(dim=-1)

# equation 13
w_mm = (
(
-2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.sin(2 * x * y)
- torch.exp(2 * torch.pi * y / h)
* torch.sin(2 * torch.pi * x / h - 2 * x * y)
)
)
/ (
1
+ torch.exp(4 * torch.pi * y / h)
+ 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)
) + w_m

w_mt_1 = (
-2
* torch.exp(-x.pow(2) + y.pow(2))
* (
torch.sin(2 * x * y)
+ torch.exp(2 * torch.pi * y / h)
* torch.sin(2 * torch.pi * x / h - 2 * x * y)
)
) / (
1
+ torch.exp(4 * torch.pi * y / h)
- 2 * torch.exp(2 * torch.pi * y / h) * torch.cos(2 * torch.pi * x / h)
)

w_mt_2 = (h * x) / (torch.pi * (x.pow(2) + y.pow(2)))

w_mt_3 = (2 * h * x / torch.pi) * (
torch.exp(-tau.pow(2))
* (-tau.pow(2) + x[..., None].pow(2) + y[..., None].pow(2))
/ (
((tau - x[..., None]).pow(2) + y[..., None].pow(2))
* ((tau + x[..., None]).pow(2) + y[..., None].pow(2))
)
).sum(dim=-1)

# equation 14
w_mt = w_mt_1 + w_mt_2 + w_mt_3

return torch.where(
y >= torch.maximum(x, torch.tensor(torch.pi / h)),
w_m,
torch.where((y < x) & (1 / 4 <= phi) & (phi <= 3 / 4), w_mt, w_mm),
)


def _faddeeva_w_impl(z):
return _voigt_v(z.real, z.imag) + 1j * _voigt_l(z.real, z.imag)


def faddeeva_w(input: Tensor, *, out: Tensor | None = None) -> Tensor:
r"""
Faddeeva function.
Parameters
----------
input : Tensor
Input tensor.
out : Tensor, optional
Output tensor.
Returns
-------
Tensor
"""
# use symmetries to map to upper right quadrant of complex plane
imag_negative = input.imag < 0.0
input = torch.where(input.imag < 0.0, -input, input)
real_negative = input.real < 0.0
input = torch.where(input.real < 0.0, -input.conj(), input)

a = input.real
b = input.imag

assert (a >= 0.0).all()
assert (b >= 0.0).all()

output = _voigt_v(a, b) + 1j * _voigt_l(a, b)

output = torch.where(imag_negative, 2 * torch.exp(-input.pow(2)) - output, output)

if out is not None:
out.copy_(output)

return out

return torch.where(real_negative, output.conj(), output, out=out)
Loading

0 comments on commit ca15b80

Please sign in to comment.