@@ -787,17 +787,17 @@ namespace {
787787
788788template <typename LeftSparseIndexType, typename RightSparseIndexType>
789789struct SparseTensorEqualsImpl {
790- static bool Compare (const SparseTensor <LeftSparseIndexType>& left,
791- const SparseTensor <RightSparseIndexType>& right) {
790+ static bool Compare (const SparseTensorImpl <LeftSparseIndexType>& left,
791+ const SparseTensorImpl <RightSparseIndexType>& right) {
792792 // TODO(mrkn): should we support the equality among different formats?
793793 return false ;
794794 }
795795};
796796
797797template <typename SparseIndexType>
798798struct SparseTensorEqualsImpl <SparseIndexType, SparseIndexType> {
799- static bool Compare (const SparseTensor <SparseIndexType>& left,
800- const SparseTensor <SparseIndexType>& right) {
799+ static bool Compare (const SparseTensorImpl <SparseIndexType>& left,
800+ const SparseTensorImpl <SparseIndexType>& right) {
801801 DCHECK (left.type ()->id () == right.type ()->id ());
802802 DCHECK (left.shape () == right.shape ());
803803 DCHECK (left.non_zero_length () == right.non_zero_length ());
@@ -821,19 +821,19 @@ struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
821821 }
822822};
823823
824- template <typename SparseTensorType >
825- inline bool SparseTensorEqualsImplDispatch (const SparseTensor<SparseTensorType >& left,
826- const SparseTensorBase & right) {
824+ template <typename SparseIndexType >
825+ inline bool SparseTensorEqualsImplDispatch (const SparseTensorImpl<SparseIndexType >& left,
826+ const SparseTensor & right) {
827827 switch (right.sparse_tensor_format_id ()) {
828828 case SparseTensorFormat::COO: {
829- const auto & right_coo = checked_cast<const SparseTensor <SparseCOOIndex>&>(right);
830- return SparseTensorEqualsImpl<SparseTensorType , SparseCOOIndex>::Compare (left,
829+ const auto & right_coo = checked_cast<const SparseTensorImpl <SparseCOOIndex>&>(right);
830+ return SparseTensorEqualsImpl<SparseIndexType , SparseCOOIndex>::Compare (left,
831831 right_coo);
832832 }
833833
834834 case SparseTensorFormat::CSR: {
835- const auto & right_csr = checked_cast<const SparseTensor <SparseCSRIndex>&>(right);
836- return SparseTensorEqualsImpl<SparseTensorType , SparseCSRIndex>::Compare (left,
835+ const auto & right_csr = checked_cast<const SparseTensorImpl <SparseCSRIndex>&>(right);
836+ return SparseTensorEqualsImpl<SparseIndexType , SparseCSRIndex>::Compare (left,
837837 right_csr);
838838 }
839839
@@ -844,7 +844,7 @@ inline bool SparseTensorEqualsImplDispatch(const SparseTensor<SparseTensorType>&
844844
845845} // namespace
846846
847- bool SparseTensorEquals (const SparseTensorBase & left, const SparseTensorBase & right) {
847+ bool SparseTensorEquals (const SparseTensor & left, const SparseTensor & right) {
848848 if (&left == &right) {
849849 return true ;
850850 } else if (left.type ()->id () != right.type ()->id ()) {
@@ -859,12 +859,12 @@ bool SparseTensorEquals(const SparseTensorBase& left, const SparseTensorBase& ri
859859
860860 switch (left.sparse_tensor_format_id ()) {
861861 case SparseTensorFormat::COO: {
862- const auto & left_coo = checked_cast<const SparseTensor <SparseCOOIndex>&>(left);
862+ const auto & left_coo = checked_cast<const SparseTensorImpl <SparseCOOIndex>&>(left);
863863 return SparseTensorEqualsImplDispatch (left_coo, right);
864864 }
865865
866866 case SparseTensorFormat::CSR: {
867- const auto & left_csr = checked_cast<const SparseTensor <SparseCSRIndex>&>(left);
867+ const auto & left_csr = checked_cast<const SparseTensorImpl <SparseCSRIndex>&>(left);
868868 return SparseTensorEqualsImplDispatch (left_csr, right);
869869 }
870870
0 commit comments