Skip to content
Snippets Groups Projects
Commit 0cb266c0 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

added departure and arrival properties to agent

parent db58de1a
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple
from attr import attrs, attrib, Factory
from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule
......@@ -20,6 +20,8 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('direction', Grid4TransitionsEnum),
('target', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('speed_data', dict),
('malfunction_data', dict),
('handle', int),
......@@ -37,6 +39,10 @@ class EnvAgent:
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# NEW - time scheduling
earliest_departure = attrib(default=None, type=int) # default None during _from_schedule()
latest_arrival = attrib(default=None, type=int) # default None during _from_schedule()
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
......@@ -82,8 +88,9 @@ class EnvAgent:
self.malfunction_data['moving_before_malfunction'] = False
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data,
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, speed_data=self.speed_data,
malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
position=self.position, old_direction=self.old_direction, old_position=self.old_position)
......@@ -109,8 +116,10 @@ class EnvAgent:
return list(starmap(EnvAgent, zip(schedule.agent_positions,
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
schedule.agent_targets,
[False] * len(schedule.agent_positions),
[None] * len(schedule.agent_positions), # earliest_departure
[None] * len(schedule.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)))))
......
......@@ -34,7 +34,8 @@ from gym.utils import seeding
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
# NEW
from flatland.envs.schedule_time_generators import schedule_time_generator
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
......@@ -390,6 +391,10 @@ class RailEnv(Environment):
# Reset agents to initial
self.reset_agents()
self.distance_map.reset(self.agents, self.rail)
# NEW - time window scheduling
schedule_time_generator(self.agents, self.distance_map, schedule, self.np_random, temp_info=optionals)
for agent in self.agents:
# Induce malfunctions
......@@ -412,7 +417,6 @@ class RailEnv(Environment):
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
self.distance_map.reset(self.agents, self.rail)
# Reset the malfunction generator
if "generate" in dir(self.malfunction_generator):
......
import os
import json
import itertools
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import numpy as np
from numpy.core.fromnumeric import shape
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.schedule_utils import Schedule
from flatland.envs.distance_map import DistanceMap
# #### DATA COLLECTION *************************
# import termplotlib as tpl
# import matplotlib.pyplot as plt
# root_path = 'C:\\Users\\nimish\\Programs\\AIcrowd\\flatland\\flatland\\playground'
# dir_name = 'TEMP'
# os.mkdir(os.path.join(root_path, dir_name))
# # Histogram 1
# dist_resolution = 50
# schedule_dist = np.zeros(shape=(dist_resolution))
# # Volume dist
# route_dist = None
# # Dist - shortest path
# shortest_paths_len_dist = []
# # City positions
# city_positions = []
# #### DATA COLLECTION *************************
def schedule_time_generator(agents: List[EnvAgent], distance_map: DistanceMap, schedule: Schedule,
np_random: RandomState = None, temp_info=None) -> None:
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
shortest_paths = get_shortest_paths(distance_map)
max_episode_steps = int(schedule.max_episode_steps * 1.0) #needs to be increased due to fractional speeds taking way longer (best - be calculated here)
end_buffer = max_episode_steps // 20 #schedule.end_buffer
latest_arrival_max = max_episode_steps-end_buffer
travel_buffer_multiplier = 1.7
earliest_departures = []
latest_arrivals = []
# #### DATA COLLECTION *************************
# # Create info.txt
# with open(os.path.join(root_path, dir_name, 'INFO.txt'), 'w') as f:
# f.write('COPY FROM main.py')
# # Volume dist
# route_dist = np.zeros(shape=(max_episode_steps, distance_map.rail.width, distance_map.rail.height), dtype=np.int8)
# # City positions
# # Dummy distance map for shortest path pairs between cities
# city_positions = temp_info['agents_hints']['city_positions']
# d_rail = distance_map.rail
# d_dmap = DistanceMap([], d_rail.height, d_rail.width)
# d_city_permutations = list(itertools.permutations(city_positions, 2))
# d_positions = []
# d_targets = []
# for position, target in d_city_permutations:
# d_positions.append(position)
# d_targets.append(target)
# d_schedule = Schedule(d_positions,
# [0] * len(d_positions),
# d_targets,
# [1.0] * len(d_positions),
# [None] * len(d_positions),
# 1000)
# d_agents = EnvAgent.from_schedule(d_schedule)
# d_dmap.reset(d_agents, d_rail)
# d_map = d_dmap.get()
# d_data = {
# 'city_positions': city_positions,
# 'start': d_positions,
# 'end': d_targets,
# }
# with open(os.path.join(root_path, dir_name, 'city_data.json'), 'w') as f:
# json.dump(d_data, f)
# with open(os.path.join(root_path, dir_name, 'distance_map.npy'), 'wb') as f:
# np.save(f, d_map)
# #### DATA COLLECTION *************************
for agent in agents:
agent_speed = agent.speed_data['speed']
agent_shortest_path = shortest_paths[agent.handle]
agent_shortest_path_len = len(agent_shortest_path)
agent_shortest_path_time = int(np.ceil(agent_shortest_path_len / agent_speed)) # for fractional speeds 1/3 etc
agent_travel_time_max = min( int(np.ceil(agent_shortest_path_time * travel_buffer_multiplier)), latest_arrival_max) # min(this, latest_arrival_max), SHOULD NOT BE lesser than shortest path time
departure_window_max = latest_arrival_max - agent_travel_time_max
earliest_departure = np_random.randint(0, departure_window_max)
latest_arrival = earliest_departure + agent_travel_time_max
earliest_departures.append(earliest_departure)
latest_arrivals.append(latest_arrival)
agent.earliest_departure = earliest_departure
agent.latest_arrival = latest_arrival
# #### DATA COLLECTION *************************
# # Histogram 1
# dist_bounds = get_dist_window(earliest_departure, latest_arrival, latest_arrival_max)
# schedule_dist[dist_bounds[0]: dist_bounds[1]] += 1
# # Volume dist
# for waypoint in agent_shortest_path:
# pos = waypoint.position
# route_dist[earliest_departure:latest_arrival, pos[0], pos[1]] += 1
# # Dist - shortest path
# shortest_paths_len_dist.append(agent_shortest_path_len)
# np.save(os.path.join(root_path, dir_name, 'volume.npy'), route_dist)
# shortest_paths_len_dist.sort()
# save_sp_fig()
# #### DATA COLLECTION *************************
# #### DATA COLLECTION *************************
# # Histogram 1
# def get_dist_window(departure_t, arrival_t, latest_arrival_max):
# return (int(np.round(np.interp(departure_t, [0, latest_arrival_max], [0, dist_resolution]))),
# int(np.round(np.interp(arrival_t, [0, latest_arrival_max], [0, dist_resolution]))))
# def plot_dist():
# counts, bin_edges = schedule_dist, [i for i in range(0, dist_resolution+1)]
# fig = tpl.figure()
# fig.hist(counts, bin_edges, orientation="horizontal", force_ascii=False)
# fig.show()
# # Shortest path dist
# def save_sp_fig():
# fig = plt.figure(figsize=(15, 7))
# plt.bar(np.arange(len(shortest_paths_len_dist)), shortest_paths_len_dist)
# plt.savefig(os.path.join(root_path, dir_name, 'shortest_paths_sorted.png'))
# #### DATA COLLECTION *************************
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment