Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
marl-flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manavsinghal157
marl-flatland
Commits
f4fca1d5
Commit
f4fca1d5
authored
4 years ago
by
Egli Adrian (IT-SCI-API-PFI)
Browse files
Options
Downloads
Patches
Plain Diff
clean up code - simplified
parent
c12f806e
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
reinforcement_learning/multi_agent_training.py
+5
-5
5 additions, 5 deletions
reinforcement_learning/multi_agent_training.py
run.py
+7
-2
7 additions, 2 deletions
run.py
with
12 additions
and
7 deletions
reinforcement_learning/multi_agent_training.py
+
5
−
5
View file @
f4fca1d5
...
@@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
...
@@ -208,7 +208,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Double Dueling DQN policy
# Double Dueling DQN policy
policy
=
DDDQNPolicy
(
state_size
,
action_size
,
train_params
)
policy
=
DDDQNPolicy
(
state_size
,
action_size
,
train_params
)
if
Fals
e
:
if
Tru
e
:
policy
=
PPOAgent
(
state_size
,
action_size
,
n_agents
)
policy
=
PPOAgent
(
state_size
,
action_size
,
n_agents
)
# Load existing policy
# Load existing policy
if
train_params
.
load_policy
is
not
""
:
if
train_params
.
load_policy
is
not
""
:
...
@@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
...
@@ -546,10 +546,10 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if
__name__
==
"
__main__
"
:
if
__name__
==
"
__main__
"
:
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
parser
.
add_argument
(
"
-n
"
,
"
--n_episodes
"
,
help
=
"
number of episodes to run
"
,
default
=
2
000
,
type
=
int
)
parser
.
add_argument
(
"
-n
"
,
"
--n_episodes
"
,
help
=
"
number of episodes to run
"
,
default
=
10
000
,
type
=
int
)
parser
.
add_argument
(
"
-t
"
,
"
--training_env_config
"
,
help
=
"
training config id (eg 0 for Test_0)
"
,
default
=
1
,
parser
.
add_argument
(
"
-t
"
,
"
--training_env_config
"
,
help
=
"
training config id (eg 0 for Test_0)
"
,
default
=
0
,
type
=
int
)
type
=
int
)
parser
.
add_argument
(
"
-e
"
,
"
--evaluation_env_config
"
,
help
=
"
evaluation config id (eg 0 for Test_0)
"
,
default
=
1
,
parser
.
add_argument
(
"
-e
"
,
"
--evaluation_env_config
"
,
help
=
"
evaluation config id (eg 0 for Test_0)
"
,
default
=
0
,
type
=
int
)
type
=
int
)
parser
.
add_argument
(
"
--n_evaluation_episodes
"
,
help
=
"
number of evaluation episodes
"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"
--n_evaluation_episodes
"
,
help
=
"
number of evaluation episodes
"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"
--checkpoint_interval
"
,
help
=
"
checkpoint interval
"
,
default
=
100
,
type
=
int
)
parser
.
add_argument
(
"
--checkpoint_interval
"
,
help
=
"
checkpoint interval
"
,
default
=
100
,
type
=
int
)
...
@@ -573,7 +573,7 @@ if __name__ == "__main__":
...
@@ -573,7 +573,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"
--load_policy
"
,
help
=
"
policy filename (reference) to load
"
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
"
--load_policy
"
,
help
=
"
policy filename (reference) to load
"
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
"
--use_fast_tree_observation
"
,
help
=
"
use FastTreeObs instead of stock TreeObs
"
,
parser
.
add_argument
(
"
--use_fast_tree_observation
"
,
help
=
"
use FastTreeObs instead of stock TreeObs
"
,
action
=
'
store_true
'
)
action
=
'
store_true
'
)
parser
.
add_argument
(
"
--max_depth
"
,
help
=
"
max depth
"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"
--max_depth
"
,
help
=
"
max depth
"
,
default
=
2
,
type
=
int
)
training_params
=
parser
.
parse_args
()
training_params
=
parser
.
parse_args
()
env_params
=
[
env_params
=
[
...
...
This diff is collapsed.
Click to expand it.
run.py
+
7
−
2
View file @
f4fca1d5
...
@@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
...
@@ -30,6 +30,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from
flatland.evaluators.client
import
FlatlandRemoteClient
from
flatland.evaluators.client
import
FlatlandRemoteClient
from
flatland.evaluators.client
import
TimeoutException
from
flatland.evaluators.client
import
TimeoutException
from
reinforcement_learning.ppo.ppo_agent
import
PPOAgent
from
utils.dead_lock_avoidance_agent
import
DeadLockAvoidanceAgent
from
utils.dead_lock_avoidance_agent
import
DeadLockAvoidanceAgent
from
utils.deadlock_check
import
check_if_all_blocked
from
utils.deadlock_check
import
check_if_all_blocked
from
utils.fast_tree_obs
import
FastTreeObs
from
utils.fast_tree_obs
import
FastTreeObs
...
@@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
...
@@ -46,12 +47,14 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
# Print per-step logs
# Print per-step logs
VERBOSE
=
True
VERBOSE
=
True
USE_FAST_TREEOBS
=
True
USE_FAST_TREEOBS
=
True
USE_PPO_AGENT
=
True
# Checkpoint to use (remember to push it!)
# Checkpoint to use (remember to push it!)
checkpoint
=
"
./checkpoints/201124171810-7800.pth
"
# 18.249244799876152 DEPTH=2 AGENTS=10
checkpoint
=
"
./checkpoints/201124171810-7800.pth
"
# 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126150143-5200.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126160144-2000.pth" # 18.249244799876152 DEPTH=2 AGENTS=10
checkpoint
=
"
./checkpoints/201127160352-2000.pth
"
checkpoint
=
"
./checkpoints/201127160352-2000.pth
"
checkpoint
=
"
./checkpoints/201130083154-2000.pth
"
EPSILON
=
0.005
EPSILON
=
0.005
...
@@ -99,8 +102,10 @@ else:
...
@@ -99,8 +102,10 @@ else:
action_size
=
5
action_size
=
5
# Creates the policy. No GPU on evaluation server.
# Creates the policy. No GPU on evaluation server.
policy
=
DDDQNPolicy
(
state_size
,
action_size
,
Namespace
(
**
{
'
use_gpu
'
:
False
}),
evaluation_mode
=
True
)
if
not
USE_PPO_AGENT
:
# policy = PPOAgent(state_size, action_size, 10)
policy
=
DDDQNPolicy
(
state_size
,
action_size
,
Namespace
(
**
{
'
use_gpu
'
:
False
}),
evaluation_mode
=
True
)
else
:
policy
=
PPOAgent
(
state_size
,
action_size
,
10
)
policy
.
load
(
checkpoint
)
policy
.
load
(
checkpoint
)
#####################################################################
#####################################################################
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment