diff --git a/examples/env_generators.py b/examples/env_generators.py index 71e9d1ece93a0e58ce66f06f0912d417499d228b..d8dd2e3ece319e95bd3b397bd694724858ee9543 100644 --- a/examples/env_generators.py +++ b/examples/env_generators.py @@ -3,15 +3,110 @@ import random import numpy as np from typing import NamedTuple -from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.agent_utils import RailAgentStatus +from flatland.core.grid.grid4_utils import get_new_position MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)]) +def get_shortest_path_action(env,handle): + distance_map = env.distance_map.get() + + agent = env.agents[handle] + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + if agent.position: + possible_transitions = env.rail.get_transitions( + *agent.position, agent.direction) + else: + possible_transitions = env.rail.get_transitions( + *agent.initial_position, agent.direction) + + num_transitions = np.count_nonzero(possible_transitions) + + min_distances = [] + for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[direction]: + new_position = get_new_position( + agent_virtual_position, direction) + min_distances.append( + distance_map[handle, new_position[0], + new_position[1], direction]) + else: + min_distances.append(np.inf) + + if num_transitions == 1: + observation = [0, 1, 0] + + elif num_transitions == 2: + idx = np.argpartition(np.array(min_distances), 2) + observation = [0, 0, 0] + observation[idx[0]] = 1 + return np.argmax(observation) + 1 + + +def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35): + random.seed(random_seed) + width = 25 + height = 25 + nr_trains = 5 + max_num_cities = 4 + grid_mode = False + max_rails_between_cities = 2 + max_rails_in_city = 3 + + malfunction_rate = 0 + malfunction_min_duration = 0 + malfunction_max_duration = 0 + + rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rails_in_city) + + stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence + min_duration=malfunction_min_duration, # Minimal duration of malfunction + max_duration=malfunction_max_duration # Max duration of malfunction + ) + speed_ratio_map = None + schedule_generator = sparse_schedule_generator(speed_ratio_map) + + malfunction_generator = no_malfunction_generator() + + while width <= max_width and height <= max_height: + try: + env = RailEnv(width=width, height=height, rail_generator=rail_generator, + schedule_generator=schedule_generator, number_of_agents=nr_trains, + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator_and_process_data=malfunction_generator, + obs_builder_object=observation_builder, remove_agents_at_target=False) + + print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format( + random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities, + max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration + )) + + return env + except ValueError as e: + logging.error(f"Error: {e}") + width += 5 + height += 5 + logging.info("Try again with larger env: (w,h):", width, height) + logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}") + return None + + def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45): random.seed(random_seed) size = random.randint(0, 5) @@ -29,11 +124,10 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma max_rails_between_cities=max_rails_between_cities, max_rails_in_city=max_rails_in_cities) - # new version: - # stochastic_data = MalfunctionParameters(malfunction_rate, malfunction_min_duration, malfunction_max_duration) - - stochastic_data = {'malfunction_rate': malfunction_rate, 'min_duration': malfunction_min_duration, - 'max_duration': malfunction_max_duration} + stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate, # Rate of malfunction occurence + min_duration=malfunction_min_duration, # Minimal duration of malfunction + max_duration=malfunction_max_duration # Max duration of malfunction + ) schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25}) @@ -41,7 +135,8 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma try: env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, number_of_agents=nr_trains, - malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), obs_builder_object=observation_builder, remove_agents_at_target=False) print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format( @@ -129,9 +224,13 @@ def _after_step(self, observation, reward, done, info): def perc_completion(env): tasks_finished = 0 - for current_agent in env.agents_data: - if current_agent.status == RailAgentStatus.DONE_REMOVED: + if isinstance(env, RailEnv): + agent_data = env.agents + else: + agent_data = env.agents_data + for current_agent in agent_data: + if current_agent.status == RailAgentStatus.DONE: tasks_finished += 1 return 100 * np.mean(tasks_finished / max( - 1, env.num_agents)) \ No newline at end of file + 1, len(agent_data))) \ No newline at end of file diff --git a/examples/flatland_env.py b/examples/flatland_env.py index b7f593f9e4128b061a220353d711100715bba870..05fcbd7d652477122396aad4d30aaa1c348fffe2 100644 --- a/examples/flatland_env.py +++ b/examples/flatland_env.py @@ -148,7 +148,7 @@ class raw_env(AECEnv, gym.Env): self.agent_selection = self._agent_selector.next() self.rewards = dict(zip(self.agents, [0 for _ in self.agents])) self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents])) - self.action_dict = {i:0 for i in self.possible_agents} + self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents} return observations @@ -160,14 +160,12 @@ class raw_env(AECEnv, gym.Env): agent = self.agent_selection self.action_dict[get_agent_handle(agent)] = action - if self._reset_next_step: - return self.reset() if self.dones[agent]: self.agents.remove(agent) - if not self.env_done(): - self.agent_selection = self._agent_selector.next() - return self.last() + # self.agent_selection = self._agent_selector.next() + # self.agents.remove(agent) + # return self.last() if self._agent_selector.is_last(): observations, rewards, dones, infos = self._environment.step(self.action_dict) @@ -185,10 +183,12 @@ class raw_env(AECEnv, gym.Env): # self._cumulative_rewards[agent] = 0 self._accumulate_rewards() + + obs, cumulative_reward, done, info = self.last() self.agent_selection = self._agent_selector.next() - return self.last() + return obs, cumulative_reward, done, info # if self._agent_selector.is_last(): # self._agent_selector.reinit(self.agents)