Skip to content

Commit d2030a6

Browse files
authored
Handle inactive args from context (#41)
1 parent 3cea3b1 commit d2030a6

File tree

1 file changed

+68
-86
lines changed

1 file changed

+68
-86
lines changed

src/enzyme_ad/jax/primitives.py

Lines changed: 68 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,15 @@ def to_jax(ty):
470470
return {"f32": jnp.float32, "f64": jnp.float64}[tystr]
471471

472472

473+
def activity_from_pipeline(pass_pipeline):
474+
start = pass_pipeline.index("argTys=")
475+
end = pass_pipeline.index(" ", start)
476+
acts = pass_pipeline[start + len("argTys=") : end].split(",")
477+
pre_act = pass_pipeline[: start + len("argTys=")]
478+
post_act = pass_pipeline[end:]
479+
return pre_act, acts, post_act
480+
481+
473482
def _enzyme_primal_lowering(
474483
ctx: jax_mlir.LoweringRuleContext,
475484
*args_flat: ir.Value,
@@ -523,12 +532,10 @@ def _enzyme_primal_lowering(
523532
if i not in in_idx_map or in_idx_map[i] in kept
524533
)
525534
if len(kept) != len(orig_shapes):
526-
post = ",".join(["enzyme_dup"] * len(kept))
527-
prev = ",".join(["enzyme_dup"] * len(orig_shapes))
528-
pass_pipeline = pass_pipeline.replace(prev, post)
529-
post = ",".join(["enzyme_out"] * len(kept))
530-
prev = ",".join(["enzyme_out"] * len(orig_shapes))
531-
pass_pipeline = pass_pipeline.replace(prev, post)
535+
if "argTys=" in pass_pipeline:
536+
pre_act, acts, post_act = activity_from_pipeline(pass_pipeline)
537+
acts2 = [act for (i, act) in enumerate(acts) if i in kept]
538+
pass_pipeline = pre_act + ",".join(acts2) + post_act
532539

533540
out_types = [
534541
shape
@@ -917,18 +924,36 @@ def cpp_call(
917924

918925

919926
def enzyme_jvp(arg_primals, arg_tangents, **kwargs):
927+
print("arg_tan", arg_tangents)
928+
print("kwargs", kwargs)
929+
920930
# TODO propagate activity info rather than make_zero
921931
def make_zero(tan, prim):
922932
return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan
923933

924-
arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals))
925-
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)
926-
927934
pipeline_options = kwargs["pipeline_options"]
928935

929936
shadconv = None
930937
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
931-
act_tup = ",".join(["enzyme_dup" for a in arg_primals])
938+
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
939+
act_tup = []
940+
args = []
941+
942+
avals = {}
943+
944+
for idx, (v, s) in enumerate(zip(arg_primals, arg_tangents)):
945+
avals[len(args)] = in_idx_map[idx]
946+
args.append(v)
947+
if type(s) is ad.Zero:
948+
act_tup.append("enzyme_const")
949+
else:
950+
act_tup.append("enzyme_dup")
951+
avals[len(args)] = in_idx_map[idx]
952+
args.append(s)
953+
954+
args = tuple(args)
955+
act_tup = ",".join(act_tup)
956+
932957
afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize"
933958
newpasses = (
934959
"inline{default-pipeline=canonicalize max-iterations=4},"
@@ -940,10 +965,10 @@ def make_zero(tan, prim):
940965
if pipeline_options.pass_pipeline() != "":
941966
oldpasses = pipeline_options.pass_pipeline()
942967
if "enzyme-wrap" in oldpasses:
943-
start = passes.rindex("enzyme-wrap{")
944-
end = passes.index("}", start)
945-
prev_passes = passes[:end]
946-
newpasses = prev_passes + afterad + newpasses + passes[end:]
968+
start = oldpasses.rindex("enzyme-wrap{")
969+
end = oldpasses.index("}", start)
970+
prev_passes = oldpasses[:end]
971+
newpasses = prev_passes + afterad + newpasses + oldpasses[end:]
947972
else:
948973
newpasses = newpasses + "," + oldpasses
949974
if pipeline_options.stablehlo_inject():
@@ -954,10 +979,6 @@ def make_zero(tan, prim):
954979
for o in kwargs["out_shapes"]:
955980
outshapes2.append(o)
956981
outshapes2.append(o)
957-
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
958-
avals = {2 * k: v for k, v in in_idx_map.items()} | {
959-
2 * k + 1: v for k, v in in_idx_map.items()
960-
}
961982
out_idx_map2 = {2 * k: v for k, v in out_idx_map.items()} | {
962983
2 * k + 1: v for k, v in out_idx_map.items()
963984
}
@@ -972,6 +993,10 @@ def make_zero(tan, prim):
972993
pipeline_options=pipeline_options
973994
)
974995
else:
996+
arg_tangents = tuple(
997+
make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)
998+
)
999+
args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t)
9751000
shadconv = _enzyme_fwd_p.bind(
9761001
*args,
9771002
source=kwargs["source"],
@@ -1027,7 +1052,6 @@ def dejaxify(x):
10271052

10281053
def fwd_partial_eval(trace, *args, **kwargs):
10291054
assert len(args) % 2 == 0
1030-
nr_primals = len(args) // 2
10311055
primals, tangents = args[0::2], args[1::2]
10321056
all_primals_known = all(p.is_known() for p in primals)
10331057
some_tangents_unknown = any(not t.is_known() for t in tangents)
@@ -1050,6 +1074,9 @@ def fwd_partial_eval(trace, *args, **kwargs):
10501074

10511075

10521076
def primal_partial_eval(trace, *args, **kwargs):
1077+
print("trace ", trace)
1078+
print("args", args)
1079+
print("kwargs", kwargs)
10531080
pipeline_options = kwargs["pipeline_options"]
10541081
if (
10551082
not pipeline_options.mlir_ad()
@@ -1058,73 +1085,20 @@ def primal_partial_eval(trace, *args, **kwargs):
10581085
):
10591086
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)
10601087

