Commit 3165b6dd authored by nilabha's avatar nilabha

Merge branch 'cherry-pick-3b3f0c1d' into 'flatland-paper-baselines'

add explore to apex runs and update results

See merge request !20
parents 8abfda27 7838b8ae
Pipeline #5831 failed with stage
in 3 minutes and 2 seconds
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import pickle import pickle
import shelve import shelve
from pathlib import Path from pathlib import Path
import random
import gym import gym
import numpy as np import numpy as np
...@@ -47,6 +48,11 @@ load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms ...@@ -47,6 +48,11 @@ load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms
from collections.abc import Mapping from collections.abc import Mapping
from copy import deepcopy from copy import deepcopy
# Default terminal state epsilon
# https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py
final_epsilon = 0.02
random.seed(1)
def val_replace(mapping): def val_replace(mapping):
obj = deepcopy(mapping) obj = deepcopy(mapping)
if isinstance(mapping, Mapping): if isinstance(mapping, Mapping):
...@@ -287,6 +293,7 @@ def run(args, parser): ...@@ -287,6 +293,7 @@ def run(args, parser):
if args.eager: if args.eager:
from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution() enable_eager_execution()
config['eager'] = True
cls = get_trainable_cls(args.run) cls = get_trainable_cls(args.run)
...@@ -422,6 +429,13 @@ def rollout(agent, ...@@ -422,6 +429,13 @@ def rollout(agent,
policy_id=policy_id) policy_id=policy_id)
a_action = flatten_to_single_ndarray(a_action) # ray 0.8.5 a_action = flatten_to_single_ndarray(a_action) # ray 0.8.5
# a_action = _flatten_action(a_action) # tuple actions # ray 0.8.4 # a_action = _flatten_action(a_action) # tuple actions # ray 0.8.4
# Epsilon-greedy action selection for APEX
if hasattr(agent, '_name'):
if agent._name == "APEX":
if random.random() <= final_epsilon:
a_action = random.choice(np.arange(env.action_space.n))
action_dict[agent_id] = a_action action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action prev_actions[agent_id] = a_action
action = action_dict action = action_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment