Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
"pytools": ("https://documen.tician.de/pytools", None),
"scipy": ("https://docs.scipy.org/doc/scipy", None),
"sumpy": ("https://documen.tician.de/sumpy", None),
"sympy": ("https://docs.sympy.org/latest/", None),
}
10 changes: 6 additions & 4 deletions pytential/linalg/direct_solver_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def map_int_g(self, expr):
if name not in source_args
}

return expr.copy(target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)
from dataclasses import replace
return replace(expr,
target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)

# }}}

Expand Down
28 changes: 15 additions & 13 deletions pytential/linalg/skeletonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from arraycontext import PyOpenCLArrayContext, Array

from pytential import GeometryCollection, sym
from pytential.symbolic.matrix import ClusterMatrixBuilderBase
from pytential.linalg.utils import IndexList, TargetAndSourceClusterList
from pytential.linalg.proxy import ProxyGeneratorBase, ProxyClusterGeometryData
from pytential.linalg.direct_solver_symbolic import (
Expand Down Expand Up @@ -136,7 +137,7 @@ def prg():
lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
)

return knl
return knl.executor(actx.context)

waa = bind(places, sym.weights_and_area_elements(
places.ambient_dim, dofdesc=domain))(actx)
Expand Down Expand Up @@ -253,17 +254,17 @@ class SkeletonizationWrangler:
domains: tuple[sym.DOFDescriptor, ...]
context: dict[str, Any]

neighbor_cluster_builder: Callable[..., np.ndarray]
neighbor_cluster_builder: type[ClusterMatrixBuilderBase]

# target skeletonization
weighted_targets: bool
target_proxy_exprs: np.ndarray
proxy_target_cluster_builder: Callable[..., np.ndarray]
proxy_target_cluster_builder: type[ClusterMatrixBuilderBase]

# source skeletonization
weighted_sources: bool
source_proxy_exprs: np.ndarray
proxy_source_cluster_builder: Callable[..., np.ndarray]
proxy_source_cluster_builder: type[ClusterMatrixBuilderBase]

@property
def nrows(self) -> int:
Expand Down Expand Up @@ -386,35 +387,36 @@ def make_skeletonization_wrangler(

# internal
_weighted_proxy: bool | tuple[bool, bool] | None = None,
_proxy_source_cluster_builder: Callable[..., np.ndarray] | None = None,
_proxy_target_cluster_builder: Callable[..., np.ndarray] | None = None,
_neighbor_cluster_builder: Callable[..., np.ndarray] | None = None,
_proxy_source_cluster_builder: type[ClusterMatrixBuilderBase] | None = None,
_proxy_target_cluster_builder: type[ClusterMatrixBuilderBase] | None = None,
_neighbor_cluster_builder: type[ClusterMatrixBuilderBase] | None = None,
) -> SkeletonizationWrangler:
if context is None:
context = {}

# {{{ setup expressions

try:
exprs = list(exprs)
lpot_exprs = list(exprs)
except TypeError:
exprs = [exprs]
lpot_exprs = [exprs]

try:
input_exprs = list(input_exprs)
except TypeError:
assert not isinstance(input_exprs, Sequence)
input_exprs = [input_exprs]

from pytential.symbolic.execution import _prepare_auto_where, _prepare_domains

auto_where = _prepare_auto_where(auto_where, places)
domains = _prepare_domains(len(input_exprs), places, domains, auto_where[0])

exprs = prepare_expr(places, exprs, auto_where)
prepared_lpot_exprs = prepare_expr(places, lpot_exprs, auto_where)
source_proxy_exprs = prepare_proxy_expr(
places, exprs, (auto_where[0], PROXY_SKELETONIZATION_TARGET))
places, prepared_lpot_exprs, (auto_where[0], PROXY_SKELETONIZATION_TARGET))
target_proxy_exprs = prepare_proxy_expr(
places, exprs, (PROXY_SKELETONIZATION_SOURCE, auto_where[1]))
places, prepared_lpot_exprs, (PROXY_SKELETONIZATION_SOURCE, auto_where[1]))

# }}}

Expand Down Expand Up @@ -449,7 +451,7 @@ def make_skeletonization_wrangler(

return SkeletonizationWrangler(
# operator
exprs=exprs,
exprs=prepared_lpot_exprs,
input_exprs=tuple(input_exprs),
domains=tuple(domains),
context=context,
Expand Down
2 changes: 1 addition & 1 deletion pytential/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def prg():
lang_version=MOST_RECENT_LANGUAGE_VERSION)

knl = lp.split_iname(knl, "icluster", 128, outer_tag="g.0")
return knl
return knl.executor(actx.context)

@memoize_in(mindex, (make_index_cluster_cartesian_product, "index_product"))
def _product():
Expand Down
4 changes: 2 additions & 2 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def preprocess_optemplate(self, name, discretizations, expr):
def op_group_features(self, expr):
from pytential.utils import sort_arrays_together
result = (
expr.source, *sort_arrays_together(expr.source_kernels,
expr.densities, key=str)
expr.source,
*sort_arrays_together(expr.source_kernels, expr.densities, key=str)
)

return result
Expand Down
12 changes: 7 additions & 5 deletions pytential/qbx/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def element_prop_threshold_checker(self):
lang_version=MOST_RECENT_LANGUAGE_VERSION)

knl = lp.split_iname(knl, "ielement", 128, inner_tag="l.0", outer_tag="g.0")
return knl
return knl.executor(self.array_context.context)

def get_wrangler(self):
return RefinerWrangler(self.array_context, self)
Expand Down Expand Up @@ -388,8 +388,8 @@ def check_sufficient_source_quadrature_resolution(self,
sym.ElementwiseMax(
sym._source_danger_zone_radii(
stage2_density_discr.ambient_dim,
dofdesc=sym.QBX_SOURCE_STAGE2),
dofdesc=sym.GRANULARITY_ELEMENT)
dofdesc=sym.as_dofdesc(sym.QBX_SOURCE_STAGE2)),
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT))
)(self.array_context), self.array_context)
unwrap_args = AreaQueryElementwiseTemplate.unwrap_args

Expand Down Expand Up @@ -633,7 +633,8 @@ def _refine_qbx_stage1(lpot_source, density_discr,
quad_resolution_by_element = bind(stage1_density_discr,
sym.ElementwiseMax(
sym._quad_resolution(stage1_density_discr.ambient_dim),
dofdesc=sym.GRANULARITY_ELEMENT))(actx)
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT)
))(actx)

violates_kernel_length_scale = \
wrangler.check_element_prop_threshold(
Expand All @@ -653,7 +654,8 @@ def _refine_qbx_stage1(lpot_source, density_discr,
scaled_max_curvature_by_element = bind(stage1_density_discr,
sym.ElementwiseMax(
sym._scaled_max_curvature(stage1_density_discr.ambient_dim),
dofdesc=sym.GRANULARITY_ELEMENT))(actx)
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT)
))(actx)

violates_scaled_max_curv = \
wrangler.check_element_prop_threshold(
Expand Down
2 changes: 1 addition & 1 deletion pytential/qbx/target_assoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def make_target_flags(self, target_discrs_and_qbx_sides):
return target_flags

def make_default_target_association(self, ntargets):
target_to_center = self.array_context.zeros(ntargets, dtype=np.int32)
target_to_center = self.array_context.np.zeros(ntargets, dtype=np.int32)
target_to_center.fill(-1)
target_to_center.finish()

Expand Down
41 changes: 24 additions & 17 deletions pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from pymbolic.primitives import cse_scope, Expression, Variable
from pymbolic.primitives import cse_scope, Expression, Variable, Subscript
from sumpy.kernel import Kernel

from pytential.symbolic.primitives import (
Expand All @@ -44,6 +44,7 @@ class Statement:
.. attribute:: exprs
.. attribute:: priority
"""

names: list[str]
exprs: list[Expression]
priority: int
Expand All @@ -52,7 +53,7 @@ def get_assignees(self) -> set[str]:
raise NotImplementedError(
f"get_assignees for '{self.__class__.__name__}'")

def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Expression]:
def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Variable]:
raise NotImplementedError(
f"get_dependencies for '{self.__class__.__name__}'")

Expand Down Expand Up @@ -80,14 +81,19 @@ def __post_init__(self):
def get_assignees(self):
return set(self.names)

def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Expression]:
def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Variable]:
from operator import or_
deps = reduce(or_, (dep_mapper(expr) for expr in self.exprs))
all_deps = reduce(or_, (dep_mapper(expr) for expr in self.exprs))

deps: set[Variable] = set()
for dep in all_deps:
if isinstance(dep, Variable):
if dep.name not in self.names:
deps.add(dep)
else:
raise TypeError(f"Unsupported dependency type: {type(dep)}")

return {
dep
for dep in deps
if dep.name not in self.names}
return deps

def __str__(self):
comment = self.comment
Expand Down Expand Up @@ -189,13 +195,16 @@ class ComputePotential(Statement):
def get_assignees(self):
return {o.name for o in self.outputs}

def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Expression]:
result = set(dep_mapper(self.densities[0]))
for density in self.densities[1:]:
result.update(dep_mapper(density))
def get_dependencies(self, dep_mapper: DependencyMapper) -> set[Variable]:
from itertools import chain

for arg_expr in self.kernel_arguments.values():
result.update(dep_mapper(arg_expr))
result: set[Variable] = set()
for expr in chain(self.densities, self.kernel_arguments.values()):
for dep in dep_mapper(expr):
if isinstance(dep, Variable):
result.add(dep)
else:
raise TypeError(f"Unsupported dependency type: {type(dep)}")

return result

Expand Down Expand Up @@ -546,9 +555,7 @@ def make_assign(

def assign_to_new_var(
self, expr: Expression, priority: int = 0, prefix: str | None = None,
) -> Variable:
from pymbolic.primitives import Subscript

) -> Variable | Subscript:
# Observe that the only things that can be legally subscripted
# are variables. All other expressions are broken down into
# their scalar components.
Expand Down
2 changes: 2 additions & 0 deletions pytential/symbolic/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def as_dofdesc(desc: DOFDescriptorLike) -> DOFDescriptor:

# {{{ type annotations

DEFAULT_DOFDESC = DOFDescriptor()

DiscretizationStages = (
type[QBX_SOURCE_STAGE1]
| type[QBX_SOURCE_STAGE2]
Expand Down
Loading
Loading