Skip to content

Predicate join operation #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ public fun <A, B> DataFrame<A>.join(
other: DataFrame<B>,
type: JoinType = JoinType.Inner,
selector: JoinColumnsSelector<A, B>? = null
): DataFrame<A> = joinImpl(other, type, true, selector)
): DataFrame<A> {
return joinImpl(other, type, addNewColumns = type.addNewColumns, selector)
}

public fun <A, B> DataFrame<A>.join(
other: DataFrame<B>,
Expand Down Expand Up @@ -116,10 +118,17 @@ public enum class JoinType {
Left, // all data from left data frame, nulls for mismatches in right data frame
Right, // all data from right data frame, nulls for mismatches in left data frame
Inner, // only matched data from right and left data frame
Filter, // only matched data from left data frame
Full, // all data from left and from right data frame, nulls for any mismatches
Exclude // mismatched rows from left data frame
}

internal val JoinType.addNewColumns: Boolean
get() = when (this) {
JoinType.Filter, JoinType.Exclude -> false
JoinType.Left, JoinType.Right, JoinType.Inner, JoinType.Full -> true
}

public val JoinType.allowLeftNulls: Boolean get() = this == JoinType.Right || this == JoinType.Full

public val JoinType.allowRightNulls: Boolean get() = this == JoinType.Left || this == JoinType.Full || this == JoinType.Exclude
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.jetbrains.kotlinx.dataframe.api

import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.Selector
import org.jetbrains.kotlinx.dataframe.impl.api.joinWithImpl

public interface JoinedDataRow<out A, out B> : DataRow<A> {
public val right: DataRow<B>
}

public typealias JoinExpression<A, B> = Selector<JoinedDataRow<A, B>, Boolean>

public fun <A, B> DataFrame<A>.joinWith(
right: DataFrame<B>,
type: JoinType = JoinType.Inner,
joinExpression: JoinExpression<A, B>
): DataFrame<A> {
return joinWithImpl(right, type, addNewColumns = type.addNewColumns, joinExpression)
}

public fun <A, B> DataFrame<A>.innerJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWith(right, JoinType.Inner, joinExpression)

public fun <A, B> DataFrame<A>.leftJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWith(right, JoinType.Left, joinExpression)

public fun <A, B> DataFrame<A>.rightJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWith(right, JoinType.Right, joinExpression)

public fun <A, B> DataFrame<A>.fullJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWith(right, JoinType.Full, joinExpression)

public fun <A, B> DataFrame<A>.filterJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWithImpl(right, JoinType.Inner, addNewColumns = false, joinExpression)

public fun <A, B> DataFrame<A>.excludeJoinWith(
right: DataFrame<B>,
joinExpression: JoinExpression<A, B>
): DataFrame<A> = joinWithImpl(right, JoinType.Exclude, addNewColumns = false, joinExpression)
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ import org.jetbrains.kotlinx.dataframe.type
import kotlin.reflect.full.withNullability

internal fun <A, B> defaultJoinColumns(left: DataFrame<A>, right: DataFrame<B>): JoinColumnsSelector<A, B> =
{ left.columnNames().intersect(right.columnNames()).map { it.toColumnAccessor() }.let { ColumnsList(it) } }
{ left.columnNames().intersect(right.columnNames().toSet()).map { it.toColumnAccessor() }.let { ColumnsList(it) } }

internal fun <T> defaultJoinColumns(dataFrames: Iterable<DataFrame<T>>): JoinColumnsSelector<T, T> =
{
dataFrames.map { it.columnNames() }.fold<List<String>, Set<String>?>(null) { set, names ->
set?.intersect(names) ?: names.toSet()
set?.intersect(names.toSet()) ?: names.toSet()
}.orEmpty().map { it.toColumnAccessor() }.let { ColumnsList(it) }
}

Expand Down Expand Up @@ -114,7 +114,7 @@ internal fun <A, B> DataFrame<A>.joinImpl(

// group row indices by key from right data frame
val groupedRight = when (joinType) {
JoinType.Exclude -> rightJoinKeyToIndex.map { it.first to emptyList<Int>() }.toMap()
JoinType.Exclude -> rightJoinKeyToIndex.associate { it.first to emptyList<Int>() }
else -> rightJoinKeyToIndex.groupBy({ it.first }) { it.second }
}

Expand All @@ -129,7 +129,7 @@ internal fun <A, B> DataFrame<A>.joinImpl(
}

// for every row index in right data frame store a flag indicating whether this row was matched by some row in left data frame
val rightMatched = Array(other.nrow) { false }
val rightMatched = BooleanArray(other.nrow) { false }

// number of rows in right data frame that were not matched by any row in left data frame. Used for correct allocation of an output array
var rightUnmatchedCount = other.nrow
Expand Down Expand Up @@ -162,8 +162,8 @@ internal fun <A, B> DataFrame<A>.joinImpl(
val newRightColumnsCount = newRightColumns.size
val outputColumnsCount = leftColumnsCount + newRightColumnsCount

val outputData = Array<Array<Any?>>(outputColumnsCount) { arrayOfNulls(outputRowsCount) }
val hasNulls = Array(outputColumnsCount) { false }
val outputData = List<Array<Any?>>(outputColumnsCount) { arrayOfNulls(outputRowsCount) }
val hasNulls = BooleanArray(outputColumnsCount) { false }

var row = 0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.jetbrains.kotlinx.dataframe.impl.api

import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.api.JoinExpression
import org.jetbrains.kotlinx.dataframe.api.JoinType
import org.jetbrains.kotlinx.dataframe.api.JoinedDataRow
import org.jetbrains.kotlinx.dataframe.api.allowLeftNulls
import org.jetbrains.kotlinx.dataframe.api.allowRightNulls
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.count
import org.jetbrains.kotlinx.dataframe.api.indices
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
import org.jetbrains.kotlinx.dataframe.impl.DataRowImpl

internal class JoinedDataRowImpl<A, B>(
leftOwner: DataFrame<A>,
val index: Int,
rightOwner: DataFrame<B>,
index1: Int
) : JoinedDataRow<A, B>, DataRowImpl<A>(index, leftOwner) {
override val right: DataRow<B> = DataRowImpl(index1, rightOwner)
}

internal fun <A, B> DataFrame<A>.joinWithImpl(
right: DataFrame<B>,
type: JoinType = JoinType.Inner,
addNewColumns: Boolean,
joinExpression: JoinExpression<A, B>
): DataFrame<A> {
val generator = ColumnNameGenerator(columnNames())
if (addNewColumns) {
right.columnNames().forEach { generator.addUnique(it) }
}
val rightColumnsCount = if (addNewColumns) right.columnsCount() else 0
val outputData = List(columnsCount() + rightColumnsCount) { mutableListOf<Any?>() }
val rightMatched = BooleanArray(right.count()) { false }
for (l in indices()) {
var leftMatched = false
for (r in right.indices()) {
val joined = JoinedDataRowImpl(this, l, right, r)
val matched = joinExpression(joined, joined)
if (matched && type == JoinType.Exclude) {
leftMatched = true
break
}
if (matched) {
rightMatched[r] = true
leftMatched = true
val left = get(l).values()
for (col in left.indices) {
outputData[col].add(left[col])
}
if (addNewColumns) {
val offset = left.size
val row = right.get(r).values()
for (col in row.indices) {
outputData[col + offset].add(row[col])
}
}
}
}
if (!leftMatched && type.allowRightNulls) {
val left = get(l).values()
for (col in left.indices) {
outputData[col].add(left[col])
}
if (addNewColumns) {
for (col in left.size..outputData.lastIndex) {
outputData[col].add(null)
}
}
}
}

if (type.allowLeftNulls) {
rightMatched.forEachIndexed { row, matched ->
if (!matched) {
repeat(columnsCount()) { col ->
outputData[col].add(null)
}
val offset = columnsCount()
val rowData = right[row].values()
for (col in rowData.indices) {
outputData[offset + col].add(rowData[col])
}
}
}
}

val df: DataFrame<*> = outputData.mapIndexed { index, values ->
DataColumn.createWithTypeInference(generator.names[index], values)
}.toDataFrame()

return df.cast()
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public fun <T> DataFrame<T>.html(): String = toStandaloneHTML().toString()
public fun <T> DataFrame<T>.toStandaloneHTML(
configuration: DisplayConfiguration = DisplayConfiguration.DEFAULT,
cellRenderer: CellRenderer = org.jetbrains.kotlinx.dataframe.jupyter.DefaultCellRenderer,
getFooter: (DataFrame<T>) -> String = { "DataFrame [${it.size}]" },
getFooter: (DataFrame<T>) -> String? = { "DataFrame [${it.size}]" },
): DataFrameHtmlData = toHTML(configuration, cellRenderer, getFooter).withTableDefinitions()

/**
Expand All @@ -192,25 +192,31 @@ public fun <T> DataFrame<T>.toStandaloneHTML(
public fun <T> DataFrame<T>.toHTML(
configuration: DisplayConfiguration = DisplayConfiguration.DEFAULT,
cellRenderer: CellRenderer = org.jetbrains.kotlinx.dataframe.jupyter.DefaultCellRenderer,
getFooter: (DataFrame<T>) -> String = { "DataFrame [${it.size}]" },
getFooter: (DataFrame<T>) -> String? = { "DataFrame [${it.size}]" },
): DataFrameHtmlData {
val limit = configuration.rowsLimit ?: Int.MAX_VALUE

val footer = getFooter(this)
val bodyFooter = buildString {
val openPTag = "<p class=\"dataframe_description\">"
if (limit < nrow) {
val bodyFooter = footer?.let {
buildString {
val openPTag = "<p class=\"dataframe_description\">"
if (limit < nrow) {
append(openPTag)
append("... showing only top $limit of $nrow rows</p>")
}
append(openPTag)
append("... showing only top $limit of $nrow rows</p>")
append(footer)
append("</p>")
}
append(openPTag)
append(footer)
append("</p>")
}

val tableHtml = toHtmlData(configuration, cellRenderer)
var tableHtml = toHtmlData(configuration, cellRenderer)

if (bodyFooter != null) {
tableHtml += DataFrameHtmlData("", bodyFooter, "")
}

return tableHtml + DataFrameHtmlData("", bodyFooter, "")
return tableHtml
}

/**
Expand Down
Loading