Commit c7a327e7 authored by dzorlu's avatar dzorlu

test.py. add dummy trained model

parent 011cc73a
{
"challenge_id": "aicrowd-neurips-2020-minerl-challenge",
"grader_id": "aicrowd-neurips-2020-minerl-challenge",
"authors": ["aicrowd-bot"],
"tags": "change-me",
"authors": ["dzorlu"],
"tags":["RL"],
"description": "Test Model for MineRL Challenge",
"gpu": false
"gpu": true
}
name: mrl
name: minerl
channels:
- conda-forge
- anaconda
- defaults
dependencies:
- _libgcc_mutex=0.1
- ca-certificates=2020.7.22
- _tflow_select=2.1.0
- absl-py=0.10.0
- argon2-cffi=20.1.0
- astunparse=1.6.3
- async_generator=1.10
- attrs=20.2.0
- backcall=0.2.0
- backports=1.0
- backports.functools_lru_cache=1.6.1
- blas=1.0
- bleach=3.2.1
- blinker=1.4
- brotlipy=0.7.0
- c-ares=1.15.0
- ca-certificates=2020.6.20
- cachetools=4.1.1
- certifi=2020.6.20
- cffi=1.14.2
- chardet=3.0.4
- click=7.1.2
- cryptography=3.1
- cudatoolkit=10.1.243
- cudnn=7.6.5
- cupti=10.1.168
- decorator=4.4.2
- defusedxml=0.6.0
- entrypoints=0.3
- expat=2.2.9
- gast=0.3.3
- git=2.23.0
- google-auth=1.21.2
- google-auth-oauthlib=0.4.1
- google-pasta=0.2.0
- grpcio=1.31.0
- h5py=2.10.0
- hdf5=1.10.6
- idna=2.10
- importlib-metadata=1.7.0
- importlib_metadata=1.7.0
- intel-openmp=2020.2
- ipykernel=5.3.4
- ipython=7.18.1
- ipython_genutils=0.2.0
- jinja2=2.11.2
- json5=0.9.4
- jsonschema=3.2.0
- jupyter_client=6.1.7
- jupyter_core=4.6.3
- jupyterlab=2.2.8
- jupyterlab_pygments=0.1.1
- jupyterlab_server=1.2.0
- krb5=1.18.2
- ld_impl_linux-64=2.33.1
- libcurl=7.71.1
- libedit=3.1.20191231
- libffi=3.3
- libgcc-ng=9.1.0
- libgfortran-ng=7.3.0
- libprotobuf=3.12.4
- libsodium=1.0.18
- libssh2=1.9.0
- libstdcxx-ng=9.1.0
- markdown=3.2.2
- markupsafe=1.1.1
- mistune=0.8.4
- mkl=2019.4
- mkl-service=2.3.0
- mkl_fft=1.1.0
- mkl_random=1.1.0
- nbclient=0.5.0
- nbconvert=6.0.3
- nbformat=5.0.7
- ncurses=6.2
- nest-asyncio=1.4.0
- notebook=6.1.4
- numpy=1.19.1
- numpy-base=1.19.1
- oauthlib=3.1.0
- openssl=1.1.1g
- opt_einsum=3.1.0
- packaging=20.4
- pandoc=2.10.1
- pandocfilters=1.4.2
- pcre=8.44
- perl=5.26.2
- pexpect=4.8.0
- pickleshare=0.7.5
- pip=20.2.2
- python=3.6.12
- prometheus_client=0.8.0
- prompt-toolkit=3.0.7
- protobuf=3.12.4
- ptyprocess=0.6.0
- pyasn1=0.4.8
- pyasn1-modules=0.2.7
- pycparser=2.20
- pygments=2.7.1
- pyjwt=1.7.1
- pyopenssl=19.1.0
- pyparsing=2.4.7
- pyrsistent=0.17.3
- pysocks=1.7.1
- python=3.8.5
- python-dateutil=2.8.1
- python_abi=3.8
- pyzmq=19.0.2
- readline=8.0
- requests=2.24.0
- requests-oauthlib=1.3.0
- rsa=4.6
- send2trash=1.5.0
- setuptools=49.6.0
- six=1.15.0
- sqlite=3.33.0
- tensorboard=2.2.1
- tensorboard-plugin-wit=1.6.0
- tensorflow=2.2.0
- tensorflow-base=2.2.0
- tensorflow-estimator=2.2.0
- tensorflow-gpu=2.2.0
- termcolor=1.1.0
- terminado=0.8.3
- testpath=0.4.4
- tk=8.6.10
- tornado=6.0.4
- traitlets=5.0.4
- urllib3=1.25.10
- wcwidth=0.2.5
- webencodings=0.5.1
- werkzeug=1.0.1
- wheel=0.35.1
- wrapt=1.12.1
- xz=5.2.5
- zeromq=4.3.2
- zipp=3.1.0
- zlib=1.2.11
- pip:
- absl-py==0.10.0
- argon2-cffi==20.1.0
- astunparse==1.6.3
- async-generator==1.10
- attrs==20.2.0
- backcall==0.2.0
- bleach==3.1.5
- bsuite==0.3.2
- cachetools==4.1.1
- cffi==1.14.2
- chardet==3.0.4
- chex==0.0.2
- cloudpickle==1.3.0
- coloredlogs==14.0
- crowdai-api==0.1.22
- cycler==0.10.0
- dataclasses==0.7
- decorator==4.4.2
- defusedxml==0.6.0
- dataclasses==0.6
- descartes==1.1.0
- dill==0.3.2
- dm-acme==0.1.8
- dm-control==0.0.322773188
- dm-env==1.2
- dm-haiku==0.0.2
- dm-reverb-nightly==0.1.0.dev20200708
- dm-sonnet==2.0.0
- dm-tree==0.1.5
- entrypoints==0.3
- frozendict==1.2
- future==0.18.2
- gast==0.3.3
- getch==1.0
- glfw==1.12.0
- google-auth==1.21.1
- google-auth-oauthlib==0.4.1
- google-pasta==0.2.0
- grpcio==1.32.0
- gym==0.17.2
- h5py==2.10.0
- humanfriendly==8.2
- idna==2.10
- imageio==2.9.0
- importlib-metadata==1.7.0
- ipykernel==5.3.4
- ipython==7.16.1
- ipython-genutils==0.2.0
- jax==0.1.76
- jax==0.1.77
- jaxlib==0.1.55
- jedi==0.17.2
- jinja2==2.11.2
- joblib==0.16.0
- json5==0.9.5
- jsonschema==3.2.0
- jupyter-client==6.1.7
- jupyter-core==4.6.3
- jupyterlab==2.2.8
- jupyterlab-pygments==0.1.1
- jupyterlab-server==1.2.0
- keras-preprocessing==1.1.2
- kiwisolver==1.2.0
- labmaze==1.0.3
- lxml==4.5.2
- markdown==3.2.2
- markupsafe==1.1.1
- matplotlib==3.0.3
- minerl==0.3.6
- mistune==0.8.4
- mizani==0.7.1
- nbclient==0.5.0
- nbconvert==6.0.2
- nbformat==5.0.7
- nest-asyncio==1.4.0
- networkx==2.5
- notebook==6.1.4
- numpy==1.18.5
- oauthlib==3.1.0
- opencv-python==4.4.0.42
- opt-einsum==3.3.0
- packaging==20.4
- palettable==3.3.0
- pandas==1.1.2
- pandocfilters==1.4.2
- parso==0.7.1
- patsy==0.5.1
- pexpect==4.8.0
- pickleshare==0.7.5
- pillow==7.2.0
- plotnine==0.7.1
- portpicker==1.3.1
- prometheus-client==0.8.0
- prompt-toolkit==3.0.7
- protobuf==3.13.0
- psutil==5.7.2
- ptyprocess==0.6.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycparser==2.20
- pyglet==1.5.0
- pygments==2.7.0
- pyopengl==3.1.5
- pyparsing==2.4.7
- pyro4==4.80
- pyrsistent==0.17.3
- python-dateutil==2.8.1
- python-gitlab==2.5.0
- pytz==2020.1
- pywavelets==1.1.1
- pyzmq==19.0.2
- redis==3.5.3
- requests==2.24.0
- requests-oauthlib==1.3.0
- rlax==0.0.2
- rsa==4.6
- scikit-image==0.17.2
- scikit-learn==0.23.2
- scipy==1.4.1
- send2trash==1.5.0
- serpent==1.30.2
- six==1.15.0
- sklearn==0.0
- statsmodels==0.12.0
- tabulate==0.8.7
- tb-nightly==2.3.0a20200722
- tensorboard==2.3.0
- tensorboard-plugin-wit==1.7.0
- tensorflow==2.3.0
- tensorflow-estimator==2.3.0
- termcolor==1.1.0
- terminado==0.8.3
- testpath==0.4.4
- tf-estimator-nightly==2.4.0.dev2020091401
- tf-estimator-nightly==2.4.0.dev2020091801
- tf-nightly==2.4.0.dev20200708
- tfp-nightly==0.12.0.dev20200717
- threadpoolctl==2.1.0
- tifffile==2020.9.3
- toolz==0.10.0
- tornado==6.0.4
- tqdm==4.49.0
- traitlets==4.3.3
- trfl==1.1.0
- typing==3.7.4.3
- urllib3==1.25.10
- wcwidth==0.2.5
- webencodings==0.5.1
- werkzeug==1.0.1
- wrapt==1.12.1
- zipp==3.1.0
prefix: /home/deniz/anaconda3/envs/mrl
prefix: /home/deniz/anaconda3/envs/minerl
absl-py @ file:///tmp/build/80754af9/absl-py_1600297518631/work
argon2-cffi @ file:///home/conda/feedstock_root/build_artifacts/argon2-cffi_1596629847793/work
astunparse==1.6.3
async-generator==1.10
attrs @ file:///home/conda/feedstock_root/build_artifacts/attrs_1599308529326/work
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache==1.6.1
bleach @ file:///home/conda/feedstock_root/build_artifacts/bleach_1600454382015/work
blinker==1.4
brotlipy==0.7.0
bsuite @ git+git://github.com/deepmind/bsuite.git@f2e6c2103247579498e2b84ca7581d8ae4632dac
cachetools @ file:///tmp/build/80754af9/cachetools_1596822027882/work
certifi==2020.6.20
cffi @ file:///tmp/build/80754af9/cffi_1598370769933/work
chardet==3.0.4
chex==0.0.2
click==7.1.2
cloudpickle==1.3.0
coloredlogs==14.0
crowdai-api==0.1.22
cryptography @ file:///tmp/build/80754af9/cryptography_1598892038851/work
cycler==0.10.0
dataclasses==0.6
decorator==4.4.2
defusedxml==0.6.0
descartes==1.1.0
dill==0.3.2
-e git+https://github.com/dzorlu/acme.git@880a001187dd4edb32f608a1b7e445b0f9899e71#egg=dm_acme
dm-control==0.0.322773188
dm-env==1.2
dm-haiku==0.0.2
dm-reverb-nightly==0.1.0.dev20200708
dm-sonnet==2.0.0
dm-tree==0.1.5
entrypoints==0.3
frozendict==1.2
future==0.18.2
gast==0.3.3
getch==1.0
glfw==1.12.0
google-auth @ file:///tmp/build/80754af9/google-auth_1600274525154/work
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
grpcio @ file:///tmp/build/80754af9/grpcio_1597424474635/work
gym==0.17.2
h5py @ file:///tmp/build/80754af9/h5py_1593454122442/work
humanfriendly==8.2
idna @ file:///tmp/build/80754af9/idna_1593446292537/work
imageio==2.9.0
importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1593446406207/work
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1595446881062/work/dist/ipykernel-5.3.4-py3-none-any.whl
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1598749948871/work
ipython-genutils==0.2.0
jax==0.1.77
jaxlib==0.1.55
jedi==0.17.2
Jinja2==2.11.2
joblib==0.16.0
json5 @ file:///home/conda/feedstock_root/build_artifacts/json5_1591810480056/work
jsonschema==3.2.0
jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1598486169312/work
jupyter-core==4.6.3
jupyterlab==2.2.8
jupyterlab-pygments==0.1.1
jupyterlab-server @ file:///home/conda/feedstock_root/build_artifacts/jupyterlab_server_1593951277307/work
Keras-Preprocessing==1.1.2
kiwisolver==1.2.0
labmaze==1.0.3
lxml==4.5.2
Markdown @ file:///tmp/build/80754af9/markdown_1597433240441/work
MarkupSafe==1.1.1
matplotlib==3.0.3
minerl==0.3.6
mistune==0.8.4
mizani==0.7.1
mkl-fft==1.1.0
mkl-random==1.1.0
mkl-service==2.3.0
nbclient @ file:///home/conda/feedstock_root/build_artifacts/nbclient_1598558657104/work
nbconvert @ file:///home/conda/feedstock_root/build_artifacts/nbconvert_1600286912556/work
nbformat @ file:///home/conda/feedstock_root/build_artifacts/nbformat_1594060262917/work
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1594996608835/work
networkx==2.5
notebook @ file:///home/conda/feedstock_root/build_artifacts/notebook_1599742234243/work
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1596233721170/work
oauthlib==3.1.0
opencv-python==4.4.0.42
opt-einsum==3.1.0
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1589925210001/work
palettable==3.3.0
pandas==1.1.2
pandocfilters==1.4.2
parso==0.7.1
patsy==0.5.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.2.0
plotnine==0.7.1
portpicker==1.3.1
prometheus-client @ file:///home/conda/feedstock_root/build_artifacts/prometheus_client_1590412252446/work
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1598885455507/work
protobuf==3.12.4
psutil==5.7.2
ptyprocess==0.6.0
pyasn1==0.4.8
pyasn1-modules==0.2.7
pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
pyglet==1.5.0
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1600347314331/work
PyJWT==1.7.1
PyOpenGL==3.1.5
pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work
pyparsing==2.4.7
Pyro4==4.80
pyrsistent @ file:///home/conda/feedstock_root/build_artifacts/pyrsistent_1599988181994/work
PySocks==1.7.1
python-dateutil==2.8.1
python-gitlab==2.5.0
pytz==2020.1
PyWavelets==1.1.1
pyzmq==19.0.2
redis==3.5.3
requests @ file:///tmp/build/80754af9/requests_1592841827918/work
requests-oauthlib==1.3.0
rlax @ git+git://github.com/deepmind/rlax.git@5e1a6bfd7271a150bc5c2568f7a2f7648863fc04
rsa @ file:///tmp/build/80754af9/rsa_1596998415516/work
scikit-image==0.17.2
scikit-learn==0.23.2
scipy==1.4.1
Send2Trash==1.5.0
serpent==1.30.2
six==1.15.0
sklearn==0.0
statsmodels==0.12.0
tabulate==0.8.7
tb-nightly==2.3.0a20200722
tensorboard==2.2.1
tensorboard-plugin-wit==1.6.0
tensorflow==2.2.0
tensorflow-estimator==2.2.0
termcolor==1.1.0
terminado==0.8.3
testpath==0.4.4
tf-estimator-nightly==2.4.0.dev2020091801
tf-nightly==2.4.0.dev20200708
tfp-nightly==0.12.0.dev20200717
threadpoolctl==2.1.0
tifffile==2020.9.3
toolz==0.10.0
tornado==6.0.4
tqdm==4.49.0
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1599471676085/work
trfl==1.1.0
typing==3.7.4.3
urllib3 @ file:///tmp/build/80754af9/urllib3_1597086586889/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1595859607677/work
webencodings==0.5.1
Werkzeug==1.0.1
wrapt==1.12.1
zipp==3.1.0
......@@ -14,9 +14,20 @@ import gym
import minerl
import abc
import numpy as np
from pathlib import Path
import coloredlogs
coloredlogs.install(logging.DEBUG)
from acme.agents.tf import actors
from acme.tf import savers as tf2_savers
from acme.tf import utils as tf2_utils
from acme import wrappers
from acme import specs
import sonnet as snt
import tensorflow as tf
import coloredlogs, logging
coloredlogs.install(logging.WARNING)
logger = logging.getLogger(__name__)
# All the evaluations will be evaluated on MineRLObtainDiamondVectorObf-v0 environment
MINERL_GYM_ENV = os.getenv('MINERL_GYM_ENV', 'MineRLObtainDiamondVectorObf-v0')
......@@ -29,7 +40,28 @@ EVALUATION_THREAD_COUNT = int(os.getenv('EPISODES_EVALUATION_THREAD_COUNT', 2))
class EpisodeDone(Exception):
pass
class Episode(gym.Env):
# class Episode(gym.Env):
# """A class for a single episode.
# """
# def __init__(self, env):
# self.env = env
# self.action_space = env.action_space
# self.observation_space = env.observation_space
# self._done = False
# def reset(self):
# if not self._done:
# return self.env.reset()
# def step(self, action):
# s,r,d,i = self.env.step(action)
# if d:
# self._done = True
# raise EpisodeDone()
# else:
# return s,r,d,i
class Episode:
"""A class for a single episode.
"""
def __init__(self, env):
......@@ -43,12 +75,12 @@ class Episode(gym.Env):
return self.env.reset()
def step(self, action):
s,r,d,i = self.env.step(action)
if d:
ts = self.env.step(action)
if ts.last():
self._done = True
raise EpisodeDone()
else:
return s,r,d,i
return ts
......@@ -100,6 +132,67 @@ class MineRLAgentBase(abc.ABC):
#######################
# YOUR CODE GOES HERE #
#######################
from acme.tf import networks
import dm_env
import functools
NUMBER_OF_DISCRETE_ACTIONS = 25 #make sure this matches train.py
rel_path = os.path.dirname(__file__) # relative directory path
model_dir = os.path.join(rel_path, "train")
Path(model_dir).mkdir(parents=True, exist_ok=True)
logger.info(model_dir)
def create_network(nb_actions: int = NUMBER_OF_DISCRETE_ACTIONS) -> networks.RNNCore:
"""Creates the policy network"""
return networks.R2D2MineRLNetwork(nb_actions)
def make_environment(k_means_path: str,
num_actions: int = NUMBER_OF_DISCRETE_ACTIONS,
dat_loader: minerl.data.data_pipeline.DataPipeline = None,
train: bool = True,
minerl_gym_env: str = MINERL_GYM_ENV) -> dm_env.Environment:
"""
Wrap the environment in:
1 - MineRLWrapper
- similar to OAR but add proprioceptive features
- kMeans to map cont action space to a discrete one
2 - SinglePrecisionWrapper
3 - GymWrapper
"""
env = gym.make(minerl_gym_env)
return wrappers.wrap_all(env, [
wrappers.GymWrapper,
functools.partial(
wrappers.MineRLWrapper,
num_actions=num_actions,
dat_loader=dat_loader,
k_means_path=k_means_path,
train=False
),
wrappers.SinglePrecisionWrapper,
])
def load_actor(environment_spec):
network = create_network(NUMBER_OF_DISCRETE_ACTIONS)
tf2_utils.create_variables(network, [environment_spec.observations])
# restores the model
tf2_savers.Checkpointer(
directory=model_dir,
subdirectory='r2d2_learner_v1',
time_delta_minutes=15,
objects_to_save={'network': network}, #only revive the network
)
policy_network = snt.DeepRNN([
network,
lambda qs: tf.math.argmax(qs, axis=-1), # this is different at inference time.
])
actor = actors.RecurrentActor(policy_network, None)
return actor
class MineRLMatrixAgent(MineRLAgentBase):
"""
......@@ -133,7 +226,6 @@ class MineRLMatrixAgent(MineRLAgentBase):
while not done:
obs,reward,done,_ = single_episode_env.step(self.act(self.flatten_obs(obs)))
class MineRLRandomAgent(MineRLAgentBase):
"""A random agent"""
def load_agent(self):
......@@ -145,11 +237,22 @@ class MineRLRandomAgent(MineRLAgentBase):
while not done:
random_act = single_episode_env.action_space.sample()
single_episode_env.step(random_act)
class R2D3Agent(MineRLAgentBase):
"""A random agent"""
def load_agent(self, actor):
self.actor = actor
def run_agent_on_episode(self, single_episode_env : Episode):
ts = single_episode_env.reset()
while not ts.last():
action = self.actor.select_action(ts.observation)
ts = single_episode_env.step(action)
#####################################################################
# IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING #
######################################################################
AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
AGENT_TO_TEST = R2D3Agent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
......@@ -157,15 +260,27 @@ AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAg
# EVALUATION CODE #
####################
def main():
#
environment = make_environment(num_actions=NUMBER_OF_DISCRETE_ACTIONS, k_means_path=model_dir)
spec = specs.make_environment_spec(environment)
actor = load_actor(spec) #initiate here to keep the state shared across threads
environment.close()
agent = AGENT_TO_TEST()
assert isinstance(agent, MineRLAgentBase)
agent.load_agent()
agent.load_agent(actor)
assert MINERL_MAX_EVALUATION_EPISODES > 0
assert EVALUATION_THREAD_COUNT > 0
# Create the parallel envs (sequentially to prevent issues!)
envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
envs = list()
for _ in range(EVALUATION_THREAD_COUNT):
environment = make_environment(num_actions=NUMBER_OF_DISCRETE_ACTIONS,
k_means_path=model_dir)
envs.append(environment)
# Create the parallel envs (sequentially to prevent issues!)
#envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# A simple funciton to evaluate on episodes!
......
......@@ -59,6 +59,10 @@ MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', '/hdd/minerl')
# Optional: You can view best effort status of your instances with the help of parser.py
# This will give you current state like number of steps completed, instances launched and so on. Make your you keep a tap on the numbers to avoid breaching any limits.
rel_path = os.path.dirname(__file__) # relative directory path
performance_dir = os.path.join(rel_path, "performance")
Path(performance_dir).mkdir(parents=True, exist_ok=True)
parser = Parser('performance/',
allowed_environment=MINERL_GYM_ENV,
maximum_instances=MINERL_TRAINING_MAX_INSTANCES,
......@@ -73,7 +77,11 @@ def create_network(nb_actions: int = NUMBER_OF_DISCRETE_ACTIONS) -> networks.RNN
return networks.R2D2MineRLNetwork(nb_actions)
def make_environment(num_actions: int, dat_loader: minerl.data.data_pipeline.DataPipeline,) -> dm_env.Environment:
def make_environment(k_means_path: str,
num_actions: int = NUMBER_OF_DISCRETE_ACTIONS,
dat_loader: minerl.data.data_pipeline.DataPipeline = None,
train: bool = True,
minerl_gym_env: str = MINERL_GYM_ENV) -> dm_env.Environment:
"""
Wrap the environment in:
1 - MineRLWrapper
......@@ -83,7 +91,7 @@ def make_environment(num_actions: int, dat_loader: minerl.data.data_pipeline.Dat
3 - GymWrapper
"""
env = gym.make(MINERL_GYM_ENV)
env = gym.make(minerl_gym_env)
return wrappers.wrap_all(env, [
wrappers.GymWrapper,
......@@ -91,6 +99,7 @@ def make_environment(num_actions: int, dat_loader: minerl.data.data_pipeline.Dat
wrappers.MineRLWrapper,
num_actions=num_actions,
dat_loader=dat_loader,
k_means_path=k_means_path,
),
wrappers.SinglePrecisionWrapper,
])
......@@ -206,7 +215,6 @@ def main():
burn_in_length = 40