Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fusion regression #2611

Merged
merged 34 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1cb5ea1
add fusion
cadurosar Jan 26, 2024
fe7e529
Merge branch 'castorini:master' into master
cadurosar Jan 26, 2024
c626a7c
added cadurosar's code via a copy + paste and made changes to match p…
DanielKohn1208 May 1, 2024
4d702d0
merged cadurosar's code
DanielKohn1208 May 1, 2024
7db24a9
moved FuseRuns
DanielKohn1208 May 1, 2024
e7efd1b
merged cadurosar's code with modifications
DanielKohn1208 May 1, 2024
ebd8ed4
added run fusion to match pyserini implementation
DanielKohn1208 May 8, 2024
1b398af
added fusion feature
Stefan824 Sep 6, 2024
27b44df
modified arguments; added test cases
Stefan824 Sep 6, 2024
72e6e06
modified TrecRun class code style
Stefan824 Sep 6, 2024
5f7ec35
added comment
Stefan824 Sep 6, 2024
509049c
deleted test file from previous version
Stefan824 Sep 7, 2024
39f62a9
Added dependency for junit test
Stefan824 Sep 7, 2024
37e89fa
resolved formatting; merged trectools module to fusion
Stefan824 Sep 7, 2024
54c74b4
remove unused test cases
Stefan824 Sep 8, 2024
32e13c2
removed unused test files
Stefan824 Sep 8, 2024
6c648f7
Merge remote-tracking branch 'origin/master' into add-fusion
Stefan824 Sep 16, 2024
a9d7804
added fusion regression script paired with two yaml test files
Stefan824 Sep 23, 2024
e049e48
added md for test
Stefan824 Sep 23, 2024
bd0ce76
add cmd on test instruction
Stefan824 Sep 23, 2024
17ceb49
removed abundant dependency
Stefan824 Sep 23, 2024
6f550b1
revert unecessary change
Stefan824 Sep 23, 2024
f4644e1
resolved a minor decoding issue
Stefan824 Sep 23, 2024
0ea8369
added a yaml that is based on regression test run results
Stefan824 Sep 23, 2024
ec57e96
added doc for test2
Stefan824 Sep 23, 2024
042b678
typo
Stefan824 Sep 23, 2024
f2b6f4c
changed name for test yamls
Stefan824 Sep 23, 2024
d94c0f9
second attempt to revert src/main/resources/regression/beir-v1.0.0-ro…
Stefan824 Sep 24, 2024
f5871b9
fixed precision and added run_origins for fusion yaml
Stefan824 Sep 25, 2024
b7961f3
removed two yamls that use runs not from current regression experiments
Stefan824 Sep 29, 2024
ab33853
modified test instructions according to last commit
Stefan824 Sep 29, 2024
db12c79
add yaml file
Stefan824 Sep 30, 2024
9a419aa
removed old yaml
Stefan824 Sep 30, 2024
d9cff54
changed output naming
Stefan824 Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
modified arguments; added test cases
  • Loading branch information
Stefan824 committed Sep 6, 2024
commit 27b44dfa988642ab7b3827f797fd4b2bb0db1907
17 changes: 5 additions & 12 deletions src/main/java/io/anserini/fusion/FuseTrecRuns.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,38 @@

package io.anserini.fusion;

import org.apache.commons.lang3.NotImplementedException;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.ParserProperties;
import org.kohsuke.args4j.spi.StringArrayOptionHandler;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.util.Arrays;
import java.util.ArrayList;
import java.nio.file.Files;
import java.util.List;
import java.nio.file.Path;
import java.nio.file.Paths;

import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

import io.anserini.trectools.TrecRun;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
* Main entry point for Fusion.
*/
public class FuseTrecRuns {
private static final Logger LOG = LogManager.getLogger(FuseTrecRuns.class);

public static class Args extends TrecRunFuser.Args {
@Option(name = "-options", usage = "Print information about options.")
@Option(name = "-options", required = false, usage = "Print information about options.")
public Boolean options = false;

@Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "[file]", required = true,
usage = "Path to both run files to fuse")
public String[] runs;

@Option (name = "-resort", usage="We Resort the Trec run files or not")
@Option (name = "-resort", required = false, metaVar = "[flag]", usage="We Resort the Trec run files or not")
public boolean resort = false;
}

