From 6e79a3d9de69af1cb282a731454c0b3a8076c280 Mon Sep 17 00:00:00 2001 From: maskomic Date: Tue, 1 Feb 2022 15:33:15 +0100 Subject: [PATCH] evaluation update - minor fixes --- MGMM.png | Bin 15900 -> 0 bytes data/results/toy/toy_results_collection.bson | Bin 0 -> 44780 bytes .../results/toy/toy_results_names_scores.bson | Bin 0 -> 632 bytes scripts/PoolAE/PoolAE_script.jl | 126 ++++++++ scripts/evaluation/MIL/mill_results.jl | 11 +- scripts/evaluation/MIL/mill_results_table.jl | 7 + scripts/evaluation/toy/test.png | Bin 8367 -> 0 bytes scripts/evaluation/toy/toy_results.jl | 4 +- scripts/evaluation/toy/toy_summary.jl | 63 ++++ src/evaluation/plotting.jl | 8 + src/models/PoolAE.jl | 304 ++++++++++++++++++ src/models/PoolModel.jl | 6 +- 12 files changed, 522 insertions(+), 7 deletions(-) delete mode 100644 MGMM.png create mode 100644 data/results/toy/toy_results_collection.bson create mode 100644 data/results/toy/toy_results_names_scores.bson create mode 100644 scripts/PoolAE/PoolAE_script.jl delete mode 100644 scripts/evaluation/toy/test.png create mode 100644 scripts/evaluation/toy/toy_summary.jl create mode 100644 src/models/PoolAE.jl diff --git a/MGMM.png b/MGMM.png deleted file mode 100644 index 6a6913f06140f5c120e7353fdec876679d801457..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 15900 zcmd73byQVtyEnS@Q9+OtX=#uKC8Q)4-O?$bbfeNOAtfLpASEH)-Jzg>v~(j9BHbO| zwS3=w&i>B%&OUpbKfZD1Gln{N7Hi(~zOVSjHA5aNNn>GROq_Re)s3Wh zWU+D$fe^*2^K;mks#VS1yEs3^#>STO=j7*iJKCDF9xb{P_U-A@r|^5&=;*1rByeMm zWsH{j_Q>1F*Ep0#c}DYX;V+lE<9VzJsp+`5xOjP;+avCB^ZbKAq=wEM;2G3Av(eE> z9Mt$+UdYPIo}g|x_NIsj86FxMN`CjQY2@9zcQG+BH8jjjO});x+dYL1dB-ZPM?ZWZ zX3$j^b5;?+4*JwiK_EkiPfJVd<+6jl>GV^(oJiLYC_@8mMcvC$}Pc($fT zU#m0oTt_%|J{d%|hM9)m4wZWfi21mDNgx<)F!R zI1jgjHJ#ZeFAw<3z(Dxrl0t$21kGdrBE?UbMC>+6wuQQtGy!{?vm$J4io9)KzEJb< z@bL4;x6bXYjSr?CD8NB`^xp|k`g>u}_H=hQG&U~!GcYiCogQT7zKXRtKiM}mH6`t8 zX>Bc?LLeL^?a{h?{Te+Ey^mA8T2f`OoJdB)x(ulxKpn4(S%@k$o&|~%SuqMUhrdOe zP*|U;eF>XttYc(kwCm)8KpcqTHpL%oOrL&FS1#1I9g1sYEYRil_)=!t_2CZ&;xmF8 zeO5tIQu1K>NdrTLmSODqsj=+{=@rC`5!+gYG%R@fKc_@VWJzWe_f(y?=}FKaY2@v& zZ?N|8_NiKrxEEO;)^8zFC6TyQ+$6FMWkRzySkg86ogJ|PwVZ=^2+^?DR5orC6M9Xn zn`uo9+dAKk&FZcoEaplgaS;e@EHQKhqTn?ZoQ02e>3IN3N8eq|N6DG8?yzoTQ83$r?N`3MOkBd`L^#FvqAm9 zrbg#kUHxO!tw9uk40&H_22oc1-+5R*udAOQ3;A4}!C{hi)#%%~y4K|8GX2ALrR0hq z?&rr6J`c`M&P})kG51Zsx$N}OCH)yMzsL3DO2amktk>x|B^}QLvXh&mPs&b)oyu0v zuJ|e$QKAYM=?MB23$4eGammQ?IaTTKnbv-HadcGE+ub_2jtGFDBIaCe#Zf5RsCFr_ z5qdXUP|%QZO$`0B5R|{$uPdD>eoe}G^(f=)##`P%kis$fUc?X(aXhIIeQF{u?qB*m zdX^D8h--eZR!mInR@DuAVfe-6*+a6XuI2>dU+XzdYmwQRqjiT8xPEp}WIt0itlR7s zH3yv}>@ibw^&s+5fo(9o#8#T!mo%ciG}&f8_*(JS-ao|%h>xw%4Anmk%Ekho0~+G- z+U{njB-hfupyIRP2KMun+_?-QSub&=Gg`BWS6ZT@F{V~lR>BuY^AUFiu_=W3`1xzy z#Js%D-_gta{k;e|4D<~RmDTP+m|H+KcxK06vKH4efeJNWetGte)FMa!VZ z;prum@x#q#*X>`OF)Wtmilh|L%n1n&9)s#t^0bwYTuD^)b4L`dhGJn_ z+SB!kN|E!uXd8#mzE=_-+Gl;Ty;D&)izO-JDcn$pF=lMXm63DAA4=MZ=?`VG=aZ9;THv<&jYx3W)-t=+~|$U#{!^ zt=l0IG{?lb>Ury7Ora4~`=?**l$c{M0+IeF8;OaD)@SO2_yMrq%F4=}FYxsq9v*=S zIXOAb>l5e4zf;Ju??LHcW?otT@u^dfn>QjN;^Tee*EaoAQ&XLt@P!&bEp+x0o@^g` zVoT?{{UMBOA3WZ9FGGh+gOBJ}M|&xDd)p*%J{YmAbKrCI#C-GS4fK^Jzqg}%Qsm6f z(Np?HNA;L6kra{lD0lJ2I_CO^swhOpdSo+IZ{NNxe!5{T(Uqmb!BIRhG11b39RK-A zI&z}I^3&s&`zyoujoTb*^&ynax<6!O(DB@^sIERdP7Djfy!!~6u;34TR1p_{ee3w* zbn_-P97U7ETuX**G&;3%itv*<$J_!2DypAbbFI+Rm}4#uChadSE*u>JT%=Mw$yPqw zYQ-_{O?vL^Ou^@1WMfkh6of`iCnDn0_wiAg-{skF5hf<4qQBQe`5eTYqHzsKR~NF?&b%a^8+f{x2?-@m8BMMtOh^xpM&5BK7HH6vf3qohF$0=wAB#mALRXHrl&BuG4rHMtl7};WWHIk+ij&Q)NX0X z?6-iR;A)YWEefjryi=>a(`eH@G%7>J*mro=oy0a2r_t+GRGfeN--YkL)ihk6q(~{N z)%fI=mX-zv2DZq^zuYY=d*JmD6OoF3Z>V)u;N`*8IFlET!IShuj_(lJfa|&)uOm|F z3-JE0xN%Gx9kEunmslBueRn{6q`&{OD3=jB6zKkE+%maezkY?B2#0v zYcBS_G=KKA*c+qm(7jnMtp6%JTk5wO6?ILTI(K<@L$C3$NmF)q_RZ$YrcR8$>R=6) zyARRW2;tdu@q?&{s!K*ztY26SlKiU&Q@Zja80d&OIt(sP~nL(|?`2|&&PKH-iXR@b^Mcw^#+rnq?Pq4nrg z&GS#&G9eNzi!b=emFA+cXzkEqQa!4m=FhgO#juSM;^RZXMEButs*mlT=Jh z%v~YZt@Fd#VjJhnwwJ$qd*Zpwp;ZC!baQi?u59v3 zN)dxJb|`0bbSA%hlT}nyNbl+D=!_rrQfX>xE-o%&@O=CBO*)c7Q9+@itSt1+8~qX^ zF@61{fm|ML?gC&JyGuP~UrI_iHDdNsD_GroK6ziedPUNbhhrKaBd~b9R-JiRUjkiH z1@iGkfHySjQUrd{LL+wIEl;KHZlV!zk-iXtsaZ8SCrU@My8d^TvHM6s`Gcn0%%96#k8!Msa=B{wVz(=|t zZalJ`r~n>leH+yXG&D4JzJ!KiP*SEW@lbjnr9q>hrek0*gNDKk&CQ_6^Z3`dFrMBx z3i2NM00;EavNHGN@vZ-}w|VsR*6J5lu{iLNV`7GlCdKjY?C#!^degzdY9*v9r!r+0 z=CW7Qft6@%Xl~HfHZk2SS#8f6WNc@5As|06FreDiM;g>>`jd?|U`s*KiJu>PZudA( z;P>*;c|Ug0?uiJVuJ=0y1x15H`9dasi!ao%v9aBrGX!V5Jh5__a4uot8prjVmIE2K zlT~*J2~kWgB_}1t(9$xuVP~XJk5=l<Q!PfPiBFUqoac6PXxhQ56>UFe7ct^`qazTHmQ z)!7LZADx?f;iXByo8LBqCz-uONGo|_N_bwJl*=jfT#8{(*x z6cMKt^;~sQ3W_GjKLex3fZY)WJtATjc6Kvaiiv93N`XN^ENqXQRtQr14)=#0d`w8t z(1IR$95BAQg^eAX%t6g%-1dfq%dEi@C07UxE^r39%zC*1FDGKxw=jtN__4`-e+7Vp zMXeN2nSL5>{w3SFmOxY@x3;vvO$hPu=H};H7ol6fygXSgF4U{m)z*f={^ zNziO>Z5@p^U)&)e$cVdu^;u+|IsU}HD*zhBt2s^h!G}>g^!&WM_Xg};0H2&M{y~%k z$06qbMfstGAfx!x_7Kk&` zuGmV;LD+BJ3SOwjpEDTHx;(eP5dp8pxPwgN6cGV%mD|*GF>2^D-1TJWN@`$dcQ-k& zncu+e_hk9Z>}Mr~^jvIL7%lJv0KBWYQ=~;ltI5c8v;<-T1L%8fh}e<}j*RrNxg#2= zpk-*bvl|o;O!5a^R3)0KyR&mZjSR81RQ8;H=E-oJgDYXcPf?-%$T&W2U}nHD21jl- z%rVm=?d|O&7_x*6Y;6mh60ai&RlI-LUOzrQ-rn98blLDc*<+GaaXh?cPnRY-_3|ZS zpIB>nvc0#bCL;j^4d>t;gl`=c0s1wFgK=oTe+ILeLNTzh4N8oz7$-yd<`i|q&C1AF z19Ab91r%vW_^hlfkf&gap4ITfe%0}eOKYpBlmbO5dQuNegbkR?*I>MmkdQ=P+XuY7 zQW${Q$a24vK7IW7v7OzZqrqeAjYG!HU~6}p(sARW4z zs%n+ZxT>}`-_w?BL_~2pIgH~^ADmcWLxkHwgpZ6kpX~jCegKzRmID9tUlQ8dV-WOY zq@*8(+|r8)QO)PmS-=S_>H=^w>0Ex=RQ zrzPO#Z3YEKM#iyvmrZDFP|pO-d&%1p{z3T0?WflX(9((<#{(-_3zq3$>3jf1Jq&72 zg~bmLdzS5T0@E2+_Wlf#c9D~l*E#-a5fwpvR*H{tLQ}=QEBI2?OUKJ;I=zpO&9Esc zzUWqxrWj-EoP)#R&CIX(xnDnc)@Zlx?`L8SO!yHkW-g)tA*;GRi{r}t((Pd#FP&aW z;ICSJ=Nx5JNI?MC;=934{E0-^bWVXqS48r%RO{vWF2qS#ou zc^|YurZ)$ya^x!?`7p5I;kb~1t4Ru2R8UgH#l=mV3Yi6?mT}s$*#A{e?GzcqgYIOJR>@GG789Z~LOzO`odwRCECj;fgZtVx*z2 zBX9Pu_CJ7405CH-Hr8=~4%G@GTL4fGC+F7BFaF-%-fwu(Mae#5&E{q&o{;9fqA~5R z_y=*4wveL|nyy)-&m^L1_QmfyKR-Wgp0GIPk1S5{~eM@g|Zz zj_mH;M<6#sUHsbK?z8tpnjAX__r?B-21vnBHrf(V_YK?mB-MPnQ4hU&w$90VXW@H_ z=T3;tLTeCCwsK1E{5G)-WTdY_r+@$07#{aXr zhB`X6uG`O$GTz?8k?k)X9ZgCe!CKsza9V!%r9_hRKVV`CX`QR7t^Mw|HEG}62rv`! z=HE}&Y2V5XVD#+ljLW>2j*hNFS}B414`?KS5*7BLoQn$5i)m|VxvY(qLQTdZ=PTLX zmxT7l5hVkikmO&hl=CYAfi^8IO*{x&&hRm)GB4#}4c#jE0iX<+nVIG0=3>kPsOE%ML}U_Zr*Ma1fSzp^&VP=hHwLi zVk%#YLr}0DxR;B|p`YJX5s_vT+x}M~uq=vMn47yScE$jXND^?$tEm}lYs&(W9q>VA zNr_R72z;Avu29;{&=5B>>I(gC2R|atIuk;|S1l}<1V%lyn<}ZU9`zoer>BSg8+Dz` z0<{T3X8j%}oSS&3e@R|mURs*i(~QTAp4-0yV+S9(ZW(1OB{!WPOn(0SnN6=Mhq>P8 zvZ=V3V?LIL=xj{$4b%wxP#VL zcZBND29nV{Exx|K(2Q;)y}@<}4n_yv{#Xrq&y_1z=H})m6u2Kg9G#rh^TvJ#Q19i7 z7X((YxCS5<3kwU-gmP0zB~_F^3dt%cyvGs{5C9}%HTe2UI(n2`W@kz|C>l&wi6}* zf1g{yCY~Mdl8}&;eF2DCWj(qvOgbnMIy5}Y-jAa~{Xe%d2r9 zHULST8Of0Y&PQc19e#*rbieH=0X+dmOqIH!VR76nXK1=Q8%?Qpw*yxIE)q{8EEqRE zP8m<7F}MOg?s&SCcf29*be=hak-1ZUDFAqx$1H8fD`0iB*e98!zNY4VOw4z*6cL{p zpmJ^3Z{EDAqoV^k0+&{?K@T8r&<0`AF-koh9gG43y3Q;)u@SA*d1Wbmzq>}$;!Uiq z@_Kxvhkm}`z6@H;)YjH!wEf=Qt-ul1G1j3yO!I(u#Ezrhu71e@c;ovjI>IogM_g=z zzjAaX07e9$lWYVzfX1dZj4w?Y7Mx=sq5Eswwa)A1u;*c=PT0Vsh|W#nX^Ug)dgpa$ z_k+KFJwpbIQpqSOZ9xB?zY0t#Qsm-8gbv_`hWBxNR_Y!U6ckZWQJYPt8`&w|buODT zQ2$Wj7Vs9BCWOY5+85vY`jneq6OobWX=!C!IiV`6X2msXY80^q%?8jKCOZ0uhxQ+B zOh6O`IgWt2-f0y$oDyI%1Y+gK9TZylkuHfr3+B#`PtrJnQBb=eUU3v{A&qPH_Vy^g zP!F5Zld`lsA3RH?Bur-kfk~7M(D>5P{ynvMo}u-N;5?yr8R>#gk>dXFt+#Icm*SGc z$+a9*_*7JIhnpht=C$5kL-d1yk;CnFk_|p07jLzuAp^UYQZNpHPQ)^~AFg_5Z@24x z=i8QpPrTyy^~uamt{_tXIrzUxO#jK6{NLw_{Qt_L`PX8&vg`+-V+Y|-0z|~a!~6L0 zBSxPpffKl>pn;Q;knC^H&i3_1Hlg%)^)&17l8TDG_3GIs=z8w%?nXvGL8l3WyhYJ1B6&|OLAlts>mz(PGbaz#Sm8&=y*_laCi)le}rTLJS#wP<`| z;v@K4rH?}KS>X2kz17xGt{N1+vUd0bvJ~Sww6ih-x&Uom-A+mfe$J30o|dySPrwu7 z#t$D>zkB!Y_3InP#`NX+xw$&!W<3x(8j5nvF{sJ|O@nu=%61ZzpGNPX^PtEk=Yx#4 z2SrM^6KEtD37Ea0t-MNSyxS`;kA;Pmr#nzbdAFJkRdpbL_e%qU)tNOagELhB)}u)R zdcv@ns*F74Iu(M*x) zt%_v~Bz6$(C5gta2QvW%P~b@r&v8V~9`_7gyDDaCnz{BIXpNFUgBNXx4LkKx73oUW zuHyqicw7a?>uLG<*3X}pu001WC||QMp%^<y} z*F{oGo}1HEJ#)|(8Es`XQXJhK&ffxVb!*V7)E$#5E zSCRLDMtnlT$|9G-sBjAiiUQS? zh#}zETEzS&9k}EF`24+2maq|9E9{q3RGAa@{Dq3^-wq=z)v}&upAH-HD9yPby?x&1 zyD2Hl&g9+-*Rbwb;mOy^_Y`pRZDUf$d*C>$(Q}Ub!uUeRwsYY(DdRAZul7hvYp&$D zKx>fhxz!ALKx@FbCH4IIa}A9&{8I;X$8)uJciVs)1=`l?KxUZtCHp_Gs6fKL($dl*6-lwauuuU`5d0RX zsPqgB35kgux)mKDvcZfBsL1UROa>PoU<5`pC<4+xU5Hk;o2n5*f9~Y8_@WkWBP=XT zNJw~cq;?A%4m0(-#dKZe+ACO5dHG8N)b%kkGO{5R6hUAs7{%{q`@VqE=&a^w@h=U$ zLmcFr_pz}c7Q(f(bn#LvYi(_9srtUE`{q6NIC&KUFBs5)HswqBbK!e5pkbmA2^OuQ zwTa67>S`AY3k!goKr#c7bKtIm9L@LaE4nQMJWdQGZd8({=>Tr*G_EqbEY4IF`U-1(H@!GEQJ(CC|$h{QL2R37z!YPK2AEQR4;GwWc9+-fb z4@9CC*iA^tci!Bn}RavT`52 zd@PXZhaWq(=iA{vl@>o}SXfvfMv~+hQ735e6v+C*=+C7WwGjBAIhBbWXPUrx4kiW$ zX)`mhmeIHne=l8fI1EwD{v)~Zjg5_+f6OyfEkS?{3JS8}6cp4ufQ=cPo^ISGfD#UB zR?FjhHjMwst$o0mD9c4fzrkGKJ*6az$*RKM zQ@44PkIQ{(IjO97U{kqWVQgf;FH7Q7smNYenCa__Kz4vup<8WN=+Ke#s-w3Tc<;uK zPeDy&mRC_}Mk1-gJDuDh)p+uRsv<5WY5FXEazd~4ndE=;2=H53^{Nz{iB7f3Uo3Wl zLD1at=sy1K+bAJ^*86ll##|ANoCgD>-DBXk?crqESwDZL_<#kksrEK5?hy7dmx<4D z`SH`Ib1ZwjPAfw&4$|C$X&DH?!|4b(^I|2x7*&-wHiT`)N+5N|k18rEUj2jS?&`X? zI=TtPzAOsDKYH5H5;Qi{Fj}@^;s^*!!d@p};`!q4{K-xMtLhCoe;H_g&_L7#>g(zR zj)xLXVU;j6lBJw71FiwI5L~2E65o$T4;LJqNMPB=@fVjTqs>q}OjsW*F{1NLY- z+v!S5N`fNV$qYuuGh16YcZ%D$Kb1)CfqnwyA?1N-D0Fjwe{uL~$L5ZTe=U4L`qhrB z5DYeg+y$D@F}>7HYUtHHhm#;hey8{l9}h>6n-U1W3VtSBuwcmC{aHTphy*Rr@!j0r zJF|bTtT5Y)&y)uRNdv^1*Ft4hF1RfGBi>eb-j?(y#L@J*&8*hf5C zDGM$p(16DudE7ZVMP0X%puU`e5yEEBpqNli5Z6%|k}R6`hPz52=s3f1-N*MW^gt@^9@bhfwu zhC3+A%H|dncvhN0c*0CpBxFqN`}ZgZ`5x#!kXQx=Gr)DA=L$Hj(6K;`9)>7WRaNDM zE&^t2o7xA?r>GfOl(q%+3D_R!t=q!_KW_dl$*9l3qHxG~6&MLR52``F2Ay&m?id&Sywu$Lo*cKgQt5vZ}l+&`C=Mm!C|kB^Q}V?;YBO$wxv3O60d8wRDM zF;@iuRu&dH4Q_7oC#dq&5B2Ejw?N5+RmU$=Q{+U`X86!SL|L4j74$9CbD=dZTaOg4 zZA{mLF>uQAtGmXPQf-SyFg&yY3ojkhyN@H?Q%3kZ1W`+dj89VIkU`m)de+?-1rkVP!Q^@L?CAAWV_uY${ zv-0;#?wF8|m_m;pC8mfn9a#fJ-%=RbTqUnaH++Skx7?F!eG>lS!u3Jv z^oQYmO(rki)Zfh;SVsvJ>-pX^y!Wfff}EGvOfs9gUD|6u>)0pq>*a9BCwt66qCs<7 z>00(hZO{KSKmMOLPij=<6Ml#Z%0A{W0NLqmt-LQOIXQ7uzs6zi@?!I{Mt>uS(i>G% z|6|ISr{xZd=Goq$GQhgRoE&0KBT*Q-M1BELLJC|!s728Kfn)ALl>B>8gv~ktTq7VP zvNAHuGmWP(?hIAd2cTPUaBv$B*+aV|E^{SEM{eqc@p5xWT3~mM+#pM&-hf|)TmV4@ zhDc}w9Qcx8$PqA{V%ickGD!Hcb{KlB_dXZ6`*$*}J&jTB_iaTPnc#?s-tXU|u-FY6 zK({-B9%`p=Y-Z*RJO)ZEF0#m=5nOl!P0cJvD^N2HYhTbuwnt-4jEyZX@M?`xnuJaU z;TF|(N#P<3>^CUaCPN381>{vj0&rVw9q*>z$lbHB3o+Khs!q+IQ(vF&WWPJ3d13V342|yo;IwP1o9!z$fPPf=7uB zkNc^<6FD8IT>M~TWOOzA_3inVx(MFR(F;R^z(mDpoJ4I_%^H={%(F6gWz44J*YXL& zi(2)Gb$kNMK8EiG%O#q=e3faf?6gtE4ybzm>&39wVJ5-;>>5ormCJr~dWO29yX0uZ zllQ`U1_l?YvC>oL!^#rxJPvKYlw&K6giG@Ve``~0K6MFY-dR^*ohnFFOuZ@Qmh9D7 zS)aJmyJ=@WBbNN7~z^ zv(NJ$swAdIE^qvz)ZK4oEz~{nU1GSEc!fRv&jdkFc72!+@lTq8yi7|+c=5sW&@ga7 zoaatIegO0^pq^kNmR(jXEi5>0&OAwL``2p-z=F!I0+WbgBP9cp9ULJHR-whX^z!QJ z>gHy%{j4xjMojDrXh1+h<38FNt3b@(2dKV1&X_NbZ9jyjy}&ubiEEQ~IFW zFZ*4`R8Ztey7l=D%=SRRfR&#{D+~XyzXz`nfM^kg0|+ae_&zf7yn~LO9-t~a5U1&B zFDMnj6-Vv4pyJpFnL#Z$IkyRuNVji3KLAZ6;H{Yd4X=yi-=0mc>#GOX;j4kq`kWm- z2T32cetmKAngm-3sfW+i*VcON_B;fKsr&Fi6*qum4sZl) z3rNB@UYz*ABf-6YcgN4KtdJ5BJ%q6|Fs*X=@dB?9NQP62%=GmYgE-H?kX2N)0bJ~d zvB(QdLRPRUG~L~4T7iY4PQ-eYeLU5hR+<5DsgTgwYH>5l$k5b`ZJMP1mvd3p4%KY7 z(W5#$J8KZvvI-yvY%om0Jg=Ps!6{0n!mNk%(E&{Sp+E^l3efMfI74qhvLN-<4jysq zm}O*V&(=I|?e0dwlW_m7@;>3N@84m-!v%?K0bd)i7-&Vr8|#aYx8B8a7z$es(0RF} z$nj6;gQorfoZ-MCKk5z=Vyn1`3%T#XdO(@doejrLdkiwp^FITbFiEA5lyj-( zYi(@}%%cB|E9iLhzkX$?=6WZlF*5V>k9T!hgWdwb7e+t&?M1?2 z!tO5r3lsR+iCqAUVJ!j3hS4+VP+)*H&nO?}6&J_ZpUye(0kM7ka@G%oxCHP-kB`6K z0G|5%ckkWvs*I?MYJxAT`T6-_p`j&cevpZ!9wLne=JcJpv4f+-XQqAU&$Hjkrh|*h>59It_h&a$;gW2!?as zmtbE*Qb{6iId>u5#!?z zz}N{a-Q3t%O-t(@C5*RPCWXDeX&u0pC__gINTaZD9UvywioOD<=Ri$2$IGRgZKt3- z!Wgczs2mul>dam{3$k)@q2BfHU`j5(xR?kd&|T=gZ{KfV!V{C1o4@8^=9X135KH&V=J(w&kO08k+QdWxJ)Yb09ktC0F^|;!=YWxsTezJiKoNO%nfTF_! z5`AYUjQac?)`$7olBUFg&%WLQbEdL%@8I}|$;jH<+sQ~QYsv(*k75X;)flz=`gAi%y?{l;6eCER(5Ps@>po z>3w|uW6v6}9<&2^S8h;Wt{L7*`N;Pi>1^~3Y)Fu|9O8&?EG#bKA>kDlpu@muAtv^w ze_Yno*B`Eq+IPjW1-u2D9Folv#+^>=opOg-Vi%P%Y}^3CVmRiJ%9l&;icTEQ2B_zR^A{`}|0E!UI1+$OGtgPNUa;O@axQ+8|lgAASf#aXSt zzyBHGyHrs+{FvBS0Y1Lv^Nh7NIHMzYkwGx)Rnl5B`s?!c{ECYAFwt2hN=h1=WV@g&&48aCETq^xB0C$r( z-H|pFi*9yUhTaOCP*Tqj!A86xe8ng@G%O5w=fXq7m_)GfTHuu#$RjIb~9-cPeKlqIx7KV2ch~ENJe%ksPL<3-n0FMsZDSh;-Y;@-R zz-WZGF2?5pgTw1dPoby*qvn0^Ko^8Zh+YE~6>1K4R@V1-O?SVbQ4$amLhc^IGn(%7 zU9rGJZIlTxT>wiyXWP=i5z(p3JOqS&F7z$8p*li#VN(jT)_$Mwk&7$3G&ZEyIjJlyEg`C-GE=*u>1o`77W$BzTix5x5Y!z=@4*?G&tzy~ z=tiP2xu~zNPZ|jZhK7c|y84i7Bo+17@83g;dJ!}FKwAI#wYFAJLS2!OA)!%PU+;m^ zM(%^3;plW@YYE;HBgO^W2gqmJ9O_AOnXxf2W(q~^ zA{-|a4~st(sivrVqej9|!}{PY!@FXjxxvP@YZ-cF=GA~^1>-+70Z>q$gLebW6#8@- zWW~TQ5aq9962ee3Op7(4=J{7~5Eifa2qgd1qaiGiRJc$-`yf?dGGdM?SG^NDAbG?9 z5o|uF#%Hke69oVG-a>^Jc1j=0*WLy5*`ao-?Mv+2w^Wns(kklu`ct4p^SJL>XnaGY z62ND5IYBYs*?9qPM}bTQ-;S!*FjzP=q(19)rGgF}=}TAu$|TIu=HzsMF$ymhpfq|H iX#^Y2|3gYvUZQ31UeX*Z$v>w z5mZo6yk$|=b5-zIk0>N0AwXOeug~NAy{hi2>SlTpT$#z|(tk|9u6|vws$SKrSFh^5 z*|>;OXS!zksrlCn6<$YEK4k1d{z-^`ppC?ktlPfaW@R!~gDh;aW{ISBg& ze0%{aio#LF*!qQl`mjRz*>1Q^jKxQfiIo@R)3=l`tfZtcuViezkTTfQ!YC*%uZR`r zBh!G$G;d&D9?J-d_jprLP+n1xUl2n9LL!q3ii@pWSym1!0aY~N)V5w67*6FO-X!FJ z(YQSbu@BoFx8qTt0pL))+%CCF!CP6~HlT>=F5Fw(#v|R=hwX$LUA`O*`ZXBPU{Hf0 zmLb1lzOYV>XfUb}sKyp4Vmw_Sr9i(11FW@t&wN1*hFBXrHLO!38jLECW%=1mAfrh{ z%N4}Nl@!Ooxg6KC-}M}DJqKOSAvsF&V}-FYmJxQnjkunp&S#btU_sOuET^8Ac+d^8 z-!#SYY*{~ZJW;C+C@hIp1pI(niJmRmxBaqrqcq0SLe`%ivVruF4W@@|DAi<3d{g_f z%y4?hM$$tznijGwD`?A28uN5PZe|MrpkIRlR%xdOb!v$9vs1$wjA$^bK$aC^Wk9y( z>9Tm?1bga~u$($29AG;GwfwII<>P6}BS9ozganLv5$RzAju>!MvTElFl5t>mE)s~!_ef zT**KF?tl%vb7zSUOD3Vmqwk^z*iJb*}Q77;83VV zRLIRCHudwq3&E+f;s3sN)kF0>8t+!Iu%VtGe(>t^9(uE$4}bCJ?A*OK^5q#>m6z_m zk)L$T9fu704HcqUNx&qB{bos`_D)PsR(JiMqB*n&!D~pel6z#F@ ztB99ZAWthtQYPr~NlI2KnQrJpYMD}Ad_&nDGMV|NP_dP^WM!#}scaC1WrkXr! z)LO_zvuJYBm=ahWg#C+3F0v+0CPxzuNAq|$&)m=xV4{Mqo*7_WJ+VkPLpfPykoELL zBIA~BV1`-qP)}Uy0S2pS0eFZfGO^F*I?IeSQ3TC)hWV+XpwmNTNG12Knn@kFFPJru2qX-#Np%6e+@Tbc-|i(4HJ_yKmfXQtF!T|)w2 z1X*v-FKEO#dpte}u_HV`puX$w&q%_BF8ScTX3sV4z{7#iEa4*9kVv>t7u(tj8309C zAw4QepNvhCEc1|u4bDv+#?7XsNbAT?#kCEZlwB)8*A9BIdu zu)~g<%bu1rNi!j(el*OUqw`~B!t6)V-A-iKLv=m}Sb~p&i4sF1v3x?&Xxd@5vJWRp zj3i2oCP;)8PbnaM^FIpY`rm^A^<7LIPk)I>_ z0=*fWce0%*P4fBUV@2o1vB?`?USG@J>H1o#qU!6g=hw1xy1tfM*!I^#|8L2Lv$~=) z7v#UPk)O7yerUT_H}S)V_-n^U>iMt}_nDXZcmrSj;n{=k9Nx&!>iSgMV}EKSJ5$Rr z%gAx3Er7NG?ZEw7y)bpcCVt6pch71U-OTTN`PI$)!@l&!vYTeCJ$Vy<+xKLrv+ii* z!>4!banCYQp>T_^Gr22dKR@r^^Q0|TY~-0M_TTHy%Qo>oYhV4mXI&k?@wQ^$RJM`# zyL0%3-*su^^L_1>)c+_d6lswavN?2ab^Fo}2XEqSYSu0r#5VJW5wAb<@*xd89RJ(5 z^*3$eTk3|Lyu7lJKR@NXF_lYEAzIt6!b(wFWz4nhNoc5Mt!)n?#6)Y`tqo^qf*MRg z{UOqQr6uJRWKU8`LMe_Bou}<)(y*p%Qp!}JtWk!wg(gh8e8N7YCJfueLg%;Wvdmpu zTdcH&MH%w6-KVwunE9}&Sfp3hrDft-%|lLm!#?&$4hx%McPkUyyrNX%@EFbU@=03i zR3Ux#3z(~@Ma6{!Lwlg%nl!6f_OcvyFh0|KHmxz#I96TAu$J6tWiyS8Yl80y4Lmg6 zi^d`IS~pM1=nR&nF4P=L-5+OA)4U~v+BVqqa+|P`$u5j>)n^+_dN8O7^H9R3q@Jz1 zJ~H+Uy0FFA^@4iK zEKVdgnJR%E(ALwS2Djlfr7u^122N;l@+=F+F4nQe@lA|MMI4;auyIlcoFKtSmk*rK zc}5L5SAPpmsC>W)^({D|(+#@dgvt|~&^QKNa6+f+!=uCrl@3m5*r+c!VV8jpn;w8} z*bIXx%|hkY;nAj#nvzu1=2dy-$NpNL`LUnmnIHQp*KL^H%2JyncHwM%@XkpJ3^Yp$ z9=hI#+zD=5Em-RM#B~-n1H;UT0S=o)!DQEacPTOOTo^%T4(?K76uLfemlEUAh2btG z`hyb#!_$@EDkq)#Er&`3Eha6;Z~t3P{mKKWJM7YbM!U4+ZOZM^l5EQD(voa$n#4uA z#8jH(6ehp)h~<<^D%q_tnWv||DfdvYoLGGzasT!i9NE!v;NQnUzdlK3G>Y zjzJPzqZ%h{z=o?H43+(4HmFXn#=$J8(lIk-+c+?fRVvt^3-g#xH_ny7M55D;lO-^W z=yZcFdUH0Dmk*W@4QsmOQ@CU&VII@ROr$%$JVN3({RpK4oYz?tnI1Mjh?o!ZG>Y^* z%+(%-8DNV+TN0O8V2n8SE4>)O8uQ6K3=~$FyUREZ3!w{{lpBO(1 z!RuI5mfb}hPMPYnT%7|LrIKOjS092+8aB=VV>P4kz0J$|TkQ)<<42`~H5xYSX`KH8 zTNEvfAB`U}(w*jIjc;r)fE_9w%uxA&7aG>}C7)G`69+GOT6_sfra%<%!+(iQ5O6}J z1~nMcAhuj=Tx>$vAdV&-^|H)2SL1>cDmQTAxmQBeQG}uYhMynDcqTZ&%BCD?r3cw} z9b>W}7%LhlY`_t=^8OnwoTvfkvK4RrkIWOCcyL8~(GWgJvfR%euwXwcK6H_!1x~EH zcxMYg$Yv~=N5#Pjx)8hiiGwXT%-)@Ut@Ulhz>gaF{t2Fp}M4- zU!p`*kY;RtDO^Xa43%nZLV*)1ImzaiC^`|xn9#7X3FX@S(m2K@6vl)~2PZUaY`S4i zQT^97CUm-UOz3py=9f-)jtP~HF`;2im)x2c;{~={lKBSbBz@~o`nD;r<79M0Npo(> z0XbRD-pAcCay?JN(e&3AJ;U(G)(7a&qv*Vb@s51BG9Yl1l>;N+l_U9hOD}HpKks*2 z9RM`GOWrV;WRsRCn*4Q==0Qhy6<9Y9M3DEp#Qy3S6RuBm9Wifszf0_|4wcONldVoh z8l;jp2xg&V%?WYnqN>{jgVOt*_mNR?$hI3?fqdFk9y8fk^ggz8o7pJ%W=^dk)P%P40|MUh$@xe%<6Ht@KPnF^Pb}_exnEc^?o{&Bp;Zdt;*zqWSYa=j>K6y;RVq2_&|xeq z*w$h07+Q49zO)2ElS+!M3}hflD~)AH`xIQuSVq*-aVIyC(-4#|Iwxp*?qpzDWfWd6 z6jV7a;!skc4eLswmhz;XEoWPr3V&9~$}eyW2GMnpw@S|M=Ta|2ASRzoOeM?|!nln%DfZ>nA_{+Q?`9Jmw!q z%v!7NRjuRSE`79yhu(PY&~Lld@LP|YR#^Kig6kBNKGOc_b^OpHhmI;n5S|WCPrLce z+!g%Zg_*7Q^B27cSfAL-R^bEth8?==00%-xQ^kQ0@k*RJCk zr|s2y@jX?18kK8b5p5{IS^`7d!Fg;SqvqcCV;67gOm1x!L+tPN*>{hl_o=MqS55l< z!p0{XdDjE4o!+K>JwI&J6Ej{OvYvl?#bfKX++EAFA9}6&spXA)$t7i%-x;V=_ueN* z&h58P9dB6GcBiGGI{wY1*;9sOZQysGe%aIUpT6LCy*agQPXrO#&;MSxd$0eHuNj^3 z%6pfs;Y-h7GW@yA*6~e?Zv9};qDH=Z?iIJ+`S|<%)cm76pND(#TN{2mXKpQj{i*JK zvL;vaew%#@hOV^2KeYsg*bVu6?<>ZAv|sl3njv;RCx0B_s)HF?D^~Xkhxt)WIoo&vnsB;>WaDV^X@~3 zuc*ZRtoQaF$u6knvtA$n!Ss@9KKI2zC!O1Rlk_4Sf+-!*9_vOP9~^5~K-^F6kwxEci`%2`bM#?IlQ)C|!xdZ#n${NiU$QIgz%6 zOO|qP9IAtg{l2(cOHjim9Rzp8-O~Nm^opb&rTI;6f<2Iu7?=85xgGb7f^K#9L>0Tu zb|9@j%?YH{R75n}kMGYDa4W!;V0`JOHpR*I^5g3VityvBgfgxJZjqi%7fkU=SQzeY z@+^WtkjEheD3>JyG+gS8%{_Stf)L12Z9HzU`qLtSE|5~7Tw#b*V_)CcUN?629i&ta zdP|=?u`E*0bUIpW9D>xOoXdswjSU6`2H5_O*U8kN216PQvjq!hil)H^!-1R2MAP`9 zEO6NyazhecJr*fjwFfB$YCk>f1*>K?HfD`|SuaOONe3;1_}@3+hV?KFEUfg3Mbr+G zp%o>i1^KYyQ}}^-UrIMGfES_FKMx z=_}2jHplj5Gryk2r!Cq0-q`8HpJ{Jgi?TE`AnS%U1=`upTYN40IZh|;PvAY06 zx0R{BkYGJgv}bahQ-I^&hFArG=Bq)$W~xDp*td?DsMywHiKbYscVOin0TE>7q?ot#kV;Dm;C&m%|wq@2(nX>NOs359F)PNPYTk~Q;qsLBvy zLiHm^Ck-2Y6Vgeidmj(gEok&jNGJ6zIH7Tjv4}CDahew!MU@|8Lc>l@*tBTH)EWeu z5s`xw7#(B79Z|v^_`)4gV!OuVFr#5$dT~5DBVB-vjQc9QOp@E-;(XRlQIwn<;yl)H#ZvMKivm1I-yAS!vAa{o|CHcjrK zDfbSQyv1Z*VGmCs#moiPb{V)9-IRNWT69zH9V&Yj=F=3nvqd*eq6!*9dg_}dVFjHh z#r3uNbIQe)MK|U0O7>@{iz#kri*Cy0l|?t@@=Ed-3#s(9bDHE87KkaXuhq_Jl2}bx zMRcA`&zVZ`8VE*k5F;ju5gH`I6t4cy8J`gs1e{P9 z;Dm;CTo8&L(2vQBeAYFN@rUIw59@S;Dh%~H-H5*eeMF_ha;0IvrUg#ubb~JDoNOj9 z9}HC*Ht2#Al?unGFF4W9P6sD6tm#tp0E(MnSyW`y4o46WaVSQnH!lnd*=i)d8Gyu$ zKhjhX6sr`D@FU=hIKX9vL~#W{(KNvu6VEIH*kj;1!jFJM##i8x`6LemS!_BC7>v(MMKxW>bJ&`VIJHJeW;EVFoOVepODVEMhPQj{^N{#bc8xkpe!b z)Sw1K8Vs|+RkLLd5e-J!s26S)R%;9%l>-KkhI825;ma}v5s0x(Pvls|2(WH-2UEFG zq@aNlVv)n|msx^A8b4y-L=8BXb$WQc%oD7+dXEM2TQEj_3o#ab^bspPz$QI9%EAw_ z8+#8F5{p7_JY725g2Qagk}(z>F>s=6_qnfF`Q)M{vNhF2s5_}_JIFH1!B_JW%V$3*$f7oF>{(IwXBx*sd&^((86HT&_fB}~5HO8DP z=RYFUhM(Pr)7`!up*HXzfs;$)EUPqD7Aq$`>r<(n%l7#{k&M|h9K zf6t7e(yZ_vutUM5BhIvuVetQHc#p*|(0}!62P4!3*H^sav1euk<2C$&Ip_U&#>z(C zxEa@?YCh%Sj>EEttmVV$&z4K8`EEa~U-S2Gwhr$x{i3(7npa=Lf112%-5aciZ@QR= z&Y4onHwgUWB3 zh@c4z)^0?6g)PU(axpfT{yI2T&BOU;cYW%XD&8MaL;m;Sdif{MErB8bX?Ty3FFk*K z0~%CT_$zJ)FSvE$L$};~U@iarfR0(2(;NBnDTlnj-$C^}b9PCuR|l`>@nJvy<9|og z@(yo2yy#m58E}X9SbWO~Gd_Q;j!#=XXUnBs>v)GYuRS~9O$0A^e$5kYJABTkbp2&* zdjv1oyVXgTJTd1(o>jB+Gw)rvhELwT^L-Z`fOsrN{p-4W=QZ+6k9n!rcz6KXo`A=^ zea3`0r`7V6H5n7;l~?l#@m&fx4{IQXgj?hYS^6I0X`tuX4eY*uCc0!J&!Z?Ez*~AX z#qHP{|JAI$*I(DShJSPXFZg7e)@Sa933Kd<>k%?+=p9{; zTDG1KSvhajUuHD&&0iFLU;T6)|Ecv2_m?cI;a{)2W726oYx(&{&-`TG0_{U-`%=!T z>vYsNxpn-&hBx=hLyvx8@Pt9x%j)?_U2e@eruS$3yVE*vz7m0Rew{Sv<7f8zkdN4- z=IvfR*YMeQk2?L;FV^yp?l?Mo;7t6hx#c5k-_JjMO!WAXwY>8^|2U+qsG8q+y0G#x zhAhl!rVlOlKf8ElN}Y8m`2zp<{Uvyabw`uxzHN%G?G;L7F+>HNa@#}S)}chQ!2`Om zNuDhLfPM`+&iX*^iOXj=1BQi zWYwc|Q7VeV>kB4-={aRTNe5kpDt%?uqja3IaE(|aswxf-Se5z*1ON%JB4h_y*PUu# z5gRAntV|(B3!ky^LfEA-2AaSd7b_|vJbx2-7sN_S2`|tDo=vwYXLy?~I&9K<*mT29 z;MsH|P2kydqfOvhbXjJwMMRuH8{yJ%9_{DiGZ*AvaZMfH_u4UYhn%7Bp`-s=edX;9 z{F9#Z&pc*U1Mf6)?2%Q|N&jyd5nJ5armbDhd2qm-AImrJb`MRfc$c69kJ(8q$t7L?;mSU&cI z)f6vfbuUieJIfgEG6cduXblH0AaifJjkKDKD((1f%f`*kydENyG zx|K|MPWsFRq9uEI;?Ro0EILW_9-i1PwX3NhC#sc593JMGE5-741d<_XbGRou9j%b) h!-iW8MZ>wb=Lci}bgnpPxE$e$PmwF!oI{el{ttT4A5Q=P literal 0 HcmV?d00001 diff --git a/data/results/toy/toy_results_names_scores.bson b/data/results/toy/toy_results_names_scores.bson new file mode 100644 index 0000000000000000000000000000000000000000..6b981b7b4373e37cb7a21a7f8a5521161feafd3a GIT binary patch literal 632 zcmbtS%SyyR5G*HAQSckQDCjNO#}B9=9s+_NAideJ3?`E$Gm85$exYaYR<{99D*;c_ zRM%8>*UY;ET#^lr0mL-C)(KhKZ5wKc8ZEC`jh2=8pq6TXBa!wqtM5G5C12M=d%&uI zD)8DVk2h(aR}y`#cNKW-`?l)WO{)TDBg2Q7B-9k0n`m`kRd5UIvy*&GzV^ZamL~EN z!l;V8#1~Z<aclg=FS#n40#9A}ckS_M}bT@k size(x, 2), val_data), + map(x -> size(x ,2), data3), + map(x -> size(x ,2), data4) +) +scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet) +savefig("enc_card.png") + + +model = pm_constructor(;idim=2, hdim=8, zdim=2, poolf=mean_max) +opt = ADAM() +ps = Flux.params(model) +loss(x) = pm_variational_loss(model, x; β=10) + +for i in 1:200 + Flux.train!(loss, ps, train_data, opt) + @info i mean(loss.(val_data)) +end + +X = hcat(val_data...) +Y = hcat([reconstruct(model, x) for x in val_data]...) + +scatter(X[1,:],X[2,:], markersize=2, markerstrokewidth=0) +scatter!(Y[1,:],Y[2,:], markersize=2, markerstrokewidth=0) +savefig("val_data_card_β=10.png") + +E = hcat([encoding(model, x) for x in val_data]...) +scatter(E[1,:],E[2,:],zcolor=vcat(zeros(Int, 100),ones(Int, 100))) +savefig("enc_card_β=10.png") + +E_an1 = hcat([encoding(model, x) for x in data3]...) +E_an2 = hcat([encoding(model, x) for x in data4]...) +scatter(E[1,:],E[2,:];label="normal", legend=:bottomright) +scatter!(E_an1[1,:],E_an1[2,:],label="anomalous 1") +scatter!(E_an2[1,:],E_an2[2,:],label="anomalous 2") +savefig("enc_anomaly_card_β=10.png") + +E_all = hcat(E, E_an1, E_an2) +card = vcat( + map(x -> size(x, 2), val_data), + map(x -> size(x ,2), data3), + map(x -> size(x ,2), data4) +) +scatter(E_all[1,:], E_all[2,:], zcolor=card, color=:jet) +savefig("enc_card_β=10.png") \ No newline at end of file diff --git a/scripts/evaluation/MIL/mill_results.jl b/scripts/evaluation/MIL/mill_results.jl index bb1f95d..8804541 100644 --- a/scripts/evaluation/MIL/mill_results.jl +++ b/scripts/evaluation/MIL/mill_results.jl @@ -11,8 +11,7 @@ using Plots using StatsPlots ENV["GKSwstype"] = "100" -#include(scriptsdir("evaluation", "MIL", "workflow.jl")) - +# names mill_datasets = [ "BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2", "Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein", @@ -27,6 +26,10 @@ mill_names = [ modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] modelscores = [:distance, :score, :type, :type, :type, :score] +####################################### +### First time results calculations ### +####################################### + # MIL results - finding the best model # if calculated for the first time mill_results_collection = Dict() @@ -58,10 +61,14 @@ for (modelname, score) in map((x, y) -> (x, y), modelnames, modelscores_agg) end save(datadir("dataframes", "mill_results_scores_agg.bson"), mill_results_scores_agg) +############################################# +### Load results from existing data files ### +############################################# # if already calculated, just load the data mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson")) mill_results_scores_agg = load(datadir("results", "MIL", "mill_results_scores_agg.bson")) +mill_results_scores = load(datadir("results", "MIL", "mill_results_scores.bson")) ################################################### diff --git a/scripts/evaluation/MIL/mill_results_table.jl b/scripts/evaluation/MIL/mill_results_table.jl index d26d9b4..eab674a 100644 --- a/scripts/evaluation/MIL/mill_results_table.jl +++ b/scripts/evaluation/MIL/mill_results_table.jl @@ -11,6 +11,13 @@ using Statistics using EvalMetrics using BSON +# Milldata sets names +mill_datasets = [ + "BrownCreeper", "CorelBeach", "CorelAfrican", "Elephant", "Fox", "Musk1", "Musk2", + "Mutagenesis1", "Mutagenesis2", "Newsgroups1", "Newsgroups2", "Newsgroups3", "Protein", + "Tiger", "UCSBBreastCancer", "Web1", "Web2", "Web3", "Web4", "WinterWren" +] + # load results dataframes modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] mill_results_collection = load(datadir("results", "MIL", "mill_results_collection.bson")) diff --git a/scripts/evaluation/toy/test.png b/scripts/evaluation/toy/test.png deleted file mode 100644 index 969ba04030df38467a4433594ff26bb83ff4fdf9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8367 zcmd6tc|6p8+xLHI;XFHKDk*z%o~4M2>{}xdB1P6Hd-i?bDk4IKWG@nC>|8p;S6&~}-Ps%6-OAGNuf1a@7zteQ3Xw_nJ%|X$k<&KNutVL#wiC$lx z?ESJ5z4|=M>!+4aFd3Xvh@lca`v<@2cXHUpLL2P0Z_Of{iQUWD{M93ILd0G-5eJj+ zXVlbBAxM2=+B!Vc^h*>K^63`M0puR1AA+PGMK#;BX%R6z3QFHDpi|hxf|BzQv z>GrOKXEVw-YIs|jso5%!;y?I&r{24KZf+B||3Q&?7v@f^;6ua1>hQYG#O!V`yE4;J zWsMacGaq*LtkE{^U7na9sSeOzO#S%rV^b8L(3w9dJ=Zx7`678<+iMJ$wc&a>@(rq{ zz~h|ncpsVzZl`F$(o~At+I=~Xt5*`v<-B4MY4>5CdYJqVj7DZU?CV=-H&*+`uolBK zpX=pWC$sHrQ7jV=CZ;j_>#x7&vkqckeD5~%QC}PLQSLI+fBF|rNL>v6-4>%i+Z|y0H0kq78 zhIOIuF6$^Kh`V)R)X_BR>f7t&y)~juT}{ozzyf0bF#QrL&GGe1y9 zd!%|}aqRy6`%;_Z;Y6Fthh(Wy3^?0|V?&lXZ?E0Tx@RNzWw0AyD17dIE#dP1o&i@pB9pLA; z^0RBUOGAHmGuz_Icy_G_8j)tTuGpMxZ*M0^?dj5}>giR$@mP4Osy%q{<42|r_oYkc zIQL(f53osGeJ$EPzxMm_Nylf!cRY-uj%Er9jaAR7$zXuy=4NbWf2K~6X`W5a^B_QN2OqjEqIpWrpAygLFJ21H)Qyh5Jvdh zx6pc^bkUF`?b;@W#}Qr5w`|Iu1Cozx1cbyP*7&jKgr5rQR)>F=)Z)`hb2 zM=LEZF800p_!p#ur@Zw><^~0w_yoQ)HIajtIKPe71X~Fb9S9PsEoZdP`60fd<5N@Q zUR<3)Zu!F3qpV;!VxiTOCr>K9cl}V=l@xaagHfxZf%ATv#kL~{>BL*}pLD>f2o1W9 z&Gq7n`t6(W>x0`ud!lD`f+c4wFh8egoEt*NhwcOC0=B z7gQ4BBs`p5T{miiPv5_H@AsoeZORvGOB|*sBXmBk9y_87uTaQca4=EFnf4kJ+?_aK zRT|ZY591RP!811@W!WzbS0#FHrB{zYbP-8KL)O=#A@cju@2AQJGyKLeUpyA-=jX@B zZ_=rkDuBRQNbeq^rgk9@dVz(+T;`h_8!;If{Sb}t{cG2*`JwiG8Uj?pqPDiS(6V=_ zC)e;P3wl5KxlT;!bH9J{CfAB6sImhW!?DD~#?JosSW{Ee*w~m96B!x#`0-=WSF`&W zI8DdMXdE_UZewC(RK7mwp`6MuC@9FuIkwuZFT@lg)NlKIN)CemG@l+lIzzv5(7eET zzF(`T1ph@b8SyP0a2fJCNPWz5sA8AZr9aoQHxJwZ;)Ry++D)1_5fPs-KjU2nXS5tz zKinV|*^t4gXf*o9jjC9i3Pu)|jM+J-vwjGwHRHkTe8*|ZM*1y{cj>DAqY4@W;GXpa zLUvY`bBQjRCcMFMX}saXhqD}hRgZ)r8VGy>rH`tfY2i#FBO)B6u z(DE5V#;UJ#3b*;Zh(kBes6n5y6jiymONKn&@%fIJ$J(4CAAY>M0YTW3<-&e=_|-NO zg_=Amn^WsUK0ZSgUL74BI2_JpzJI1W7cL@~FNfaHxoUUSeXlKMX^rtQ(T}X9BgrqvR*0X zJeSs@<~?CWBlqgeVih%ezD4(UqiZRjKPx9n?lN&_!*#PiG7~1?TQgl6**mOZcj;?J zY>Nl1@zgiIsM5ogw^Zrz)K3Xn(gIY)&`lXm$bIDv;@Hdj(dz z7b%5wPFYnRhTs;H4iD5J78M#9GYkR%S0%iZk~Da!e^J}f>?BeO+(btyR5nvPje ziAMFogSBRULd3gwgC1mjf-9L=lnp|_AlI@l{|RJy3HOyAIMd0N*cOvzC_it*!zFaM zxwtg(?oLmh%z)orbzOMGXG(MQ==Ld&l@~?KyBlL+ri292?Y?g>PeQK8a0N9Om5w&5 zs;WXRm5e=(F!`Gl4v(h=T4ZQ1E8Y`z=qlFhj1qckgYB3HQZ|S|#==`gK&uq2f=U{v@EA z-d$EyRFr6HXlO`qus?kGa93B?dcQ;16eajHpL>FW_wH&pybG0TvEOIDU!-h+GQ=*^ z`L>8dNkKu>bHmooF1KUmRdDdp9qz*j;yF8YdVm4lS@B!Vqbr@1dWTF#j+-}c`qdcp z7FcxixsSLt2yk;Z?Jc`w*-j4#GKJXL*+CId1AWzzp@n<<_5`^*EI3$EK>>qb3F?>; zcU$^|GJqt?hhMp;t{yo>dG?GdF-%+EbNdkT!J6~dtNr>7b)YVco!qi@~dk8+rNkL!gX2g~c>-D~RVJSt}N^mQGX5i0GQH@~GArle|kV_W9? zd5fxS{iozf#3Djd6Bh3_^3}g@;_y)hSB5+YZye5otRgSJF&fO**w`q3fd)!C>YT8! z>r8vfv4G^_V&O}d6ozom+4e(|4D)N*w~m?any_b zg;|Q^^T3s|bCO0qU0G2T!kOnBqHM$2T)Z zeE84&apZZGO4)f>P`Y;L%r3F`Gt=C3FCldDx=`I<_wDCiZe#sf9T#6PPZ8o>Y7$@g z`E8K7!}g^F3fG3}z8TDCbgbaLQq|;en)IfDt)lDQVvcf9OMT=RhyPN)|5LgBe`NmO zUDoVQhaeRD#ZlQJV&*!%RH^4?9ttyffB1JA4&T}*;P_g*WM}Ntb~aO|iB{&M)dVUH zX_^B9%A1>;wVhS!oO@a8XQi|#AI$gE$Q-tZ<3L6JTR-u4H~L3y$G_+@2mB9S9-f`8 z*=(JBv$jkfA#B&r)Mx~0Rgyppdu^?p;(p*r9>kZIm;3Bd_)VG@hM>hWt0tl*$H&W; zzl%}Ru|Z``L474TZ#!H<5P?yW$X6jy0c~>Z*O@{<+?9dl+Ugr*5oyC=68~3-5Ls17 z=!~qaUgPy|oMwNlPPdJjbF)JcfqJ~iBV;pJu97H8%bu@#&pT=TERC?FxS-_DLiK*j z;we_~HZIPynpVgxhf}Y(bRLT8a-B?ao7t5!S@_T*4v>lEkX)V~q;57@-q~D+s8}Zv zT?!(@*d$7wN!kQ^H2UlHw--fZRauS&G#It-7Xw*Ub8~YhN_$hO)9iV|`x9-2=_;Si ziyIx*DkSps_O|cQu1(mNsL?v#=Q|*D_mN%MsetMcqSX;<>TMpsOznck=4OG|7Ia|Q zeCcd@YQjDG3G=XFsFQI*HhOrI$nYs2nb?-k`&$W0dfzhn`t|E0)^#JcH=V7I3q1%n zeKfd12|pL!l-6>Uot|E;ZY*()%KF0>gKN=FZsIFzL$P@G)k~KyE%79ZyXCnpPl|^f zKm@1;B{}27-Rz+`gPN&bWG!L(J?=t+>(^|<1IPzjo5bTF$<~7TD$>%@HsE;my*DSI zJZ?GhAV|Jy7+t5!T#wy+e+h*d5+iBRPWqp;!2cp&{M)|!*G>6nC6Ly%9@3hWlJ|;% zo7=&nWDJH#5{y=&?f?EV>0C-W_9^oLCLuJ%xa#-O3><$dVKcwO=h?@z)OM;y(v*6H^GFzZO=uww~$qd?D_>Jgg!N{L~L;h{58?4h)v zfdqamMi8(lhEu05Yvv3;r<3I3;SqP4kE1A^@c?4tckZ)uey1BWjliby4^*=xYb(^*8qC^kM?67N4eA2ynv?b1KX8w@g zCLWxZy5-x3J+0@z&>YSG0P9r$ngbme6chwP5c*Uo4*q_Ax^`aJE+$6CnWn2xEpiV1 zrt^ix=JQ$n1>~+2eD8;jjfD43vFT=J(O`^}4~4S3y( zoH}?P@zoHP`x_$muin=$4ELYi)P$@tQ0cP=kmZ!@Ocz$e-ac>71YDjUa-(jE{Y2|W z8R$hlK*|!W%dFaxWW8CBB8Zji_NyGIX34igu0<=M1A#M1Q8vbUdwVG?e;~+ZSLUxK zEiqlZ0WV%Kg~aI|LIjSi3d>E8IEp5E%$GETU(h#_o)WlqTTwVPkZ&Z!JlVRsI@=!LH!a<=?9qXb|TVaTqril}nBim02T-HShf zS8GTChMZhn^^dOTdc<$}-&KuhOO}(9m4%*<#Do=L3GI^yiP~v#y4xROW`~U9CkF<1 zdc0=$a#$+I8Q#aIBEGZtB+%2&WzZ350N;6BeFV3uUD<=x$cAEZ-L>$Gv0!Ytj?lIt zA=I62DLK!4>Vi9A;pna82J-^*98O#>&(Do+HmcwTv(CAm;nw1cGX{t2D3Gt z9LB)F;Jve0S5;Nz9se7Gw19npS{N>l*0>BfbwiL(mzV-VEB(ZbAnZ$mels&O7TsC9 zn-U3OfP`%yE(`j-{Ppr?{U2Uxd^Hpl6z18k7G|tG+sCVV&L8DY7C{Igp2J%-BS5+6 z^shav@OT3sjz6#DQukP4qe6xccFWQDqiS>CD`A)*AA7$BP5dbn4FeY@2`wsUCn$gDd4AlV3RTJP^e8Y%k`k`B+k2f z?9P|0i<*Wpe?+)NFj=YmDP}=6-E-HRYOd$*fbw6yp|k<#l=?`@PKVobM=*9r+H+Nt zi`7%bY(r^_WR6#D&Givc?;<1p{r%nE0UZRc8K<0DV^A{M2)a=p+DcQA%a;KKJJH5n zJ9iE(QU2GsijaVw8nvY<~%E=oTbIvB4L%qu*XapWItyqYsQre{ZY%x0B~LeyPC`knFEsy#h!R z0wVeIXA=NYmD>a4o&W6!=)l_mAYud>e3Iqst&8+a9WV0oDwA)ZZ0DTI2CTzELb^ak z!@z>~VTET2WG;8Qc30a>RSAIM6cEJi9={(-@}Vpz`fPP1t|U-BH3Ka|5MO>ZcwgYs z-+FZmk*?E&y#IJIbP|pQ!QY3cFj#E7a;ggNjeuJ(3=e8l=c@9NEBCzP**(|2AnqHct~FBVA3K>m~P&}Dgl6jdqMYp6Da8QQx8Bfv6uu=0F&BDI@u>h{&|H)yTs(JUhNM$@eyAC@@-J9ku7ePiuV zO~(_X#&4noutd+NbM>Zw;_7TS`6)GR2rHrwFWCL`sZ){f-a&H@qX*IW?~qN4o^;fboTDNmdo2qAS>)yA0ZFS| z>$6BzNrai3v!`cy^+@YBpD&?UH%mj`s{q&*ppdv`E;E-UOPuOpwaP znCs~!beXeCdGF*J)`1p;`OdqyZ{J2mx%})*2Y}}S)DDAzdUi`z)~dgFahqJJUu^pr zWJA+L`+D|KPruem<#>p8<*=lk=F70u0I>;ljkz9y54$Rc3Y3eB8&6QnG5an~FS2E;H!a7tcYyMvFty}hz> z=+uY-wBuK=Ud3P}VV(rZ!Io4l4>TTPK6_K&AwIqC^zx_%7cZ|h=pA5NQzo9CJMbhD zOjf$a;E|-{ATFlr(R5Gl81#{jtPI}kLv+W{foW+lpPST@v>W}#4>7&7Oz^NSH1BF8 zIVTF+8FXho)c0DAoca#=qle^dMO{fjq4$D9Ln~}Y(0~Tz|M&0TgM2duW>&$TyExi%jnb9%undl- zB~GMc#CC2FzYUE5lZgFz*8{dIkG{-wq@L#0owgMQCx#@Tr>U7Xk>9FmX*u_XQ!bQE z;-{N?R6>F`FbweR3l}b=rli0iYaNED;2h{c7*E230qYxd5RC^~Sx6+6Yq!>atF2|A zr)L4rfyQllNrTO8{0+=kSj3!xHtXt^gLvck2E)$>U=X)ulc*-+A|lvu)>e2&$a(Aw zKe|!}%mcH7gx9Z8r%n~PE*cQVv@I=qX0(ae!njp*#` z1Q@JJXd5ww+Oj=|Bb--`>nV3zj*DZ7ZE@b)T^F+W(fsHNz=_AzFG>ci2|_Wlva((q zBLTpI$HyDO!{5FYv>DXh-a5k+0INYgkXY+63`&Ah!6_+OvBq%@CORA(H)tN?dVzvl z_!u`vTx19~{uX!~xWvoQP&P5=4Adyh4fbcr85av~zkxx#=KlWzXyG8EnBscdJ31_e qUIbxBh!>!}8cR|A|836u_NWieAsVCKi~`^u$Zc8Wn_1{b&;A#cpO|R? diff --git a/scripts/evaluation/toy/toy_results.jl b/scripts/evaluation/toy/toy_results.jl index 7ceb93c..7dd639b 100644 --- a/scripts/evaluation/toy/toy_results.jl +++ b/scripts/evaluation/toy/toy_results.jl @@ -43,8 +43,8 @@ toy_results_names_scores = Dict(map((x, y) -> x => y, modelnames, modelscores)) safesave(datadir("dataframes", "toy_results_names_scores.bson"), toy_results_names_scores) # load results collection -toy_results_collection = load(datadir("dataframes", "toy_results_collection.bson")) -toy_results_names_scores = load(datadir("dataframes", "toy_results_names_scores.bson")) +toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson")) +toy_results_names_scores = load(datadir("results/toy", "toy_results_names_scores.bson")) ### BARPLOTS diff --git a/scripts/evaluation/toy/toy_summary.jl b/scripts/evaluation/toy/toy_summary.jl new file mode 100644 index 0000000..d64ce17 --- /dev/null +++ b/scripts/evaluation/toy/toy_summary.jl @@ -0,0 +1,63 @@ +using DrWatson +@quickactivate +using GroupAD +using GroupAD: Evaluation +using DataFrames +using Statistics +using EvalMetrics +using PrettyTables + +using Plots +using StatsPlots +#using PlotlyJS +ENV["GKSwstype"] = "100" + +modelnames = ["knn_basic", "vae_basic", "vae_instance", "statistician", "PoolModel", "MGMM"] +modelscores = [:distance, :score, :type, :type, :type, :score] + +# load results collection +toy_results_collection = load(datadir("results/toy", "toy_results_collection.bson")) + +df_vec = map(name -> toy_results_collection[name], modelnames) +df_vec2 = map(name -> insertcols!(toy_results_collection[name], :model => name), modelnames) +df_full = vcat(df_vec2..., cols=:union) +sort!(df_full, :val_AUC_mean, rev=true) +g = groupby(df_full, [:model, :scenario]) +df_best = map(df -> DataFrame(df[1,[:model, :scenario, :test_AUC_mean]]), g) +df_red = vcat(df_best...) + +s1 = filter(:scenario => scenario -> scenario == 1, df_red)[:, [:model, :test_AUC_mean]] +s2 = filter(:scenario => scenario -> scenario == 2, df_red)[:, [:model, :test_AUC_mean]] +s3 = filter(:scenario => scenario -> scenario == 3, df_red)[:, [:model, :test_AUC_mean]] + +H = [] +for modelname in modelnames + v1 = s1[s1[:, :model] .== modelname, :test_AUC_mean] + v2 = s2[s2[:, :model] .== modelname, :test_AUC_mean] + v3 = s3[s3[:, :model] .== modelname, :test_AUC_mean] + V = vcat(v1,v2,v3) + push!(H, V) +end + +H2 = hcat(H...) +H3 = vcat(H2, mean(H2, dims=1)) +_final = DataFrame(hcat(["1","2","3","Average"],H3)) +nice_modelnames = ["scenario", "kNNagg", "VAEagg", "VAE", "NS", "PoolModel", "MGMM"] +final = rename(_final, nice_modelnames) + + +l_max = LatexHighlighter( + (data, i, j) -> (data[i,j] == maximum(final[i, 2:7])) && typeof(data[i,j])!==String, + ["textbf", "textcolor{blue}"] +) +l_min = LatexHighlighter( + (data, i, j) -> (data[i,j] == minimum(final[i, 2:7])) && typeof(data[i,j])!==String, + ["textcolor{red}"] +) + +t = pretty_table( + final, + highlighters = (l_max, l_min), + formatters = ft_printf("%5.3f"), + backend=:latex, tf=tf_latex_booktabs, nosubheader=true +) \ No newline at end of file diff --git a/src/evaluation/plotting.jl b/src/evaluation/plotting.jl index 8113053..fa7e2cd 100644 --- a/src/evaluation/plotting.jl +++ b/src/evaluation/plotting.jl @@ -1,3 +1,11 @@ +using StatsPlots + +mill_names = [ + "BrownCreeper", "CorelAfrican", "CorelBeach", "Elephant", "Fox", "Musk1", "Musk2", + "Mut1", "Mut2", "News1", "News2", "News3", "Protein", + "Tiger", "UCSB-BC", "Web1", "Web2", "Web3", "Web4", "WinterWren" +] + """ groupedbar_matrix(df::DataFrame; group::Symbol, cols::Symbol, value::Symbol, groupnamefull=true) diff --git a/src/models/PoolAE.jl b/src/models/PoolAE.jl new file mode 100644 index 0000000..07fa316 --- /dev/null +++ b/src/models/PoolAE.jl @@ -0,0 +1,304 @@ +using Flux +using Flux3D: chamfer_distance +using ConditionalDists, Distributions, DistributionsAD +using MLDataPattern: RandomBatches +using StatsBase +using Random +using Mill + +""" +PoolAE is a generative model which reconstructs and generates +output from a single vector summary of the input set. + +PoolAE has 6 components: +- prepool_net +- poolf +- prior +- encoder +- generator +- decoder + +Pre-pool net is a neural network which transforms all vectors in given set. +A summary is created with a pooling function which has to be permutation invariant. +Possible functions include: mean, sum, maximum, etc. +""" +struct PoolAE{pre <: Chain, fun <: Function, e <: ConditionalMvNormal, p <: ContinuousMultivariateDistribution, g <: ConditionalMvNormal, d <: Chain} + prepool_net::pre + poolf::fun + encoder::e + prior::p + generator::g + decoder::d +end + +Flux.@functor PoolAE + +function Flux.trainable(m::PoolAE) + (prepool_net = m.prepool_net, encoder = m.encoder, generator = m.generator, decoder = m.decoder) +end + +function PoolAE(pre, fun, enc::ConditionalMvNormal, gen, dec, plength::Int) + W = first(Flux.params(enc)) + μ = fill!(similar(W, plength), 0) + σ = fill!(similar(W, plength), 1) + prior = DistributionsAD.TuringMvNormal(μ, σ) + PoolAE(pre, fun, enc, prior, gen, dec) +end + +function Base.show(io::IO, pm::PoolAE) + nm = "PoolAE($(pm.poolf))" + print(io, nm) +end + + +""" + pm_constructor(;idim, hdim, predim, postdim, edim, activation="swish", nlayers=3, var="scalar", fun=sum_stat) + +Constructs a PoolAE. Some input dimensions are automatically calculated based on the chosen +pooling function. + +Dimensions: +- idim: input dimension +- hdim: hidden dimension in all networks +- predim: the input dimension of pooling function +- postdim: the output dimension of post-pool network and input dimension of encoder and generator +- edim: output dimension of encoder and generator, input dimension to decoder +""" +function pm_constructor(;idim, hdim=32, predim=8, zdim=8, activation="swish", nlayers=3, var="scalar", + poolf=bag_mean, init_seed=nothing, kwargs...) + + fun = eval(:($(Symbol(poolf)))) + + # if seed is given, set it + (init_seed != nothing) ? Random.seed!(init_seed) : nothing + + # pre-pool network + pre = Chain( + build_mlp(idim,hdim,hdim,nlayers-1,activation=activation)..., + Dense(hdim,predim) + ) + # dimension after pooling + pooldim = length(fun(randn(predim))) + # post-pool network + + if var == "scalar" + # encoder + enc = Chain( + build_mlp(pooldim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,1]) + ) + enc_dist = ConditionalMvNormal(enc) + + gen = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,1]) + ) + gen_dist = ConditionalMvNormal(gen) + else + enc = Chain( + build_mlp(pooldim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,zdim]) + ) + enc_dist = ConditionalMvNormal(enc) + + gen = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + SplitLayer(hdim,[zdim,zdim]) + ) + gen_dist = ConditionalMvNormal(gen) + end + + dec = Chain( + build_mlp(zdim,hdim,hdim,nlayers-1,activation=activation)..., + Dense(hdim,idim) + ) + + pm = PoolAE(pre, fun, enc_dist, gen_dist, dec, zdim) + return pm +end + +################################# +### Special pooling functions ### +################################# + +bag_mean(x) = mean(x, dims=2) +bag_maximum(x) = maximum(x, dims=2) + +""" + mean_max(x) + +Concatenates mean and maximum. +""" +function mean_max(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + return vcat(m1,m2) +end + +""" + mean_max_card(x) + +Concatenates mean, maximum and set cardinality. +""" +function mean_max_card(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + return vcat(m1,m2,size(x,2)) +end + +""" + sum_stat(x) + +Calculates a summary vector as a concatenation of mean, maximum, minimum, and var pooling. +""" +function sum_stat(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + m3 = minimum(x, dims=2) + m4 = var(x, dims=2) + if any(isnan.(m4)) + m4 = zeros(length(m1)) + end + return vcat(m1,m2,m3,m4) +end + +function sum_stat_card(x) + m1 = mean(x, dims=2) + m2 = maximum(x, dims=2) + m3 = minimum(x, dims=2) + m4 = var(x, dims=2) + if any(isnan.(m4)) + m4 = zeros(length(m1)) + end + return vcat(m1,m2,m3,m4,size(x,2)) +end + +""" + pm_variational_loss(m::PoolAE, x) + +Loss function for the PoolAE which mirrors ELBO for VAE and +should create a latent space mapped to standard Gaussian. Uses +Chamfer distance and KL divergence. +""" +function pm_variational_loss(m::PoolAE, x; β=1) + # pre-pool network transformation of X + v = m.prepool_net(x) + # pooling + p = m.poolf(v) + # pool encoder + z = rand(m.encoder, p) + kld = mean(kl_divergence(condition(m.encoder, p), m.prior)) + + Z = hcat([rand(m.generator, z) for i in 1:size(x, 2)]...) + dz = m.decoder(Z) + + return chamfer_distance(x, dz) + β*kld +end + +""" +StatsBase.fit!(model::MGMM, data::Tuple, loss::Function; max_train_time=82800, lr=0.001, + batchsize=64, patience=30, check_interval::Int=10, kwargs...) + +Function to fit MGMM model. +""" +function StatsBase.fit!(model::PoolAE, data::Tuple, loss::Function; + max_iters=10000, max_train_time=82800, lr=0.001, batchsize=64, patience=30, + check_interval::Int=10, kwargs...) + + history = MVHistory() + opt = ADAM(lr) + + tr_model = deepcopy(model) + ps = Flux.params(tr_model) + _patience = patience + + # prepare data for bag model + tr_x, tr_l = unpack_mill(data[1]) + vx, vl = unpack_mill(data[2]) + val_x = vx[vl .== 0] + + best_val_loss = Inf + i = 1 + start_time = time() + + lossf(x) = loss(tr_model, x) + + # infinite for loop via RandomBatches + for batch in RandomBatches(tr_x, 10) + # classic training + bag_batch = RandomBagBatches(tr_x,batchsize=batchsize,randomize=true) + Flux.train!(lossf, ps, bag_batch, opt) + # only batch training loss + batch_loss = mean(lossf.(bag_batch)) + + push!(history, :training_loss, i, batch_loss) + if mod(i, check_interval) == 0 + + # validation/early stopping + val_loss = mean(lossf.(val_x)) + + @info "$i - loss: $(batch_loss) (batch) | $(val_loss) (validation)" + + if isnan(val_loss) || isnan(batch_loss) + error("Encountered invalid values in loss function.") + end + + push!(history, :validation_likelihood, i, val_loss) + + if val_loss < best_val_loss + best_val_loss = val_loss + _patience = patience + + # this should save the model at least once + # when the validation loss is decreasing + model = deepcopy(tr_model) + else # else stop if the model has not improved for `patience` iterations + _patience -= 1 + # @info "Patience is: $_patience." + if _patience == 0 + @info "Stopped training after $(i) iterations." + break + end + end + end + if (time() - start_time > max_train_time) | (i > max_iters) # stop early if time is running out + model = deepcopy(tr_model) + @info "Stopped training after $(i) iterations, $((time() - start_time) / 3600) hours." + break + end + i += 1 + end + # again, this is not optimal, the model should be passed by reference and only the reference should be edited + (history = history, iterations = i, model = model, npars = sum(map(p -> length(p), Flux.params(model)))) +end + +###################################### +### Score functions and evaluation ### +###################################### + +""" + reconstruct(m::PoolAE, x) + +Reconstructs the input bag. +""" +function reconstruct(m::PoolAE, x) + v = m.prepool_net(x) + p = m.poolf(v) + z = mean(m.encoder, p) + Z = hcat([rand(m.generator, z) for i in 1:size(x, 2)]...) + m.decoder(Z) +end + +""" + pool_encoding(m::PoolAE, x; post=true) + +Returns the one-vector summary encoding for a bag. +If `post=true`, takes the bag through pre-pool network, +pooling function and post-pool network. If `post=false`, +skips the post-pool network transformation. +""" +function encoding(m::PoolAE, x) + v = m.prepool_net(x) + p = m.poolf(v) + z = mean(m.encoder, p) +end \ No newline at end of file diff --git a/src/models/PoolModel.jl b/src/models/PoolModel.jl index 16eda84..e5420f3 100644 --- a/src/models/PoolModel.jl +++ b/src/models/PoolModel.jl @@ -123,9 +123,9 @@ function pm_constructor(;idim, hdim, predim, postdim, edim, activation="swish", return pm end -################# -### Functions ### -################# +################################# +### Special pooling functions ### +################################# bag_mean(x) = mean(x, dims=2) bag_maximum(x) = maximum(x, dims=2)