Skip to content

Commit 93846ed

Browse files
narendasanapbose
authored andcommitted
refactor: Moving elementwise and unary core to impl
Signed-off-by: Naren Dasan <naren@narendasan.com> new file: ../converters/impl/unary/base.py
1 parent 74e17b5 commit 93846ed

File tree

10 files changed

+882
-463
lines changed

10 files changed

+882
-463
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+379-130
Large diffs are not rendered by default.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .converter_utils import * # noqa: F403
2424
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2525
from torch_tensorrt.fx.converters.impl import activation
26+
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
2627

2728
_LOGGER: logging.Logger = logging.getLogger(__name__)
2829

@@ -161,9 +162,7 @@ def aten_ops_div(
161162
network, target, None, kwargs_new, name
162163
)
163164
elif rounding_mode == "trunc":
164-
return acc_ops_converters.acc_ops_trunc_div(
165-
network, target, None, kwargs_new, name
166-
)
165+
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
167166
else:
168167
raise RuntimeError(
169168
f"Target {target} does not support rounding mode {rounding_mode}"

py/torch_tensorrt/fx/converters/converter_utils.py

+3-330
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SourceIR(Enum):
2828
ACC = auto()
2929
ATEN = auto()
3030
PRIM = auto()
31+
TORCHTRT_LOWERED = auto()
3132
UNKNOWN = auto()
3233

3334
def __str__(self):
@@ -39,6 +40,8 @@ def __str__(self):
3940
return "aten"
4041
elif self == SourceIR.PRIM:
4142
return "prim"
43+
elif self == SourceIR.TORCHTRT_LOWERED:
44+
return "torchtrt_lowered"
4245
else:
4346
return "unknown_ir"
4447

@@ -383,171 +386,6 @@ def broadcast(
383386
return a, b
384387

385388

386-
def get_shape_with_dynamic_shape(
387-
network: TRTNetwork,
388-
shape: Union[list, tuple, torch.Tensor],
389-
input_val: TRTTensor,
390-
target: Target,
391-
name: str,
392-
) -> TRTTensor:
393-
"""
394-
Prepare the real output tensor shape for dynamic shape mode tensor input.
395-
How this functions works:
396-
Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation
397-
output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual
398-
reduce operation output shape. Steps of calculations are:
399-
1. get the actual tensor shape of input_val via add_shape layer;
400-
2. create a all 0 tensor [0, 0, 0];
401-
3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False];
402-
4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace
403-
all -1 dynamic shape dimensions with actual batch_size value;
404-
5. output shape with actual batch_size as [2048, 128, 256]
405-
406-
Args:
407-
network (TRTNetwork): TensorRT network object.
408-
shape: calculated shape of the expected output tensor
409-
input_val (TRTTensor): A TensorRT ITensor.
410-
target (Target): Target of fx node.
411-
name (str): The name we want to assign to the created TensorRT layer.
412-
Returns:
413-
TensorRT ITensors that represents the actual shape of the input_val
414-
"""
415-
# Ger real shape info for input_val
416-
input_shape = network.add_shape(input_val).get_output(0)
417-
418-
scale_layer = network.add_constant(
419-
input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32)
420-
)
421-
set_layer_name(scale_layer, target, f"{name}_scale")
422-
scale_res = scale_layer.get_output(0)
423-
424-
length = input_shape.shape[0]
425-
zero_layer = network.add_constant(
426-
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
427-
)
428-
set_layer_name(zero_layer, target, f"{name}_zeros")
429-
430-
condition_val = add_binary_elementwise_layer(
431-
network,
432-
scale_res,
433-
zero_layer.get_output(0),
434-
trt.ElementWiseOperation.LESS,
435-
target,
436-
f"{name}_shape",
437-
)
438-
select_layer = network.add_select(condition_val, input_shape, scale_res)
439-
set_layer_name(select_layer, target, f"{name}_select")
440-
return select_layer.get_output(0)
441-
442-
443-
def add_binary_elementwise_layer(
444-
network: TRTNetwork,
445-
lhs_val: Union[int, float, TRTTensor, torch.Tensor],
446-
rhs_val: Union[int, float, TRTTensor, torch.Tensor],
447-
op_type: trt.ElementWiseOperation,
448-
target: Target,
449-
name: str,
450-
) -> TRTTensor:
451-
"""
452-
This function adds a TensorRT elementwise layer. We allow both operands to be
453-
constant (not a trt tensor) because in implicit batch dimension mode, we could
454-
introduce constant via .size() op. Other scenario should be const folded first.
455-
If any operand is not a trt tensor, we make it a trt constant layer while preserve
456-
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
457-
458-
Limitation:
459-
If we are using implicit batch dim mode, the operand that is not a trt
460-
tensor are not allowed to have larger ranks than the trt tensor operand.
461-
462-
Args:
463-
network (TRTNetwork): TensorRT network object.
464-
lhs_val (TRTTensor): Left operand of the binary operation. Could
465-
be a TensorRT tensor, a PyTorch tensor or a simple value.
466-
rhs_val (TRTTensor): Right operand of the binary operation. Similar
467-
to lhs_val.
468-
op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation.
469-
target (Target): Target of fx node.
470-
name (str): The name we want to assign to the created TensorRT layer.
471-
472-
Returns:
473-
The output of TensorRT Elementwise layer.
474-
"""
475-
lhs_dtype = None
476-
rhs_dtype = None
477-
is_lhs_trt_tensor = False
478-
is_rhs_trt_tensor = False
479-
480-
if isinstance(lhs_val, TRTTensor):
481-
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
482-
is_lhs_trt_tensor = True
483-
if isinstance(rhs_val, TRTTensor):
484-
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
485-
is_rhs_trt_tensor = True
486-
487-
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
488-
warnings.warn(
489-
f"Both operands of the binary elementwise op {name} "
490-
"are constant. In this case, please consider constant fold the model first."
491-
)
492-
return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val)
493-
494-
# If the following conditions are true:
495-
# 1. the network has implicit batch dimension,
496-
# 2. one operand has shape [] (real shape is [batch_size]),
497-
# 3. another operand is a scalar,
498-
# then the result should also have shape [] (real shape is [batch_size]).
499-
#
500-
# In such case, we need to convert the scalar operand to tensor, because
501-
# this way the shape will become [1], and then will be properly squeezed
502-
# into [], meaning that the result will have shape [], which is what we
503-
# expect.
504-
#
505-
# Note that the dtype here is supposed to be the same as the scalar
506-
# dtype but we don't have a way to detect whether it makes sense for the
507-
# scalar to be float or half. Hence we go with the lhs dtype.
508-
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
509-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
510-
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
511-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
512-
513-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
514-
# will fail because lhs shape has fewer dimensions than rhs shape. This
515-
# happens when using implicit batch dimension, when we removed the 1st
516-
# dimension from input tensor, causing it to have shape [] - a scalar. We
517-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
518-
# scalar too. More generally, we squeeze_left on input if it's a constant
519-
# tensor. This is safe because broadcast will pad dimensions on the left
520-
# (prepend) to make lhs and rhs shape compatible.
521-
if network.has_implicit_batch_dimension:
522-
if isinstance(lhs_val, torch.Tensor):
523-
lhs_val = squeeze_left(lhs_val)
524-
if isinstance(rhs_val, torch.Tensor):
525-
rhs_val = squeeze_left(rhs_val)
526-
527-
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
528-
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
529-
530-
# Check the limitation in the doc string.
531-
if network.has_implicit_batch_dimension:
532-
if is_lhs_trt_tensor and not is_rhs_trt_tensor:
533-
assert len(lhs_val.shape) >= len(
534-
rhs_val.shape
535-
), f"{lhs_val.shape} >= {rhs_val.shape}"
536-
elif not is_lhs_trt_tensor and is_rhs_trt_tensor:
537-
assert len(rhs_val.shape) >= len(
538-
lhs_val.shape
539-
), f"{rhs_val.shape} >= {lhs_val.shape}"
540-
541-
lhs_val, rhs_val = broadcast(
542-
network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs"
543-
)
544-
layer = network.add_elementwise(lhs_val, rhs_val, op_type)
545-
set_layer_name(layer, target, name)
546-
output = layer.get_output(0)
547-
output.name = output.name + "_" + target.__name__
548-
return output
549-
550-
551389
def squeeze_left(const: torch.Tensor):
552390
"""
553391
Squeeze the size-1 dimensions on the left side of the shape tuple.
@@ -559,38 +397,6 @@ def squeeze_left(const: torch.Tensor):
559397
return const
560398

561399

562-
def add_unary_layer(
563-
network: TRTNetwork,
564-
input_val: TRTTensor,
565-
operation_type: trt.UnaryOperation,
566-
target: Target,
567-
name: str,
568-
) -> TRTTensor:
569-
"""
570-
Add a TensorRT Unary layer to `network`.
571-
572-
Args:
573-
network (TRTNetwork): TensorRT network object.
574-
input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor.
575-
op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation.
576-
target (Target): Target of fx node.
577-
name (str): The name we want to assign to the created TensorRT layer.
578-
579-
Returns:
580-
The output of TensorRT Unary layer.
581-
"""
582-
if not isinstance(input_val, TRTTensor):
583-
raise RuntimeError(
584-
f"{operation_type} received input {input_val} that is not part "
585-
"of the TensorRT region!"
586-
)
587-
layer = network.add_unary(input_val, operation_type)
588-
set_layer_name(layer, target, name)
589-
output = layer.get_output(0)
590-
output.name = output.name + "_" + target.__name__
591-
return layer.get_output(0)
592-
593-
594400
def add_reduce_layer(
595401
network: TRTNetwork,
596402
target: Target,
@@ -695,139 +501,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names):
695501
return inputs
696502

697503

698-
def sign(
699-
network: TRTNetwork, input_val: TRTTensor, target: Target, name: str
700-
) -> TRTTensor:
701-
"""
702-
Sign is calculated as below:
703-
x = input
704-
sign = (exp(x) // exp(abs(x))) * 2 - 1
705-
For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0.
706-
With multiply 2, the value become 2(for pos and 0) and 0(for neg).
707-
Finally minus 1, the value become 1(for pos and 0) and -1(for neg).
708-
709-
Args:
710-
network (TRTNetwork): TensorRT network object.
711-
input_val (TRTTensor): The input tensor.
712-
target (Target): fx node target.
713-
name (str): Name of the fx node with optional suffix.
714-
715-
Returns:
716-
A TensorRT tensor represent the result of sign operator.
717-
"""
718-
input_exp_output = add_unary_layer(
719-
network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp"
720-
)
721-
input_abs_output = add_unary_layer(
722-
network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs"
723-
)
724-
input_abs_exp_output = add_unary_layer(
725-
network,
726-
input_abs_output,
727-
trt.UnaryOperation.EXP,
728-
target,
729-
f"{name}_prod_abs_exp",
730-
)
731-
floor_div_output = add_binary_elementwise_layer(
732-
network,
733-
input_exp_output,
734-
input_abs_exp_output,
735-
trt.ElementWiseOperation.FLOOR_DIV,
736-
target,
737-
f"{name}_exp_floor_div",
738-
)
739-
double_floor_div_output = add_binary_elementwise_layer(
740-
network,
741-
floor_div_output,
742-
2,
743-
trt.ElementWiseOperation.PROD,
744-
target,
745-
f"{name}_floor_div*2",
746-
)
747-
return add_binary_elementwise_layer(
748-
network,
749-
double_floor_div_output,
750-
1,
751-
trt.ElementWiseOperation.SUB,
752-
target,
753-
f"{name}_sign",
754-
)
755-
756-
757-
def trunc_div(
758-
input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str
759-
) -> TRTTensor:
760-
"""
761-
Perform trunc divide on Tensor, result of divide will be round toward zero.
762-
This means for positive number, it will be floor round; for negative number,
763-
it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3].
764-
765-
Args:
766-
input: divisor.
767-
other: dividend.
768-
network: INetworkDefinition.
769-
target: node target.
770-
name: namespace for the op
771-
772-
Returns:
773-
A TensorRT tensor represent the result of trunc divide.
774-
"""
775-
prod_output = add_binary_elementwise_layer(
776-
network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod"
777-
)
778-
sign_output = sign(network, prod_output, target, name)
779-
780-
# Convert constant input into ITensor for UnaryOperation
781-
if not isinstance(input, trt.tensorrt.ITensor):
782-
input = get_trt_tensor(network, input, f"{name}_input")
783-
if not isinstance(other, trt.tensorrt.ITensor):
784-
other = get_trt_tensor(
785-
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
786-
)
787-
788-
abs_input_output = add_unary_layer(
789-
network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input"
790-
)
791-
abs_other_output = add_unary_layer(
792-
network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other"
793-
)
794-
abs_floor_output = add_binary_elementwise_layer(
795-
network,
796-
abs_input_output,
797-
abs_other_output,
798-
trt.ElementWiseOperation.FLOOR_DIV,
799-
target,
800-
f"{name}_floor_div",
801-
)
802-
output = add_binary_elementwise_layer(
803-
network,
804-
abs_floor_output,
805-
sign_output,
806-
trt.ElementWiseOperation.PROD,
807-
target,
808-
f"{name}_output",
809-
)
810-
811-
return output
812-
813-
814-
def get_python_op_from_trt_elementwise_op(
815-
trt_op: TRTElementWiseOp,
816-
) -> Callable[[Any, Any], Any]:
817-
if trt_op == trt.ElementWiseOperation.SUM:
818-
return operator.add
819-
elif trt_op == trt.ElementWiseOperation.PROD:
820-
return operator.mul
821-
elif trt_op == trt.ElementWiseOperation.SUB:
822-
return operator.sub
823-
elif trt_op == trt.ElementWiseOperation.DIV:
824-
return operator.truediv
825-
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
826-
return operator.floordiv
827-
else:
828-
raise RuntimeError(f"{trt_op} is not supported yet!")
829-
830-
831504
def dtype_uniform(
832505
network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor
833506
):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ops import *

0 commit comments

Comments
 (0)