Commit 3165b6dd authored by nilabha's avatar nilabha

Merge branch 'cherry-pick-3b3f0c1d' into 'flatland-paper-baselines'

add explore to apex runs and update results

See merge request !20
parents 8abfda27 7838b8ae
Pipeline #5831 failed with stage
in 3 minutes and 2 seconds
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
......@@ -8,6 +8,7 @@ import os
import pickle
import shelve
from pathlib import Path
import random
import gym
import numpy as np
......@@ -47,6 +48,11 @@ load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms
from collections.abc import Mapping
from copy import deepcopy
# Default terminal state epsilon
# https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py
final_epsilon = 0.02
random.seed(1)
def val_replace(mapping):
obj = deepcopy(mapping)
if isinstance(mapping, Mapping):
......@@ -287,6 +293,7 @@ def run(args, parser):
if args.eager:
from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution()
config['eager'] = True
cls = get_trainable_cls(args.run)
......@@ -422,6 +429,13 @@ def rollout(agent,
policy_id=policy_id)
a_action = flatten_to_single_ndarray(a_action) # ray 0.8.5
# a_action = _flatten_action(a_action) # tuple actions # ray 0.8.4
# Epsilon-greedy action selection for APEX
if hasattr(agent, '_name'):
if agent._name == "APEX":
if random.random() <= final_epsilon:
a_action = random.choice(np.arange(env.action_space.n))
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment