Skip to content

Commit

Permalink
fix: fix storage issue in torchtensor class (docarray#1833)
Browse files Browse the repository at this point in the history
Signed-off-by: Naymul Islam <naymul504@gmail.com>
  • Loading branch information
ai-naymul authored Dec 9, 2023
1 parent 82918fe commit 3cfa0b8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
)
return super().__torch_function__(func, types_, args, kwargs)

def __deepcopy__(self, memo):
"""
Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues.
"""
# Create a new tensor with the same data and properties
new_tensor = self.clone()
# Set the class to the custom TorchTensor class
new_tensor.__class__ = self.__class__
return new_tensor

@classmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
Expand Down
18 changes: 18 additions & 0 deletions tests/integrations/typing/test_torch_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
from docarray.typing.tensor.torch_tensor import TorchTensor
import copy

from docarray import BaseDoc
from docarray.typing import TorchEmbedding, TorchTensor
Expand All @@ -25,3 +27,19 @@ class MyDocument(BaseDoc):
assert isinstance(d.embedding, TorchEmbedding)
assert isinstance(d.embedding, torch.Tensor)
assert (d.embedding == torch.zeros((128,))).all()


def test_torchtensor_deepcopy():
# Setup
original_tensor_float = TorchTensor(torch.rand(10))
original_tensor_int = TorchTensor(torch.randint(0, 100, (10,)))

# Exercise
copied_tensor_float = copy.deepcopy(original_tensor_float)
copied_tensor_int = copy.deepcopy(original_tensor_int)

# Verify
assert torch.equal(original_tensor_float, copied_tensor_float)
assert original_tensor_float is not copied_tensor_float
assert torch.equal(original_tensor_int, copied_tensor_int)
assert original_tensor_int is not copied_tensor_int

0 comments on commit 3cfa0b8

Please sign in to comment.