Skip to content

Commit

Permalink
Merge pull request #313 from barakugav/comparator-key-extractor
Browse files Browse the repository at this point in the history
Primitive Comparator.comparing(keyExtractor), like standrad Comparator
  • Loading branch information
vigna authored Jul 21, 2024
2 parents aa3f51a + c031bae commit 49a738d
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
110 changes: 110 additions & 0 deletions drv/Comparator.drv
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package PACKAGE;

import java.util.Comparator;
import java.util.Objects;
import java.io.Serializable;

/** A type-specific {@link Comparator}; provides methods to compare two primitive types both as objects
* and as primitive types.
Expand Down Expand Up @@ -76,4 +78,112 @@ public interface KEY_COMPARATOR KEY_GENERIC extends Comparator<KEY_GENERIC_CLASS
return Comparator.super.thenComparing(second);
}
#endif

#define CONCAT_(A, B) A ## B
#define CONCAT(A, B) CONCAT_(A, B)
#define KEY_TO_OBJ_FUNCTION CONCAT(KEY_TYPE_CAP, 2ObjectFunction)
#define KEY_TO_INT_FUNCTION CONCAT(KEY_TYPE_CAP, 2IntFunction)
#define KEY_TO_LONG_FUNCTION CONCAT(KEY_TYPE_CAP, 2LongFunction)
#define KEY_TO_DOUBLE_FUNCTION CONCAT(KEY_TYPE_CAP, 2DoubleFunction)


/**
* Accepts a function that extracts a {@link java.lang.Comparable Comparable} sort key from
* a primitive key, and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function is also serializable.
*
* @param keyExtractor the function used to extract the {@link Comparable} sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
#if KEYS_PRIMITIVE
static <U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? extends U> keyExtractor) {
#else
static <K, U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? super K, ? extends U> keyExtractor) {
#endif
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> keyExtractor.get(k1).compareTo(keyExtractor.get(k2));
}

/**
* Accepts a function that extracts a sort key from a primitive key, and returns a
* comparator that compares by that sort key using the specified {@link Comparator}.
*
* <p>
* The returned comparator is serializable if the specified function and comparator are
* both serializable.
*
* @param keyExtractor the function used to extract the sort key
* @param keyComparator the {@code Comparator} used to compare the sort key
* @return a comparator that compares by an extracted key using the specified {@code Comparator}
* @throws NullPointerException if {@code keyExtractor} or {@code keyComparator} are {@code null}
*/
#if KEYS_PRIMITIVE
static <U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? extends U> keyExtractor, Comparator<? super U> keyComparator) {
#else
static <K, U extends Comparable<? super U>> KEY_COMPARATOR KEY_GENERIC comparing(KEY_TO_OBJ_FUNCTION <? super K, ? extends U> keyExtractor, Comparator<? super U> keyComparator) {
#endif
Objects.requireNonNull(keyExtractor);
Objects.requireNonNull(keyComparator);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> keyComparator.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code int} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the integer sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingInt(KEY_TO_INT_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Integer.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code long} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the long sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingLong(KEY_TO_LONG_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Long.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

/**
* Accepts a function that extracts an {@code double} sort key from a primitive key,
* and returns a comparator that compares by that sort key.
*
* <p>
* The returned comparator is serializable if the specified function
* is also serializable.
*
* @param keyExtractor the function used to extract the double sort key
* @return a comparator that compares by an extracted key
* @throws NullPointerException if {@code keyExtractor} is {@code null}
*/
static KEY_GENERIC KEY_COMPARATOR KEY_GENERIC comparingDouble(KEY_TO_DOUBLE_FUNCTION KEY_SUPER_GENERIC keyExtractor) {
Objects.requireNonNull(keyExtractor);
return (KEY_COMPARATOR KEY_GENERIC & Serializable)
(k1, k2) -> Double.compare(keyExtractor.get(k1), keyExtractor.get(k2));
}

}
69 changes: 69 additions & 0 deletions test/it/unimi/dsi/fastutil/ints/IntComparatorTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (C) 2003-2024 Barak Ugav and Sebastiano Vigna
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package it.unimi.dsi.fastutil.ints;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

public class IntComparatorTest {

@Test
public void comparing() {
String[] array = new String[] { "68", "98", "30", "62", "81", "61", "80", "63", "62", "77", "10", "95", "40",
"73", "55", "45", "16", "10", "86", "28", "79", "44", "52", "92", "98", "28", "88", "70", "70", "10" };
IntComparator c = IntComparator.comparing(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 29) * 1337) % array.length;
assertEquals(c.compare(i, j), array[i].compareTo(array[j]));
}
}

@Test
public void comparingInt() {
int[] array = new int[] { 81, 87, 70, 54, 40, 79, 16, 8, 84, 39, 37, 84, 64, 60, 31, 44, 95, 15, 52, 48, 19, 20,
75, 31, 46, 61, 38, 27, 32, 84 };
IntComparator c = IntComparator.comparingInt(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 17) * 1337) % array.length;
assertEquals(c.compare(i, j), Integer.compare(array[i], array[j]));
}
}

@Test
public void comparingLong() {
long[] array = new long[] { 26, 49, 49, 24, 15, 71, 10, 88, 78, 4, 42, 79, 75, 69, 63, 16, 71, 47, 54, 39, 89,
10, 64, 37, 38, 59, 81, 59, 58, 33 };
IntComparator c = IntComparator.comparingLong(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 19) * 1337) % array.length;
assertEquals(c.compare(i, j), Long.compare(array[i], array[j]));
}
}

@Test
public void comparingDouble() {
double[] array = new double[] { 0.61, 0.97, 0.97, 0.75, 0.73, 0.36, 0.72, 0.14, 0.93, 0.18, 0.45, 0.03, 0.62,
0.05, 0.04, 0.05, 0.38, 0.89, 0., 0.93, 0.83, 0.14, 0.21, 0.79, 0.5, 0.17, 0.46, 0.74, 0.88, 0.94 };
IntComparator c = IntComparator.comparingDouble(i -> array[i]);
for (int i = 0; i < array.length; i++) {
int j = ((i + 23) * 1337) % array.length;
assertEquals(c.compare(i, j), Double.compare(array[i], array[j]));
}
}

}

0 comments on commit 49a738d

Please sign in to comment.