Skip to content
Snippets Groups Projects
Commit 0b2b56f6 authored by hagrid67's avatar hagrid67
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland

parents 01fb10f8 72c69df1
No related branches found
No related tags found
No related merge requests found
...@@ -12,19 +12,23 @@ np.random.seed(1) ...@@ -12,19 +12,23 @@ np.random.seed(1)
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
transition_probability = [5, # empty cell - Case 0 transition_probability = [5, # empty cell - Case 0
15, # Case 1 - straight 1, # Case 1 - straight
5, # Case 2 - simple switch 5, # Case 2 - simple switch
1, # Case 3 - diamond crossing 1, # Case 3 - diamond crossing
1, # Case 4 - single slip 1, # Case 4 - single slip
1, # Case 5 - double slip 1, # Case 5 - double slip
1, # Case 6 - symmetrical 1, # Case 6 - symmetrical
0] # Case 7 - dead end 0, # Case 7 - dead end
15, # Case 1b (8) - simple turn right
15, # Case 1c (9) - simple turn left
15] # Case 2b (10) - simple switch mirrored
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=10, env = RailEnv(width=10,
height=10, height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1) number_of_agents=3)
""" """
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
...@@ -57,7 +61,7 @@ action_prob = [0] * 4 ...@@ -57,7 +61,7 @@ action_prob = [0] * 4
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
demo = True demo = False
def max_lt(seq, val): def max_lt(seq, val):
......
...@@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
transitions_templates_ = [] transitions_templates_ = []
transition_probabilities = [] transition_probabilities = []
for i in range(len(t_utils.transitions) - 4): # don't include dead-ends for i in range(len(t_utils.transitions)-4): # don't include dead-ends
all_transitions = 0 all_transitions = 0
for dir_ in range(4): for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_) trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment