[frontend] allow var_mean to be implemented in one pass #1285
Closed
Description
Currently PyTorch inductor is forced to implement torch.var_mean
as two passes over the input data, which causes a slowdown for batch norm. To allow single-pass computation we need a new reduction operator tl.welford(mean, m2, count)
which implements the combination step of parallel Welford's algortihm.
A more general solution might be to instead add a tl.reduce
which takes a function acting on scalars, so users can write their own reductions without needing to change the triton language.