Skip to content

Commit

Permalink
Remove some uses of torch.is_tensor in favor of isinstance (pytorch#5473
Browse files Browse the repository at this point in the history
)
  • Loading branch information
colesbury authored and soumith committed Mar 2, 2018
1 parent 5dedc64 commit 70ba50c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 26 deletions.
8 changes: 3 additions & 5 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
def _make_grads(outputs, grads):
new_grads = []
for out, grad in zip(outputs, grads):
if isinstance(grad, Variable):
if isinstance(grad, torch.Tensor):
new_grads.append(grad)
elif torch.is_tensor(grad):
new_grads.append(Variable(grad))
elif grad is None:
if out.requires_grad:
if out.numel() != 1:
Expand Down Expand Up @@ -70,7 +68,7 @@ def backward(variables, grad_variables=None, retain_graph=None, create_graph=Fal

if grad_variables is None:
grad_variables = [None] * len(variables)
elif isinstance(grad_variables, Variable) or torch.is_tensor(grad_variables):
elif isinstance(grad_variables, torch.Tensor):
grad_variables = [grad_variables]
else:
grad_variables = list(grad_variables)
Expand Down Expand Up @@ -126,7 +124,7 @@ def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=Fal
inputs = (inputs,) if isinstance(inputs, Variable) else tuple(inputs)
if grad_outputs is None:
grad_outputs = [None] * len(outputs)
elif isinstance(grad_outputs, Variable) or torch.is_tensor(grad_outputs):
elif isinstance(grad_outputs, torch.Tensor):
grad_outputs = [grad_outputs]
else:
grad_outputs = list(grad_outputs)
Expand Down
10 changes: 2 additions & 8 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,17 @@ def make_jacobian(input, num_out):


def iter_tensors(x, only_requiring_grad=False):
if isinstance(x, Variable):
if isinstance(x, torch.Tensor):
if x.requires_grad or not only_requiring_grad:
yield x.data
elif torch.is_tensor(x):
if only_requiring_grad:
raise AssertionError("iter_tensors encountered Tensor with only_requiring_grad=True")
yield x
elif isinstance(x, Iterable):
for elem in x:
for result in iter_tensors(elem, only_requiring_grad):
yield result


def contiguous(input):
if torch.is_tensor(input):
return input.contiguous()
elif isinstance(input, Variable):
if isinstance(input, torch.Tensor):
return input.contiguous()
elif isinstance(input, Iterable):
return type(input)(contiguous(e) for e in input)
Expand Down
5 changes: 0 additions & 5 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,6 @@ def retain_grad_hook(grad):
self.register_hook(retain_grad_hook)
self.retains_grad = True

def type_as(self, other):
if torch.is_tensor(other):
other = Variable(other)
return super(Variable, self).type_as(other)

def is_pinned(self):
r"""Returns true if this tensor resides in pinned memory"""
storage = self.storage()
Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def _validate_log_prob_arg(self, value):
ValueError: when the rightmost dimensions of `value` do not match the
distribution's batch and event shapes.
"""
if not (torch.is_tensor(value) or isinstance(value, Variable)):
raise ValueError('The value argument to log_prob must be a Tensor or Variable instance.')
if not isinstance(value, torch.Tensor):
raise ValueError('The value argument to log_prob must be a Tensor')

event_dim_start = len(value.size()) - len(self._event_shape)
if value.size()[event_dim_start:] != self._event_shape:
Expand Down
9 changes: 3 additions & 6 deletions torch/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,17 @@ def broadcast_all(*values):
`(1,)`.
Args:
values (list of `numbers.Number`, `torch.autograd.Variable` or
`torch.Tensor`)
values (list of `numbers.Number` or `torch.Tensor`)
Raises:
ValueError: if any of the values is not a `numbers.Number`, `torch.Tensor`
or `torch.autograd.Variable` instance
"""
values = list(values)
scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)]
tensor_idxs = [i for i in range(len(values)) if
torch.is_tensor(values[i]) or isinstance(values[i], Variable)]
tensor_idxs = [i for i in range(len(values)) if isinstance(values[i], torch.Tensor)]
if len(scalar_idxs) + len(tensor_idxs) != len(values):
raise ValueError('Input arguments must all be instances of numbers.Number, torch.Tensor or ' +
'torch.autograd.Variable.')
raise ValueError('Input arguments must all be instances of numbers.Number or torch.Tensor.')
if tensor_idxs:
broadcast_shape = _broadcast_shape([values[i].size() for i in tensor_idxs])
for idx in tensor_idxs:
Expand Down

0 comments on commit 70ba50c

Please sign in to comment.