-
Notifications
You must be signed in to change notification settings - Fork 73
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate model inference to build query (#20)
Integrates ml-commons model inference capabilities to transform NeuralQueryBuilder into a KNNQueryBuilder. Minor changes to parsing logic and build.gradle to fix bugs. Minor enhancement to MLCommonsCLientAccessor to add single sentence inference. Signed-off-by: John Mazanec <jmazane@amazon.com>
- Loading branch information
1 parent
393141d
commit 272d803
Showing
8 changed files
with
240 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
src/main/java/org/opensearch/neuralsearch/common/VectorUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.common; | ||
|
||
import java.util.List; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
|
||
/** | ||
* Utility class for working with vectors | ||
*/ | ||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
public class VectorUtil { | ||
|
||
/** | ||
* Converts a vector represented as a list to an array | ||
* | ||
* @param vectorAsList {@link List} of {@link Float}'s representing the vector | ||
* @return array of floats produced from input list | ||
*/ | ||
public static float[] vectorAsListToArray(List<Float> vectorAsList) { | ||
float[] vector = new float[vectorAsList.size()]; | ||
for (int i = 0; i < vectorAsList.size(); i++) { | ||
vector[i] = vectorAsList.get(i); | ||
} | ||
return vector; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
src/test/java/org/opensearch/neuralsearch/common/VectorUtilTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.common; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
|
||
import org.opensearch.test.OpenSearchTestCase; | ||
|
||
public class VectorUtilTests extends OpenSearchTestCase { | ||
|
||
public void testVectorAsListToArray() { | ||
List<Float> vectorAsList_withThreeElements = List.of(1.3f, 2.5f, 3.5f); | ||
float[] vectorAsArray_withThreeElements = VectorUtil.vectorAsListToArray(vectorAsList_withThreeElements); | ||
|
||
assertEquals(vectorAsList_withThreeElements.size(), vectorAsArray_withThreeElements.length); | ||
for (int i = 0; i < vectorAsList_withThreeElements.size(); i++) { | ||
assertEquals(vectorAsList_withThreeElements.get(i), vectorAsArray_withThreeElements[i], 0.0f); | ||
} | ||
|
||
List<Float> vectorAsList_withNoElements = Collections.emptyList(); | ||
float[] vectorAsArray_withNoElements = VectorUtil.vectorAsListToArray(vectorAsList_withNoElements); | ||
assertEquals(0, vectorAsArray_withNoElements.length); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.