Skip to content

Commit

Permalink
little refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuckame committed Jun 25, 2024
1 parent 0f4fcaa commit 386dccb
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ internal class ClassDescriptorForWriterSchema(
*/
val encodingSteps: Array<EncodingStep>,
) {
val hasMissingWriterField by lazy { encodingSteps.any { it is EncodingStep.MissingWriterFieldFailure } }

companion object {
val EMPTY =
ClassDescriptorForWriterSchema(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.github.avrokotlin.avro4k.internal.decoder

import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import com.github.avrokotlin.avro4k.internal.isNamedSchema
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal abstract class AbstractPolymorphicDecoder(
protected val avro: Avro,
private val descriptor: SerialDescriptor,
private val schema: Schema,
) : AbstractDecoder() {
final override val serializersModule: SerializersModule
get() = avro.serializersModule

private lateinit var chosenSchema: Schema

final override fun decodeString(): String {
return tryFindSerialName()?.also { chosenSchema = it.second }?.first
?: throw SerializationException("Unknown schema name '${schema.fullName}' for polymorphic type ${descriptor.serialName}. Full schema: $schema")
}

private fun tryFindSerialName(): Pair<String, Schema>? {
val namesAndAliasesToSerialName: Map<String, String> = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor)
return tryFindSerialName(namesAndAliasesToSerialName, schema)
}

protected abstract fun tryFindSerialNameForUnion(
namesAndAliasesToSerialName: Map<String, String>,
schema: Schema,
): Pair<String, Schema>?

protected fun tryFindSerialName(
namesAndAliasesToSerialName: Map<String, String>,
schema: Schema,
): Pair<String, Schema>? {
if (schema.isUnion) {
return tryFindSerialNameForUnion(namesAndAliasesToSerialName, schema)
}
return (
namesAndAliasesToSerialName[schema.fullName]
?: schema.takeIf { it.isNamedSchema() }?.aliases?.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] }
)
?.let { it to schema }
}

final override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return newDecoder(chosenSchema)
.decodeSerializableValue(deserializer)
}

abstract fun newDecoder(chosenSchema: Schema): Decoder

final override fun decodeSequentially() = true

final override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import com.github.avrokotlin.avro4k.decodeResolvingDouble
import com.github.avrokotlin.avro4k.decodeResolvingFloat
import com.github.avrokotlin.avro4k.decodeResolvingInt
import com.github.avrokotlin.avro4k.decodeResolvingLong
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError
import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder
import com.github.avrokotlin.avro4k.internal.decoder.direct.AbstractAvroDirectDecoder.SizeGetter
import com.github.avrokotlin.avro4k.internal.getElementIndexNullable
import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch
Expand All @@ -33,8 +33,8 @@ import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.CompositeDecoder
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.internal.AbstractCollectionSerializer
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema
Expand Down Expand Up @@ -523,42 +523,19 @@ internal abstract class AbstractAvroDirectDecoder(
}

private class PolymorphicDecoder(
private val avro: Avro,
private val descriptor: SerialDescriptor,
private val schema: Schema,
avro: Avro,
descriptor: SerialDescriptor,
schema: Schema,
private val binaryDecoder: org.apache.avro.io.Decoder,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private lateinit var chosenSchema: Schema

override fun decodeString(): String {
chosenSchema =
if (schema.isUnion) {
schema.types[binaryDecoder.readIndex()]
} else {
schema
}

return tryFindSerialName(chosenSchema)
?: throw SerializationException("Unknown schema name ${schema.fullName} for polymorphic type ${descriptor.nonNullSerialName}")
) : AbstractPolymorphicDecoder(avro, descriptor, schema) {
override fun tryFindSerialNameForUnion(
namesAndAliasesToSerialName: Map<String, String>,
schema: Schema,
): Pair<String, Schema>? {
return tryFindSerialName(namesAndAliasesToSerialName, schema.types[binaryDecoder.readIndex()])
}

private fun tryFindSerialName(schema: Schema): String? {
val namesAndAliasesToSerialName = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor)
return namesAndAliasesToSerialName[schema.fullName]
?: schema.aliases.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] }
}

override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
override fun newDecoder(chosenSchema: Schema): Decoder {
return AvroValueDirectDecoder(chosenSchema, avro, binaryDecoder)
.decodeSerializableValue(deserializer)
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ internal class ArrayBlockDirectDecoder(
) : AbstractAvroDirectDecoder(avro, binaryDecoder) {
override lateinit var currentWriterSchema: Schema

override fun decodeSequentially() = true

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return if (decodeFirstBlock) {
binaryDecoder.readArrayStart().toInt()
Expand All @@ -57,6 +55,8 @@ internal class ArrayBlockDirectDecoder(
currentWriterSchema = arraySchema.elementType
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
Expand All @@ -70,8 +70,6 @@ internal class MapBlockDirectDecoder(
) : AbstractAvroDirectDecoder(avro, binaryDecoder) {
override lateinit var currentWriterSchema: Schema

override fun decodeSequentially() = true

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return if (decodeFirstBlock) {
binaryDecoder.readMapStart().toInt()
Expand All @@ -88,6 +86,8 @@ internal class MapBlockDirectDecoder(
currentWriterSchema = if (index % 2 == 0) KEY_SCHEMA else mapSchema.valueType
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ internal class RecordDirectDecoder(
is DecodingStep.IgnoreOptionalElement -> {
// loop again to ignore the optional element
}

is DecodingStep.SkipWriterField -> binaryDecoder.skip(field.schema)
is DecodingStep.MissingElementValueFailure -> {
throw SerializationException("No writer schema field matching element index ${field.elementIndex} in descriptor $descriptor")
}

is DecodingStep.DeserializeWriterField -> {
currentDecodingStep = field
currentWriterSchema = field.schema
return field.elementIndex
}

is DecodingStep.GetDefaultValue -> {
currentDecodingStep = field
currentWriterSchema = field.schema
Expand All @@ -55,16 +58,12 @@ internal class RecordDirectDecoder(
}
}

private inline fun <T> decodeDefaultIfMissing(
private fun <T> decodeDefault(
element: DecodingStep.GetDefaultValue,
deserializer: DeserializationStrategy<T>,
block: () -> T,
): T {
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> block()
is DecodingStep.GetDefaultValue ->
AvroValueGenericDecoder(avro, element.defaultValue, currentWriterSchema)
.decodeSerializableValue(deserializer)
}
return AvroValueGenericDecoder(avro, element.defaultValue, currentWriterSchema)
.decodeSerializableValue(deserializer)
}

override fun decodeNotNullMark(): Boolean {
Expand Down Expand Up @@ -95,62 +94,72 @@ internal class RecordDirectDecoder(
}

override fun decodeInt(): Int {
return decodeDefaultIfMissing(Int.serializer()) {
super.decodeInt()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeInt()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Int.serializer())
}
}

override fun decodeLong(): Long {
return decodeDefaultIfMissing(Long.serializer()) {
super.decodeLong()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeLong()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Long.serializer())
}
}

override fun decodeBoolean(): Boolean {
return decodeDefaultIfMissing(Boolean.serializer()) {
super.decodeBoolean()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeBoolean()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Boolean.serializer())
}
}

override fun decodeChar(): Char {
return decodeDefaultIfMissing(Char.serializer()) {
super.decodeChar()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeChar()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Char.serializer())
}
}

override fun decodeString(): String {
return decodeDefaultIfMissing(String.serializer()) {
super.decodeString()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeString()
is DecodingStep.GetDefaultValue -> decodeDefault(element, String.serializer())
}
}

override fun decodeDouble(): Double {
return decodeDefaultIfMissing(Double.serializer()) {
super.decodeDouble()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeDouble()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Double.serializer())
}
}

override fun decodeFloat(): Float {
return decodeDefaultIfMissing(Float.serializer()) {
super.decodeFloat()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeFloat()
is DecodingStep.GetDefaultValue -> decodeDefault(element, Float.serializer())
}
}

override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return decodeDefaultIfMissing(deserializer) {
super.decodeSerializableValue(deserializer)
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeSerializableValue(deserializer)
is DecodingStep.GetDefaultValue -> decodeDefault(element, deserializer)
}
}

override fun decodeEnum(enumDescriptor: SerialDescriptor): Int {
return decodeDefaultIfMissing(Int.serializer()) {
super.decodeEnum(enumDescriptor)
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeEnum(enumDescriptor)
is DecodingStep.GetDefaultValue -> decodeDefault(element, Int.serializer())
}
}

override fun decodeBytes(): ByteArray {
return decodeDefaultIfMissing(ByteArraySerializer()) {
super.decodeBytes()
return when (val element = currentDecodingStep) {
is DecodingStep.DeserializeWriterField -> super.decodeBytes()
is DecodingStep.GetDefaultValue -> decodeDefault(element, ByteArraySerializer())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.encoding.CompositeDecoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.generic.GenericArray
import org.apache.avro.generic.GenericContainer
import org.apache.avro.generic.GenericEnumSymbol
import org.apache.avro.generic.GenericFixed
import org.apache.avro.generic.IndexedRecord
Expand Down Expand Up @@ -92,8 +93,8 @@ internal abstract class AbstractAvroGenericDecoder : AbstractDecoder(), AvroDeco

is PolymorphicKind ->
when (val value = decodeValue()) {
is IndexedRecord -> PolymorphicGenericDecoder(avro, descriptor, value)
else -> throw BadDecodedValueError(value, descriptor.kind, IndexedRecord::class)
is GenericContainer -> PolymorphicGenericDecoder(avro, descriptor, value.schema, value)
else -> PolymorphicGenericDecoder(avro, descriptor, currentWriterSchema, value)
}

else -> throw SerializationException("Unsupported descriptor for structure decoding: $descriptor")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,25 @@
package com.github.avrokotlin.avro4k.internal.decoder.generic

import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.SerializationException
import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.encoding.Decoder
import org.apache.avro.Schema
import org.apache.avro.generic.IndexedRecord

internal class PolymorphicGenericDecoder(
private val avro: Avro,
private val descriptor: SerialDescriptor,
private val value: IndexedRecord,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

override fun decodeString(): String {
return tryFindSerialName(value.schema)
?: throw SerializationException("Unknown schema name ${value.schema.fullName} for polymorphic type ${descriptor.serialName}")
}

private fun tryFindSerialName(schema: Schema): String? {
val namesAndAliasesToSerialName = avro.polymorphicResolver.getFullNamesAndAliasesToSerialName(descriptor)
return namesAndAliasesToSerialName[schema.fullName]
?: schema.aliases.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] }
avro: Avro,
descriptor: SerialDescriptor,
schema: Schema,
private val value: Any?,
) : AbstractPolymorphicDecoder(avro, descriptor, schema) {
override fun tryFindSerialNameForUnion(
namesAndAliasesToSerialName: Map<String, String>,
schema: Schema,
): Pair<String, Schema>? {
return schema.types.firstNotNullOfOrNull { tryFindSerialName(namesAndAliasesToSerialName, it) }
}

override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
return AvroValueGenericDecoder(avro, value, value.schema)
.decodeSerializableValue(deserializer)
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
override fun newDecoder(chosenSchema: Schema): Decoder {
return AvroValueGenericDecoder(avro, value, chosenSchema)
}
}
Loading

0 comments on commit 386dccb

Please sign in to comment.