Skip to content

Commit

Permalink
simplified argument 'partial_pattern' to 'partial'
Browse files Browse the repository at this point in the history
  • Loading branch information
FrederikLizakJohansen committed Aug 22, 2024
1 parent f1627eb commit 57bf852
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
46 changes: 23 additions & 23 deletions debyecalculator/debye_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def parse_elements(elements, size):
def generate_partial_masks(
self,
structure: StructureTuple,
partial_pattern: str,
partial: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generates masks for selecting specific atom pairs in a structure based on a given partial pattern.
Expand All @@ -458,7 +458,7 @@ def generate_partial_masks(
Parameters:
structure (StructureTuple): A tuple representing the atomic structure, containing atomic positions, elements, etc.
partial_pattern (str): A string in the form 'X-Y', where 'X' and 'Y' are element symbols. If provided, masks are created to isolate interactions between these elements. If None, masks select all elements.
partial (str): A string in the form 'X-Y', where 'X' and 'Y' are element symbols. If provided, masks are created to isolate interactions between these elements. If None, masks select all elements.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -467,20 +467,20 @@ def generate_partial_masks(
- partial_mask_other (torch.Tensor): A boolean mask for selecting atoms of the second element 'Y' in the structure.
Raises:
ValueError: If the partial_pattern does not match the expected format 'X-Y', where 'X' and 'Y' are valid element symbols.
ValueError: If the partial pattern does not match the expected format 'X-Y', where 'X' and 'Y' are valid element symbols.
"""

# Generate the upper trianfular indices
N = structure.xyz.size(0)
triu_indices = torch.triu_indices(N, N, offset=1)

if partial_pattern is not None:
if partial is not None:

# Assert that string matches represetation
re_pattern = r'^([a-zA-Z]+)-([a-zA-Z]+)$'
match = re.match(re_pattern, partial_pattern)
match = re.match(re_pattern, partial)
if not match:
raise ValueError("'partial_pattern' does not match the pattern 'X-Y', of elements 'X' and 'Y'.")
raise ValueError("'partial' does not match the pattern 'X-Y', of elements 'X' and 'Y'.")

# Extract elements and make sure they are atoms in the structure
el1, el2 = match.groups()
Expand All @@ -489,9 +489,9 @@ def generate_partial_masks(
elements = np.array(structure.elements)

if not el1 in elements:
raise ValueError(f"element {el1} from 'partial_pattern' is not present in structure.")
raise ValueError(f"element {el1} from 'partial' is not present in structure.")
if not el2 in elements:
raise ValueError(f"element {el2} from 'partial_pattern' is not present in structure.")
raise ValueError(f"element {el2} from 'partial' is not present in structure.")


# Construct masks for struc and other using comparison
Expand All @@ -515,7 +515,7 @@ def iq(
self,
structure_source: StructureSourceType,
radii: Union[List[float], float, None] = None,
partial_pattern: str = None,
partial: str = None,
keep_on_device: bool = False,
include_self_scattering: bool = True,
) -> Union[IqTuple, List[IqTuple]]:
Expand All @@ -525,7 +525,7 @@ def iq(
Parameters:
structure_source (StructureSourceType): Atomic structure source in XYZ/CIF format, ASE Atoms object, or as a tuple of (atomic_identities, atomic_positions).
radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF.
partial_pattern (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
partial (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
keep_on_device (bool): Flag to keep the results on the class device. Default is False, and will return numpy arrays on CPU.
include_self_scattering (bool): Flag to compute self-scattering contribution. Default is True.
Expand All @@ -551,7 +551,7 @@ def compute_iq(structure):
dists = torch.norm(structure.xyz[:,None] - structure.xyz, dim=2, p=2)[triu_indices[0], triu_indices[1]].split(self.batch_size)

# Generate partial masks
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial_pattern)
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial)

# Batch mask and other indices
partial_mask_sparse = partial_mask_sparse.to(device=self.device).split(self.batch_size)
Expand Down Expand Up @@ -628,7 +628,7 @@ def sq(
self,
structure_source: StructureSourceType,
radii: Union[List[float], float, None] = None,
partial_pattern: str = None,
partial: str = None,
keep_on_device: bool = False,
) -> Union[SqTuple, List[SqTuple]]:
"""
Expand All @@ -637,7 +637,7 @@ def sq(
Parameters:
structure_source (StructureSourceType): Atomic structure source in XYZ/CIF format, ASE Atoms object or as a tuple of (atomic_identities, atomic_positions)
radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF.
partial_pattern (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
partial (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
keep_on_device (bool): Flag to keep the results on the class device. Default is False, and will return numpy arrays on CPU
Returns:
Expand All @@ -656,7 +656,7 @@ def compute_sq(structure):
dists = torch.norm(structure.xyz[:,None] - structure.xyz, dim=2, p=2)[triu_indices[0], triu_indices[1]].split(self.batch_size)

# Generate partial masks
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial_pattern)
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial)

# Batch mask and other indices
partial_mask_sparse = partial_mask_sparse.to(device=self.device).split(self.batch_size)
Expand Down Expand Up @@ -726,7 +726,7 @@ def fq(
self,
structure_source: StructureSourceType,
radii: Union[List[float], float, None] = None,
partial_pattern: str = None,
partial: str = None,
keep_on_device: bool = False,
) -> Union[FqTuple, List[FqTuple]]:
"""
Expand All @@ -735,7 +735,7 @@ def fq(
Parameters:
structure_source (StructureSourceType): Atomic structure source in XYZ/CIF format, ASE Atoms object, or as a tuple of (atomic_identities, atomic_positions).
radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF.
partial_pattern (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
partial (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
keep_on_device (bool): Flag to keep the results on the class device. Default is False, and will return numpy arrays on CPU.
Returns:
Expand All @@ -760,7 +760,7 @@ def compute_fq(structure):
dists = torch.norm(structure.xyz[:,None] - structure.xyz, dim=2, p=2)[triu_indices[0], triu_indices[1]].split(self.batch_size)

# Generate partial masks
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial_pattern)
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial)

# Batch mask and other indices
partial_mask_sparse = partial_mask_sparse.to(device=self.device).split(self.batch_size)
Expand Down Expand Up @@ -831,7 +831,7 @@ def gr(
self,
structure_source: StructureSourceType,
radii: Union[List[float], float, None] = None,
partial_pattern: str = None,
partial: str = None,
keep_on_device: bool = False,
) -> Union[GrTuple, List[GrTuple]]:
"""
Expand All @@ -840,7 +840,7 @@ def gr(
Parameters:
structure_source (StructureSourceType): Atomic structure source in XYZ/CIF format, ASE Atoms object, or as a tuple of (atomic_identities, atomic_positions).
radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF.
partial_pattern (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
partial (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
keep_on_device (bool): Flag to keep the results on the class device. Default is False, and will return numpy arrays on CPU.
Returns:
Expand All @@ -864,7 +864,7 @@ def compute_gr(structure):
dists = torch.norm(structure.xyz[:,None] - structure.xyz, dim=2, p=2)[triu_indices[0], triu_indices[1]].split(self.batch_size)

# Generate partial masks
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial_pattern)
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial)

# Batch mask and other indices
partial_mask_sparse = partial_mask_sparse.to(device=self.device).split(self.batch_size)
Expand Down Expand Up @@ -939,7 +939,7 @@ def _get_all(
self,
structure_source: StructureSourceType,
radii: Union[List[float], float, None] = None,
partial_pattern: str = None,
partial: str = None,
keep_on_device: bool = False,
include_self_scattering: bool = True,
) -> Union[AllTuple, List[AllTuple]]:
Expand All @@ -949,7 +949,7 @@ def _get_all(
Parameters:
structure_source (StructureSourceType): Atomic structure source in XYZ/CIF format, ASE Atoms object, or as a tuple of (atomic_identities, atomic_positions).
radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF.
partial_pattern (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
partial (str): String on the form 'X-Y' where 'X' and 'Y' are elements in either structure. Used for calculating partial scattering patterns. Default is None.
keep_on_device (bool): Flag to keep the results on the class device. Default is False, and will return numpy arrays on CPU.
include_self_scattering (bool): Flag to compute self-scattering contribution. Default is True.
Expand All @@ -974,7 +974,7 @@ def compute_all(structure):
dists = torch.norm(structure.xyz[:,None] - structure.xyz, dim=2, p=2)[triu_indices[0], triu_indices[1]].split(self.batch_size)

# Generate partial masks
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial_pattern)
partial_mask_sparse, partial_mask_struc, partial_mask_other = self.generate_partial_masks(structure, partial)

# Batch mask and other indices
partial_mask_sparse = partial_mask_sparse.to(device=self.device).split(self.batch_size)
Expand Down
40 changes: 20 additions & 20 deletions debyecalculator/test_debye_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,38 +376,38 @@ def test_partials():
Testing whether the partials add up
"""
# I(Q)
_, iq = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern=None)
_, iq_o_o = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="O-O", include_self_scattering=True)
_, iq_co_o = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-O", include_self_scattering=False)
_, iq_co_co = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-Co", include_self_scattering=True)
_, iq = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial=None)
_, iq_o_o = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial="O-O", include_self_scattering=True)
_, iq_co_o = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-O", include_self_scattering=False)
_, iq_co_co = calc.iq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-Co", include_self_scattering=True)
iq_added = iq_o_o + iq_co_o + iq_co_co
assert np.allclose(iq, iq_added, atol=1e-04, rtol=1e-03), f"Partials are not matching for I(Q) calculations"
# S(Q)
_, sq = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern=None)
_, sq_o_o = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="O-O", )
_, sq_co_o = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-O", )
_, sq_co_co = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-Co", )
_, sq = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial=None)
_, sq_o_o = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial="O-O", )
_, sq_co_o = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-O", )
_, sq_co_co = calc.sq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-Co", )
sq_added = sq_o_o + sq_co_o + sq_co_co
assert np.allclose(sq, sq_added, atol=1e-04, rtol=1e-03), f"Partials are not matching for S(Q) calculations"
# F(Q)
_, fq = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern=None)
_, fq_o_o = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="O-O", )
_, fq_co_o = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-O", )
_, fq_co_co = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-Co", )
_, fq = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial=None)
_, fq_o_o = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial="O-O", )
_, fq_co_o = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-O", )
_, fq_co_co = calc.fq("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-Co", )
fq_added = fq_o_o + fq_co_o + fq_co_co
assert np.allclose(fq, fq_added, atol=1e-04, rtol=1e-03), f"Partials are not matching for F(Q) calculations"
# G(r)
_, gr = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern=None)
_, gr_o_o = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="O-O", )
_, gr_co_o = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-O", )
_, gr_co_co = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial_pattern="Co-Co", )
_, gr = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial=None)
_, gr_o_o = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial="O-O", )
_, gr_co_o = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-O", )
_, gr_co_co = calc.gr("data/AntiFluorite_Co2O.cif", radii=5, partial="Co-Co", )
gr_added = gr_o_o + gr_co_o + gr_co_co
assert np.allclose(gr, gr_added, atol=1e-04, rtol=1e-03), f"Partials are not matching for G(r) calculations"
# _get_all
_, _, iq, sq, fq, gr = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial_pattern=None)
_, _, iq_o_o, sq_o_o, fq_o_o, gr_o_o = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial_pattern="O-O", include_self_scattering=True)
_, _, iq_co_o, sq_co_o, fq_co_o, gr_co_o = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial_pattern="Co-O", include_self_scattering=False)
_, _, iq_co_co, sq_co_co, fq_co_co, gr_co_co = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial_pattern="Co-Co", include_self_scattering=True)
_, _, iq, sq, fq, gr = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial=None)
_, _, iq_o_o, sq_o_o, fq_o_o, gr_o_o = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial="O-O", include_self_scattering=True)
_, _, iq_co_o, sq_co_o, fq_co_o, gr_co_o = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial="Co-O", include_self_scattering=False)
_, _, iq_co_co, sq_co_co, fq_co_co, gr_co_co = calc._get_all('data/AntiFluorite_Co2O.cif', radii=5, partial="Co-Co", include_self_scattering=True)
iq_added = iq_o_o + iq_co_o + iq_co_co
assert np.allclose(iq, iq_added, atol=1e-04, rtol=1e-03), f"Partials are not matching for I(Q) calculations"
sq_added = sq_o_o + sq_co_o + sq_co_co
Expand Down

0 comments on commit 57bf852

Please sign in to comment.