|
5 | 5 | import topi |
6 | 6 | from relay.env import Environment |
7 | 7 | import relay.ir as ir |
8 | | -from relay.make import Operator, IntrinsicId, TypeId, TensorType, FloatType |
| 8 | +from relay.make import String, Operator, IntrinsicId, TypeId, TensorType, FloatType |
9 | 9 | from relay.make import TypeQuantifier, TypeArrow, ProductType |
| 10 | +from relay.make import ShapeAttr, ShapeBinaryOp, ShapeProjection, ShapeSingleton, ShapeSeq |
10 | 11 |
|
11 | 12 | # TODO(@jroesch): Fix up my type |
12 | 13 | __operator_registry__: Dict[str, Any] = {} |
@@ -128,6 +129,15 @@ def broadcast_mul_compiler(func_ty: ir.Type) -> Any: |
128 | 129 | module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="broadcast_mul_compiler") |
129 | 130 | return module.get_function("broadcast_mul_compiler") |
130 | 131 |
|
| 132 | +# TODO(@jroesch): ensure this interfaces correctly |
| 133 | +# note that the type provided doesn't handle padding |
| 134 | +# feel free to assume some default behavior |
| 135 | +def conv2d_compiler(func_ty: ir.Type) -> Any: |
| 136 | + Inputs, ret_ty = func_ty_to_placeholders(func_ty) |
| 137 | + Output = topi.nn.conv2d(*Inputs) |
| 138 | + schedule = tvm.create_schedule(Output.op) |
| 139 | + module = tvm.build(schedule, Inputs + [Output], __tgt__, target_host=__tgt_host__, name="conv2d_compiler") |
| 140 | + return module.get_function("conv2d_compiler") |
131 | 141 |
|
132 | 142 | def initialize_operators(env) -> None: |
133 | 143 | """Initialize the default set of operators for the system, this will populate |
@@ -175,3 +185,44 @@ def initialize_operators(env) -> None: |
175 | 185 | bmul_type = TypeQuantifier(shape, TypeArrow(ProductType([in_out_type, in_out_type]), in_out_type)) |
176 | 186 | # TODO: reverse mode |
177 | 187 | register_op(env, 'broadcast_mul', bmul_type, broadcast_mul_compiler) |
| 188 | + |
| 189 | + # Conv2d |
| 190 | + # input: [batch, in_channel, in_height, in_width] |
| 191 | + # filter: [filter_height, filter_width, in_channel, num_filter] |
| 192 | + # output shape: [out_height, out_width, num_filter, batch] |
| 193 | + # out_height = (in_height - filter_h)/stride_h + 1 |
| 194 | + # out_width = (in_width - filter_w)/stride_w + 1 |
| 195 | + stride_h = ShapeAttr(String("stride_h")) |
| 196 | + stride_w = ShapeAttr(String("stride_w")) |
| 197 | + btvar = TypeId("bt", Kind.BaseType) |
| 198 | + input_shape = TypeId("input_shape", Kind.Shape) |
| 199 | + filter_shape = TypeId("filter_shape", Kind.Shape) |
| 200 | + output_shape = ShapeSeq([ |
| 201 | + ShapeBinaryOp(ShapeOp.SHPLUS, |
| 202 | + ShapeBinaryOp(ShapeOp.SHDIV, |
| 203 | + ShapeBinaryOp(ShapeOp.SHSUB, |
| 204 | + ShapeProjection(input_shape, 2), |
| 205 | + ShapeProjection(filter_shape, 0)), |
| 206 | + stride_h), |
| 207 | + ShapeSingleton(1)), |
| 208 | + ShapeBinaryOp(ShapeOp.SHPLUS, |
| 209 | + ShapeBinaryOp(ShapeOp.SHDIV, |
| 210 | + ShapeBinaryOp(ShapeOp.SHSUB, |
| 211 | + ShapeProjection(input_shape, 3), |
| 212 | + ShapeProjection(filter_shape, 1)), |
| 213 | + stride_w), |
| 214 | + ShapeSingleton(1)), |
| 215 | + ShapeProjection(filter_shape, 3), |
| 216 | + ShapeProjection(input_shape, 0) |
| 217 | + ]) |
| 218 | + conv2d_type = TypeQuantifier( |
| 219 | + btvar, |
| 220 | + TypeQuantifier( |
| 221 | + input_shape, |
| 222 | + TypeQuantifier( |
| 223 | + filter_shape, |
| 224 | + TypeArrow(ProductType([TensorType(btvar, input_shape), TensorType(btvar, filter_shape)], |
| 225 | + TensorType(btvar, output_shape) |
| 226 | + ))))) |
| 227 | + # TODO: reverse mode |
| 228 | + register_op(env, 'conv2d', conv2d_type, conv2d_compiler) |
0 commit comments