|  | 
|  | 1 | +#define _GLIBCXX_USE_CXX11_ABI 1 | 
|  | 2 | +#define HL_PERMIT_FAILED_UNROLL 1 | 
|  | 3 | + | 
|  | 4 | +#include "mul.hpp" | 
|  | 5 | + | 
|  | 6 | +#include "Halide.h" | 
|  | 7 | +#include "HalideBuffer.h" | 
|  | 8 | + | 
|  | 9 | +#include <unordered_map> | 
|  | 10 | + | 
|  | 11 | +/* Estimates for some of the Halide parameters */ | 
|  | 12 | +static const int maxHalideRow = 1000000; | 
|  | 13 | +static const int featureCount = 32; | 
|  | 14 | +static const int activeRows = 60000; | 
|  | 15 | +static const int groups = 1; | 
|  | 16 | +static const int featureRowCount = 100000; | 
|  | 17 | + | 
|  | 18 | +template <typename Operation> | 
|  | 19 | +using MulStrategyMap = | 
|  | 20 | +    std::unordered_map<LayerDimensions, std::unique_ptr<Operation>, | 
|  | 21 | +                       LayerDimensionsHash>; | 
|  | 22 | + | 
|  | 23 | +template <typename Operation> | 
|  | 24 | +const Operation &getHalideMul(int inFeatureCount, int outFeatureCount, | 
|  | 25 | +                              int groups, bool cuda, | 
|  | 26 | +                              MulStrategyMap<Operation> &container) { | 
|  | 27 | +  const LayerDimensions dims = {inFeatureCount, outFeatureCount, groups, cuda}; | 
|  | 28 | +  auto it = container.find(dims); | 
|  | 29 | + | 
|  | 30 | +  if (it != container.end()) { | 
|  | 31 | +    return *(it->second); | 
|  | 32 | +  } | 
|  | 33 | + | 
|  | 34 | +  auto mul = | 
|  | 35 | +      container.insert(std::make_pair(dims, std::make_unique<Operation>(dims))) | 
|  | 36 | +          .first->second.get(); | 
|  | 37 | +  return *mul; | 
|  | 38 | +} | 
|  | 39 | + | 
|  | 40 | +struct HalideMulFactory::Impl { | 
|  | 41 | +  MulStrategyMap<HalideMulBackward> backward; | 
|  | 42 | +  MulStrategyMap<HalideMulForward> forward; | 
|  | 43 | +}; | 
|  | 44 | + | 
|  | 45 | +HalideMulFactory::HalideMulFactory() : pimpl(new Impl()) {} | 
|  | 46 | + | 
|  | 47 | +HalideMulFactory::~HalideMulFactory() = default; | 
|  | 48 | + | 
|  | 49 | +const HalideMulFactory &HalideMulFactory::getInstance() { | 
|  | 50 | +  static HalideMulFactory instance; | 
|  | 51 | +  return instance; | 
|  | 52 | +} | 
|  | 53 | + | 
|  | 54 | +const HalideMulForward & | 
|  | 55 | +HalideMulFactory::getHalideMulForward(int inFeatureCount, int outFeatureCount, | 
|  | 56 | +                                      int groups, bool cuda) const { | 
|  | 57 | +  return getHalideMul<HalideMulForward>(inFeatureCount, outFeatureCount, groups, | 
|  | 58 | +                                        cuda, pimpl->forward); | 
|  | 59 | +} | 
|  | 60 | + | 
|  | 61 | +const HalideMulBackward & | 
|  | 62 | +HalideMulFactory::getHalideMulBackward(int inFeatureCount, int outFeatureCount, | 
|  | 63 | +                                       int groups, bool cuda) const { | 
|  | 64 | +  return getHalideMul<HalideMulBackward>(inFeatureCount, outFeatureCount, | 
|  | 65 | +                                         groups, cuda, pimpl->backward); | 
|  | 66 | +} | 
|  | 67 | + | 
|  | 68 | +HalideMul::HalideMul(int inFeatureCount, int outFeatureCount, int groups) | 
|  | 69 | +    : dimensions({inFeatureCount, outFeatureCount, groups}) {} | 
|  | 70 | + | 
|  | 71 | +HalideMul::HalideMul(const LayerDimensions &dims) : dimensions(dims) {} | 
|  | 72 | + | 
|  | 73 | +HalideMul::~HalideMul() = default; | 
|  | 74 | + | 
|  | 75 | +/* Implementation of forward Halide matrix multiplication */ | 
|  | 76 | +struct HalideMulForward::Impl { | 
|  | 77 | +public: | 
|  | 78 | +  Impl(const LayerDimensions &dimensions, bool cuda) { | 
|  | 79 | +    Halide::Target target = Halide::get_host_target(); | 
|  | 80 | +    Halide::Func matmul = Halide::Func("matmul"); | 
|  | 81 | + | 
|  | 82 | +    /* Variables */ | 
|  | 83 | +    Halide::Var i, g, j; | 
|  | 84 | +    Halide::RDom k{0, dimensions.inFeatureCount / dimensions.groups}; | 
|  | 85 | + | 
|  | 86 | +    /* Algorithm */ | 
|  | 87 | +    Halide::Expr producer = clamp(rules(2 * i), 0, maxHalideRow - 1); | 
|  | 88 | +    matmul(j, i, g) = sum(inputFeatures(k, g, producer) * weights(j, k, g)); | 
|  | 89 | + | 
|  | 90 | +    /* Schedule */ | 
|  | 91 | +    matmul.estimate(j, 0, featureCount) | 
|  | 92 | +        .estimate(g, 0, groups) | 
|  | 93 | +        .estimate(i, 0, featureRowCount); | 
|  | 94 | + | 
|  | 95 | +    inputFeatures.dim(0).set_bounds_estimate(0, featureCount); | 
|  | 96 | +    inputFeatures.dim(1).set_bounds_estimate(0, groups); | 
|  | 97 | +    inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); | 
|  | 98 | + | 
|  | 99 | +    weights.dim(0).set_bounds_estimate(0, featureCount); | 
|  | 100 | +    weights.dim(1).set_bounds_estimate(0, featureCount); | 
|  | 101 | +    weights.dim(2).set_bounds_estimate(0, groups); | 
|  | 102 | + | 
|  | 103 | +    rules.dim(0).set_bounds_estimate(0, activeRows); | 
|  | 104 | +    activeRowsParam.set_estimate(activeRows); | 
|  | 105 | + | 
|  | 106 | +    p = Halide::Pipeline({matmul}); | 
|  | 107 | + | 
|  | 108 | +    if (!cuda) { | 
|  | 109 | +      p.auto_schedule(target); | 
|  | 110 | +    } else { | 
|  | 111 | +      target.set_feature(Halide::Target::CUDA); | 
|  | 112 | +    } | 
|  | 113 | + | 
|  | 114 | +    p.compile_jit(target); | 
|  | 115 | +  }; | 
|  | 116 | + | 
|  | 117 | +  Halide::ImageParam inputFeatures = | 
|  | 118 | +      Halide::ImageParam(Halide::type_of<float>(), 3, "source"); | 
|  | 119 | +  Halide::ImageParam weights = | 
|  | 120 | +      Halide::ImageParam(Halide::type_of<float>(), 3, "weight"); | 
|  | 121 | +  Halide::ImageParam rules = | 
|  | 122 | +      Halide::ImageParam(Halide::type_of<int>(), 1, "rules"); | 
|  | 123 | + | 
|  | 124 | +  Halide::Param<int> activeRowsParam = Halide::Param<int>("row_count"); | 
|  | 125 | + | 
|  | 126 | +  Halide::Pipeline p; | 
|  | 127 | +}; | 
|  | 128 | + | 
|  | 129 | +HalideMulForward::HalideMulForward(int inFeatureCount, int outFeatureCount, | 
|  | 130 | +                                   int groups, bool cuda) | 
|  | 131 | +    : HalideMul(inFeatureCount, outFeatureCount, groups), | 
|  | 132 | +      pimpl(new Impl(dimensions, cuda)) {} | 
|  | 133 | + | 
|  | 134 | +HalideMulForward::HalideMulForward(const LayerDimensions &dims) | 
|  | 135 | +    : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} | 
|  | 136 | + | 
|  | 137 | +HalideMulForward::~HalideMulForward() = default; | 
|  | 138 | + | 
|  | 139 | +/* Executes the forward matrix multiplication created through the | 
|  | 140 | +   implementation object. */ | 
|  | 141 | +void HalideMulForward::execute(float *input, float *weight, int *rules, | 
|  | 142 | +                               float *output, int activeRowCount) const { | 
|  | 143 | + | 
|  | 144 | +  int inputPlanes = dimensions.inFeatureCount / dimensions.groups; | 
|  | 145 | +  int outputPlanes = dimensions.outFeatureCount / dimensions.groups; | 
|  | 146 | + | 
|  | 147 | +  pimpl->inputFeatures.set(Halide::Buffer<float>( | 
|  | 148 | +      input, inputPlanes, dimensions.groups, maxHalideRow)); | 
|  | 149 | +  pimpl->weights.set(Halide::Buffer<float>(weight, outputPlanes, inputPlanes, | 
|  | 150 | +                                           dimensions.groups)); | 
|  | 151 | +  pimpl->rules.set(Halide::Buffer<int>(rules, 2 * activeRowCount)); | 
|  | 152 | +  pimpl->activeRowsParam.set(activeRowCount); | 
|  | 153 | + | 
|  | 154 | +  auto out = Halide::Buffer<float>(output, outputPlanes, activeRowCount, | 
|  | 155 | +                                   dimensions.groups); | 
|  | 156 | +  pimpl->p.realize(out); | 
|  | 157 | +} | 
|  | 158 | + | 
|  | 159 | +/* Implementation of backward Halide matrix multiplication */ | 
|  | 160 | +struct HalideMulBackward::Impl { | 
|  | 161 | +public: | 
|  | 162 | +  Impl(const LayerDimensions &dimensions, bool cuda) { | 
|  | 163 | +    Halide::Target target = Halide::get_host_target(); | 
|  | 164 | + | 
|  | 165 | +    int outputPlanes = dimensions.outFeatureCount / dimensions.groups; | 
|  | 166 | + | 
|  | 167 | +    /* Variables */ | 
|  | 168 | +    Halide::Func o_matmul = Halide::Func("o_matmul"); | 
|  | 169 | +    Halide::Func o_weights = Halide::Func("o_weights"); | 
|  | 170 | +    Halide::Var i, g, k, j, gw, outp, inp; | 
|  | 171 | + | 
|  | 172 | +    Halide::RDom planes = Halide::RDom(0, outputPlanes); | 
|  | 173 | +    Halide::RDom nums = Halide::RDom(0, activeRowsParam); | 
|  | 174 | + | 
|  | 175 | +    /* Algorithm */ | 
|  | 176 | +    Halide::Expr producer = clamp(rules(2 * i + 1), 0, maxHalideRow - 1); | 
|  | 177 | + | 
|  | 178 | +    Halide::Expr orAccess_dom = clamp(rules(2 * nums + 1), 0, maxHalideRow - 1); | 
|  | 179 | +    Halide::Expr irAccess_dom = clamp(rules(2 * nums), 0, maxHalideRow - 1); | 
|  | 180 | + | 
|  | 181 | +    o_matmul(k, i, g) = | 
|  | 182 | +        sum(weights(planes, k, g) * outputFeatures(planes, g, producer)); | 
|  | 183 | + | 
|  | 184 | +    o_weights(outp, inp, gw) = sum(outputFeatures(outp, gw, orAccess_dom) * | 
|  | 185 | +                                   inputFeatures(inp, gw, irAccess_dom)); | 
|  | 186 | + | 
|  | 187 | +    /* Schedule */ | 
|  | 188 | +    o_matmul.estimate(k, 0, featureCount) | 
|  | 189 | +        .estimate(g, 0, groups) | 
|  | 190 | +        .estimate(i, 0, featureRowCount); | 
|  | 191 | +    o_weights.estimate(gw, 0, groups) | 
|  | 192 | +        .estimate(outp, 0, featureCount) | 
|  | 193 | +        .estimate(inp, 0, featureCount); | 
|  | 194 | + | 
|  | 195 | +    inputFeatures.dim(0).set_bounds_estimate(0, featureCount); | 
|  | 196 | +    inputFeatures.dim(1).set_bounds_estimate(0, groups); | 
|  | 197 | +    inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); | 
|  | 198 | + | 
|  | 199 | +    outputFeatures.dim(0).set_bounds_estimate(0, featureCount); | 
|  | 200 | +    outputFeatures.dim(1).set_bounds_estimate(0, groups); | 
|  | 201 | +    outputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); | 
|  | 202 | + | 
|  | 203 | +    weights.dim(0).set_bounds_estimate(0, featureCount); | 
|  | 204 | +    weights.dim(1).set_bounds_estimate(0, featureCount); | 
|  | 205 | +    weights.dim(2).set_bounds_estimate(0, groups); | 
|  | 206 | + | 
|  | 207 | +    rules.dim(0).set_bounds_estimate(0, activeRows); | 
|  | 208 | +    activeRowsParam.set_estimate(activeRows); | 
|  | 209 | + | 
|  | 210 | +    p = Halide::Pipeline({o_matmul, o_weights}); | 
|  | 211 | + | 
|  | 212 | +    if (cuda) { | 
|  | 213 | +      target.set_feature(Halide::Target::CUDA); | 
|  | 214 | +    } else { | 
|  | 215 | +      p.auto_schedule(target); | 
|  | 216 | +    } | 
|  | 217 | + | 
|  | 218 | +    p.compile_jit(target); | 
|  | 219 | +  }; | 
|  | 220 | + | 
|  | 221 | +  Halide::ImageParam inputFeatures = | 
|  | 222 | +      Halide::ImageParam(Halide::type_of<float>(), 3, "input_features"); | 
|  | 223 | +  Halide::ImageParam outputFeatures = | 
|  | 224 | +      Halide::ImageParam(Halide::type_of<float>(), 3, "output_features"); | 
|  | 225 | +  Halide::ImageParam rules = | 
|  | 226 | +      Halide::ImageParam(Halide::type_of<int>(), 1, "rules"); | 
|  | 227 | +  Halide::ImageParam weights = | 
|  | 228 | +      Halide::ImageParam(Halide::type_of<float>(), 3, "weights"); | 
|  | 229 | + | 
|  | 230 | +  Halide::Param<int> activeRowsParam = Halide::Param<int>("row_count"); | 
|  | 231 | + | 
|  | 232 | +  Halide::Pipeline p; | 
|  | 233 | +}; | 
|  | 234 | + | 
|  | 235 | +HalideMulBackward::HalideMulBackward(int inFeatureCount, int outFeatureCount, | 
|  | 236 | +                                     int groups, bool cuda) | 
|  | 237 | +    : HalideMul(inFeatureCount, outFeatureCount, groups), | 
|  | 238 | +      pimpl(new Impl(dimensions, cuda)) {} | 
|  | 239 | + | 
|  | 240 | +HalideMulBackward::HalideMulBackward(const LayerDimensions &dims) | 
|  | 241 | +    : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} | 
|  | 242 | + | 
|  | 243 | +HalideMulBackward::~HalideMulBackward() = default; | 
|  | 244 | + | 
|  | 245 | +/* Executes the backward matrix multiplications created through the | 
|  | 246 | +   implementation object. */ | 
|  | 247 | +void HalideMulBackward::execute(float *inputFeatures, float *outputFeatures, | 
|  | 248 | +                                int *rules, float *weights, | 
|  | 249 | +                                float *dWeightsOutput, float *output, | 
|  | 250 | +                                int activeRowCount) const { | 
|  | 251 | + | 
|  | 252 | +  int inputPlanes = dimensions.inFeatureCount / dimensions.groups; | 
|  | 253 | +  int outputPlanes = dimensions.outFeatureCount / dimensions.groups; | 
|  | 254 | + | 
|  | 255 | +  pimpl->inputFeatures.set(Halide::Buffer<float>( | 
|  | 256 | +      inputFeatures, inputPlanes, dimensions.groups, maxHalideRow)); | 
|  | 257 | +  pimpl->outputFeatures.set(Halide::Buffer<float>( | 
|  | 258 | +      outputFeatures, outputPlanes, dimensions.groups, maxHalideRow)); | 
|  | 259 | +  pimpl->weights.set(Halide::Buffer<float>(weights, outputPlanes, inputPlanes, | 
|  | 260 | +                                           dimensions.groups)); | 
|  | 261 | +  pimpl->rules.set(Halide::Buffer<int>(rules, 2 * activeRowCount)); | 
|  | 262 | + | 
|  | 263 | +  pimpl->activeRowsParam.set(activeRowCount); | 
|  | 264 | + | 
|  | 265 | +  auto halideOutput = Halide::Buffer<float>(output, inputPlanes, activeRowCount, | 
|  | 266 | +                                            dimensions.groups); | 
|  | 267 | +  auto halideWOutput = Halide::Buffer<float>(dWeightsOutput, outputPlanes, | 
|  | 268 | +                                             inputPlanes, dimensions.groups); | 
|  | 269 | + | 
|  | 270 | +  pimpl->p.realize({halideOutput, halideWOutput}); | 
|  | 271 | +} | 
0 commit comments