|
22 | 22 | from typing import Any, TypeVar
|
23 | 23 | import weakref
|
24 | 24 |
|
| 25 | +import jax |
25 | 26 | from jax._src import ad_checkpoint
|
26 | 27 | from jax._src import ad_util
|
27 | 28 | from jax._src import api
|
@@ -809,98 +810,61 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
|
809 | 810 | carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
|
810 | 811 | else:
|
811 | 812 | assert False, "Fixpoint not reached"
|
812 |
| - num_res = len(res_avals) |
| 813 | + num_res_in = len(res_avals) # number of res inputs to jaxpr_unknown |
813 | 814 | del res_avals, carry_uk_out
|
814 | 815 |
|
815 | 816 | # Instantiate those inputs which must be treated as unknown from the fixpoint.
|
816 | 817 | tracers = tuple(trace.instantiate_const(t) if uk else t
|
817 | 818 | for t, uk in zip(tracers, unknowns))
|
818 | 819 |
|
819 |
| - # The residual inputs and outputs of the jaxprs produced haven't yet been |
820 |
| - # adapted to the scan calling convention; in particular, jaxpr_known has its |
821 |
| - # residual outputs all at the end, meaning they're extensive outputs (which is |
822 |
| - # fully general but may be wasteful for residuals which are loop-invariant) |
823 |
| - # while jaxpr_unknown has its corresponding residual inputs at the front (just |
824 |
| - # as a convention with partial_eval_jaxpr_nounits), making them constant |
825 |
| - # inputs. To make them consistent, we move the residual inputs on |
826 |
| - # jaxpr_unknown to the end, even though we may move some back in the sequel. |
827 |
| - jaxpr_unknown = pe.move_binders_to_back( |
828 |
| - jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) |
829 |
| - |
830 |
| - # At this point, all residuals are treated as extensive outputs of jaxpr_known |
831 |
| - # (and extensive inputs to jaxpr_unknown). But residuals that are loop- |
832 |
| - # invariant can be hoisted out of the scan, rather than letting them get |
833 |
| - # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't |
834 |
| - # want to broadcast the matrix!). So, outside the loop we perform a partial |
835 |
| - # evaluation with known 'const' inputs (but all other inputs unknown). |
836 |
| - const_pvals = [pe.PartialVal.known(t.pval.get_known()) |
837 |
| - for t in tracers[:num_consts] if t.pval.is_known()] |
838 |
| - other_pvals = [pe.PartialVal.unknown(aval) |
839 |
| - for aval in jaxpr_known.in_avals[len(const_pvals):]] |
840 |
| - with source_info_util.reset_name_stack(): |
841 |
| - jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( |
842 |
| - lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), |
843 |
| - debug_info=jaxpr_known.jaxpr.debug_info), |
844 |
| - const_pvals + other_pvals, |
845 |
| - instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) |
846 |
| - jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) |
847 |
| - # The above trace_to_jaxpr_nounits call computed loop-invariant residuals |
848 |
| - # (known values in invar_pvals_out) and also computed loop-invariant values |
849 |
| - # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the |
850 |
| - # previous consts). We need to collect the computed intensive residuals, and |
851 |
| - # move corresponding intensive residual binders in jaxpr_unknown to the front. |
852 |
| - res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] |
853 |
| - intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] |
854 |
| - jaxpr_unknown = pe.move_binders_to_front( |
855 |
| - jaxpr_unknown, |
856 |
| - [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) |
857 |
| - del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals |
858 |
| - # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and |
859 |
| - # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. |
860 |
| - |
861 |
| - # As another optimization, for any extensive inputs that are just forwarded to |
862 |
| - # extensive outputs, to avoid a copy (which would be looping over |
863 |
| - # dynamic-update-slice) we'd rather forward the input tracer/value. That means |
864 |
| - # pruning some outputs from jaxpr_known here, and updating `out_flat` below. |
865 |
| - fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) |
866 |
| - # Prune fwds_known to include only extensive input to extensive output. |
867 |
| - fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and |
868 |
| - in_idx is not None and |
869 |
| - in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) |
870 |
| - else None for out_idx, in_idx in enumerate(fwds_known)] |
871 |
| - # Drop any extensive output we can instead get by forwarding an input. |
872 |
| - # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint |
873 |
| - jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts |
874 |
| - jaxpr_known_ = jaxpr_known_.replace( |
875 |
| - outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None]) |
876 |
| - jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) |
877 |
| - del jaxpr_known_ |
878 |
| - # We use `fwds_known` below when forming the output of scanning jaxpr_known. |
| 820 | + known_inputs = [t.pval.get_known() for t in tracers if t.pval.is_known()] |
| 821 | + |
| 822 | + # Residuals that are just forwarded from constants or scanned-over inputs |
| 823 | + # can be passed directly to the unknown scan (as constants or scanned-over |
| 824 | + # inputs, respectively). So we prune them as outputs from jaxpr_known. |
| 825 | + # Recall the outputs of jaxpr_known currently look like [*knowns, *res]. |
| 826 | + num_knowns_out = len(jaxpr_known.out_avals) - num_res_in |
| 827 | + num_consts_known = num_consts - sum(const_uk) |
| 828 | + num_carry_known = num_carry - sum(carry_uk) |
| 829 | + del num_consts, num_carry |
| 830 | + in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr_known.jaxpr) |
| 831 | + in_fwd = [f if out_idx >= num_knowns_out and f is not None and |
| 832 | + (f < num_consts_known or f >= num_consts_known + num_carry_known) |
| 833 | + and isinstance(known_inputs[f], jax.Array) # no np.ndarrays |
| 834 | + else None for out_idx, f in enumerate(in_fwd)] |
| 835 | + jaxpr_known = pe.prune_closed_jaxpr_outputs(jaxpr_known, [f is None for f in in_fwd]) |
| 836 | + |
| 837 | + # All jaxpr_unknown residual binders are at the front like [*res, *unknowns], |
| 838 | + # but to match the scan body calling convention of [*consts, *carry, *ext_in], |
| 839 | + # we move the binders that correspond to extensive residuals (ie not forwarded |
| 840 | + # from jaxpr_known's consts) to the extensive slots. |
| 841 | + num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in |
| 842 | + res_to_move = [f is None or f >= num_consts_known + num_carry_known |
| 843 | + for f in in_fwd[num_knowns_out:]] |
| 844 | + jaxpr_unknown = pe.move_binders_to_back(jaxpr_unknown, |
| 845 | + res_to_move + [False] * num_unk_in) |
879 | 846 |
|
880 | 847 | # Run the known part of the scan (if it has any outputs or effects).
|
881 |
| - known_inputs = (list(jaxpr_known_consts) + |
882 |
| - [t.pval.get_known() for t in tracers[num_consts:] |
883 |
| - if t.pval.is_known()]) |
884 | 848 | if not jaxpr_known.out_avals and not jaxpr_known.effects:
|
885 | 849 | out_known = []
|
886 | 850 | else:
|
887 |
| - linear_known = [False] * len(known_inputs) # conservative! |
| 851 | + linear_known = [l for l, uk in zip(linear, unknowns) if not uk] |
888 | 852 | out_known = scan_p.bind(
|
889 | 853 | *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
|
890 |
| - num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), |
| 854 | + num_consts=num_consts_known, num_carry=num_carry_known, |
891 | 855 | linear=tuple(linear_known), unroll=unroll,
|
892 | 856 | _split_transpose=_split_transpose)
|
893 | 857 | del linear_known
|
894 |
| - # Complete the known output by filling in forwarded values using fwds_known. |
895 |
| - out_known_iter = iter(out_known) |
896 |
| - out_known = [next(out_known_iter) if f is None |
897 |
| - else _maybe_put(known_inputs[f]) for f in fwds_known] |
898 |
| - assert next(out_known_iter, None) is None |
899 |
| - del known_inputs, out_known_iter |
900 | 858 |
|
901 |
| - # Split known outputs from residuals. |
902 |
| - out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) |
903 |
| - assert len(intensive_res) + len(extensive_res) == num_res |
| 859 | + # Complete out_known by filling in forwards from known inputs. |
| 860 | + out_known_ = iter(out_known) |
| 861 | + out_known = [next(out_known_) if f is None else known_inputs[f] for f in in_fwd] |
| 862 | + assert next(out_known_, None) is None |
| 863 | + |
| 864 | + # Split known outputs from residuals, and const-forwarded (intensive) |
| 865 | + # residuals from other (extensive) residuals. |
| 866 | + out_known, all_res = split_list(out_known, [num_knowns_out]) |
| 867 | + intensive_res, extensive_res = partition_list(res_to_move, all_res) |
904 | 868 |
|
905 | 869 | # Create input tracers for jaxpr_unknown bind.
|
906 | 870 | unknown_inputs = [t for t in tracers if not t.pval.is_known()]
|
@@ -1254,87 +1218,58 @@ def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn):
|
1254 | 1218 | assert False, "Fixpoint not reached"
|
1255 | 1219 | jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts)
|
1256 | 1220 | jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
|
1257 |
| - |
1258 |
| - # Move all residual binders to the back of jaxpr_staged so they're extensive. |
1259 |
| - # TODO(mattjj): make jaxpr_staged only take instantiated inputs |
1260 | 1221 | res_avals = jaxpr_staged.in_avals[:num_res]
|
1261 |
| - jaxpr_staged = pe.move_binders_to_back( |
1262 |
| - jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals)) |
1263 | 1222 |
|
1264 | 1223 | # Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
|
1265 | 1224 | # passing in_inst argument to partial_eval_jaxpr_custom above).
|
1266 | 1225 | new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
1267 | 1226 | if type(x) is core.Var and not inst]
|
1268 | 1227 | inst_in = [True] * len(inst_in)
|
1269 | 1228 |
|
1270 |
| - # As an optimization, hoist loop-invariant residuals out of the loop rather |
1271 |
| - # than using extensive outputs for them. See _scan_partial_eval for comments. |
1272 |
| - num_const_known = len(const_uk) - sum(const_uk) |
1273 |
| - num_carry_known = len(carry_uk) - sum(carry_uk) |
1274 |
| - num_xs_known = len( xs_uk) - sum( xs_uk) |
1275 |
| - jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ |
1276 |
| - pe.partial_eval_jaxpr_nounits( |
1277 |
| - jaxpr_known, |
1278 |
| - [False] * num_const_known + [True] * (num_carry_known + num_xs_known), |
1279 |
| - [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) |
1280 |
| - # jaxpr_known_hoist produces intensive residuals followed by the constants for |
1281 |
| - # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. |
1282 |
| - _, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res]) |
1283 |
| - jaxpr_staged = pe.move_binders_to_front( |
1284 |
| - jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res)) |
1285 |
| - num_intensive_res = len(loop_dep_res) - sum(loop_dep_res) |
1286 |
| - del loop_dep, num_carry_known, num_xs_known, const_uk |
1287 |
| - |
1288 |
| - # Create residual variables. |
1289 |
| - intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) |
1290 |
| - ext_avals = [core.unmapped_aval(eqn.params['length'], 0, a) |
1291 |
| - for a in ext_avals_mapped] |
| 1229 | + # No res forwards should be possible because jaxpr_unknown already has all |
| 1230 | + # inputs instantiated. |
| 1231 | + num_knowns_out = len(jaxpr_known.out_avals) - num_res |
| 1232 | + num_consts_known = len(const_uk) - sum(const_uk) |
| 1233 | + num_carry_known = len(carry_uk) - sum(carry_uk) |
| 1234 | + in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr_known.jaxpr) |
| 1235 | + in_fwd = [f if out_idx >= num_knowns_out and f is not None and |
| 1236 | + (f < num_consts_known or f >= num_consts_known + num_carry_known) |
| 1237 | + else None for out_idx, f in enumerate(in_fwd)] |
| 1238 | + assert all(f is None for f in in_fwd) |
| 1239 | + del in_fwd |
| 1240 | + |
| 1241 | + # Move all residual binders to the back of jaxpr_staged so they're extensive. |
| 1242 | + jaxpr_staged = pe.move_binders_to_back( |
| 1243 | + jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals)) |
| 1244 | + |
| 1245 | + # Create variables for extensive residuals output by the known eqn. |
1292 | 1246 | newvar = core.gensym()
|
1293 |
| - intensive_res = _map(newvar, intensive_avals) |
1294 |
| - extensive_res = _map(newvar, ext_avals) |
| 1247 | + ext_res_avals = _map(partial(core.unmapped_aval, eqn.params['length'], 0), res_avals) |
| 1248 | + ext_res_out_binders = _map(newvar, ext_res_avals) |
1295 | 1249 |
|
1296 |
| - # Create known eqn, which is a call_p combining evaluation of |
1297 |
| - # jaxpr_known_hoist and a scan of jaxpr_known_loop. |
| 1250 | + # Create known eqn. |
1298 | 1251 | ins_known, _ = partition_list(unks_in, eqn.invars)
|
1299 | 1252 | out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
1300 |
| - # jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist |
1301 |
| - # (corresponding to consts_known_lp_avals) followed by known carry and xs. |
1302 |
| - linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk] |
1303 |
| - _, linear_known_ = split_list(linear_known_, [num_const_known]) |
1304 |
| - linear_known = [False] * len(consts_known_lp_avals) + linear_known_ |
1305 |
| - params_known = dict(eqn.params, jaxpr=jaxpr_known_loop, |
1306 |
| - num_consts=len(consts_known_lp_avals), |
1307 |
| - num_carry=len(carry_uk)-sum(carry_uk), |
| 1253 | + linear_known = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk] |
| 1254 | + params_known = dict(eqn.params, jaxpr=jaxpr_known, |
| 1255 | + num_consts=num_consts_known, num_carry=num_carry_known, |
1308 | 1256 | linear=tuple(linear_known))
|
1309 |
| - |
1310 |
| - def known(*ins_known): |
1311 |
| - consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) |
1312 |
| - out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) |
1313 |
| - intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) |
1314 |
| - out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) |
1315 |
| - return [*intensive_res, *out_loop] |
1316 |
| - call_jaxpr_, _, call_jaxpr_consts, () = pe.trace_to_jaxpr_dynamic( |
1317 |
| - lu.wrap_init(known, debug_info=jaxpr_known_hoist.jaxpr.debug_info), |
1318 |
| - [v.aval for v in ins_known]) |
1319 |
| - call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) |
1320 | 1257 | eqn_known = pe.new_jaxpr_eqn(
|
1321 |
| - ins_known, [*intensive_res, *out_binders_known, *extensive_res], |
1322 |
| - core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, |
1323 |
| - eqn.source_info, eqn.ctx) |
| 1258 | + ins_known, [*out_binders_known, *ext_res_out_binders], |
| 1259 | + scan_p, params_known, jaxpr_known.effects, eqn.source_info, eqn.ctx) |
1324 | 1260 |
|
1325 | 1261 | # Create the staged eqn.
|
1326 | 1262 | _, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
1327 |
| - linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) + |
1328 |
| - [False] * len(extensive_res)) |
| 1263 | + linear_staged = list(eqn.params['linear']) + [False] * len(ext_res_out_binders) |
1329 | 1264 | params_staged = dict(eqn.params, jaxpr=jaxpr_staged,
|
1330 |
| - num_consts=len(intensive_res) + eqn.params['num_consts'], |
| 1265 | + num_consts=eqn.params['num_consts'], |
1331 | 1266 | linear=tuple(linear_staged))
|
1332 |
| - eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res], |
| 1267 | + eqn_staged = pe.new_jaxpr_eqn([*eqn.invars, *ext_res_out_binders], |
1333 | 1268 | out_binders_staged, eqn.primitive,
|
1334 | 1269 | params_staged, jaxpr_staged.effects,
|
1335 | 1270 | eqn.source_info, eqn.ctx)
|
1336 | 1271 |
|
1337 |
| - new_vars = [*new_inst, *intensive_res, *extensive_res] |
| 1272 | + new_vars = [*new_inst, *ext_res_out_binders] |
1338 | 1273 | return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
1339 | 1274 |
|
1340 | 1275 | def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
|
|
0 commit comments