Skip to content

Handle inactive args from context #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 68 additions & 86 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ def to_jax(ty):
return {"f32": jnp.float32, "f64": jnp.float64}[tystr]


def activity_from_pipeline(pass_pipeline):
start = pass_pipeline.index("argTys=")
end = pass_pipeline.index(" ", start)
acts = pass_pipeline[start + len("argTys=") : end].split(",")
pre_act = pass_pipeline[: start + len("argTys=")]
post_act = pass_pipeline[end:]
return pre_act, acts, post_act


def _enzyme_primal_lowering(
ctx: jax_mlir.LoweringRuleContext,
*args_flat: ir.Value,
Expand Down Expand Up @@ -523,12 +532,10 @@ def _enzyme_primal_lowering(
if i not in in_idx_map or in_idx_map[i] in kept
)
if len(kept) != len(orig_shapes):
post = ",".join(["enzyme_dup"] * len(kept))
prev = ",".join(["enzyme_dup"] * len(orig_shapes))
pass_pipeline = pass_pipeline.replace(prev, post)
post = ",".join(["enzyme_out"] * len(kept))
prev = ",".join(["enzyme_out"] * len(orig_shapes))
pass_pipeline = pass_pipeline.replace(prev, post)
if "argTys=" in pass_pipeline:
pre_act, acts, post_act = activity_from_pipeline(pass_pipeline)
acts2 = [act for (i, act) in enumerate(acts) if i in kept]
pass_pipeline = pre_act + ",".join(acts2) + post_act

out_types = [
shape
Expand Down Expand Up @@ -917,18 +924,36 @@ def cpp_call(


def enzyme_jvp(arg_primals, arg_tangents, **kwargs):
print("arg_tan", arg_tangents)
print("kwargs", kwargs)
Comment on lines +927 to +928
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover debug spew


# TODO propagate activity info rather than make_zero
def make_zero(tan, prim):
return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan

arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals))
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)

pipeline_options = kwargs["pipeline_options"]

shadconv = None
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
act_tup = ",".join(["enzyme_dup" for a in arg_primals])
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
act_tup = []
args = []

avals = {}

for idx, (v, s) in enumerate(zip(arg_primals, arg_tangents)):
avals[len(args)] = in_idx_map[idx]
args.append(v)
if type(s) is ad.Zero:
act_tup.append("enzyme_const")
else:
act_tup.append("enzyme_dup")
avals[len(args)] = in_idx_map[idx]
args.append(s)

args = tuple(args)
act_tup = ",".join(act_tup)

afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize"
newpasses = (
"inline{default-pipeline=canonicalize max-iterations=4},"
Expand All @@ -940,10 +965,10 @@ def make_zero(tan, prim):
if pipeline_options.pass_pipeline() != "":
oldpasses = pipeline_options.pass_pipeline()
if "enzyme-wrap" in oldpasses:
start = passes.rindex("enzyme-wrap{")
end = passes.index("}", start)
prev_passes = passes[:end]
newpasses = prev_passes + afterad + newpasses + passes[end:]
start = oldpasses.rindex("enzyme-wrap{")
end = oldpasses.index("}", start)
prev_passes = oldpasses[:end]
newpasses = prev_passes + afterad + newpasses + oldpasses[end:]
else:
newpasses = newpasses + "," + oldpasses
if pipeline_options.stablehlo_inject():
Expand All @@ -954,10 +979,6 @@ def make_zero(tan, prim):
for o in kwargs["out_shapes"]:
outshapes2.append(o)
outshapes2.append(o)
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
avals = {2 * k: v for k, v in in_idx_map.items()} | {
2 * k + 1: v for k, v in in_idx_map.items()
}
out_idx_map2 = {2 * k: v for k, v in out_idx_map.items()} | {
2 * k + 1: v for k, v in out_idx_map.items()
}
Expand All @@ -972,6 +993,10 @@ def make_zero(tan, prim):
pipeline_options=pipeline_options
)
else:
arg_tangents = tuple(
make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)
)
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)
shadconv = _enzyme_fwd_p.bind(
*args,
source=kwargs["source"],
Expand Down Expand Up @@ -1027,7 +1052,6 @@ def dejaxify(x):

def fwd_partial_eval(trace, *args, **kwargs):
assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)
Expand All @@ -1050,6 +1074,9 @@ def fwd_partial_eval(trace, *args, **kwargs):


