Skip to content

Commit

Permalink
device type
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jul 29, 2024
1 parent 1d17283 commit 6a78633
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 230 deletions.
12 changes: 6 additions & 6 deletions src/beignet/_lennard_jones_neighbor_list_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def lennard_jones_neighbor_list_potential(
displacement_fn: Callable,
box_size: Tensor,
kinds: Optional[Tensor] = None,
sigma: Tensor = torch.tensor(1.0),
epsilon: Tensor = torch.tensor(1.0),
alpha: Tensor = torch.tensor(2.0),
r_onset: Tensor = torch.tensor(2.0),
r_cutoff: Tensor = torch.tensor(2.5),
dr_threshold: Tensor = torch.tensor(0.5),
sigma: Tensor = 1.0,
epsilon: Tensor = 1.0,
alpha: Tensor = 2.0,
r_onset: Tensor = 2.0,
r_cutoff: Tensor = 2.5,
dr_threshold: Tensor = 0.5,
per_particle: bool = False,
normalized: bool = False,
neighbor_list_format: _NeighborListFormat = _NeighborListFormat.ORDERED_SPARSE,
Expand Down
8 changes: 4 additions & 4 deletions src/beignet/_lennard_jones_pair_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
def lennard_jones_pair_potential(
displacement_fn: Callable,
kinds: Optional[Tensor] = None,
sigma: Tensor = torch.tensor(1.0),
epsilon: Tensor = torch.tensor(1.0),
r_onset: Tensor = torch.tensor(2.0),
r_cutoff: Tensor = torch.tensor(2.5),
sigma: Tensor = 1.0,
epsilon: Tensor = 1.0,
r_onset: Tensor = 2.0,
r_cutoff: Tensor = 2.5,
per_particle: bool = False,
) -> Callable[[Tensor], Tensor]:
r"""Convenience wrapper to compute Lennard-Jones energy over a system.
Expand Down
4 changes: 2 additions & 2 deletions src/beignet/_periodic_displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def periodic_displacement(box: float | Tensor, dR: Tensor) -> Tensor:
Matrix of wrapped displacements with shape (..., spatial_dim).
"""
distances = (
torch.remainder(dR + box * torch.tensor(0.5, dtype=torch.float32), box)
- torch.tensor(0.5, dtype=torch.float32) * box
torch.remainder(dR + box * 0.5, box)
- 0.5 * box
)
return distances
Binary file added src/beignet/examples/models/sand.webp
Binary file not shown.
6 changes: 3 additions & 3 deletions src/beignet/func/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def mapped_fn(
[*mask.shape, *([1] * (out.ndim - mask.ndim))],
)

out = torch.where(mask, out, torch.tensor(0.0))
out = torch.where(mask, out, 0.0)

if dim is None:
return torch.divide(_safe_sum(out), normalization)
Expand Down Expand Up @@ -554,7 +554,7 @@ def mapped_fn(_position: Tensor, **_dynamic_kwargs) -> Tensor:
raise ValueError

def mapped_fn(_position: Tensor, **_dynamic_kwargs):
u = torch.tensor(0.0, dtype=torch.float32)
u = 0.0

distance_fn = functools.partial(displacement_fn, **_dynamic_kwargs)

Expand Down Expand Up @@ -595,7 +595,7 @@ def mapped_fn(_position: Tensor, _kinds: Tensor, **_dynamic_kwargs):
if not isinstance(_kinds, Tensor) or _kinds.is_floating_point():
raise ValueError

u = torch.tensor(0.0, dtype=torch.float32)
u = 0.0

num_particles = _position.shape[0]

Expand Down
33 changes: 18 additions & 15 deletions src/beignet/func/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from beignet.func.__dataclass import _dataclass
from beignet.func._static_field import static_field

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PartitionErrorCode(IntEnum):
"""An enum specifying different error codes.
Expand Down Expand Up @@ -81,7 +83,7 @@ def update(self, bit: bytes, predicate: Tensor) -> "_PartitionError":
"""
zero = torch.zeros([], dtype=torch.uint8)

bit = torch.tensor(bit, dtype=torch.uint8)
bit = bit

return _PartitionError(code=self.code | torch.where(predicate, bit, zero))

Expand Down Expand Up @@ -405,10 +407,10 @@ def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor:
"""
if cells_per_side.numel() == 1:
constants = [[cells_per_side**dim for dim in range(spatial_dimensions)]]
return torch.tensor(constants, dtype=torch.int32)
return torch.tensor(constants, device=device, dtype=torch.int32)

elif cells_per_side.numel() == spatial_dimensions:
one = torch.tensor([[1]], dtype=torch.int32)
one = torch.tensor([[1]], device=device, dtype=torch.int32)
cells_per_side = torch.cat((one, cells_per_side[:, :-1]), dim=1)
return torch.cumprod(cells_per_side, dim=1).squeeze()

Expand Down Expand Up @@ -523,9 +525,9 @@ def safe_mask(
Tensor
A tensor with the function applied to the masked elements and the placeholder value elsewhere.
"""
masked = torch.where(mask, operand, torch.tensor(0, dtype=operand.dtype))
masked = torch.where(mask, operand, 0)

return torch.where(mask, fn(masked), torch.tensor(placeholder, dtype=operand.dtype))
return torch.where(mask, fn(masked), placeholder)


def _segment_sum(
Expand Down Expand Up @@ -899,12 +901,12 @@ def _is_space_valid(space: Tensor) -> Tensor:
If the space tensor has more than 2 dimensions.
"""
if space.ndim == 0 or space.ndim == 1:
return torch.tensor([True])
return torch.tensor([True], device=device)

if space.ndim == 2:
return torch.tensor([torch.all(torch.triu(space) == space)])
return torch.tensor([torch.all(torch.triu(space) == space)], device=device)

return torch.tensor([False])
return torch.tensor([False], device=device)


def neighbor_list_mask(neighbor: _NeighborList, mask_self: bool = False) -> Tensor:
Expand All @@ -924,6 +926,7 @@ def neighbor_list_mask(neighbor: _NeighborList, mask_self: bool = False) -> Tens
"""
if is_neighbor_list_sparse(neighbor.format):
mask = neighbor.indexes[0] < len(neighbor.reference_positions)
torch.set_printoptions(profile="full")
if mask_self:
mask = mask & (neighbor.indexes[0] != neighbor.indexes[1])

Expand Down Expand Up @@ -991,7 +994,7 @@ def _normalize_cell_size(box: Tensor, cutoff: float) -> Tensor:
nx = xx / torch.sqrt(1 + xy**2)
ny = yy

nmin = torch.floor(torch.min(torch.tensor([nx, ny])) / cutoff)
nmin = torch.floor(torch.min(torch.tensor([nx, ny], device=device)) / cutoff)

return 1 / torch.where(nmin == 0, 1, nmin)

Expand All @@ -1007,7 +1010,7 @@ def _normalize_cell_size(box: Tensor, cutoff: float) -> Tensor:
ny = yy / torch.sqrt(1 + yz**2)
nz = zz

nmin = torch.floor(torch.min(torch.tensor([nx, ny, nz])) / cutoff)
nmin = torch.floor(torch.min(torch.tensor([nx, ny, nz], device=device)) / cutoff)
return 1 / torch.where(nmin == 0, 1, nmin)
else:
raise ValueError
Expand Down Expand Up @@ -1050,7 +1053,7 @@ def _particles_per_cell(

hash_multipliers = _hash_constants(dim, per_side)

particle_index = torch.tensor(positions / unit_size, dtype=torch.int32)
particle_index = torch.tensor(positions / unit_size, dtype=torch.int32, device=device)

particle_hash = torch.sum(particle_index * hash_multipliers, dim=1)

Expand Down Expand Up @@ -1083,7 +1086,7 @@ def cell_list(
An object containing `setup_fn` and `update_fn` functions to create and update the cell list.
"""
if not isinstance(size, Tensor):
size = torch.tensor(size, dtype=torch.float32)
size = torch.tensor(size, device=device, dtype=torch.float32)

if size.ndim == 1:
size = torch.reshape(size, [1, -1])
Expand Down Expand Up @@ -1347,7 +1350,7 @@ def neighbor_list(

squared_cutoff = cutoff**2

squared_maximum_distance = (maximum_distance / torch.tensor(2.0)) ** 2
squared_maximum_distance = (maximum_distance / 2.0) ** 2

metric_sq = _to_square_metric_fn(displacement_fn)

Expand Down Expand Up @@ -1623,12 +1626,12 @@ def _fn(position_and_error, maximum_size=None):
raise ValueError

current_unit_size = _cell_size(
torch.tensor(1.0),
1.0,
updated_neighbors.item_size,
)

updated_unit_size = _cell_size(
torch.tensor(1.0),
1.0,
_normalize_cell_size(space, cutoff),
)

Expand Down
6 changes: 4 additions & 2 deletions src/beignet/func/_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from beignet.func._interact import _safe_sum

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


CUSTOM_SIMULATION_TYPE = []

Expand Down Expand Up @@ -183,7 +185,7 @@ def volume(dimension: int, box: Union[float, Tensor]) -> Tensor:
If the box is not a scalar, vector, or matrix.
"""
if isinstance(box, (int, float)) or not box.ndim:
return torch.tensor(box**dimension)
return torch.tensor(box**dimension, device=device)
elif box.ndim == 1:
return torch.prod(box)
elif box.ndim == 2:
Expand Down Expand Up @@ -231,7 +233,7 @@ def U(eps):
return energy_fn(position, perturbation=(1 + eps), **kwargs)

def grad_U(eps):
eps_tensor = torch.tensor([eps], requires_grad=True)
eps_tensor = torch.tensor([eps], requires_grad=True, device=device)
energy = U(eps_tensor)
grad_eps = torch.autograd.grad(energy, eps_tensor, create_graph=True)[0]
return grad_eps
Expand Down
36 changes: 19 additions & 17 deletions src/beignet/func/_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from beignet.func._interact import _force, _safe_sum


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


SUZUKI_YOSHIDA_WEIGHTS = {
1: [
+1.0000000000000000,
Expand Down Expand Up @@ -482,10 +485,10 @@ def half_step_fn(system_momentums, state, temperature):

delta = step_size / steps

weights = torch.tensor(SUZUKI_YOSHIDA_WEIGHTS[system_steps])
weights = torch.tensor(SUZUKI_YOSHIDA_WEIGHTS[system_steps], device=device)

def body_fn(cs, i):
d = torch.tensor(delta * weights[i % system_steps], dtype=torch.float32)
d = torch.tensor(delta * weights[i % system_steps], device=device, dtype=torch.float32)
return substep_fn(d, *cs), 0

(system_momentums, state, _), _ = _scan(
Expand Down Expand Up @@ -574,7 +577,7 @@ def _npt_nose_hoover(
barostat_kwargs: dict | None = None,
thermostat_kwargs: dict | None = None,
) -> (Callable[..., T], Callable[[T], T]):
step_size_2 = torch.tensor(step_size / 2, dtype=torch.float32)
step_size_2 = step_size / 2

force_fn = _force(fn)

Expand All @@ -599,7 +602,7 @@ def setup_fn(
**kwargs,
):
if not masses:
masses = torch.tensor(1.0, dtype=torch.float32)
masses = 1.0

particles, spatial_dimension = positions.shape

Expand Down Expand Up @@ -673,14 +676,12 @@ def update_box_mass(
) -> _NPTNoseHooverChainState:
particles, spatial_dimension = state.positions.shape

current_box_masses = torch.tensor(
current_box_masses = (
spatial_dimension
* (particles + 1)
* _temperature
* state.barostat.oscillations**2,
dtype=state.positions.dtype,
* state.barostat.oscillations**2
)

return state.set(
current_box_masses=current_box_masses,
)
Expand Down Expand Up @@ -717,7 +718,7 @@ def u(eps):
),
),
),
torch.func.grad(u)(torch.tensor(0.0)),
torch.func.grad(u)(0.0),
),
torch.multiply(
torch.multiply(
Expand Down Expand Up @@ -995,7 +996,7 @@ def setup_fn(
**kwargs,
) -> _NVTNoseHooverChainState:
if masses is None:
masses = torch.tensor(1.0, dtype=positions.dtype)
masses = 1.0

if "temperature" not in kwargs:
_temperature = temperature
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def _scan(fn: Callable, carry: Any, indexes: Tensor):

ys.append(y)

return carry, torch.tensor(ys)
return carry, torch.tensor(ys, device=device)


@_DispatchByState
Expand Down Expand Up @@ -1189,7 +1190,7 @@ def _stochastic_step(
c1 = torch.exp(torch.multiply(torch.negative(friction), step_size))

c2 = torch.sqrt(
torch.multiply(temperature, torch.subtract(torch.tensor(1.0), torch.square(c1)))
torch.multiply(temperature, torch.subtract(1.0, torch.square(c1)))
)

momentum_dist = _Normal(c1 * state.momentums, c2**2 * state.masses)
Expand Down Expand Up @@ -1227,7 +1228,7 @@ def _velocity_verlet(


def _volume_metric(dimension: int, box: Tensor) -> Tensor:
if torch.tensor(box).shape == torch.Size([]) or not box.ndim:
if box.shape == torch.Size([]) or not box.ndim:
return box**dimension

match box.ndim:
Expand Down Expand Up @@ -1297,10 +1298,10 @@ def ensemble(
"""
if friction is None:
friction = torch.tensor(1.0)
friction = 1.0

if not isinstance(friction, Tensor):
friction = torch.tensor(friction)
friction = torch.tensor(friction, device=device)

if barostat_kwargs is None:
barostat_kwargs = {}
Expand Down Expand Up @@ -1333,7 +1334,7 @@ def setup_fn(
**kwargs,
):
if masses is None:
masses = torch.tensor(1.0, dtype=positions.dtype)
masses = torch.tensor(1.0, device=device, dtype=positions.dtype)

state = _NVEState(
forces=force_fn(positions, **kwargs),
Expand Down Expand Up @@ -1368,7 +1369,7 @@ def step_fn(state, **kwargs):
raise ValueError

if not isinstance(temperature, Tensor):
temperature = torch.tensor(temperature)
temperature = torch.tensor(temperature, device=device)

match thermostat:
case "Langevin":
Expand All @@ -1382,6 +1383,7 @@ def setup_fn(
if masses is None:
masses = torch.tensor(
1.0,
device=device,
dtype=positions.dtype,
)

Expand Down
4 changes: 3 additions & 1 deletion src/beignet/func/_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

T = TypeVar("T")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def canonicalize_displacement_or_metric(displacement_fn: Callable) -> Callable:
r"""Checks whether or not a displacement or metric was provided.
Expand Down Expand Up @@ -146,7 +148,7 @@ def space(
)
"""
if isinstance(box, (int, float)):
box = torch.tensor([box])
box = torch.tensor([box], device=device)

if box is None:

Expand Down
4 changes: 3 additions & 1 deletion src/beignet/func/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ def maybe_downcast(x):
if isinstance(x, Tensor) and x.dtype is torch.float64:
return x

return x.to(dtype=torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

return torch.tensor(x).to(device=device, dtype=torch.float32)
Loading

0 comments on commit 6a78633

Please sign in to comment.