Skip to content

Commit

Permalink
Check size on transformer cache
Browse files Browse the repository at this point in the history
  • Loading branch information
graemenail committed Jun 8, 2022
1 parent e88c1aa commit ca0aeb2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class Transformer : public EncoderOrDecoderBase {
// memoization propagation (short-term)
if (cache // if caching
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
&& cache_[prefix + "_keys"]->shape() == keys->shape()) { // and the underlying shape did not change
kh = cache_[prefix + "_keys"]; // then return cached tensor
}
else {
Expand All @@ -296,7 +296,7 @@ class Transformer : public EncoderOrDecoderBase {
Expr vh;
if (cache
&& cache_.count(prefix + "_values") > 0
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
&& cache_[prefix + "_values"]->shape() == values->shape()) {
vh = cache_[prefix + "_values"];
} else {
auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform());
Expand Down

0 comments on commit ca0aeb2

Please sign in to comment.