-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import oneflow as torch #6076
Import oneflow as torch #6076
Conversation
有个png图片被提交了 |
@@ -0,0 +1,66 @@ | |||
import oneflow as flow | |||
import oneflow.nn as nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/Oneflow-Inc/models/blob/main/scripts/compare_speed_with_pytorch.py#L67-L87 这里实现了把 pytorch 模型文件里的 import torch 覆盖为 import oneflow as torch,这样不需要维护两套模型文件了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不仅仅是import,里面有一些op torchvsion直接使用的torch.xxx,需要改成oneflow.xxx,这里可以做到吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为已经 import oneflow as torch 了,所以 torch.xxx 就是 oneflow.xxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦,对,好的,我修改一下
@@ -0,0 +1,450 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件是放错目录了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有个png图片被提交了
那个图片在我本地是没有的了,而且我在pr里点删除是灰的,很奇怪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有放错,我开了一个models来存所有的模型文件
@@ -0,0 +1,228 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件太长,上个comment写错地方了。是这个文件好像放错地方了?test util 里面不应该放 unittest.main() 的东西吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这个文件应该放在哪里呢?modules吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这个文件应该放在哪里呢?modules吗
嗯,不然不会被测试的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我的理解 test util 下提供一下用例会用到的公共函数,自身不应该是直接运行的测试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和auotest类似调整了目录结构,将oneflow_pytorch_compatiblity_test需要的公共函数当成一个包导入,在modules下进行测试。
@@ -0,0 +1,30 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note 去掉oneflow export之后这种hack的文件可以不用写了,直接import oneflow.test_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除。
new_parameters[k] = flow.tensor(w[k].detach().numpy()) | ||
|
||
try: | ||
shutil.rmtree("/dataset/imagenet/compatiblity_models") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note
不应该直接写 /dataset
目录,应该用 python tempfile 创建临时目录
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已解决。
…Inc/oneflow into import_oneflow_as_torch
…low into import_oneflow_as_torch
with open("/tmp/tmp_model.py", "w") as new_f: | ||
new_f.write(buf) | ||
|
||
python_module = import_file("/tmp/tmp_model.py") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note
import_file 这个函数应该改成直接接受源码,在里面处理零时文件创建和回收的逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
|
||
|
||
def import_file(path): | ||
spec = importlib.util.spec_from_file_location("mod", path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note
记得加一下flush
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加
CI failed, removing label automerge |
…low into import_oneflow_as_torch
CI failed, removing label automerge |
Speed stats:
|
此PR将TorchVison的常见模型加入CI测试,确保可以import oneflow as torch和 import torch as oneflow可以跑同一份模型代码(兼容性)。目前已经支持如下模型的测试,评价指标是100个iter的loss相似度: