Skip to content
Snippets Groups Projects
Commit 6c8a30f9 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

PPO ?

parent 168d6728
No related branches found
No related tags found
No related merge requests found
File added
File added
File added
......@@ -9,7 +9,6 @@ from pprint import pprint
import numpy as np
import psutil
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
......@@ -78,6 +77,7 @@ def create_rail_env(env_params, tree_observation):
random_seed=seed
)
def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Environment parameters
n_agents = train_env_params.n_agents
......@@ -283,11 +283,23 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
agent_positions = get_agent_positions(train_env)
for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle]
act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD)
act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING)
if agent.status == RailAgentStatus.ACTIVE:
all_rewards[agent_handle] = 0.0
if done[agent_handle] == False:
if check_for_deadlock(agent_handle, train_env, agent_positions):
all_rewards[agent_handle] -= 1000.0
all_rewards[agent_handle] = -1.0
else:
pos = agent.position
possible_transitions = train_env.rail.get_transitions(*pos, agent.direction)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions < 2 and ((act != RailEnvActions.MOVE_FORWARD) or
(act != RailEnvActions.STOP_MOVING)):
all_rewards[agent_handle] = -0.5
else:
all_rewards[agent_handle] = -0.01
else:
all_rewards[agent_handle] = 1.0
step_timer.end()
......
......@@ -148,7 +148,10 @@ class PPOAgent(Policy):
reward_i = 1
else:
done_list.insert(0, 0)
reward_i = 0
if reward_i < -1:
reward_i = -1
else:
reward_i = 0
discounted_reward = reward_i + self.gamma * discounted_reward
reward_list.insert(0, discounted_reward)
state_next_list.insert(0, state_next_i)
......
......@@ -57,7 +57,7 @@ checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786
checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857
checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504
checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537
checkpoint = "./checkpoints/201212190452-6500.pth" # PPO: 13.944402986414723
checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723
EPSILON = 0.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment