diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt index 78d30de..ac3bdd4 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt @@ -233,6 +233,8 @@ internal class ClassDescriptorForWriterSchema( */ val encodingSteps: Array, ) { + val hasMissingWriterField by lazy { encodingSteps.any { it is EncodingStep.MissingWriterFieldFailure } } + companion object { val EMPTY = ClassDescriptorForWriterSchema( diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/AbstractPolymorphicDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/AbstractPolymorphicDecoder.kt new file mode 100644 index 0000000..431615e --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/AbstractPolymorphicDecoder.kt @@ -0,0 +1,65 @@ +package com.github.avrokotlin.avro4k.internal.decoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError +import com.github.avrokotlin.avro4k.internal.isNamedSchema +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.AbstractDecoder +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.Schema + +internal abstract class AbstractPolymorphicDecoder( + protected val avro: Avro, + private val descriptor: SerialDescriptor, + private val schema: Schema, +) : AbstractDecoder() { + final override val serializersModule: SerializersModule + get() = avro.serializersModule + + private lateinit var chosenSchema: Schema + + final override fun decodeString(): String { + return tryFindSerialName()?.also { chosenSchema = it.second }?.first + ?: throw SerializationException("Unknown schema name '${schema.fullName}' for polymorphic type ${descriptor.serialName}. Full schema: $schema") + } + + private fun tryFindSerialName(): Pair? { + val namesAndAliasesToSerialName: Map = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor) + return tryFindSerialName(namesAndAliasesToSerialName, schema) + } + + protected abstract fun tryFindSerialNameForUnion( + namesAndAliasesToSerialName: Map, + schema: Schema, + ): Pair? + + protected fun tryFindSerialName( + namesAndAliasesToSerialName: Map, + schema: Schema, + ): Pair? { + if (schema.isUnion) { + return tryFindSerialNameForUnion(namesAndAliasesToSerialName, schema) + } + return ( + namesAndAliasesToSerialName[schema.fullName] + ?: schema.takeIf { it.isNamedSchema() }?.aliases?.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] } + ) + ?.let { it to schema } + } + + final override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { + return newDecoder(chosenSchema) + .decodeSerializableValue(deserializer) + } + + abstract fun newDecoder(chosenSchema: Schema): Decoder + + final override fun decodeSequentially() = true + + final override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index 925dfe3..db440f6 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -16,8 +16,8 @@ import com.github.avrokotlin.avro4k.decodeResolvingDouble import com.github.avrokotlin.avro4k.decodeResolvingFloat import com.github.avrokotlin.avro4k.decodeResolvingInt import com.github.avrokotlin.avro4k.decodeResolvingLong -import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError +import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder import com.github.avrokotlin.avro4k.internal.decoder.direct.AbstractAvroDirectDecoder.SizeGetter import com.github.avrokotlin.avro4k.internal.getElementIndexNullable import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch @@ -33,8 +33,8 @@ import kotlinx.serialization.builtins.ByteArraySerializer import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.internal.AbstractCollectionSerializer import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema @@ -523,42 +523,19 @@ internal abstract class AbstractAvroDirectDecoder( } private class PolymorphicDecoder( - private val avro: Avro, - private val descriptor: SerialDescriptor, - private val schema: Schema, + avro: Avro, + descriptor: SerialDescriptor, + schema: Schema, private val binaryDecoder: org.apache.avro.io.Decoder, -) : AbstractDecoder() { - override val serializersModule: SerializersModule - get() = avro.serializersModule - - private lateinit var chosenSchema: Schema - - override fun decodeString(): String { - chosenSchema = - if (schema.isUnion) { - schema.types[binaryDecoder.readIndex()] - } else { - schema - } - - return tryFindSerialName(chosenSchema) - ?: throw SerializationException("Unknown schema name ${schema.fullName} for polymorphic type ${descriptor.nonNullSerialName}") +) : AbstractPolymorphicDecoder(avro, descriptor, schema) { + override fun tryFindSerialNameForUnion( + namesAndAliasesToSerialName: Map, + schema: Schema, + ): Pair? { + return tryFindSerialName(namesAndAliasesToSerialName, schema.types[binaryDecoder.readIndex()]) } - private fun tryFindSerialName(schema: Schema): String? { - val namesAndAliasesToSerialName = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor) - return namesAndAliasesToSerialName[schema.fullName] - ?: schema.aliases.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] } - } - - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { + override fun newDecoder(chosenSchema: Schema): Decoder { return AvroValueDirectDecoder(chosenSchema, avro, binaryDecoder) - .decodeSerializableValue(deserializer) - } - - override fun decodeSequentially() = true - - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - throw IllegalIndexedAccessError() } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/CollectionsDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/CollectionsDirectDecoder.kt index 08d74cc..ffd71d7 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/CollectionsDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/CollectionsDirectDecoder.kt @@ -39,8 +39,6 @@ internal class ArrayBlockDirectDecoder( ) : AbstractAvroDirectDecoder(avro, binaryDecoder) { override lateinit var currentWriterSchema: Schema - override fun decodeSequentially() = true - override fun decodeCollectionSize(descriptor: SerialDescriptor): Int { return if (decodeFirstBlock) { binaryDecoder.readArrayStart().toInt() @@ -57,6 +55,8 @@ internal class ArrayBlockDirectDecoder( currentWriterSchema = arraySchema.elementType } + override fun decodeSequentially() = true + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { throw IllegalIndexedAccessError() } @@ -70,8 +70,6 @@ internal class MapBlockDirectDecoder( ) : AbstractAvroDirectDecoder(avro, binaryDecoder) { override lateinit var currentWriterSchema: Schema - override fun decodeSequentially() = true - override fun decodeCollectionSize(descriptor: SerialDescriptor): Int { return if (decodeFirstBlock) { binaryDecoder.readMapStart().toInt() @@ -88,6 +86,8 @@ internal class MapBlockDirectDecoder( currentWriterSchema = if (index % 2 == 0) KEY_SCHEMA else mapSchema.valueType } + override fun decodeSequentially() = true + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { throw IllegalIndexedAccessError() } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt index b368357..5b5446c 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt @@ -37,15 +37,18 @@ internal class RecordDirectDecoder( is DecodingStep.IgnoreOptionalElement -> { // loop again to ignore the optional element } + is DecodingStep.SkipWriterField -> binaryDecoder.skip(field.schema) is DecodingStep.MissingElementValueFailure -> { throw SerializationException("No writer schema field matching element index ${field.elementIndex} in descriptor $descriptor") } + is DecodingStep.DeserializeWriterField -> { currentDecodingStep = field currentWriterSchema = field.schema return field.elementIndex } + is DecodingStep.GetDefaultValue -> { currentDecodingStep = field currentWriterSchema = field.schema @@ -55,16 +58,12 @@ internal class RecordDirectDecoder( } } - private inline fun decodeDefaultIfMissing( + private fun decodeDefault( + element: DecodingStep.GetDefaultValue, deserializer: DeserializationStrategy, - block: () -> T, ): T { - return when (val element = currentDecodingStep) { - is DecodingStep.DeserializeWriterField -> block() - is DecodingStep.GetDefaultValue -> - AvroValueGenericDecoder(avro, element.defaultValue, currentWriterSchema) - .decodeSerializableValue(deserializer) - } + return AvroValueGenericDecoder(avro, element.defaultValue, currentWriterSchema) + .decodeSerializableValue(deserializer) } override fun decodeNotNullMark(): Boolean { @@ -95,62 +94,72 @@ internal class RecordDirectDecoder( } override fun decodeInt(): Int { - return decodeDefaultIfMissing(Int.serializer()) { - super.decodeInt() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeInt() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Int.serializer()) } } override fun decodeLong(): Long { - return decodeDefaultIfMissing(Long.serializer()) { - super.decodeLong() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeLong() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Long.serializer()) } } override fun decodeBoolean(): Boolean { - return decodeDefaultIfMissing(Boolean.serializer()) { - super.decodeBoolean() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeBoolean() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Boolean.serializer()) } } override fun decodeChar(): Char { - return decodeDefaultIfMissing(Char.serializer()) { - super.decodeChar() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeChar() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Char.serializer()) } } override fun decodeString(): String { - return decodeDefaultIfMissing(String.serializer()) { - super.decodeString() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeString() + is DecodingStep.GetDefaultValue -> decodeDefault(element, String.serializer()) } } override fun decodeDouble(): Double { - return decodeDefaultIfMissing(Double.serializer()) { - super.decodeDouble() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeDouble() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Double.serializer()) } } override fun decodeFloat(): Float { - return decodeDefaultIfMissing(Float.serializer()) { - super.decodeFloat() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeFloat() + is DecodingStep.GetDefaultValue -> decodeDefault(element, Float.serializer()) } } override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { - return decodeDefaultIfMissing(deserializer) { - super.decodeSerializableValue(deserializer) + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeSerializableValue(deserializer) + is DecodingStep.GetDefaultValue -> decodeDefault(element, deserializer) } } override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - return decodeDefaultIfMissing(Int.serializer()) { - super.decodeEnum(enumDescriptor) + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeEnum(enumDescriptor) + is DecodingStep.GetDefaultValue -> decodeDefault(element, Int.serializer()) } } override fun decodeBytes(): ByteArray { - return decodeDefaultIfMissing(ByteArraySerializer()) { - super.decodeBytes() + return when (val element = currentDecodingStep) { + is DecodingStep.DeserializeWriterField -> super.decodeBytes() + is DecodingStep.GetDefaultValue -> decodeDefault(element, ByteArraySerializer()) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt index 38528af..ede3dd5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/AbstractAvroGenericDecoder.kt @@ -21,6 +21,7 @@ import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder import kotlinx.serialization.modules.SerializersModule import org.apache.avro.generic.GenericArray +import org.apache.avro.generic.GenericContainer import org.apache.avro.generic.GenericEnumSymbol import org.apache.avro.generic.GenericFixed import org.apache.avro.generic.IndexedRecord @@ -92,8 +93,8 @@ internal abstract class AbstractAvroGenericDecoder : AbstractDecoder(), AvroDeco is PolymorphicKind -> when (val value = decodeValue()) { - is IndexedRecord -> PolymorphicGenericDecoder(avro, descriptor, value) - else -> throw BadDecodedValueError(value, descriptor.kind, IndexedRecord::class) + is GenericContainer -> PolymorphicGenericDecoder(avro, descriptor, value.schema, value) + else -> PolymorphicGenericDecoder(avro, descriptor, currentWriterSchema, value) } else -> throw SerializationException("Unsupported descriptor for structure decoding: $descriptor") diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/PolymorphicGenericDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/PolymorphicGenericDecoder.kt index bad86e1..213854f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/PolymorphicGenericDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/PolymorphicGenericDecoder.kt @@ -1,42 +1,25 @@ package com.github.avrokotlin.avro4k.internal.decoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError -import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.SerializationException +import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.encoding.Decoder import org.apache.avro.Schema -import org.apache.avro.generic.IndexedRecord internal class PolymorphicGenericDecoder( - private val avro: Avro, - private val descriptor: SerialDescriptor, - private val value: IndexedRecord, -) : AbstractDecoder() { - override val serializersModule: SerializersModule - get() = avro.serializersModule - - override fun decodeString(): String { - return tryFindSerialName(value.schema) - ?: throw SerializationException("Unknown schema name ${value.schema.fullName} for polymorphic type ${descriptor.serialName}") - } - - private fun tryFindSerialName(schema: Schema): String? { - val namesAndAliasesToSerialName = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor) - return namesAndAliasesToSerialName[schema.fullName] - ?: schema.aliases.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] } + avro: Avro, + descriptor: SerialDescriptor, + schema: Schema, + private val value: Any?, +) : AbstractPolymorphicDecoder(avro, descriptor, schema) { + override fun tryFindSerialNameForUnion( + namesAndAliasesToSerialName: Map, + schema: Schema, + ): Pair? { + return schema.types.firstNotNullOfOrNull { tryFindSerialName(namesAndAliasesToSerialName, it) } } - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { - return AvroValueGenericDecoder(avro, value, value.schema) - .decodeSerializableValue(deserializer) - } - - override fun decodeSequentially() = true - - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - throw IllegalIndexedAccessError() + override fun newDecoder(chosenSchema: Schema): Decoder { + return AvroValueGenericDecoder(avro, value, chosenSchema) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt index bedcbad..5c0e1fc 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt @@ -34,7 +34,7 @@ internal fun RecordDirectEncoder( */ private class RecordSequentialDirectEncoder( private val classDescriptor: ClassDescriptorForWriterSchema, - protected val schema: Schema, + private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ) : AbstractAvroDirectEncoder(avro, binaryEncoder) { @@ -63,7 +63,7 @@ private class RecordSequentialDirectEncoder( } override fun endStructure(descriptor: SerialDescriptor) { - if (descriptor.elementsCount < classDescriptor.encodingSteps.size) { + if (classDescriptor.hasMissingWriterField) { throw SerializationException("The descriptor is not writing all the expected fields of writer schema. Schema: $schema, descriptor: $descriptor") } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt index a35d7f9..f6ace38 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt @@ -31,6 +31,10 @@ internal inline fun SerialDescriptor.findElementAnnotat internal val SerialDescriptor.nonNullSerialName: String get() = nonNullOriginal.serialName +internal fun Schema.isNamedSchema(): Boolean { + return this.type == Schema.Type.RECORD || this.type == Schema.Type.ENUM || this.type == Schema.Type.FIXED +} + internal fun Schema.isFullNameOrAliasMatch(descriptor: SerialDescriptor): Boolean { return isFullNameMatch(descriptor.nonNullSerialName) || descriptor.findAnnotation()?.value?.any { isFullNameMatch(it) } == true }