Skip to content

Improve dataframe sorting in KTNB UI by handling non-comparable columns #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.jupyter
import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.api.Convert
import org.jetbrains.kotlinx.dataframe.api.FormatClause
import org.jetbrains.kotlinx.dataframe.api.FormattedFrame
Expand All @@ -25,12 +26,13 @@ import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.at
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.frames
import org.jetbrains.kotlinx.dataframe.api.getColumn
import org.jetbrains.kotlinx.dataframe.api.into
import org.jetbrains.kotlinx.dataframe.api.sortBy
import org.jetbrains.kotlinx.dataframe.api.isComparable
import org.jetbrains.kotlinx.dataframe.api.sortWith
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.values
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator

/**
Expand Down Expand Up @@ -62,6 +64,7 @@ public object KotlinNotebookPluginUtils {

/**
* Sorts a dataframe-like object by multiple columns.
* If a column type is not comparable, sorting by string representation is applied instead.
*
* @param dataFrameLike The dataframe-like object to sort.
* @param columnPaths The list of columns to sort by. Each element in the list represents a column path
Expand All @@ -79,27 +82,78 @@ public object KotlinNotebookPluginUtils {
}

/**
* Sorts the given data frame by the specified columns.
* Sorts a dataframe by multiple columns with specified sorting order for each column.
* If a column type is not comparable, sorting by string representation is applied instead.
*
* @param df The data frame to be sorted.
* @param columnPaths The paths of the columns to be sorted. Each path is represented as a list of strings.
* @param isDesc A list of booleans indicating whether each column should be sorted in descending order.
* The size of this list must be equal to the size of the columnPaths list.
* @return The sorted data frame.
* @param df The dataframe to be sorted.
* @param columnPaths A list of column paths where each path is a list of strings representing the hierarchical path of the column.
* @param isDesc A list of boolean values indicating whether each column should be sorted in descending order;
* true for descending, false for ascending. The size of this list should match the size of `columnPaths`.
* @return The sorted dataframe.
*/
public fun sortByColumns(df: AnyFrame, columnPaths: List<List<String>>, isDesc: List<Boolean>): AnyFrame =
df.sortBy {
require(columnPaths.all { it.isNotEmpty() })
require(columnPaths.size == isDesc.size)
public fun sortByColumns(df: AnyFrame, columnPaths: List<List<String>>, isDesc: List<Boolean>): AnyFrame {
require(columnPaths.all { it.isNotEmpty() })
require(columnPaths.size == isDesc.size)

val sortKeys = columnPaths.map { path ->
ColumnPath(path)
}

val comparator = createComparator(sortKeys, isDesc)

val sortKeys = columnPaths.map { path ->
ColumnPath(path)
return df.sortWith(comparator)
}

private fun createComparator(sortKeys: List<ColumnPath>, isDesc: List<Boolean>): Comparator<DataRow<*>> {
return Comparator { row1, row2 ->
for ((key, desc) in sortKeys.zip(isDesc)) {
val comparisonResult = if (row1.df().getColumn(key).isComparable()) {
compareComparableValues(row1, row2, key, desc)
} else {
compareStringValues(row1, row2, key, desc)
}
// If a comparison result is non-zero, we have resolved the ordering
if (comparisonResult != 0) return@Comparator comparisonResult
}
// All comparisons are equal
0
}
}

(sortKeys zip isDesc).map { (key, desc) ->
if (desc) key.desc() else key
}.toColumnSet()
@Suppress("UNCHECKED_CAST")
private fun compareComparableValues(
row1: DataRow<*>,
row2: DataRow<*>,
key: ColumnPath,
desc: Boolean,
): Int {
val firstValue = row1.getValueOrNull(key) as Comparable<Any?>?
val secondValue = row2.getValueOrNull(key) as Comparable<Any?>?

return when {
firstValue == null && secondValue == null -> 0
firstValue == null -> if (desc) 1 else -1
secondValue == null -> if (desc) -1 else 1
desc -> secondValue.compareTo(firstValue)
else -> firstValue.compareTo(secondValue)
}
}

private fun compareStringValues(
row1: DataRow<*>,
row2: DataRow<*>,
key: ColumnPath,
desc: Boolean,
): Int {
val firstValue = (row1.getValueOrNull(key)?.toString() ?: "")
val secondValue = (row2.getValueOrNull(key)?.toString() ?: "")

return if (desc) {
secondValue.compareTo(firstValue)
} else {
firstValue.compareTo(secondValue)
}
}

internal fun isDataframeConvertable(dataframeLike: Any?): Boolean =
when (dataframeLike) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,206 @@ class RenderingTests : JupyterReplTestCase() {
}
}

@Test
fun `test sortByColumns by int column`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf("nums")(5, 4, 3, 2, 1)
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("nums")), listOf(false))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

val rows = json[KOTLIN_DATAFRAME]!!.jsonArray
json.extractColumn<Int>(0, "nums") shouldBe 1
json.extractColumn<Int>(rows.size - 1, "nums") shouldBe 5
}

internal inline fun <reified T> JsonObject.extractColumn(index: Int, fieldName: String): T {
val element = this[KOTLIN_DATAFRAME]!!.jsonArray[index].jsonObject[fieldName]!!.jsonPrimitive
return when (T::class) {
String::class -> element.content as T
Int::class -> element.int as T
else -> throw IllegalArgumentException("Unsupported type")
}
}

@Test
fun `test sortByColumns by multiple int columns`() {
val json = executeScriptAndParseDataframeResult(
"""
data class Row(val a: Int, val b: Int)
val df = listOf(Row(1, 1), Row(1, 2), Row(2, 3), Row(2, 4), Row(3, 5), Row(3, 6)).toDataFrame()
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("a"), listOf("b")), listOf(true, false))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<Int>(0, "a") shouldBe 3
json.extractColumn<Int>(0, "b") shouldBe 5
json.extractColumn<Int>(5, "a") shouldBe 1
json.extractColumn<Int>(5, "b") shouldBe 2
}

@Test
fun `test sortByColumns by single string column`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf("letters")("e", "d", "c", "b", "a")
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("letters")), listOf(true))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<String>(0, "letters") shouldBe "e"
json.extractColumn<String>(4, "letters") shouldBe "a"
}

@Test
fun `test sortByColumns by multiple string columns`() {
val json = executeScriptAndParseDataframeResult(
"""
data class Row(val first: String, val second: String)
val df = listOf(Row("a", "b"), Row("a", "a"), Row("b", "b"), Row("b", "a")).toDataFrame()
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("first"), listOf("second")), listOf(false, true))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<String>(0, "first") shouldBe "a"
json.extractColumn<String>(0, "second") shouldBe "b"
json.extractColumn<String>(3, "first") shouldBe "b"
json.extractColumn<String>(3, "second") shouldBe "a"
}

@Test
fun `test sortByColumns by mix of int and string columns`() {
val json = executeScriptAndParseDataframeResult(
"""
data class Row(val num: Int, val letter: String)
val df = listOf(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"), Row(3, "a")).toDataFrame()
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("num"), listOf("letter")), listOf(true, false))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<Int>(0, "num") shouldBe 3
json.extractColumn<String>(0, "letter") shouldBe "a"
json.extractColumn<Int>(4, "num") shouldBe 1
json.extractColumn<String>(4, "letter") shouldBe "b"
}

@Test
fun `test sortByColumns by multiple non-comparable column`() {
val json = executeScriptAndParseDataframeResult(
"""
data class Person(val name: String, val age: Int) {
override fun toString(): String {
return age.toString()
}
}
val df = dataFrameOf("urls", "person")(
URL("https://example.com/a"), Person("Alice", 10),
URL("https://example.com/b"), Person("Bob", 11),
URL("https://example.com/a"), Person("Nick", 12),
URL("https://example.com/b"), Person("Guy", 13),
)
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("urls"), listOf("person")), listOf(false, true))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<Int>(0, "person") shouldBe 12
json.extractColumn<Int>(3, "person") shouldBe 11
}

@Test
fun `test sortByColumns by mix of comparable and non-comparable columns`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf("urls", "id")(
URL("https://example.com/a"), 1,
URL("https://example.com/b"), 2,
URL("https://example.com/a"), 2,
URL("https://example.com/b"), 1,
)
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("id"), listOf("urls")), listOf(true, true))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<String>(0, "urls") shouldBe "https://example.com/b"
json.extractColumn<Int>(0, "id") shouldBe 2
json.extractColumn<String>(3, "urls") shouldBe "https://example.com/a"
json.extractColumn<Int>(3, "id") shouldBe 1
}

@Test
fun `test sortByColumns by url column`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf("urls")(
URL("https://example.com/a"),
URL("https://example.com/c"),
URL("https://example.com/b"),
URL("https://example.com/d")
)
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("urls")), listOf(false))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<String>(0, "urls") shouldBe "https://example.com/a"
json.extractColumn<String>(1, "urls") shouldBe "https://example.com/b"
json.extractColumn<String>(2, "urls") shouldBe "https://example.com/c"
json.extractColumn<String>(3, "urls") shouldBe "https://example.com/d"
}

@Test
fun `test sortByColumns by column group children`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf(
"a" to listOf(5, 4, 3, 2, 1),
"b" to listOf(1, 2, 3, 4, 5)
)
val res = KotlinNotebookPluginUtils.sortByColumns(df.group("a", "b").into("c"), listOf(listOf("c", "a")), listOf(false))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

fun JsonObject.extractBFields(): List<Int> {
val dataframe = this[KOTLIN_DATAFRAME]!!.jsonArray
return dataframe.map { it.jsonObject["c"]!!.jsonObject["data"]!!.jsonObject["b"]!!.jsonPrimitive.int }
}

val bFields = json.extractBFields()
bFields shouldBe listOf(5, 4, 3, 2, 1)
}

@Test
fun `test sortByColumns for column that contains string and int`() {
val json = executeScriptAndParseDataframeResult(
"""
val df = dataFrameOf("mixed")(
5,
"10",
2,
"4",
"1"
)
val res = KotlinNotebookPluginUtils.sortByColumns(df, listOf(listOf("mixed")), listOf(true))
KotlinNotebookPluginUtils.convertToDataFrame(res)
""".trimIndent(),
)

json.extractColumn<String>(0, "mixed") shouldBe "5"
json.extractColumn<String>(1, "mixed") shouldBe "4"
json.extractColumn<String>(2, "mixed") shouldBe "2"
json.extractColumn<String>(3, "mixed") shouldBe "10"
json.extractColumn<String>(4, "mixed") shouldBe "1"
}

companion object {
/**
* Set the system property for the IDE version needed for specific serialization testing purposes.
Expand Down
Loading
Loading