Commit f8ef8acc authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

Merge branch '264_fixed_examples_branch_2' into 'master'

264 fixed examples branch 2

Closes #264

See merge request !251
parents 4ed208c4 5688a388
Pipeline #2670 passed with stages
in 34 minutes and 20 seconds
This diff is collapsed.
......@@ -24,28 +24,11 @@ def run_benchmark():
action_dict = dict()
action_prob = [0] * 4
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
idx = len(seq) - 1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0:
return seq[idx]
idx -= 1
return None
for trials in range(1, n_trials + 1):
# Reset environment
obs, info = env.reset()
for a in range(env.get_num_agents()):
norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# Run episode
for step in range(100):
# Action
......@@ -56,9 +39,6 @@ def run_benchmark():
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
if done['__all__']:
break
......
......@@ -35,7 +35,11 @@ class SingleAgentNavigationObs(ObservationBuilder):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if agent.position:
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
else:
possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
......
......@@ -7,7 +7,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
from flatland.envs.schedule_generators import ScheduleGenerator, compute_max_episode_steps
from flatland.envs.schedule_generators import ScheduleGenerator
from flatland.envs.schedule_utils import Schedule
from flatland.utils.rendertools import RenderTool
......@@ -31,15 +31,14 @@ def custom_rail_generator() -> RailGenerator:
def custom_schedule_generator() -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> Schedule:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
num_resets: int = 0) -> Schedule:
agents_positions = []
agents_direction = []
agents_target = []
speeds = []
max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
return Schedule(agent_positions=agents_positions, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
max_episode_steps=max_episode_steps)
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
return generator
......
import random
import time
from typing import List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(1)
np.random.seed(1)
class SingleAgentNavigationObs(ObservationBuilder):
"""
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__()
def reset(self):
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
env = RailEnv(width=14,
height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs())
obs, info = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=False)
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
print("Agent halts")
actions[0] = 4 # Halt
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
print("Agent 0 broken-ness: ", env.agents[0].malfunction_data['malfunction'])
env_renderer.render_env(show=True, frames=True, show_observations=False)
time.sleep(0.5)
if done["__all__"]:
break
env_renderer.close_window()
......@@ -729,7 +729,6 @@ class RailEnv(Environment):
(*agent.position, agent.direction),
new_direction)
# only call cell_free() if new cell is inside the scene
if new_cell_valid:
# Check the new position is not the same as any of the existing agent positions
......
......@@ -114,7 +114,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
step, a, False, info_dict['action_required'][a])
if replay.set_malfunction is not None:
# As we force malfunctions on the agents we have to set a positive rate that the env
# recognizes the agent as potentially malfuncitoning
# We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['malfunction_rate'] = max(agent.malfunction_data['malfunction_rate'], 1)
agent.malfunction_data['next_malfunction'] = np.inf
agent.malfunction_data['moving_before_malfunction'] = agent.moving
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment