|
1 | 1 | import math
|
2 | 2 | from typing import Optional
|
3 | 3 |
|
| 4 | +import numpy as np |
| 5 | +import tensorrt as trt |
4 | 6 | from torch.fx.node import Target
|
5 | 7 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
| 8 | +from torch_tensorrt.dynamo.conversion import impl |
6 | 9 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
7 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim |
| 10 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 11 | + get_positive_dim, |
| 12 | + get_trt_tensor, |
| 13 | +) |
8 | 14 | from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
|
9 | 15 | from torch_tensorrt.fx.converters.converter_utils import (
|
10 | 16 | has_dynamic_shape,
|
@@ -96,3 +102,98 @@ def expand(
|
96 | 102 | layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
|
97 | 103 | set_layer_name(layer, target, name, source_ir)
|
98 | 104 | return layer.get_output(0)
|
| 105 | + |
| 106 | + |
| 107 | +def chunk( |
| 108 | + ctx: ConversionContext, |
| 109 | + target: Target, |
| 110 | + source_ir: Optional[SourceIR], |
| 111 | + name: str, |
| 112 | + input: TRTTensor, |
| 113 | + chunks: int, |
| 114 | + dim: int, |
| 115 | +) -> TRTTensor: |
| 116 | + if chunks <= 0: |
| 117 | + raise RuntimeError( |
| 118 | + f"chunk expects `chunks` to be greater than 0, got: {chunks}" |
| 119 | + ) |
| 120 | + |
| 121 | + shape = input.shape |
| 122 | + dim = get_positive_dim(dim, len(shape)) |
| 123 | + |
| 124 | + if dim >= len(shape): |
| 125 | + raise RuntimeError( |
| 126 | + f"chunk expects `dim` to be less than the length of input shape, got: {dim}" |
| 127 | + ) |
| 128 | + |
| 129 | + dynamic_shape = has_dynamic_shape(input.shape) |
| 130 | + if dynamic_shape > 0: |
| 131 | + # Check whether slice target dim is dynamic shape dim |
| 132 | + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" |
| 133 | + |
| 134 | + size_dim = shape[dim] |
| 135 | + chunk_size = math.ceil(size_dim / chunks) |
| 136 | + result = [] |
| 137 | + start = 0 |
| 138 | + end = min(start + chunk_size, size_dim) |
| 139 | + cnt = 0 |
| 140 | + |
| 141 | + while start < end: |
| 142 | + result.append( |
| 143 | + slice_op( |
| 144 | + ctx, |
| 145 | + target, |
| 146 | + source_ir, |
| 147 | + f"{name}_slice_{cnt}", |
| 148 | + input, |
| 149 | + dim, |
| 150 | + start, |
| 151 | + end, |
| 152 | + 1, |
| 153 | + ) |
| 154 | + ) |
| 155 | + start = end |
| 156 | + end = min(start + chunk_size, size_dim) |
| 157 | + cnt += 1 |
| 158 | + |
| 159 | + return result |
| 160 | + |
| 161 | + |
| 162 | +def cumsum( |
| 163 | + ctx: ConversionContext, |
| 164 | + target: Target, |
| 165 | + source_ir: Optional[SourceIR], |
| 166 | + name: str, |
| 167 | + input: TRTTensor, |
| 168 | + dim: int, |
| 169 | +) -> TRTTensor: |
| 170 | + input_shape = input.shape |
| 171 | + dim = get_positive_dim(dim, len(input_shape)) |
| 172 | + loop = ctx.net.add_loop() |
| 173 | + axis = np.array(input_shape[dim]) |
| 174 | + trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit") |
| 175 | + loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT) |
| 176 | + iterator = loop.add_iterator(input, dim, reverse=False) |
| 177 | + data = iterator.get_output(0) |
| 178 | + new_dims = tuple(data.shape) |
| 179 | + zeros = np.zeros(new_dims) |
| 180 | + zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value") |
| 181 | + |
| 182 | + running_sum = loop.add_recurrence(zero_trttensor) |
| 183 | + set_layer_name(running_sum, target, f"{name}_running_sum", source_ir) |
| 184 | + running_sum_tensor = running_sum.get_output(0) |
| 185 | + |
| 186 | + current_sum = impl.elementwise.add( |
| 187 | + ctx, |
| 188 | + target, |
| 189 | + source_ir, |
| 190 | + f"{name}_elementwise_add", |
| 191 | + data, |
| 192 | + running_sum_tensor, |
| 193 | + ) |
| 194 | + running_sum.set_input(1, current_sum) |
| 195 | + |
| 196 | + loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim) |
| 197 | + set_layer_name(loop_output, target, f"{name}_loop_output", source_ir) |
| 198 | + loop_output.set_input(1, trip_limit) |
| 199 | + return loop_output.get_output(0) |
0 commit comments