def primal_partial_eval(trace, *args, **kwargs):
print("trace ", trace)
print("args", args)
print("kwargs", kwargs)
Comment on lines +1077 to +1079
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. We can use import logging and potentially even attach to the jax logger if desired.

pipeline_options = kwargs["pipeline_options"]
if (
not pipeline_options.mlir_ad()
Expand All @@ -1058,73 +1085,20 @@ def primal_partial_eval(trace, *args, **kwargs):
):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)

if not (all_primals_known and some_tangents_unknown):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

shadow_aug_args = primals + tangents

out_shapes = kwargs["out_shapes"]
out_shapes2 = out_shapes[: len(out_shapes) // 2]
del kwargs["out_shapes"]

shadows_known = trace.default_process_primitive(
_enzyme_shadow_aug_p, shadow_aug_args, kwargs | {"out_shapes": out_shapes2}
)

passes = pipeline_options.pass_pipeline()
start = passes.rindex("enzyme-wrap{")
prev_passes = passes[:start]
end = passes.index("}", start)
post_passes = passes[end + 1 :]
newpasses = prev_passes + post_passes[1:]

if pipeline_options.stablehlo_inject():
pipeline_options = JaXPipeline(newpasses)
else:
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
_, acts, _ = activity_from_pipeline(pipeline_options.pass_pipeline())

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
source = (in_tree, avals, outmap2, mfunc)

primalret = trace.default_process_primitive(
_enzyme_primal_p,
primals,
{
"out_shapes": out_shapes2,
"source": source,
"fn": kwargs["fn"],
"argv": kwargs["argv"],
"lang": kwargs["lang"],
"pipeline_options": pipeline_options,
},
)
return primalret + shadows_known
primals = []
tangents = []
avals = {}

for idx, v in enumerate(acts):
avals[idx] = in_idx_map[len(primals) + len(tangents)]
primals.append(args[len(primals) + len(tangents)])
if v == "enzyme_dup":
tangents.append(args[len(primals) + len(tangents)])

pe.custom_partial_eval_rules[_enzyme_primal_p] = primal_partial_eval


def primal_partial_eval(trace, *args, **kwargs):
pipeline_options = kwargs["pipeline_options"]
if (
not pipeline_options.mlir_ad()
or kwargs["lang"] != LANG_MHLO
or pipeline_options.ad_level() == 0
):
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)

assert len(args) % 2 == 0
nr_primals = len(args) // 2
primals, tangents = args[0::2], args[1::2]
all_primals_known = all(p.is_known() for p in primals)
some_tangents_unknown = any(not t.is_known() for t in tangents)

Expand Down Expand Up @@ -1153,9 +1127,6 @@ def primal_partial_eval(trace, *args, **kwargs):
else:
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
source = (in_tree, avals, outmap2, mfunc)

Expand All @@ -1180,14 +1151,16 @@ def primal_partial_eval(trace, *args, **kwargs):
def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
pipeline_options = kwargs["pipeline_options"]
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
prim_args = prim_args[0 : len(prim_args) // 2]

passes = pipeline_options.pass_pipeline()
start = passes.rindex("enzyme-wrap{")
prev_passes = passes[:start]
end = passes.index("}", start)
post_passes = passes[end + 1 :]
ad_pass = passes[start : end + 1]

_, acts, _ = activity_from_pipeline(ad_pass)

ad_pass = ad_pass.replace("enzyme_dup", "enzyme_out")
ad_pass = ad_pass.replace("ForwardMode", "ReverseModeCombined")
newpasses = (
Expand All @@ -1204,7 +1177,16 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):

(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]

avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
prim_args = prim_args[: len(acts)]

avals = {}
argidx = 0
for idx, v in enumerate(acts):
avals[idx] = in_idx_map[argidx]
argidx += 1
if v == "enzyme_dup":
argidx += 1

outmap = avals

primal_in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args)
Expand Down