Skip to content

Commit 06ec749

Browse files
committed
introduce a version of matmul which does not depend on first arg
Former-commit-id: 46c9a58
1 parent 150d4e4 commit 06ec749

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

dynet/dynet.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,8 +636,10 @@ struct Node {
636636

637637
Device* device; /**< pointer to the node, or null to inherit device from first input, or default when there is no input */
638638

639+
unsigned matmul_count; // how many matmul nodes am I an arg of?
640+
639641
protected:
640-
Node() : args(), device(default_device) {}
642+
Node() : args(), device(default_device), matmul_count(0) {}
641643
explicit Node(const std::initializer_list<VariableIndex>& a) : args(a), device(default_device) {}
642644
template <typename T>
643645
explicit Node(const T&c) : args(c.begin(), c.end()), device(default_device) {}

dynet/expr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Expression operator+(const Expression& x, real y) { return y + x; }
5252
Expression operator-(const Expression& x, const Expression& y) { return x + (-y); }
5353
Expression operator-(real x, const Expression& y) { return Expression(y.pg, y.pg->add_function<ConstantMinusX>({y.i}, x)); }
5454
Expression operator-(const Expression& x, real y) { return -(y - x); }
55-
Expression operator*(const Expression& x, const Expression& y) { return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
55+
Expression operator*(const Expression& x, const Expression& y) { x.pg->nodes[x.i]->matmul_count++; return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
5656
Expression operator*(const Expression& x, float y) { return Expression(x.pg, x.pg->add_function<ConstScalarMultiply>({x.i}, y)); }
5757
Expression cmult(const Expression& x, const Expression& y) {
5858
if (x.dim().batch_size() == 1)

dynet/nodes-common.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,16 +879,25 @@ int MatrixMultiply::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const
879879
// TODO do we want to treat different dimensions of first/second arg differently?
880880
if(dim.bd == 1) {
881881
Sig s(nt::matmul);
882-
s.add_node(args[0]);
882+
// if arg0 is likely to be shared, include it in the sig.
883+
// otherwise, include both args dims in the sig.
884+
if (cg.nodes[args[0]]->matmul_count > 2) { //TODO why 2? can we set a better number?
885+
s.add_node(args[0]); s.add_dim(cg.nodes[args[1]]->dim);
886+
} else {
887+
s.add_dim(cg.nodes[args[0]]->dim); s.add_dim(cg.nodes[args[1]]->dim);
888+
}
883889
return sm.get_idx(s);
884890
} else {
885891
return 0; // TODO handle the batched case as well? should it differ at all?
886892
}
887893
}
888894

889895
std::vector<int> MatrixMultiply::autobatch_concat(const ComputationGraph & cg) const {
890-
vector<int> ret(args.size(), 0);
891-
if (dim.bd == 1) { ret[1] = 1; }
896+
vector<int> ret(2, 0);
897+
if (dim.bd == 1) {
898+
ret[1] = 1;
899+
if (cg.nodes[args[0]]->matmul_count <= 2) { ret[0] = 1; }
900+
}
892901
return ret;
893902
}
894903

0 commit comments

Comments
 (0)