-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Torch for Numpy users
Peter O'Connor edited this page Aug 25, 2018
·
17 revisions
torch equivalents of numpy functions
Numpy | Torch |
---|---|
np.ndarray | torch.Tensor |
np.float32 | torch.FloatTensor |
np.float64 | torch.DoubleTensor |
np.int8 | torch.CharTensor |
np.uint8 | torch.ByteTensor |
np.int16 | torch.ShortTensor |
np.int32 | torch.IntTensor |
np.int64 | torch.LongTensor |
Numpy | Torch |
---|---|
np.empty([2,2]) | torch.Tensor(2,2) |
np.empty_like(x) | x.new(x:size()) |
np.eye | torch.eye |
np.identity | torch.eye |
np.ones | torch.ones |
np.ones_like | torch.ones(x:size()) |
np.zeros | torch.zeros |
np.zeros_like | torch.zeros(x:size()) |
Numpy | Torch |
---|---|
np.array([ [1,2],[3,4] ]) | torch.Tensor({{1,2},{3,4}}) |
np.ascontiguousarray(x) | x:contiguous() |
np.copy(x) | x:clone() |
np.fromfile(file) | torch.Tensor(torch.Storage(file)) |
np.frombuffer | ??? |
np.fromfunction | ??? |
np.fromiter | ??? |
np.fromstring | ??? |
np.loadtxt | ??? |
np.concatenate | torch.cat |
np.multiply | torch.cmul |
Numpy | Torch |
---|---|
np.arange(10) | torch.range(0,9) |
np.arange(2, 3, 0.1) | torch.linspace(2, 2.9, 10) |
np.linspace(1, 4, 6) | torch.linspace(1, 4, 6) |
np.logspace | torch.logspace |
Numpy | Torch |
---|---|
np.diag | torch.diag |
np.tril | torch.tril |
np.triu | torch.triu |
Numpy | Torch |
---|---|
x.shape | x:size() |
x.strides | x:stride() |
x.ndim | x:dim() |
x.data | x:data() |
x.size | x:nElement() |
x.size == y.size | x:isSameSizeAs(y) |
x.dtype | x:type() |
Numpy | Torch |
---|
Numpy | Torch |
---|---|
x.reshape | x:reshape |
x.resize | x:resize |
? | x:resizeAs |
x.transpose | x:transpose() |
x.flatten | x:view(x:nElement()) |
x.squeeze | x:squeeze |
Numpy | Torch |
---|---|
np.take(a, indices) | a[indices] |
x[:,0] | x[{{},1}] |
np.put | ???? |
x.repeat | x:repeatTensor |
x.fill | x:fill |
np.choose | ??? |
np.sort | sorted, indices = torch.sort(x, [dim]) |
np.argsort | sorted, indices = torch.sort(x, [dim]) |
np.nonzero | torch.find(x:gt(0), 1) (torchx) |
Numpy | Torch |
---|---|
ndarray.min | mins, indices = torch.min(x, [dim]) |
ndarray.argmin | mins, indices = torch.min(x, [dim]) |
ndarray.max | maxs, indices = torch.max(x, [dim]) |
ndarray.argmax | maxs, indices = torch.max(x, [dim]) |
ndarray.clip | torch.clamp |
ndarray.round | |
ndarray.trace | torch.trace |
ndarray.sum | torch.sum |
ndarray.cumsum | torch.cumsum |
ndarray.mean | torch.mean |
ndarray.std | torch.std |
ndarray.prod | torch.prod |
ndarray.dot | torch.mm |
ndarray.cumprod | torch.cumprod |
ndarray.all | ??? |
ndarray.any | ??? |
Numpy | Torch |
---|---|
ndarray.lt | torch.lt |
ndarray.le | torch.le |
ndarray.gt | torch.gt |
ndarray.ge | torch.ge |
ndarray.eq | torch.eq |
ndarray.ne | torch.ne |