Skip to content

Commit 794f9fa

Browse files
s-maddrellmanderjayniep-gcanjleeg-gcai
authored
Graphium (#32)
Co-authored-by: JaynieP <92803120+jayniep-gc@users.noreply.github.com> Co-authored-by: anjleeg-gcai <anjleeg@graphcore.ai> Co-authored-by: Jaynie Padayachee <jayniep@graphcore.ai>
1 parent 4c48c64 commit 794f9fa

11 files changed

+2102
-0
lines changed

.gradient/notebook-tests.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,13 @@ molfeat:
9494
notebook:
9595
file: pytorch_geometric_molfeat.ipynb
9696
timeout: 1200
97+
requirements_file: requirements.txt
98+
99+
# Graphium
100+
graphium:
101+
location: ../graphium
102+
generated: true
103+
notebook:
104+
file: running-multitask-ipu.ipynb
105+
timeout: 1200
106+
requirements_file: requirements.txt

graphium/ToyMix.png

334 KB
Loading

graphium/UMAP.png

451 KB
Loading

graphium/config_small_gcn.yaml

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# Testing the gcn model with the ToyMix dataset on IPU.
2+
constants:
3+
name: &name neurips2023_small_data_gcn
4+
seed: &seed 42
5+
raise_train_error: true # Whether the code should raise an error if it crashes during training
6+
7+
accelerator:
8+
type: ipu # cpu or ipu or gpu
9+
config_override:
10+
datamodule:
11+
args:
12+
ipu_dataloader_training_opts:
13+
mode: async
14+
max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54
15+
max_num_edges_per_graph: 80
16+
ipu_dataloader_inference_opts:
17+
mode: async
18+
max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118
19+
max_num_edges_per_graph: 80
20+
# Data handling-related
21+
batch_size_training: 50
22+
batch_size_inference: 50
23+
predictor:
24+
optim_kwargs:
25+
loss_scaling: 1024
26+
trainer:
27+
trainer:
28+
precision: 16
29+
accumulate_grad_batches: 4
30+
31+
ipu_config:
32+
- deviceIterations(5) # IPU would require large batches to be ready for the model.
33+
- replicationFactor(16)
34+
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
35+
# - enableExecutableCaching("pop_compiler_cache")
36+
- TensorLocations.numIOTiles(128)
37+
- _Popart.set("defaultBufferingDepth", 128)
38+
- Precision.enableStochasticRounding(True)
39+
40+
# accelerator:
41+
# type: cpu # cpu or ipu or gpu
42+
# config_override:
43+
# datamodule:
44+
# batch_size_training: 64
45+
# batch_size_inference: 256
46+
# trainer:
47+
# trainer:
48+
# precision: 32
49+
# accumulate_grad_batches: 1
50+
51+
datamodule:
52+
module_type: "MultitaskFromSmilesDataModule"
53+
# module_type: "FakeDataModule" # Option to use generated data
54+
args: # Matches that in the test_multitask_datamodule.py case.
55+
task_specific_args: # To be replaced by a new class "DatasetParams"
56+
qm9:
57+
df: null
58+
df_path: data/neurips2023/small-dataset/qm9.csv.gz
59+
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz
60+
# or set path as the URL directly
61+
smiles_col: "smiles"
62+
label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"]
63+
# sample_size: 2000 # use sample_size for test
64+
splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt`
65+
seed: *seed
66+
task_level: graph
67+
label_normalization:
68+
normalize_val_test: True
69+
method: "normal"
70+
71+
tox21:
72+
df: null
73+
df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz
74+
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
75+
# or set path as the URL directly
76+
smiles_col: "smiles"
77+
label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"]
78+
# sample_size: 2000 # use sample_size for test
79+
splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt`
80+
seed: *seed
81+
task_level: graph
82+
83+
zinc:
84+
df: null
85+
df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz
86+
# wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
87+
# or set path as the URL directly
88+
smiles_col: "smiles"
89+
label_cols: ["SA", "logp", "score"]
90+
# sample_size: 2000 # use sample_size for test
91+
splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt`
92+
seed: *seed
93+
task_level: graph
94+
label_normalization:
95+
normalize_val_test: True
96+
method: "normal"
97+
98+
# Featurization
99+
prepare_dict_or_graph: pyg:graph
100+
featurization_n_jobs: 30
101+
featurization_progress: True
102+
featurization_backend: "loky"
103+
processed_graph_data_path: "../datacache/neurips2023-small/"
104+
featurization:
105+
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
106+
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
107+
# 'num_chiral_centers (not included yet)']
108+
atom_property_list_onehot: [atomic-number, group, period, total-valence]
109+
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
110+
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
111+
edge_property_list: [bond-type-onehot, stereo, in-ring]
112+
add_self_loop: False
113+
explicit_H: False # if H is included
114+
use_bonds_weights: False
115+
pos_encoding_as_features: # encoder dropout 0.18
116+
pos_types:
117+
lap_eigvec:
118+
pos_level: node
119+
pos_type: laplacian_eigvec
120+
num_pos: 8
121+
normalization: "none" # nomrlization already applied on the eigen vectors
122+
disconnected_comp: True # if eigen values/vector for disconnected graph are included
123+
lap_eigval:
124+
pos_level: node
125+
pos_type: laplacian_eigval
126+
num_pos: 8
127+
normalization: "none" # nomrlization already applied on the eigen vectors
128+
disconnected_comp: True # if eigen values/vector for disconnected graph are included
129+
rw_pos: # use same name as pe_encoder
130+
pos_level: node
131+
pos_type: rw_return_probs
132+
ksteps: 16
133+
134+
# cache_data_path: .
135+
num_workers: 30 # -1 to use all
136+
persistent_workers: False # if use persistent worker at the start of each epoch.
137+
# Using persistent_workers false might make the start of each epoch very long.
138+
139+
140+
architecture:
141+
model_type: FullGraphMultiTaskNetwork
142+
mup_base_path: null
143+
pre_nn: # Set as null to avoid a pre-nn network
144+
out_dim: 64
145+
hidden_dims: 256
146+
depth: 2
147+
activation: relu
148+
last_activation: none
149+
dropout: &dropout 0.18
150+
normalization: &normalization layer_norm
151+
last_normalization: *normalization
152+
residual_type: none
153+
154+
pre_nn_edges: null # Set as null to avoid a pre-nn network
155+
156+
pe_encoders:
157+
out_dim: 32
158+
pool: "sum" #"mean" "max"
159+
last_norm: None #"batch_norm", "layer_norm"
160+
encoders: #la_pos | rw_pos
161+
la_pos: # Set as null to avoid a pre-nn network
162+
encoder_type: "laplacian_pe"
163+
input_keys: ["laplacian_eigvec", "laplacian_eigval"]
164+
output_keys: ["feat"]
165+
hidden_dim: 64
166+
out_dim: 32
167+
model_type: 'DeepSet' #'Transformer' or 'DeepSet'
168+
num_layers: 2
169+
num_layers_post: 1 # Num. layers to apply after pooling
170+
dropout: 0.1
171+
first_normalization: "none" #"batch_norm" or "layer_norm"
172+
rw_pos:
173+
encoder_type: "mlp"
174+
input_keys: ["rw_return_probs"]
175+
output_keys: ["feat"]
176+
hidden_dim: 64
177+
out_dim: 32
178+
num_layers: 2
179+
dropout: 0.1
180+
normalization: "layer_norm" #"batch_norm" or "layer_norm"
181+
first_normalization: "layer_norm" #"batch_norm" or "layer_norm"
182+
183+
184+
185+
gnn: # Set as null to avoid a post-nn network
186+
in_dim: 64 # or otherwise the correct value
187+
out_dim: &gnn_dim 96
188+
hidden_dims: *gnn_dim
189+
depth: 4
190+
activation: gelu
191+
last_activation: none
192+
dropout: 0.1
193+
normalization: "layer_norm"
194+
last_normalization: *normalization
195+
residual_type: simple
196+
virtual_node: 'none'
197+
layer_type: 'pyg:gcn' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps
198+
layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1
199+
200+
201+
graph_output_nn:
202+
graph:
203+
pooling: [sum]
204+
out_dim: *gnn_dim
205+
hidden_dims: *gnn_dim
206+
depth: 1
207+
activation: relu
208+
last_activation: none
209+
dropout: *dropout
210+
normalization: *normalization
211+
last_normalization: "none"
212+
residual_type: none
213+
214+
task_heads:
215+
qm9:
216+
task_level: graph
217+
out_dim: 19
218+
hidden_dims: 128
219+
depth: 2
220+
activation: relu
221+
last_activation: none
222+
dropout: *dropout
223+
normalization: *normalization
224+
last_normalization: "none"
225+
residual_type: none
226+
tox21:
227+
task_level: graph
228+
out_dim: 12
229+
hidden_dims: 64
230+
depth: 2
231+
activation: relu
232+
last_activation: sigmoid
233+
dropout: *dropout
234+
normalization: *normalization
235+
last_normalization: "none"
236+
residual_type: none
237+
zinc:
238+
task_level: graph
239+
out_dim: 3
240+
hidden_dims: 32
241+
depth: 2
242+
activation: relu
243+
last_activation: none
244+
dropout: *dropout
245+
normalization: *normalization
246+
last_normalization: "none"
247+
residual_type: none
248+
249+
#Task-specific
250+
predictor:
251+
metrics_on_progress_bar:
252+
qm9: ["mae"]
253+
tox21: ["auroc"]
254+
zinc: ["mae"]
255+
loss_fun:
256+
qm9: mae_ipu
257+
tox21: bce_ipu
258+
zinc: mae_ipu
259+
random_seed: *seed
260+
optim_kwargs:
261+
lr: 4.e-5 # warmup can be scheduled using torch_scheduler_kwargs
262+
# weight_decay: 1.e-7
263+
torch_scheduler_kwargs:
264+
module_type: WarmUpLinearLR
265+
max_num_epochs: &max_epochs 100
266+
warmup_epochs: 10
267+
verbose: False
268+
scheduler_kwargs:
269+
# monitor: &monitor qm9/mae/train
270+
# mode: min
271+
# frequency: 1
272+
target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label
273+
multitask_handling: flatten # flatten, mean-per-label
274+
275+
# Task-specific
276+
metrics:
277+
qm9: &qm9_metrics
278+
- name: mae
279+
metric: mae_ipu
280+
target_nan_mask: null
281+
multitask_handling: flatten
282+
threshold_kwargs: null
283+
- name: pearsonr
284+
metric: pearsonr_ipu
285+
threshold_kwargs: null
286+
target_nan_mask: null
287+
multitask_handling: mean-per-label
288+
- name: r2_score
289+
metric: r2_score_ipu
290+
target_nan_mask: null
291+
multitask_handling: mean-per-label
292+
threshold_kwargs: null
293+
tox21:
294+
- name: auroc
295+
metric: auroc_ipu
296+
task: binary
297+
multitask_handling: mean-per-label
298+
threshold_kwargs: null
299+
- name: avpr
300+
metric: average_precision_ipu
301+
task: binary
302+
multitask_handling: mean-per-label
303+
threshold_kwargs: null
304+
- name: f1 > 0.5
305+
metric: f1
306+
multitask_handling: mean-per-label
307+
target_to_int: True
308+
num_classes: 2
309+
average: micro
310+
threshold_kwargs: &threshold_05
311+
operator: greater
312+
threshold: 0.5
313+
th_on_preds: True
314+
th_on_target: True
315+
- name: precision > 0.5
316+
metric: precision
317+
multitask_handling: mean-per-label
318+
average: micro
319+
threshold_kwargs: *threshold_05
320+
zinc: *qm9_metrics
321+
322+
trainer:
323+
seed: *seed
324+
logger:
325+
save_dir: logs/neurips2023-small/
326+
name: *name
327+
project: *name
328+
#early_stopping:
329+
# monitor: *monitor
330+
# min_delta: 0
331+
# patience: 10
332+
# mode: &mode min
333+
model_checkpoint:
334+
dirpath: models_checkpoints/neurips2023-small-gcn/
335+
filename: *name
336+
# monitor: *monitor
337+
# mode: *mode
338+
# save_top_k: 1
339+
save_last: True
340+
trainer:
341+
max_epochs: *max_epochs
342+
min_epochs: 1
343+
check_val_every_n_epoch: 20

0 commit comments

Comments
 (0)