@@ -470,6 +470,15 @@ def to_jax(ty):
470
470
return {"f32" : jnp .float32 , "f64" : jnp .float64 }[tystr ]
471
471
472
472
473
+ def activity_from_pipeline (pass_pipeline ):
474
+ start = pass_pipeline .index ("argTys=" )
475
+ end = pass_pipeline .index (" " , start )
476
+ acts = pass_pipeline [start + len ("argTys=" ) : end ].split ("," )
477
+ pre_act = pass_pipeline [: start + len ("argTys=" )]
478
+ post_act = pass_pipeline [end :]
479
+ return pre_act , acts , post_act
480
+
481
+
473
482
def _enzyme_primal_lowering (
474
483
ctx : jax_mlir .LoweringRuleContext ,
475
484
* args_flat : ir .Value ,
@@ -523,12 +532,10 @@ def _enzyme_primal_lowering(
523
532
if i not in in_idx_map or in_idx_map [i ] in kept
524
533
)
525
534
if len (kept ) != len (orig_shapes ):
526
- post = "," .join (["enzyme_dup" ] * len (kept ))
527
- prev = "," .join (["enzyme_dup" ] * len (orig_shapes ))
528
- pass_pipeline = pass_pipeline .replace (prev , post )
529
- post = "," .join (["enzyme_out" ] * len (kept ))
530
- prev = "," .join (["enzyme_out" ] * len (orig_shapes ))
531
- pass_pipeline = pass_pipeline .replace (prev , post )
535
+ if "argTys=" in pass_pipeline :
536
+ pre_act , acts , post_act = activity_from_pipeline (pass_pipeline )
537
+ acts2 = [act for (i , act ) in enumerate (acts ) if i in kept ]
538
+ pass_pipeline = pre_act + "," .join (acts2 ) + post_act
532
539
533
540
out_types = [
534
541
shape
@@ -917,18 +924,36 @@ def cpp_call(
917
924
918
925
919
926
def enzyme_jvp (arg_primals , arg_tangents , ** kwargs ):
927
+ print ("arg_tan" , arg_tangents )
928
+ print ("kwargs" , kwargs )
929
+
920
930
# TODO propagate activity info rather than make_zero
921
931
def make_zero (tan , prim ):
922
932
return lax .zeros_like_array (prim ) if type (tan ) is ad .Zero else tan
923
933
924
- arg_tangents = tuple (make_zero (t , p ) for (t , p ) in zip (arg_tangents , arg_primals ))
925
- args = tuple (v for t in zip (arg_primals , arg_tangents ) for v in t )
926
-
927
934
pipeline_options = kwargs ["pipeline_options" ]
928
935
929
936
shadconv = None
930
937
if pipeline_options .mlir_ad () and kwargs ["lang" ] == LANG_MHLO :
931
- act_tup = "," .join (["enzyme_dup" for a in arg_primals ])
938
+ (in_tree , in_idx_map , out_idx_map , mfunc ) = kwargs ["source" ]
939
+ act_tup = []
940
+ args = []
941
+
942
+ avals = {}
943
+
944
+ for idx , (v , s ) in enumerate (zip (arg_primals , arg_tangents )):
945
+ avals [len (args )] = in_idx_map [idx ]
946
+ args .append (v )
947
+ if type (s ) is ad .Zero :
948
+ act_tup .append ("enzyme_const" )
949
+ else :
950
+ act_tup .append ("enzyme_dup" )
951
+ avals [len (args )] = in_idx_map [idx ]
952
+ args .append (s )
953
+
954
+ args = tuple (args )
955
+ act_tup = "," .join (act_tup )
956
+
932
957
afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize"
933
958
newpasses = (
934
959
"inline{default-pipeline=canonicalize max-iterations=4},"
@@ -940,10 +965,10 @@ def make_zero(tan, prim):
940
965
if pipeline_options .pass_pipeline () != "" :
941
966
oldpasses = pipeline_options .pass_pipeline ()
942
967
if "enzyme-wrap" in oldpasses :
943
- start = passes .rindex ("enzyme-wrap{" )
944
- end = passes .index ("}" , start )
945
- prev_passes = passes [:end ]
946
- newpasses = prev_passes + afterad + newpasses + passes [end :]
968
+ start = oldpasses .rindex ("enzyme-wrap{" )
969
+ end = oldpasses .index ("}" , start )
970
+ prev_passes = oldpasses [:end ]
971
+ newpasses = prev_passes + afterad + newpasses + oldpasses [end :]
947
972
else :
948
973
newpasses = newpasses + "," + oldpasses
949
974
if pipeline_options .stablehlo_inject ():
@@ -954,10 +979,6 @@ def make_zero(tan, prim):
954
979
for o in kwargs ["out_shapes" ]:
955
980
outshapes2 .append (o )
956
981
outshapes2 .append (o )
957
- (in_tree , in_idx_map , out_idx_map , mfunc ) = kwargs ["source" ]
958
- avals = {2 * k : v for k , v in in_idx_map .items ()} | {
959
- 2 * k + 1 : v for k , v in in_idx_map .items ()
960
- }
961
982
out_idx_map2 = {2 * k : v for k , v in out_idx_map .items ()} | {
962
983
2 * k + 1 : v for k , v in out_idx_map .items ()
963
984
}
@@ -972,6 +993,10 @@ def make_zero(tan, prim):
972
993
pipeline_options = pipeline_options
973
994
)
974
995
else :
996
+ arg_tangents = tuple (
997
+ make_zero (t , p ) for (t , p ) in zip (arg_tangents , arg_primals )
998
+ )
999
+ args = tuple (v for t in zip (arg_primals , arg_tangents ) for v in t )
975
1000
shadconv = _enzyme_fwd_p .bind (
976
1001
* args ,
977
1002
source = kwargs ["source" ],
@@ -1027,7 +1052,6 @@ def dejaxify(x):
1027
1052
1028
1053
def fwd_partial_eval (trace , * args , ** kwargs ):
1029
1054
assert len (args ) % 2 == 0
1030
- nr_primals = len (args ) // 2
1031
1055
primals , tangents = args [0 ::2 ], args [1 ::2 ]
1032
1056
all_primals_known = all (p .is_known () for p in primals )
1033
1057
some_tangents_unknown = any (not t .is_known () for t in tangents )
@@ -1050,6 +1074,9 @@ def fwd_partial_eval(trace, *args, **kwargs):
1050
1074
1051
1075
1052
1076
def primal_partial_eval (trace , * args , ** kwargs ):
1077
+ print ("trace " , trace )
1078
+ print ("args" , args )
1079
+ print ("kwargs" , kwargs )
1053
1080
pipeline_options = kwargs ["pipeline_options" ]
1054
1081
if (
1055
1082
not pipeline_options .mlir_ad ()
@@ -1058,73 +1085,20 @@ def primal_partial_eval(trace, *args, **kwargs):
1058
1085
):
1059
1086
return trace .default_process_primitive (_enzyme_primal_p , args , kwargs )
1060
1087
1061
- assert len (args ) % 2 == 0
1062
- nr_primals = len (args ) // 2
1063
- primals , tangents = args [0 ::2 ], args [1 ::2 ]
1064
- all_primals_known = all (p .is_known () for p in primals )
1065
- some_tangents_unknown = any (not t .is_known () for t in tangents )
1066
-
1067
- if not (all_primals_known and some_tangents_unknown ):
1068
- return trace .default_process_primitive (_enzyme_primal_p , args , kwargs )
1069
-
1070
- shadow_aug_args = primals + tangents
1071
-
1072
- out_shapes = kwargs ["out_shapes" ]
1073
- out_shapes2 = out_shapes [: len (out_shapes ) // 2 ]
1074
- del kwargs ["out_shapes" ]
1075
-
1076
- shadows_known = trace .default_process_primitive (
1077
- _enzyme_shadow_aug_p , shadow_aug_args , kwargs | {"out_shapes" : out_shapes2 }
1078
- )
1079
-
1080
- passes = pipeline_options .pass_pipeline ()
1081
- start = passes .rindex ("enzyme-wrap{" )
1082
- prev_passes = passes [:start ]
1083
- end = passes .index ("}" , start )
1084
- post_passes = passes [end + 1 :]
1085
- newpasses = prev_passes + post_passes [1 :]
1086
-
1087
- if pipeline_options .stablehlo_inject ():
1088
- pipeline_options = JaXPipeline (newpasses )
1089
- else :
1090
- pipeline_options = NewXLAPipeline (newpasses , pipeline_options .mlir_ad ())
1088
+ _ , acts , _ = activity_from_pipeline (pipeline_options .pass_pipeline ())
1091
1089
1092
1090
(in_tree , in_idx_map , out_idx_map , mfunc ) = kwargs ["source" ]
1093
1091
1094
- avals = {k // 2 : v for k , v in in_idx_map .items () if k % 2 == 0 }
1095
- outmap2 = {k // 2 : v for k , v in out_idx_map .items () if k % 2 == 0 }
1096
- source = (in_tree , avals , outmap2 , mfunc )
1097
-
1098
- primalret = trace .default_process_primitive (
1099
- _enzyme_primal_p ,
1100
- primals ,
1101
- {
1102
- "out_shapes" : out_shapes2 ,
1103
- "source" : source ,
1104
- "fn" : kwargs ["fn" ],
1105
- "argv" : kwargs ["argv" ],
1106
- "lang" : kwargs ["lang" ],
1107
- "pipeline_options" : pipeline_options ,
1108
- },
1109
- )
1110
- return primalret + shadows_known
1092
+ primals = []
1093
+ tangents = []
1094
+ avals = {}
1111
1095
1096
+ for idx , v in enumerate (acts ):
1097
+ avals [idx ] = in_idx_map [len (primals ) + len (tangents )]
1098
+ primals .append (args [len (primals ) + len (tangents )])
1099
+ if v == "enzyme_dup" :
1100
+ tangents .append (args [len (primals ) + len (tangents )])
1112
1101
1113
- pe .custom_partial_eval_rules [_enzyme_primal_p ] = primal_partial_eval
1114
-
1115
-
1116
- def primal_partial_eval (trace , * args , ** kwargs ):
1117
- pipeline_options = kwargs ["pipeline_options" ]
1118
- if (
1119
- not pipeline_options .mlir_ad ()
1120
- or kwargs ["lang" ] != LANG_MHLO
1121
- or pipeline_options .ad_level () == 0
1122
- ):
1123
- return trace .default_process_primitive (_enzyme_primal_p , args , kwargs )
1124
-
1125
- assert len (args ) % 2 == 0
1126
- nr_primals = len (args ) // 2
1127
- primals , tangents = args [0 ::2 ], args [1 ::2 ]
1128
1102
all_primals_known = all (p .is_known () for p in primals )
1129
1103
some_tangents_unknown = any (not t .is_known () for t in tangents )
1130
1104
@@ -1153,9 +1127,6 @@ def primal_partial_eval(trace, *args, **kwargs):
1153
1127
else :
1154
1128
pipeline_options = NewXLAPipeline (newpasses , pipeline_options .mlir_ad ())
1155
1129
1156
- (in_tree , in_idx_map , out_idx_map , mfunc ) = kwargs ["source" ]
1157
-
1158
- avals = {k // 2 : v for k , v in in_idx_map .items () if k % 2 == 0 }
1159
1130
outmap2 = {k // 2 : v for k , v in out_idx_map .items () if k % 2 == 0 }
1160
1131
source = (in_tree , avals , outmap2 , mfunc )
1161
1132
@@ -1180,14 +1151,16 @@ def primal_partial_eval(trace, *args, **kwargs):
1180
1151
def enzyme_vjp (shadow_rets , * prim_args , ** kwargs ):
1181
1152
pipeline_options = kwargs ["pipeline_options" ]
1182
1153
if pipeline_options .mlir_ad () and kwargs ["lang" ] == LANG_MHLO :
1183
- prim_args = prim_args [0 : len (prim_args ) // 2 ]
1184
1154
1185
1155
passes = pipeline_options .pass_pipeline ()
1186
1156
start = passes .rindex ("enzyme-wrap{" )
1187
1157
prev_passes = passes [:start ]
1188
1158
end = passes .index ("}" , start )
1189
1159
post_passes = passes [end + 1 :]
1190
1160
ad_pass = passes [start : end + 1 ]
1161
+
1162
+ _ , acts , _ = activity_from_pipeline (ad_pass )
1163
+
1191
1164
ad_pass = ad_pass .replace ("enzyme_dup" , "enzyme_out" )
1192
1165
ad_pass = ad_pass .replace ("ForwardMode" , "ReverseModeCombined" )
1193
1166
newpasses = (
@@ -1204,7 +1177,16 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs):
1204
1177
1205
1178
(in_tree , in_idx_map , out_idx_map , mfunc ) = kwargs ["source" ]
1206
1179
1207
- avals = {k // 2 : v for k , v in in_idx_map .items () if k % 2 == 0 }
1180
+ prim_args = prim_args [: len (acts )]
1181
+
1182
+ avals = {}
1183
+ argidx = 0
1184
+ for idx , v in enumerate (acts ):
1185
+ avals [idx ] = in_idx_map [argidx ]
1186
+ argidx += 1
1187
+ if v == "enzyme_dup" :
1188
+ argidx += 1
1189
+
1208
1190
outmap = avals
1209
1191
1210
1192
primal_in_shapes = tuple ((a .shape , jaxify (a .dtype )) for a in prim_args )
0 commit comments