Skip to content

Commit

Permalink
fix linter actually
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh committed Oct 27, 2020
1 parent f76e2fa commit 2c7e88a
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
equal,
shape_of,
log,
concatenate
concatenate,
)
from .transform import (
broadcast_to_like,
Expand All @@ -62,7 +62,7 @@
split,
squeeze,
strided_set,
arange
arange,
)


Expand Down Expand Up @@ -679,7 +679,7 @@ def cross_entropy_with_logits_grad(orig, grad):
def take_grad(orig, grad):
def make_scalar_tensor(v):
if isinstance(v, int):
v = const(v, dtype='int32')
v = const(v, dtype="int32")
return reshape(v, (1,))

# TODO(@altanh): we currently assume indices are in range
Expand All @@ -690,7 +690,7 @@ def make_scalar_tensor(v):
try:
data_shape = data.checked_type.concrete_shape
except TypeError:
raise OpError('currently take_grad only supports data with concrete shape')
raise OpError("currently take_grad only supports data with concrete shape")
if axis is None:
axis = 0
data_grad = reshape(data_grad, (-1,))
Expand All @@ -710,7 +710,7 @@ def make_scalar_tensor(v):
elif len(indices.checked_type.shape) == 1:
num_indices = take(shape_of(indices), zero, axis=0)
else:
raise OpError('take_grad only supports scalar or 1D indices')
raise OpError("take_grad only supports scalar or 1D indices")

def loop_cond(data_grad, i):
return squeeze(less(i, num_indices))
Expand All @@ -731,8 +731,8 @@ def loop_body(data_grad, i):
return (next_data_grad, i + one)

loop_vars = [
Var('data_grad', type_annotation=TensorType(data_shape, data.checked_type.dtype)),
Var('i', type_annotation=TensorType((1,), 'int32')),
Var("data_grad", type_annotation=TensorType(data_shape, data.checked_type.dtype)),
Var("i", type_annotation=TensorType((1,), "int32")),
]

loop = while_loop(loop_cond, loop_vars, loop_body)
Expand Down Expand Up @@ -777,11 +777,11 @@ def expand_dims_grad(orig, grad):
@register_gradient("arange")
def arange_grad(orig, grad):
start, stop, step = orig.args
length = take(shape_of(orig), const(0, dtype='int32'), axis=0)
length = take(shape_of(orig), const(0, dtype="int32"), axis=0)

grad_start = cast_like(_sum(grad), start)
grad_stop = zeros_like(stop)
grad_step = cast_like(arange(length, dtype='int32'), grad) * grad
grad_step = cast_like(arange(length, dtype="int32"), grad) * grad
grad_step = cast_like(_sum(grad_step), step)

return [grad_start, grad_stop, grad_step]

0 comments on commit 2c7e88a

Please sign in to comment.