Skip to content

Commit

Permalink
added tests for double precision and invalid dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
FrederikLizakJohansen committed Sep 16, 2024
1 parent db7226e commit 4be1f37
Showing 1 changed file with 55 additions and 15 deletions.
70 changes: 55 additions & 15 deletions debyecalculator/test_debye_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,59 @@ def debye_calc():
def normalize_counts(data, eps=1e-16):
return data / (np.max(data) + eps)

@pytest.mark.parametrize("filename, qstep, gr_qstep, radii, expected_iq, expected_sq, expected_fq, expected_gr", [
('AntiFluorite_Co2O.cif', 0.05, 0.05, 10.0,
'iq_AntiFluorite_Co2O_radius10.0.dat',
'sq_AntiFluorite_Co2O_radius10.0.dat',
'fq_AntiFluorite_Co2O_radius10.0.dat',
'gr_AntiFluorite_Co2O_radius10.0.dat'),
('icsd_001504_cc_r6_lc_2.85_6_tetragonal.xyz', 0.1, 0.1, None,
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Iq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Sq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Fq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Gr.dat')
@pytest.mark.parametrize("filename, qstep, gr_qstep, radii, dtype, expected_iq, expected_sq, expected_fq, expected_gr", [
(
'AntiFluorite_Co2O.cif',
0.05,
0.05,
10.0,
torch.float32,
'iq_AntiFluorite_Co2O_radius10.0.dat',
'sq_AntiFluorite_Co2O_radius10.0.dat',
'fq_AntiFluorite_Co2O_radius10.0.dat',
'gr_AntiFluorite_Co2O_radius10.0.dat'
),
(
'AntiFluorite_Co2O.cif',
0.05,
0.05,
10.0,
torch.float64,
'iq_AntiFluorite_Co2O_radius10.0.dat',
'sq_AntiFluorite_Co2O_radius10.0.dat',
'fq_AntiFluorite_Co2O_radius10.0.dat',
'gr_AntiFluorite_Co2O_radius10.0.dat'
),
(
'icsd_001504_cc_r6_lc_2.85_6_tetragonal.xyz',
0.1,
0.1,
None,
torch.float32,
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Iq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Sq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Fq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Gr.dat'
),
(
'icsd_001504_cc_r6_lc_2.85_6_tetragonal.xyz',
0.1,
0.1,
None,
torch.float64,
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Iq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Sq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Fq.dat',
'icsd_001504_cc_r6_lc_2.85_6_tetragonal_Gr.dat'
)
])
def test_scattering(debye_calc, filename, qstep, gr_qstep, radii, expected_iq, expected_sq, expected_fq, expected_gr):
def test_scattering(debye_calc, filename, qstep, gr_qstep, radii, dtype, expected_iq, expected_sq, expected_fq, expected_gr):

# Set dtype of calculator
debye_calc.update_parameters(dtype=dtype)

# Load the expected PDF
expected_data = np.genfromtxt(f'debyecalculator/unittests_files/{expected_gr}', delimiter=',', skip_header=1)
Expand Down Expand Up @@ -93,7 +132,8 @@ def test_invalid_input(debye_calc):
invalid_params = [
{'qmin': -1.0}, {'qmax': -1.0}, {'qstep': -1.0}, {'qdamp': -1.0},
{'rmin': -1.0}, {'rmax': -1.0}, {'rstep': -1.0}, {'rthres': -1.0},
{'biso': -1.0}, {'batch_size': -1}, {'device': 'x'}, {'radiation_type': 'x'}
{'biso': -1.0}, {'batch_size': -1}, {'device': 'x'}, {'radiation_type': 'x'},
{'dtype': torch.float16}
]

for param in invalid_params:
Expand Down Expand Up @@ -209,4 +249,4 @@ def test_optimal_qstep(debye_calc):
calc = debye_calc
# Optimal qstep is π / (rmax + rstep), with a small tolerance for computational uncertainty
optimal_qstep = math.pi / (calc.rmax + calc.rstep) + 1e-5
assert calc.qstep <= optimal_qstep, f"Expected qstep <= {optimal_qstep}, but got qstep = {calc.qstep}"
assert calc.qstep <= optimal_qstep, f"Expected qstep <= {optimal_qstep}, but got qstep = {calc.qstep}"

0 comments on commit 4be1f37

Please sign in to comment.