This notebook gives a brief introduction to DeepOnets by Lu et al in their paper found here. This is also part of my accompanying blog post here
Here we aim to create an Operator network that can solve the following similar 1D ODE problem given in the paper:
Where
We look at 2 different approaches: Data Driven and physics-informed. Data-Driven has faster convergence and is more straightforward but requires data which may not necessarily be available for more complicated Operators. Physics informed does not need any data (other than
Like regular deep learning tasks such as image classification or NLP, sampling too far out of training distribution (i.e. a
We can also use physics-informed training to train our Onet. In this example, we only use the derivative information and the initial condition
This process takes longer and doesn't make too much sense for this example. But we can imagine problems where the sensors are sparse (e.g. only at the boundaries of domains) and therefore PINNs can be a good way to enforce laws at locations far from sensors.
To take derivatives efficiently, we use the torch.func library a jax-like library designed for pytorch. We only want the derivatives of the output wrt to the sampling points y and not u(x). We use vmap to iterate over the 2 batch dimensions to achieve this and jacrev to calculate the derivatives.
Training takes quite a bit longer and is less stable and we alos need to start weighting losses to improve convergence. But we require significantly less data/information about the Operator.
The network
Here 'y' represents the points we want to query. Because
The Deep O Net uses a stacked net from Lu et al. Essentially this is two separate networks called the branch and trunk network that each handle the input function and sampling point respectively. They are merged at some latent representation through element-wise multiplication and then passed through a final linear layer.
For both the trunk and branch net, we'll just use the standard MLP networks with tanh activations.
The number of sensors for nn.Linear works we define the inputs to the Onet as:
- branch net
$u(x)$ has shape [B,1,M] - trunk net input has shape [B,N,I]
- Output of Onet shape [B,N,O]
Where:
- B is the batch dimension for the branch net
- N is the 'batch' dimension for the trunk net
- M is the number of sensors that discretize u(x)/input dimension of the branch net
- I is the input dimension of the trunk net
- O is the output dimension size of the Onet
For this example:
We'll generate our derivatives
We use a mix of numpy and scipy integration to get the derivative and the integral function at each point. We then discretize both functions. for the derivative
The tuple of data of the shape (y,u,Guy) where:
- y is the sampling points of shape
$[10000,100,1]$ - u is the discretized derivative function of shape
$[10000,1,100]$ -
$G(u)(y)$ is target output of the Onet os shape$[10000,100,1]$
We can then create a very straightforward training pipeline to train our data-driven Onet. This is identical to other training pipelines such as image classification and we can simply use a dataloader

