From 10d66a332f9cf58f45b5094642c4e39ed5cda8d5 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Sat, 20 Jun 2020 13:22:47 +0200 Subject: [PATCH] . --- .../dueling_double_dqn.cpython-36.pyc | Bin 0 -> 16301 bytes src/agent/__pycache__/model.cpython-36.pyc | Bin 0 -> 2216 bytes src/agent/dueling_double_dqn.py | 512 ++++++++++++ src/agent/model.py | 61 ++ src/observations.py | 731 ++++++++++++++++++ 5 files changed, 1304 insertions(+) create mode 100644 src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc create mode 100644 src/agent/__pycache__/model.cpython-36.pyc create mode 100644 src/agent/dueling_double_dqn.py create mode 100644 src/agent/model.py create mode 100644 src/observations.py diff --git a/src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc b/src/agent/__pycache__/dueling_double_dqn.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41afc8cad7a82889d0118d5fd836cf1f15af34de GIT binary patch literal 16301 zcmeHOYmgk*RqpP2?#%3JC9U4F)RtviV@oSpHYRp#No&2VV6SCq726X#?a_4aYG&u9 z+dZ;0%#z2L1WZgEUI_t`;%ENkHwFJFen1svC{ysGD}O>N<){K7q$mnfQ1E@{_Vmn- zCG@ayjPY#s>3eVAd;8vV&pr1$=kz`}Ihp(8Z@&HM?>uW5-!n#j38atU^1G&CxQ1)i zjgr2bHLGNK3D@eIrKFoErQBpG?WRf@H(koQnbL%tE#=&aQr^v#3U0nM=^3>tTx)96 zeM6mg3r)+NL|c|kMX9ykRNpiW&#J8}t@qY?>zO-gxKr+$R}FVf-zv=@H|?%PZY^>f zWJDWF8|A*Kv`OxpOPg_DSKrduTH4y!R@&CsUfOQzQQuM8A$10`*jck{cb0ZZnal{O z_4T_NyGy$pcbD$QSVm_!gBhkV!xUz?0W;iq{TaSts)ufAj+<7^ar4L=&9g>v%T+GQ z{1T~A`?n)i!!nGeZ=d{bcjnvwba?4MzV**<J$?Sj;iY{?xBu|(AI=@VI)i#wb0mKb zvT!v8+D%tFZu!vzVWQIRR8SEA;$^;!BF`w9u2HgFvy^amx=A<nDwjCzW?nV0%vpB= z_q1!fdAEQk8F$j1LP^$L<4)r~;jVSp;huBvbl1BxsF8O!xEoPYaCf<z-7P4Ybho<O zP%`CickjS`jk^Pryemx2EqTqLyZ=-(@Km`H`1UK+;DTLlx^~?wtEO)+s#e1eE_k-r zyi`@KW&`zS7Zj&g$wc$vodrj6`F{t}Gkb<(^ej*_uxg23vd@x4F9DkDr9g4jlCHUI zsf8eo`c^O5OZSttOpukj6G5()N1X}D6`H$(NnNX-xL%!LiuG%2`bIGATCel%YlC$? zD_9@QU<66zH;m;sj^#I14b<A~a-3U)ZWX%CO@d#k?%$qjR)cD}UhR5dk}ce`Tjy(D zC78|S?C5v8tjY~P&OR=`ILk_YP!2rDXE%GQ&7k<G?N-5AzuIcrtwq}_S1#x}t6E`! zP~is6bfs18e9v<S?@?G=tSF1(I$`o?t9hwthE`oE=pAFexIN4b#~9{E-W;ZN7h#fp zgqcg8;;_S<+v=RJdyadt8BRPq?**^4)bhd53R+>>^)6K_UO08Jsf(R@t5U9q$x5qz zCCs{>*Ot_p_<2xPOI{FW2mOST=guB?o;kg6>giLZ<7dKLn=JMqw0>wkbtar>wS#H{ z>{MaljMuK0uRPINTtqJa#JLkEj-PSPqFI=GVs7E+lRBMlc#W325*GB`S*$iOqHR&G z%x&RR%1x))!u>+4<$GaH>NtMjwZn9v?*)AyG>>xEYF^BcWA#NvC86>lVHWIgf-7xL zO|ona%ciQ$c89MlH=ADF|CXVODE{fo`y0(*e`WtUA5-4nIrzZCk38}qwr#U@sZ62R z-}E}F+V&3~ICx;+V!a&HvCjLjgbF;`w_FYO`>L{E#<kz=c=c*?$r&!#ESJG?st|6+ z*@^f0yFrYJw3#*wxYAb6+-Oc)Gv<^vlbAN~XQdN-+MSX;Fgs{mftr*b?@3&u6AmB| z7&YjC8ni(l`oT3f8GXwflqRE6OP9I^c3Wb;o1PUx>1JlNKonD9Drl+71+cOV!CQ}h z2ZHKWD}maGf}*AFVR?43TKAfe65*gG_u+xs2O`7b_zX;}o70nEpVyv14;49qniUP* zR0lDcL@gPlYEWM?9h0YLkU@^qV0zFZ!Mt+OA{)1mvbaSuYB3?b=AgpzZc0>GL70&q zGNYiGV(&X-x0YkpbSEPo>GGnfCNSpWWN5YgFxf5#7s9l6x#|bL=GvNitL#FCsfUDh zdr>`1vNTM!6_k=NQ*B1E+Cs9GWE)7tV6~qGl+F>}K7{N^T>cb@0YTXa0n&89WxhhR zOw4MYhP-4^-a>*5#mVp_m|>8afw_@-85i<=@@4?-4w*57<Z@EIDNIZCjTM%rq46`6 zt>zcF6sY|MO0&|>g!Ge@{IH+gNI$vSJNr3=`gv(zkoG*6@+0E_tcu4m`Ofx{G|QZ? zQoR<Z)LpoRg>$Em%`F^vj(_s_nHRzYq*ZbAsU3Wn_REcS-3wE45QnLgb5B1#r?@Y} zb))T^6Q}0qo>Dd&-p!{wRz00Nd3=7sIXXx3<TKBlJ+AIVrD8(wk+hHV9@{DJU4|B_ zdd-UGk1c69>2%sI#iY@4z4~di_IH68i7Xa2XHLmAl7@cJrQP+BZWwJ}NwY0}yz`Pp zXNO=aRRo~`07Ea;Pnj$OEaEN^CR+2tM2wz%3KM15RY#C2ruB4EBF|x3A6Gc3yjRM~ z4RcNJGBq<M5hmSMb6g`pn?R=%>-nM*pyvj{$_MXq$XBTfF=X;x&5;~ofgZAYf(7J( zI!1Dw<OIp;jTLU;NE}h%W9L#GwTI*08L6`C)m{bK2;7AKvq%j1-v=07&Z_$)4}CXS zk6wByRUNK{IgdGU-NFQq>a<?uVn(k-L5~!tm|OR5dXiIEP<@oj>3T7q86>Scqi43d zN~}Yu?c(V)k-7Y`o>1R7X1sLfMQg=c0seH&EyhX`cMJCv?ujkXHk*u<G}1|wWpGd7 zP8}i06=*fx%k;95I3{U;na0J!b7UBFMI@SOEt=|Kyjnd8ay3ChdBRc)pHkJpYxsJn z#&Qs@b(H4=2MobQaYFNOEv5>z*#a;@9g$?y36t#Qra8K%QC{{&`tD#n{yGrDoH4Wb zD+tBGZgw|l29LF<n5>N-?-XSb7=jqV%6el$7)+TYZuj8PnDi=`xyEO4`Ru?jyPNc@ z-ZAyJ`Ix^7xyK(F$d}1mKaYgq2#8_L>f;2Ua^R8LH&=)*(zOg%1a-1~<IB^cKyW<c zY2rDYB+nXWjOV#-Y1kF4m#g8Z?O8ZO>OsC%=p|~C=yi%VL?31#o=>BEEn2MWqnESj zgUF=&))_D8&^*$b24+%A*-O<+UQ@$%yX@OgsCK2J6p*CdZq=)mE4OYU*n7%Km9IGU z>aqtTVbGnZvKYSG_<3#L-m{3Km9_*j)q1PBS4I_O_AM#E^()bUf&c5USs+_+FA?xx zuLKC>X!kkIP6Nm=CH;mIou+@W<9PttYq1WBQlzm&(z^ugoN2b4B}MoUsB*Qb3V0q` zO<-k!`9}G&qPB<rEeGW=rF%{TS<^5;jdF9z6I5JGY6*IRwG$Y9I2l<VcuD7-$5R4N z`>5+rgBaPZIDB%D?OdW@&RE@vNNH4pSfYrvV{7%#km;E(rdARw$(7Vf8v7r|P!GtW z1`GnL(bAegEl7D9<sx&UM+nwvwd%SPHMnFTCDux=xg&m(9!M!Q4r)+RbI>VPFE;}$ z)5~l&W(??`e2}SSd)Zzd?F(o<LCJ*t7HWC9v!oEECS@#mcQDmYK>s8llj-jF&&e@u z3!Chrfn;}lp#4ZX&zEuBi-{Q@_;%0%&hT$ZI7i@uy=Q^#KP9+)b^%6gOT7dklwSF! zJ^o(1yw_HHZU5fAHjm$8^vYQFD8WnRMx#8GVYlj6S{-biMXYa2#Z|R}r7oJk`d|2| zMI?)Bv>>LJfO9>iEvW=$js}5x0fa_q0JAzYPlVZq@4*E1eSx##1SC2V-R${tWto~* zon)K5-&zbJ_!A};=FSx-wQwHh)07MuK#jn{4>OT22s2t)`(b{FJNz)knfqbxS%=K4 z1hAbaFFNNvKX7!zaOzn{n>oJ2hOnXsdE#=6b~PkWk75vC#pM&D8xvEBbygbCWou$b zawbuL$(F?~=;rkf5ENOVHsz07G2R2mkhy#)fF7)ww9QES7EaAQI5n{=dPzA66TMW; z5}8eGHdscvm9ub$x|Mpd^AHy1EO=&XG++y49FfJ@cpcyTqiAyW<Mv;@dLOc4>e=_( zdmga&p{U3*&5v90tVk~!s@i)_Us<5FsD}ZWe8sCST@WW%y>6fPY<L~2aE(yc#RH=( z>et?4FF|XpkIqn<VNQ7phk)4XLw>{!mbfT}@#Dku#blW1R@>U-a!?kg=o@hc5H~+E z!LTqo&@d5{J8_Zh|1hd8nQ`$UNct088aw8;o87fqR7Qk;zG!J*AI^2|x#8`T%uUm4 zKoi?>!kj~Xb>IPV!h+*m?3C-#lZ*pbL&b5_{b(cBx1yOSFj)4PI>qeMB<Dzo7L-Rq z5my}O$lzrch#^nBw;h*H(==-)t=#TxGCPsmluh9;nagJrxtZ*4#d;N7@sB9v2rmC) zAgkQ3mYaZkG$~G4I$yODHXS)(Gj7h!<9XICK-?!4U*6q!qI%hbuZ7OM^Wt@~gBD;w zOTlUS4lP^Zw&Ra?4orZ}5hQ@LVd4hB+phUK)3k%&Gz|bbU4T<pMzaKmu=oX10l-e= zSk1T-7*DRe0c$ksp|nSRVYlO_)SBtmJ&u|WtJ*;uR~sF89?2&h3%W|Y1Gw$va9Y)9 zP+1(J0xNiwNxfAqZ{*>GZ!?Vnz`_o5$Ky4*E{aC={9@IsyCU5Fd@-vXb}82b0@B83 zZa7*PnAiiZoOoewxaI<Z#J5*W>2pKee)$pZsMoOa+Wj}C-MLx*5|>Z(Yvi;eZZlB& z=$a1wZq(@fQKVwAI*1IhLJ&Ew`WB5j$-(!I9E@yOKH|tCjscd?_;%-<>)K_z>AkW_ zgj=@ukIqJ5F{QqQH>j6D#DxHB56&dmIAWfs%k8$;bR(1Sy37vEaeG{><nN%<$gJIE zcBfa(8irA9mT<E1QN~7MAi#Lzw6K$`6*y;V3^$~?k6y(-JBDM6e>4_q*#IGi6)UcT z*xIT(fMHP`?8K;_)L5NiKgqaG`nq*eah=R)omxI9!2DP}E;zHaH>8(dOks~_yAPfb z4}SfMtsO*IrfYz@C~y-k7v$MojyXo{M<8$(m#EGq2+2$9U_jCGpHN9a0OKpxMTR`Y zup<;NTA?+2K)r$Fh%TlY{%P*9I{b>2R-=9B`|Ohar;YBttCma;cJ?tjzWhT-4BFiO zxTreejQkDUt~D_An|Saxu4qKJ{~{u!8@=f2d-0;<H}Rq$z85b#aT724$Ki|Mv7|Yt z9wFi8)_Q9%Q_q2f+nmwWcAQIew^f#%YNM^x^Q;!zj0CgICD?qvgg!)hq&h%I9~)>2 zqHA@LM3ML;0SUbqW6~!>`Yy^f3bxi+I}nx_z3xHyqPu={c(Ym<V|gI5q0q`vOmzz^ z3|A!W{i4~uU)%Z;1dR>w=nxZJ2mj1KFG5HNDTGDWc7=pMLpX*U=O?JA-62+YX4UAT z3bAX8dzF_A?+6kY9*om1Q0sk5?SLy>axCeai1NeL6)-Z9fVRRU_#8o0Yte#f2^Sb9 zl^AhW71Lp^3JVeI1@CW|J;g{c0(p8OvN|wm{0rqae9_{K7+C0Bp)lij&Qs@y6V0}E zv5wm*;nZqz_*XRZ<+w6;nOn{7wz0{Niq1TIVV_5R)hBtC<d;bVNzO4P68SZzev4#y zhz?;QQQBYN@+rcEiP>qu!?oF@;5T?Ne~m4FmE`LnSVOuJ)FzOnMfv^c@YOxYRez8e z1Cq!aM{xO<K~@3D_ce|$2=}8X{vgW#|1*lOvXDV2oeA6`e71cUY#w4!HN^@HHfuSL z6|?wI)C?bNAQ9gQcfqxSfU<Yg9em1gl|B?~rtT!!MRFI(Zj!q}2HKqpR^YU{mnpjA z^k}%nMm8|9g>3X)(9a*(fV8+vwNbmC1D+)z=b%RJVftE8(tUVvg0rGpzqttXub^T; zham&!0O|zs_^>95H7kZFflkFHjLM`IPYQ`NkF=0jvudM7Hf|whaf{SQG&)H~B%;yj zk!W;gzB?JHE3|VN&3)94BlWiyoK`=FS*Zgg50KDa(Ng#zQvxr<RUo3cnPs;Yiv9$d zOl@w&Vd#&7(PQVo4o0h=CwF7;qBpf1uHu9FIG@T^evB!bgqvDDPQopZb5!L05DNss z$q6h(w(=1?Sm1h4H~j3Q$G07d6_?d=ynEl<V6lV_-)*p1{Auk~BfPz^@WWUrs0j(n zM`7Vv$zMAx9L2|??;RFSL;+&;33Q`=0c3ClJjH@1NRE;mBcbVYt;o0d1v{J!&o_Qj zvK0{S|5`Ly&2xIsfJD(?b(#e>$sEZMl4n7p=&w4%g0mzGBsa?~I#7G$);c8K{~+%d z`D#mEeY-*4*U<WZDahM?XOP!TEW+hA9_V$GEEQ*1>cboCJ+V;K;a;)fwVD0`xIeVj z^<i9M*{4WeB%v84K0ozorfAA3hva3FGRb+8JjlRS6!ERF)Fl}^8a$M}M#+(*dQ}W} z^G7jUeDV3Ajp0i8b~J{omM9-$om^l_ke>);mDw7}G6}y#RCSUD$w?B46t$2Vjnk?& zA6+C-BtA(%BJQM1Oua%v8(v)@>5`1eC}GeyD5Hd-A4OXkud#vt@g4Uq#z{X|(N^_~ zT<a|)qQiPfMX^@3!qQ(N`3%WtNj^t1CK?pGKc#4hId`j()IUV);`H_1Tr?Kn!+_~d z4s<K}j02{*C}5fy1Wc)?;~x+6Be*__mEPcgIzBwHQuPI_xB4Q9oZHs`>F_hrBQc1K zYVYYzwjg}d9@Q@~C18O*p3xAfdOfbmN8|hHew6(U#WaV+2S4<wew}lW$$uHCfsu1P z)?G7HHHU`Z<CH~(SRX|^`PuEeig>E8fTeMWhgyuRr9pKIG0*=ww)hWRF!ua`iFW>I zIHKF1`Z3JtM$yh6zZWl}$$z64{p7uP(TSUQ(Z38|bfaje4mCZGPDd=I-)8oANEkCy zUj-Rxg2ODhE&v?^6OW+s4>%UVnIHIgrurJj`|ktZ89(b-rutp-`s*YgBayZEJ*Ivi zWMGZTTKoY^zd<s#0^da0e{x2&Y>oFZNcq=%DG$uiAZ76$@q?H0`U6|}N5@i1ZefD| zrgaAzXm{<$nWFil|5r;qy5g>VoMmED?O}?pp;D6nn_MbXeW!e>T&+vUm+tsda?UL` zbjT;e+S^GaDs+G;jxWqtTJ`7ya(^s*s$OO-^1ZZP54qW#UF<aFYjb=iZZ(^w+%PW| zU1<&eJ8`zmf4ZENPez9k*pbDNWx9{A&pxI<H9Q1{8UBkP_|(+`*YsaP?xfacd<>q& lXR+DD)Mt#@sR#Kl4_SN|uK)ERXHAc{ve=3rk;knL{x`s99(n)( literal 0 HcmV?d00001 diff --git a/src/agent/__pycache__/model.cpython-36.pyc b/src/agent/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a13378b2cbe5716b1bf04b774b5fce12ccef39e7 GIT binary patch literal 2216 zcmb7F-HPNy6s}4o>GbdXt?u}@uDc=xq@}w-SdnFMKoPV-WN(_rBBrW$(wnIsE9sea z$b!%Vd*y?8A@~$NMBRGTkDxbtP9^E-URJ?s^7T16CzYyG=le41c3Xe${rcTBAmk6S zHC(VC!IeD(5Jb?F%;|($%bC+To48*P;Rtq31e0{`t=PneJFf8Hj#u4j2>LAv{WZ)% z#xY_&aDE)w2LO`b%o9g6UXU3TOt{z02@{_1ugS!HL0mHNV4lDP-;8B`xzJD7kKnQO zPXMtGfCK+ob!9U!a+1@6tem;CB9!E8?iRFgS585=XFa#@tmj+L6V!T*oy=e-vsq>q zj?HWd0;_VyYYJ}*+s4b5=Td72JQHx9s3hmyI7%7+t3=62o5n02@QWx#V^|xb+8FRi zT$G{87~!4JH*O}=nQ3KtluMo^SJJejI8PR;bRaB+2~HD%vBAa~hSW4UkJBj2IRAzG z_4wqI>Elf5Y`Po{K7R0EH2qSk#YKcuO=oG8r;!p|%S6o<I+kZr<vfLjt3*BF@)_(d z!F$u_3F6s8$h3d{qi(>CvHV#NfY6@PqwuqMbx<E!zphA(1;T+nfG2hWS0(^*=&^## zor135_uQ#`2KBZ+#9XvD5wFT*Ri<Aso5>)vf%bEFwqUnO!REn=9OHS>ao9BUbK`26 zE_DlDtJ?@2gf0T^Lhk_>_e@4AboD+u^qFCJb_~u$-^LsZ#~5osy)Gjy4#26Zuuu1C zi&>;s`<p6MiJ_x$>l$kS2Mq0^qB9BY(&8NIx=hjTaQ?$GO3R3?vVIre)~(JomTAiQ z&00dLU&l#tNwy6=Ffx2aKznU_CkAITFoNU%k+50sWoTD#_Yq?Y;HX;xH<$em@Cp{X zf}(N=_h+}_q7U4IljoLu5C!H#Mj1y2`ognu%Rh)Vw&E=tZ`gPn?2ZU3UU$(eyu!C^ z?OD$k%zAq}nfp7L2W4i(@<Yq==6H=wu(pSWWepwEek!$+X~{u9%XK2;`YvW+f~`wJ z-hk92jqK3xBVd6vs4X%6$wFNWh4EteADH%MQ68U;7y4`{ChU7PC<8;ZhowD2d!(B< z#eX(#*sv|Du^sgx!V$i)m#Fh)&g0Wa!MZbj3nQ-q)Qz)gXuHfi5Xs&I07}zav{PT( z^gi4)!>jp?l{mq9egk-k-*m(8lJB(-I&E28`@r(K4;Oq*R`7LLRccT@9c-yXb4MLo zJI{ggm)cN0*WS{K&P!U+t+c`$Yt(j)yIB7XLl3|Tdm2uaXe^EX5TTFo5dt2IaW4{i z38D>R{IA-g--66aTJ+oa=sO5l^BY2!Oh(S*<Qbk90a<|`yJO|WBE1pWVgE*7klnTk zLto!T`#pqv2(}4vT74g3BW86USOLLy#WmFea$p-6zrs0u98UX6m3@3vyZzZx*{>3a nWb4MZ#WMcaR@JNsl}y;YbxTKsPfPK8h-J@^H7(j<hsXZ_eB|P5 literal 0 HcmV?d00001 diff --git a/src/agent/dueling_double_dqn.py b/src/agent/dueling_double_dqn.py new file mode 100644 index 0000000..6d82fde --- /dev/null +++ b/src/agent/dueling_double_dqn.py @@ -0,0 +1,512 @@ +import torch +import torch.optim as optim + +BUFFER_SIZE = int(1e5) # replay buffer size +BATCH_SIZE = 512 # minibatch size +GAMMA = 0.99 # discount factor 0.99 +TAU = 0.5e-3 # for soft update of target parameters +LR = 0.5e-4 # learning rate 0.5e-4 works + +# how often to update the network +UPDATE_EVERY = 20 +UPDATE_EVERY_FINAL = 10 +UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 + + +double_dqn = True # If using double dqn algorithm +input_channels = 5 # Number of Input channels + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + +USE_OPTIMIZER = optim.Adam +# USE_OPTIMIZER = optim.RMSprop +print(USE_OPTIMIZER) + + +class Agent: + """Interacts with and learns from the environment.""" + + def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + """Initialize an Agent object. + + Params + ====== + state_size (int): dimension of each state + action_size (int): dimension of each action + seed (int): random seed + """ + self.state_size = state_size + self.action_size = action_size + self.seed = random.seed(seed) + self.version = net_type + self.double_dqn = double_dqn + # Q-Network + if self.version == "Conv": + self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + # Replay memory + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + + self.final_step = {} + + # Initialize time step (for updating every UPDATE_EVERY steps) + self.t_step = 0 + self.t_step_final = 0 + self.t_step_agent_can_not_choose = 0 + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + if os.path.exists(filename + ".local"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print(filename + ".local -> ok") + if os.path.exists(filename + ".target"): + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print(filename + ".target -> ok") + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + def _update_model(self, switch=0): + # Learn every UPDATE_EVERY time steps. + # If enough samples are available in memory, get random subset and learn + if switch == 0: + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0: + if len(self.memory) > BATCH_SIZE: + experiences = self.memory.sample() + self.learn(experiences, GAMMA) + elif switch == 1: + self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL + if self.t_step_final == 0: + if len(self.memory_final) > BATCH_SIZE: + experiences = self.memory_final.sample() + self.learn(experiences, GAMMA) + else: + # If enough samples are available in memory_agent_can_not_choose, get random subset and learn + self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE + if self.t_step_agent_can_not_choose == 0: + if len(self.memory_agent_can_not_choose) > BATCH_SIZE: + experiences = self.memory_agent_can_not_choose.sample() + self.learn(experiences, GAMMA) + + def step(self, state, action, reward, next_state, done): + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + self._update_model(0) + + def step_agent_can_not_choose(self, state, action, reward, next_state, done): + # Save experience in replay memory_agent_can_not_choose + self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) + self._update_model(2) + + def add_final_step(self, agent_handle, state, action, reward, next_state, done): + if self.final_step.get(agent_handle) is None: + self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) + + def make_final_step(self, additional_reward=0): + for _, item in self.final_step.items(): + state = item[0] + action = item[1] + reward = item[2] + additional_reward + next_state = item[3] + done = item[4] + self.memory_final.add(state, action, reward, next_state, done) + self._update_model(1) + self._reset_final_step() + + def _reset_final_step(self): + self.final_step = {} + + def act(self, state, eps=0.): + """Returns actions for given state as per current policy. + + Params + ====== + state (array_like): current state + eps (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()) + else: + return random.choice(np.arange(self.action_size)) + + def learn(self, experiences, gamma): + + """Update value parameters using given batch of experience tuples. + + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + + Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) + + # Compute loss + loss = F.mse_loss(Q_expected, Q_targets) + # Minimize the loss + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # ------------------- update target network ------------------- # + self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) + + def soft_update(self, local_model, target_model, tau): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + + Params + ====== + local_model (PyTorch model): weights will be copied from + target_model (PyTorch model): weights will be copied to + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, seed): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + self.seed = random.seed(seed) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states + + +import copy +import os +import random +from collections import namedtuple, deque, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim + +from src.agent.model import QNetwork2, QNetwork + +BUFFER_SIZE = int(1e5) # replay buffer size +BATCH_SIZE = 512 # minibatch size +GAMMA = 0.95 # discount factor 0.99 +TAU = 0.5e-4 # for soft update of target parameters +LR = 0.5e-3 # learning rate 0.5e-4 works + +# how often to update the network +UPDATE_EVERY = 40 +UPDATE_EVERY_FINAL = 1000 +UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 + +double_dqn = True # If using double dqn algorithm +input_channels = 5 # Number of Input channels + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + +USE_OPTIMIZER = optim.Adam +# USE_OPTIMIZER = optim.RMSprop +print(USE_OPTIMIZER) + + +class Agent: + """Interacts with and learns from the environment.""" + + def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + """Initialize an Agent object. + + Params + ====== + state_size (int): dimension of each state + action_size (int): dimension of each action + seed (int): random seed + """ + self.state_size = state_size + self.action_size = action_size + self.seed = random.seed(seed) + self.version = net_type + self.double_dqn = double_dqn + # Q-Network + if self.version == "Conv": + self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + # Replay memory + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + + self.final_step = {} + + # Initialize time step (for updating every UPDATE_EVERY steps) + self.t_step = 0 + self.t_step_final = 0 + self.t_step_agent_can_not_choose = 0 + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + print("try to load: " + filename) + if os.path.exists(filename + ".local"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print(filename + ".local -> ok") + if os.path.exists(filename + ".target"): + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print(filename + ".target -> ok") + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + def _update_model(self, switch=0): + # Learn every UPDATE_EVERY time steps. + # If enough samples are available in memory, get random subset and learn + if switch == 0: + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0: + if len(self.memory) > BATCH_SIZE: + experiences = self.memory.sample() + self.learn(experiences, GAMMA) + elif switch == 1: + self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL + if self.t_step_final == 0: + if len(self.memory_final) > BATCH_SIZE: + experiences = self.memory_final.sample() + self.learn(experiences, GAMMA) + else: + # If enough samples are available in memory_agent_can_not_choose, get random subset and learn + self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE + if self.t_step_agent_can_not_choose == 0: + if len(self.memory_agent_can_not_choose) > BATCH_SIZE: + experiences = self.memory_agent_can_not_choose.sample() + self.learn(experiences, GAMMA) + + def step(self, state, action, reward, next_state, done): + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + self._update_model(0) + + def step_agent_can_not_choose(self, state, action, reward, next_state, done): + # Save experience in replay memory_agent_can_not_choose + self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) + self._update_model(2) + + def add_final_step(self, agent_handle, state, action, reward, next_state, done): + if self.final_step.get(agent_handle) is None: + self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) + return True + else: + return False + + def make_final_step(self, additional_reward=0): + for _, item in self.final_step.items(): + state = item[0] + action = item[1] + reward = item[2] + additional_reward + next_state = item[3] + done = item[4] + self.memory_final.add(state, action, reward, next_state, done) + self._update_model(1) + self._reset_final_step() + + def _reset_final_step(self): + self.final_step = {} + + def act(self, state, eps=0.): + """Returns actions for given state as per current policy. + + Params + ====== + state (array_like): current state + eps (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()), False + else: + return random.choice(np.arange(self.action_size)), True + + def learn(self, experiences, gamma): + + """Update value parameters using given batch of experience tuples. + + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + + Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) + + # Compute loss + loss = F.mse_loss(Q_expected, Q_targets) + # Minimize the loss + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # ------------------- update target network ------------------- # + self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) + + def soft_update(self, local_model, target_model, tau): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + + Params + ====== + local_model (PyTorch model): weights will be copied from + target_model (PyTorch model): weights will be copied to + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, seed): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + self.seed = random.seed(seed) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states diff --git a/src/agent/model.py b/src/agent/model.py new file mode 100644 index 0000000..70952e0 --- /dev/null +++ b/src/agent/model.py @@ -0,0 +1,61 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class QNetwork(nn.Module): + def __init__(self, state_size, action_size, seed, hidsize1=64, hidsize2=128): + super(QNetwork, self).__init__() + + self.fc1_val = nn.Linear(state_size, hidsize1) + self.fc2_val = nn.Linear(hidsize1, hidsize2) + self.fc3_val = nn.Linear(hidsize2, 1) + + self.fc1_adv = nn.Linear(state_size, hidsize1) + self.fc2_adv = nn.Linear(hidsize1, hidsize2) + self.fc3_adv = nn.Linear(hidsize2, action_size) + + def forward(self, x): + val = F.relu(self.fc1_val(x)) + val = F.relu(self.fc2_val(val)) + val = self.fc3_val(val) + + # advantage calculation + adv = F.relu(self.fc1_adv(x)) + adv = F.relu(self.fc2_adv(adv)) + adv = self.fc3_adv(adv) + return val + adv - adv.mean() + + +class QNetwork2(nn.Module): + def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64): + super(QNetwork2, self).__init__() + self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1) + self.bn1 = nn.BatchNorm2d(16) + self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3) + self.bn2 = nn.BatchNorm2d(32) + self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3) + self.bn3 = nn.BatchNorm2d(64) + + self.fc1_val = nn.Linear(6400, hidsize1) + self.fc2_val = nn.Linear(hidsize1, hidsize2) + self.fc3_val = nn.Linear(hidsize2, 1) + + self.fc1_adv = nn.Linear(6400, hidsize1) + self.fc2_adv = nn.Linear(hidsize1, hidsize2) + self.fc3_adv = nn.Linear(hidsize2, action_size) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + + # value function approximation + val = F.relu(self.fc1_val(x.view(x.size(0), -1))) + val = F.relu(self.fc2_val(val)) + val = self.fc3_val(val) + + # advantage calculation + adv = F.relu(self.fc1_adv(x.view(x.size(0), -1))) + adv = F.relu(self.fc2_adv(adv)) + adv = self.fc3_adv(adv) + return val + adv - adv.mean() diff --git a/src/observations.py b/src/observations.py new file mode 100644 index 0000000..cc89198 --- /dev/null +++ b/src/observations.py @@ -0,0 +1,731 @@ +""" +Collection of environment-specific ObservationBuilder. +""" +import collections +from typing import Optional, List, Dict, Tuple + +import numpy as np +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent +from flatland.utils.ordered_set import OrderedSet + + +class MyTreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the graph structure of the rail + network to simplify the representation of the state of the environment for each agent. + + For details about the features in the tree observation see the get() function. + """ + Node = collections.namedtuple('Node', 'dist_min_to_target ' + 'target_encountered ' + 'num_agents_same_direction ' + 'num_agents_opposite_direction ' + 'childs') + + tree_explored_actions_char = ['L', 'F', 'R', 'B'] + + def __init__(self, max_depth: int, predictor: PredictionBuilder = None): + super().__init__() + self.max_depth = max_depth + self.observation_dim = 2 + self.location_has_agent = {} + self.predictor = predictor + self.location_has_target = None + + self.switches_list = {} + self.switches_neighbours_list = [] + self.check_agent_descision = None + + def reset(self): + self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} + + def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision): + self.switches_list = switch_list + self.switches_neighbours_list = pre_switch_list + self.check_agent_descision = check_agent_descision + + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: + """ + Called whenever an observation has to be computed for the `env` environment, for each agent with handle + in the `handles` list. + """ + + if handles is None: + handles = [] + if self.predictor: + self.max_prediction_depth = 0 + self.predicted_pos = {} + self.predicted_dir = {} + self.predictions = self.predictor.get() + if self.predictions: + for t in range(self.predictor.max_depth + 1): + pos_list = [] + dir_list = [] + for a in handles: + if self.predictions[a] is None: + continue + pos_list.append(self.predictions[a][t][1:3]) + dir_list.append(self.predictions[a][t][3]) + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + self.predicted_dir.update({t: dir_list}) + self.max_prediction_depth = len(self.predicted_pos) + # Update local lookup table for all agents' positions + # ignore other agents not in the grid (only status active and done) + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} + self.location_has_agent_ready_to_depart = {} + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ + 'malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 + + observations = super().get_many(handles) + + return observations + + def get(self, handle: int = 0) -> Node: + """ + Computes the current observation for agent `handle` in env + + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is:: + + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + Each branch data is organized as:: + + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] + + Each node information is composed of 9 features: + + #1: + if own target lies on the explored branch the current distance from the agent in number of cells is stored. + + #2: + if another agents target is detected the distance in number of cells from the agents current location\ + is stored + + #3: + if another agent is detected the distance in number of cells from current agent position is stored. + + #4: + possible conflict detected + tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \ + distance in number of cells from current agent position + + 0 = No other agent reserve the same cell at similar time + + #5: + if an not usable switch (for agent) is detected we store the distance. + + #6: + This feature stores the distance in number of cells to the next branching (current node) + + #7: + minimum distance from node to the agent's target given the direction of the agent if this path is chosen + + #8: + agent in the same direction + n = number of agents present same direction \ + (possible future use: number of other agents in the same direction in this branch) + 0 = no agent present same direction + + #9: + agent in the opposite direction + n = number of agents present other direction than myself (so conflict) \ + (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) + 0 = no agent present other direction than myself + + #10: + malfunctioning/blokcing agents + n = number of time steps the oberved agent remains blocked + + #11: + slowest observed speed of an agent in same direction + 1 if no agent is observed + + min_fractional speed otherwise + #12: + number of agents ready to depart but no yet active + + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). + + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + if handle > len(self.env.agents): + print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) + agent = self.env.agents[handle] # TODO: handle being treated as index + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Here information about the agent itself is stored + distance_map = self.env.distance_map.get() + + root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[ + (handle, *agent_virtual_position, + agent.direction)], + target_encountered=0, + num_agents_same_direction=0, + num_agents_opposite_direction=0, + childs={}) + + visited = OrderedSet() + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # If only one transition is possible, the tree is oriented with this transition as the forward branch. + orientation = agent.direction + + if num_transitions == 1: + orientation = np.argmax(possible_transitions) + + for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_cell = get_new_position(agent_virtual_position, branch_direction) + + branch_observation, branch_visited = \ + self._explore_branch(handle, new_cell, branch_direction, 1, 1) + root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation + + visited |= branch_visited + else: + # add cells filled with infinity if no transition is possible + root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf + self.env.dev_obs_dict[handle] = visited + + return root_node_observation + + def _explore_branch(self, handle, position, direction, tot_dist, depth): + """ + Utility function to compute tree-based observations. + We walk along the branch and collect the information documented in the get() function. + If there is a branching point a new node is created and each possible branch is explored. + """ + + # [Recursive branch opened] + if depth >= self.max_depth + 1: + return [], [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + + visited = OrderedSet() + agent = self.env.agents[handle] + + other_agent_opposite_direction = 0 + other_agent_same_direction = 0 + + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] + + last_is_dead_end = False + last_is_a_decision_cell = False + target_encountered = 0 + + while exploring: + + dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1], + direction]) + + if agent.target == position: + target_encountered = 1 + + new_direction_me = direction + new_cell_me = position + a = self.env.agent_positions[new_cell_me] + if a != -1 and a != handle: + opp_agent = self.env.agents[a] + # look one step forward + # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction) + if opp_agent.direction != new_direction_me: # opp_possible_transitions[new_direction_me] == 0: + other_agent_opposite_direction += 1 + else: + other_agent_same_direction += 1 + + # ############################# + # ############################# + if (position[0], position[1], direction) in visited: + break + visited.add((position[0], position[1], direction)) + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + if np.array_equal(position, self.env.agents[handle].target): + last_is_target = True + break + + exploring = False + + # Check number of possible transitions for agent and total number of transitions in cell (type) + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = np.count_nonzero(possible_transitions) + # cell_transitions = self.env.rail.get_transitions(*position, direction) + transition_bit = bin(self.env.rail.get_full_transitions(*position)) + total_transitions = transition_bit.count("1") + + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + nbits = total_transitions + if nbits == 1: + # Dead-end! + last_is_dead_end = True + + if self.check_agent_descision is not None: + ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \ + self.check_agent_descision(position, + direction, + self.switches_list, + self.switches_neighbours_list) + if ret_agents_on_switch: + last_is_a_decision_cell = True + break + + exploring = True + # convert one-hot encoding to 0,1,2,3 + cell_transitions = self.env.rail.get_transitions(*position, direction) + direction = np.argmax(cell_transitions) + position = get_new_position(position, direction) + + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + + node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target, + target_encountered=target_encountered, + num_agents_opposite_direction=other_agent_opposite_direction, + num_agents_same_direction=other_agent_same_direction, + childs={}) + + # ############################# + # ############################# + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # Get the possible transitions + possible_transitions = self.env.rail.get_transitions(*position, direction) + + for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): + if last_is_dead_end and self.env.rail.get_transition((*position, direction), + (branch_direction + 2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = get_new_position(position, (branch_direction + 2) % 4) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + (branch_direction + 2) % 4, + tot_dist + 1, + depth + 1) + node.childs[self.tree_explored_actions_char[i]] = branch_observation + if len(branch_visited) != 0: + visited |= branch_visited + elif last_is_a_decision_cell and possible_transitions[branch_direction]: + new_cell = get_new_position(position, branch_direction) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + branch_direction, + tot_dist + 1, + depth + 1) + node.childs[self.tree_explored_actions_char[i]] = branch_observation + if len(branch_visited) != 0: + visited |= branch_visited + else: + # no exploring possible, add just cells with infinity + node.childs[self.tree_explored_actions_char[i]] = -np.inf + + if depth == self.max_depth: + node.childs.clear() + return node, visited + + def util_print_obs_subtree(self, tree: Node): + """ + Utility function to print tree observations returned by this object. + """ + self.print_node_features(tree, "root", "") + for direction in self.tree_explored_actions_char: + self.print_subtree(tree.childs[direction], direction, "\t") + + @staticmethod + def print_node_features(node: Node, label, indent): + print(indent, "Direction ", label, ": ", node.num_agents_same_direction, + ", ", node.num_agents_opposite_direction) + + def print_subtree(self, node, label, indent): + if node == -np.inf or not node: + print(indent, "Direction ", label, ": -np.inf") + return + + self.print_node_features(node, label, indent) + + if not node.childs: + return + + for direction in self.tree_explored_actions_char: + self.print_subtree(node.childs[direction], direction, indent + "\t") + + def set_env(self, env: Environment): + super().set_env(env) + if self.predictor: + self.predictor.set_env(self.env) + + def _reverse_dir(self, direction): + return int((direction + 2) % 4) + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16),\ + assuming 16 bits encoding of transitions. + + - obs_agents_state: A 3D array (map_height, map_width, 5) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and direction + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - fifth channel containing number of other agents ready to depart + + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). + """ + + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() + + def set_env(self, env: Environment): + super().set_env(env) + + def reset(self): + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + obs_targets = np.zeros((self.env.height, self.env.width, 2)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 + + # TODO can we do this more elegantly? + # for r in range(self.env.height): + # for c in range(self.env.width): + # obs_agents_state[(r, c)][4] = 0 + obs_agents_state[:, :, 4] = 0 + + obs_agents_state[agent_virtual_position][0] = agent.direction + obs_targets[agent.target][0] = 1 + + for i in range(len(self.env.agents)): + other_agent: EnvAgent = self.env.agents[i] + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + obs_targets[other_agent.target][1] = 1 + + # second to fourth channel only if in the grid + if other_agent.position is not None: + # second channel only for other agents + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[other_agent.initial_position][4] += 1 + return self.rail_obs, obs_agents_state, obs_targets + + +class LocalObsForRailEnv(ObservationBuilder): + """ + !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!! + Gives a local observation of the rail environment around the agent. + The observation is composed of the following elements: + + - transition map array of the local environment around the given agent, \ + with dimensions (view_height,2*view_width+1, 16), \ + assuming 16 bits encoding of transitions. + + - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \ + if they are in the agent's vision range, its target position, the positions of the other targets. + + - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \ + of the other agents at their position coordinates, if they are in the agent's vision range. + + - A 4 elements array with one hot encoding of the direction. + + Use the parameters view_width and view_height to define the rectangular view of the agent. + The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has + observation in front of it. + + .. deprecated:: 2.0.0 + """ + + def __init__(self, view_width, view_height, center): + + super(LocalObsForRailEnv, self).__init__() + self.view_width = view_width + self.view_height = view_height + self.center = center + self.max_padding = max(self.view_width, self.view_height - self.center) + + def reset(self): + # We build the transition map with a view_radius empty cells expansion on each side. + # This helps to collect the local transition map view when the agent is close to a border. + self.max_padding = max(self.view_width, self.view_height) + self.rail_obs = np.zeros((self.env.height, + self.env.width, 16)) + for i in range(self.env.height): + for j in range(self.env.width): + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): + agents = self.env.agents + agent = agents[handle] + + # Correct agents position for padding + # agent_rel_pos[0] = agent.position[0] + self.max_padding + # agent_rel_pos[1] = agent.position[1] + self.max_padding + + # Collect visible cells as set to be plotted + visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) + local_rail_obs = None + + # Add the visible cells to the observed cells + self.env.dev_obs_dict[handle] = set(visited) + + # Locate observed agents and their coresponding targets + local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) + obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) + _idx = 0 + for pos in visited: + curr_rel_coord = rel_coords[_idx] + local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] + if pos == agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 + else: + for tmp_agent in agents: + if pos == tmp_agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 + if pos != agent.position: + for tmp_agent in agents: + if pos == tmp_agent.position: + obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ + tmp_agent.direction] + + _idx += 1 + + direction = np.identity(4)[agent.direction] + return local_rail_obs, obs_map_state, obs_other_agents_state, direction + + def get_many(self, handles: Optional[List[int]] = None) -> Dict[ + int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + """ + Called whenever an observation has to be computed for the `env` environment, for each agent with handle + in the `handles` list. + """ + + return super().get_many(handles) + + def field_of_view(self, position, direction, state=None): + # Compute the local field of view for an agent in the environment + data_collection = False + if state is not None: + temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) + data_collection = True + if direction == 0: + origin = (position[0] + self.center, position[1] - self.view_width) + elif direction == 1: + origin = (position[0] - self.view_width, position[1] - self.center) + elif direction == 2: + origin = (position[0] - self.center, position[1] + self.view_width) + else: + origin = (position[0] + self.view_width, position[1] + self.center) + visible = list() + rel_coords = list() + for h in range(self.view_height): + for w in range(2 * self.view_width + 1): + if direction == 0: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.append((origin[0] - h, origin[1] + w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] + elif direction == 1: + if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: + visible.append((origin[0] + w, origin[1] + h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] + elif direction == 2: + if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: + visible.append((origin[0] + h, origin[1] - w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] + else: + if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: + visible.append((origin[0] - w, origin[1] - h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] + if data_collection: + return temp_visible_data + else: + return visible, rel_coords + + +def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray, + np.ndarray): + data = np.zeros(2) + + data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0 + # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0 + data[1] = 2.0 * int(node.target_encountered > 0) - 1.0 + + return data + + +def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int, + current_tree_depth: int, + max_tree_depth: int) -> ( + np.ndarray, np.ndarray, np.ndarray): + if node == -np.inf: + remaining_depth = max_tree_depth - current_tree_depth + # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure + num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) + return [0] * num_remaining_nodes * 2 + + data = _split_node_into_feature_groups(node, dist_min_to_target) + + if not node.childs: + return data + + for direction in MyTreeObsForRailEnv.tree_explored_actions_char: + sub_data = _split_subtree_into_feature_groups(node.childs[direction], + node.dist_min_to_target, + current_tree_depth + 1, + max_tree_depth) + data = np.concatenate((data, sub_data)) + return data + + +def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> ( + np.ndarray, np.ndarray, np.ndarray): + """ + This function splits the tree into three difference arrays of values + """ + data = _split_node_into_feature_groups(tree, 1000000.0) + + for direction in MyTreeObsForRailEnv.tree_explored_actions_char: + sub_data = _split_subtree_into_feature_groups(tree.childs[direction], + 1000000.0, + 1, + max_tree_depth) + data = np.concatenate((data, sub_data)) + + return data + + +def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int): + """ + This function normalizes the observation used by the RL algorithm + """ + data = split_tree_into_feature_groups(observation, tree_depth) + normalized_obs = data + + # navigate_info + navigate_info = np.zeros(4) + action_info = np.zeros(4) + np.seterr(all='raise') + try: + dm = observation.dist_min_to_target + if observation.childs['L'] != -np.inf: + navigate_info[0] = dm - observation.childs['L'].dist_min_to_target + action_info[0] = 1 + if observation.childs['F'] != -np.inf: + navigate_info[1] = dm - observation.childs['F'].dist_min_to_target + action_info[1] = 1 + if observation.childs['R'] != -np.inf: + navigate_info[2] = dm - observation.childs['R'].dist_min_to_target + action_info[2] = 1 + if observation.childs['B'] != -np.inf: + navigate_info[3] = dm - observation.childs['B'].dist_min_to_target + action_info[3] = 1 + except: + navigate_info = np.ones(4) + normalized_obs = np.zeros(len(normalized_obs)) + + # navigate_info_2 = np.copy(navigate_info) + # max_v = np.max(navigate_info_2) + # navigate_info_2 = navigate_info_2 / max_v + # navigate_info_2[navigate_info_2 < 1] = -1 + + max_v = np.max(navigate_info) + navigate_info = navigate_info / max_v + navigate_info[navigate_info < 0] = -1 + # navigate_info[abs(navigate_info) < 1] = 0 + # normalized_obs = navigate_info + + # navigate_info = np.concatenate((navigate_info, action_info)) + normalized_obs = np.concatenate((navigate_info, normalized_obs)) + # normalized_obs = np.concatenate((navigate_info, navigate_info_2)) + # print(normalized_obs) + return normalized_obs -- GitLab