Commit 993e6a9d authored by Erik Nygren's avatar Erik Nygren
Browse files

Bug fix in handling curve. Now you can pass a curve by going straigt!

parent 8de88772
Pipeline #431 passed with stage
in 2 minutes and 37 seconds
No preview for this file type
......@@ -34,7 +34,7 @@ env = RailEnv(width=20,
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../env-data/tests/train_simple.npy']),
['../env-data/tests/test1.npy']),
number_of_agents=1)
......@@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.number_of_agents):
if demo:
eps = 0
action = agent.act(np.array(obs[a]), eps=eps)
action = 2# agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
......
......@@ -1024,7 +1024,7 @@ class RailEnv(Environment):
is_deadend = False
if action == 2:
# TODO: Check if cell is curve --> Compute correct transition and change in orientation
# compute number of possible transitions in the current
# cell
nbits = 0
......@@ -1054,20 +1054,23 @@ class RailEnv(Environment):
movement = reverse_direction
is_deadend = True
if nbits == 2:
# straigt or curve
# Checking for curves
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
movement)
reverse_direction = (direction + 2) % 4
curv_dir = (movement + 1) % 4
while not valid_transition:
if curv_dir != reverse_direction:
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
curv_dir)
curv_dir = (curv_dir+1) % 4
if valid_transition:
movement = curv_dir
if curv_dir != reverse_direction:
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
curv_dir)
if valid_transition:
print("Curve")
movement = curv_dir
curv_dir = (curv_dir + 1) % 4
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
......
......@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
......@@ -79,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
......@@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
......@@ -116,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
......@@ -129,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
......@@ -138,7 +138,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
......@@ -147,7 +147,7 @@
"10"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
......@@ -165,7 +165,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
......@@ -182,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [
{
......@@ -222,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [
{
......@@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
......@@ -302,7 +302,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {
"scrolled": false
},
......@@ -310,7 +310,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "15e10f8984c44a529cd86cc57c697c7f",
"model_id": "0fe2a192b21c4eebb7aac3af8f3cb02e",
"version_major": 2,
"version_minor": 0
},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment