Skip to content

Commit

Permalink
refactored CGR standardizer.
Browse files Browse the repository at this point in the history
GNNFingerprint transformer added.
  • Loading branch information
stsouko committed Jul 6, 2020
1 parent f5381b1 commit 39f40ef
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 68 deletions.
5 changes: 5 additions & 0 deletions CIMtools/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from .equation import *
from .fingerprint import *
from .fragmentor import *
from .graph_encoder import *
from .graph_to_matrix import *
from .solvent import *
from .standardize import *
from .standardize import __all__ as _standardize


__all__ = ['Conditions', 'DictToConditions', 'ConditionsToDataFrame', 'SolventVectorizer', 'EquationTransformer',
'CGR', 'MoleculesToMatrix', 'CGRToMatrix']
__all__.extend(_standardize)

if 'Fragmentor' in locals():
__all__.append('Fragmentor')
__all__.append('FragmentorFingerprint')
if 'GNNFingerprint' in locals():
__all__.append('GNNFingerprint')
70 changes: 70 additions & 0 deletions CIMtools/preprocessing/graph_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
#
# Copyright 2020 Ramil Nugmanov <nougmanoff@protonmail.com>
# Copyright 2020 Daniyar Mazitov <daniyarttt@gmail.com>
# This file is part of CIMtools.
#
# CIMtools is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <https://www.gnu.org/licenses/>.
#
from os.path import dirname, join
from sys import modules
from ..graph_to_matrix import MoleculesToMatrix
from ...base import CIMtoolsTransformerMixin


class GNNFingerprint(CIMtoolsTransformerMixin):
def __init__(self):
"""
Molecules encoder
"""
self.__m2m = MoleculesToMatrix(is_radical=True)

def __getstate__(self):
return {'_GNNFingerprint__m2m': self.__m2m}

def transform(self, x):
x = self.__m2m.transform(x).data
x = self.__encoder(x).numpy()
return x

def __new__(cls, *args, **kwargs):
if cls.__encoder is None: # load only once
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from .gnn import GNN

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

atoms = Input(shape=(None, 3))
connections_m = Input(shape=(None, None))

m = GNN(nodes_num=119, connections_num=5, selector_size=25, top_k=4, depth=2)([atoms, connections_m])
m = Dense(50, activation=lambda x: K.l2_normalize(x, axis=-1), kernel_initializer='truncated_normal')(m)

encoder = Model(inputs=[atoms, connections_m], outputs=m)
path = join(dirname(modules[__package__].__file__), 'weights.h5')
encoder.load_weights(path)
cls.__encoder = encoder

return super().__new__(cls)

__encoder = None


__all__ = ['GNNFingerprint']
206 changes: 206 additions & 0 deletions CIMtools/preprocessing/graph_encoder/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-
#
# Copyright 2020 Daniyar Mazitov <daniyarttt@gmail.com>
# This file is part of CIMtools.
#
# CIMtools is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <https://www.gnu.org/licenses/>.
#
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (Input, Lambda, Layer, Dense, Concatenate, BatchNormalization, Conv1D,
TimeDistributed, Reshape, Embedding)
from tensorflow.keras.models import Model


def FTSwish(threshold=-0.2):
def _FTSwish(x):
return K.relu(x) * K.sigmoid(x) + threshold

return Lambda(_FTSwish)


def p_relu():
def _p_relu(x):
return K.relu(x) + 0.001

return Lambda(_p_relu)


def get_matrix_all():
def _perm_dims(x):
adj_m, atoms = x
adj_m = K.expand_dims(adj_m, axis=-2)
atoms = K.expand_dims(atoms, axis=-1)
ans = adj_m * atoms
return K.permute_dimensions(ans, pattern=(0, 3, 1, 2))

return Lambda(_perm_dims)


def mask_by_adj():
def _mask_by_adj(x):
adj_m, x = x
return x * K.expand_dims(adj_m, axis=-1)

return Lambda(_mask_by_adj)


def get_top_k(k):
def _get_top_k(x):
f = K.permute_dimensions(x, pattern=(0, 1, 3, 2))
f = tf.nn.top_k(f, k=k).values
f = K.permute_dimensions(f, pattern=(0, 1, 3, 2))
return f

return Lambda(_get_top_k)


def adj_m_pad_k(k):
def _adj_m_pad_k(adj_m):
size = K.shape(adj_m)[-2]
adj_m = tf.cond(size < k, lambda: tf.pad(adj_m, [[0, 0], [0, k - size], [0, k - size]]), lambda: adj_m)
return adj_m

return Lambda(_adj_m_pad_k)


def atoms_pad_k(k):
def _atoms_pad_k(atoms):
size = K.shape(atoms)[-2]
atoms = tf.cond(size < k, lambda: tf.pad(atoms, [[0, 0], [0, k - size], [0, 0]]), lambda: atoms)
return atoms

return Lambda(_atoms_pad_k)


def connections_m_pad_k(k):
def _connections_m_pad_k(connections_m):
size = K.shape(connections_m)[-2]
connections_m = tf.cond(size < k, lambda: tf.pad(connections_m, [[0, 0], [0, k - size], [0, k - size], [0, 0]]),
lambda: connections_m)
return connections_m

return Lambda(_connections_m_pad_k)


def mask_pad_by_adj():
def _mask_pad_by_adj(x):
adj_m_pad, x = x
n = K.sum(K.sum(adj_m_pad, axis=-1), axis=0)
n = tf.math.divide_no_nan(n, n)
n = K.expand_dims(n, axis=0)
n = K.expand_dims(n, axis=-1)
return x * n

return Lambda(_mask_pad_by_adj)


class Atom_Emb(Layer):
def __init__(self, nodes_num, emb_size):
super(Atom_Emb, self).__init__()
self.emb = Embedding(nodes_num, emb_size, mask_zero=True)

def call(self, inputs):
# charge, atomic_number, is_radical = tf.split(inputs, 3, axis=-1)
atomic_number, charge, is_radical = tf.split(inputs, 3, axis=-1)
atomic_emb = K.squeeze(self.emb(atomic_number), axis=-2)
return Concatenate(axis=-1)([charge, atomic_emb, is_radical])


def RMS():
def _RMS(x):
n = K.sum(x, axis=-1)
n = tf.math.divide_no_nan(n, n)
size = K.sum(n, axis=-1)
return K.sqrt(K.sum(K.square(x), axis=1) / K.expand_dims(size, axis=-1))

return Lambda(_RMS)


def GraphConv(n_atoms, n_connections, k=4, selector_size=20):
adj_m = Input(shape=(None, None))
atoms = Input(shape=(None, n_atoms))
connections_m = Input(shape=(None, None, n_connections))

adj_m_pad = adj_m_pad_k(k=k)(adj_m)
atoms_pad = atoms_pad_k(k=k)(atoms)
connections_m_pad = connections_m_pad_k(k=k)(connections_m)

x = get_matrix_all()([adj_m_pad, atoms_pad])
x = Concatenate(axis=-1)([x, connections_m_pad])

selector = Dense(selector_size, kernel_initializer='he_normal')

x = p_relu()(BatchNormalization()(selector(x)))
x = mask_by_adj()([adj_m_pad, x])

x = get_top_k(k=k)(x)

ext_atoms = Lambda(lambda x: tf.pad(x, [[0, 0], [0, 0], [0, K.int_shape(connections_m_pad)[-1]]]))(atoms_pad)
ext_atoms = Lambda(lambda x: K.expand_dims(x, axis=-2))(ext_atoms)
ext_atoms = p_relu()(BatchNormalization()(selector(ext_atoms)))

x = Concatenate(axis=-2)([ext_atoms, x]) # Nx(k+1)x(selector_size)

if k % 2 != 0:
x = TimeDistributed(Conv1D(200, kernel_size=2, kernel_initializer='he_normal'))(x)
x = BatchNormalization()(x)
x = FTSwish()(x)

if k >= 4:
x = TimeDistributed(Conv1D(200, kernel_size=k - 1, kernel_initializer='he_normal'))(x)
x = BatchNormalization()(x)
x = FTSwish()(x)

x = TimeDistributed(Conv1D(100, kernel_size=3, kernel_initializer='he_normal'))(x)
x = BatchNormalization()(x)
x = FTSwish()(x)

x = Reshape([-1, 100])(x) # Nx100

x = mask_pad_by_adj()([adj_m_pad, x])

return Model(inputs=[adj_m, atoms, connections_m], outputs=x)


def GNN(nodes_num, connections_num, selector_size, top_k, depth):
nodes = Input(shape=(None, 3))
connections_matrix = Input(shape=(None, None))

adj_m = Lambda(lambda x: tf.math.divide_no_nan(x, x))(connections_matrix)

atoms_emb = Atom_Emb(nodes_num, 20)(nodes)

connections_m_emb = Embedding(connections_num, 10, mask_zero=True)(connections_matrix)

vectors = []
for i in range(depth):
if i == 0:
vectors = GraphConv(n_atoms=22, n_connections=10, k=top_k, selector_size=selector_size)(
[adj_m, atoms_emb, connections_m_emb])
else:
tmp = GraphConv(n_atoms=i * 100 + 22, n_connections=10, k=top_k, selector_size=selector_size)(
[adj_m, Concatenate(axis=-1)([atoms_emb, vectors]), connections_m_emb])
vectors = Concatenate(axis=-1)([vectors, tmp])

concat_vectors = Dense(300, kernel_initializer='he_normal')(vectors)
concat_vectors = BatchNormalization()(concat_vectors)
concat_vectors = FTSwish()(concat_vectors)

mols = RMS()(concat_vectors)

return Model(inputs=[nodes, connections_matrix], outputs=mols)


__all__ = ['GNN']
Binary file added CIMtools/preprocessing/graph_encoder/weights.h5
Binary file not shown.
2 changes: 1 addition & 1 deletion CIMtools/preprocessing/standardize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .rdtool import *


__all__ = ['StandardizeCGR', 'StandardizeReaction']
__all__ = ['StandardizeCGR']
if 'RDTool' in locals():
__all__.append('RDTool')
if 'StandardizeChemAxon' in locals():
Expand Down
71 changes: 5 additions & 66 deletions CIMtools/preprocessing/standardize/cgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,74 +16,15 @@
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <https://www.gnu.org/licenses/>.
#
from CGRtools.reactor import CGRReactor
from CGRtools.containers import MoleculeContainer, CGRContainer, ReactionContainer
from CGRtools.containers import MoleculeContainer, ReactionContainer
from pandas import DataFrame
from ...base import CIMtoolsTransformerMixin
from ...exceptions import ConfigurationError


class StandardizeCGR(CIMtoolsTransformerMixin):
def __init__(self, templates=(), delete_atoms=False):
"""
Molecule and CGR standardization
For molecules kekule/thiele and groups standardization procedures will be applied.
:param templates: CGRTemplates. list of rules for graph modifications.
:param delete_atoms: if True atoms exists in templates reactants but not exists in products will be removed
"""
self.templates = templates
self.delete_atoms = delete_atoms
self.__init()

def __init(self):
try:
self.__fixes = [CGRReactor(x, delete_atoms=self.delete_atoms) for x in self.templates]
except Exception as e:
raise ConfigurationError from e

def __getstate__(self):
return {k: v for k, v in super().__getstate__().items() if not k.startswith('_StandardizeCGR__')}

def __setstate__(self, state):
super().__setstate__(state)
self.__init()

def set_params(self, **params):
if params:
super().set_params(**params)
self.__init()
return self

def transform(self, x):
return DataFrame([[self.__prepare(g)] for g in super().transform(x)], columns=['standardized'])

def __prepare(self, g):
if isinstance(g, MoleculeContainer):
g = g.copy()
g.standardize()
g.kekule()
g.thiele()

for fix in self.__fixes:
while True:
try:
p = next(fix(g, False))
except StopIteration:
break
else:
p.meta.update(g.meta)
g = p
return g

_dtype = (MoleculeContainer, CGRContainer)


class StandardizeReaction(CIMtoolsTransformerMixin):
def __init__(self):
"""
Reactions standardization
Reactions and Molecules standardization
For molecules kekule/thiele and groups standardization procedures will be applied.
"""
Expand All @@ -94,12 +35,10 @@ def transform(self, x):
@staticmethod
def __prepare(r):
r = r.copy()
r.standardize()
r.kekule()
r.thiele()
r.canonicalize()
return r

_dtype = ReactionContainer
_dtype = (MoleculeContainer, ReactionContainer)


__all__ = ['StandardizeCGR', 'StandardizeReaction']
__all__ = ['StandardizeCGR']
Loading

0 comments on commit 39f40ef

Please sign in to comment.