Skip to content

[SPARK-45926][SQL] Implementing equals and hashCode which takes into account pushed runtime filters , in InMemoryTable related scans #49153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,15 @@ abstract class InMemoryBaseTable(
tableSchema: StructType,
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeFiltering {

private var allFilters: Set[Filter] = Set.empty
override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}

override def filter(filters: Array[Filter]): Unit = {
allFilters = allFilters ++ filters.toSet
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
val ref = partitioning.head.references().head
filters.foreach {
Expand All @@ -491,6 +492,16 @@ abstract class InMemoryBaseTable(
}
}
}

override def equals(other: Any): Boolean = other match {
case imbs: InMemoryBatchScan => this.readSchema == imbs.readSchema &&

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ahshahid , I am not from Databricks, and unfortunately I am not familar with Spark code enough to review this PR, but I do encounter some problems with equals overrides in Spark in general, so dropping my point here and see if you have any idea.

Any reason you don't add

case imbs: InMemoryBatchScan => this.getClass == imbs.getClass && ...

to the equals check? Would this have a negative implication to your use case?

The reason I wanted stricter equals check is because of the project https://github.com/apache/incubator-gluten. In Gluten, which is a middle layer between Spark and native engines, (most) operators inherit Spark operators. If we don't do class equivalence check here, a Spark operator and a Gluten (native) operator would be regarded as equal.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @li-boxuan ,
In most of the cases, the spark classes are case classes of scala and usually its not recommended to override equals and hashCode of case classes, unless there is a specific requirement ( like you want to exclude certain member fields from equality check or need to do special handling ).
As for the reason why there is no
case imbs: InMemoryBatchScan => this.getClass == imbs.getClass &&
is because that its already getting accomplished.
The code snippet here implies that the "this" instance is InMemoryBatchScan ( because we are in its equal's method), and the case imbs: InMemoryBatchScan ensures that "other" is also of type InMemoryBatchScan.
which accomplishes what you are hinting at.

Regarding the issue which you are hitting , that would be possible only in those cases where the Spark Operator classes are not case classes (& there are only some exceptions) and case classes cannot get extended , right?

So if you provide some more details as to where the match is happening incorrectly , may shed some more light on the issue.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be possible only in those cases where the Spark Operator classes are not case classes (& there are only some exceptions) and case classes cannot get extended , right?

Aha, thanks. I am new to scala and didn't notice the difference between case classes and normal classes. You are right, this is a case class, which presumably shouldn't be extended - even though Scala compiler doesn't prohibit you from doing that. Gluten as a downstream application uses a plugin to explicitly prohibit that behavior, so we are good. Not an issue then. I saw an issue with some other classes, but it didn't happen with case class (thanks to the plugin).

That being said, it's probably safe to add such a check, because after all, Scala doesn't prohibit one from (wrongly) inheriting case classes. But anyways, not an issue for my use case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know that this rule of case class not getting inherited could be circumvented.. Thanks .
I suppose committers can think over it... Me neither with data bricks nor a committer.

this.tableSchema == imbs.tableSchema && this.allFilters == imbs.allFilters

case _ => false
}

override def hashCode: Int = Objects.hashCode(this.readSchema, this.tableSchema,
this.allFilters)
}

abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.catalog

import java.util

import com.google.common.base.Objects
import org.scalatest.Assertions.assert

import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
Expand Down Expand Up @@ -65,13 +66,16 @@ class InMemoryTableWithV2Filter(
options: CaseInsensitiveStringMap)
extends BatchScanBaseClass(_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {

private var allFilters: Set[Predicate] = Set.empty

override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
partitioning.flatMap(_.references)
.filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
}

override def filter(filters: Array[Predicate]): Unit = {
allFilters = allFilters ++ filters.toSet
if (partitioning.length == 1 && partitioning.head.references().length == 1) {
val ref = partitioning.head.references().head
filters.foreach {
Expand All @@ -90,6 +94,16 @@ class InMemoryTableWithV2Filter(
}
}
}

override def equals(other: Any): Boolean = other match {
case imbs: InMemoryV2FilterBatchScan => this.readSchema == imbs.readSchema &&
this.tableSchema == imbs.tableSchema && this.allFilters == imbs.allFilters

case _ => false
}

override def hashCode: Int = Objects.hashCode(this.readSchema, this.tableSchema,
this.allFilters)
}

override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
Expand Down