From 7cf5b8001db26d2e7bbda08942879b4184635869 Mon Sep 17 00:00:00 2001
From: Jonas Frey
+ Overview •
+ Results •
+ Installation •
+ Experiments •
+ Development •
+ Credits
+
+Code for ICML 2024 "Position Paper: Learning with 3D rotations, a hitchhiker’s guide to SO(3)" Paper.
+
+
((t4#yfM6ELYGZsRykwz+3N)`O= zpBNQ^5|OjQZ1jRxW!%Ti7*$+j?S7~)esYF9q2UNEaAr2N0eb>drT9hm8jsZLg7;wn z$5fwAmwEN;hFst2b(P?K9woOFxTL6@@Z5^S1gt!Efo$Q|PyxD)GEc?EEtcIgBLFrw zNXCO4o8nOn3w7nEfCipGR6m>>b89AsIN-*sJ;Y!au221{;=J<6%kwsgRUf1ioU_L{}WjCEKOiYl$2+srV&(Jz0)#^*(%$@9~ zXlA%G =^p>f@#N8ut5`PV(h{-xo40h zLcUY3lmUtni6eK(Nd|ZiSg(rc9wOK?sqO4oKO8&@39or-9**d3 c|Bb>^MZ&KwAZZ_HB6i14&0A4Eu |73Q8P)QoF@Q(vA!7 z_A{4&s>fsNaS>}Z%bz$a|B|#1t>ZzAp`4(_u+^*-%| X6ZW5)SE$+S0yiKh< z*a+1OGRCcQUWISbd#zq=?(!9YI}ybKCBAbLu$cq5Upe9VhT-V1|L&@ o@RpC9?{kfSv;&TSqcN+fqK|lD>NMT4T{P& zzIz5ax~(}I<5!3KayZxMV|QVZM*uRlJt>K)NUXO1$BYnb@mq{qYZdi2X!$lsDwdO+ zBmb^LLMf0XdbJ&c@HBMv#k|#uVbDHJ@J!|T&w$Bzp5ccvB+YVhR=5c+=ZlintfrL! zsA~`S-7KH=&sUsavw^6G;c2FGDf(X{X*xv4bRFakE%U5&=BRgj&Gqa8O>H@Qg?TS& z)ST-)t6P?X-D%*3mjQD7wSr6iaUU+ob_sj{o$3C;+#NN(cuhRLMXJ87+L8)TyxS5p zk9~2!n#Uzhd1-zJaXwSo{~!vzFs0t|Y3}jDg7dQm20Od1AcF96NOE^KHx`e?a56 nxcXS@F7vnv*L!({7{GltWoUvHact%$)-S>W`> zXEvKoA&| >Ap)7Wq;CB0d88B*D9h)Mk%0|<|T?NGVBt@N;Y&_DApg# zMLigMNXZys72EduV^ZXd+s@1LMKustW-b`Xt{RI(-f7#z4~|plehbhx^d!CO;|KMb zDv+N`vVtOJ$xq<`g*%KGJo{|yr-<#J6{7=bPMi2MJD=^pF73)8f1Op{6vy(f-*4-E zO8R7^qdaX<>n`Ne`-G`qPaS{1Eg^q$2Tyt4#Jr@hoAHOdJ_TSGit7+MAo;>+?Xun# zk_9TLAdMG2Gz#)B6zEP2co#USO%$BJS4bf1I@p~IG5Nx}94gN$cnq2t=Bqs@3Y|*b zb=D_W?yHxGBa9vtjJt+y$%CKDPF&;Q=&MMkxCW5Lg(CNj!9aA#@0+)jJ8dTPf-)lp z;(f<>#W1_XBejbiGLoWxY}BUrV_lxfZ)5g@Md7E{$=sT@VVVLGx@??d){Qpskj3HA zEyXuxSWzP^&99X4Zex6?T#i*0u3)dcdazz~Lzjl(;Cv|Ga18n&3R!LTKXt&rJzRHu z*OpqO+`JURLHE7c9&mQ%$HU3c>T#u>3>yvjtpGpWm4P APY`{Z0H0J`-N)M!2a~h#pDG^B3Cp zwa$m_n6teMUeEhxnqUv^ug!&DAbf!4I8d}uv7I5gQe&J9wO&1K;^Z)h^fcslm(UW= zK`0}R+amK8L84AMd1gQuGvH0PXeSc%!&P66;Adj>uV@17Aem 81AKEgsLVP1 z6%75GtE~1M?>usVm%*HZ;QZx&NvD*UsRA_|Ij!0^J_@FNfJt3?v<@uWM}!jbX+@4$ z?M+RevjgGlD8++sAdOejHx@2x!UFuMq c>!DWIznvyx1oDY&Ekv!K$-!j;+lsc6UNV=_pn*t_WP z>j!yHSo-fI7E3OU2()B>$Psx2qh8L-xdSE|E&YC$!O^4W_1Bl8sd^jz7$h`tO}rx_ z;A?vJzO&Wf7q6GyT*mN!Ir|)R?zvs&8(`7W1=x*63km;-L`}B+8QlU6&&?spIq*!m z*qBHQ@w}>IS|!`bY-nFDoXKh^bBPd?P9EkQaCH9ERbBlYiV(CsGnOA;W4F>Uyge&2 zd+C0#lyz}S52X3+@fLoTSR%$= IW{Gp(Cu5AgL% zk%gZYNU*QA*4Uk1s0q)1@TM CB6jE8AM?~Y&g5@+1J^%v9!+F%9*c`qD(Ny- zgn>0zjTDi_3*cUp4R!KcY6gep+BW-D6X*W3{{h?i*Q)B^C}tHlPn>3-eqFwf)i!F| zsh{ddx^@T;0jvh>@WS3Jn%CJ2VVwkG-+O63rY2*#k<+m6^p>#w*&Ig9 zBO#0O2d^7b)q>QPx~IKr-kTjp!sMWVb8FKntD_bsO6wG=nDrP1&)3ZpLeR0&GFZ9V z#f57_uVxKO*w;4}y%4O7V1D`)TL;;wf-gPAo?ZWw;xd-^`j6j;rY&Rh-|0 SK%@ zVBxE730P HL5i8+zxJYc!R}Uj1 zhEW8n5>9ghHHUP){tvHi<9V@mKF?s3&+$E^8KK?dBN9O(B;7gF)x-^XeN l(l{|H{TL#32B{ynTeVO*_&`#ccx~?wro@t+ZnDPk050NH{d}|{Y1aHjweos zmk=RWxu6D7xX^-9e9`BC=(p#pq0$rzXBAg0c{p%xXGX%T>Lybvf(k2O2h}wO5<6^Y z9VuDhcO@$HqtkdUq4GyWN8|H%G^M}Gp5tY#3PH6Ay~?721Wwa-v-aw?cD<{UI+nBh zcwVY9yhm!Br6Gz3->q3YB0Y}+kTq>TJPA412@VeP0SGSMW*J`rk-z{vN+W!D(n)Gh zRB-^(mZc^1fi0B=68%(VYe6mG2ONqrb|4QQu49^;P2r;_vf`6>J)9o+;5C5o^h6jE z`(&$q|3VNVg$N!}hl`v-@{cBqhRV}HO6*7?3#{ A{?8 z7#q&=ZMkwlcT Ed!0uoOHkfa3S?CD!ho#HJ%-lop2Hy_{pxPJ-nF!?A4`D06 z Y_3ecy|z<0i8B|{PpovIX9_~L(*cziQULmS2oDK73i&67ZEC|ziw<2W42iQTpO z`A>!vp3v (15$=7Vs~^R><&RHSTEW6hQY=u@(0YJ?ON}nN`v_wI#}i~ zI8`!C8xK+fjD_BK-G>54Wf%|qi2|nl3Q@>kOthWT_DyCw+c^< Wzs>`D@<77Lvo8C(tj#^fIVEG@161;DLqxy3JI@v#eZqW2E$&qaHN!#+~S z9(p64{T+K6AO(dVH?%yRxVgj9i$+5duDvLTWx <1wkUJLan8s!R(m*?Sn_bP?=4e>sox+>KG`K#%t1UBl~UY$;}KX9 z-?76Kb^JpCY97D84nz$CI+YCL2j?7DMITg`wr$7gILm<0O8$BW9ywX3YztV?jn^5w z=3gU!Co_ex1Cxg@Ev2<}_Pv+I8CCL9Gk1o^1MF7#MJhe3<>n1*y>mN+fqaw@hZ)RY zs4A^?$e=Ix0)A&c%(B9Kh&%2Jn!pWtFYlOI2%f{?fzdbrnUtB8VvE=4xhD^PQANNn ze-IVSfubb#EEf2*NQq?M(P= {btl@ znN@}Ru`*N#z3_nq8oGJsEs5Q5MMA9`>qA-VGBOF`j;0RS{H#!%XP9ii#LnOH3j=DZ zQ>4P%f|+be*B`#pd;5 |t1Nn-_Dox9 zNB =`sw0D57E;iurEY{K5WYhpqNR)V-5vOd!asem&9rM3kSbigqGm3koX_edOI zzRR8HblSvraMA!eB@n&9O*#%w2glU9=1;y1dtVnJd8{o$9GUXY8Zw-WN#~(kEmXWI zea A`~ oVGaHEaRN;dj<0&p$HBE2ONC44A4ROqEcZ~L$w5$iI=+0QF z^&D@7yA6Lwu*SFF6Sfx)?0`S?itUDF82QPlrjm;0p3sGl$_bK52|#zg{d&jGQn1@$ zcsGnzO|G9<3U_!_(y#E5ZF6O~&E>y?8p2JiXfVVVzO-=UIDlu~<7_M*kEerSG_iS= zT{@(8Q;nqlnKWMTPbr&7BI&KzxU?R4cM=uAmST3kSMUHueqjJ_w+uGZ$i+&T9I!{J zYe4+OoP=JWZm+YQdlk|sU61+WyUhj`HAIQDNb21&VS9(ipFPF4px&)dXMyO2F;!RT z`!I7bf8k^F+%4RlYRIG@%yI-#y7$)h9V_kwK)(>TXh1SMLS2KV)UbA(wBw%G6mIaq zr;^tWZrR~5>2?ehUJbw8t*8k*{4D7!(FnqbMW(i>W{=<;=}<&cD3DAH6x#W_VEU(2 zQ!o7?xL|l@nOMRcwTXzuzs;f>vg=gAHA0!_y7dc1-Fun3i4?|j9*mrt6fz_rshISt z;;ft;&wgE$w|F}60piD)W|JXZzVj;wTga_X3|S 95=RLLpKh~u7IGM2?tA~m |j*k#0(R+?Oe2$ltwdK* (X+BiURi zNGJ;#TOO5b*s&pd!PmdhkjILN)n{C@_Y%7@2=YzS+zEuqTd!JZr{A_H+z}$aseBVJ zxx {Qh5-#0gTU|dYGj|m!Oacy@D+CH{yTcS%2Yp{{^ruW &DtwYqpJ+(g1T>;yMNYem^BKuo7npo}i^ z0~1W9#uuu0sjUcG*hFHPe%aYgCL6`xq>QrM1*(gL2XQa%LGrP-D!FGECylTj{Y>bx zr}zr*N%9foOU{lA@f)^ycJ(}mS<_S-eS~WqJ{ME+H`~7O^tF_y^$Y#RnOJ%CRZ}SL z*&M$)OE>A|0tD2Q8;C6ZOqG0()S?EmdBCehNfPeQho#+E;xR*q%gdvpX?8|( v40D>M`lE73olTIY;B|_%^w!Q808EJ`?XeEmT)L zmK4WZAk|~lQ_pz6_T|?eGpr{^(Y@kn)5iI3nx(TLL2>t!UL*%{82)LFy&F+6_i)fN z){7!{{gxQv;eqgzaV}ct`FBp6nE@kZzYE6$tYSb3wg|TO!B)UKwL+uL`#TQG(_ij0 zv&$1u6z8WOJU^wLgCcP$VGffZmMZ(dJ-$8DA)rfQQGUg2h71IfJo&*IcP~dH`85yW zdh2%}#G)l`)+?;Ro?U0ts|%zCnfv}^s?HCjsf@S}&{CIW-s#6>8e|m(ls)F}7Q^g& zTw%(;BWx>4&f{iP?M9&op!?s@+ZmV?5#QG8<);vru?{DoR4I>ufUZF6=N_yXM|o zD}r{dgkGfw$-o)xe_aS`L@ALKkdKJw|Ex}}UylY78FklJu^^ua8TJ >&?^}LI$GQfY6o}wCYq#ubpgenqyE-j^qNnIAlqZJp@8V(AO z`&dURMx8V;m2av3`Pq5FO;{6rJ0#48wOzko0R^^?`N{-1vG>83V$juwS8`W>TwG1G z0;kNjpUK^@O%En+I)*l*racsU#S}yy4c$^$Duz~l9KXOPL8K9Ny*-HfJjRLonZ6Vm zl-l;kL-=-UO)X2U `_K(d1BWnQkD(a49uLnf!5s)}EQw&9xeSUo*I|5A z&z$(TMncHrb>V}8HCyIqLG#Hzlwp@D&vD~F}CKt zaqQEB`JM@amwqGsN1Y`kj7F7M6ORFUGVnMeu>-gAah|TDvjpm9uWwaL0jm?GIP#{{9f1T0tN2sxrM)*h< zp JdO?7{&fG7Cjh`eW)oZzz z0i0a0-5*OCMx@KWd=!%~(!0~M-bV(`fYe?Nwzk~PrY65OZt#?s5j|I_F0>@+B8o_& zO)Y^oAUkZplRkjY^!z4p9rzI)sd96_R|qtl{@&EhM?*!$fRyk^r>n-}+ ^I3EnrrJ#%`q z(^36&*zFl;Q-_8Pr7uKxu^0uS96ol1g_qyI@ST60L#oV}cayQmtY{ZZp8B4aCuy$a zOu9NK$vVa-J%7MfDTc _Lp1&d48}m}!l>0E(Qj3Q4-c2gCRu zKb>31ZC1l24}U({(LC3cek)RU%ZZc3%z;pqf}aVBV)70~oGtlPVd9K=1vg6j0TPH- zw}Vf4>$zMap}Qztoov3w!aZ_g-9*|2A;bggL&5V(Uo$xOb@$cL^bK+0m7Z{N&! D}W07%}*Y8yBn543&vX$<1$l?pio5>#;aHe5-tP zDS&q1HFbq?n4yv^s7p=S&GX*knahj(i|_-}`)RfLIBroj%^gnt7B=)l7YaeCbIgEW z>+Cw6X)Hxx*O)St9sS3L8AQo$(xzVK(v6fjwW7i6Wq3!Wc`eJrI~3}dkdZadZ *gfCcChKaJPd^@u|DiW>H=!Juo z^x)ayad~kjTt?JAiJqJ#@rsfMu`%%iRLV5JIFC`{EjsW{`j_o+cE5)5{Yg=;g#7Eo zRktxs61bV+E?5XN$ 67TulprFO5g6dckD4yXkp`~IbJ(yUfw(ug z_*C?QTgFOV+(pGi!4`$>I#c@ozw2u3Bzf+TWzgr{s) u5I~etPi`Y5X#+c_MC#bc6Ha(o2gDGV!jQV1Dnb8Z+jPU1)8#x1^EF7#fns*U5O~ zP7P+Swl3-Dec(xrr3Xw+ZXN9AWzeUrJ=s1y9ahc>)3aHGD?-QlY(p~E-3~eh_?!=H z8K58#L^fb^%!2X+8-NY)4V4I|A*nSFkr!vgHNh @W$RdXRYMCBwOgw8$wz5SirDZR{w0- po63 zY_5hA08nzD($hjUxoTfij*MRi`{pKWU>puU`{RA@9k=4QK={gtAJ9<-q6A9;8RmeW3@)F_ik+t~6;J2nt0_(Ntt?kv z!+cI)&3@X8B%&@D^C2kyh#kYcIyT +vAd$Jyn}prfPKoy-1|8m5gJEh-BqXjYWI#us5L(YjOh>SIOk-|+4U zk$o|>7B+2M2P7S%fWX>-6b_OL#Zz^9&s#RlO#D`P_pksTqm5+xWG9*p6~<_KUN^C< zJNL}`3AV{Ap&|_3dKV~vHW#Wh?E>;ey3uW%_AyOJ^n}?fMVs}Z2N6mqw%^R14|M6? zt|R*#C|PT7P2w}kQ)Jp3u|W$kz{-Hb+g7K3)|*F+)YiZ;q`10 1wv)pDWB70w z2j$+xG_zIfb@+TrzIAlgA1npbgT36R4fZ+gqmMlrLK)RWApImD^F$|)$S1K3X7$tr zO>0`nqj7*R2S0R1E;|c?B9KP@2HZV)^E8B=M?9|JQA%xmopd)poE2WI8(?nlVAVsK z#R5rqFcMh)Nz)+6NEitK7t!X^!YiH3TQpE#25Bpo(9w22?64eLc|~-c`js3#ZF)uZ zLL{QPK=J>$N)1+uyf#ds-HZpa(5WjN!6}E(tk}`*bJV0I#WK)|t#!4c`_>B Ydyf$kycMC}r&8n!V{5i@(wXfj0 zwacJ0A4bM7=1$DF;$mmCKGOTPJZ&E1=(b0)O#!UDn=(!MVE6gY3tCJGaQbL1g1MXU zEIcanaE&kR&*OMr-W)KFNq2vWKTs;+XLUHM*e5?MeohuR;kmZz=Dd2r(yhg;q3-&c z=3VH$#Q5e9li!g*tRKo2_8BUzj|(zcysSD>^9tgW9;tac$7`q(bqTedO?d`A9{nQ+ z?MB!0Y-#@|Uj ^g`D6?|;9c z&VLV7q>^zJVFayh{``=p&Zq?voLzvz>bH8{^&;-0+`PqrD$kwmkF2!-wQRP|hhC@= zx4M8&^Qitg>85GtEp5Kj<+|>!Xd;@nEc~cE!^ -CPnG%nF_u54A8a;9!z|+Z^E#7UT?n$;4 z%@609NEA%{tp1Ur)*<%kk{b&$-0wT3`3QKn=E8Rx{N7EL5ASm_Eu%? c*3aks$HLxHZmTC*RgM)rwdUIgV}~ zclKK#cWRs9H+SMGes<|_IN@Ypac&Qj4xQ;%xT_d*L(wpPQ`BI#&GcLN*qMoCT@5#1 z*xLAi#l3Y@9ZS 9#qH4*Q|jCK})bII(wfimY=Q* zrX7p##i+GeB{1QxLNj=pF~PlBDNK uq62MUYd>zJZBT&=AGu59?eN9#TjrR^ 9cOS4I4ty1V2q_ zC`orOlbCCP^gEd<2wr~_zd62A4=kNsKBIaEac$4qY?%%FF^J97En_x}!sj)_UaA_v z4D2IFNvkrtSVt>^B66D;Pz{^ab8+4MdEzNJ5W D2dn@*!aPgqV!_srB}fR*a1@hsB84WloZBF z5zvg3aUp`3tK_F~g6bN$@>WZGL@B=<69!5#SKlTQU@?|R&`>i_RPCB)vsVs#jn^du ztO}e8J%to{^}qDK2@JQ1MuTDxLh5leoA34AsQbn5WVU?rYn>8WTwQMT`xv`MABDIr zuZ4)h)90T>dUlvNs!(GCX|tD`W=w-HrqH^8R=7|dLa=I=ZI2zRwbq6D;TtvsA+XTQ zyi)DO_$X?xMCW!kUB}b2nO@kP^J#lWlH} UJFPq(_>tTAA46=G>@pAdUKK3NpA% z<*0tn+ar3`(Hc2;ZdY&HF<+KyQ-VaH2pJL!9&fAAs}_&VE+!kuytj0U{sGkrI3JIi z_Gz4xVN;~Mj55@E4k@Q9kU!O6q~XJmiHDBJQkd|L;$gZ1fkn|oBLpsNz6X{9{TT?f zcim@i-|0IBv$Xz7$J7Q_QYgo^hYlZLL28v_^KVeD1>hmB-U0 z+cld#yf5a}_ob}(Wi|oU0Wf2`MkCWr?G$|CM|Oe!LNMb1bDJIk^R-RnW(I=@Er(WV zD6y_YPGF_C6c+@m<~)9*MEYjeaNMLJQ83Px{@$GHXXeu}J$#ow*SO8J-vV@Q2vROY zt?n#jV_cQ 5GEliX08HM$XX3BIP;N79o(u4s`On-10}{&JnTS zLt^5j8k7f0pMv(gznwpcPNNd^vQ*D-OHdQ>H{3Q{Dx++Zpl9_lfdESdZs)9dDjxKh zxyBYUl%3v2s|5Kreut?{5v2LL;2V0&1vK}G49WgCfcBRaeichK(K~#+>^^*Gq3_?> ziZ6J(SXR%mK;{l?>|Xy#*0GCJ?|NZ7R=TM@f3mYQNErnN!5=))#!Gg3Gl4A7td47} zm!%Jl=?zdz63J;aX>=jiy!F49PPC)*gv9h&Y`;!%M+~c&Oc#Z; Jx3_SJv8*z6ZS)`qk6pPa5w*g+^nXqyTeJ-6b9PgYd0`O}yrag6NUPmF+ zzjd*QQ0`V;yuDXX&Rv0GkOa)Rcb|XxL8ZA+OFiDc8)y0<)IG}yxUao;!ceqV)%%*u zlxfeysiJowmfiv#2Bv-5CGGgttqqpPbUe)-wL=QfhMh}2mGT;jFNmlOa3MXr#{0w> z09L;}l2RoZ)h?IU!H&ifTPIR?0uR3?c`^cXWny3tBK4q`8zF9VeB>sj7yl^j);l-D ziV_*Sme^)F%e~C8wg1hAq0+I2t!_>h=fRmkkVyVSSwr$1RYG1Os^uW1vCnlc0qL7p zC c-C`VBayMzk+Z&3U$ck?!3-{3s-#Vw#vs9w44mO;is z0gEKI+I Wb~)AEZQ}Lr_w$ zZvu(#^>YTUAe^YyrSvmRfB*}JpjecmIxtc6SMb6x6Xkm0V(=1yxlMi z~as?US210^@ zy9+E+we^KtBqg+@*K|-oDYNws)Or#Q$tDXCuM$mVPcLEkhFdKczRm=-?Rcey@4e0o zsT8w%8<}y@M4OAy`FYhUt&5#A+6>9E!ww(vc2T<7BiO?`z)GX4_B0U^AE#k }Tz#mo65dnlIK>^*0z#T~I?$R|WVx^pj-QOfz59hjUfsXKah z8WDC^^jR(x&ffAM`V|O%UigWqaTA!{GEHR9+Ioy$@qR{rZW=S#)SpoS$KgF-aLo*C zfsd7^HA`feK1S+5^Nk}0BD1R5x}%fR)71h8$@UEm>(jhi#9*a^T0)IUYKV%jEWLU` zVh<1q?|{WjjmiDWD+T?DqrV$4P|9g_X2?^N8kIz zuZDx-1l@ZQ$oqERBI{lr1lOr@V1Xvqj}I62>#&P+lu?7w&9eOAfp+$@s@IQ0ftJK# z|3&wT=r!F=0RHmb(<#=OR{|jvRmXzaai!n@!}^tnxa}RYnB#q!R{u)!SJsjm>}Jxy z96p#y-x$UOi}^!$lhe*0QemTT1j3QpLfqCrh&sF%9^zpR8O^jIzrH46#UU588R$VX zPF6}6@FZMA?VJHvQB>abMwPUO4o@+if6_RrHu#B=K)4;@L(oFUIYy+`he&t3(Ewa) zk-d?87$e~}BLNxO(*;YRC1PiGL4RW$?i~4wds2i#t_)clwBap2VLRcpPQe&BK5_!# z7~=@;HP_f^aZ5#qsjPN`x8i^b>697Sk| ^jy_ zktx#)#oxi+Cz+xF?4MM5d8{@+E);e&L@^nH8`T8~4J #nGwu5TrH6Ab~5NQu)heaUj!Oqy `R{^>VTPRJImRZPAP&{dnzV%Mb0$n*4*89kSPh6!jdKuX=;Mz z-dT$&LZH|n1}dSr_BbT*GFRxy!axRor+$bNEa2UBjBv>BvDdoA0xLq1wrr|toC_mq z$3jHj0sLaK#sO%rlZkvSW5ppyoXLogEu`eHfM-8M`ILDmM^j5zeLCF)M44O&FaixVGw5#TkQ-PxTN9u07JhpX+KS9pKQ2<=bheY8KMFG#6=+9 z9!i_WJk$t>LyT3bXUs%kfd36PGg+tMQn;G)fC{-Nu|IjZi!!mZqUik>+10I(G?|hk zY-4JJnSQ8Z-wrF%0JAYFyAUqC+w-Ga;e0_e#AhnIpZIhHkY-Yo) oOU%x7)tC>G+7*0XLL0^y(otInB%tFJ~{7=^)4!!#{I4&BhNq84Ht zwCU5>gC%odTM;1*aSCVvk2b3GNZZ-znwd|mT=I+<(>p=00P @H!{HXMW`L<8r#8Jay!XDHuK>DSn?pGCh*wggVxz|r`A|Y9G7BrZ3$uJo@jz( z-^{tKPycE;GZ7(U(?)s+h{w!zW!Spz*01lKI~&EYGk52Z{Vt1X(~bPSV^!r$dMl`7 zNm(p~xVhU4Se$Ev0I5;@nF<%;9<7H?{mlE<
FLA)=O}Js0jn*2oG=?A%6#_=JWtt3Ep-e~` zvMAS^pxWzI_*pr>GW-Cuj3Np<>=aMy5$?r@pCZHY0ZzOnNV=ZSSWA;C+sSY#F$$Xk z#?}1dxGiSly58+5cJUs}RQCSUIvxbWhG^#Ig4l%e(C%}oHCuGqoWIsX+7S+cX)p2s z 7i2lW4P>mG{0QUkWQpFZ=dS-sN|Gj`*$A z*fa+Y6f^k|l9%MRQh;{8U))mdyt2{zEleyhx_EaCh^#7~Fs3asMPe>sDazI#z9id3 z)`>qomo1fu^>-+ggk0_Xi3Qb60;x1K{+tV$?=*N2Rbeujge{-+blQa?0uO@^9*#SU z(h{L!L}CRBOUAD4vA35anDbR%KgeRw-D$Co?hZQSs#j09FbFp3qWMV6AC6FRlN|>u zboRFIjJ$@iHl~1djRN4DX&9>YP&yV+X-ZrGcMt1V=i&91aKal|9M<#wx6>R_ogCtv z Ck$ z`~XTSvbe4L63tnl@Sfo9H7Buwl7Rn>lplRj^Ip~hS&N_KQjTe2FYckSTK^}ZLcr3+ zJF)ns-o9x2=l;24M&afXUz;ydp2&E<%aXLkY1P1aftut#Z;r!FHe{h!G^ypyvecY1 zxtb6w5+IEk!h&M6EAY(biFbrZBS`5`0`MP01%zvDgcwRV$32ijfcLN<^QR(ArBhF` z?> 001zAg}gj@%W3mCtC| z&MJL}fTL9q$AU;NCHUTzlRtJlE_CHkbtAV&v3_ed)0sZ`TY{F?wJZ+imX;fV@@&lh zO@5j5B^#tDpcBJ1ON|i7uukcN<=>$IN+ttAzsptLWJJCb>1wyfkKV6QBmA$$gD0EL zxyVJ76LR8lDn2TXvKD?ugj1Y03NV+n#PLrn)6!H6LC*+9bTV+{g`sD+rRy!H4jait z6ro_VMyf>4z9i?o{@CKo&%XKO^2I+Bovr8=5CS;oG%3alu6mAy`6{C!;LtNbz@V2l zk$BV0)M*+>c|RrqA#p!=@DmV?j)vE{l Ztie_1A31yNK6R2&Q9f$?LjsRlqf$jW&f_uY`|eI@1&;Uv zfY(-O{BC*jxR1!hD~vSRX^>@Vy7w^;GQZCX=ESE!gz)1Y!tANH;HBIYQHbo(I{~@+ zRR~u%@|_|@kA764Eg4wx6Y+RCP2v?%E_XR$6uGX=!y!x!#Od;)nG7|n 6_*-lfAS3~{ zR49MFl5JL=_oR=c4Ht<5kR%6KHO73A$;$eidBn)1!(KB@Pyl=B8Am=v<>W6fcXdi2 zkZs}vnh4`C)szjj)||y;s|ZsN7# JtH(u9# z6{Nck5qNb;KrrZ1jGxLq0hJPciW`0>QAEgVguH$!;kV#?X1<41hz}dmp-b~pSZMv{ zA6W!Kv#LVPD$E66KDm Dt>Cz?f_&(ou{{-GS2D-xmAv`eCL+for;~8MjDmuXT1a-U(O8YH@_IqzBsQ8IUOw z-m%SZhf9LyShOr`!t-B|=_7=5kg%>!Crb0Bt3a>cGJH}sK~f?oki4&n*ot?ZnQ%s{ zUveiNXYPPS`BXC+^`)r!rc->#pG(3LH$WEX vy Z-F9Ax^lKK2pww^r!Lk z;6k73M8u==ci|%Xy7=foknh*0b8tlyZ|Vg~wJIlif zHbJqram~l~rmbBSAZ*FFTT -1^!j6tW{J9R>vs?;YxZFM olnFdxL?2<9@2%r92c0|2y3%H zb@!6#CK<4f)KWWs+E3a<9Sx&+<+0z;=nBJafV*=gci#_Hk^Y4PPeP;FXij0Q7FNH6 zFMQ~k5Lgtw?qYj=NFtfMxe~!1hu`=P83%gf7)L%b+3Ungt(urrI)3?^oT1V~Cy|3o zf{IKkh-5^hbWH1@qG1h-_*lSj6NZmHHXicL17@w3->Z$B$~j`JJz%onqOA~b7)&16 zE_{6k#e8}O>z1z4+Pz GpUMRKp=e{$7QyQO2i+PY@ qA@4)w~RU&=SrQ 6ic z 8*}~x sc}og)alSe_`sY)f!TduY$5z~RW LU(9e}!D43OyzaSu%Y-wqhBtU4xbPTv zSR{|U`%w?$b(-?>_!Pf%3*H$ZI3fW0zPYXqNiCMDc%N$aD)ScRfxkaV=}kPRVh~k# zfmLysMI~+gXIEGBU5{5iyust2M|V y6?_-i;kUz#LAl4xSELxCsc9ZMWk(`EX2sd~` zkg_i+-g}2ODB~V#O3Js~Z#8kDnFr^fSD+i>7D8C5ZaGy-obi2nM-Ze6G>lWFTvQ6X z*3pl2D+_As&@B zu8QzCD%Z-th%8bQr)V2a3=kuAkKDFyvcQ>V>vm+W(i307)B=h-iA3$kt)P~Q)ZkTJ zPMCNSH&W^S`BycowD+(cMa0Jh6dTLzsD^R8-m5xS))F0ei&XVnTeFqgvKMz8<7d3# zFG?KkP5gzNX3GqJL+(1K`mBldHUEA>!Usn!;!=9kUhY(4PJ6rzdU$z6g3PP3v$LI8 zt1gYA>&u+tX$7kLv;0Q$0+^YsSIf;{?t+e_v$8fM`DlA9O0LL+SLRZKMG7G8BjHWIBTCmL`NOqy{mR@@&oFLK^OQq^pU9%- zhaR9%uer3vraOzPJ>_FeW!K-KB^8wd`*e&)-|+BU7CqH9RL(rUKiF*v{1Fr(ixYZ~ zA*`&1@V5OD zLRLw`4H4o-foR)g6ztdNviM*n6QbQ#s#-zvFv3Vj#7G*Dn2+YL{?3M<3%QnkY}*XI zXAz7m+kPC#hAAN(H8+~kl7n!SnMRr;Dy7Uh{OI<$VqvDUVp-=-78E}7BuUa6_w;o3 zu41@k$a2|xyMK@&owc&g^$!ssWs0(CHtFHYYHJ?%*08c9xggm>)CdSSg0(w1$fH4= zCW1BN`Vn90ERXOes^s0aj}9OnL RZez)5Szf5bi2_?E-<*fC#m52FtRyN!mLIge4LJ(R%p_Q|p>9 z@v$bAQA%k*=>7_+(i=!qEw}}>%4FNvn5dyhTyWa01VydTECjh%D`v{z zAHt`Xad%;-+jgQ^8#QlP=iR(a(*-%eF!)|g uTW zdUEF0kLv_mT0gf^` 8Y?|Yl80XUt>Dwl3a zl`GZgx9aD_g=cT Nhao$8zEc^NtGBw2tzAz2?W~(iCkh-?0eY9XnLs8 z<*yrh)AYCLbA8L2RS=3*^ovrtXCf>$4 CIL*ABVsvg{6aJ=YCSuh=~3Zmb5S|BY50 z9j}FU<9?iMnrWTrl0JrRWjyo8q8E817246-Q?e%~Eo6l< lO zZ(4;&gd5}&vV}f6HaJ%1@MsW K7e<6oRB3b9}9rV@0~DE`)&q%s3e>IV>2uBaRVf0Q) gHa y=KEg<5%_M-~yNfm8k?h@IIPCql(YC(6U)@=vYGV;JRjy8p z`Ig!yzA>MQTbU#d_4IYM{1=#xS{gfNbM-eXc{%->_gDr!4vL$Q >lZq~;zS{c-b;O=djXt`FL)w-a}`l+p_)1_ zUb8crbOxOj7BCa!O?J* JylL5m^om zI*5V8MzFeM7#!xY(SJu7p36Z(HzHA<_Tk$lY%|uO(@1TSbV(0FIu`J9gKZy`v6F%$ zl GCpl7bghl0VI5UraG@q#I;2lMu z(IOPOWQqJJRu?BRgEW>25>M}&h|KgnkuUl*L-w8QhJBDUaz>^PR+3aX5!@XZN9nsI z1QK!=3!0;px7`T5biy#_^=A5%^yGTIgw$^S;xbd*Q4GNAH1^>$ 32-Cv7; zI*6Hz8~(hY^uhKbYE^QW-oOY$1xLPp|9!bcZMbv`xHtX6Tj!E|Sgf0}oC{N~qSn@* z?%fJKl4=}9HBa|yb6(8OXR MrCA NN?# zUp)c06_pz&U $kxZzyrpF5T>pB$7e#obv#KyurLWL1=nz z#$AW}ECK%;5UDh&zfXbZgt;*kJZ5xK>LppbR!hCii4k9HzOaZ>IrzZ3X%xn~N%= zf&~fZ)D5LWx`gZw&T8-gUa!$hH{=|$x$=s84?xlPSc=mW23%XqUxmCX)! ^cHcpaLpsDBU7` z_T~1_HmVjFGbOq`Z|1euG1uG`(V?{RSny0jFjsPZt!JoOVNR=U4TCF24yuqnsSs?A zE*x(#7W;qM2eu1AS9A1ZRlIF21+Z&08Ir)oLy^8g%o|$M9%LGMk(aivxQB?twNL<4 z$Cmln7w;wNkN51jlYJ)YZRiLPLb2^5A@kH>jhuH6b_w+ C{6Gm&~o zdj3Td71+LFF4J B9)xc7= ks?Uon*pr5{4JIY zTi|VQB!1+WwObk2ZSn5DnAnhpY1Lz@?TFY8Y#wC=LcW2s03nJZ$}9!sMu$w!Sl`5H zNt3k8`-^74Yi4Rm_=J)Q*F&o}B4jg{h)wAM2f>t;Yh}H$qCSI2q0cW#2{u3TmnqEu zt=Ev4VOQ<9q1II5iGY9wKpt-wnzn~Wrh6|5yA}zk8rZW_T@}~8vy(l7pqAcluN6vB zMuM_RGF3cv1U%9be0Hcz1K&(7)*Aqh1(&+2a$^A&)2OY5UGpm-EUw)yK^^|LK((?) z@eA;jL_Hy*rbil-Kch+@`h^zFtPos@uDG)mJ|;#+_Bk#Rr+qLibT&Gb#j(#GN8P(n z@3pD`IhZ9+eJC+ce|6l w6JKz*cDafrA6)gwDW^Romv2P)5w-AC* zM-0@PB2q_ProuV@UPr1C2}VxBN{zy-S52u@ZQ9yWm)S8T#fdmNP#}t XMuWf4TeETG-QPU#;9(tl z3k}+jl7 K x Ppb%OI=)}zX!A3Zi)>+e5if*tWP|Ax7M!-^5!@D{jx+zb-Ai<+ z|L~c`ZG}vqx|eIN3%HHLdj(*fxVg3h73K`$Ja%TjAvCgpML0lnWG^7y4AyB0I`EcM zIY!)#Zs;iA(QD0YGdnzJbspVcq(XiE67!j_?<8~-TQTz2Qu(C+E4@Ks4r3Bz57nXB zNtve1HoUXoWE#-7-BzqKPIP=VGRHjg?u{(>?pp6OYj1~HgR3N{COZm1x$)wOh_a4c zROGrieT*^U*D2Jt2YeCyLiLi7o-5w-JH*c661Tt(_BH*PqI;Bf)v>Dz!2$8?H|H;E zTR6J%x!i!l%u$;;MowMDLOIz++kLpK#_AFE#t`f4%Naur)m%f0aSMj$WujVy@Gf?B zqaD`%A~S<6CdG4!m!()^j+s-=qW1l9(DVR0a`xBRYRD_UePoPVC(o(NMI9+zgxMz0 zqO+)rP0D$`o4?2?&g1NX4v#n{W9+gv_uvgJYg579*j3E4n)i`U>#48F0Y(=*VJ?y6 zqi*TI!L!ElnH_}>-_wF@#ZaKtY#7gX8fan9oM~nKq>>`OtVr{=Z~Kd1#AEwsH3!QN zC2qfe@AL6_w76B^Dd*|^{$(}^`6OW$Z^wu>-|4*}_Wdn=C3KDFFsFv`lWGH@?vBf$ z5(eWrjZWDE6=jePp2uyT*WKEp{`c$kC7RW+^$5J2DK=@m7WUVV37(&yJ- H#= ztkf?hGXzGl) >n|c{`UhvCN} JafYG~26NB6_*spEv%_NwnJGK-p=7;K`8gTB^k@jed@u09Waz?-4Z zWZ1c%#Aufn-E?k=A8(6jIPERAPMAcCoVL>SHbg4X1lXvsqtOSboRbrR887-o<{w(A ze$g%-?lt+z7PHkCYi)mQ@0+d2&uW*U%LzSnc;MOU_9+CRNkSqkvkWEwby-IvA*S8zn z(m{Ev`>n4_%w^}~TR)#EztSCt^&&gZqqhpOIDCz8{1h*@*Iv2%kaN+e8T2$PF>b@L zLDhvl*0eNa%AzymVZBmg3gtL6e)u48KO}Hg1Z}D9bPKniseIKw$o%?g)Mu>;RZsWm z)_^CsP8j``3tqmbwDJ1+n$~%K>&@J`$k2)HZ9n%sof$5Z5;jMfd~7;QYA O^n8LvhPM=V{sdd5^{zfV=j szqHZ*`C?R08Jju%{)j$lS|(`Vx&p*1?;8(3P>tyX;Pku*n*rN~#c z8}E~mdz?ISLgyvPu2lNKBtg=i@=fK{ySVpjM(|IfO;2Yw0+mz~`qt=KRbu{Of;NZe z=xBY@kcTQI&8h5m!C}c;2hJZT;bHOAv KXfNJJ&8Dvw8mg zyi~<( Fs$Xiuf3AXh>{)ko?8>jzqlB>661Qma1cs)!Qd9k{>Vj+#mqJAkyTh4 zDz3gqCzW*TwqeZY)(?C3A=;+OPT?F6MGAXe-`?}gU`>y)8nx#OBZu~DRrFfd)DxwM z@rEIJdfH|Q63p_=4tZIGt}epSQi36N_@(cL(O0@&v?bW9R~u^3IcEN=^x6t%C8654 z7e*Ayu^GpGi~9^o=!O{${W5AE&duEkhTy>&a`sIAcfIXI)TvE0$)B8>?afoIX~gpj zF2rxGK3>KQgoWEu@ibcuVY>~@&b$8l!2sQ%gXfO8b4%tepfAkLvGuAA@k}UBrx_K~ zN3vQo!8J@6do`QhWMBd@W$JsZKa~_0n;$KBo>5bIq&@oFC&VS{7^lu|tS`+ts45k6 zW(eO0IY6&d1GZrF-uY^T$*-7v<$d~2i>_J+d59HupwTiYW}G!vg|Hg^Wt}m;wNl{J zKS6iCMpa!$GiXf&dHcNkVsQ4|p!$RA*3VFR*2Lu5LY}P$bJOqkgpRMv+v|+vk;&J~ zk{D W<5TcT1Pn=3hrLltlv;j# zz2Av_wtUFM9(?$QS6}kILcJRlaiRHE-b>DaylkJ97(EuIo%5<_^MTXO!A*n2dPyss zE+`D8|7lGxM}G!1PHe9d@ I?%db-=V z%6poULEo7QG9F)X*ur7dK*6g7k7ZgTT=shYb(Fy^^fd>3m}|>Tx>*$IppcMRvL+O- zzSyAmnTCbZxixFhM-RQU*Q5VrF6>IT{gNTmU&UsPGH?U$8FZF4YabuyV^K8-s-^R6 zH^32mQ!QEUdEtoC+(&|i;#~dny>Y~IB! %Fe#kEUpSV*v)#mIU+2hyLIXK9B(Z9 zaohBjWD&oMTWpgM+W?JNPBfJFRM9!uUXYL&(^zn5><$-KVZbJ vm@o5IR=5qERi3^GOW5s6K5+tP;VV zI!zw*8Dg(YXVOv)kA~zp^a!sGmLK Z8 z%8K@(XIx6zbfX!x#6ETT2GlpI+^*GLG^=SWIi5k;7@MgrLX~4fXC8r(THoBzs`E jb0f*0D_F=x)fVza~-O~oY6 z)DVuIN*)iM1b?1mm~d298?%&YxIeuW#~zN`-&WmT{~0=MDAj{^_%VCpeWV9XOa3cz z9gc?(p2r!Tilt`b4CfEjb-@(dw!4qbI!fYI(j$AAbgj#gB z9|x1R>zgglsHBo0GQDR{S)TZ1i3?MO!Yr7En55bL%__UXcvP;ln`5ni3OmQ6>Ar3o zM3&=D KJ0@k7s=gZM=^+xmx9l+;t5x0eRX%%e+h zKw)}sKDCzwPd|QXxAzvw3l&tox}MD~iDH_Ko4olz=Bl3eWM;xptL?p3-7Bcr3hJut z;GR>Gr4ctsaGCu!`5+QTNeJhbOteBvN$2ddPHRef=0XZ9$nyTVJ_3Yh5^s9yf(q=V zj=)wZncy?}s%9aOR#d$oK9H@JyB5e+V4#QREhik>3L21&7dXrc$bwn(`FUv_Rx7{I zlg2ZD6L(0KQM5{wjR1>{p0e6xX7=vhza=I$=VPBM{dKHa1Yt$Bqy2*8Y2eCxH#Y2_ znVhFOtfei!OUXejvJ9g9T7yn<3gPWP)gx0wF%`4YEpZ2x@zz;<;Fd?#@;m2@xAnuQ zf}?c;&WukKsK$qgLNt!DKI3>(att#nB$cbobC=)5AYr7oP}uGEM$8U}poJ)0Ez{~{ zh2c!!&>Cm3^;Apo^Gv(E!E4Z&T~u0 |$8QB T}S#TY}KAr zah2gm yxJHCeVdk zy?pOi2Na|huZ0otY@ZxS-PenGtklsYk@<;1t -El0m&1`mcNacm-Vq6}dx@QoHZfkOC|EAz0SgTS9JfbGbE`W18RrB$ zb_O%NB2x(Ms~BD?v2^WRbqOsdi{DO;ZrJwg5p$1xZo9_nw=f6j7;pE)nDArPioXWO zxPI!>*yVKNhMnx;qdZA;uc&X|T(qBfzPtItyeCoi0nfpUR MYN%4{98 T@p!K9GnjR# z*{0 ;2q+elVPKl0AQ03Ypv|!=YS^owQM^ zp&2j1ZfwCX+?LV$Op C*yx|vI3w{^lX39{PW~LG#%&xoTeABrr#J$e FQ`v{VI~PXSsOT-*;xOEL}K{e3W7FPdS=$YfuQLBf_$QTIi+D^pl1dF6loY) znHcE-_$L|$Iu_vPU%BXkTmW$z#^0a+{LuPMR~o>c0!sNOKo$Kz5vu-IrC?-Z05O36 z`OLqdng0ERm#Y8S=|C5}07U)SIRB|DkPduk1#B#B>=kVF3;>v{f5A`veci_1&;h6% zD=Q5h3s7YcJ~bmN4GS9sGZPy=D>Ds HVFa- ;SWT#Kk>dmOf;;ljI022Q!{~R z*y!2lSb;&tLc_wsLJtgYY6f~5CMFQ;-(YBeBa^+bIPiJtgulhepK1TO@+Yj@Z$P&{ zz}>jH|G4`1um(>5;?w`S-NBTe5g+*bb4dU5PTGHht-TcDch&xt`kV6qD+=4om45-T zy(Ihxi|rq*0u&SY`x}pJQie%Zlveo%?I<0SGWqBl_$%TKNa_EUzrVZeF9FQ-w+Q$r zc-tT4mC|#t_$@2|AUH`UD>G{w2Q$atd`c^V{~~ez1^)J1`u!!z{vhzLwx?Bip@&cZ z8z1jC*_W38{o@5$@AnTrF$H{MJxd3pf6AtRFx?+yU%>Mme&?4k`z;B6v+G|n$j0df zOYffq{xh)vz =E`1CaNAR0P+a#Ke~TL*Sp0NxKsHL=&TH8nGE zps}$xp`iGqi1tQ$jyCrAoWerF!gO>@Y;<%WA>iPrqhnzNzF!V{;5Qxc%}htf0DLn7 z=L|rezrTURmvjbJAP;b`F#UM|X)iPw`G9Kx?Uw`e|3&}h*+4!PfFKBXC ^Ij%8w*M+0Fs1)@K4vyx(EU{wM|(Xp zOC$T=V2cW7?nW>D@ggeZZEOJkf9coX5}a1d+8AK-moe}gs?otw&))HO4>B@=m{{OQ KNQ7jB;r<`eM<)^h literal 0 HcmV?d00001 diff --git a/hitchhiking_rotations/__init__.py b/hitchhiking_rotations/__init__.py new file mode 100644 index 0000000..646fbf9 --- /dev/null +++ b/hitchhiking_rotations/__init__.py @@ -0,0 +1,4 @@ +import os + +HITCHHIKING_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +"""Absolute path to the hitchhiking repository.""" diff --git a/hitchhiking_rotations/datasets/__init__.py b/hitchhiking_rotations/datasets/__init__.py new file mode 100644 index 0000000..bf3db62 --- /dev/null +++ b/hitchhiking_rotations/datasets/__init__.py @@ -0,0 +1 @@ +from cube_dataset import CubeImageToPoseDataset, PoseToCubeImageDataset diff --git a/hitchhiking_rotations/datasets/cube_dataset.py b/hitchhiking_rotations/datasets/cube_dataset.py new file mode 100644 index 0000000..e1ddaff --- /dev/null +++ b/hitchhiking_rotations/datasets/cube_dataset.py @@ -0,0 +1,48 @@ +import os +import pickle +from scipy.spatial.transform import Rotation +import torch +import roma +from torch.utils.data import Dataset + + +class CubeImageToPoseDataset(Dataset): + def __init__(self, args, device, dataset_file, name): + rots = Rotation.random(args.dataset_size) + quats = rots.as_quat() + + self.quats = torch.from_numpy(quats) + self.imgs = [] + dataset_file = dataset_file + "_" + name + ".pkl" + + if os.path.exists(dataset_file): + dic = pickle.load(open(dataset_file, "rb")) + self.imgs, self.quats = dic["imgs"], dic["quats"] + print("Dataset file exists -> loaded") + else: + from .dataset_generation import DataGenerator + + dg = DataGenerator(height=args.height, width=args.width) + for i in range(args.dataset_size): + # TODO normalize data + self.imgs.append(torch.from_numpy(dg.render_img(quats[i]))) + dic = {"imgs": self.imgs, "quats": self.quats} + pickle.dump(dic, open(dataset_file, "wb")) + print("Dataset file was created and saved") + + self.imgs = [i.to(device) for i in self.imgs] + self.quats = self.quats.to(device) + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, idx): + return self.imgs[idx].type(torch.float32) / 255, roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32) + + +class PoseToCubeImageDataset(CubeImageToPoseDataset): + def __init__(self, args, device, dataset_file, name): + super(PoseToCubeImageDataset, self).__init__(args, device, dataset_file, name) + + def __getitem__(self, idx): + return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.imgs[idx].type(torch.float32) / 255 diff --git a/hitchhiking_rotations/datasets/data_generator.py b/hitchhiking_rotations/datasets/data_generator.py new file mode 100644 index 0000000..ea2af9d --- /dev/null +++ b/hitchhiking_rotations/datasets/data_generator.py @@ -0,0 +1,67 @@ +import mujoco +import torch +import numpy as np +from PIL import Image + + +class DataGenerator: + def __init__(self, height: int, width: int): + xml = """ + + + """ + # Make model, data, and renderer + self.mj_model = mujoco.MjModel.from_xml_string(xml) + self.mj_data = mujoco.MjData(self.mj_model) + self.renderer = mujoco.Renderer(self.mj_model, height=height, width=width) + + # enable joint visualization option: + self.scene_option = mujoco.MjvOption() + self.scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = False + + def render_img(self, quat: np.array) -> np.array: + """ + Returns image for the body with the specified rotation. + + Args: + quat (np.array, shape:=(4) ): scipy format x,y,z,w + """ + mujoco.mj_resetData(self.mj_model, self.mj_data) + + # mj_data.qpos = np.random.rand(4) + self.mj_data.qpos = quat + + mujoco.mj_forward(self.mj_model, self.mj_data) + self.renderer.update_scene(self.mj_data, scene_option=self.scene_option) + img = self.renderer.render() + + return img + + def __del__(self): + self.renderer.close() + + +if __name__ == "__main__": + dg = DataGenerator(64, 64) + img = dg.render_img(np.array([0, 0, 0, 1])) + + i1 = Image.fromarray(img) + i1.show() + + img = dg.render_img(np.array([0, 1, 0, 1])) + + i1 = Image.fromarray(img) + i1.show() diff --git a/hitchhiking_rotations/models/__init__.py b/hitchhiking_rotations/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hitchhiking_rotations/models/models.py b/hitchhiking_rotations/models/models.py new file mode 100644 index 0000000..81114b0 --- /dev/null +++ b/hitchhiking_rotations/models/models.py @@ -0,0 +1,53 @@ +import torch +from torch import nn + + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim): + super(MLP, self).__init__() + self.model = nn.Sequential( + nn.Linear(input_dim, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, output_dim), + ) + + def forward(x): + return self.model(x) + + +class CNN(nn.Module): + def __init__(self, rotation_representation_dim, width, height): + super(CNN, self).__init__() + Z_DIM = rotation_representation_dim + IMAGE_CHANNEL = 3 + Z_DIM = 10 + G_HIDDEN = 64 + X_DIM = 64 + D_HIDDEN = 64 + + self.INP_SIZE = 5 + self.rotation_representation_dim = rotation_representation_dim + self.inp = nn.Linear(self.rotation_representation_dim, self.INP_SIZE * self.INP_SIZE * 10) + self.seq = nn.Sequential( + # input layer + nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False), + nn.BatchNorm2d(G_HIDDEN * 8), + nn.ReLU(True), + # 1st hidden layer + nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(G_HIDDEN * 4), + nn.ReLU(True), + # 2nd hidden layer + nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(G_HIDDEN * 2), + nn.ReLU(True), + # 3rd hidden layer + nn.ConvTranspose2d(G_HIDDEN * 2, IMAGE_CHANNEL, 4, 2, 1, bias=False), + ) + + def forward(self, x): + x = self.inp(x) + x = self.seq(x.reshape(-1, 10, self.INP_SIZE, self.INP_SIZE)) + return x.permute(0, 2, 3, 1) diff --git a/hitchhiking_rotations/utils/__init__.py b/hitchhiking_rotations/utils/__init__.py new file mode 100644 index 0000000..f9d48af --- /dev/null +++ b/hitchhiking_rotations/utils/__init__.py @@ -0,0 +1,3 @@ +from .conversions import get_rotation_representation_dim, to_rotmat, to_rotation_representation +from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles +from .metrics import chordal_distance, l2_dp_loss, cosine_similarity_loss, chordal_loss, mse_loss diff --git a/hitchhiking_rotations/utils/conversions.py b/hitchhiking_rotations/utils/conversions.py new file mode 100644 index 0000000..70a105c --- /dev/null +++ b/hitchhiking_rotations/utils/conversions.py @@ -0,0 +1,118 @@ +from pose_estimation import euler_angles_to_matrix, matrix_to_euler_angles +import roma +import torch + + +def get_rotation_representation_dim(rotation_representation: str) -> int: + """ + Return dimensionality of rotation representation + + Args: + rotation_representation (str): rotation representation identifier + + Returns: + int: dimensionality of rotation representation + """ + if rotation_representation == "euler": + rotation_representation_dim = 3 + elif rotation_representation == "rotvec": + rotation_representation_dim = 3 + elif ( + rotation_representation == "quaternion" + or rotation_representation == "quaternion_canonical" + or rotation_representation == "quaternion_rand_flip" + ): + rotation_representation_dim = 4 + + elif rotation_representation == "procrustes": + rotation_representation_dim = 9 + elif rotation_representation == "gramschmidt": + rotation_representation_dim = 6 + else: + raise ValueError("Unknown rotation representation" + rotation_representation) + + return rotation_representation_dim + + +def to_rotmat(inp: torch.Tensor, rotation_representation: str) -> torch.Tensor: + """ + Supported representations and shapes: + + quaternion: N,4 - comment: XYZW + quaternion_canonical: N,4 - comment: XYZW + gramschmidt: N,3,2 - + procrustes: N,3,3 - + rotvec: N,3 - + + Args: + inp (torch.tensor, shape=(N,..), dtype=torch.float32): specified rotation representation + rotation_representation (string): rotation representation identifier + + Returns: + (torch.tensor, shape=(N,...): SO3 Rotation Matrix + """ + + if rotation_representation == "euler": + base = euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY") + + elif ( + rotation_representation == "quaternion" + or rotation_representation == "quaternion_canonical" + or rotation_representation == "quaternion_rand_flip" + ): + inp = inp.reshape(-1, 4) + # normalize + inp = inp / torch.norm(inp, dim=1, keepdim=True) + base = roma.unitquat_to_rotmat(inp.reshape(-1, 4)) + + elif rotation_representation == "gramschmidt": + base = roma.special_gramschmidt(inp.reshape(-1, 3, 2)) + + elif rotation_representation == "procrustes": + base = roma.special_procrustes(inp.reshape(-1, 3, 3)) + + elif rotation_representation == "rotvec": + base = roma.rotvec_to_rotmat(inp.reshape(-1, 3)) + + return base + + +def to_rotation_representation(base: torch.Tensor, rotation_representation: str) -> torch.Tensor: + """ + Quaternion representation is always XYZW + For Euler uses XZY + + Args: + base (torch.tensor, shape=(N,3,3), dtype=torch.float32): SO3 Rotation Matrix + rotation_representation (string): rotation representation identifier + + Returns: + (torch.tensor, shape=(N,...): Returns selected rotation representation + """ + + rotation_representation_dim = get_rotation_representation_dim(rotation_representation) + if rotation_representation == "euler": + rep = matrix_to_euler_angles(base, convention="XZY") + + elif rotation_representation == "quaternion": + rep = roma.rotmat_to_unitquat(base) + + elif rotation_representation == "quaternion_rand_flip": + rep = roma.rotmat_to_unitquat(base) + rand_flipping = torch.rand(base.shape[0]) > 0.5 + rep[rand_flipping] *= -1 + + elif rotation_representation == "quaternion_canonical": + rep = roma.rotmat_to_unitquat(base) + rep[rep[:, 3] < 0] *= -1 + + elif rotation_representation == "gramschmidt": + rep = base[:, :, :2] + + elif rotation_representation == "procrustes": + rep = base + + elif rotation_representation == "rotvec": + rep = roma.rotmat_to_rotvec(base) + + return rep.reshape(-1, rotation_representation_dim) diff --git a/hitchhiking_rotations/utils/euler_helper.py b/hitchhiking_rotations/utils/euler_helper.py new file mode 100644 index 0000000..d48ae12 --- /dev/null +++ b/hitchhiking_rotations/utils/euler_helper.py @@ -0,0 +1,133 @@ +import torch + + +# Stolen from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#euler_angles_to_matrix +# BSD License +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: + """ + + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [_axis_angle_rotation(c, e) for c, e in zip(convention, torch.unbind(euler_angles, -1))] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def _angle_from_tan(axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan), + central_angle, + _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan), + ) + return torch.stack(o, -1) diff --git a/hitchhiking_rotations/utils/metrics.py b/hitchhiking_rotations/utils/metrics.py new file mode 100644 index 0000000..caba9d4 --- /dev/null +++ b/hitchhiking_rotations/utils/metrics.py @@ -0,0 +1,48 @@ +import torch +from .conversions import to_rotmat + + +def chordal_distance(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return torch.norm(pred - target, p="fro", dim=[1, 2]) + + +def l2_dp_loss(pred: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Returns distance picking l2 norm + + Args: + pred (torch.Tensor, shape=(N,4)): Prediction Quaternion XYZW + target (torch.Tensor, shape=(N,4)): Target Quaternion XYZW + + Returns: + (torch.Tensor, shape=(N)): distance + """ + assert pred.shape[1] == 4 + assert target.shape[1] == 4 + + with torch.no_grad(): + target_flipped = target.clone() + target_flipped[target_flipped[:, 3] < 0] *= -1 + + normal = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=1) + flipped = torch.nn.functional.mse_loss(pred, target_flipped, reduction="none").mean(dim=1) + + m1 = normal < flipped + return (normal[m1].sum() + flipped[~m1].sum()) / m1.numel() + + +def cosine_similarity_loss(pred: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: + return torch.nn.functional.cosine_similarity(pred, target).mean() + + +def chordal_loss(pred: torch.Tensor, target: torch.Tensor, rotation_representation: str, **kwargs) -> torch.Tensor: + base_pred = to_rotmat(pred, rotation_representation) + + with torch.no_grad(): + base_target = to_rotmat(target, rotation_representation) + + return chordal_distance(base_pred, base_target).mean() + + +def mse_loss(pred: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: + return torch.nn.functional.mse_loss(pred, target).mean() diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..c5e9a71 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,186 @@ +from scipy.spatial.transform import Rotation +import numpy as np +from torchvision.models import resnet18 +from PIL import Image +from torch import nn + +import torch +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +import numpy as np +import matplotlib.pyplot as plt +import argparse +import roma +import numpy as np +import os + +from pose_estimation import euler_angles_to_matrix, matrix_to_euler_angles, ToyDataset +from pose_estimation import l2_dp_loss, cosine_similarity_loss, chordal_distance, mse_loss, chordal_loss +from pose_estimation import to_rotmat, get_rotation_representation_dim, to_rotation_representation + +parser = argparse.ArgumentParser() +parser.add_argument("--epochs", type=int, default=1000) +parser.add_argument("--height", type=int, default=64, help="img_height") +parser.add_argument("--width", type=int, default=64, help="img_width") +parser.add_argument("--dataset_size", type=int, default=2048, help="img_width") +parser.add_argument("--batch_size", type=int, default=32, help="batch_size") +parser.add_argument("--seed", type=int, default=0, help="number of seeds") +parser.add_argument("--prefix", type=str, default="_chordal", help="number of seeds") +parser.add_argument( + "--out_dir", + type=str, + default="/media/jfrey/git/pose_estimation/DenseFusion/pose_estimation/results/img_to_pose", + help="batch_size", +) +parser.add_argument( + "--dataset_file", + type=str, + default="/media/jfrey/git/pose_estimation/DenseFusion/pose_estimation/data", + help="batch_size", +) + +args = parser.parse_args() + + +class Trainer: + def __init__(self, rotation_representation, device, args, metric="l2"): + if metric == "l2": + self.loss = torch.nn.MSELoss() + elif metric == "dp": + self.loss = l2_dp + elif metric == "cosine_similarity": + self.loss = cosine_similarity + elif metric == "chordal": + self.loss = chordal_loss + + self.rotation_representation = rotation_representation + self.rotation_representation_dim = get_rotation_representation_dim(rotation_representation) + + self.input_dim = int(args.width * args.height * 3) + self.model = nn.Sequential( + nn.Linear(self.input_dim, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, self.rotation_representation_dim), + ) + # self.model = resnet18(weights=None, num_classes=self.rotation_representation_dim) + + self.model.to(device) + self.device = device + + # previously 0.01 worked kind of + self.opt = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + self.reset() + + def train_batch(self, x, target): + self.opt.zero_grad() + pred = self.model(x.reshape(-1, self.input_dim)) + + with torch.no_grad(): + target_rep = to_rotation_representation(target, self.rotation_representation) + + loss = self.loss(pred, target_rep, self.rotation_representation) + loss.backward() + self.opt.step() + + self.loss_sum_train += loss.item() + self.count_train += 1 + + return loss + + @torch.no_grad() + def test_batch(self, x, target): + pred = self.model(x.reshape(-1, self.input_dim)) + with torch.no_grad(): + target_rep = to_rotation_representation(target, self.rotation_representation) + + pred_base = to_rotmat(pred, self.rotation_representation) + + # Alternative you can use: roma.rotmat_geodesic_distance + self.loss_sum_test += chordal_distance(pred_base, target).mean().item() + self.count_test += 1 + + def reset(self): + self.loss_sum_train = 0 + self.count_train = 0 + self.loss_sum_test = 0 + self.count_test = 0 + + def get_epoch_summary(self, name, verbose): + tr = self.loss_sum_train / self.count_train + te = self.loss_sum_test / self.count_test + + if verbose: + tr_str = str(round(tr, 6)) + te_str = str(round(te, 6)) + print(f"{name}".ljust(15) + f"-- Train loss (mse): {tr_str} -- Test average (chordal_distance) : {te_str}") + + return tr, te + + +s = args.seed +torch.manual_seed(s) +np.random.seed(s) + +device = "cuda" +trainers = {} +# trainers["euler_l2"] = Trainer("euler", device=device, args=args) +# trainers["rotvec_l2"] = Trainer("rotvec", device=device, args=args) + +# trainers["quaternion_fixed_l2"] = Trainer("quaternion", device=device, args=args) +# trainers["quaternion_rf_l2"] = Trainer("quaternion_rand_flip", device=device, args=args) +# trainers["quaternion_dp"] = Trainer("quaternion", device=device, args=args, metric="dp") + +# trainers["quaternion_canonical_l2"] = Trainer("quaternion_canonical", device=device, args=args) +# trainers["quaternion_canonical_cosine_similarity"] = Trainer( +# "quaternion_canonical", device=device, args=args, metric="cosine_similarity" +# ) +# trainers["gramschmidt_l2"] = Trainer("gramschmidt", device=device, args=args) +# trainers["procrustes_l2"] = Trainer("procrustes", device=device, args=args) + + +trainers["euler_chordal"] = Trainer("euler", device=device, args=args, metric="chordal") +trainers["rotvec_chordal"] = Trainer("rotvec", device=device, args=args, metric="chordal") +trainers["quaternion_chordal"] = Trainer("quaternion", device=device, args=args, metric="chordal") +trainers["gramschmidt_chordal"] = Trainer("gramschmidt", device=device, args=args, metric="chordal") +trainers["procrustes_chordal"] = Trainer("procrustes", device=device, args=args, metric="chordal") + +training_data = ToyDataset(args=args, device=device, dataset_file=args.dataset_file, name="train") +train_dataloader = DataLoader(training_data, num_workers=0, batch_size=args.batch_size, shuffle=True) + +test_data = ToyDataset(args=args, device=device, dataset_file=args.dataset_file, name="test") +test_dataloader = DataLoader(test_data, num_workers=0, batch_size=args.batch_size, shuffle=True) + +train_losses, test_losses = {n: [] for n in trainers.keys()}, {n: [] for n in trainers.keys()} + +for epoch in range(args.epochs): + for j, batch in enumerate(train_dataloader): + x, target = batch + + for name, trainer in trainers.items(): + trainer.train_batch(x, target) + + for j, batch in enumerate(test_dataloader): + x, target = batch + + for name, trainer in trainers.items(): + trainer.test_batch(x, target) + + print(f"Epoch {epoch}:") + for name, trainer in trainers.items(): + train_loss, test_loss = trainer.get_epoch_summary(name=name, verbose=True) + train_losses[name].append(train_loss) + test_losses[name].append(test_loss) + trainer.reset() + print("") + +for name, trainer in trainers.items(): + print(f"{name} -- Best Train loss (mse on given representation regression): ", np.array(train_losses[name]).min()) + print(f"{name} -- Best Test average (chordal_distance): ", np.array(test_losses[name]).min()) + +pf = args.prefix +np.save(os.path.join(args.out_dir, f"seed_{s}_train_losses{pf}.npy"), train_losses) +np.save(os.path.join(args.out_dir, f"seed_{s}_test_losses{pf}.npy"), test_losses) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..074df40 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import find_packages +from distutils.core import setup + +INSTALL_REQUIRES = ["numpy", "pip", "scipy", "scikit-learn", "matplotlib", "seaborn", "pytictac", "roma", "black"] + +setup( + name="hitchhiking_rotations", + version="0.0.1", + author="Rene Geist, Jonas Frey, Mikel Zhobro", + author_email="jonfrey@ethz.ch", + packages=find_packages(), + python_requires=">=3.7", + description="Code for: Position Paper: Learning with 3D rotations, a hitchhiker’s guide to SO(3)", + install_requires=[INSTALL_REQUIRES], + dependencies=[], + dependency_links=[], +) diff --git a/tests/roma_verify_rotations.py b/tests/roma_verify_rotations.py new file mode 100644 index 0000000..650b0f3 --- /dev/null +++ b/tests/roma_verify_rotations.py @@ -0,0 +1,76 @@ +import torch +import roma +import numpy as np +from scipy.spatial.transform import Rotation as R +from pose_estimation import euler_angles_to_matrix + + +def test_roma_quaternion(): + BS = 1 + NR_SAMPLES = 1000 + for i in range(NR_SAMPLES): + test_rotation = R.random(BS) + quat_wxyz = torch.from_numpy(test_rotation.as_quat()) + # xyzw + out = roma.unitquat_to_rotmat(quat_wxyz) + + if np.abs((test_rotation.as_matrix() - out.numpy())).sum() > 0.00001: + print("Something went wrong.") + # raise ValueError("Something went wrong.") + print("test_roma_quaternion - successfully working") + + +def test_special_gramschmidt(): + BS = 2 + NR_SAMPLES = 1000 + for i in range(NR_SAMPLES): + test_rotation = R.random(BS) + mat = torch.from_numpy(test_rotation.as_matrix()) + out_mat = roma.special_gramschmidt(mat[:, :, :2]) + + error = (mat - out_mat).sum() + if error > 0.000001: + raise ValueError("Something went wrong.") + print("test_ortho6d - successfully working") + + +def test_special_procrustes(): + BS = 2 + NR_SAMPLES = 1000 + for i in range(NR_SAMPLES): + test_rotation = R.random(BS) + mat = torch.from_numpy(test_rotation.as_matrix()) + + out_mat = roma.special_procrustes(mat[:, :, :]) + + error = (mat - out_mat).sum() + if error > 0.000001: + raise ValueError("Something went wrong.") + print("test_special_procrustes - successfully working") + + +def test_euler_pytorch_3d(): + BS = 1 + NR_SAMPLES = 1000 + for i in range(NR_SAMPLES): + test_rotation = R.random(BS) + mat = torch.from_numpy(test_rotation.as_matrix()) + euler_extrinsic = torch.from_numpy(test_rotation.as_euler("XZY", degrees=False).astype(np.float32)) + + # Sanity check + R.from_euler("xzy", test_rotation.as_euler("xzy", degrees=False), degrees=False).as_matrix() + test_rotation.as_matrix() + + out_mat = euler_angles_to_matrix(euler_extrinsic, convention="XZY") + error = np.abs(mat - out_mat).sum() + if error > 0.0001: + raise ValueError("Something went wrong. Intrinsic") + + print("test_euler - successfully working") + + +if __name__ == "__main__": + test_roma_quaternion() + test_special_gramschmidt() + test_special_procrustes() + test_euler_pytorch_3d() diff --git a/visu/figure_11.py b/visu/figure_11.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_12.py b/visu/figure_12.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_13.py b/visu/figure_13.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_14.py b/visu/figure_14.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_16.py b/visu/figure_16.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_17.py b/visu/figure_17.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_18.py b/visu/figure_18.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_19.py b/visu/figure_19.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_20.py b/visu/figure_20.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_21.py b/visu/figure_21.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/figure_6.py b/visu/figure_6.py new file mode 100644 index 0000000..e69de29 diff --git a/visu/tables_21.py b/visu/tables_21.py new file mode 100644 index 0000000..e69de29+ ++ + + + + + + + + + +