Skip to content
Snippets Groups Projects
Commit dae5e483 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor updates

parent a13d54e8
No related branches found
No related tags found
No related merge requests found
...@@ -13,15 +13,15 @@ np.random.seed(1) ...@@ -13,15 +13,15 @@ np.random.seed(1)
transition_probability = [0.5, # empty cell - Case 0 transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight 1.0, # Case 1 - straight
1.0, # Case 2 - simple switch 1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing 0.3, # Case 3 - diamond crossing
0.5, # Case 4 - single slip 0.5, # Case 4 - single slip
0.5, # Case 5 - double slip 0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical 0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=7, env = RailEnv(width=20,
height=7, height=20,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1) number_of_agents=1)
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
...@@ -29,7 +29,7 @@ handle = env.get_agent_handles() ...@@ -29,7 +29,7 @@ handle = env.get_agent_handles()
state_size = 105 state_size = 105
action_size = 4 action_size = 4
n_trials = 9999 n_trials = 15000
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
eps_decay = 0.998 eps_decay = 0.998
...@@ -40,19 +40,34 @@ scores = [] ...@@ -40,19 +40,34 @@ scores = []
dones_list = [] dones_list = []
action_prob = [0]*4 action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth'))
demo = True
def max_lt(seq, val): def max_lt(seq, val):
""" """
Return greatest item in seq for which item < val applies. Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val. None is returned if seq was empty or all items in seq were >= val.
""" """
max = 0
idx = len(seq)-1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
max = seq[idx]
idx -= 1
return max
def min_lt(seq, val):
"""
Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
min = np.inf
idx = len(seq)-1 idx = len(seq)-1
while idx >= 0: while idx >= 0:
if seq[idx] < val and seq[idx] >= 0: if seq[idx] > val and seq[idx] < min:
return seq[idx] min = seq[idx]
idx -= 1 idx -= 1
return None return min
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
...@@ -69,12 +84,14 @@ for trials in range(1, n_trials + 1): ...@@ -69,12 +84,14 @@ for trials in range(1, n_trials + 1):
# Run episode # Run episode
for step in range(50): for step in range(50):
#if trials > 114: if demo:
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
#print(step) #print(step)
# Action # Action
for a in range(env.number_of_agents): for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=0) if demo:
eps = 0
action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
......
...@@ -649,7 +649,8 @@ class RailEnv(Environment): ...@@ -649,7 +649,8 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty # if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \ if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]: self.agents_position[i][1] == self.agents_target[i][1] and \
action_dict[handle] == 0:
self.dones[handle] = True self.dones[handle] = True
else: else:
self.rewards_dict[handle] += step_penalty self.rewards_dict[handle] += step_penalty
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment