From 62b385a72b17ee0fc3c5c18e49900e6b69331640 Mon Sep 17 00:00:00 2001 From: Andrew Quirke Date: Tue, 30 Jul 2024 09:22:46 -0600 Subject: [PATCH 1/7] Updating dep versions for bh2 install --- env.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/env.yml b/env.yml index 25a6af0cf..62bf95a3c 100644 --- a/env.yml +++ b/env.yml @@ -12,7 +12,7 @@ dependencies: - platformdirs # scientific - - numpy + - numpy == 1.26.4 - scipy >=1.4 - pandas >=1.0 - scikit-learn @@ -28,7 +28,7 @@ dependencies: - gcsfs >=2021.6 # ML packages - - cuda-version == 11.2 # works also with CPU-only system. + - cuda-version # works also with CPU-only system. - pytorch >=1.12 - lightning >=2.0 - torchmetrics >=0.7.0,<0.11 @@ -41,7 +41,7 @@ dependencies: - pytorch_scatter >=2.0 # chemistry - - rdkit + - rdkit == 2024.03.4 - datamol >=0.10 - boost # needed by rdkit From 7f9112a4c2daaf4efc61bd095551cb158babd231 Mon Sep 17 00:00:00 2001 From: wenkelf Date: Thu, 8 Aug 2024 13:55:50 -0600 Subject: [PATCH 2/7] Fix lightning backend issue; add predict_step for inference --- graphium/trainer/predictor.py | 18 ++++++++++++++---- graphium/trainer/predictor_summaries.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index 8cfb1ad28..b0c144c2b 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -562,6 +562,13 @@ def get_gradient_norm(self): total_norm = total_norm**0.5 return total_norm + + def predict_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: + preds = self.forward(batch) # The dictionary of predictions + targets_dict = batch.get("labels") + + return preds, targets_dict + def validation_step(self, batch: Dict[str, Tensor], to_cpu: bool = True) -> Dict[str, Any]: return self._general_step(batch=batch, step_name="val", to_cpu=to_cpu) @@ -601,7 +608,10 @@ def _general_epoch_end(self, outputs: Dict[str, Any], step_name: str, device: st n_epochs=self.current_epoch, ) metrics_logs = self.task_epoch_summary.get_metrics_logs() - self.task_epoch_summary.set_results(task_metrics=metrics_logs) + + for task in metrics_logs.keys(): + for key, val in metrics_logs[task].items(): + metrics_logs[task][key] = val.to(self.device) return metrics_logs # Consider returning concatenated dict for logging @@ -614,7 +624,7 @@ def on_train_epoch_end(self) -> None: else: epoch_time = time.time() - self.epoch_start_time self.epoch_start_time = None - self.log("epoch_time", torch.tensor(epoch_time), sync_dist=True) + self.log("epoch_time", torch.tensor(epoch_time).to(self.device), sync_dist=True) def on_validation_epoch_start(self) -> None: self.mean_val_time_tracker.reset() @@ -641,8 +651,8 @@ def on_validation_epoch_end(self) -> None: ) self.validation_step_outputs.clear() concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) - concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value) - concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value + concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value).to(self.device) + concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value.to(self.device) self.log_dict(concatenated_metrics_logs, sync_dist=True) # Save yaml file with the per-task metrics summaries diff --git a/graphium/trainer/predictor_summaries.py b/graphium/trainer/predictor_summaries.py index 4cec79377..4f8c3b032 100644 --- a/graphium/trainer/predictor_summaries.py +++ b/graphium/trainer/predictor_summaries.py @@ -525,7 +525,7 @@ def concatenate_metrics_logs( concatenated_metrics_logs = {} for task in list(self.tasks) + ["_global"]: concatenated_metrics_logs.update(metrics_logs[task]) - concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().cpu() + concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().to(self.device) return concatenated_metrics_logs def metric_log_name( From 47b7d1cdb573be16f50062f7674a05b8e85510a3 Mon Sep 17 00:00:00 2001 From: wenkelf Date: Thu, 8 Aug 2024 21:38:29 -0600 Subject: [PATCH 3/7] Fixing device issue in metrics calculation --- graphium/trainer/predictor.py | 4 ++-- graphium/trainer/predictor_summaries.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index b0c144c2b..95d2c3166 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -650,7 +650,7 @@ def on_validation_epoch_end(self) -> None: outputs=self.validation_step_outputs, step_name="val", device="cpu" ) self.validation_step_outputs.clear() - concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) + concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs, device=self.device) concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value).to(self.device) concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value.to(self.device) self.log_dict(concatenated_metrics_logs, sync_dist=True) @@ -665,7 +665,7 @@ def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader def on_test_epoch_end(self) -> None: metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test", device="cpu") self.test_step_outputs.clear() - concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) + concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs, device=self.device) self.log_dict(concatenated_metrics_logs, sync_dist=True) diff --git a/graphium/trainer/predictor_summaries.py b/graphium/trainer/predictor_summaries.py index 4f8c3b032..7a4d9292f 100644 --- a/graphium/trainer/predictor_summaries.py +++ b/graphium/trainer/predictor_summaries.py @@ -514,6 +514,7 @@ def get_metrics_logs( def concatenate_metrics_logs( self, metrics_logs: Dict[str, Dict[str, Tensor]], + device: str, ) -> Dict[str, Tensor]: r""" concatenate the metrics logs @@ -525,7 +526,7 @@ def concatenate_metrics_logs( concatenated_metrics_logs = {} for task in list(self.tasks) + ["_global"]: concatenated_metrics_logs.update(metrics_logs[task]) - concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().to(self.device) + concatenated_metrics_logs[f"loss/{self.step_name}"] = self.weighted_loss.detach().to(device) return concatenated_metrics_logs def metric_log_name( From 1ec4969f7865a98ee6abeef8bce0e28b0a423438 Mon Sep 17 00:00:00 2001 From: sft-managed Date: Mon, 19 Aug 2024 12:19:41 -0600 Subject: [PATCH 4/7] Disabled caching model checkpoint through WandbLogger --- graphium/config/_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index b27aa9b4e..2b9055a83 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -447,7 +447,7 @@ def load_trainer( name = wandb_cfg.pop("name", "main") if len(date_time_suffix) > 0: name += f"_{date_time_suffix}" - trainer_kwargs["logger"] = WandbLogger(name=name, log_model=True, **wandb_cfg) + trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg) trainer_kwargs["callbacks"] = callbacks trainer = Trainer( From 9ba5a16da047461162b4c770e429059193bcb0cc Mon Sep 17 00:00:00 2001 From: wenkelf Date: Mon, 19 Aug 2024 17:39:44 -0600 Subject: [PATCH 5/7] Drafting unit test for node ordering --- graphium/data/dataset.py | 1 + .../data/dummy_node_label_order_data.parquet | Bin 0 -> 38550 bytes tests/test_node_label_order.py | 103 ++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 tests/data/dummy_node_label_order_data.parquet create mode 100644 tests/test_node_label_order.py diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 498515fc3..6c5f7a1a0 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -195,6 +195,7 @@ def __getitem__(self, idx): datum = {"features": self.featurize_smiles(smiles_str)} else: datum = { + "smiles": smiles_str, "labels": self.load_graph_from_index(idx), "features": self.featurize_smiles(smiles_str), } diff --git a/tests/data/dummy_node_label_order_data.parquet b/tests/data/dummy_node_label_order_data.parquet new file mode 100644 index 0000000000000000000000000000000000000000..a9a165d8257a3168f91b95091329763bd3005d2a GIT binary patch literal 38550 zcmeHw34Bvk_HbTFA*CUN(3Y@D3q+Pu-oB(w+O(ycv_O<@P%&vkTj)kts4Nym*#$&l zK+q~6Dj;s41HGYCs)8)zii(Ulg5okVioz(ucW%-)X_``;v3366&-~;i_q}uPJ^MZ9 z-23VzlbH&X2QDcO{5)ntU_hWuBa_{y4>fm_$)?66(0Ve9@>R&GEZ-h7#;=ErgMR|i zW29cnGE9xpQ%u%K-(fO}7!#_e=&WH=Q}h&@HOenYMiQY4S}y`P5C2%d9v7F>S=7*Z zBgc|qG|34Z8#45+%nShT;T@g-5JNQ6htTajeAgp;Q=8GmhS*Kr=KY5C!R!M3VwM zktk9Wp$|v#BrAp!BeXn6k}MTU)GAQISG`&yf7o~w5Ln;=P54dYS(c*0=`fO)C=MQwB+pA6%@B{V97mHR2hV7Z=S1ReHAl11 zXP}xTfed1nngnF&;7TRjm3-gZATTt~bE85n<;8`QN{dVMWLRNwfeBj4mXef;LaVJb zjMP(cfieZ9r|}n~XYd!NXYrSy=kS-L=iy69F?t~>2!o3Fi`Gl{i`7#}AIQ`cj6=nR z$rK{AKaweU`1;uS^t4P#&s##+QU`%v4Fmn56Mrp~+ca`LXPF>N$j@aNJ(UGYf+Z-D z!ie>1k*7If#DM^k6@bO2-f9pI#RZ=T0J<4Y3ck)tJg|=rXMjx%DH4O!G{vwCF^Hr& zief2vAkczniVL6xMq((Fb~wqi3R2@oUwbC`mlQC|I6LJR|40@nJt&b4aciclqt|8Cy4MubQIZ>EcVuk%p-gYyANcS}!a; z6+nl9=%iq+l4425_DBHDu_DdY)CKU20GeI%i;ORv7XY-g0$sC2%QG}65ml2dV@K4f zC13(wGfTlxZ0)EB&^3Vp#Rl%OU?>tXa1cpStVGdL@MTOEFZk-!JPG_G?og5x#V_rl z6lszoMv$Z^2%^A)nu3Za`zs|uV2Mlap_(;xE+a6I168MG0V`>qbE^%1S z(Y1?$eQG{ZNE}6ICx9l?ELfq)fi(k05={%fht(X(vOE!@7Ab*bj@<2IepL=Kx!Xrh zkty?hWD>rx_#&pj6DH*qxp|))t7pnyxr{RJmCGp_?#L;|{EA#5%zM%jbm58;++4!s zFY>hta9@^NOyA106VCn$ml^~YOCY}Vnr*a@L@zZs!>KX0+hj3|Ps&EwhAOzMQN-Y( zAn2(R6Z8~c^Mx90nPSeGLFrg9<(i2j8K8g#GZuJ$)aOc?;+WuT!v%0~Sf4q5>Ca%$ zydatKNt$N`E}R8zrNIuC%YnHDkqiTWATShYDMPc&;;nLysW}d}1@Kc858wiD-qgL7SiVO#V}HMz2csm+N_pZT56ak%F2UQ{VXJO4TBNSxWsn^+>RKvHs)Pdu6oc zNwolK5~q63UpC6pAZNfkD7h?p6cHn%dc|1w$R=8LsGjl%6`?qaQH3j%B$yDhF#WI5k~YFM7&_Ed$Eg{8za zRRL=#F;PwPu>M-sDJ4>X#Y#zuJbz9KrD#TwSl}eMNlS>D1*a(yv(*d()@Ioagp#~U z^jESxO&)$)%Y&l<@ztVvfeG!wMXe>lmBAWad( zO7v5Le({SLEiBL=H<9Qwl%!zgq$QeKoI61jK#puU1u+35NpN2jq@W!CK@#|SQVIv! zhDsbPs3b4aOOLBL2~)CWI4{s50JEH=(kWCI{S=%?29t70IJ(H~5(e*I_=I%l`Sh z$|C9d&6Q&t4uv%8O&d(6|+W7ov4=z=PPKD(wy^Gh!U-u zk_yNR;ay1^i@OH#7p&AAby6H1Y)piyUG@$reVFHRex83s07NCA}1M5 zbx+VI;5%nBDBz*xVb$w?vED- z4Z$N&JS&=>03SfnL=Ko5tP`t22Ebfc78wXy#QHl&g4TjEg)xG_YeIUfN#JofG3&x& zB@3R%azaUREO~s53LNkoKIYq%V0mxA0&0HUC#AeQSR`Dyl*8CbANEm5^Q!CKQ!P|# zHYlH7BNJ<%Q6>9C65KvZ%mDR;)m8$dV8IOP`^yBC!B5rb7b;M&M2AZd z>@re)rIP0;L9@#*00L5m5^En-aJ)pk2J0fplBN}*V3H8uhEuSdFi)>j(;zL9Fa|Ij z2M<*7e)ZMC{h$j~`;LsPzcxtZ8Su3r zPpK&&n>T%+q#(q$>6H{s^291F39AaM>Xryt%vsG<4UA1{PW7>EIDQGQua|txi~E4! zF2Rz^n1A#!|Kwwy=?7NylTRi*;LLTtU?6Z2%m;jBT+ao{X9@qV_4WzvAZD>aRM)UbR>f~G+w zXZXavGES54AACgt0x|ul7C2y-=^Pk?z_Oq<3Sewd9LS9XrB;!>6|l%?Ug)dfK&%7Q zSeHqHfhZIF7Z1yvKm*-xtG4*bBvtobS_&Kv2ND?yauzk?6lz$YASgKq9wA5{AmdaA z6{_cyN}2(2b?>F1IqC2na)~k>g`PaaQMBr7rILk3P_DB17b7?T@t}93nWdcR%zrk?K@!*f?&+RN$e(D{Q!1=@tZIIMoEB6C{*X;($~n9hX$3Q84rVo_Lpx4hbG0V@xrcO^U2^mHz-m zf?iOqQK>W<)go1GkgRBYhAdk%T_ej@{TL*lT3j2g$qx*m^?VjRRCPHBqFITPELY_W zGvL==kj_z@=JNoxz*0P~YUrt8V4*u~B9#&#t~wi_fPhjn1JoSi0P(C9z6C+j@PQx$ zi4Dr`1@Qs5G#$HQQndth43=@#J3ZB)UJ|9s@E-)8M1ZI}T(z{nj1D^AO96>R%|3|9 zK=~*L#voFMOhaS^QZb-z!w4n|ti=MQ+3g<+*#*R~z-a~ywntwzEa$AG*{TSDrCZ_{ zVzC@n_hAE+5VX?Arpaj2juE3Fm;kk@Qz-!pvG8Rzcx#C|-dAP5%^#NO7+)D{PFJMt z^)VmpW8Moc3g04A_JPC0zc{#LYE*K5UZ5;+XyEWrsDG#fy9ah3#LGsm?LIX-0hd4U z7X|5Cg5D#E!?z%~9|Z}5UIaK?E@849U$pa!$6t^QLt+QD5?H4owZkB#x;m^kT!x(g)5IPY`J=brd6a0`HN@&}`{ko`G|5IFyHRBId= zpaV^u1LE2REEf(~XBz03m8T%V1Ub7Uf`X$@t6UJ0*f4;^^ElE$Tp@u{hohJSL_k>u z58zL14n=_$wgskW&`K}2p$g-LGGU(f?d}fE#a4ke7|g;= z>A^glY7Jj*^+6{j26ZnHI`5LsE~4|xVJ(5f^E33DIfOH&;aTqEOzf2f~K7kWg-HdlQO}ZEr#;vh7V=nDlZN z>XcsYdbJ-4s+isk2p2WmjsYc3FZKgA=!G;G*b8m+Jef?Uj_`qN03r8P(|1&EfZw{^ zcMb|v_UB~X)%~@7i%WBCrM8^u*>lPY^9yWc)0vz=pP+!C##Kbn_B(YeVs(l?m#jaR ztk&iHKbNf1us@e9Z(Ce{E?J)YUTr>m8LJQ%xBa-{w_K=kj6NdTCZrMHb)AqW>cPx}miCkHN-y${MJiQE)%3ICO8(?mWi|U3Yp)-S{i4Ydd@FN?I%$xI2Acl@*!NiHO84!CG1R;NnpOsk{9Ah;6}x zeIJlarkw*vN1ntQ{(v57aMi+Zfo3oc~#@V_{zr5xU9x2!R3u@wGa2w$}v3F z4p+;&_8r>jMw7-#ldZ^Bnmc!Td2zX=z)b;P`wO9acLS!&HK@lF9V_5UTV~{d1BsDe z><}V}K64`*?q3@D7i(nXLt`I}^f_4+`Cf5Rmskp>7ubq&%fWfLQFrN4LU-Szm^#-JZt@ z-NDB&b*`a3ChC~F=}#UrjvTwu_}=dG##zN%jA2Q*^CY3W`UIxVHMGY>9aGo$(dQ!_FPe!Vv~t`RPwHxUr*8WWH|heMnwbw4 zSDtT~UT7&RYo(Pto+5OwY{!(jhV_`FW6C~FnIAbU<$!VBWHK`4+r^Q)hkX|LVC7k3 zz@Q%@|NTvV4$_r%}LMz98@uaSncgj9`wgbxCHS+W`gzlTCF=eh{Jtpax zvSVKih_vU-GW?kOvhnJ8I`T@|PUGXVBxB5rRmN|Y*BCXMBaFstcY4uC46T*pzIaks z%R6POp6h@zcWqp~o6tSJ3sdGA)?<>6DJ%24o#*?mJ?^K?vH(gqd(?a6z%)l1&yekls(|7xDdf06M5_oj9Y-+$>a{>bM;`15Bz^+Ftm z(8_UNJgKYYow%zncTAkSPF{P7(5-q26XzP)W1fzQo3i*~^0hufcWgiBZ! z%4KJN%@)3s$*nmsl{@!X1h@3G!3%M?r&f;p;z?aC@5JqWwPWJk)$)~B2;Cd|FmbMt zJ?80{xFe%C88rVs5i#WS90T{*V8eISrTg!CEZwj+c%@8le${oiMx2fjkuw$)iOKJQkn~8`ofBWg8W&wqE`C({eB22_hTwuQ+Q0+F_nww zZEDEgGW*a63&{rSxfdEXrd2mgi_K^V*rPwRea+@WXNpfA8W6s+zAcsbu2zow;7MIA z?^GT*cymS}qX^4-_ns2t*5V+)F_a`Id2>-&v_?)+g)qiX_>={lzItEc-lth#sLp>ko@q3;_5 z8a|C3-w-qXgF_|K*7|pngATp+Ovs_HU-I{&ulTN3j^Xj7u9kNiA3b{0G`ef+mLr62 z_nVkT*90EZbxh;%o&6e4EFXAihJM$fo;MOF{0)~fbBu$gE;kn4a?VRQgL`P@7!FVBYI&z|=i6>HHWgCJO7d)_ z#o77gbKL^#-G3u==p?4fHKfM`9aDwkb0fFE{iSi#!l{wVvc5H*KtCJj)8it49lY82 z(xW{iSMg-zPXoNAe7L7pj^Xg6u9kPIPQB+wRa0PHQeY{vO}E<0%BRoBFL$Htqjw42 zSMOlTT*G=y;+3+5zzGuqHI5g#hOaoDe=G_rDpA2xUWj^U)rP1Q?7LCihKHi=`RmJ3 zXMcS^YW=P~xI7dRsB;~D(#r7@@GAfp_o@}p(B1!`8(D!) z$5Nh?Go!qz_SIZ`_>>6xlF;3MiqLKO0Mpg(H2))X?LF;A*Fbl=n(7a&N-Iw4_I?cO z-beo>b*aBDY}|DAi^G@M%ErL-7aRL6xUKQRl~s-Xk}4ZVidl`H40DuKJfwzuXyq6l zOU>14M{3{y-Hp1Y`h!zyZK^9y(!M`K=$4(qq`AiSn8~Z)b|?a0P)rD1=Ht-Gb@EGn zeH5QBLczWdpg79|XvcuX=t-4C>QS_RzzHamW&FJI%2{b{Jqm+H?kiOf! zXqzSl4SYNneazp7uKWHKZTm0+=??cW4jg^n@YcKcp&1juG8%_nK&nf7Q0ZYm`wQ=% zJaDJF4~p3>M{&A0P{~Ony6?RNWbU~UU7Po@QU6@9ad)qm4V%`68tz-P*1)W+G+q_D z*=@^4*`u@M_T>i^_P1Y~ihk-j6&=k{pxY~ZqMYZC8M4>EVT?Jq-e}%;uko+r`bGY0 zQ&QxSrH`OnY{TuFf4Yi3KmAX1PpI6kpL+p~n0p2#C!Rt6+0UZ8-s&E?b>QI0I%0;Q z|Ej;E#`<$;Px=q&yK4&jx;$U|j|ZlrCFBq^^MOcIn9iW{)AVTUh}+Ph@50baWj&3v zHZ3#WbJuZ0!uAH^s{?9`z2_hHmz4|NaNs@BcY^#c&pl`6L<9RrAq3o5Rt(7%Fmu z{gC0>ydTin1zP*%wdaxl=(T8L^}|SWT2aSr3q&)23`B{!_ZXsmZ$;O?I9YdQ2#cZ) z$s&ouuhD=3c66lZ3<5XwsZZeNKA~7soX0S=a?B`L8G*gn^|@N0(|ZVVAbj<>w$QF6 zx?R)w$}evAz4l80;d|}xYPj5j9n~_QZi-Ss^028n!W*nLH)G- z2MioEc*re7gF}W5AECQ-RdeQ6-#u^s zf`#|gEV}nTKt^}!TbId+fknkRw&|X!g1jKVth|S+GIs8t76|2fVdd)@# z)r)9Y-9_~M^Zw}C#!$4}tU(pV73dpHAX=`~A@pH)bn8c)k=VZCz~y7}BW^js*ZuTT zLtV(>PYqX#SE9l<522!~XOTVUA86BTJz{MFB5i7PvUgwfT-Ei64a`8}(V8I!{|yHt zURiys@v(|5n-3$|$WVHGFbW}fREc%%I3~6sYkI1@}ND-_?EB6jSBlh^AHTxpap^4K_-Wmh? zsrvGPo--yGF2wrOEf{>t@T@V-@cw|&D84ZXo!xK_tq=bPTD6+P zduxrx14l9=w(oloP5tT>v}g*4wyfWX4peVOLk?-trokEL$AApv^TWwHRG)_4nErF! zS6>W9dDBiAlI}F1uewQS`1haGd9^;wC@8LhXt1WaTA<7S2d(ejU)-#(Z!7C_rakT2 z-s&ra?!b50_FO#h;D}e-d%Q>BGd(5*en~i1zLVW9ogn6iOhoTIIT!7--eMoH$mGBkbv+kPu-9Ex8OJZYh++$0MDJQpf(L&Ag^PpiPhR*Y{BP?vAJ|km*gm){ z#2zLj7>_JFyMNNPe;OWY{4!!mnAW&pP#xN>EV4fqzTCd(!C-V>uUk;&zk1aTI`Mkl zuKIV3JGhmGfIFk>rf9!5_S-xt^7M}>k?XJTLH56_Mt@tQwEu8~u>1A<2%XwE);{^b z&*=UB{&w{@pQ6})%E)z0LS*fhX#1pw9>}1+fWDYqi)xOS+p~Kb?WZq3i@2=k>hgEx zq1)GOLZ){KyFM!jy>ubguC98|_|0Oo@%bnHB42v;q+!^Wmy8b_G@*sL4G7hbv9mv_ z>{|=Q+8-}?1`XDqLzn7%+W+~&7bvmzEyM1d+{oM-o$>hgCnKgWr0vf?`6jySNPzul z?*RMDpZVLLpZN;f^JNg4eSVz%lGYGO&DntV-*ph39yQ87aN!f^)+a6_pPIcWrZgI{ zKfR2yv>&2f>;;1+P;alL9@@X*ft_`~-k)zA8kdLqoVC|o;g;c5YpYM-Gq7rXMYOD1 z7!lAFU@MM~TrIGhwM0}u_-b0UuHx0|7sB`AJm->y!|8->-L76=OLXgT%RTqp6*W?; zT5_MYK2IF7dPz?7)9GOwR?KL)Fvs}#>b%&|e@WS~Zsy^cuim+5LqXiL9~QmxXwmCm z|9io~$4e4I2R!)3*3vf%^pBm~QJ%Cb^SM(yE02C#ZU1!7oC#r@PJH=N^|7M;f4}(3 zJQYl-TeiffW#3vH=r9duDA>{R^c4}**fZ$;p1OM$`{-W!6`P2QTo2-J)FDD#$5G!cZ=mN+OhX?( znTD?IwWGS*>(RN`d(p&)KCIITN9rb4EkGY^^*27Xs}@BqQP{6sRoHv4pMXAFmx#VA zzZLzubu`lMxv%bekCDdiVZ9@k4*xmg&B{BCkKFwiWB+qCXz?CBvX#DodXIVpCFd_j zyIBHx*RGYm{DHSYYiuI?|7-;eyBUXCVKtwLMR_}dp=?r9&q z<|*_*zgV!$CBR*A z488fe+TQP9fp%Gm+CJyS?I^$eLFE6P8C{-wH!42$h#@`0A5D2;r=e=I0nG^PY0n%7 zyod?1KXiLP`;S|K>~zIr=<&}^p?jBYMC+DLHNLv70&SJmqdvEPg{IU#kK)Dm(Cdrd zMlUWHh7z8B6xoPPX!m2NuE)j)kudOH6iY5Zi>|LU{P0;hdh4Ax(b{j8dN&cwDA>q^ zX|a~LTK}^qa&n!UiS+kkBHq`K76WNqg|sSg;4i&ziRy%|XVt0Os?wtHE2o~_Ul zXp{viNBDR>5I#aSVGuCS@!aDeQ1c@NbZv7eMHYx}d!FiOb+_eMg!8V)xdP`+FB|jD z2kq``-}tfh#*ed|kG*bAw*RPUP;b4O25fJC$vI{_Pn(8q=V{Z(9c~dgr?>%_v&P;b zn6pUj1nZ#6dow^y#NBwLCK}rm>YNjJt#mmbd9M~ZpWUe9;)H0oY9Zr%wBGZpK%cI)=c^;X?QO@lL9Xu4L+n|Umse_ZoqL*G z*Qv+4b)9;?TGyEeopqgk;#k);M}Ku)a|Tz}HHTq!oqsA-*ZIdob)A1sQ`aR2CUsqM za#7bMM+kLYdzMevwTJ3-U3n!BKK(m6cmc%b`g+eixU( zPs;h*=J-873m!NJobDLi)--PGVP@pp3Ua2mzURbidimTEJYPsfQGPZYP;sD8QCI>m z+ciBXw$8L=m*e||@VaDf}$ zRKz0pM**A@p1F}<)YW>D$dkD3K>U?C{66r1KQRZ$Z-4qf3x%epBvdQFMpzN(4JJ-ezrVLa_5~*ki&DP1zfKr*P2qf?0b) zZ4J>WNVI+*tUW~9LUjr<9&d5Ahm9AoPDID!<*OSbs-|{jc1$9l7(2eXUEplTN$I!a zF`DKokV>qD<^o%MYBioGjVa5w7RHrl-VS}^Im#MeRhAf&RA@;{Ez69PDl*cj0t{y@ zN-eX-z&J(GfEUaS|KA2%dbHzNVo_>^B`u{gGk!`w#sy%@vl9wt~=X zDac?ZJD;bek{M~{yqx&)m?q#yYDESu%>vqs9JCjtPBG_`Y+@1B8h00;TJTCNl;#32 z=VqpvXJjyW)z-9vvdr|98OgMBuFS%M0&Bc^MhlIAOE%D-U6?ux_|In3REZUMkzSoL z%faj4pFW4oW`GPpzFzTe#=$A;%#zHsIe=TU%z>Xw(yYw%qymsf6Me}JK4!L#l~Dc@3fQVTqT@t#cjy+vrJRVb0u)^yuPrI<5@*xn~79@v6cy>5$d3v83!rnu{Ip*3F0 z%ZW{%+hmuNH9eYh=f4L%zm30zU>B)Uv6ZGO+tY;>S$M*F)7wH1-nF$$4|=GAocK8K zRpz`Fn@=Fi5)<5fL`+gqc41tVB~Hr6x|2?~@aqP;kc7vbkx9o@wQFOYi)(?k5PSp! zJ}F)T-&5t#B`?0AnKrNR-t5go-d^m@Lk6vV-;MQ`iJl8Sai%35d~TdnmI?kb2iu|3 zw%h0Tjr(=EUS}sH!J3m&0{kmWr{@4~b4aW|9qITD`*pp(S7pPxok14>|K`kqHKQVv zF@v9T`r;et(hd7{xlYj#d!#3kmbADT*0i{4SjY1^k^|^qL3WY31bkC@MtVwhIz9dm z^5<=E?@))@;_h5W{s;Iz^SAT8{dK=n{g9i^ZvM)XHpmSzZGXJ(wV}-^x8I_Np0qW|?SD!q%dAWe z&Ph9tx7_@r8*e=OcX{09U6}7e@Ml7 zU<#Q|XO=Y86mFE4ltC>FVt}Lqs9(*1n3?>8a3?e4&4oE}Ql2#-x)|cxqKtG?M?7!0 zPuJ^;c~)k8Ch#0`5YT_9@7#PY&UDzf%l%YNJm6FW`hfF{!W>*rsf2n#Sx52l4f}PS zU#hUuX3%LjUz85D()fbX%rv$=o$9cUH=Tkpt8?PYTWU7xzhjNhS_n16RPZM)@l=za zch=>+*fao1Dt;P+fBlkMPJ4y^6jbU~#xANC_+66d7kvsDvjn(}e|+->iH2ESKrNs}rZ zW6VsfNX}>HBo>jd$6}g+FIK!46O#*jF|ozbuuW5jw|)|G6JY-Zc4Y7s+R3JP_$9&~ zkZB^`tC^GwdqFX5H0uoq;?#cc{Yn-Z&Hq5pW;{xOZi z+ds*s80hcV2AYi9ld1qThQ(d)z}rX@@upF<14%rFDJGU3H%^2fW zb5kl4H`~|ICOgGjPtnydo|HRg(ztQBWtxL~IyRy3{#5hFF+Ri%1y($VDK@4BJ`wlt z5Fab<)jmEhdov}@gz?5Dw~T*N_?Rlj%xJ=Y67KJYkCPv6{K1$2YvI$b|H_Rm!}BLk zDjGKqZ`Vyuc2Et1#|_67UmX02FCa719O5y}s%(aj&Thsh7Apo;rDk}isAl1I9DWqv z#wWV+S}dQqJGgP<#$+^)ca!qrTi}y%e+NHe9ZZP91U2ak{)H9U6y03IZuV!MwQLM& znpj!d#1GT>F_o?K#f?4(s6*b(`wq>%sAa8&SgyIYZ%QcToA5K4CpC}nBZXmVna3G%BL14aW~qD`?gg;w%*_YZi`Xck|Z z<=sJE1O8~jgy!)~GbT0T@01VTopzT`0#>DF`a8;Rfo+nDAEp-ko%ZhXcaAxno8&j< zhs%dRKIiB)sj|9RUlYf*_HX~`_|5ec$M|_I<3~^aoyKpjMLEW=ZW%u&!(;qn4B)0O z>A3BMKkl_VQ%oh(tbbGf`|%P@=9ckHlgBmtU#I;31N_FDa$ClWPVQhlSZKY{+qBo< zlz(N*crl6p-FRRJ@EOxGUhMe)Zan95lu`+%%ZJIRG13%o;ZM}o@J8u5Zl>@`w43I&S zJjz%0q-Qf`5Lm~6nc64L|xfVdM#$-ptO!_c5}Gf~meX2*D*_V0|hTxU{d+7g<^rNn9o zRvkk)`D-x9v_nfQrL!w+ Date: Thu, 22 Aug 2024 13:44:26 -0600 Subject: [PATCH 6/7] Fixing unit tests --- graphium/data/dataset.py | 7 ++++++- tests/test_node_label_order.py | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/graphium/data/dataset.py b/graphium/data/dataset.py index 6c5f7a1a0..bf55e0418 100644 --- a/graphium/data/dataset.py +++ b/graphium/data/dataset.py @@ -45,6 +45,7 @@ def __init__( num_edges_tensor=None, about: str = "", data_path: Optional[Union[str, os.PathLike]] = None, + return_smiles: bool = False, ): r""" This class holds the information for the multitask dataset. @@ -75,6 +76,8 @@ def __init__( self.num_edges_tensor = num_edges_tensor self.dataset_length = num_nodes_tensor.size(dim=0) + self.return_smiles = return_smiles + logger.info(f"Dataloading from DISK") def __len__(self): @@ -195,11 +198,13 @@ def __getitem__(self, idx): datum = {"features": self.featurize_smiles(smiles_str)} else: datum = { - "smiles": smiles_str, "labels": self.load_graph_from_index(idx), "features": self.featurize_smiles(smiles_str), } + if self.return_smiles: + datum["smiles"] = smiles_str + # One of the featurization error handling options returns a string on error, # instead of throwing an exception, so assume that the intention is to just skip, # instead of crashing. diff --git a/tests/test_node_label_order.py b/tests/test_node_label_order.py index 974834405..1baede2b6 100644 --- a/tests/test_node_label_order.py +++ b/tests/test_node_label_order.py @@ -43,13 +43,15 @@ def test_node_label_ordering(self): "task": {"task_level": "node", "label_cols": ["node_charges_mulliken", "node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, } - ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) - ds.prepare_data() - ds.setup() + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + dm.prepare_data() + dm.setup() - self.assertEqual(len(ds.train_ds), 10) + dm.train_ds.return_smiles = True - dl = ds.train_dataloader() + self.assertEqual(len(dm.train_ds), 10) + + dl = dm.train_dataloader() batch = next(iter(dl)) @@ -73,13 +75,15 @@ def test_node_label_ordering(self): "task_2": {"task_level": "node", "label_cols": ["node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 43, **task_kwargs}, } - ds = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) - ds.prepare_data() - ds.setup() + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True - self.assertEqual(len(ds.train_ds), 10) + self.assertEqual(len(dm.train_ds), 10) - dl = ds.train_dataloader() + dl = dm.train_dataloader() batch = next(iter(dl)) From c23dc023343beaeb1ea94c7095cbb1b17596bcab Mon Sep 17 00:00:00 2001 From: wenkelf Date: Thu, 5 Sep 2024 09:12:39 -0600 Subject: [PATCH 7/7] Partial fix of node label ordering --- .gitignore | 1 + graphium/graphium_cpp/features.h | 2 +- tests/test_node_label_order.py | 262 +++++++++++++++++++++++++++---- 3 files changed, 233 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 289f10a4d..41dbc0e45 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ draft/ scripts-expts/ sweeps/ mup/ +loc-* # Data and predictions graphium/data/ZINC_bench_gnn/ diff --git a/graphium/graphium_cpp/features.h b/graphium/graphium_cpp/features.h index 4bbcde001..034e85630 100644 --- a/graphium/graphium_cpp/features.h +++ b/graphium/graphium_cpp/features.h @@ -274,4 +274,4 @@ std::tuple, int64_t, int64_t> featurize_smiles( std::unique_ptr parse_mol( const std::string& smiles_string, bool explicit_H, - bool ordered = false); + bool ordered = true); diff --git a/tests/test_node_label_order.py b/tests/test_node_label_order.py index 1baede2b6..35411cc33 100644 --- a/tests/test_node_label_order.py +++ b/tests/test_node_label_order.py @@ -28,75 +28,275 @@ class Test_NodeLabelOrdering(ut.TestCase): def test_node_label_ordering(self): - # Import node labels from parquet fole - parquet_file = "tests/data/dummy_node_label_order_data.parquet" - task_kwargs = {"df_path": parquet_file, "split_val": 0.0, "split_test": 0.0} - - # Look at raw data - raw_data = pd.read_parquet("tests/data/dummy_node_label_order_data.parquet") - raw_labels = { - smiles: torch.from_numpy(np.stack([label_1, label_2])).T for (smiles, label_1, label_2) in zip(raw_data["ordered_smiles"], raw_data["node_charges_mulliken"], raw_data["node_charges_lowdin"]) - } + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################################################### + ### Test I: Test if atom labels are ordered correctly for a single dataset that contains only a single molecule ### + ################################################################################################################### + + # Import node labels from parquet file + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} # Check datamodule with single task and two labels task_specific_args = { - "task": {"task_level": "node", "label_cols": ["node_charges_mulliken", "node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, } - dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) dm.prepare_data() dm.setup() dm.train_ds.return_smiles = True - self.assertEqual(len(dm.train_ds), 10) + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ################################################################################### + ### Test II: Two ordered SMILES representing the same molecule in same dataset ### + ################################################################################### + + # Create input data + df = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]", "[O:0][C:1][C:2]"], + "node_labels": [[0., 0., 2.], [2., 0., 0.]], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True dl = dm.train_dataloader() batch = next(iter(dl)) + + atom_types = batch["labels"].node_task.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) - smiles = batch["smiles"] - unbatched_node_labels = unbatch(batch["labels"].node_task, batch["labels"].batch) + np.testing.assert_array_equal(atom_types_from_features, atom_types) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################################################################################# + ### Test III: Merging two node-level tasks each with different ordering of ordered SMILES ### + ### TODO: Will currently fail ### + ############################################################################################# + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "node_labels": [[0., 0., 2.]], + } + ) - processed_labels = { - smiles[idx]: unbatched_node_labels[idx] for idx in range(len(smiles)) + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, } - for key in raw_labels.keys(): - assert torch.abs(raw_labels[key] - processed_labels[key]).max() < 1e-3 + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + unbatched_node_labels1 = unbatch(batch["labels"].node_task1, batch["labels"].batch) + unbatched_node_labels2 = unbatch(batch["labels"].node_task2, batch["labels"].batch) + unbatched_node_features = unbatch(batch["features"].feat, batch["features"].batch) + + atom_types1 = unbatched_node_labels1[0].squeeze() + atom_types2 = unbatched_node_labels2[0].squeeze() + atom_types_from_features = unbatched_node_features[0].argmax(1) + + np.testing.assert_array_equal(atom_types_from_features, atom_types1) + np.testing.assert_array_equal(atom_types_from_features, atom_types2) # Delete the cache if already exist if exists(TEMP_CACHE_DATA_PATH): rm(TEMP_CACHE_DATA_PATH, recursive=True) - # Check datamodule with two tasks with each one label + ############################################################################### + ### Test IV: Merging node-level task on graph-level task with no node order ### + ### NOTE: Works as rdkit does not merge ordered_smiles vs. unordered smiles ### + ############################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["CCO"], + "graph_labels": [1.], + } + ) + + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels task_specific_args = { - "task_1": {"task_level": "node", "label_cols": ["node_charges_mulliken"], "smiles_col": "ordered_smiles", "seed": 41, **task_kwargs}, - "task_2": {"task_level": "node", "label_cols": ["node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 43, **task_kwargs}, + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, } - dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH) + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) dm.prepare_data() dm.setup() dm.train_ds.return_smiles = True - self.assertEqual(len(dm.train_ds), 10) - dl = dm.train_dataloader() batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) + + # Ignore NaNs + nan_indices = atom_types.isnan() + atom_types_from_features[nan_indices] = 333 + atom_types[nan_indices] = 333 + + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ##################################################################################### + ### Test V: Merging node-level task on graph-level task with different node order ### + ### TODO: Will currently fail ### + ##################################################################################### + + # Create input data + df1 = pd.DataFrame( + { + "ordered_smiles": ["[C:0][C:1][O:2]"], + "graph_labels": [1.], + } + ) - smiles = batch["smiles"] - unbatched_node_labels_1 = unbatch(batch["labels"].node_task_1, batch["labels"].batch) - unbatched_node_labels_2 = unbatch(batch["labels"].node_task_2, batch["labels"].batch) + df2 = pd.DataFrame( + { + "ordered_smiles": ["[O:0][C:1][C:2]"], + "node_labels": [[2., 0., 0.]], + } + ) + + task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0} + task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs}, + "task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs}, + } + + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + atom_types = batch["labels"].node_task2.squeeze() + atom_types_from_features = batch["features"].feat.argmax(1) - processed_labels = { - smiles[idx]: torch.cat([unbatched_node_labels_1[idx], unbatched_node_labels_2[idx]], dim=-1) for idx in range(len(smiles)) + np.testing.assert_array_equal(atom_types, atom_types_from_features) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) + + ############################ + ### Test VI: ... ### + ### TODO: To be finished ### + ############################ + + # Create input data + df = pd.DataFrame( + { + "smiles": ["CCO", "OCC", "COC", "[C:0][C:1][O:2]", "[O:0][C:1][C:2]", "[C:0][O:1][C:2]"], + "graph_labels": [0., 0., 1., 0., 0., 1.], + } + ) + + task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0} + + # Check datamodule with single task and two labels + task_specific_args = { + "task": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "smiles", "seed": 42, **task_kwargs}, } - for key in raw_labels.keys(): - assert torch.abs(raw_labels[key] - processed_labels[key]).max() < 1e-3 + dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]}) + dm.prepare_data() + dm.setup() + + dm.train_ds.return_smiles = True + + dl = dm.train_dataloader() + + batch = next(iter(dl)) + + # Delete the cache if already exist + if exists(TEMP_CACHE_DATA_PATH): + rm(TEMP_CACHE_DATA_PATH, recursive=True) if __name__ == "__main__":