Skip to content

Commit

Permalink
Fix ops and r_ops in case of float and int (#88)
Browse files Browse the repository at this point in the history
* Fix ops and r_ops in case of float and int

* Random input
  • Loading branch information
alessandropalla authored Jul 4, 2024
1 parent a35dea0 commit 66c1205
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
80 changes: 80 additions & 0 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def __add__(self, other) -> "Tensor":
Returns:
Tensor: The result of the addition.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_add")

def __sub__(self, other) -> "Tensor":
Expand All @@ -178,6 +182,10 @@ def __sub__(self, other) -> "Tensor":
Returns:
Tensor: The result of the subtraction.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, -other], "eltwise_add")

def __mul__(self, other) -> "Tensor":
Expand All @@ -190,6 +198,10 @@ def __mul__(self, other) -> "Tensor":
Returns:
Tensor: The result of the multiplication.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_mul")

def __truediv__(self, other) -> "Tensor":
Expand All @@ -202,8 +214,76 @@ def __truediv__(self, other) -> "Tensor":
Returns:
Tensor: The result of the division.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([self, other], "eltwise_div")

def __radd__(self, other) -> "Tensor":
"""
Add two tensors element-wise.
Args:
other (Tensor): The tensor to be added.
Returns:
Tensor: The result of the addition.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_add")

def __rsub__(self, other) -> "Tensor":
"""
Subtract two tensors element-wise.
Args:
other (Tensor): The tensor to be subtracted.
Returns:
Tensor: The result of the subtraction.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, -self], "eltwise_add")

def __rmul__(self, other) -> "Tensor":
"""
Multiply two tensors element-wise.
Args:
other (Tensor): The tensor to be multiplied.
Returns:
Tensor: The result of the multiplication.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_mul")

def __rtruediv__(self, other) -> "Tensor":
"""
Divide two tensors element-wise.
Args:
other (Tensor): The tensor to be divided.
Returns:
Tensor: The result of the division.
"""
if isinstance(other, (int, float)):
other = self.factory.constant(
torch.tensor([other], dtype=self.dtype.torch_dtype)
)
return generate_op([other, self], "eltwise_div")

def __neg__(self) -> "Tensor":
"""
Negate the tensor.
Expand Down
40 changes: 40 additions & 0 deletions test/python/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,43 @@ def test_reduce_operations(batch, hidden_dim, axis, op):
assert 1 - r2_score([reference, 1], [result, 1]) < 0.01
else:
assert 1 - r2_score(reference, result) < 0.01


@pytest.mark.parametrize("shape", [[1, 128, 32, 64], [12, 231]])
@pytest.mark.parametrize("op", ["+", "-", "*", "/"])
@pytest.mark.parametrize("side", ["left", "right"])
@pytest.mark.parametrize("value", [3, -10])
def test_float_op(shape, op, side, value):
def op_func(a, b):
if op == "+":
return a + b
elif op == "-":
return a - b
elif op == "*":
return a * b
elif op == "/":
return a / b

def act(a, b):
if side == "left":
return op_func(b, a)
else:
return op_func(a, b)

x = torch.rand(shape).to(torch.float16) + 2
reference = act(x, value)

model = NNFactory()
t1 = model.parameter(shape, float16)
out = act(t1, value)
model.compile()

result = model(x)

assert (
1
- r2_score(
reference.flatten().detach().numpy(), result.flatten().detach().numpy()
)
< 0.001
)

0 comments on commit 66c1205

Please sign in to comment.