-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Training] Additional gradients #8307
Conversation
Yes, tvm/include/tvm/topi/detail/strided_slice.h Lines 132 to 147 in cbe3dca
tvm/include/tvm/topi/detail/strided_slice.h Lines 98 to 113 in cbe3dca
|
This looks interesting. Do we have a RFC or a tracking issue that indicates where the Training support of TVM is going ? |
There's no RFC up at the moment, but we've had some discussion threads (see https://discuss.tvm.apache.org/t/two-missing-pieces-for-training/10037 for example). I'm currently in the process of upstreaming general improvements (gradients, AD bug fixes / improvements, new training ops, other supporting changes like contrib library support) and longer term hoping to open source a proof-of-concept TVM training framework around Q3 (developed at OctoML). Would love to hear your thoughts and perhaps it would be worth opening a long term tracking/discussion thread for training related topics! cc @tqchen |
New gradients:
cast_like
not_equal
strided_slice
one_hot
Also simplified
log_softmax
gradient.For
strided_slice
, I wasn't sure how to add support for the recently introducedaxes
argument (which if I understand correctly aims to circumvent limitations in the type system for dynamic shape inference), so I just added a check. In my mind,strided_slice
(as opposed todyn.strided_slice
) should be used when everything is concrete, but I think now it allows for limited shape dynamism which atm isn't a good idea for training anyway. cc @masahi for confirmation