Skip to content

Commit

Permalink
fix pytest error on non-linux system (thu-ml#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored May 12, 2022
1 parent bf8f63f commit a03f19a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
14 changes: 7 additions & 7 deletions test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import argparse
import os
import sys

try:
import envpool
except ImportError:
envpool = None

import numpy as np
import pytest
Expand All @@ -19,6 +13,11 @@
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, ActorProb, Critic

try:
import envpool
except ImportError:
envpool = None


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -57,8 +56,9 @@ def get_args():
return args


@pytest.mark.skipif(sys.platform != "linux", reason="envpool only support linux now")
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)
Expand Down
9 changes: 8 additions & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os
import pprint

import envpool
import gym
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -15,6 +15,11 @@
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic

try:
import envpool
except ImportError:
envpool = None


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -52,7 +57,9 @@ def get_args():
return args


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)
Expand Down
9 changes: 8 additions & 1 deletion test/modelbased/test_psrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import os
import pprint

import envpool
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -12,6 +12,11 @@
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger

try:
import envpool
except ImportError:
envpool = None


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -40,7 +45,9 @@ def get_args():
return parser.parse_known_args()[0]


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_psrl(args=get_args()):
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make_gym(
args.task, num_envs=args.training_num, seed=args.seed
)
Expand Down

0 comments on commit a03f19a

Please sign in to comment.