Skip to content

Commit

Permalink
add: cart2sph conversion and evalf in integral generation
Browse files Browse the repository at this point in the history
c2s transformations seems to introduce some unnecessary multiplications
by 1.0 at several places
  • Loading branch information
Johannes Steinmetzer authored and sumpfaffe committed Dec 14, 2022
1 parent f9b042c commit cf10169
Showing 1 changed file with 64 additions and 9 deletions.
73 changes: 64 additions & 9 deletions pysisyphus/wavefunction/gen_ints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python3

# [1] https://doi.org/10.1063/1.450106
# Efficient recursive computation of molecular integrals over Cartesian
Expand All @@ -23,7 +23,9 @@
# Memory-Efficient Recursive Evaluation of 3-Center Gaussian Integrals
# Asadchev, Valeev, 2022


import argparse
from datetime import datetime
import functools
import itertools as it
import os
Expand All @@ -37,25 +39,35 @@
from jinja2 import Template
import numpy as np
from sympy import (
Array,
cse,
exp,
Expr,
flatten,
Function,
IndexedBase,
Matrix,
permutedims,
pi,
sqrt,
Symbol,
symbols,
tensorcontraction as tc,
tensorproduct as tp,
)


from sympy.codegen.ast import Assignment
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.c import C99CodePrinter

from pysisyphus.wavefunction.cart2sph import cart2sph_coeffs

try:
from pysisyphus.config import L_MAX, L_AUX_MAX
except ModuleNotFoundError:
L_MAX = 4
L_AUX_MAX = 5


L_MAP = {
Expand All @@ -78,6 +90,7 @@
"3c2e",
"3c2e_sph",
)
ONE_THRESH = 1e-14


def make_py_func(repls, reduced, args=None, name=None, doc_str=""):
Expand Down Expand Up @@ -973,11 +986,43 @@ def get_map(i, center_i):
return array, array_map


def cart2spherical(L_tots, exprs):
assert len(L_tots) > 0

# Coefficient matrices for Cartesian-to-spherical conversion
coeffs = [Array(CART2SPH[L]) for L in L_tots]
cart_shape = [(l + 1) * (l + 2) // 2 for l in L_tots]
cart = Array(exprs).reshape(*cart_shape)

sph = tc(tp(coeffs[0], cart), (1, 2))
if len(L_tots) == 2:
sph = tc(tp(sph, coeffs[1].transpose()), (1, 2))
elif len(L_tots) == 3:
_, Cb, Cc = coeffs
sph = tc(tp(permutedims(sph, (0, 2, 1)), Cb.transpose()), (2, 3))
sph = tc(tp(permutedims(sph, (0, 2, 1)), Cc.transpose()), (2, 3))
else:
raise Exception(
"Cartesian -> spherical transformation for 4-center integrals "
"is not implemented!"
)

# Cartesian-to-spherical transformation introduces quite a number of
# multiplications by 1.0, which are uneccessary. Here, we try to drop
# some of them by replacing numbers very close to +1.0 with 1.
sph = sph.replace(lambda n: n.is_Number and (abs(n - 1) <= ONE_THRESH), lambda n: 1)
# TODO: maybe something along the lines
# sph = map(lambda expr: expr.evalf(), flatten(sph))
# is faster?
return flatten(sph)


def gen_integral_exprs(
int_func,
L_maxs,
kind,
maps=None,
sph=False,
):
if maps is None:
maps = list()
Expand All @@ -986,30 +1031,36 @@ def gen_integral_exprs(

for L_tots in it.product(*ranges):
time_str = time.strftime("%H:%M:%S")
start = time.time()
start = datetime.now()
print(f"{time_str} - Generating {L_tots} {kind}")
sys.stdout.flush()
# Generate actual expressions
# Generate actual list of expressions.
exprs = int_func(*L_tots)
print("\t... generated expressions")
sys.stdout.flush()
if sph:
exprs = cart2spherical(L_tots, exprs)
print("\t... did Cartesian -> Spherical conversion")
sys.stdout.flush()

# Common subexpression elimination
repls, reduced = cse(list(exprs), order="none")
print("\t... did common subexpression elimination")

for i, red in enumerate(reduced):
red = red.evalf()
reduced[i] = functools.reduce(
lambda red, map_: red.xreplace(map_), maps, red
)

for i, (lhs, rhs) in enumerate(repls):
repls[i] = (
lhs,
functools.reduce(lambda rhs, map_: rhs.xreplace(map_), maps, rhs),
)
rhs = rhs.evalf()
# Replace occurences of Ax, Ay, Az, ... with A[0], A[1], A[2], ...
rhs = functools.reduce(lambda rhs, map_: rhs.xreplace(map_), maps, rhs)
repls[i] = (lhs, rhs)

dur = time.time() - start
print(f"\t... finished in {dur: >8.2f} s")
dur = datetime.now() - start
print(f"\t... finished in {str(dur)} h")
sys.stdout.flush()
yield (repls, reduced), L_tots

Expand Down Expand Up @@ -1186,6 +1237,9 @@ def run():
except FileExistsError:
pass

global CART2SPH
CART2SPH = cart2sph_coeffs(max(l_max, l_aux_max), zero_small=True)

# Cartesian basis function centers A and B.
center_A = get_center("A")
center_B = get_center("B")
Expand Down Expand Up @@ -1502,6 +1556,7 @@ def _3center2el_doc_func(L_tots):
(l_max, l_max, l_aux_max),
"_3center2el3d_sph",
(A_map, B_map, C_map),
sph=True,
)
write_render(
_3center2el_ints_Ls,
Expand Down

0 comments on commit cf10169

Please sign in to comment.