Skip to content

Commit 2021c19

Browse files
committed
feat(semantic_router): first cut of semantic router
1 parent e39b6bc commit 2021c19

16 files changed

+2693
-35
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.redis.vl.extensions.router;
2+
3+
/**
4+
* Enumeration for distance aggregation methods. Ported from Python:
5+
* redisvl/extensions/router/schema.py:50
6+
*/
7+
public enum DistanceAggregationMethod {
8+
/** Average aggregation method */
9+
AVG("avg"),
10+
/** Minimum aggregation method */
11+
MIN("min"),
12+
/** Sum aggregation method */
13+
SUM("sum");
14+
15+
private final String value;
16+
17+
DistanceAggregationMethod(String value) {
18+
this.value = value;
19+
}
20+
21+
/**
22+
* Get the string value of the aggregation method.
23+
*
24+
* @return the string value
25+
*/
26+
public String getValue() {
27+
return value;
28+
}
29+
30+
/**
31+
* Get the DistanceAggregationMethod from a string value.
32+
*
33+
* @param value the string value
34+
* @return the corresponding DistanceAggregationMethod
35+
* @throws IllegalArgumentException if the value is unknown
36+
*/
37+
public static DistanceAggregationMethod fromValue(String value) {
38+
for (DistanceAggregationMethod method : values()) {
39+
if (method.value.equalsIgnoreCase(value)) {
40+
return method;
41+
}
42+
}
43+
throw new IllegalArgumentException("Unknown aggregation method: " + value);
44+
}
45+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package com.redis.vl.extensions.router;
2+
3+
import com.google.gson.Gson;
4+
import com.google.gson.reflect.TypeToken;
5+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
6+
import java.lang.reflect.Type;
7+
import java.util.ArrayList;
8+
import java.util.HashMap;
9+
import java.util.List;
10+
import java.util.Map;
11+
import lombok.Builder;
12+
import lombok.Data;
13+
14+
/**
15+
* Model representing a routing path with associated metadata and thresholds. Ported from Python:
16+
* redisvl/extensions/router/schema.py:12
17+
*/
18+
@Data
19+
@Builder
20+
@lombok.NoArgsConstructor
21+
@lombok.AllArgsConstructor
22+
@SuppressFBWarnings(
23+
value = {"EI_EXPOSE_REP", "EI_EXPOSE_REP2"},
24+
justification =
25+
"Route is a data class with intentionally mutable fields for Python API compatibility")
26+
public class Route {
27+
private String name;
28+
private List<String> references;
29+
@Builder.Default private Map<String, Object> metadata = Map.of();
30+
@Builder.Default private double distanceThreshold = 0.5;
31+
32+
/** Custom builder to ensure references list is mutable. */
33+
public static class RouteBuilder {
34+
public RouteBuilder references(List<String> references) {
35+
// Convert to ArrayList to ensure mutability for addRouteReferences/deleteRouteReferences
36+
this.references = references != null ? new ArrayList<>(references) : new ArrayList<>();
37+
return this;
38+
}
39+
}
40+
41+
/**
42+
* Validate the route configuration.
43+
*
44+
* @throws IllegalArgumentException if validation fails
45+
*/
46+
public void validate() {
47+
if (name == null || name.trim().isEmpty()) {
48+
throw new IllegalArgumentException("Route name must not be empty");
49+
}
50+
if (references == null || references.isEmpty()) {
51+
throw new IllegalArgumentException("References must not be empty");
52+
}
53+
for (String ref : references) {
54+
if (ref == null || ref.trim().isEmpty()) {
55+
throw new IllegalArgumentException("All references must be non-empty strings");
56+
}
57+
}
58+
if (distanceThreshold <= 0 || distanceThreshold > 2) {
59+
throw new IllegalArgumentException(
60+
"Distance threshold must be greater than 0 and less than or equal to 2");
61+
}
62+
}
63+
64+
/**
65+
* Convert route to a map for JSON serialization. Ported from Python: model_to_dict(route)
66+
*
67+
* @return Map representation of the route
68+
*/
69+
public Map<String, Object> toDict() {
70+
Map<String, Object> dict = new HashMap<>();
71+
dict.put("name", name);
72+
dict.put("references", new ArrayList<>(references));
73+
dict.put("metadata", new HashMap<>(metadata));
74+
dict.put("distance_threshold", distanceThreshold);
75+
return dict;
76+
}
77+
78+
/**
79+
* Convert route to JSON string.
80+
*
81+
* @return JSON representation of the route
82+
*/
83+
public String toJson() {
84+
Gson gson = new Gson();
85+
return gson.toJson(toDict());
86+
}
87+
88+
/**
89+
* Create route from map representation. Ported from Python: Route(**dict)
90+
*
91+
* @param dict Map containing route data
92+
* @return Route instance
93+
*/
94+
public static Route fromDict(Map<String, Object> dict) {
95+
@SuppressWarnings("unchecked")
96+
List<String> refs = (List<String>) dict.get("references");
97+
@SuppressWarnings("unchecked")
98+
Map<String, Object> meta =
99+
dict.containsKey("metadata") ? (Map<String, Object>) dict.get("metadata") : Map.of();
100+
101+
return Route.builder()
102+
.name((String) dict.get("name"))
103+
.references(refs != null ? refs : List.of())
104+
.metadata(meta)
105+
.distanceThreshold(
106+
dict.containsKey("distance_threshold")
107+
? ((Number) dict.get("distance_threshold")).doubleValue()
108+
: 0.5)
109+
.build();
110+
}
111+
112+
/**
113+
* Create route from JSON string.
114+
*
115+
* @param json JSON string
116+
* @return Route instance
117+
*/
118+
public static Route fromJson(String json) {
119+
Gson gson = new Gson();
120+
Type type = new TypeToken<Map<String, Object>>() {}.getType();
121+
Map<String, Object> dict = gson.fromJson(json, type);
122+
return fromDict(dict);
123+
}
124+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.redis.vl.extensions.router;
2+
3+
import lombok.AllArgsConstructor;
4+
import lombok.Builder;
5+
import lombok.Data;
6+
import lombok.NoArgsConstructor;
7+
8+
/**
9+
* Model representing a matched route with distance information. Ported from Python:
10+
* redisvl/extensions/router/schema.py:41
11+
*/
12+
@Data
13+
@Builder
14+
@NoArgsConstructor
15+
@AllArgsConstructor
16+
@SuppressWarnings("javadoc")
17+
public class RouteMatch {
18+
private String name;
19+
private Double distance;
20+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.redis.vl.extensions.router;
2+
3+
import lombok.Builder;
4+
import lombok.Data;
5+
6+
/**
7+
* Configuration for routing behavior. Ported from Python: redisvl/extensions/router/schema.py:61
8+
*/
9+
@Data
10+
@Builder
11+
@SuppressWarnings("javadoc")
12+
public class RoutingConfig {
13+
@Builder.Default private int maxK = 1;
14+
15+
@Builder.Default
16+
private DistanceAggregationMethod aggregationMethod = DistanceAggregationMethod.AVG;
17+
18+
/**
19+
* Validate the routing configuration.
20+
*
21+
* @throws IllegalArgumentException if validation fails
22+
*/
23+
public void validate() {
24+
if (maxK <= 0) {
25+
throw new IllegalArgumentException("maxK must be greater than 0");
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)