Skip to content
Snippets Groups Projects
Commit 1169fcb6 authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed formatting

parent 36714e2c
No related branches found
No related tags found
No related merge requests found
......@@ -115,6 +115,7 @@ for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
final_obs = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
data = norm_obs_clip(data)
......@@ -150,7 +151,8 @@ for trials in range(1, n_trials + 1):
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
if done[a]:
final_obs[a] = obs[a]
final_obs[a] = obs[a].copy()
final_obs_next[a] = next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]:
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
......@@ -159,7 +161,7 @@ for trials in range(1, n_trials + 1):
obs = next_obs.copy()
if done['__all__']:
env_done = 1
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], next_obs[a], done[a])
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
break
# Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
......
......@@ -139,8 +139,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width:
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
......
......@@ -143,7 +143,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99
if len(new_path) >= 2:
nr_created += 1
#print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
# print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
# print(start_goal)
agents_position = [sg[0] for sg in start_goal]
......
......@@ -300,7 +300,7 @@ class RailEnv(Environment):
# if num_agents_in_target_position == self.number_of_agents:
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True
self.rewards_dict = [r + global_reward for r in self.rewards_dict]
self.rewards_dict = [0*r+global_reward for r in self.rewards_dict]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment