Commit 0d88ef17 authored by pfrl_rainbow's avatar pfrl_rainbow

pytorch with gpu

parent c1159605
......@@ -4,5 +4,5 @@
"authors": ["minerl_rainbow_baseline"],
"tags": "RL",
"description": "Test Model for MineRL Challenge",
"gpu": false
"gpu": true
}
......@@ -128,6 +128,11 @@ class MineRLMatrixAgent(MineRLAgentBase):
Args:
single_episode_env (Episode): The episode on which to run the agent.
"""
import torch
device = torch.device('cuda:0')
x = torch.randn(64, 1000, device=device, dtype=torch.float)
assert torch.cuda.is_available()
obs = single_episode_env.reset()
done = False
while not done:
......@@ -188,5 +193,3 @@ def main():
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment