Commit 0d88ef17 authored by pfrl_rainbow's avatar pfrl_rainbow

pytorch with gpu

parent c1159605
......@@ -4,5 +4,5 @@
"authors": ["minerl_rainbow_baseline"],
"tags": "RL",
"description": "Test Model for MineRL Challenge",
"gpu": false
"gpu": true
}
......@@ -57,12 +57,12 @@ class MineRLAgentBase(abc.ABC):
"""
To compete in the competition, you are required to implement a
SUBCLASS to this class.
YOUR SUBMISSION WILL FAIL IF:
* Rename this class
* You do not implement a subclass to this class
* You do not implement a subclass to this class
This class enables the evaluator to run your agent in parallel,
This class enables the evaluator to run your agent in parallel,
so you should load your model only once in the 'load_agent' method.
"""
......@@ -83,9 +83,9 @@ class MineRLAgentBase(abc.ABC):
You should just implement the standard environment interaction loop here:
obs = env.reset()
while not done:
env.step(self.agent.act(obs))
env.step(self.agent.act(obs))
...
NOTE: This method will be called in PARALLEL during evaluation.
So, only store state in LOCAL variables.
For example, if using an LSTM, don't store the hidden state in the class
......@@ -103,7 +103,7 @@ class MineRLAgentBase(abc.ABC):
class MineRLMatrixAgent(MineRLAgentBase):
"""
An example random agent.
An example random agent.
Note, you MUST subclass MineRLAgentBase.
"""
......@@ -128,6 +128,11 @@ class MineRLMatrixAgent(MineRLAgentBase):
Args:
single_episode_env (Episode): The episode on which to run the agent.
"""
import torch
device = torch.device('cuda:0')
x = torch.randn(64, 1000, device=device, dtype=torch.float)
assert torch.cuda.is_available()
obs = single_episode_env.reset()
done = False
while not done:
......@@ -145,9 +150,9 @@ class MineRLRandomAgent(MineRLAgentBase):
while not done:
random_act = single_episode_env.action_space.sample()
single_episode_env.step(random_act)
#####################################################################
# IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING #
# IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING #
######################################################################
AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
......@@ -177,7 +182,7 @@ def main():
except EpisodeDone:
print("[{}] Episode complete".format(i))
pass
evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
for thread in evaluator_threads:
thread.start()
......@@ -188,5 +193,3 @@ def main():
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment