-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add: test convert dependency #6023
Conversation
@flow.unittest.skip_unless_1n1d() | ||
class TestConvertDependency(flow.unittest.TestCase): | ||
def test_get_params(test_case): | ||
model_dir_path = "alexnet_oneflow_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个路径ci是默认有的吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个需要下载预训练的模型参数
|
||
p_size = re.compile(r"size=\(.*?\)", re.S) | ||
p_type = re.compile(r"dtype=.*?,", re.S) | ||
types = ["INPUT", "PARAMETER", "BUFFER", "OUTPUT"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Graph的input和output有这些类型可能出现:Tensor、None、TensorTuple、List[Tensor]
这里只考虑了Tensor?
不过repr里面的确把TensorTuple、List[Tensor]展开成Tensor了,参考这个pr:#5803
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外,这里的“通过graph获取每一个节点的shape和dtype”,repr这里只有graph和module级别的,没有op级别的,也不影响?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里主要是获取到每一个节点的信息,在之前是可以通过job.helper获取的,但是现在helper好像是None。这里的input没有考虑None的情况,在转到tvm的过程中当input没有的时候在转换过程中会直接报错。关于op级别的节点信息在转换过程中会从repr(graph)解析出来的信息中提取,应该没有影响。
) | ||
) | ||
if not graph._is_compiled: | ||
_ = graph._compile(flow.rand(shape_input)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_compile后面比如我们转为public接口,改成 compile,怎么处理,提示要match oneflow的版本?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.5.0及之前我们以这个测试作为接口约定。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_compile后面比如我们转为public接口,改成 compile,怎么处理,提示要match oneflow的版本?
请问一下_compile转为compile是在本周内完成的吗,如果比较快的话这个以及后面的部分(获取所有nodes)可以先省略,等graph开发完全了再提一个PR补回来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
短期内不改,在graph各种训练功能稳定后,才考虑把这个改为public接口。你可以赖现在这个。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
if size_attr[-2] == ",": | ||
size_attr = size_attr.replace(",", "") | ||
if type_attr[-1] == ",": | ||
type_str = type_attr.replace(",", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个检查有点弱,只能保证哟内容,最好检查下内容是对的。
@@ -0,0 +1,105 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_xx_convert_dependency.py
xx最好明确下
或者叫 test_api_dependency_on_graph.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这个脚本tvm转换和onnx转换都有用到所以一开始没有做区分
|
好的
是不是可以选取一个tensor,写死对它的检查就好
好的
可以自己构造一个module,里面注册一个:
参见:oneflow/python/oneflow/test/graph/test_graph.py
resnet50中的bn
|
谢谢你的建议 @strint
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
可以改一下名字:test_api_dependency_on_graph.py -> test_tvm_fronted_api_dependency_on_graph.py |
好的,改好了 |
CI failed, removing label automerge |
CI failed, removing label automerge |
CI failed, removing label automerge |
CI failed, removing label automerge |
Speed stats:
|
tvm 和 oneflow_convert_tool 需要通过graph获取每一个节点的shape和dtype
目前对repr(garph)的依赖:
其中input和output应该为计算图的i/o, 命名类似_OneFlowGraph0-input_0
其中buffer类似batchnorm算子中的running_mean/var
同时对flow.load的返回值有依赖: