Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Training] Additional gradients #8307

Merged
merged 3 commits into from
Jun 24, 2021
Merged

Conversation

altanh
Copy link
Contributor

@altanh altanh commented Jun 22, 2021

New gradients:

  • cast_like
  • not_equal
  • strided_slice
  • one_hot

Also simplified log_softmax gradient.

For strided_slice, I wasn't sure how to add support for the recently introduced axes argument (which if I understand correctly aims to circumvent limitations in the type system for dynamic shape inference), so I just added a check. In my mind, strided_slice (as opposed to dyn.strided_slice) should be used when everything is concrete, but I think now it allows for limited shape dynamism which atm isn't a good idea for training anyway. cc @masahi for confirmation

@masahi
Copy link
Member

masahi commented Jun 22, 2021

Yes, strided_slice with static begin, end etc does support dynamic input shape. If some dims in input shape are dynamic, the corresponding output dims would also be dynamic, regard less of begin, end, and stride.

if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(ishape[axes[i]]);
ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
int interval = std::abs(end_i - begin_i);
int slice_size =
static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else if (use_any) {
out_shape.Set(axes[i], tvm::tir::Any());
} else {
out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype));
}

if (ishape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(ishape[axes[i]]);
int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
begin_expr.push_back(make_const(dtype, begin_i));
} else {
auto idim = ishape[axes[i]];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
if (s < 0) {
b = tvm::min(b, idim - 1);
} else {
b = tvm::if_then_else(b < 0, 0, b);
}
begin_expr.push_back(b);
}

@manupak
Copy link
Contributor

manupak commented Jun 23, 2021

This looks interesting. Do we have a RFC or a tracking issue that indicates where the Training support of TVM is going ?

@altanh
Copy link
Contributor Author

altanh commented Jun 23, 2021

This looks interesting. Do we have a RFC or a tracking issue that indicates where the Training support of TVM is going ?

There's no RFC up at the moment, but we've had some discussion threads (see https://discuss.tvm.apache.org/t/two-missing-pieces-for-training/10037 for example). I'm currently in the process of upstreaming general improvements (gradients, AD bug fixes / improvements, new training ops, other supporting changes like contrib library support) and longer term hoping to open source a proof-of-concept TVM training framework around Q3 (developed at OctoML). Would love to hear your thoughts and perhaps it would be worth opening a long term tracking/discussion thread for training related topics!

cc @tqchen

@tqchen tqchen merged commit b9d2899 into apache:main Jun 24, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
zxy844288792 pushed a commit to zxy844288792/tvm that referenced this pull request Mar 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants