Skip to content

Commit ec4ec4f

Browse files
committed
Make sortBy(ColumnReference) accept pathOf without extra cast
1 parent f5a06eb commit ec4ec4f

File tree

4 files changed

+671
-4
lines changed

4 files changed

+671
-4
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public fun <T, C> DataFrame<T>.sortBy(columns: SortColumnsSelector<T, C>): DataF
107107
UnresolvedColumnsPolicy.Fail, columns
108108
)
109109

110-
public fun <T> DataFrame<T>.sortBy(vararg cols: ColumnReference<Comparable<*>?>): DataFrame<T> =
110+
public fun <T> DataFrame<T>.sortBy(vararg cols: ColumnReference<*>): DataFrame<T> =
111111
sortBy { cols.toColumnSet() }
112112

113113
public fun <T> DataFrame<T>.sortBy(vararg cols: String): DataFrame<T> = sortBy { cols.toColumnSet() }
@@ -132,7 +132,7 @@ public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: KProperty<Comparable<C
132132

133133
public fun <T> DataFrame<T>.sortByDesc(vararg columns: String): DataFrame<T> = sortByDesc { columns.toColumnSet() }
134134

135-
public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<Comparable<C>?>): DataFrame<T> =
135+
public fun <T> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<*>): DataFrame<T> =
136136
sortByDesc { columns.toColumnSet() }
137137

138138
// endregion
@@ -141,7 +141,7 @@ public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<Compar
141141

142142
public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: String): GroupBy<T, G> = sortBy { cols.toColumnSet() }
143143

144-
public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: ColumnReference<Comparable<*>?>): GroupBy<T, G> =
144+
public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: ColumnReference<*>): GroupBy<T, G> =
145145
sortBy { cols.toColumnSet() }
146146

147147
public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: KProperty<Comparable<*>?>): GroupBy<T, G> =
@@ -151,7 +151,7 @@ public fun <T, G, C> GroupBy<T, G>.sortBy(selector: SortColumnsSelector<G, C>):
151151

152152
public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: String): GroupBy<T, G> = sortByDesc { cols.toColumnSet() }
153153

154-
public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: ColumnReference<Comparable<*>?>): GroupBy<T, G> =
154+
public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: ColumnReference<*>): GroupBy<T, G> =
155155
sortByDesc { cols.toColumnSet() }
156156

157157
public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: KProperty<Comparable<*>?>): GroupBy<T, G> =

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

3+
import io.kotest.assertions.throwables.shouldThrowMessage
34
import io.kotest.matchers.shouldBe
45
import org.jetbrains.kotlinx.dataframe.DataColumn
6+
import org.jetbrains.kotlinx.dataframe.io.readDataFrame
57
import org.jetbrains.kotlinx.dataframe.nrow
8+
import org.jetbrains.kotlinx.dataframe.testResource
9+
import org.jetbrains.kotlinx.dataframe.testSets.*
10+
import org.jetbrains.kotlinx.dataframe.testSets.DsSalaries
611
import org.junit.Test
712

813
class SortDataColumn {
@@ -67,4 +72,27 @@ class SortDataColumn {
6772
col.sortWith { df1, df2 -> df1[a] - df2[a] } shouldBe sortedCol
6873
col.sortWith(compareBy { it[a] }) shouldBe sortedCol
6974
}
75+
76+
@Test
77+
fun `sort by nested column`() {
78+
val df = testResource("ds_salaries.csv").readDataFrame().cast<DsSalaries>()
79+
val aggregate = df.pivot(false) { companySize }.groupBy { companyLocation }.aggregate {
80+
maxOf { salaryInUsd } into "salary"
81+
maxBy { salaryInUsd } into "extra"
82+
}
83+
aggregate.sortBy(pathOf("L", "salary"))[0][pathOf("L", "salary")] shouldBe null
84+
aggregate.sortByDesc(pathOf("L", "salary"))[0][pathOf("L", "salary")] shouldBe 600_000
85+
}
86+
87+
@Test
88+
fun `sort by invalid nested column`() {
89+
val df = testResource("ds_salaries.csv").readDataFrame().cast<DsSalaries>()
90+
val aggregate = df.pivot(false) { companySize }.groupBy { companyLocation }.aggregate {
91+
maxOf { salaryInUsd } into "salary"
92+
maxBy { salaryInUsd } into "extra"
93+
}
94+
shouldThrowMessage("Can not use ColumnGroup as sort column") {
95+
aggregate.sortBy(pathOf("L", "extra"))
96+
}
97+
}
7098
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package org.jetbrains.kotlinx.dataframe.testSets
2+
3+
import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
4+
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
5+
6+
@Suppress("unused")
7+
@DataSchema
8+
interface DsSalaries {
9+
@ColumnName("company_location")
10+
val companyLocation: String
11+
@ColumnName("company_size")
12+
val companySize: String
13+
@ColumnName("employee_residence")
14+
val employeeResidence: String
15+
@ColumnName("employment_type")
16+
val employmentType: String
17+
@ColumnName("experience_level")
18+
val experienceLevel: String
19+
@ColumnName("job_title")
20+
val jobTitle: String
21+
@ColumnName("remote_ratio")
22+
val remoteRatio: Int
23+
val salary: Int
24+
@ColumnName("salary_currency")
25+
val salaryCurrency: String
26+
@ColumnName("salary_in_usd")
27+
val salaryInUsd: Int
28+
val untitled: Int
29+
@ColumnName("work_year")
30+
val workYear: Int
31+
}

0 commit comments

Comments
 (0)