Skip to content
Snippets Groups Projects
Commit c528e652 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor bugfixes in training script

parent 8112097a
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv ...@@ -3,7 +3,7 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import * from flatland.utils.rendertools import *
from flatland.baselines.dueling_double_dqn import Agent from flatland.baselines.dueling_double_dqn import Agent
from collections import deque from collections import deque
import torch import torch,random
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -18,22 +18,22 @@ transition_probability = [1.0, # empty cell - Case 0 ...@@ -18,22 +18,22 @@ transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 6 - symmetrical 1.0, # Case 6 - symmetrical
1.0] # Case 7 - dead end 1.0] # Case 7 - dead end
""" """
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [1.0, # empty cell - Case 0 transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight 1.0, # Case 1 - straight
0.5, # Case 2 - simple switch 1.0, # Case 2 - simple switch
0.2, # Case 3 - diamond drossing 0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip 0.5, # Case 4 - single slip
0.1, # Case 5 - double slip 0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical 0.2, # Case 6 - symmetrical
0.01] # Case 7 - dead end 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=20, env = RailEnv(width=7,
height=20, height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=10) number_of_agents=1)
env.reset()
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
handle = env.get_agent_handles() handle = env.get_agent_handles()
...@@ -51,28 +51,10 @@ dones_list = [] ...@@ -51,28 +51,10 @@ dones_list = []
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
env = RailEnv(width=6,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs, all_rewards, done, _ = env.step({0: 0}) obs = env.reset()
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment