diff --git a/env-data/tests/test1.npy b/env-data/tests/test1.npy index 77e0288589171b8b03d828423ca456f2ac8395e3..f0cff3c9a1260facf073b88702da3f0557ab32f0 100644 Binary files a/env-data/tests/test1.npy and b/env-data/tests/test1.npy differ diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 41dcf7795b10125f89f1719bcba873eb44f9f105..2704e84463279638c6e86ae8c3259c5653b9c508 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -34,7 +34,7 @@ env = RailEnv(width=20, env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../env-data/tests/train_simple.npy']), + ['../env-data/tests/test1.npy']), number_of_agents=1) @@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1): for a in range(env.number_of_agents): if demo: eps = 0 - action = agent.act(np.array(obs[a]), eps=eps) + action = 2# agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index bb754ea9e21a8b785efa92c557569611ed13e504..9b6ee96ea11c8e8bed01aee20e8bccc9ced15728 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -1024,7 +1024,7 @@ class RailEnv(Environment): is_deadend = False if action == 2: - # TODO: Check if cell is curve --> Compute correct transition and change in orientation + # compute number of possible transitions in the current # cell nbits = 0 @@ -1054,20 +1054,23 @@ class RailEnv(Environment): movement = reverse_direction is_deadend = True if nbits == 2: - # straigt or curve + # Checking for curves + valid_transition = self.rail.get_transition( (pos[0], pos[1], direction), movement) reverse_direction = (direction + 2) % 4 curv_dir = (movement + 1) % 4 while not valid_transition: - if curv_dir != reverse_direction: - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - curv_dir) - curv_dir = (curv_dir+1) % 4 - if valid_transition: - movement = curv_dir + if curv_dir != reverse_direction: + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + curv_dir) + if valid_transition: + print("Curve") + movement = curv_dir + curv_dir = (curv_dir + 1) % 4 + new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb index 159b85aa988e9897a8494e103145c39cb355808c..985671617e89b43130fd0cf1f13ba83edf842b45 100644 --- a/notebooks/CanvasEditor.ipynb +++ b/notebooks/CanvasEditor.ipynb @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -147,7 +147,7 @@ "10" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -165,7 +165,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -182,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -302,7 +302,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -310,7 +310,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "15e10f8984c44a529cd86cc57c697c7f", + "model_id": "0fe2a192b21c4eebb7aac3af8f3cb02e", "version_major": 2, "version_minor": 0 },