-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flow xxx and tensor xxx autotest #5386
Conversation
device: str = "cuda", | ||
training: bool = True, | ||
backward: bool = True, | ||
rtol=1e-4, | ||
atol=1e-5, | ||
n=20, | ||
pytorch_module_class_name=None, | ||
api_flag: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还有一些 0 1 2 没有改成 TEST_MODULE 等常量
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
def has_full_args_spec(args): | ||
if args == set(): | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
啊不是的,是这样:
def has_full_args_spec(callable):
try:
spec = inspect.getfullargspec(callable)
return True
except Exception:
return False
然后 164 行-169 行可以一起改成
torch_module_class = eval(f"torch.{pytorch_module_class_name}")
if has_full_arg_spec(torch_module_class):
spec = inspect.getfullargspec(torch_module_class)
else:
spec = xxxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
try: | ||
torch_module_class = eval(f"torch.{pytorch_module_class_name}") | ||
spec = inspect.getfullargspec(torch_module_class) | ||
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 as e 可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好
args = (set(spec.args) | set(spec.kwonlyargs)) - {"self"} | ||
if has_full_args_spec(args) == False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if has_full_args_spec(args) == False: | |
if not has_full_args_spec(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
if has_default(name): | ||
if rng.random() < 1 / 3: | ||
continue | ||
if api_flag == TEST_MODULE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这个 if 是不是也没有必要,不管测试什么都需要处理 default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的, 已删除
test_case, | ||
module_class_name, | ||
extra_annotations: Optional[Dict[str, Any]] = None, | ||
extra_generators: Optional[Dict[str, Any]] = None, | ||
extra_defaults: Optional[Dict[str, Any]] = None, | ||
device: str = "cuda", | ||
training: bool = True, | ||
backward: bool = True, | ||
rtol=1e-4, | ||
atol=1e-5, | ||
n=20, | ||
pytorch_module_class_name=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里和 165 行的变量名需要改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好
return True | ||
except Exception: | ||
return False | ||
|
||
torch_module_class = eval(f"torch.{pytorch_module_class_name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api_flag 是 TEST_TENSOR 的时候,这里是不是要改,先生成一个 torch tensor,再获取函数对象、拿到 spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -183,9 +209,6 @@ def generate(name): | |||
flow_attr_dict = {} | |||
torch_attr_dict = {} | |||
for name in args: | |||
if has_default(name): | |||
if rng.random() < 1 / 3: | |||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这为什么删掉了😂
@@ -127,18 +132,20 @@ def generator(_): | |||
return generator | |||
|
|||
|
|||
def test_module_against_pytorch( | |||
def test_against_pytorch( | |||
test_case, | |||
module_class_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数名也要改,不能出现 module 字样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
assert args == set( | ||
annotations.keys() | ||
annotations.keys() | extra_defaults.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行改动是不需要的
….com/Oneflow-Inc/oneflow into add_flow_xxx_and_tensor_xxx_autotest
给flow.xxx 和 flow.Tensor.xxx 添加Pytorch自动测试测试。
test_module_against_pytorch 测试module
test_flow_against_pytorch 测试flow.xxx
test_tensor_against_pytorch 测试flow.Tensor.xxx