-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support consistent_tensor.to(copy=True) #6122
Conversation
0cb7244
to
c361d60
Compare
python/oneflow/nn/modules/to.py
Outdated
@@ -185,7 +185,7 @@ def to_op(input, *args, **kwargs): | |||
) | |||
|
|||
if copy is True: | |||
raise TypeError("A consistent tensor do not support to(copy=True)") | |||
return input.detach().clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接这样改不对?还是需要传到 _consistent_tensor_to 里,因为上面的 device dtype 可能有值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该把 copy 传到 _consistent_tensor_to 里面去,然后在 55 行做下改动
if device_type == input.placement.device_type and dtype == input.dtype:
if copy:
return input.detach().clone()
else:
return input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是应该直接把copy的逻辑放到_consistent_tensor_to
里面去,和_tensor_to
接口对齐。
目前这个to的逻辑在python里面有点重,后续找人挪到C++吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看上去其实还不能有detach。因为原本有可能requires_grad就是为True。
CI failed, removing label automerge |
@@ -53,7 +53,7 @@ def _consistent_tensor_to(input, device_type, dtype): | |||
assert isinstance(dtype, flow.dtype) | |||
|
|||
if device_type == input.placement.device_type and dtype == input.dtype: | |||
return input | |||
return input if not copy else input.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert这行,失败的那个test_graph_asymmetric_io.py测试就可以过
cur_rank_phy_tensor = std::make_shared<MirroredTensor>(cur_rank_phy_tensor_impl); | ||
} else { | ||
const auto& dtype_symbol = JUST(DType::Get(dtype)); | ||
const auto& empty = JUST(functional::Empty(*cur_rank_phy_shape, dtype_symbol, device)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
虽然cur_shape_phy_shape为全0 shape。但如果不执行Empty op,将导致eager_blob_object的blob对象为空。
BlockingCounter bc(1); | ||
JUST(PhysicalRun([&bc](InstructionsBuilder* builder) -> Maybe<void> { | ||
JUST(builder->ComputeGlobalFrontSeqBarrier()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NNGraph同步vm的时候不应该操心其他rank的vm。
@@ -40,9 +40,11 @@ Maybe<void> Run(vm::InstructionMsgList* instr_msg_list) { | |||
return Maybe<void>::Ok(); | |||
} | |||
|
|||
Maybe<void> SingleClientSync() { | |||
Maybe<void> ClusterSync() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重名的原因是这里兼职了旧版的SingleClientSync和MultiClientSync
Speed stats:
|
No description provided.