1061-
assert len(args) % 2 == 0
1062-
nr_primals = len(args) // 2
1063-
primals, tangents = args[0::2], args[1::2]
1064-
all_primals_known = all(p.is_known() for p in primals)
1065-
some_tangents_unknown = any(not t.is_known() for t in tangents)
1066-
1067-
if not (all_primals_known and some_tangents_unknown):
1068-
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)
1069-
1070-
shadow_aug_args = primals + tangents
1071-
1072-
out_shapes = kwargs["out_shapes"]
1073-
out_shapes2 = out_shapes[: len(out_shapes) // 2]
1074-
del kwargs["out_shapes"]
1075-
1076-
shadows_known = trace.default_process_primitive(
1077-
_enzyme_shadow_aug_p, shadow_aug_args, kwargs | {"out_shapes": out_shapes2}
1078-
)
1079-
1080-
passes = pipeline_options.pass_pipeline()
1081-
start = passes.rindex("enzyme-wrap{")
1082-
prev_passes = passes[:start]
1083-
end = passes.index("}", start)
1084-
post_passes = passes[end + 1 :]
1085-
newpasses = prev_passes + post_passes[1:]
1086-
1087-
if pipeline_options.stablehlo_inject():
1088-
pipeline_options = JaXPipeline(newpasses)
1089-
else:
1090-
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
1088+
_, acts, _ = activity_from_pipeline(pipeline_options.pass_pipeline())
10911089

10921090
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
10931091

1094-
avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
1095-
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
1096-
source = (in_tree, avals, outmap2, mfunc)
1097-
1098-
primalret = trace.default_process_primitive(
1099-
_enzyme_primal_p,
1100-
primals,
1101-
{
1102-
"out_shapes": out_shapes2,
1103-
"source": source,
1104-
"fn": kwargs["fn"],
1105-
"argv": kwargs["argv"],
1106-
"lang": kwargs["lang"],
1107-
"pipeline_options": pipeline_options,
1108-
},
1109-
)
1110-
return primalret + shadows_known
1092+
primals = []
1093+
tangents = []
1094+
avals = {}
11111095

1096+
for idx, v in enumerate(acts):
1097+
avals[idx] = in_idx_map[len(primals) + len(tangents)]
1098+
primals.append(args[len(primals) + len(tangents)])
1099+
if v == "enzyme_dup":
1100+
tangents.append(args[len(primals) + len(tangents)])
11121101

1113-
pe.custom_partial_eval_rules[_enzyme_primal_p] = primal_partial_eval
1114-
1115-
1116-
def primal_partial_eval(trace, *args, **kwargs):
1117-
pipeline_options = kwargs["pipeline_options"]
1118-
if (
1119-
not pipeline_options.mlir_ad()
1120-
or kwargs["lang"] != LANG_MHLO
1121-
or pipeline_options.ad_level() == 0
1122-
):
1123-
return trace.default_process_primitive(_enzyme_primal_p, args, kwargs)
1124-
1125-
assert len(args) % 2 == 0
1126-
nr_primals = len(args) // 2
1127-
primals, tangents = args[0::2], args[1::2]
11281102
all_primals_known = all(p.is_known() for p in primals)
11291103
some_tangents_unknown = any(not t.is_known() for t in tangents)
11301104

@@ -1153,9 +1127,6 @@ def primal_partial_eval(trace, *args, **kwargs):
11531127
else:
11541128
pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad())
11551129

1156-
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
1157-
1158-
avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
11591130
outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0}
11601131
source = (in_tree, avals, outmap2, mfunc)
11611132

@@ -1180,14 +1151,16 @@ def primal_partial_eval(trace, *args, **kwargs):
11801151
def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
11811152
pipeline_options = kwargs["pipeline_options"]
11821153
if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO:
1183-
prim_args = prim_args[0 : len(prim_args) // 2]
11841154

11851155
passes = pipeline_options.pass_pipeline()
11861156
start = passes.rindex("enzyme-wrap{")
11871157
prev_passes = passes[:start]
11881158
end = passes.index("}", start)
11891159
post_passes = passes[end + 1 :]
11901160
ad_pass = passes[start : end + 1]
1161+
1162+
_, acts, _ = activity_from_pipeline(ad_pass)
1163+
11911164
ad_pass = ad_pass.replace("enzyme_dup", "enzyme_out")
11921165
ad_pass = ad_pass.replace("ForwardMode", "ReverseModeCombined")
11931166
newpasses = (
@@ -1204,7 +1177,16 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
12041177

12051178
(in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"]
12061179

1207-
avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0}
1180+
prim_args = prim_args[: len(acts)]
1181+
1182+
avals = {}
1183+
argidx = 0
1184+
for idx, v in enumerate(acts):
1185+
avals[idx] = in_idx_map[argidx]
1186+
argidx += 1
1187+
if v == "enzyme_dup":
1188+
argidx += 1
1189+
12081190
outmap = avals
12091191

12101192
primal_in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args)

0 commit comments

Comments
 (0)