Skip to content

Commit

Permalink
Adds PyTorch NN (except Pooling and Conv) (keras-team#111)
Browse files Browse the repository at this point in the history
* Add NN for pytorch backend

* Add PyTorch nn functions

* Add PyTorch nn functions

* Adds PyTorch Backend

* Adds PyTorch Backend

* Adds PyTorch nn backend

* Add PyTorch backend funcitons

* Add Torch NN Backend

* Add Torch NN Backend

* Add Torch NN Backend
  • Loading branch information
sampathweb authored May 19, 2023
1 parent 00330e3 commit 83f0797
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 2 deletions.
240 changes: 240 additions & 0 deletions keras_core/backend/torch/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import torch
import torch.nn.functional as tnn

from keras_core.backend.config import epsilon


def relu(x):
return tnn.relu(x)


def relu6(x):
return tnn.relu6(x)


def sigmoid(x):
return tnn.sigmoid(x)


def tanh(x):
return tnn.tanh(x)


def softplus(x):
return tnn.softplus(x)


def softsign(x):
return tnn.soft_sign(x)


def silu(x, beta=1.0):
return x * sigmoid(beta * x)


def swish(x):
return silu(x, beta=1)


def log_sigmoid(x):
return tnn.logsigmoid(x)


def leaky_relu(x, negative_slope=0.2):
return tnn.leaky_relu(x, negative_slope=negative_slope)


def hard_sigmoid(x):
return tnn.hardsigmoid(x)


def elu(x):
return tnn.elu(x)


def selu(x):
return tnn.selu(x)


def gelu(x, approximate=True):
return tnn.gelu(x, approximate)


def softmax(x, axis=None):
return tnn.softmax(x, dim=axis)


def log_softmax(x, axis=-1):
return tnn.log_softmax(x, dim=axis)


def max_pool(
inputs,
pool_size,
strides=None,
padding="valid",
data_format="channels_last",
):
raise NotImplementedError(
"`max_pool` not yet implemented for PyTorch Backend"
)


def average_pool(
inputs,
pool_size,
strides,
padding,
data_format="channels_last",
):
raise NotImplementedError(
"`average_pool` not yet implemented for PyTorch Backend"
)


def conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
raise NotImplementedError("`conv` not yet implemented for PyTorch Backend")


def depthwise_conv(
inputs,
kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
raise NotImplementedError(
"`depthwise_conv` not yet implemented for PyTorch Backend"
)


def separable_conv(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=1,
padding="valid",
data_format="channels_last",
dilation_rate=1,
):
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)


def conv_transpose(
inputs,
kernel,
strides=1,
padding="valid",
output_padding=None,
data_format="channels_last",
dilation_rate=1,
):
raise NotImplementedError(
"`conv_transpose` not yet implemented for PyTorch backend"
)


def one_hot(x, num_classes, axis=-1):
if axis != -1 or axis != x.shape[-1]:
raise ValueError(
"`one_hot` is only implemented for last axis for PyTorch backend. "
f"`axis` arg value {axis} should be -1 or last axis of the input "
f"tensor with shape {x.shape}."
)
return tnn.one_hot(x, num_classes)


def categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = torch.as_tensor(target)
output = torch.as_tensor(output)

if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if len(target.shape) < 1:
raise ValueError(
"Arguments `target` and `output` must be at least rank 1. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)

if from_logits:
log_prob = tnn.log_softmax(output, dim=axis)
else:
output = output / torch.sum(output, dim=axis, keepdim=True)
output = torch.clip(output, epsilon(), 1.0 - epsilon())
log_prob = torch.log(output)
return -torch.sum(target * log_prob, dim=axis)


def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
target = torch.as_tensor(target, dtype=torch.long)
output = torch.as_tensor(output)

if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
target = torch.squeeze(target, dim=-1)

if len(output.shape) < 1:
raise ValueError(
"Argument `output` must be at least rank 1. "
"Received: "
f"output.shape={output.shape}"
)
if target.shape != output.shape[:-1]:
raise ValueError(
"Arguments `target` and `output` must have the same shape "
"up until the last dimension: "
f"target.shape={target.shape}, output.shape={output.shape}"
)
if from_logits:
log_prob = tnn.log_softmax(output, dim=axis)
else:
output = output / torch.sum(output, dim=axis, keepdim=True)
output = torch.clip(output, epsilon(), 1.0 - epsilon())
log_prob = torch.log(output)
target = one_hot(target, output.shape[axis], axis=axis)
return -torch.sum(target * log_prob, dim=axis)


def binary_crossentropy(target, output, from_logits=False):
# TODO: `torch.as_tensor` has device arg. Need to think how to pass it.
target = torch.as_tensor(target)
output = torch.as_tensor(output)

if target.shape != output.shape:
raise ValueError(
"Arguments `target` and `output` must have the same shape. "
"Received: "
f"target.shape={target.shape}, output.shape={output.shape}"
)

if from_logits:
return tnn.binary_cross_entropy_with_logits(output, target)
else:
return tnn.binary_cross_entropy(output, target)
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ pytest
pandas
absl-py
requests
h5py
torch
h5py

0 comments on commit 83f0797

Please sign in to comment.