3838import java .lang .reflect .InvocationTargetException ;
3939import java .lang .reflect .ParameterizedType ;
4040import java .lang .reflect .RecordComponent ;
41+ import java .lang .reflect .Type ;
42+ import java .lang .reflect .TypeVariable ;
4143import java .util .ArrayList ;
4244import java .util .Arrays ;
4345import java .util .List ;
4850import static java .lang .String .format ;
4951import static org .bson .assertions .Assertions .notNull ;
5052
51- final class RecordCodec <T extends Record > implements Codec <T > {
53+ final class RecordCodec <T extends Record > implements Codec <T >, Parameterizable {
5254 private static final Logger LOGGER = Loggers .getLogger ("RecordCodec" );
5355 private final Class <T > clazz ;
56+ private final boolean requiresParameterization ;
5457 private final Constructor <?> canonicalConstructor ;
5558 private final List <ComponentModel > componentModels ;
5659 private final ComponentModel componentModelForId ;
@@ -62,10 +65,11 @@ private static final class ComponentModel {
6265 private final int index ;
6366 private final String fieldName ;
6467
65- private ComponentModel (final RecordComponent component , final CodecRegistry codecRegistry , final int index ) {
68+ private ComponentModel (final List <Type > typeParameters , final RecordComponent component , final CodecRegistry codecRegistry ,
69+ final int index ) {
6670 validateAnnotations (component , index );
6771 this .component = component ;
68- this .codec = computeCodec (component , codecRegistry );
72+ this .codec = computeCodec (typeParameters , component , codecRegistry );
6973 this .index = index ;
7074 this .fieldName = computeFieldName (component );
7175 }
@@ -83,11 +87,13 @@ Object getValue(final Record record) throws InvocationTargetException, IllegalAc
8387 }
8488
8589 @ SuppressWarnings ("deprecation" )
86- private static Codec <?> computeCodec (final RecordComponent component , final CodecRegistry codecRegistry ) {
87- var codec = codecRegistry .get (toWrapper (component .getType ()));
90+ private static Codec <?> computeCodec (final List <Type > typeParameters , final RecordComponent component ,
91+ final CodecRegistry codecRegistry ) {
92+ var codec = codecRegistry .get (toWrapper (resolveComponentType (typeParameters , component )));
8893 if (codec instanceof Parameterizable parameterizableCodec
8994 && component .getGenericType () instanceof ParameterizedType parameterizedType ) {
90- codec = parameterizableCodec .parameterize (codecRegistry , Arrays .asList (parameterizedType .getActualTypeArguments ()));
95+ codec = parameterizableCodec .parameterize (codecRegistry ,
96+ resolveActualTypeArguments (typeParameters , component .getDeclaringRecord (), parameterizedType ));
9197 }
9298 BsonType bsonRepresentationType = null ;
9399
@@ -109,6 +115,36 @@ private static Codec<?> computeCodec(final RecordComponent component, final Code
109115 return codec ;
110116 }
111117
118+ private static Class <?> resolveComponentType (final List <Type > typeParameters , final RecordComponent component ) {
119+ Type resolvedType = resolveType (component .getGenericType (), typeParameters , component .getDeclaringRecord ());
120+ return resolvedType instanceof Class <?> clazz ? clazz : component .getType ();
121+ }
122+
123+ private static List <Type > resolveActualTypeArguments (final List <Type > typeParameters , final Class <?> recordClass ,
124+ final ParameterizedType parameterizedType ) {
125+ return Arrays .stream (parameterizedType .getActualTypeArguments ())
126+ .map (type -> resolveType (type , typeParameters , recordClass ))
127+ .toList ();
128+ }
129+
130+ private static Type resolveType (final Type type , final List <Type > typeParameters , final Class <?> recordClass ) {
131+ return type instanceof TypeVariable <?> typeVariable
132+ ? typeParameters .get (getIndexOfTypeParameter (typeVariable .getName (), recordClass ))
133+ : type ;
134+ }
135+
136+ // Get
137+ private static int getIndexOfTypeParameter (final String typeParameterName , final Class <?> recordClass ) {
138+ var typeParameters = recordClass .getTypeParameters ();
139+ for (int i = 0 ; i < typeParameters .length ; i ++) {
140+ if (typeParameters [i ].getName ().equals (typeParameterName )) {
141+ return i ;
142+ }
143+ }
144+ throw new CodecConfigurationException (String .format ("Could not find type parameter on record %s with name %s" ,
145+ recordClass .getName (), typeParameterName ));
146+ }
147+
112148 @ SuppressWarnings ("deprecation" )
113149 private static String computeFieldName (final RecordComponent component ) {
114150 if (component .isAnnotationPresent (BsonId .class )) {
@@ -218,16 +254,47 @@ private static <T extends Annotation> void validateAnnotationOnlyOnField(final R
218254
219255 RecordCodec (final Class <T > clazz , final CodecRegistry codecRegistry ) {
220256 this .clazz = notNull ("class" , clazz );
257+ if (clazz .getTypeParameters ().length > 0 ) {
258+ requiresParameterization = true ;
259+ canonicalConstructor = null ;
260+ componentModels = null ;
261+ fieldNameToComponentModel = null ;
262+ componentModelForId = null ;
263+ } else {
264+ requiresParameterization = false ;
265+ canonicalConstructor = notNull ("canonicalConstructor" , getCanonicalConstructor (clazz ));
266+ componentModels = getComponentModels (clazz , codecRegistry , List .of ());
267+ fieldNameToComponentModel = componentModels .stream ()
268+ .collect (Collectors .toMap (ComponentModel ::getFieldName , Function .identity ()));
269+ componentModelForId = getComponentModelForId (clazz , componentModels );
270+ }
271+ }
272+
273+ RecordCodec (final Class <T > clazz , final CodecRegistry codecRegistry , final List <Type > types ) {
274+ if (types .size () != clazz .getTypeParameters ().length ) {
275+ throw new CodecConfigurationException ("Unexpected number of type parameters for record class " + clazz );
276+ }
277+ this .clazz = notNull ("class" , clazz );
278+ requiresParameterization = false ;
221279 canonicalConstructor = notNull ("canonicalConstructor" , getCanonicalConstructor (clazz ));
222- componentModels = getComponentModels (clazz , codecRegistry );
280+ componentModels = getComponentModels (clazz , codecRegistry , types );
223281 fieldNameToComponentModel = componentModels .stream ()
224282 .collect (Collectors .toMap (ComponentModel ::getFieldName , Function .identity ()));
225283 componentModelForId = getComponentModelForId (clazz , componentModels );
226284 }
227285
286+ @ Override
287+ public Codec <?> parameterize (final CodecRegistry codecRegistry , final List <Type > types ) {
288+ return new RecordCodec <>(clazz , codecRegistry , types );
289+ }
290+
228291 @ SuppressWarnings ("unchecked" )
229292 @ Override
230293 public T decode (final BsonReader reader , final DecoderContext decoderContext ) {
294+ if (requiresParameterization ) {
295+ throw new CodecConfigurationException ("Can not decode to a record with type parameters that has not been parameterized" );
296+ }
297+
231298 reader .readStartDocument ();
232299
233300 Object [] constructorArguments = new Object [componentModels .size ()];
@@ -254,6 +321,10 @@ public T decode(final BsonReader reader, final DecoderContext decoderContext) {
254321
255322 @ Override
256323 public void encode (final BsonWriter writer , final T record , final EncoderContext encoderContext ) {
324+ if (requiresParameterization ) {
325+ throw new CodecConfigurationException ("Can not decode to a record with type parameters that has not been parameterized" );
326+ }
327+
257328 writer .writeStartDocument ();
258329 if (componentModelForId != null ) {
259330 writeComponent (writer , record , componentModelForId );
@@ -287,11 +358,12 @@ private void writeComponent(final BsonWriter writer, final T record, final Compo
287358 }
288359 }
289360
290- private static <T > List <ComponentModel > getComponentModels (final Class <T > clazz , final CodecRegistry codecRegistry ) {
361+ private static <T > List <ComponentModel > getComponentModels (final Class <T > clazz , final CodecRegistry codecRegistry ,
362+ final List <Type > typeParameters ) {
291363 var recordComponents = clazz .getRecordComponents ();
292364 var componentModels = new ArrayList <ComponentModel >(recordComponents .length );
293365 for (int i = 0 ; i < recordComponents .length ; i ++) {
294- componentModels .add (new ComponentModel (recordComponents [i ], codecRegistry , i ));
366+ componentModels .add (new ComponentModel (typeParameters , recordComponents [i ], codecRegistry , i ));
295367 }
296368 return componentModels ;
297369 }
0 commit comments