-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Support arbitrary outputs in TorchMD_Net #239
base: main
Are you sure you want to change the base?
Conversation
Why are we changing these things?
We agreed on them a while back.
g
…On Fri, Nov 3, 2023 at 1:45 PM Raul ***@***.***> wrote:
Following the discussion in #198
<#198> this PR attempts to
give TorchMD_Net the ability to return more than one output ("y") and its
derivative ("neg_dy").
This PR is still a draft as I am trying to figure out the final design.
This PR introduces user-facing breaking changes:
- It changes some names in the configuration file (for instance Scalar
is no longer a thing). Although a conversion could be made when processing
the configuration.
- The Datasets must provide "energy", "force" instead of "y", "neg_dy".
- TorchMD_Net is expected to compute always at least energy, instead
of a generic label called "y". Maybe I am missing some usecases here, so we
will see...
New design proposed for the outputs of the model:
- TorchMD_Net is composed of a representation model + an arbitrary
number of heads stacked sequentially.
- There is no distinction between a Prior and what used to be an
OutputModel, they are all Heads now.
- The EnergyHead is always the first one and the ForceHead the last
(if derivative=True)
- There is some level of customization akin to the Heads for computing
the loss of each output and reducing the total loss.
- The user provides a list of weights (like y_weight, neg_dy_weight
now) for each model output that should be considered for the loss
computation.
This is the BaseHead interface I propose:
class BaseHead(nn.Module):
def __init__(self, dtype=torch.float32):
super(BaseHead, self).__init__()
self.dtype = dtype
def reset_parameters(self):
pass
def per_point(self, point_features, results, z, pos, batch, extra_args):
return point_features, results
def per_sample(self, point_features, results, z, pos, batch, extra_args):
return point_features, results
Where the forward call of TorchMD_Net would go like this:
results = {}
point_features = self.representation_model(z, pos, batch, q=q, s=s)
for head in self.head_list:
point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
for head in self.head_list:
point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args)
Each head is free to add a new key to result, modify the point_features or
the contents of result (i.e add to the energy). For instance, the
EnergyHead:
class EnergyHead(BaseHead):
def __init__(self,
hidden_channels,
activation="silu",
dtype=torch.float32):
super(EnergyHead, self).__init__(dtype=dtype)
act_class = act_class_mapping[activation]
self.output_network = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype),
act_class(),
nn.Linear(hidden_channels // 2, 1, dtype=dtype),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.output_network[0].weight)
self.output_network[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.output_network[2].weight)
self.output_network[2].bias.data.fill_(0)
def per_point(self, point_features, results, z, pos, batch, extra_args):
results["energy"] = self.output_network(point_features)
return point_features, results
def per_sample(self, point_features, results, z, pos, batch, extra_args):
results["energy"] = scatter(results["energy"], batch, dim=0)
return point_features, results
There are some challenges I have still to deal with:
- Not sure how happy TorchScript is going to be with this.
- Not sure ho the user should specify a list of predefined heads.
Perhaps something like an option
head_list: energy_head, coulomb_prior, some_other_prior, charge_head,
some_charge_prior, force_head
Tasks:
- Adapt TorchMD_Net
-
- Make Equivariant versions of the heads for ET.
- Adapt LNNP
- Adapt Datasets
- Make priors into heads
- Generalize the loss computation
- Handle user input
- Update tests
------------------------------
You can view, comment on, or merge this pull request online at:
#239
Commit Summary
- 13d5f52
<13d5f52>
First draft, work on TorchMD_Net
- e55d365
<e55d365>
Add charge head
- 1da9d61
<1da9d61>
Typo
- 5c17a00
<5c17a00>
Draft module
- a59f593
<a59f593>
Remove reduce
File Changes
(2 files <https://github.com/torchmd/torchmd-net/pull/239/files>)
- *M* torchmdnet/models/model.py
<https://github.com/torchmd/torchmd-net/pull/239/files#diff-c571c5ec1169ec77e8aae36aecef037f629916cba8bc342cac85edd42d801f8e>
(183)
- *M* torchmdnet/module.py
<https://github.com/torchmd/torchmd-net/pull/239/files#diff-fd3255a64c42e363ecb102409e22722c4ffe118f22076d0ebe54eaaa4ffa355c>
(101)
Patch Links:
- https://github.com/torchmd/torchmd-net/pull/239.patch
- https://github.com/torchmd/torchmd-net/pull/239.diff
—
Reply to this email directly, view it on GitHub
<#239>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB3KUOXKN6LGQRLERLQG3DTYCTRNBAVCNFSM6AAAAAA64MBGAKVHI2DSMVQWIX3LMV43ASLTON2WKOZRHE3TMMJTGI4TMNY>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
I was thinking of something a bit more generic than this. You can define an arbitrary set of output heads and loss terms. I imagine the description in the config file looking something like this. output_heads:
- scalar:
name: energy
- coulomb # the Coulomb head is hardcoded to output a scalar "energy" and a vector "charges"
losses:
- l2
output: energy # since multiple heads have "energy" outputs, they get summed before computing the loss
dataset_field: y
weight: 1.0
- gradient_l2
output: energy
dataset_field: neg_dy
weight: 0.1
- l2
output: charges
dataset_field: mbis_charges
weight: 0.1 The configuration for a totally different sort of model might look like this. output_heads:
- scalar
name: solubility
losses:
- l2
output: solubility
dataset_field: solubility
# if weight is omitted, it defaults to 1 |
Is it ok if I try implementing the design described above? |
Hi Peter, I am working on it but I have not had much time, sorry about that. |
We have already an implementation of what I think it's what you need, so
maybe wait that Raul finds out what is that we are already doing.
G
…On Wed, Nov 15, 2023, 07:46 Raul ***@***.***> wrote:
Hi Peter, I am working on it but I have not had much time, sorry about
that.
It is fine if you want to give it a try, feel free to open a new PR
if/when you have something and we can iterate. Would love to see your take.
I like your design very much, btw. Perhaps with the exception that I would
rather the gradient be a property of the heads instead of the losses.
Thinking about how an inference configuration should work, when reading it
I would not immediately look at the loss section.
—
Reply to this email directly, view it on GitHub
<#239 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AB3KUOT6NYV45ED7YZZLDGDYERQMLAVCNFSM6AAAAAA64MBGAKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQMJRHA4TGMRYGA>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Following the discussion in #198 this PR attempts to give TorchMD_Net the ability to return more than one output ("y") and its derivative ("neg_dy").
This PR is still a draft as I am trying to figure out the final design.
This PR introduces user-facing breaking changes:
New design proposed for the outputs of the model:
This is the BaseHead interface I propose:
Where the forward call of TorchMD_Net would go like this:
Each head is free to add a new key to result, modify the point_features or the contents of result (i.e add to the energy). For instance, the EnergyHead:
There are some challenges I have still to deal with:
head_list: energy_head, coulomb_prior, some_other_prior, charge_head, some_charge_prior, force_head
Tasks: