From 3767707dd2d48628cec2cc9adcce24cb3f48c07b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 1 Aug 2023 09:58:49 -0700 Subject: [PATCH] Refactor abstract combination techq class to util class Signed-off-by: Martin Gaievski --- ...ithmeticMeanScoreCombinationTechnique.java | 17 +++++------- ...HarmonicMeanScoreCombinationTechnique.java | 17 +++++------- .../combination/ScoreCombinationFactory.java | 10 ++++--- .../ScoreCombinationTechnique.java | 1 + ...chnique.java => ScoreCombinationUtil.java} | 27 ++++++++----------- ...ticMeanScoreCombinationTechniqueTests.java | 14 +++++++--- ...nicMeanScoreCombinationTechniqueTests.java | 14 +++++++--- 7 files changed, 53 insertions(+), 47 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/combination/{AbstractScoreCombinationTechnique.java => ScoreCombinationUtil.java} (72%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index df262d4a5..57040d2a1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -12,17 +12,19 @@ /** * Abstracts combination of scores based on arithmetic mean method */ -public class ArithmeticMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique { +public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { public static final String TECHNIQUE_NAME = "arithmetic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; - public ArithmeticMeanScoreCombinationTechnique(final Map params) { - validateParams(params); - weights = getWeights(params); + public ArithmeticMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); } /** @@ -38,7 +40,7 @@ public float combine(final float[] scores) { for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { float score = scores[indexOfSubQuery]; if (score >= 0.0) { - float weight = getWeightForSubQuery(weights, indexOfSubQuery); + float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); score = score * weight; combinedScore += score; sumOfWeights += weight; @@ -49,9 +51,4 @@ public float combine(final float[] scores) { } return combinedScore / sumOfWeights; } - - @Override - Set getSupportedParams() { - return SUPPORTED_PARAMS; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 69fd12d77..cb44e030a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -12,17 +12,19 @@ /** * Abstracts combination of scores based on harmonic mean method */ -public class HarmonicMeanScoreCombinationTechnique extends AbstractScoreCombinationTechnique implements ScoreCombinationTechnique { +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { public static final String TECHNIQUE_NAME = "harmonic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; + private final ScoreCombinationUtil scoreCombinationUtil; - public HarmonicMeanScoreCombinationTechnique(final Map params) { - validateParams(params); - weights = getWeights(params); + public HarmonicMeanScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) { + scoreCombinationUtil = combinationUtil; + scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS); + weights = scoreCombinationUtil.getWeights(params); } /** @@ -40,15 +42,10 @@ public float combine(final float[] scores) { if (score <= 0) { continue; } - float weightOfSubQuery = getWeightForSubQuery(weights, indexOfSubQuery); + float weightOfSubQuery = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery); sumOfWeights += weightOfSubQuery; sumOfHarmonics += weightOfSubQuery / score; } return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; } - - @Override - Set getSupportedParams() { - return SUPPORTED_PARAMS; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 1195b7004..d034ede16 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -13,14 +13,18 @@ * Abstracts creation of exact score combination method based on technique name */ public class ScoreCombinationFactory { + private static final ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); - public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique( + Map.of(), + scoreCombinationUtil + ); private final Map, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of( ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, - ArithmeticMeanScoreCombinationTechnique::new, + params -> new ArithmeticMeanScoreCombinationTechnique(params, scoreCombinationUtil), HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, - HarmonicMeanScoreCombinationTechnique::new + params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil) ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java index 21090b1ce..6e0a5db65 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.processor.combination; public interface ScoreCombinationTechnique { + /** * Defines combination function specific to this technique * @param scores array of collected original scores diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java similarity index 72% rename from src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java rename to src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 16850f04f..9458c6d35 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/AbstractScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -14,17 +14,11 @@ import java.util.stream.Collectors; /** - * Base class for score normalization technique + * Collection of utility methods for score combination technique classes */ -public abstract class AbstractScoreCombinationTechnique { +public class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; - /** - * Each technique must provide collection of supported parameters - * @return set of supported parameter names - */ - abstract Set getSupportedParams(); - /** * Get collection of weights based on user provided config * @param params map of named parameters and their values @@ -42,30 +36,31 @@ protected List getWeights(final Map params) { /** * Validate config parameters for this technique - * @param params map of parameters in form of name-value + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique */ - protected void validateParams(final Map params) { - if (Objects.isNull(params) || params.isEmpty()) { + protected void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { return; } // check if only supported params are passed - Optional optionalNotSupportedParam = params.keySet() + Optional optionalNotSupportedParam = actualParams.keySet() .stream() - .filter(paramName -> !getSupportedParams().contains(paramName)) + .filter(paramName -> !supportedParams.contains(paramName)) .findFirst(); if (optionalNotSupportedParam.isPresent()) { throw new IllegalArgumentException( String.format( Locale.ROOT, "provided parameter for combination technique is not supported. supported parameters are [%s]", - getSupportedParams().stream().collect(Collectors.joining(",")) + supportedParams.stream().collect(Collectors.joining(",")) ) ); } // check param types - if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { - if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { throw new IllegalArgumentException( String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 79f036fd8..3c3ca3776 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -17,24 +17,30 @@ public ArithmeticMeanScoreCombinationTechniqueTests() { } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of()); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 425c46443..02a8084ef 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -17,24 +17,30 @@ public HarmonicMeanScoreCombinationTechniqueTests() { } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of()); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(), new ScoreCombinationUtil()); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { List weights = List.of(0.9, 0.2, 0.7); - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique(Map.of(PARAM_NAME_WEIGHTS, weights)); + ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( + Map.of(PARAM_NAME_WEIGHTS, weights), + new ScoreCombinationUtil() + ); testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); }