Skip to content

Commit

Permalink
Refactor abstract combination techq class to util class
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Aug 1, 2023
1 parent e960476 commit 3767707
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
/**
* Abstracts combination of scores based on arithmetic mean method
*/
public class ArithmeticMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique {
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
Expand All @@ -38,7 +40,7 @@ public float combine(final float[] scores) {
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(weights, indexOfSubQuery);
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
score = score * weight;
combinedScore += score;
sumOfWeights += weight;
Expand All @@ -49,9 +51,4 @@ public float combine(final float[] scores) {
}
return combinedScore / sumOfWeights;
}

@Override
Set<String> getSupportedParams() {
return SUPPORTED_PARAMS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
/**
* Abstracts combination of scores based on harmonic mean method
*/
public class HarmonicMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique {
public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "harmonic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
Expand All @@ -40,15 +42,10 @@ public float combine(final float[] scores) {
if (score <= 0) {
continue;
}
float weightOfSubQuery = getWeightForSubQuery(weights, indexOfSubQuery);
float weightOfSubQuery = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
sumOfWeights += weightOfSubQuery;
sumOfHarmonics += weightOfSubQuery / score;
}
return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE;
}

@Override
Set<String> getSupportedParams() {
return SUPPORTED_PARAMS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {
private static final ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil();

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(
Map.of(),
scoreCombinationUtil
);

private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
ArithmeticMeanScoreCombinationTechnique::new,
params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil),
HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME,
HarmonicMeanScoreCombinationTechnique::new
params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.processor.combination;

public interface ScoreCombinationTechnique {

/**
* Defines combination function specific to this technique
* @param scores array of collected original scores
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,11 @@
import java.util.stream.Collectors;

/**
* Base class for score normalization technique
* Collection of utility methods for score combination technique classes
*/
public abstract class AbstractScoreCombinationTechnique {
public class ScoreCombinationUtil {
private static final String PARAM_NAME_WEIGHTS = "weights";

/**
* Each technique must provide collection of supported parameters
* @return set of supported parameter names
*/
abstract Set<String> getSupportedParams();

/**
* Get collection of weights based on user provided config
* @param params map of named parameters and their values
Expand All @@ -42,30 +36,31 @@ protected List<Float> getWeights(final Map<String, Object> params) {

/**
* Validate config parameters for this technique
* @param params map of parameters in form of name-value
* @param actualParams map of parameters in form of name-value
* @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique
*/
protected void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
protected void validateParams(final Map<String, Object> actualParams, final Set<String> supportedParams) {
if (Objects.isNull(actualParams) || actualParams.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
Optional<String> optionalNotSupportedParam = actualParams.keySet()
.stream()
.filter(paramName -> !getSupportedParams().contains(paramName))
.filter(paramName -> !supportedParams.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
getSupportedParams().stream().collect(Collectors.joining(","))
supportedParams.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,30 @@ public ArithmeticMeanScoreCombinationTechniqueTests() {
}

public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of());
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights));
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
);
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights));
ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
);
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,30 @@ public HarmonicMeanScoreCombinationTechniqueTests() {
}

public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of());
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() {
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of());
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil());
testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique);
}

public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights));
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
);
testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() {
List<Double> weights = List.of(0.9, 0.2, 0.7);
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights));
ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(
Map.of(PARAM_NAME_WEIGHTS, weights),
new ScoreCombinationUtil()
);
testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights);
}

Expand Down

0 comments on commit 3767707

Please sign in to comment.