Expand Down
72 changes: 47 additions & 25 deletions src/main/java/io/anserini/fusion/TrecRunFuser.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/*
* Anserini: A Lucene toolkit for reproducible information retrieval research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.anserini.fusion;

import java.io.IOException;
Expand All @@ -10,10 +26,16 @@
import io.anserini.trectools.RescoreMethod;
import io.anserini.trectools.TrecRun;


/**
* Main logic class for Fusion
*/
public class TrecRunFuser {
private final Args args;

private static final String METHOD_RRF = "rrf";
private static final String METHOD_INTERPOLATION = "interpolation";
private static final String METHOD_AVERAGE = "average";

public static class Args {
@Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output")
public String output;
Expand All @@ -24,16 +46,16 @@ public static class Args {
@Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method")
public String method = "rrf";

@Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.")
@Option(name = "-rrf_k", metaVar = "[number]", required = false, usage = "Parameter k needed for reciprocal rank fusion.")
public int rrf_k = 60;

@Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.")
@Option(name = "-alpha", metaVar = "[value]", required = false, usage = "Alpha value used for interpolation.")
public double alpha = 0.5;

@Option(name = "-k", required = false, usage = "number of documents to output for topic")
@Option(name = "-k", metaVar = "[number]", required = false, usage = "number of documents to output for topic")
public int k = 1000;

@Option(name = "-depth", required = false, usage = "Pool depth per topic.")
@Option(name = "-depth", metaVar = "[number]", required = false, usage = "Pool depth per topic.")
public int depth = 1000;
}

Expand All @@ -42,13 +64,13 @@ public TrecRunFuser(Args args) {
}

/**
* Perform fusion by averaging on a list of TrecRun objects.
*
* @param runs List of TrecRun objects.
* @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered.
* @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked.
* @return Output TrecRun that combines input runs via averaging.
*/
* Perform fusion by averaging on a list of TrecRun objects.
*
* @param runs List of TrecRun objects.
* @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered.
* @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked.
* @return Output TrecRun that combines input runs via averaging.
*/
public static TrecRun average(List<TrecRun> runs, int depth, int k) {

for (TrecRun run : runs) {
Expand Down Expand Up @@ -77,16 +99,16 @@ public static TrecRun reciprocalRankFusion(List<TrecRun> runs, int rrf_k, int de
return TrecRun.merge(runs, depth, k);
}

/**
* Perform fusion by interpolation on a list of exactly two TrecRun objects.
* new_score = first_run_score * alpha + (1 - alpha) * second_run_score.
*
* @param runs List of TrecRun objects. Exactly two runs.
* @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run.
* @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered.
* @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked.
* @return Output TrecRun that combines input runs via interpolation.
*/
/**
* Perform fusion by interpolation on a list of exactly two TrecRun objects.
* new_score = first_run_score * alpha + (1 - alpha) * second_run_score.
*
* @param runs List of TrecRun objects. Exactly two runs.
* @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run.
* @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered.
* @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked.
* @return Output TrecRun that combines input runs via interpolation.
*/
public static TrecRun interpolation(List<TrecRun> runs, double alpha, int depth, int k) {
// Ensure exactly 2 runs are provided, as interpolation requires 2 runs
if (runs.size() != 2) {
Expand Down Expand Up @@ -115,13 +137,13 @@ public void fuse(List<TrecRun> runs) throws IOException {

// Select fusion method
switch (args.method.toLowerCase()) {
case "rrf":
case METHOD_RRF:
fusedRun = reciprocalRankFusion(runs, args.rrf_k, args.depth, args.k);
break;
case "interpolation":
case METHOD_INTERPOLATION:
fusedRun = interpolation(runs, args.alpha, args.depth, args.k);
break;
case "average":
case METHOD_AVERAGE:
fusedRun = average(runs, args.depth, args.k);
break;
default:
Expand Down
42 changes: 42 additions & 0 deletions src/test/java/io/anserini/fusion/FuseTrecRunsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.anserini.fusion;

import java.io.IOException;

import static org.junit.Assert.fail;
import org.junit.Test;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.ParserProperties;

public class FuseTrecRunsTest {

@Test
public void testFuseTrecRunsRRF() throws IOException {
String[] args = {
"-runs", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc.txt",
"runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-title.txt",
"-output", "runs/testsrc/test/resources/fused_output.txt",
"-rrf_k", "60",
"-k", "1000",
"-depth", "1000",
"-resort"
};

FuseTrecRuns.Args fuseArgs = new FuseTrecRuns.Args();
CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120));

try {
parser.parseArgument(args);
} catch (CmdLineException e) {
fail("Argument parsing failed: " + e.getMessage());
}

FuseTrecRuns fuseTrecRuns = new FuseTrecRuns(fuseArgs);
fuseTrecRuns.run();

// Assert the existence of the output file
// assertTrue("Output file should exist", Paths.get("runs/testsrc/test/resources/fused_output.txt").toFile().exists());

// Further assertions on the output can be made by reading and validating the contents.
}
}
76 changes: 76 additions & 0 deletions src/test/java/io/anserini/trectools/TrecRunTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package io.anserini.trectools;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import org.junit.Before;
import org.junit.Test;

public class TrecRunTest {
private TrecRun trecRun;
private Path sampleFilePath;

@Before
public void setUp() throws IOException {
sampleFilePath = Paths.get("runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt");
trecRun = new TrecRun(sampleFilePath, false);
}

@Test
public void testReadRun() throws IOException {
assertEquals(114, trecRun.getTopics().size()); // Assuming sample file has 3 topics
}

@Test
public void testGetDocsByTopic() {
List<Map<TrecRun.Column, Object>> docs = trecRun.getDocsByTopic("101", 0);
// System.out.println(docs);
assertNotNull(docs);
assertEquals(1000, docs.size()); // Assuming there are at least 10 documents for topic 101
}

@Test
public void testRescoreRRF() {
trecRun.rescore(RescoreMethod.RRF, 60, 1.0);
List<Map<TrecRun.Column, Object>> docs = trecRun.getDocsByTopic("101", 1);
System.out.println(docs.get(0).get(TrecRun.Column.SCORE));
assertEquals(1.0 / 61, docs.get(0).get(TrecRun.Column.SCORE));
}

@Test
public void testNormalizeScores() {
trecRun.rescore(RescoreMethod.NORMALIZE, 0, 0);
List<Map<TrecRun.Column, Object>> docs = trecRun.getDocsByTopic("101", 0);
double maxScore = (Double) docs.get(0).get(TrecRun.Column.SCORE);
double minScore = (Double) docs.get(docs.size() - 1).get(TrecRun.Column.SCORE);
assertEquals(1.0, maxScore, 0.01);
assertEquals(0.0, minScore, 0.01);
}

@Test
public void testMergeRuns() throws IOException {
TrecRun trecRun1 = new TrecRun(sampleFilePath);
TrecRun trecRun2 = new TrecRun(sampleFilePath);
TrecRun mergedRun = TrecRun.merge(Arrays.asList(trecRun1, trecRun2), null, 10);
Path outputPath = Paths.get("runs/testsrc/test/resources/output-merge.trec");
mergedRun.saveToTxt(outputPath, "test_tag");

// assertEquals(mergedRun.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE), 2 * (double) trecRun1.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE));
}

@Test
public void testSaveToTxt() throws IOException {
Path outputPath = Paths.get("runs/testsrc/test/resources/output.trec");
// trecRun.rescore(RescoreMethod.SCALE, 0, 2.0);
trecRun.saveToTxt(outputPath, "Anserini");
// Re-load the saved run
TrecRun savedRun = new TrecRun(outputPath);
assertEquals(trecRun.getTopics().size(), savedRun.getTopics().size());
}
}