From a4223acc01285b41162c005d7c1926ea2de8f10f Mon Sep 17 00:00:00 2001 From: Chuckame Date: Fri, 3 May 2024 14:25:10 +0200 Subject: [PATCH] feat: Support everything at root level Also: - support nullable array items - remove TypeNameStrategy - support optional fields without avro default #198 --- .../avrokotlin/avro4k/AnnotationExtractor.kt | 38 -- .../com/github/avrokotlin/avro4k/Avro.kt | 30 +- .../avrokotlin/avro4k/AvroConfiguration.kt | 10 +- .../avro4k/AvroObjectContainerFile.kt | 2 +- .../avrokotlin/avro4k/SerialDescriptor.kt | 25 -- .../avrokotlin/avro4k/decoder/ArrayDecoder.kt | 46 +++ .../avrokotlin/avro4k/decoder/AvroDecoder.kt | 15 + .../avro4k/decoder/AvroTaggedDecoder.kt | 233 ++++++++++++ .../avro4k/decoder/AvroValueDecoder.kt | 30 ++ .../avro4k/decoder/ByteArrayDecoder.kt | 28 +- .../avro4k/decoder/FromAvroValue.kt | 17 - .../avrokotlin/avro4k/decoder/ListDecoder.kt | 100 ----- .../avrokotlin/avro4k/decoder/MapDecoder.kt | 166 +++------ .../avro4k/decoder/PolymorphicDecoder.kt | 56 +++ .../avro4k/decoder/RecordDecoder.kt | 212 ++--------- .../avro4k/decoder/RootRecordDecoder.kt | 42 --- .../avrokotlin/avro4k/decoder/UnionDecoder.kt | 62 ---- .../avrokotlin/avro4k/encoder/ArrayEncoder.kt | 41 +++ .../avrokotlin/avro4k/encoder/AvroEncoder.kt | 81 ++++ .../avro4k/encoder/AvroTaggedEncoder.kt | 348 ++++++++++++++++++ .../avro4k/encoder/AvroValueEncoder.kt | 34 ++ .../avro4k/encoder/ByteArrayEncoder.kt | 41 --- .../avrokotlin/avro4k/encoder/BytesEncoder.kt | 36 ++ .../avrokotlin/avro4k/encoder/FieldEncoder.kt | 18 - .../avrokotlin/avro4k/encoder/FixedEncoder.kt | 45 +++ .../avrokotlin/avro4k/encoder/ListEncoder.kt | 84 ----- .../avrokotlin/avro4k/encoder/MapEncoder.kt | 162 ++------ .../avro4k/encoder/PolymorphicEncoder.kt | 36 ++ .../avro4k/encoder/RecordEncoder.kt | 148 +++----- .../avro4k/encoder/RootRecordEncoder.kt | 32 -- .../avrokotlin/avro4k/encoder/ToAvroValue.kt | 45 --- .../avrokotlin/avro4k/encoder/UnionEncoder.kt | 44 --- .../avrokotlin/avro4k/internal/NumberUtils.kt | 88 +++++ .../avro4k/internal/RecordResolver.kt | 190 ++++++++++ .../avro4k/internal/UnionResolver.kt | 39 ++ .../avrokotlin/avro4k/internal/exceptions.kt | 62 ++++ .../schema/AvroSchemaGenerationException.kt | 5 - .../avrokotlin/avro4k/schema/ClassVisitor.kt | 13 +- .../avro4k/schema/FieldNamingStrategy.kt | 54 ++- .../avro4k/schema/InlineClassVisitor.kt | 4 +- .../avrokotlin/avro4k/schema/ListVisitor.kt | 4 +- .../avrokotlin/avro4k/schema/MapVisitor.kt | 11 +- .../avro4k/schema/PolymorphicVisitor.kt | 5 +- .../avro4k/schema/SerialDescriptorVisitor.kt | 2 +- .../avro4k/schema/TypeNamingStrategy.kt | 33 -- .../avrokotlin/avro4k/schema/ValueVisitor.kt | 34 +- .../avro4k/schema/VisitorContext.kt | 56 +-- .../avrokotlin/avro4k/schema/helpers.kt | 52 ++- .../avro4k/serializer/AvroSerializer.kt | 66 ++-- .../avro4k/serializer/BigDecimalSerializer.kt | 116 +++--- .../avro4k/serializer/BigIntegerSerializer.kt | 48 +-- .../avro4k/serializer/URLSerializer.kt | 31 +- .../avro4k/serializer/UUIDSerializer.kt | 22 +- .../avrokotlin/avro4k/serializer/date.kt | 240 ++++++++---- .../avrokotlin/avro4k/serializer/helpers.kt | 20 - .../avrokotlin/avro4k/AvroAssertions.kt | 11 +- .../avro4k/AvroObjectContainerFileTest.kt | 36 +- .../avrokotlin/avro4k/RecordBuilderForTest.kt | 9 +- .../avro4k/encoding/ArrayEncodingTest.kt | 52 ++- .../avro4k/encoding/AvroAliasEncodingTest.kt | 71 +++- .../encoding/AvroDefaultEncodingTest.kt | 1 + .../avro4k/encoding/AvroFixedEncodingTest.kt | 107 ++++++ .../avro4k/encoding/BytesEncodingTest.kt | 42 +-- .../avro4k/encoding/EnumEncodingTest.kt | 57 ++- .../encoding/InlineClassEncodingTest.kt | 31 -- .../encoding/LogicalTypesEncodingTest.kt | 15 + .../MapEncodingTest.kt} | 3 +- .../encoding/NestedClassEncodingTest.kt | 42 ++- .../avro4k/encoding/PrimitiveEncodingTest.kt | 52 +++ .../avro4k/encoding/RecordEncodingTest.kt | 20 + .../encoding/SealedClassEncodingTest.kt | 7 + .../avro4k/schema/AvroFixedSchemaTest.kt | 52 --- .../avro4k/schema/UnionSchemaTest.kt | 1 + src/test/resources/class_of_list_of_maps.json | 2 +- src/test/resources/list_of_maps.json | 2 +- src/test/resources/map_boolean_null.json | 2 +- src/test/resources/map_int.json | 2 +- src/test/resources/map_record.json | 2 +- src/test/resources/map_set_nested.json | 2 +- src/test/resources/set_of_maps.json | 2 +- 80 files changed, 2411 insertions(+), 1714 deletions(-) delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/AnnotationExtractor.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/SerialDescriptor.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ArrayDecoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDecoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroTaggedDecoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroValueDecoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/PolymorphicDecoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ArrayEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroTaggedEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroValueEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ByteArrayEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/BytesEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FieldEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FixedEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/PolymorphicEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ToAvroValue.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/UnionResolver.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/schema/TypeNamingStrategy.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt create mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroFixedEncodingTest.kt delete mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/encoding/InlineClassEncodingTest.kt rename src/test/kotlin/com/github/avrokotlin/avro4k/{schema/MapSchemaTest.kt => encoding/MapEncodingTest.kt} (98%) delete mode 100644 src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroFixedSchemaTest.kt diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AnnotationExtractor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AnnotationExtractor.kt deleted file mode 100644 index e47e11a1..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AnnotationExtractor.kt +++ /dev/null @@ -1,38 +0,0 @@ -package com.github.avrokotlin.avro4k - -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.SerialDescriptor - -@ExperimentalSerializationApi -class AnnotationExtractor(private val annotations: List) { - companion object { - fun entity(descriptor: SerialDescriptor) = - AnnotationExtractor( - descriptor.annotations - ) - - operator fun invoke( - descriptor: SerialDescriptor, - index: Int, - ): AnnotationExtractor = AnnotationExtractor(descriptor.getElementAnnotations(index)) - } - - fun fixed(): Int? = annotations.filterIsInstance().firstOrNull()?.size - - fun doc(): String? = annotations.filterIsInstance().firstOrNull()?.value - - fun aliases(): List = - ( - annotations.firstNotNullOfOrNull { - it as? AvroAlias - }?.value ?: emptyArray() - ).asList() - - fun props(): List> = annotations.filterIsInstance().map { it.key to it.value } - - fun jsonProps(): List> = annotations.filterIsInstance().map { it.key to it.jsonValue } - - fun default(): String? = annotations.filterIsInstance().firstOrNull()?.value - - fun enumDefault(): String? = annotations.filterIsInstance().firstOrNull()?.value -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt index 8c5ab12f..9b0e1b14 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt @@ -1,9 +1,10 @@ package com.github.avrokotlin.avro4k -import com.github.avrokotlin.avro4k.decoder.RootRecordDecoder -import com.github.avrokotlin.avro4k.encoder.RootRecordEncoder +import com.github.avrokotlin.avro4k.decoder.AvroValueDecoder +import com.github.avrokotlin.avro4k.encoder.AvroValueEncoder +import com.github.avrokotlin.avro4k.internal.RecordResolver +import com.github.avrokotlin.avro4k.internal.UnionResolver import com.github.avrokotlin.avro4k.schema.FieldNamingStrategy -import com.github.avrokotlin.avro4k.schema.TypeNamingStrategy import com.github.avrokotlin.avro4k.schema.ValueVisitor import com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer import com.github.avrokotlin.avro4k.serializer.BigIntegerSerializer @@ -16,7 +17,6 @@ import com.github.avrokotlin.avro4k.serializer.URLSerializer import com.github.avrokotlin.avro4k.serializer.UUIDSerializer import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerializationException import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.modules.EmptySerializersModule @@ -27,7 +27,6 @@ import kotlinx.serialization.serializer import org.apache.avro.Schema import org.apache.avro.generic.GenericContainer import org.apache.avro.generic.GenericDatumReader -import org.apache.avro.generic.GenericRecord import org.apache.avro.io.DecoderFactory import org.apache.avro.io.EncoderFactory import org.apache.avro.reflect.ReflectDatumWriter @@ -45,6 +44,8 @@ sealed class Avro( val serializersModule: SerializersModule, ) { private val schemaCache: MutableMap = ConcurrentHashMap() + internal val recordResolver = RecordResolver(this) + internal val unionResolver = UnionResolver() companion object Default : Avro( AvroConfiguration(), @@ -102,7 +103,7 @@ sealed class Avro( value: T, ): Any? { var result: Any? = null - RootRecordEncoder(writerSchema, serializersModule, configuration) { + AvroValueEncoder(this, writerSchema) { result = it }.encodeSerializableValue(serializer, value) return result @@ -118,8 +119,7 @@ sealed class Avro( EncodedAs.BINARY -> DecoderFactory.get().binaryDecoder(inputStream, null) EncodedAs.JSON_COMPACT, EncodedAs.JSON_PRETTY -> DecoderFactory.get().jsonDecoder(writerSchema, inputStream) } - val readerSchema = schema(deserializer.descriptor) - val genericData = GenericDatumReader(writerSchema, readerSchema).read(null, avroDecoder) + val genericData = GenericDatumReader(writerSchema).read(null, avroDecoder) return decodeFromGenericData(writerSchema, deserializer, genericData) } @@ -136,11 +136,7 @@ sealed class Avro( deserializer: DeserializationStrategy, value: Any?, ): T { - return RootRecordDecoder( - (value as? GenericRecord?) ?: throw SerializationException("Expected a GenericRecord, actual: ${value?.let { it::class.qualifiedName }}"), - serializersModule, - configuration - ) + return AvroValueDecoder(this, value, writerSchema) .decodeSerializableValue(deserializer) } } @@ -155,7 +151,6 @@ fun Avro( } class AvroBuilder internal constructor(avro: Avro) { - var typeNamingStrategy: TypeNamingStrategy = avro.configuration.typeNamingStrategy var fieldNamingStrategy: FieldNamingStrategy = avro.configuration.fieldNamingStrategy var implicitNulls: Boolean = avro.configuration.implicitNulls var encodedAs: EncodedAs = avro.configuration.encodedAs @@ -163,10 +158,9 @@ class AvroBuilder internal constructor(avro: Avro) { fun build() = AvroConfiguration( - typeNamingStrategy = this.typeNamingStrategy, - fieldNamingStrategy = this.fieldNamingStrategy, - implicitNulls = this.implicitNulls, - encodedAs = this.encodedAs + fieldNamingStrategy = fieldNamingStrategy, + implicitNulls = implicitNulls, + encodedAs = encodedAs ) } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt index b07446fa..2bd8ff58 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt @@ -1,15 +1,9 @@ package com.github.avrokotlin.avro4k import com.github.avrokotlin.avro4k.schema.FieldNamingStrategy -import com.github.avrokotlin.avro4k.schema.TypeNamingStrategy +import kotlinx.serialization.ExperimentalSerializationApi data class AvroConfiguration( - /** - * The naming strategy to use for complex types (record, enum and fixed types). - * - * Default: [TypeNamingStrategy.Builtins.FullyQualified] - */ - val typeNamingStrategy: TypeNamingStrategy = TypeNamingStrategy.Builtins.FullyQualified, /** * The naming strategy to use for records' fields name. * @@ -26,9 +20,11 @@ data class AvroConfiguration( * * @see EncodedAs */ + @ExperimentalSerializationApi val encodedAs: EncodedAs = EncodedAs.BINARY, ) +@ExperimentalSerializationApi enum class EncodedAs { BINARY, JSON_COMPACT, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFile.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFile.kt index 266dcf27..21b192e5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFile.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFile.kt @@ -134,7 +134,7 @@ private class KotlinxSerializationDatumWriter( } } -internal class KotlinxSerializationDatumReader( +private class KotlinxSerializationDatumReader( private val deserializer: DeserializationStrategy, private val avro: Avro, ) : DatumReader { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/SerialDescriptor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/SerialDescriptor.kt deleted file mode 100644 index 1ac735f0..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/SerialDescriptor.kt +++ /dev/null @@ -1,25 +0,0 @@ -package com.github.avrokotlin.avro4k - -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.SerialKind -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.descriptors.elementDescriptors -import kotlinx.serialization.descriptors.getPolymorphicDescriptors -import kotlinx.serialization.modules.SerializersModule - -@ExperimentalSerializationApi -fun SerialDescriptor.possibleSerializationSubclasses(serializersModule: SerializersModule): List { - return when (this.kind) { - StructureKind.CLASS, StructureKind.OBJECT -> listOf(this) - PolymorphicKind.SEALED -> - elementDescriptors.filter { it.kind == SerialKind.CONTEXTUAL } - .flatMap { it.elementDescriptors } - .flatMap { it.possibleSerializationSubclasses(serializersModule) } - PolymorphicKind.OPEN -> - serializersModule.getPolymorphicDescriptors(this) - .flatMap { it.possibleSerializationSubclasses(serializersModule) } - else -> throw UnsupportedOperationException("Can't get possible serialization subclasses for the SerialDescriptor of kind ${this.kind}.") - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ArrayDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ArrayDecoder.kt new file mode 100644 index 00000000..7af49f40 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ArrayDecoder.kt @@ -0,0 +1,46 @@ +package com.github.avrokotlin.avro4k.decoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.DecodedNullError +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class ArrayDecoder( + private val collection: Collection, + private val writerSchema: Schema, + override val avro: Avro, +) : AvroTaggedDecoder() { + private val iterator = collection.iterator() + private val elementType = if (writerSchema.type == Schema.Type.BYTES) writerSchema else writerSchema.elementType + + private var currentItem: Any? = null + private var decodedNullMark = false + + override val Schema.writerSchema: Schema + get() = this@ArrayDecoder.elementType + + override fun SerialDescriptor.getTag(index: Int): Schema { + return elementType + } + + override fun decodeTaggedNotNullMark(tag: Schema): Boolean { + decodedNullMark = true + currentItem = iterator.next() + return currentItem != null + } + + override fun decodeTaggedValue(tag: Schema): Any { + val value = if (decodedNullMark) currentItem else iterator.next() + decodedNullMark = false + return value ?: throw DecodedNullError() + } + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() + } + + override fun decodeCollectionSize(descriptor: SerialDescriptor) = collection.size + + override fun decodeSequentially() = true +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDecoder.kt new file mode 100644 index 00000000..3f59ed47 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDecoder.kt @@ -0,0 +1,15 @@ +package com.github.avrokotlin.avro4k.decoder + +import kotlinx.serialization.encoding.Decoder +import org.apache.avro.Schema +import org.apache.avro.generic.GenericFixed + +interface AvroDecoder : Decoder { + val currentWriterSchema: Schema + + fun decodeBytes(): ByteArray + + fun decodeFixed(): GenericFixed + + fun decodeValue(): Any +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroTaggedDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroTaggedDecoder.kt new file mode 100644 index 00000000..a4b66d87 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroTaggedDecoder.kt @@ -0,0 +1,233 @@ +package com.github.avrokotlin.avro4k.decoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroEnumDefault +import com.github.avrokotlin.avro4k.internal.BadDecodedValueError +import com.github.avrokotlin.avro4k.internal.toByteExact +import com.github.avrokotlin.avro4k.internal.toDoubleExact +import com.github.avrokotlin.avro4k.internal.toFloatExact +import com.github.avrokotlin.avro4k.internal.toIntExact +import com.github.avrokotlin.avro4k.internal.toLongExact +import com.github.avrokotlin.avro4k.internal.toShortExact +import com.github.avrokotlin.avro4k.schema.findAnnotation +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.SerializationException +import kotlinx.serialization.builtins.ByteArraySerializer +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encoding.CompositeDecoder +import kotlinx.serialization.internal.TaggedDecoder +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.Schema +import org.apache.avro.generic.GenericArray +import org.apache.avro.generic.GenericEnumSymbol +import org.apache.avro.generic.GenericFixed +import org.apache.avro.generic.IndexedRecord +import java.math.BigDecimal +import java.nio.ByteBuffer + +@OptIn(InternalSerializationApi::class) +internal abstract class AvroTaggedDecoder : TaggedDecoder(), AvroDecoder { + protected abstract val avro: Avro + + protected abstract val Tag.writerSchema: Schema + + abstract override fun decodeTaggedNotNullMark(tag: Tag): Boolean + + abstract override fun decodeTaggedValue(tag: Tag): Any + + override val currentWriterSchema: Schema + get() = currentTag.writerSchema + + override val serializersModule: SerializersModule + get() = avro.serializersModule + + override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { + if (deserializer.descriptor == ByteArraySerializer().descriptor) { + // fast-path for ByteArray fields, to avoid slow-path with ArrayDecoder + @Suppress("UNCHECKED_CAST") + return decodeBytes() as T + } + return super.decodeSerializableValue(deserializer) + } + + @Suppress("UNCHECKED_CAST") + override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + return when (descriptor.kind) { + StructureKind.LIST -> + when (val value = decodeValue()) { + is GenericArray<*> -> + ArrayDecoder( + collection = value, + writerSchema = value.schema, + avro = avro + ) + + is Collection<*> -> + ArrayDecoder( + collection = value, + writerSchema = currentTag.writerSchema, + avro = avro + ) + + // TODO should be removed as byte arrays are handled by fast-path in decodeSerializableValue + // and collection of bytes should be handled as normal arrays of byte and not as native bytes + is ByteBuffer -> ByteArrayDecoder(avro, value.array()) + + else -> throw BadDecodedValueError(value, StructureKind.LIST, GenericArray::class, Collection::class, ByteBuffer::class) + } + + StructureKind.MAP -> + when (val value = decodeValue()) { + is Map<*, *> -> + MapDecoder( + value as Map, + currentTag.writerSchema, + avro + ) + + else -> throw BadDecodedValueError(value, StructureKind.MAP, Map::class) + } + + StructureKind.CLASS, StructureKind.OBJECT -> + when (val value = decodeValue()) { + is IndexedRecord -> RecordDecoder(value, descriptor, avro) + else -> throw BadDecodedValueError(value, descriptor.kind, IndexedRecord::class) + } + + is PolymorphicKind -> + when (val value = decodeValue()) { + is IndexedRecord -> PolymorphicDecoder(avro, descriptor, value) + else -> throw BadDecodedValueError(value, descriptor.kind, IndexedRecord::class) + } + + else -> throw SerializationException("Unsupported descriptor for structure decoding: $descriptor") + } + } + + override fun decodeValue() = decodeTaggedValue(currentTag) + + override fun decodeTaggedBoolean(tag: Tag): Boolean { + return when (val value = decodeTaggedValue(tag)) { + is Boolean -> value + 1 -> true + 0 -> false + is CharSequence -> value.toString().toBoolean() + else -> throw BadDecodedValueError(value, PrimitiveKind.BOOLEAN, Boolean::class, Int::class, CharSequence::class) + } + } + + override fun decodeTaggedByte(tag: Tag): Byte { + return when (val value = decodeTaggedValue(tag)) { + is Int -> value.toByteExact() + is Long -> value.toByteExact() + is BigDecimal -> value.toByteExact() + is CharSequence -> value.toString().toByte() + else -> throw BadDecodedValueError(value, PrimitiveKind.BYTE, Int::class, Long::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedShort(tag: Tag): Short { + return when (val value = decodeTaggedValue(tag)) { + is Int -> value.toShortExact() + is Long -> value.toShortExact() + is BigDecimal -> value.toShortExact() + is CharSequence -> value.toString().toShort() + else -> throw BadDecodedValueError(value, PrimitiveKind.SHORT, Int::class, Long::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedInt(tag: Tag): Int { + return when (val value = decodeTaggedValue(tag)) { + is Int -> value + is Long -> value.toIntExact() + is BigDecimal -> value.toIntExact() + is CharSequence -> value.toString().toInt() + else -> throw BadDecodedValueError(value, PrimitiveKind.INT, Int::class, Long::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedLong(tag: Tag): Long { + return when (val value = decodeTaggedValue(tag)) { + is Long -> value + is Int -> value.toLong() + is BigDecimal -> value.toLongExact() + is CharSequence -> value.toString().toLong() + else -> throw BadDecodedValueError(value, PrimitiveKind.LONG, Int::class, Long::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedFloat(tag: Tag): Float { + return when (val value = decodeTaggedValue(tag)) { + is Float -> value + is Double -> value.toFloatExact() + is BigDecimal -> value.toFloatExact() + is CharSequence -> value.toString().toFloat() + else -> throw BadDecodedValueError(value, PrimitiveKind.FLOAT, Float::class, Double::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedDouble(tag: Tag): Double { + return when (val value = decodeTaggedValue(tag)) { + is Double -> value + is Float -> value.toDouble() + is BigDecimal -> value.toDoubleExact() + is CharSequence -> value.toString().toDouble() + else -> throw BadDecodedValueError(value, PrimitiveKind.DOUBLE, Float::class, Double::class, BigDecimal::class, CharSequence::class) + } + } + + override fun decodeTaggedChar(tag: Tag): Char { + val value = decodeTaggedValue(tag) + return when { + value is Int -> value.toChar() + value is CharSequence && value.length == 1 -> value[0] + else -> throw BadDecodedValueError(value, PrimitiveKind.CHAR, Int::class, CharSequence::class) + } + } + + override fun decodeTaggedString(tag: Tag): String { + return when (val value = decodeTaggedValue(tag)) { + is CharSequence -> value.toString() + is ByteArray -> value.decodeToString() + is GenericFixed -> value.bytes().decodeToString() + else -> throw BadDecodedValueError(value, PrimitiveKind.STRING, CharSequence::class, ByteArray::class, GenericFixed::class) + } + } + + override fun decodeTaggedEnum( + tag: Tag, + enumDescriptor: SerialDescriptor, + ): Int { + return when (val value = decodeTaggedValue(tag)) { + is GenericEnumSymbol<*>, is CharSequence -> { + enumDescriptor.getElementIndex(value.toString()).takeIf { it >= 0 } + ?: enumDescriptor.findAnnotation()?.value?.let { enumDescriptor.getElementIndex(it) }?.takeIf { it >= 0 } + ?: throw SerializationException("Unknown enum symbol '$value' for Enum '${enumDescriptor.serialName}'") + } + + else -> throw BadDecodedValueError(value, SerialKind.ENUM, GenericEnumSymbol::class, CharSequence::class) + } + } + + override fun decodeBytes(): ByteArray { + return when (val value = decodeTaggedValue(currentTag)) { + is ByteArray -> value + is ByteBuffer -> value.array() + is GenericFixed -> value.bytes() + is CharSequence -> value.toString().toByteArray() + else -> throw BadDecodedValueError(value, ByteArray::class, GenericFixed::class, CharSequence::class) + } + } + + override fun decodeFixed(): GenericFixed { + return when (val value = decodeTaggedValue(currentTag)) { + is GenericFixed -> value + else -> throw BadDecodedValueError(value, GenericFixed::class) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroValueDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroValueDecoder.kt new file mode 100644 index 00000000..e52dd72c --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/AvroValueDecoder.kt @@ -0,0 +1,30 @@ +package com.github.avrokotlin.avro4k.decoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.DecodedNullError +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class AvroValueDecoder( + override val avro: Avro, + val value: Any?, + val writerSchema: Schema, +) : AvroTaggedDecoder() { + init { + pushTag(writerSchema) + } + + override val Schema.writerSchema: Schema + get() = this@AvroValueDecoder.writerSchema + + override fun decodeTaggedNotNullMark(tag: Schema) = value != null + + override fun decodeTaggedValue(tag: Schema) = value ?: throw DecodedNullError() + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() + } + + override fun SerialDescriptor.getTag(index: Int) = this@AvroValueDecoder.writerSchema +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ByteArrayDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ByteArrayDecoder.kt index 4d573748..71883126 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ByteArrayDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ByteArrayDecoder.kt @@ -1,23 +1,29 @@ package com.github.avrokotlin.avro4k.decoder -import kotlinx.serialization.ExperimentalSerializationApi +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE import kotlinx.serialization.modules.SerializersModule -@ExperimentalSerializationApi -class ByteArrayDecoder(val data: ByteArray, override val serializersModule: SerializersModule) : AbstractDecoder() { - private var index = -1 +internal class ByteArrayDecoder( + private val avro: Avro, + private val bytes: ByteArray, +) : AbstractDecoder() { + override val serializersModule: SerializersModule + get() = avro.serializersModule - override fun decodeCollectionSize(descriptor: SerialDescriptor): Int = data.size + private val iterator = bytes.iterator() - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - index++ - return if (index < data.size) index else DECODE_DONE + override fun decodeByte() = iterator.nextByte() + + override fun decodeCollectionSize(descriptor: SerialDescriptor): Int { + return bytes.size } - override fun decodeByte(): Byte { - return data[index] + override fun decodeSequentially() = true + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt deleted file mode 100644 index fc4915ae..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt +++ /dev/null @@ -1,17 +0,0 @@ -package com.github.avrokotlin.avro4k.decoder - -import kotlinx.serialization.SerializationException -import org.apache.avro.generic.GenericData -import java.nio.ByteBuffer - -object StringFromAvroValue { - fun fromValue(value: Any): String { - return when (value) { - is CharSequence -> value.toString() - is GenericData.Fixed -> String(value.bytes()) - is ByteArray -> String(value) - is ByteBuffer -> String(value.array()) - else -> throw SerializationException("Unsupported type for String [is ${value::class.qualifiedName}]") - } - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt deleted file mode 100644 index c011af33..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt +++ /dev/null @@ -1,100 +0,0 @@ -package com.github.avrokotlin.avro4k.decoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.encoding.CompositeDecoder -import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericArray -import org.apache.avro.generic.GenericRecord - -@ExperimentalSerializationApi -class ListDecoder( - private val schema: Schema, - private val array: List, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, -) : AbstractDecoder(), FieldDecoder { - init { - require(schema.type == Schema.Type.ARRAY) - } - - private var index = -1 - - override fun decodeBoolean(): Boolean { - return decodeAnyNotNull() as Boolean - } - - override fun decodeLong(): Long { - return decodeAnyNotNull() as Long - } - - override fun decodeString(): String { - val raw = decodeAnyNotNull() - return StringFromAvroValue.fromValue(raw) - } - - override fun decodeDouble(): Double { - return decodeAnyNotNull() as Double - } - - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - index++ - return if (index < array.size) index else DECODE_DONE - } - - override fun decodeFloat(): Float { - return decodeAnyNotNull() as Float - } - - override fun decodeByte(): Byte { - return decodeAnyNotNull() as Byte - } - - override fun decodeInt(): Int { - return decodeAnyNotNull() as Int - } - - override fun decodeChar(): Char { - return decodeAnyNotNull() as Char - } - - override fun decodeAny(): Any? { - return array[index] - } - - private fun decodeAnyNotNull(): Any { - return array[index] ?: throw SerializationException("Item at index $index must not be null") - } - - override fun fieldSchema(): Schema = schema.elementType - - override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - val symbol = decodeAnyNotNull().toString() - return (0 until enumDescriptor.elementsCount).find { enumDescriptor.getElementName(it) == symbol } ?: -1 - } - - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { - return deserializer.deserialize(this) - } - - @Suppress("UNCHECKED_CAST") - override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - return when (descriptor.kind) { - StructureKind.CLASS -> RecordDecoder(descriptor, decodeAnyNotNull() as GenericRecord, serializersModule, configuration) - StructureKind.LIST -> ListDecoder(schema.elementType, decodeAnyNotNull() as GenericArray<*>, serializersModule, configuration) - StructureKind.MAP -> MapDecoder(schema.elementType, decodeAnyNotNull() as Map, serializersModule, configuration) - PolymorphicKind.SEALED, PolymorphicKind.OPEN -> UnionDecoder(descriptor, decodeAnyNotNull() as GenericRecord, serializersModule, configuration) - else -> throw UnsupportedOperationException("Kind ${descriptor.kind} is currently not supported.") - } - } - - override fun decodeCollectionSize(descriptor: SerialDescriptor): Int = array.size -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt index dd6d9d2f..035d9afc 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt @@ -1,146 +1,68 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroConfiguration -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.PrimitiveKind +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.DecodedNullError +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.encoding.CompositeDecoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericArray -import org.apache.avro.generic.GenericRecord -import java.nio.ByteBuffer -@ExperimentalSerializationApi -class MapDecoder( - private val schema: Schema, - map: Map<*, *>, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, -) : AbstractDecoder(), CompositeDecoder { - init { - require(schema.type == Schema.Type.MAP) - } - - private val entries = map.toList() - private var index = -1 - - override fun decodeString(): String { - val value = keyOrValueNotNull() - return StringFromAvroValue.fromValue(value) - } - - private fun keyOrValue() = if (expectKey()) key() else value() - - private fun keyOrValueNotNull(): Any = keyOrValue() ?: throw SerializationException("Cannot decode as a key or value") - - private fun expectKey() = index % 2 == 0 - - private fun key(): Any? = entries[index / 2].first - - private fun value(): Any? = entries[index / 2].second - - override fun decodeNotNullMark(): Boolean { - return keyOrValue() != null - } - - override fun decodeFloat(): Float { - return when (val v = keyOrValueNotNull()) { - is Float -> v - is CharSequence -> v.toString().toFloat() - else -> throw SerializationException("Unsupported type for Float ${v::class.qualifiedName}") +internal class MapDecoder( + private val map: Map, + private val writerSchema: Schema, + override val avro: Avro, +) : AvroTaggedDecoder() { + private val iterator = map.iterator() + private lateinit var currentEntry: Map.Entry + private var polledEntry = false + + override val MapTag.writerSchema: Schema + get() = this@MapDecoder.writerSchema.valueType + + override fun SerialDescriptor.getTag(index: Int): MapTag { + return if (index % 2 == 0) { + MapTag.key() + } else { + MapTag.value(writerSchema.valueType) } } - override fun decodeInt(): Int { - return when (val v = keyOrValueNotNull()) { - is Int -> v - is CharSequence -> v.toString().toInt() - else -> throw SerializationException("Unsupported type for Int ${v::class.qualifiedName}") + override fun decodeTaggedNotNullMark(tag: MapTag): Boolean { + if (tag.isKey) { + polledEntry = true + currentEntry = iterator.next() + return true // key never null } + return currentEntry.value != null } - override fun decodeLong(): Long { - return when (val v = keyOrValueNotNull()) { - is Long -> v - is Int -> v.toLong() - is CharSequence -> v.toString().toLong() - else -> throw SerializationException("Unsupported type for Long ${v::class.qualifiedName}") + override fun decodeTaggedValue(tag: MapTag): Any { + if (tag.isKey) { + if (!polledEntry) { + currentEntry = iterator.next() + } + return currentEntry.key } + polledEntry = false + return currentEntry.value ?: throw DecodedNullError() } - override fun decodeDouble(): Double { - return when (val v = keyOrValueNotNull()) { - is Double -> v - is Float -> v.toDouble() - is CharSequence -> v.toString().toDouble() - else -> throw SerializationException("Unsupported type for Double ${v::class.qualifiedName}") - } - } - - override fun decodeByte(): Byte { - return when (val v = keyOrValueNotNull()) { - is Byte -> v - is Int -> v.toByte() - is CharSequence -> v.toString().toByte() - else -> throw SerializationException("Unsupported type for Byte ${v::class.qualifiedName}") - } + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() } - override fun decodeChar(): Char { - return when (val v = keyOrValueNotNull()) { - is Char -> v - is Int -> v.toChar() - is CharSequence -> v.first() - else -> throw SerializationException("Unsupported type for Char ${v::class.qualifiedName}") - } - } + override fun decodeCollectionSize(descriptor: SerialDescriptor) = map.size - override fun decodeShort(): Short { - return when (val v = keyOrValueNotNull()) { - is Short -> v - is Int -> v.toShort() - is CharSequence -> v.toString().toShort() - else -> throw SerializationException("Unsupported type for Byte ${v::class.qualifiedName}") - } - } + override fun decodeSequentially() = true - override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - return when (val v = keyOrValueNotNull()) { - is CharSequence -> enumDescriptor.getElementIndex(v.toString()) - else -> throw SerializationException("Unsupported type for $enumDescriptor: ${v::class.qualifiedName}") - } - } + data class MapTag(val isKey: Boolean, val schema: Schema) { + companion object { + fun key() = MapTag(true, STRING_SCHEMA) - override fun decodeBoolean(): Boolean { - return when (val v = keyOrValueNotNull()) { - is Boolean -> v - is CharSequence -> v.toString().toBooleanStrict() - else -> throw SerializationException("Unsupported type for Boolean. Actual: ${v::class.qualifiedName}") + fun value(schema: Schema) = MapTag(false, schema) } } - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - index++ - return if (index == entries.size * 2) CompositeDecoder.DECODE_DONE else index - } - - @Suppress("UNCHECKED_CAST") - override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - return when (descriptor.kind) { - StructureKind.CLASS -> RecordDecoder(descriptor, value() as GenericRecord, serializersModule, configuration) - StructureKind.LIST -> - when (descriptor.getElementDescriptor(0).kind) { - PrimitiveKind.BYTE -> ByteArrayDecoder((value() as ByteBuffer).array(), serializersModule) - else -> ListDecoder(schema.valueType, value() as GenericArray<*>, serializersModule, configuration) - } - StructureKind.MAP -> MapDecoder(schema.valueType, value() as Map, serializersModule, configuration) - PolymorphicKind.SEALED, PolymorphicKind.OPEN -> UnionDecoder(descriptor, value() as GenericRecord, serializersModule, configuration) - else -> throw UnsupportedOperationException("Kind ${descriptor.kind} is currently not supported.") - } + companion object { + private val STRING_SCHEMA = Schema.create(Schema.Type.STRING) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/PolymorphicDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/PolymorphicDecoder.kt new file mode 100644 index 00000000..27d83ccd --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/PolymorphicDecoder.kt @@ -0,0 +1,56 @@ +package com.github.avrokotlin.avro4k.decoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroAlias +import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError +import com.github.avrokotlin.avro4k.schema.findAnnotation +import com.github.avrokotlin.avro4k.schema.nonNullSerialName +import com.github.avrokotlin.avro4k.schema.possibleSerializationSubclasses +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.AbstractDecoder +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.Schema +import org.apache.avro.generic.IndexedRecord + +internal class PolymorphicDecoder( + private val avro: Avro, + private val descriptor: SerialDescriptor, + private val value: IndexedRecord, +) : AbstractDecoder() { + override val serializersModule: SerializersModule + get() = avro.serializersModule + + private val namesAndAliasesToSerialName: Map = + descriptor.possibleSerializationSubclasses(serializersModule) + .flatMap { + sequence { + yield(it.nonNullSerialName to it.nonNullSerialName) + it.findAnnotation()?.value?.forEach { alias -> + yield(alias to it.nonNullSerialName) + } + } + }.toMap() + + override fun decodeString(): String { + return tryFindSerialName(value.schema) + ?: throw SerializationException("Unknown schema name ${value.schema.fullName} for polymorphic type ${descriptor.serialName}") + } + + private fun tryFindSerialName(schema: Schema): String? { + return namesAndAliasesToSerialName[schema.fullName] + ?: schema.aliases.firstNotNullOfOrNull { namesAndAliasesToSerialName[it] } + } + + override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { + return AvroValueDecoder(avro, value, value.schema) + .decodeSerializableValue(deserializer) + } + + override fun decodeSequentially() = true + + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + throw IllegalIndexedAccessError() + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt index dbbf5739..fd503792 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt @@ -1,203 +1,45 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroConfiguration -import com.github.avrokotlin.avro4k.AvroEnumDefault -import com.github.avrokotlin.avro4k.schema.extractNonNull -import com.github.avrokotlin.avro4k.schema.findAnnotation -import kotlinx.serialization.ExperimentalSerializationApi +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.DecodedNullError +import com.github.avrokotlin.avro4k.internal.ElementDescriptor import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractDecoder import kotlinx.serialization.encoding.CompositeDecoder -import kotlinx.serialization.encoding.Decoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import org.apache.avro.generic.GenericRecord -import java.nio.ByteBuffer +import org.apache.avro.generic.IndexedRecord -interface ExtendedDecoder : Decoder { - fun decodeAny(): Any? -} +internal class RecordDecoder( + private val record: IndexedRecord, + private val descriptor: SerialDescriptor, + override val avro: Avro, +) : AvroTaggedDecoder() { + // from descriptor element index to schema field + private val fields = avro.recordResolver.resolveFields(record.schema, descriptor) + private var currentIndex = 0 -interface FieldDecoder : ExtendedDecoder { - fun fieldSchema(): Schema -} + override val ElementDescriptor.writerSchema: Schema + get() = writerFieldSchema -@ExperimentalSerializationApi -class RecordDecoder( - private val desc: SerialDescriptor, - private val record: GenericRecord, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, -) : AbstractDecoder(), FieldDecoder { - private var currentIndex = -1 + override fun decodeTaggedNotNullMark(tag: ElementDescriptor) = decodeTaggedNullableValue(tag) != null - @Suppress("UNCHECKED_CAST") - override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - val value = fieldValueNonNull() - return when (descriptor.kind) { - StructureKind.CLASS -> RecordDecoder(descriptor, value as GenericRecord, serializersModule, configuration) - StructureKind.MAP -> - MapDecoder( - fieldSchema(), - value as Map, - serializersModule, - configuration - ) - - StructureKind.LIST -> { - val decoder: CompositeDecoder = - if (descriptor.getElementDescriptor(0).kind == PrimitiveKind.BYTE) { - when (value) { - is List<*> -> ByteArrayDecoder((value as List).toByteArray(), serializersModule) - is Array<*> -> ByteArrayDecoder((value as Array).toByteArray(), serializersModule) - is ByteArray -> ByteArrayDecoder(value, serializersModule) - is ByteBuffer -> ByteArrayDecoder(value.array(), serializersModule) - is GenericFixed -> ByteArrayDecoder(value.bytes(), serializersModule) - else -> this - } - } else { - when (value) { - is List<*> -> ListDecoder(fieldSchema(), value, serializersModule, configuration) - is Array<*> -> ListDecoder(fieldSchema(), value.asList(), serializersModule, configuration) - else -> this - } - } - decoder - } - - PolymorphicKind.SEALED, PolymorphicKind.OPEN -> - UnionDecoder( - descriptor, - value as GenericRecord, - serializersModule, - configuration - ) - - else -> throw UnsupportedOperationException("Decoding descriptor of kind ${descriptor.kind} is currently not supported") - } - } - - private fun fieldValue(): Any? { - if (record.hasField(resolvedFieldName())) { - return record.get(resolvedFieldName()) - } - - return null - } - - private fun fieldValueNonNull(): Any { - val resolvedFieldName = resolvedFieldName() - if (record.hasField(resolvedFieldName)) { - return record.get(resolvedFieldName) - ?: throw SerializationException("Field $resolvedFieldName must not be null") - } - - throw SerializationException("Missing field $resolvedFieldName in record ${record.schema}") - } - - private fun resolvedFieldName(): String = configuration.fieldNamingStrategy.resolve(desc, currentIndex, desc.getElementName(currentIndex)) - - private fun field(): Schema.Field = record.schema.getField(resolvedFieldName()) - - override fun fieldSchema(): Schema { - // if the element is nullable, then we should have a union schema which we can extract the non-null schema from - val schema = field().schema() - return if (schema.isNullable) { - schema.extractNonNull() - } else { - schema - } - } - - override fun decodeString(): String = StringFromAvroValue.fromValue(fieldValueNonNull()) - - override fun decodeBoolean(): Boolean { - return when (val v = fieldValueNonNull()) { - is Boolean -> v - else -> throw SerializationException("Unsupported type for Boolean ${v::class.qualifiedName}") - } - } - - override fun decodeAny(): Any? = fieldValue() - - override fun decodeByte(): Byte { - return when (val v = fieldValueNonNull()) { - is Byte -> v - is Int -> if (v < 255) v.toByte() else throw SerializationException("Out of bound integer cannot be converted to byte [$v]") - else -> throw SerializationException("Unsupported type for Byte ${v::class.qualifiedName}") - } + override fun decodeTaggedValue(tag: ElementDescriptor): Any { + return decodeTaggedNullableValue(tag) ?: throw DecodedNullError(descriptor, tag.elementIndex) } - override fun decodeNotNullMark(): Boolean { - return fieldValue() != null + private fun decodeTaggedNullableValue(tag: ElementDescriptor): Any? { + return tag.writerFieldIndex?.let { record.get(it) } ?: tag.readerDefaultValue } - override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - val symbol = fieldValueNonNull().toString() - val enumIndex = enumDescriptor.getElementIndex(symbol) - - if (enumIndex != CompositeDecoder.UNKNOWN_NAME) { - return enumIndex - } - - return enumDescriptor.findAnnotation()?.value - ?.let { enumDescriptor.getElementIndex(it) } ?: -1 - } - - override fun decodeFloat(): Float { - return when (val v = fieldValueNonNull()) { - is Float -> v - else -> throw SerializationException("Unsupported type for Float ${v::class.qualifiedName}") - } - } - - override fun decodeInt(): Int { - return when (val v = fieldValueNonNull()) { - is Int -> v - else -> throw SerializationException("Unsupported type for Int ${v::class.qualifiedName}") - } - } - - override fun decodeShort(): Short { - return when (val v = fieldValueNonNull()) { - is Short -> v - is Int -> v.toShort() - else -> throw SerializationException("Unsupported type for Short ${v.javaClass}") - } - } - - override fun decodeLong(): Long { - return when (val v = fieldValueNonNull()) { - is Long -> v - is Int -> v.toLong() - else -> throw SerializationException("Unsupported type for Long [is ${v::class.qualifiedName}]") - } - } - - override fun decodeDouble(): Double { - return when (val v = fieldValueNonNull()) { - is Double -> v - is Float -> v.toDouble() - else -> throw SerializationException("Unsupported type for Double ${v::class.qualifiedName}") - } - } - - override fun decodeChar(): Char { - return when (val v = fieldValueNonNull()) { - is Int -> v.toChar() - is Char -> v - is CharSequence -> v.single() - else -> throw SerializationException("Unsupported type for Char ${v::class.qualifiedName}") + override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + while (currentIndex < fields.size) { + val field = fields[currentIndex++] + if (field != null) { + return field.elementIndex + } } + return CompositeDecoder.DECODE_DONE } - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - currentIndex++ - return if (currentIndex < descriptor.elementsCount) currentIndex else CompositeDecoder.DECODE_DONE - } + override fun SerialDescriptor.getTag(index: Int) = fields[index] ?: throw SerializationException("An optional field should not be decoded") } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt deleted file mode 100644 index 9c53b9c4..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt +++ /dev/null @@ -1,42 +0,0 @@ -package com.github.avrokotlin.avro4k.decoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.encoding.CompositeDecoder -import kotlinx.serialization.encoding.CompositeDecoder.Companion.DECODE_DONE -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.generic.GenericRecord - -@ExperimentalSerializationApi -class RootRecordDecoder( - private val record: GenericRecord, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, -) : AbstractDecoder() { - var decoded = false - - override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - return when (descriptor.kind) { - StructureKind.CLASS, StructureKind.OBJECT -> - RecordDecoder( - descriptor, - record, - serializersModule, - configuration - ) - PolymorphicKind.SEALED -> UnionDecoder(descriptor, record, serializersModule, configuration) - else -> throw SerializationException("Non-class structure passed to root record decoder") - } - } - - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - val index = if (decoded) DECODE_DONE else 0 - decoded = true - return index - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt deleted file mode 100644 index 754d4937..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt +++ /dev/null @@ -1,62 +0,0 @@ -package com.github.avrokotlin.avro4k.decoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import com.github.avrokotlin.avro4k.possibleSerializationSubclasses -import com.github.avrokotlin.avro4k.schema.TypeName -import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractDecoder -import kotlinx.serialization.encoding.CompositeDecoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericRecord - -@ExperimentalSerializationApi -class UnionDecoder( - descriptor: SerialDescriptor, - private val value: GenericRecord, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, -) : AbstractDecoder(), FieldDecoder { - private enum class DecoderState(val index: Int) { - BEFORE(0), - READ_CLASS_NAME(1), - READ_DONE(CompositeDecoder.DECODE_DONE), - ; - - fun next() = values().firstOrNull { it.ordinal > this.ordinal } ?: READ_DONE - } - - private var currentState = DecoderState.BEFORE - - private var leafDescriptor: SerialDescriptor = - descriptor.possibleSerializationSubclasses(serializersModule).firstOrNull { - val schemaName = TypeName(name = value.schema.name, namespace = value.schema.namespace) - val serialName = configuration.typeNamingStrategy.resolve(it, it.serialName) - serialName == schemaName - } ?: throw SerializationException("Cannot find a subtype of ${descriptor.serialName} that can be used to deserialize a record of schema ${value.schema}.") - - override fun decodeElementIndex(descriptor: SerialDescriptor): Int { - val currentIndex = currentState.index - currentState = currentState.next() - return currentIndex - } - - override fun fieldSchema(): Schema = value.schema - - /** - * Decode string needs to return the class name of the actual decoded class. - */ - override fun decodeString(): String { - return leafDescriptor.serialName - } - - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { - val recordDecoder = RootRecordDecoder(value, serializersModule, configuration) - return recordDecoder.decodeSerializableValue(deserializer) - } - - override fun decodeAny(): Any = UnsupportedOperationException() -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ArrayEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ArrayEncoder.kt new file mode 100644 index 00000000..24d8f890 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ArrayEncoder.kt @@ -0,0 +1,41 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema +import org.apache.avro.generic.GenericArray +import org.apache.avro.generic.GenericData + +internal class ArrayEncoder( + override val avro: Avro, + arraySize: Int, + private val schema: Schema, + private val onEncoded: (GenericArray<*>) -> Unit, +) : AvroTaggedEncoder() { + init { + schema.ensureTypeOf(Schema.Type.ARRAY) + } + + private val values: Array = Array(arraySize) { null } + + override fun endEncode(descriptor: SerialDescriptor) { + onEncoded(GenericData.Array(schema, values.asList())) + } + + override fun SerialDescriptor.getTag(index: Int) = index + + override val Int.writerSchema: Schema + get() = this@ArrayEncoder.schema.elementType + + override fun encodeTaggedValue( + tag: Int, + value: Any, + ) { + values[tag] = value + } + + override fun encodeTaggedNull(tag: Int) { + require(tag.writerSchema.isNullable) + values[tag] = null + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroEncoder.kt new file mode 100644 index 00000000..f69d7464 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroEncoder.kt @@ -0,0 +1,81 @@ +package com.github.avrokotlin.avro4k.encoder + +import kotlinx.serialization.SerializationException +import kotlinx.serialization.encoding.Encoder +import org.apache.avro.Schema +import org.apache.avro.generic.GenericFixed +import java.nio.ByteBuffer + +interface AvroEncoder : Encoder { + val currentWriterSchema: Schema + + fun encodeBytes(value: ByteBuffer) + + fun encodeBytes(value: ByteArray) + + fun encodeFixed(value: ByteArray) + + fun encodeFixed(value: GenericFixed) + + /** + * Helps to encode a value in different ways depending on the type of the writer schema. + * Each encoder have to return the encoded value for the matched schema. + * + * @param kotlinTypeName represents the kotlin type name of the encoded value for debugging purposes as it's used in exceptions. This is not the written avro type name. + */ + fun encodeValueResolved( + vararg encoders: Pair Any>, + kotlinTypeName: String, + ) +} + +inline fun AvroEncoder.encodeValueResolved(vararg encoders: Pair Any>) { + encodeValueResolved(*encoders, kotlinTypeName = T::class.qualifiedName!!) +} + +sealed class SchemaTypeMatcher { + sealed class Scalar : SchemaTypeMatcher() { + object BOOLEAN : Scalar() + + object INT : Scalar() + + object LONG : Scalar() + + object FLOAT : Scalar() + + object DOUBLE : Scalar() + + object STRING : Scalar() + + object BYTES : Scalar() + + object NULL : Scalar() + } + + object FirstArray : SchemaTypeMatcher() + + object FirstMap : SchemaTypeMatcher() + + sealed class Named : SchemaTypeMatcher() { + object FirstFixed : Named() + + object FirstEnum : Named() + + data class Fixed(val fullName: String) : Named() + + data class Enum(val fullName: String) : Named() + + data class Record(val fullName: String) : Named() + } + + override fun toString(): String { + return this::class.simpleName!! + } +} + +context(Encoder) +internal fun Schema.ensureTypeOf(type: Schema.Type) { + if (this.type != type) { + throw SerializationException("Schema $this must be of type $type to be used with ${this@ensureTypeOf::class}") + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroTaggedEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroTaggedEncoder.kt new file mode 100644 index 00000000..da2b56f9 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroTaggedEncoder.kt @@ -0,0 +1,348 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.internal.toIntExact +import com.github.avrokotlin.avro4k.schema.nonNull +import com.github.avrokotlin.avro4k.schema.nonNullSerialName +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encoding.CompositeEncoder +import kotlinx.serialization.internal.TaggedEncoder +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import org.apache.avro.generic.GenericFixed +import java.nio.ByteBuffer + +@OptIn(InternalSerializationApi::class) +internal abstract class AvroTaggedEncoder : TaggedEncoder(), AvroEncoder { + abstract val avro: Avro + abstract val Tag.writerSchema: Schema + + abstract override fun encodeTaggedNull(tag: Tag) + + override val serializersModule: SerializersModule + get() = avro.serializersModule + + override val currentWriterSchema: Schema + get() = currentTag.writerSchema + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + val schema = currentTag.writerSchema.nonNull + return when (descriptor.kind) { + StructureKind.CLASS, + StructureKind.OBJECT, + -> + avro.unionResolver.tryResolveUnion(schema, descriptor.nonNullSerialName) + ?.takeIf { it.type == Schema.Type.RECORD } + ?.let { RecordEncoder(avro, descriptor, it) { encodeTaggedValue(currentTag, it) } } + ?: throwUnsupportedSchemaType(schema, descriptor) + + is PolymorphicKind -> + PolymorphicEncoder(avro, schema) { + encodeTaggedValue(currentTag, it) + } + + else -> throw SerializationException("Unsupported structure kind: $descriptor") + } + } + + override fun beginCollection( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder { + val schema = currentTag.writerSchema.nonNull + return when (descriptor.kind) { + StructureKind.LIST -> + when (schema.type) { + Schema.Type.ARRAY -> ArrayEncoder(avro, collectionSize, schema) { encodeTaggedValue(currentTag, it) } + Schema.Type.BYTES -> BytesEncoder(avro, collectionSize) { encodeTaggedValue(currentTag, it) } + Schema.Type.FIXED -> FixedEncoder(avro, collectionSize, schema) { encodeTaggedValue(currentTag, it) } + else -> throwUnsupportedSchemaType(schema, descriptor) + } + + StructureKind.MAP -> + when (schema.type) { + Schema.Type.MAP -> MapEncoder(avro, collectionSize, schema) { encodeTaggedValue(currentTag, it) } + else -> throwUnsupportedSchemaType(schema, descriptor) + } + + else -> throw SerializationException("Unsupported collection kind: $descriptor") + } + } + + private fun throwUnsupportedSchemaType( + schema: Schema, + descriptor: SerialDescriptor, + ): Nothing { + throw SerializationException("Unsupported schema $schema for ${descriptor.kind} $descriptor") + } + + override fun encodeBytes(value: ByteBuffer) { + encodeTaggedValueResolved( + currentTag, + SchemaTypeMatcher.Scalar.BYTES to { value }, + SchemaTypeMatcher.Named.FirstFixed to { value.array().toPaddedGenericFixed(it, endPadded = false) }, + SchemaTypeMatcher.Scalar.STRING to { value.array().decodeToString() } + ) + } + + override fun encodeBytes(value: ByteArray) { + encodeTaggedValueResolved( + currentTag, + SchemaTypeMatcher.Scalar.BYTES to { ByteBuffer.wrap(value) }, + SchemaTypeMatcher.Named.FirstFixed to { value.toPaddedGenericFixed(it, endPadded = false) }, + SchemaTypeMatcher.Scalar.STRING to { value.decodeToString() } + ) + } + + override fun encodeFixed(value: GenericFixed) { + encodeTaggedValueResolved( + currentTag, + SchemaTypeMatcher.Named.Fixed(value.schema.fullName) to { value }, + SchemaTypeMatcher.Scalar.BYTES to { ByteBuffer.wrap(value.bytes()) }, + SchemaTypeMatcher.Scalar.STRING to { value.bytes().decodeToString() } + ) + } + + override fun encodeFixed(value: ByteArray) { + encodeTaggedValueResolved( + currentTag, + SchemaTypeMatcher.Named.FirstFixed to { value.toPaddedGenericFixed(it, endPadded = false) }, + SchemaTypeMatcher.Scalar.BYTES to { ByteBuffer.wrap(value) }, + SchemaTypeMatcher.Scalar.STRING to { value.decodeToString() } + ) + } + + override fun encodeTaggedBoolean( + tag: Tag, + value: Boolean, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.BOOLEAN to { value }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() } + ) + } + + override fun encodeTaggedByte( + tag: Tag, + value: Byte, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.INT to { value.toInt() }, + SchemaTypeMatcher.Scalar.LONG to { value.toLong() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() } + ) + } + + override fun encodeTaggedShort( + tag: Tag, + value: Short, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.INT to { value.toInt() }, + SchemaTypeMatcher.Scalar.LONG to { value.toLong() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() } + ) + } + + override fun encodeTaggedInt( + tag: Tag, + value: Int, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.INT to { value }, + SchemaTypeMatcher.Scalar.LONG to { value.toLong() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() } + ) + } + + override fun encodeTaggedLong( + tag: Tag, + value: Long, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.LONG to { value }, + SchemaTypeMatcher.Scalar.INT to { value.toIntExact() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() } + ) + } + + override fun encodeTaggedFloat( + tag: Tag, + value: Float, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.FLOAT to { value }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.INT to { value.toInt() } + ) + } + + override fun encodeTaggedDouble( + tag: Tag, + value: Double, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.DOUBLE to { value }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.INT to { value.toInt() } + ) + } + + override fun encodeTaggedChar( + tag: Tag, + value: Char, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.INT to { value.code }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() } + ) + } + + override fun encodeTaggedString( + tag: Tag, + value: String, + ) { + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Scalar.STRING to { value }, + SchemaTypeMatcher.Scalar.BYTES to { value.encodeToByteArray() }, + SchemaTypeMatcher.Named.FirstFixed to { value.encodeToByteArray().toPaddedGenericFixed(it, endPadded = true) }, + SchemaTypeMatcher.Named.FirstEnum to { GenericData.EnumSymbol(it, value) } + ) + } + + override fun encodeTaggedEnum( + tag: Tag, + enumDescriptor: SerialDescriptor, + ordinal: Int, + ) { + /* + We allow enums as ENUM (must match the descriptor's full name), STRING or UNION. + For UNION, we look for an enum with the descriptor's full name, otherwise a string. + */ + val value = enumDescriptor.getElementName(ordinal) + + encodeTaggedValueResolved( + tag, + SchemaTypeMatcher.Named.Enum(enumDescriptor.nonNullSerialName) to { GenericData.EnumSymbol(it, value) }, + SchemaTypeMatcher.Scalar.STRING to { value }, + kotlinTypeName = enumDescriptor.serialName + ) + } + + private inline fun encodeTaggedValueResolved( + tag: Tag, + vararg encoders: Pair Any>, + ) = encodeTaggedValueResolved(tag, *encoders, kotlinTypeName = T::class.qualifiedName!!) + + override fun encodeValueResolved( + vararg encoders: Pair Any>, + kotlinTypeName: String, + ) = encodeTaggedValueResolved(currentTag, *encoders, kotlinTypeName = kotlinTypeName) + + private fun encodeTaggedValueResolved( + tag: Tag, + vararg encoders: Pair Any>, + kotlinTypeName: String, + ) { + // TODO cache the resolved type from the elementIndex + // We have to retrieve the SerialDescriptor and the elementIndex from the tag + // to cache the resolved schema given SerialDescriptor, elementIndex and the non-resolved writer schema. + val schema = tag.writerSchema + val encoder = hashMapOf(*encoders) + + val valueEncoder = + schema.toTypeMatchers() + .firstNotNullOfOrNull { typeMatcher -> + encoder[typeMatcher.first]?.let { typeMatcher.second to it } + } + if (valueEncoder == null) { + if (schema.type == Schema.Type.UNION) { + throw SerializationException( + "Expected one of schema types ${encoder.keys} but no compatible schema type found " + + "for encoded kotlin type $kotlinTypeName in union $schema" + ) + } else { + throw SerializationException( + "The kotlin type $kotlinTypeName expected to be encoded as ${encoder.keys} but was $schema" + ) + } + } + encodeTaggedValue(tag, valueEncoder.second(valueEncoder.first)) + } +} + +@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") +private fun Schema.toTypeMatchers(): Sequence> = + when (type) { + Schema.Type.BOOLEAN -> sequenceOf(SchemaTypeMatcher.Scalar.BOOLEAN to this) + Schema.Type.INT -> sequenceOf(SchemaTypeMatcher.Scalar.INT to this) + Schema.Type.LONG -> sequenceOf(SchemaTypeMatcher.Scalar.LONG to this) + Schema.Type.FLOAT -> sequenceOf(SchemaTypeMatcher.Scalar.FLOAT to this) + Schema.Type.DOUBLE -> sequenceOf(SchemaTypeMatcher.Scalar.DOUBLE to this) + Schema.Type.STRING -> sequenceOf(SchemaTypeMatcher.Scalar.STRING to this) + Schema.Type.BYTES -> sequenceOf(SchemaTypeMatcher.Scalar.BYTES to this) + Schema.Type.NULL -> sequenceOf(SchemaTypeMatcher.Scalar.NULL to this) + Schema.Type.FIXED -> + sequenceOf(SchemaTypeMatcher.Named.Fixed(fullName) to this) + + aliases.map { SchemaTypeMatcher.Named.Fixed(it) to this } + + sequenceOf(SchemaTypeMatcher.Named.FirstFixed to this) + + Schema.Type.ENUM -> + sequenceOf(SchemaTypeMatcher.Named.Enum(fullName) to this) + + aliases.map { SchemaTypeMatcher.Named.Enum(it) to this } + + sequenceOf(SchemaTypeMatcher.Named.FirstEnum to this) + + Schema.Type.RECORD -> + sequenceOf(SchemaTypeMatcher.Named.Record(fullName) to this) + + aliases.map { SchemaTypeMatcher.Named.Record(it) to this } + + Schema.Type.ARRAY -> sequenceOf(SchemaTypeMatcher.FirstArray to this) + Schema.Type.MAP -> sequenceOf(SchemaTypeMatcher.FirstMap to this) + Schema.Type.UNION -> types.asSequence().flatMap { it.toTypeMatchers() } + } + +private fun ByteArray.toPaddedGenericFixed( + schema: Schema, + endPadded: Boolean, +): GenericFixed { + if (size > schema.fixedSize) { + throw SerializationException("Actual byte array size $size is greater than schema fixed size $schema") + } + val padSize = schema.fixedSize - size + return GenericData.Fixed( + schema, + if (padSize > 0) { + if (endPadded) { + this + ByteArray(padSize) + } else { + ByteArray(padSize) + this + } + } else { + this + } + ) +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroValueEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroValueEncoder.kt new file mode 100644 index 00000000..48c5d8c8 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/AvroValueEncoder.kt @@ -0,0 +1,34 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class AvroValueEncoder( + override val avro: Avro, + schema: Schema, + private val onEncoded: (Any?) -> Unit, +) : AvroTaggedEncoder() { + override val Schema.writerSchema: Schema + get() = this + + init { + pushTag(schema) + } + + override fun SerialDescriptor.getTag(index: Int): Schema { + throw UnsupportedOperationException("${this::class} does not support element encoding") + } + + override fun encodeTaggedValue( + tag: Schema, + value: Any, + ) { + onEncoded(value) + } + + override fun encodeTaggedNull(tag: Schema) { + require(tag.writerSchema.isNullable) + onEncoded(null) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ByteArrayEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ByteArrayEncoder.kt deleted file mode 100644 index 71e834b5..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ByteArrayEncoder.kt +++ /dev/null @@ -1,41 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData -import java.nio.ByteBuffer - -@ExperimentalSerializationApi -class ByteArrayEncoder( - private val schema: Schema, - override val serializersModule: SerializersModule, - private val callback: (Any) -> Unit, -) : AbstractEncoder() { - private val bytes = mutableListOf() - - override fun encodeByte(value: Byte) { - bytes.add(value) - } - - override fun endStructure(descriptor: SerialDescriptor) { - when (schema.type) { - Schema.Type.FIXED -> { - // the array passed in must be padded to size - val padding = schema.fixedSize - bytes.size - val padded = - ByteBuffer.allocate(schema.fixedSize) - .put(ByteArray(padding) { 0 }) - .put(bytes.toByteArray()) - .array() - callback(GenericData.get().createFixed(null, padded, schema)) - } - // Wrapping the resulting byte array directly as this does not duplicate the byte array - Schema.Type.BYTES -> callback(ByteBuffer.wrap(bytes.toByteArray())) - else -> throw SerializationException("Cannot encode byte array when schema is ${schema.type}") - } - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/BytesEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/BytesEncoder.kt new file mode 100644 index 00000000..40505369 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/BytesEncoder.kt @@ -0,0 +1,36 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema +import java.nio.ByteBuffer + +private val BYTES_SCHEMA = Schema.create(Schema.Type.BYTES) + +internal class BytesEncoder( + override val avro: Avro, + arraySize: Int, + private val onEncoded: (ByteBuffer) -> Unit, +) : AvroTaggedEncoder() { + private val output: ByteBuffer = ByteBuffer.allocate(arraySize) + + override fun encodeTaggedNull(tag: Int) { + throw UnsupportedOperationException("nulls are not supported for schema type ${Schema.Type.BYTES}") + } + + override fun endEncode(descriptor: SerialDescriptor) { + onEncoded(output.rewind()) + } + + override fun SerialDescriptor.getTag(index: Int) = index + + override val Int.writerSchema: Schema + get() = BYTES_SCHEMA + + override fun encodeTaggedByte( + tag: Int, + value: Byte, + ) { + output.put(tag, value) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FieldEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FieldEncoder.kt deleted file mode 100644 index 95e3e9dd..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FieldEncoder.kt +++ /dev/null @@ -1,18 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import kotlinx.serialization.encoding.Encoder -import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer - -interface ExtendedEncoder : Encoder { - fun encodeByteArray(buffer: ByteBuffer) - - fun encodeFixed(fixed: GenericFixed) -} - -interface FieldEncoder : ExtendedEncoder { - fun addValue(value: Any) - - fun fieldSchema(): Schema -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FixedEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FixedEncoder.kt new file mode 100644 index 00000000..3af4e173 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/FixedEncoder.kt @@ -0,0 +1,45 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import org.apache.avro.generic.GenericFixed + +internal class FixedEncoder( + override val avro: Avro, + arraySize: Int, + private val schema: Schema, + private val onEncoded: (GenericFixed) -> Unit, +) : AvroTaggedEncoder() { + private val padSize = schema.fixedSize - arraySize + private val output: ByteArray = ByteArray(schema.fixedSize) + + init { + schema.ensureTypeOf(Schema.Type.FIXED) + if (arraySize > schema.fixedSize) { + throw SerializationException("Actual collection size $arraySize is greater than schema fixed size $schema") + } + } + + override fun encodeTaggedNull(tag: Int) { + throw UnsupportedOperationException("nulls are not supported for schema type ${Schema.Type.FIXED}") + } + + override fun endEncode(descriptor: SerialDescriptor) { + onEncoded(GenericData.Fixed(schema, output)) + } + + override fun SerialDescriptor.getTag(index: Int) = padSize + index + + override val Int.writerSchema: Schema + get() = this@FixedEncoder.schema + + override fun encodeTaggedByte( + tag: Int, + value: Byte, + ) { + output[tag] = value + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt deleted file mode 100644 index 3cac4e51..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt +++ /dev/null @@ -1,84 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer - -@ExperimentalSerializationApi -class ListEncoder( - private val schema: Schema, - override val serializersModule: SerializersModule, - override val configuration: AvroConfiguration, - private val callback: (GenericData.Array) -> Unit, -) : AbstractEncoder(), StructureEncoder { - private val list = mutableListOf() - - override fun endStructure(descriptor: SerialDescriptor) { - val generic = GenericData.Array(schema, list.toList()) - callback(generic) - } - - override fun fieldSchema(): Schema = schema.elementType - - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return super.beginStructure(descriptor) - } - - override fun addValue(value: Any) { - list.add(value) - } - - override fun encodeString(value: String) { - list.add(StringToAvroValue.toValue(schema, value)) - } - - override fun encodeLong(value: Long) { - list.add(value) - } - - override fun encodeDouble(value: Double) { - list.add(value) - } - - override fun encodeBoolean(value: Boolean) { - list.add(value) - } - - override fun encodeShort(value: Short) { - list.add(value) - } - - override fun encodeByteArray(buffer: ByteBuffer) { - list.add(buffer) - } - - override fun encodeFixed(fixed: GenericFixed) { - list.add(fixed) - } - - override fun encodeByte(value: Byte) { - list.add(value) - } - - override fun encodeFloat(value: Float) { - list.add(value) - } - - override fun encodeInt(value: Int) { - list.add(value) - } - - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - list.add(ValueToEnum.toValue(fieldSchema(), enumDescriptor, index)) - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt index f2b24c9e..6950aca8 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt @@ -1,156 +1,62 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroConfiguration -import kotlinx.serialization.ExperimentalSerializationApi +import com.github.avrokotlin.avro4k.Avro import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import org.apache.avro.util.Utf8 -import java.nio.ByteBuffer -@ExperimentalSerializationApi -class MapEncoder( - schema: Schema, - override val serializersModule: SerializersModule, - override val configuration: AvroConfiguration, - private val callback: (Map) -> Unit, -) : AbstractEncoder(), - CompositeEncoder, - StructureEncoder { - private val map = mutableMapOf() - private var key: String? = null - private val valueSchema = schema.valueType +private val STRING_SCHEMA = Schema.create(Schema.Type.STRING) - override fun encodeString(value: String) { - if (key == null) { - key = value - } else { - finalizeMapEntry(StringToAvroValue.toValue(valueSchema, value)) - } +internal class MapEncoder( + override val avro: Avro, + mapSize: Int, + private val schema: Schema, + private val onEncoded: (Map) -> Unit, +) : AvroTaggedEncoder() { + init { + schema.ensureTypeOf(Schema.Type.MAP) } - override fun encodeBoolean(value: Boolean) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) - } - } + private val entries: MutableList> = ArrayList(mapSize) + private lateinit var currentKey: String - override fun encodeByte(value: Byte) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) - } + override fun endEncode(descriptor: SerialDescriptor) { + onEncoded(entries.associate { it.first to it.second }) } - override fun encodeChar(value: Char) { - if (key == null) { - key = value.toString() + override fun SerialDescriptor.getTag(index: Int) = + if (index % 2 == 0) { + MapTag.key() } else { - finalizeMapEntry(value.code) + MapTag.value(schema.valueType) } - } - override fun encodeDouble(value: Double) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) - } - } + override val MapTag.writerSchema: Schema + get() = schema - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, + override fun encodeTaggedValue( + tag: MapTag, + value: Any, ) { - val value = enumDescriptor.getElementName(index) - if (key == null) { - key = value + if (tag.isKey) { + currentKey = value.toString() } else { - finalizeMapEntry(value) + entries.add(currentKey to value) } } - override fun encodeInt(value: Int) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) - } - } - - override fun encodeLong(value: Long) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) + override fun encodeTaggedNull(tag: MapTag) { + if (tag.isKey) { + throw SerializationException("Map key cannot be null") } + entries.add(currentKey to null) } - override fun encodeFloat(value: Float) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) - } - } + data class MapTag(val isKey: Boolean, val schema: Schema) { + companion object { + fun key() = MapTag(true, STRING_SCHEMA) - override fun encodeShort(value: Short) { - if (key == null) { - key = value.toString() - } else { - finalizeMapEntry(value) + fun value(schema: Schema) = MapTag(false, schema) } } - - override fun encodeValue(value: Any) { - val k = key - if (k == null) { - throw SerializationException("Expected key but received value $value") - } else { - finalizeMapEntry(value) - } - } - - override fun encodeNull() { - val k = key - if (k == null) { - throw SerializationException("Expected key but received ") - } else { - finalizeMapEntry(null) - } - } - - private fun finalizeMapEntry(value: Any?) { - map[Utf8(key)] = value - key = null - } - - override fun endStructure(descriptor: SerialDescriptor) { - callback(map.toMap()) - } - - override fun encodeByteArray(buffer: ByteBuffer) { - encodeValue(buffer) - } - - override fun encodeFixed(fixed: GenericFixed) { - encodeValue(fixed) - } - - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return super.beginStructure(descriptor) - } - - override fun addValue(value: Any) { - encodeValue(value) - } - - override fun fieldSchema(): Schema = valueSchema } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/PolymorphicEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/PolymorphicEncoder.kt new file mode 100644 index 00000000..d50a1254 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/PolymorphicEncoder.kt @@ -0,0 +1,36 @@ +package com.github.avrokotlin.avro4k.encoder + +import com.github.avrokotlin.avro4k.Avro +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.AbstractEncoder +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.Schema + +internal class PolymorphicEncoder( + private val avro: Avro, + private val schema: Schema, + private val onEncoded: (Any) -> Unit, +) : AbstractEncoder() { + override val serializersModule: SerializersModule + get() = avro.serializersModule + + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + // index 0 is the type discriminator, index 1 is the value itself + // we don't need the type discriminator here + return index == 1 + } + + override fun encodeSerializableValue( + serializer: SerializationStrategy, + value: T, + ) { + // Here we don't need to resolve the union, as it is already resolved inside AvroTaggedEncoder.beginStructure + AvroValueEncoder(avro, schema) { + onEncoded(it ?: throw UnsupportedOperationException("Polymorphic types cannot encode null values")) + }.encodeSerializableValue(serializer, value) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt index 13067275..6acce032 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt @@ -1,127 +1,75 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroConfiguration +import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.ListRecord -import com.github.avrokotlin.avro4k.Record -import com.github.avrokotlin.avro4k.schema.extractNonNull -import kotlinx.serialization.ExperimentalSerializationApi +import com.github.avrokotlin.avro4k.internal.ElementDescriptor import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer +import org.apache.avro.generic.GenericRecord -@ExperimentalSerializationApi -interface StructureEncoder : FieldEncoder { - val configuration: AvroConfiguration - - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.LIST -> { - when (descriptor.getElementDescriptor(0).unwrapValueClass.kind) { - PrimitiveKind.BYTE -> ByteArrayEncoder(fieldSchema(), serializersModule) { addValue(it) } - else -> ListEncoder(fieldSchema(), serializersModule, configuration) { addValue(it) } - } - } - StructureKind.CLASS -> RecordEncoder(fieldSchema(), serializersModule, configuration) { addValue(it) } - StructureKind.MAP -> MapEncoder(fieldSchema(), serializersModule, configuration) { addValue(it) } - is PolymorphicKind -> UnionEncoder(fieldSchema(), serializersModule, configuration) { addValue(it) } - else -> throw SerializationException(".beginStructure was called on a non-structure type [$descriptor]") - } - } -} - -@ExperimentalSerializationApi -internal val SerialDescriptor.unwrapValueClass: SerialDescriptor - get() = if (isInline) getElementDescriptor(0) else this - -@ExperimentalSerializationApi -class RecordEncoder( +internal class RecordEncoder( + override val avro: Avro, + descriptor: SerialDescriptor, private val schema: Schema, - override val serializersModule: SerializersModule, - override val configuration: AvroConfiguration, - val callback: (Record) -> Unit, -) : AbstractEncoder(), StructureEncoder { - private val builder = RecordBuilder(schema) - private var currentIndex = -1 - - override fun fieldSchema(): Schema { - // if the element is nullable, then we should have a union schema which we can extract the non-null schema from - val currentFieldSchema = schema.fields[currentIndex].schema() - return if (currentFieldSchema.isNullable) { - currentFieldSchema.extractNonNull() - } else { - currentFieldSchema - } + private val onEncoded: (GenericRecord) -> Unit, +) : AvroTaggedEncoder() { + init { + schema.ensureTypeOf(Schema.Type.RECORD) } - override fun addValue(value: Any) { - builder.add(value) - } + private val fieldValues: Array = Array(schema.fields.size) { null } - override fun encodeString(value: String) { - builder.add(StringToAvroValue.toValue(fieldSchema(), value)) - } + // from descriptor element index to schema field + private val fields = avro.recordResolver.resolveFields(schema, descriptor) - override fun encodeChar(value: Char) { - val schema = fieldSchema() - when (schema.type) { - Schema.Type.STRING -> builder.add(value.toString()) - Schema.Type.INT -> builder.add(value.code) - else -> throw SerializationException("Unsupported type for Char: $schema") + override fun encodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T?, + ) { + // Skip data class fields that are not present in the schema + if (fields[index]?.writerFieldIndex != null) { + super.encodeNullableSerializableElement(descriptor, index, serializer, value) } } - override fun encodeValue(value: Any) { - builder.add(value) - } - - override fun encodeElement( + override fun encodeSerializableElement( descriptor: SerialDescriptor, index: Int, - ): Boolean { - currentIndex = index - return true - } - - override fun encodeByteArray(buffer: ByteBuffer) { - builder.add(buffer) + serializer: SerializationStrategy, + value: T, + ) { + // Skip data class fields that are not present in the schema + if (fields[index]?.writerFieldIndex != null) { + super.encodeSerializableElement(descriptor, index, serializer, value) + } } - override fun encodeFixed(fixed: GenericFixed) { - builder.add(fixed) + override fun endEncode(descriptor: SerialDescriptor) { + onEncoded(ListRecord(schema, fieldValues.asList())) } - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - builder.add(ValueToEnum.toValue(fieldSchema(), enumDescriptor, index)) - } + override fun SerialDescriptor.getTag(index: Int) = + fields[index] ?: throw SerializationException("An optional kotlin field without corresponding writer field should not be encoded") - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return super.beginStructure(descriptor) - } + override val ElementDescriptor.writerSchema: Schema + get() = writerFieldSchema - override fun endStructure(descriptor: SerialDescriptor) { - callback(builder.record()) + override fun encodeTaggedValue( + tag: ElementDescriptor, + value: Any, + ) { + if (tag.writerFieldIndex != null) { + fieldValues[tag.writerFieldIndex] = value + } } - override fun encodeNull() { - builder.add(null) + override fun encodeTaggedNull(tag: ElementDescriptor) { + if (tag.writerFieldIndex != null) { + fieldValues[tag.writerFieldIndex] = null + } } -} - -class RecordBuilder(private val schema: Schema) { - private val values = ArrayList(schema.fields.size) - - fun add(value: Any?) = values.add(value) - - fun record(): Record = ListRecord(schema, values) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt deleted file mode 100644 index 05600355..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt +++ /dev/null @@ -1,32 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import com.github.avrokotlin.avro4k.Record -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema - -@ExperimentalSerializationApi -class RootRecordEncoder( - private val schema: Schema, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, - private val callback: (Record) -> Unit, -) : AbstractEncoder() { - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - is StructureKind.CLASS -> RecordEncoder(schema, serializersModule, configuration, callback) - is PolymorphicKind -> UnionEncoder(schema, serializersModule, configuration, callback) - else -> throw SerializationException("Unsupported root element passed to root record encoder") - } - } - - override fun endStructure(descriptor: SerialDescriptor) { - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ToAvroValue.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ToAvroValue.kt deleted file mode 100644 index a63baa35..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ToAvroValue.kt +++ /dev/null @@ -1,45 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import com.github.avrokotlin.avro4k.schema.extractNonNull -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData -import org.apache.avro.util.Utf8 -import java.nio.ByteBuffer - -object StringToAvroValue { - fun toValue( - schema: Schema, - t: String, - ): Any { - return when (schema.type) { - Schema.Type.FIXED -> { - val size = t.toByteArray().size - if (size > schema.fixedSize) { - throw SerializationException("Cannot write string with $size bytes to fixed type of size ${schema.fixedSize}") - } - // the array passed in must be padded to size - val bytes = ByteBuffer.allocate(schema.fixedSize).put(t.toByteArray()).array() - GenericData.get().createFixed(null, bytes, schema) - } - Schema.Type.BYTES -> ByteBuffer.wrap(t.toByteArray()) - else -> Utf8(t) - } - } -} - -@ExperimentalSerializationApi -object ValueToEnum { - fun toValue( - schema: Schema, - enumDescription: SerialDescriptor, - ordinal: Int, - ): GenericData.EnumSymbol { - // the schema provided will be a union, so we should extract the correct schema - val symbol = enumDescription.getElementName(ordinal) - val nonNullSchema = schema.extractNonNull() - return GenericData.get().createEnum(symbol, nonNullSchema) as GenericData.EnumSymbol - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt deleted file mode 100644 index 27ccb8c2..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt +++ /dev/null @@ -1,44 +0,0 @@ -package com.github.avrokotlin.avro4k.encoder - -import com.github.avrokotlin.avro4k.AvroConfiguration -import com.github.avrokotlin.avro4k.Record -import com.github.avrokotlin.avro4k.schema.TypeName -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema - -@ExperimentalSerializationApi -class UnionEncoder( - private val unionSchema: Schema, - override val serializersModule: SerializersModule, - private val configuration: AvroConfiguration, - private val callback: (Record) -> Unit, -) : AbstractEncoder() { - override fun encodeString(value: String) { - // No need to encode the name of the concrete type. The name will never be encoded in the avro schema. - } - - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - is StructureKind.CLASS, is StructureKind.OBJECT -> { - // Hand in the concrete schema for the specified SerialDescriptor so that fields can be correctly decoded. - val leafSchema = - unionSchema.types.first { - val schemaName = TypeName(name = it.name, namespace = it.namespace) - val serialName = configuration.typeNamingStrategy.resolve(descriptor, descriptor.serialName) - serialName == schemaName - } - RecordEncoder(leafSchema, serializersModule, configuration, callback) - } - else -> throw SerializationException("Unsupported root element passed to root record encoder") - } - } - - override fun endStructure(descriptor: SerialDescriptor) { - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt new file mode 100644 index 00000000..480385eb --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/NumberUtils.kt @@ -0,0 +1,88 @@ +package com.github.avrokotlin.avro4k.internal + +import kotlinx.serialization.SerializationException +import java.math.BigDecimal + +internal fun BigDecimal.toLongExact(): Long { + if (this.toLong().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Long") + } + return this.toLong() +} + +internal fun Int.toByteExact(): Byte { + if (this.toByte().toInt() != this) { + throw SerializationException("Value $this is not a valid Byte") + } + return this.toByte() +} + +internal fun Long.toByteExact(): Byte { + if (this.toByte().toLong() != this) { + throw SerializationException("Value $this is not a valid Byte") + } + return this.toByte() +} + +internal fun BigDecimal.toByteExact(): Byte { + if (this.toInt().toByte().toInt().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Byte") + } + return this.toInt().toByte() +} + +internal fun Int.toShortExact(): Short { + if (this.toShort().toInt() != this) { + throw SerializationException("Value $this is not a valid Short") + } + return this.toShort() +} + +internal fun Long.toShortExact(): Short { + if (this.toShort().toLong() != this) { + throw SerializationException("Value $this is not a valid Short") + } + return this.toShort() +} + +internal fun BigDecimal.toShortExact(): Short { + if (this.toInt().toShort().toInt().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Short") + } + return this.toInt().toShort() +} + +internal fun Long.toIntExact(): Int { + if (this.toInt().toLong() != this) { + throw SerializationException("Value $this is not a valid Int") + } + return this.toInt() +} + +internal fun BigDecimal.toIntExact(): Int { + if (this.toInt().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Int") + } + return this.toInt() +} + +internal fun BigDecimal.toFloatExact(): Float { + if (this.toFloat().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Float") + } + return this.toFloat() +} + +internal fun Double.toFloatExact(): Float { + if (this.toFloat().toDouble() != this) { + throw SerializationException("Value $this is not a valid Float") + } + return this.toFloat() +} + +internal fun BigDecimal.toDoubleExact(): Double { + if (this.toDouble().toBigDecimal() != this) { + throw SerializationException("Value $this is not a valid Double") + } + return this.toDouble() +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt new file mode 100644 index 00000000..8beddfcc --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt @@ -0,0 +1,190 @@ +package com.github.avrokotlin.avro4k.internal + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroAlias +import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.schema.findElementAnnotation +import com.github.avrokotlin.avro4k.schema.isStartingAsJson +import kotlinx.serialization.SerializationException +import kotlinx.serialization.Transient +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.elementNames +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.boolean +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import java.util.concurrent.ConcurrentHashMap + +internal class RecordResolver( + private val avro: Avro, +) { + /** + * For a class descriptor + writerSchema, it returns a map of the field index to the schema field. + * + * Note: We use the descriptor in the key as we could have multiple descriptors for the same record schema, and multiple record schemas for the same descriptor. + */ + private val fieldCache: MutableMap, List> = ConcurrentHashMap() + + /** + * @return a list of fields for the writer schema, in the same order as the class descriptor. If a field is not found in the schema, the array item is null. + */ + fun resolveFields( + writerSchema: Schema, + classDescriptor: SerialDescriptor, + ): List { + if (classDescriptor.elementsCount == 0) { + return emptyList() + } + return fieldCache.getOrPut(classDescriptor to writerSchema) { loadCache(classDescriptor, writerSchema) } + } + + /** + * Here the different steps to get the schema field corresponding to the serial descriptor element: + * - class field name -> schema field name + * - class field name -> schema field aliases + * - class field aliases -> schema field name + * - class field aliases -> schema field aliases + * - if field is optional, returns null + * - if still not found, [SerializationException] thrown + */ + private fun loadCache( + classDescriptor: SerialDescriptor, + writerSchema: Schema, + ): List { + val readerSchema = avro.schema(classDescriptor) + return classDescriptor.elementNames.mapIndexed { elementIndex, _ -> + val avroFieldName = avro.configuration.fieldNamingStrategy.resolve(classDescriptor, elementIndex) + + val writerField = writerSchema.tryGetField(avroFieldName, classDescriptor, elementIndex) + val writerFieldSchema = writerField?.schema() ?: readerSchema.getField(avroFieldName).schema() + + // not using the default from reader schema field to simplify the default value parsing + val readerDefaultAnnotation = classDescriptor.findElementAnnotation(elementIndex) + + if (writerField == null && readerDefaultAnnotation == null) { + if (classDescriptor.isElementOptional(elementIndex)) { + // default kotlin values are managed natively by kotlinx.serialization, so we can safely skip the field + return@mapIndexed null + } else { + throw SerializationException( + "Field '$avroFieldName' at index $elementIndex from descriptor '${classDescriptor.serialName}' not found in schema $writerSchema. " + + "Consider removing the field, " + + "adding a default value, " + + "or annotating it with @${Transient::class.qualifiedName}" + ) + } + } else { + ElementDescriptor( + elementIndex = elementIndex, + writerFieldIndex = writerField?.pos(), + writerFieldSchema = writerFieldSchema, + readerDefaultValue = readerDefaultAnnotation?.parseValue(writerFieldSchema), + readerHasDefaultValue = readerDefaultAnnotation != null + ) + } + } + } + + private fun Schema.tryGetField( + avroFieldName: String, + classDescriptor: SerialDescriptor, + elementIndex: Int, + ): Schema.Field? = + getField(avroFieldName) + ?: fields.firstOrNull { avroFieldName in it.aliases() } + ?: classDescriptor.findElementAnnotation(elementIndex)?.value?.let { aliases -> + fields.firstOrNull { schemaField -> + schemaField.name() in aliases || schemaField.aliases().any { alias -> alias in aliases } + } + } +} + +internal data class ElementDescriptor( + val elementIndex: Int, + val writerFieldIndex: Int?, + val writerFieldSchema: Schema, + val readerDefaultValue: Any?, + val readerHasDefaultValue: Boolean, +) + +private fun AvroDefault.parseValue(schema: Schema): Any? { + if (value.isStartingAsJson()) { + return Json.parseToJsonElement(value).convertObject(schema) + } + return JsonPrimitive(value).convertObject(schema) +} + +private fun JsonElement.convertObject(schema: Schema): Any? { + return when (this) { + is JsonArray -> + when (schema.type) { + Schema.Type.ARRAY -> this.map { it.convertObject(schema.elementType) } + Schema.Type.UNION -> this.convertObject(schema.resolveUnion(this, Schema.Type.ARRAY)) + else -> throw SerializationException("Not a valid array value for schema $schema: $this") + } + + is JsonNull -> null + is JsonObject -> + when (schema.type) { + Schema.Type.RECORD -> { + GenericData.Record(schema).apply { + entries.forEach { (fieldName, value) -> + val schemaField = schema.getField(fieldName) + put(schemaField.pos(), value.convertObject(schemaField.schema())) + } + } + } + + Schema.Type.MAP -> entries.associate { (key, value) -> key to value.convertObject(schema.valueType) } + Schema.Type.UNION -> this.convertObject(schema.resolveUnion(this, Schema.Type.RECORD, Schema.Type.MAP)) + else -> throw SerializationException("Not a valid record value for schema $schema: $this") + } + + is JsonPrimitive -> + when (schema.type) { + Schema.Type.BYTES -> this.content.toByteArray() + Schema.Type.FIXED -> GenericData.Fixed(schema, this.content.toByteArray()) + Schema.Type.STRING -> this.content + Schema.Type.BOOLEAN -> this.boolean + + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + -> this.content.toBigDecimal() + + Schema.Type.UNION -> + this.convertObject( + schema.resolveUnion( + this, + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.STRING, + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE + ) + ) + + else -> throw SerializationException("Not a valid primitive value for schema $schema: $this") + } + } +} + +private fun Schema.resolveUnion( + value: JsonElement?, + vararg expectedTypes: Schema.Type, +): Schema { + val index = types.indexOfFirst { it.type in expectedTypes } + if (index < 0) { + throw SerializationException("Union type does not contain one of ${expectedTypes.asList()}, unable to convert default value: $value") + } + return types[index] +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/UnionResolver.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/UnionResolver.kt new file mode 100644 index 00000000..b742fa7a --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/UnionResolver.kt @@ -0,0 +1,39 @@ +package com.github.avrokotlin.avro4k.internal + +import org.apache.avro.Schema +import java.util.concurrent.ConcurrentHashMap + +internal class UnionResolver { + /** + * For a given union schema, we cache the possible schemas by fullName or alias. + */ + private val cache: MutableMap> = ConcurrentHashMap() + + fun tryResolveUnion( + schema: Schema, + typeName: String, + ): Schema? { + if (schema.type != Schema.Type.UNION) { + if (schema.fullName == typeName || schema.isNamedType() && typeName in schema.aliases) { + return schema + } + return null + } + return cache.getOrPut(schema) { loadCache(schema) }[typeName] + } + + private fun loadCache(unionSchema: Schema): MutableMap { + val possibleSchemasByNameOrAlias = mutableMapOf() + for (type in unionSchema.types) { + possibleSchemasByNameOrAlias[type.fullName] = type + if (type.isNamedType()) { + for (alias in type.aliases) { + possibleSchemasByNameOrAlias[alias] = type + } + } + } + return possibleSchemasByNameOrAlias + } +} + +private fun Schema.isNamedType() = type == Schema.Type.FIXED || type == Schema.Type.ENUM || type == Schema.Type.RECORD \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt new file mode 100644 index 00000000..c08cd0e1 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt @@ -0,0 +1,62 @@ +@file:Suppress("FunctionName") + +package com.github.avrokotlin.avro4k.internal + +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.encoding.Decoder +import kotlin.reflect.KClass + +class AvroSchemaGenerationException(message: String) : SerializationException(message) + +context(Decoder) +internal fun DecodedNullError() = SerializationException("Unexpected null value, Decoder.decodeTaggedNotNullMark should be called first") + +context(Decoder) +internal fun DecodedNullError( + descriptor: SerialDescriptor, + elementIndex: Int, +) = SerializationException( + "Unexpected null value for field '${descriptor.getElementName(elementIndex)}' for type '${descriptor.serialName}', Decoder.decodeTaggedNotNullMark should be called first" +) + +internal fun Decoder.IllegalIndexedAccessError() = UnsupportedOperationException("${this::class.qualifiedName} does not support indexed access") + +context(Decoder) +internal inline fun BadDecodedValueError( + value: Any?, + firstExpectedType: KClass<*>, + vararg expectedTypes: KClass<*>, +): SerializationException { + val allExpectedTypes = listOf(firstExpectedType) + expectedTypes + return if (value == null) { + SerializationException( + "Decoded null value for ${ExpectedType::class.qualifiedName} kind, expected one of [${allExpectedTypes.joinToString { it.qualifiedName!! }}]" + ) + } else { + SerializationException( + "Decoded value '$value' of type ${value::class.qualifiedName} for " + + "${ExpectedType::class.qualifiedName} kind, expected one of [${allExpectedTypes.joinToString { it.qualifiedName!! }}]" + ) + } +} + +context(Decoder) +internal fun BadDecodedValueError( + value: Any?, + expectedKind: SerialKind, + firstExpectedType: KClass<*>, + vararg expectedTypes: KClass<*>, +): SerializationException { + val allExpectedTypes = listOf(firstExpectedType) + expectedTypes + return if (value == null) { + SerializationException( + "Decoded null value for $expectedKind kind, expected one of [${allExpectedTypes.joinToString { it.qualifiedName!! }}]" + ) + } else { + SerializationException( + "Decoded value '$value' of type ${value::class.qualifiedName} for $expectedKind kind, expected one of [${allExpectedTypes.joinToString { it.qualifiedName!! }}]" + ) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt deleted file mode 100644 index 7e5f8b6f..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt +++ /dev/null @@ -1,5 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import kotlinx.serialization.SerializationException - -class AvroSchemaGenerationException(message: String) : SerializationException(message) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt index 71d367f4..82f2ac8f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt @@ -7,29 +7,28 @@ import org.apache.avro.Schema internal class ClassVisitor( descriptor: SerialDescriptor, - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorClassVisitor, AvroVisitorContextAware { +) : SerialDescriptorClassVisitor { private val fields = mutableListOf() private val schemaAlreadyResolved: Boolean private val schema: Schema init { - val recordName = descriptor.getAvroName() var schemaAlreadyResolved = true schema = - context.resolvedSchemas.getOrPut(recordName) { + context.resolvedSchemas.getOrPut(descriptor.nonNullSerialName) { schemaAlreadyResolved = false val annotations = TypeAnnotations(descriptor) val schema = Schema.createRecord( // name = - recordName.name, + descriptor.nonNullSerialName, // doc = annotations.doc?.value, // namespace = - recordName.namespace, + null, // isError = false ) @@ -55,7 +54,7 @@ internal class ClassVisitor( ) { fields.add( createField( - descriptor.getElementAvroName(elementIndex), + context.avro.configuration.fieldNamingStrategy.resolve(descriptor, elementIndex), FieldAnnotations(descriptor, elementIndex), it ) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt index 851424bd..fd1bfe8d 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt @@ -6,7 +6,6 @@ interface FieldNamingStrategy { fun resolve( descriptor: SerialDescriptor, elementIndex: Int, - serialName: String, ): String companion object Builtins { @@ -17,8 +16,7 @@ interface FieldNamingStrategy { override fun resolve( descriptor: SerialDescriptor, elementIndex: Int, - serialName: String, - ) = serialName + ) = descriptor.getElementName(elementIndex) } /** @@ -28,37 +26,38 @@ interface FieldNamingStrategy { override fun resolve( descriptor: SerialDescriptor, elementIndex: Int, - serialName: String, ): String = - buildString(serialName.length * 2) { - var bufferedChar: Char? = null - var previousUpperCharsCount = 0 + descriptor.getElementName(elementIndex).let { serialName -> + buildString(serialName.length * 2) { + var bufferedChar: Char? = null + var previousUpperCharsCount = 0 - serialName.forEach { c -> - if (c.isUpperCase()) { - if (previousUpperCharsCount == 0 && isNotEmpty() && last() != '_') { - append('_') - } + serialName.forEach { c -> + if (c.isUpperCase()) { + if (previousUpperCharsCount == 0 && isNotEmpty() && last() != '_') { + append('_') + } - bufferedChar?.let(::append) + bufferedChar?.let(::append) - previousUpperCharsCount++ - bufferedChar = c.lowercaseChar() - } else { - if (bufferedChar != null) { - if (previousUpperCharsCount > 1 && c.isLetter()) { - append('_') + previousUpperCharsCount++ + bufferedChar = c.lowercaseChar() + } else { + if (bufferedChar != null) { + if (previousUpperCharsCount > 1 && c.isLetter()) { + append('_') + } + append(bufferedChar) + previousUpperCharsCount = 0 + bufferedChar = null } - append(bufferedChar) - previousUpperCharsCount = 0 - bufferedChar = null + append(c) } - append(c) } - } - if (bufferedChar != null) { - append(bufferedChar) + if (bufferedChar != null) { + append(bufferedChar) + } } } } @@ -70,8 +69,7 @@ interface FieldNamingStrategy { override fun resolve( descriptor: SerialDescriptor, elementIndex: Int, - serialName: String, - ): String = serialName.replaceFirstChar { it.uppercaseChar() } + ): String = descriptor.getElementName(elementIndex).replaceFirstChar { it.uppercaseChar() } } } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt index a66b74b7..d447959d 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt @@ -4,9 +4,9 @@ import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema internal class InlineClassVisitor( - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorInlineClassVisitor, AvroVisitorContextAware { +) : SerialDescriptorInlineClassVisitor { override fun visitInlineClassElement( inlineClassDescriptor: SerialDescriptor, inlineElementIndex: Int, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt index 2cc0d87f..eb0a7572 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt @@ -4,9 +4,9 @@ import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema internal class ListVisitor( - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorListVisitor, AvroVisitorContextAware { +) : SerialDescriptorListVisitor { private lateinit var itemSchema: Schema override fun visitListItem( diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt index cb3af283..803fc30f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt @@ -1,12 +1,13 @@ package com.github.avrokotlin.avro4k.schema +import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema internal class MapVisitor( - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorMapVisitor, AvroVisitorContextAware { +) : SerialDescriptorMapVisitor { private lateinit var valueSchema: Schema override fun visitMapKey( @@ -51,8 +52,10 @@ private fun Schema.isStringable(): Boolean = -> true Schema.Type.NULL, - Schema.Type.BYTES, // bytes could be stringified, but it's not a good idea as it can produce unreadable strings. - Schema.Type.FIXED, // same, just bytes. Btw, if the user wants to stringify it, he can use @Contextual or custom @Serializable serializer. + // bytes could be stringified, but it's not a good idea as it can produce unreadable strings. + Schema.Type.BYTES, + // same, just bytes. Btw, if the user wants to stringify it, he can use @Contextual or custom @Serializable serializer. + Schema.Type.FIXED, Schema.Type.ARRAY, Schema.Type.MAP, Schema.Type.RECORD, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt index 553559de..c7fb8520 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt @@ -1,12 +1,13 @@ package com.github.avrokotlin.avro4k.schema +import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema internal class PolymorphicVisitor( - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorPolymorphicVisitor, AvroVisitorContextAware { +) : SerialDescriptorPolymorphicVisitor { private val possibleSchemas = mutableListOf() override fun visitPolymorphicFoundDescriptor(descriptor: SerialDescriptor): SerialDescriptorValueVisitor { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt index 60551f7f..fead2db8 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt @@ -187,7 +187,7 @@ private fun SerialDescriptor.getNonNullContextualDescriptor(serializersModule: S } @ExperimentalSerializationApi -private fun SerialDescriptor.possibleSerializationSubclasses(serializersModule: SerializersModule): Sequence { +internal fun SerialDescriptor.possibleSerializationSubclasses(serializersModule: SerializersModule): Sequence { return when (this.kind) { PolymorphicKind.SEALED -> elementDescriptors.asSequence() diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/TypeNamingStrategy.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/TypeNamingStrategy.kt deleted file mode 100644 index f4105916..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/TypeNamingStrategy.kt +++ /dev/null @@ -1,33 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import kotlinx.serialization.descriptors.SerialDescriptor - -interface TypeNamingStrategy { - fun resolve( - descriptor: SerialDescriptor, - serialName: String, - ): TypeName - - companion object Builtins { - /** - * Extract the record name from the fully qualified class name by taking the last part of the class name as the record name and the rest as the namespace. - * - * If there is no dot, then the namespace is null. - */ - object FullyQualified : TypeNamingStrategy { - override fun resolve( - descriptor: SerialDescriptor, - serialName: String, - ): TypeName { - val lastDot = serialName.lastIndexOf('.').takeIf { it >= 0 && it + 1 < serialName.length } - val lastIndex = if (serialName.endsWith('?')) serialName.length - 1 else serialName.length - return TypeName( - name = lastDot?.let { serialName.substring(lastDot + 1, lastIndex) } ?: serialName.substring(0, lastIndex), - namespace = lastDot?.let { serialName.substring(0, lastDot) }?.takeIf { it.isNotEmpty() } - ) - } - } - } -} - -data class TypeName(val name: String, val namespace: String?) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt index 13f2e1a0..05db8322 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt @@ -4,11 +4,12 @@ import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroFixed import com.github.avrokotlin.avro4k.AvroLogicalType import com.github.avrokotlin.avro4k.AvroSchema +import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.SerialKind -import kotlinx.serialization.json.Json +import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.modules.SerializersModule import org.apache.avro.LogicalType import org.apache.avro.Schema @@ -16,9 +17,9 @@ import org.apache.avro.SchemaBuilder import kotlin.reflect.KClass internal class ValueVisitor internal constructor( - override val context: VisitorContext, + private val context: VisitorContext, private val onSchemaBuilt: (Schema) -> Unit, -) : SerialDescriptorValueVisitor, AvroVisitorContextAware { +) : SerialDescriptorValueVisitor { private var isNullable: Boolean = false private var logicalType: LogicalType? = null @@ -28,8 +29,7 @@ internal class ValueVisitor internal constructor( constructor(avro: Avro, onSchemaBuilt: (Schema) -> Unit) : this( VisitorContext( avro, - mutableMapOf(), - Json { serializersModule = avro.serializersModule } + mutableMapOf() ), onSchemaBuilt = onSchemaBuilt ) @@ -40,12 +40,9 @@ internal class ValueVisitor internal constructor( ) = setSchema(Schema.create(kind.toAvroType())) override fun visitEnum(descriptor: SerialDescriptor) { - val enumName = descriptor.getAvroName() - val annotations = TypeAnnotations(descriptor) val schema = - SchemaBuilder.enumeration(enumName.name) - .namespace(enumName.namespace) + SchemaBuilder.enumeration(descriptor.nonNullSerialName) .doc(annotations.doc?.value) .defaultSymbol(annotations.enumDefault?.value) .symbols(*descriptor.elementNamesArray) @@ -95,7 +92,7 @@ internal class ValueVisitor internal constructor( val parentFieldName = fixed.elementIndex?.let { fixed.descriptor.getElementName(it) } ?: throw AvroSchemaGenerationException("@AvroFixed must be used on a field") - val parentNamespace = fixed.descriptor.getAvroName().namespace + val parentNamespace = fixed.descriptor.serialName.substringBeforeLast('.', "").takeIf { it.isNotEmpty() } setSchema( SchemaBuilder.fixed(parentFieldName) @@ -144,4 +141,19 @@ private fun Schema.toNullableSchema(): Schema { } else { Schema.createUnion(Schema.create(Schema.Type.NULL), this) } -} \ No newline at end of file +} + +private fun PrimitiveKind.toAvroType() = + when (this) { + PrimitiveKind.BOOLEAN -> Schema.Type.BOOLEAN + PrimitiveKind.CHAR -> Schema.Type.INT + PrimitiveKind.BYTE -> Schema.Type.INT + PrimitiveKind.SHORT -> Schema.Type.INT + PrimitiveKind.INT -> Schema.Type.INT + PrimitiveKind.LONG -> Schema.Type.LONG + PrimitiveKind.FLOAT -> Schema.Type.FLOAT + PrimitiveKind.DOUBLE -> Schema.Type.DOUBLE + PrimitiveKind.STRING -> Schema.Type.STRING + } + +private fun SerialDescriptor.isByteArray(): Boolean = kind == StructureKind.LIST && getElementDescriptor(0).let { !it.isNullable && it.kind == PrimitiveKind.BYTE } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt index 984a6d50..c83efa72 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt @@ -1,10 +1,7 @@ package com.github.avrokotlin.avro4k.schema import com.fasterxml.jackson.databind.JsonNode -import com.fasterxml.jackson.databind.node.ArrayNode -import com.fasterxml.jackson.databind.node.JsonNodeFactory -import com.fasterxml.jackson.databind.node.NullNode -import com.fasterxml.jackson.databind.node.ObjectNode +import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.TextNode import com.github.avrokotlin.avro4k.AnnotatedLocation import com.github.avrokotlin.avro4k.Avro @@ -34,17 +31,12 @@ import org.apache.avro.Schema internal data class VisitorContext( val avro: Avro, - val resolvedSchemas: MutableMap, - val json: Json, + val resolvedSchemas: MutableMap, val inlinedAnnotations: ValueAnnotations? = null, ) internal fun VisitorContext.resetNesting() = copy(inlinedAnnotations = null) -internal interface AvroVisitorContextAware { - val context: VisitorContext -} - /** * Contains all the annotations for a field of a class (kind == CLASS && isInline == true). */ @@ -169,20 +161,20 @@ internal fun ValueAnnotations?.appendAnnotations(other: ValueAnnotations) = stack = (this?.stack ?: emptyList()) + other.stack ) -context(AvroVisitorContextAware) +private val objectMapper = ObjectMapper() + internal val AvroJsonProp.jsonNode: JsonNode get() { if (jsonValue.isStartingAsJson()) { - return context.json.parseToJsonElement(jsonValue).toJacksonNode() + return objectMapper.readTree(jsonValue) } return TextNode.valueOf(jsonValue) } -context(AvroVisitorContextAware) internal val AvroDefault.jsonValue: Any get() { if (value.isStartingAsJson()) { - return context.json.parseToJsonElement(value).toAvroObject() + return Json.parseToJsonElement(value).toAvroObject() } return value } @@ -212,36 +204,12 @@ private fun JsonElement.toAvroObject(): Any = this.booleanOrNull != null -> this.boolean else -> { this.content.toBigDecimal().stripTrailingZeros().let { - if (it.scale() <= 0) it.toBigInteger() else it + if (it.scale() <= 0) { + it.toBigInteger() + } else { + it + } } } } - } - -private fun JsonElement.toJacksonNode(): JsonNode = - when (this) { - is JsonNull -> NullNode.instance - is JsonObject -> ObjectNode(JsonNodeFactory.instance, this.entries.associate { it.key to it.value.toJacksonNode() }) - is JsonArray -> ArrayNode(JsonNodeFactory.instance, this.map { it.toJacksonNode() }) - is JsonPrimitive -> - when { - this.isString -> JsonNodeFactory.instance.textNode(this.content) - this.booleanOrNull != null -> JsonNodeFactory.instance.booleanNode(this.boolean) - else -> - this.content.toBigDecimal().let { - if (it.scale() <= 0) JsonNodeFactory.instance.numberNode(it.toBigInteger()) else JsonNodeFactory.instance.numberNode(it) - } - } - } - -/** - * Get the record/enum name using the configured record naming strategy. - */ -context(AvroVisitorContextAware) -internal fun SerialDescriptor.getAvroName() = context.avro.configuration.typeNamingStrategy.resolve(this, serialName) - -/** - * Get the field name using the configured field naming strategy. - */ -context(AvroVisitorContextAware) -internal fun SerialDescriptor.getElementAvroName(elementIndex: Int) = context.avro.configuration.fieldNamingStrategy.resolve(this, elementIndex, getElementName(elementIndex)) \ No newline at end of file + } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt index 1a641704..5b468d87 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt @@ -1,23 +1,31 @@ package com.github.avrokotlin.avro4k.schema -import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind import org.apache.avro.Schema -inline fun SerialDescriptor.findAnnotation() = annotations.asSequence().filterIsInstance().firstOrNull() +inline fun SerialDescriptor.findAnnotation() = annotations.firstNotNullOfOrNull { it as? T } inline fun SerialDescriptor.findAnnotations() = annotations.filterIsInstance() -inline fun SerialDescriptor.findElementAnnotation(elementIndex: Int) = getElementAnnotations(elementIndex).asSequence().filterIsInstance().firstOrNull() +inline fun SerialDescriptor.findElementAnnotation(elementIndex: Int) = getElementAnnotations(elementIndex).firstNotNullOfOrNull { it as? T } inline fun SerialDescriptor.findElementAnnotations(elementIndex: Int) = getElementAnnotations(elementIndex).filterIsInstance() -internal fun Schema.extractNonNull(): Schema = - when (this.type) { - Schema.Type.UNION -> this.types.filter { it.type != Schema.Type.NULL }.let { if (it.size > 1) Schema.createUnion(it) else it[0] } - else -> this +internal val SerialDescriptor.nonNullSerialName: String get() = serialName.removeSuffix('?') + +private fun String.removeSuffix(suffix: Char): String { + if (lastOrNull() == suffix) { + return substring(0, length - 1) } + return this +} + +internal val Schema.nonNull: Schema + get() = + when { + type == Schema.Type.UNION && isNullable -> this.types.filter { it.type != Schema.Type.NULL }.let { if (it.size > 1) Schema.createUnion(it) else it[0] } + else -> this + } /** * Overrides the namespace of a [Schema] with the given namespace. @@ -37,29 +45,19 @@ internal fun Schema.overrideNamespace(namespaceOverride: String): Schema { } val copy = Schema.createRecord(name, doc, namespaceOverride, isError, fields) aliases.forEach { copy.addAlias(it) } - this.objectProps.forEach { copy.addProp(it.key, it.value) } copy } + Schema.Type.UNION -> Schema.createUnion(types.map { it.overrideNamespace(namespaceOverride) }) - Schema.Type.ENUM -> Schema.createEnum(name, doc, namespaceOverride, enumSymbols, enumDefault) - Schema.Type.FIXED -> Schema.createFixed(name, doc, namespaceOverride, fixedSize) + Schema.Type.ENUM -> + Schema.createEnum(name, doc, namespaceOverride, enumSymbols, enumDefault) + .also { aliases.forEach { alias -> it.addAlias(alias) } } + Schema.Type.FIXED -> + Schema.createFixed(name, doc, namespaceOverride, fixedSize) + .also { aliases.forEach { alias -> it.addAlias(alias) } } Schema.Type.MAP -> Schema.createMap(valueType.overrideNamespace(namespaceOverride)) Schema.Type.ARRAY -> Schema.createArray(elementType.overrideNamespace(namespaceOverride)) else -> this } -} - -internal fun SerialDescriptor.isByteArray(): Boolean = kind == StructureKind.LIST && getElementDescriptor(0).let { !it.isNullable && it.kind == PrimitiveKind.BYTE } - -internal fun PrimitiveKind.toAvroType() = - when (this) { - PrimitiveKind.BOOLEAN -> Schema.Type.BOOLEAN - PrimitiveKind.CHAR -> Schema.Type.INT - PrimitiveKind.BYTE -> Schema.Type.INT - PrimitiveKind.SHORT -> Schema.Type.INT - PrimitiveKind.INT -> Schema.Type.INT - PrimitiveKind.LONG -> Schema.Type.LONG - PrimitiveKind.FLOAT -> Schema.Type.FLOAT - PrimitiveKind.DOUBLE -> Schema.Type.DOUBLE - PrimitiveKind.STRING -> Schema.Type.STRING - } \ No newline at end of file + .also { objectProps.forEach { prop -> it.addProp(prop.key, prop.value) } } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt index 58a5469e..a4b88094 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt @@ -1,51 +1,51 @@ package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.decoder.FieldDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import com.github.avrokotlin.avro4k.encoder.FieldEncoder -import com.github.avrokotlin.avro4k.schema.extractNonNull +import com.github.avrokotlin.avro4k.decoder.AvroDecoder +import com.github.avrokotlin.avro4k.encoder.AvroEncoder import kotlinx.serialization.KSerializer import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder -import org.apache.avro.Schema abstract class AvroSerializer : KSerializer { final override fun serialize( encoder: Encoder, value: T, ) { - val schema = - (encoder as FieldEncoder).fieldSchema().let { - if (!this.descriptor.isNullable && it.isNullable) { - it.extractNonNull() - } else { - it - } - } - encodeAvroValue(schema, encoder, value) + if (encoder is AvroEncoder) { + serializeAvro(encoder, value) + return + } + serializeGeneric(encoder, value) } - final override fun deserialize(decoder: Decoder): T { - val schema = - (decoder as FieldDecoder).fieldSchema().let { - if (!this.descriptor.isNullable && it.isNullable) { - it.extractNonNull() - } else { - it - } - } - return decodeAvroValue(schema, decoder) + /** + * This method is called when the serializer is used outside Avro serialization. + * By default, it throws an exception. + * + * Implement it to provide a generic serialization logic with the standard [Encoder]. + */ + open fun serializeGeneric( + encoder: Encoder, + value: T, + ) { + throw UnsupportedOperationException("The serializer ${this::class.qualifiedName} is not usable outside of Avro serialization.") } - abstract fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: T, + abstract fun serializeAvro( + encoder: AvroEncoder, + value: T, ) - abstract fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): T + final override fun deserialize(decoder: Decoder): T { + if (decoder !is AvroDecoder) { + return deserializeGeneric(decoder) + } + return deserializeAvro(decoder) + } + + open fun deserializeGeneric(decoder: Decoder): T { + throw UnsupportedOperationException("The serializer ${this::class.qualifiedName} is not usable outside of Avro serialization.") + } + + abstract fun deserializeAvro(decoder: AvroDecoder): T } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt index 209e42c9..6e4ba02f 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt @@ -4,12 +4,19 @@ import com.github.avrokotlin.avro4k.AnnotatedLocation import com.github.avrokotlin.avro4k.AvroDecimal import com.github.avrokotlin.avro4k.AvroLogicalType import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder +import com.github.avrokotlin.avro4k.decoder.AvroDecoder +import com.github.avrokotlin.avro4k.encoder.AvroEncoder +import com.github.avrokotlin.avro4k.encoder.SchemaTypeMatcher +import com.github.avrokotlin.avro4k.encoder.encodeValueResolved import com.github.avrokotlin.avro4k.schema.findElementAnnotation +import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import org.apache.avro.Conversions import org.apache.avro.LogicalType import org.apache.avro.LogicalTypes @@ -28,22 +35,28 @@ object BigDecimalSerializer : AvroSerializer(), AvroLogicalTypeSuppl } ?: defaultAnnotation.logicalType } + @OptIn(InternalSerializationApi::class) override val descriptor = - buildByteArraySerialDescriptor( - BigDecimal::class.qualifiedName!!, - AvroLogicalType(BigDecimalSerializer::class) - ) - - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: BigDecimal, - ) = encodeBigDecimal(schema, encoder, obj) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ) = decodeBigDecimal(decoder, schema) + buildSerialDescriptor(BigDecimal::class.qualifiedName!!, StructureKind.LIST) { + element("item", buildSerialDescriptor("item", PrimitiveKind.BYTE)) + this.annotations = listOf(AvroLogicalType(BigDecimalSerializer::class)) + } + + override fun serializeAvro( + encoder: AvroEncoder, + value: BigDecimal, + ) = encodeBigDecimal(encoder, value) + + override fun serializeGeneric( + encoder: Encoder, + value: BigDecimal, + ) = encoder.encodeString(value.toString()) + + override fun deserializeAvro(decoder: AvroDecoder) = decodeBigDecimal(decoder) + + override fun deserializeGeneric(decoder: Decoder): BigDecimal { + return decoder.decodeString().toBigDecimal() + } private val AvroDecimal.logicalType: LogicalType get() { @@ -54,58 +67,49 @@ object BigDecimalSerializer : AvroSerializer(), AvroLogicalTypeSuppl object BigDecimalAsStringSerializer : AvroSerializer() { override val descriptor = PrimitiveSerialDescriptor(BigDecimal::class.qualifiedName!!, PrimitiveKind.STRING) - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: BigDecimal, - ) = encodeBigDecimal(schema, encoder, obj) + override fun serializeAvro( + encoder: AvroEncoder, + value: BigDecimal, + ) = encodeBigDecimal(encoder, value) + + override fun serializeGeneric( + encoder: Encoder, + value: BigDecimal, + ) = encoder.encodeString(value.toString()) + + override fun deserializeAvro(decoder: AvroDecoder) = decodeBigDecimal(decoder) - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ) = decodeBigDecimal(decoder, schema) + override fun deserializeGeneric(decoder: Decoder): BigDecimal { + return decoder.decodeString().toBigDecimal() + } } private fun encodeBigDecimal( - schema: Schema, - encoder: ExtendedEncoder, + encoder: AvroEncoder, value: BigDecimal, ) { - when (schema.type) { - Schema.Type.STRING -> encoder.encodeString(value.toString()) - Schema.Type.BYTES -> { - encoder.encodeByteArray(converter.toBytes(value, schema, schema.getDecimalLogicalType())) - } - - Schema.Type.FIXED -> { - encoder.encodeFixed(converter.toFixed(value, schema, schema.getDecimalLogicalType())) - } - - Schema.Type.INT -> encoder.encodeInt(value.intValueExact()) - Schema.Type.LONG -> encoder.encodeLong(value.longValueExact()) - Schema.Type.FLOAT -> encoder.encodeFloat(value.toFloat()) - Schema.Type.DOUBLE -> encoder.encodeDouble(value.toDouble()) - - else -> throw SerializationException("Cannot encode BigDecimal as ${schema.type}") - } + encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.BYTES to { converter.toBytes(value, it, it.getDecimalLogicalType()) }, + SchemaTypeMatcher.Named.FirstFixed to { converter.toFixed(value, it, it.getDecimalLogicalType()) }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.INT to { value.intValueExact() }, + SchemaTypeMatcher.Scalar.LONG to { value.longValueExact() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() } + ) } -private fun decodeBigDecimal( - decoder: ExtendedDecoder, - schema: Schema, -): BigDecimal = - // TODO we should use the schema instead of this generic decodeAny() - when (val v = decoder.decodeAny()) { +private fun decodeBigDecimal(decoder: AvroDecoder): BigDecimal = + when (val v = decoder.decodeValue()) { is CharSequence -> BigDecimal(v.toString()) - is ByteArray -> converter.fromBytes(ByteBuffer.wrap(v), schema, schema.getDecimalLogicalType()) - is ByteBuffer -> converter.fromBytes(v, schema, schema.getDecimalLogicalType()) - is GenericFixed -> converter.fromFixed(v, schema, schema.getDecimalLogicalType()) + is ByteArray -> converter.fromBytes(ByteBuffer.wrap(v), decoder.currentWriterSchema, decoder.currentWriterSchema.getDecimalLogicalType()) + is ByteBuffer -> converter.fromBytes(v, decoder.currentWriterSchema, decoder.currentWriterSchema.getDecimalLogicalType()) + is GenericFixed -> converter.fromFixed(v, decoder.currentWriterSchema, decoder.currentWriterSchema.getDecimalLogicalType()) else -> throw SerializationException("Unsupported BigDecimal type [$v]") } private fun Schema.getDecimalLogicalType(): LogicalTypes.Decimal { - val l = logicalType - return when (l) { + return when (val l = logicalType) { is LogicalTypes.Decimal -> l else -> throw SerializationException("Expected to find a decimal logical type for BigDecimal but found $l") } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt index 15d3ade5..1ae78fe6 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt @@ -1,40 +1,46 @@ package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder +import com.github.avrokotlin.avro4k.decoder.AvroDecoder +import com.github.avrokotlin.avro4k.encoder.AvroEncoder +import com.github.avrokotlin.avro4k.encoder.SchemaTypeMatcher +import com.github.avrokotlin.avro4k.encoder.encodeValueResolved import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import org.apache.avro.Schema import java.math.BigInteger object BigIntegerSerializer : AvroSerializer() { override val descriptor = PrimitiveSerialDescriptor(BigInteger::class.qualifiedName!!, PrimitiveKind.STRING) - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: BigInteger, - ) = when (schema.type) { - Schema.Type.STRING -> encoder.encodeString(obj.toString()) - Schema.Type.INT -> encoder.encodeInt(obj.intValueExact()) - Schema.Type.LONG -> encoder.encodeLong(obj.longValueExact()) - Schema.Type.FLOAT -> encoder.encodeFloat(obj.toFloat()) - Schema.Type.DOUBLE -> encoder.encodeDouble(obj.toDouble()) + override fun serializeAvro( + encoder: AvroEncoder, + value: BigInteger, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.STRING to { value.toString() }, + SchemaTypeMatcher.Scalar.INT to { value.intValueExact() }, + SchemaTypeMatcher.Scalar.LONG to { value.longValueExact() }, + SchemaTypeMatcher.Scalar.FLOAT to { value.toFloat() }, + SchemaTypeMatcher.Scalar.DOUBLE to { value.toDouble() } + ) - else -> throw UnsupportedOperationException("Unsupported schema type: $schema") - } + override fun serializeGeneric( + encoder: Encoder, + value: BigInteger, + ) = encoder.encodeString(value.toString()) - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): BigInteger = - when (schema.type) { + override fun deserializeAvro(decoder: AvroDecoder): BigInteger = + when (decoder.currentWriterSchema.type) { Schema.Type.STRING -> BigInteger(decoder.decodeString()) Schema.Type.INT -> BigInteger.valueOf(decoder.decodeInt().toLong()) Schema.Type.LONG -> BigInteger.valueOf(decoder.decodeLong()) Schema.Type.FLOAT -> BigInteger.valueOf(decoder.decodeFloat().toLong()) Schema.Type.DOUBLE -> BigInteger.valueOf(decoder.decodeDouble().toLong()) - - else -> throw UnsupportedOperationException("Unsupported schema type for BigInteger: $schema") + else -> throw UnsupportedOperationException("Unsupported schema type for BigInteger: ${decoder.currentWriterSchema}") } + + override fun deserializeGeneric(decoder: Decoder): BigInteger { + return decoder.decodeString().toBigInteger() + } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt index d9941623..f8ee41d5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt @@ -1,32 +1,19 @@ package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import kotlinx.serialization.SerializationException +import kotlinx.serialization.KSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor -import org.apache.avro.Schema +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import java.net.URL -object URLSerializer : AvroSerializer() { +object URLSerializer : KSerializer { override val descriptor = PrimitiveSerialDescriptor(URL::class.qualifiedName!!, PrimitiveKind.STRING) - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: URL, - ) { - encoder.encodeString(obj.toString()) - } + override fun serialize( + encoder: Encoder, + value: URL, + ) = encoder.encodeString(value.toString()) - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): URL { - return when (val v = decoder.decodeAny()) { - is CharSequence -> URL(v.toString()) - null -> throw SerializationException("Cannot decode as URL") - else -> throw SerializationException("Unsupported URL type [$v : ${v::class.qualifiedName}]") - } - } + override fun deserialize(decoder: Decoder): URL = URL(decoder.decodeString()) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt index 0dd3bb9c..f9cb1b36 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt @@ -3,17 +3,17 @@ package com.github.avrokotlin.avro4k.serializer import com.github.avrokotlin.avro4k.AnnotatedLocation import com.github.avrokotlin.avro4k.AvroLogicalType import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.KSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import org.apache.avro.LogicalType import org.apache.avro.LogicalTypes -import org.apache.avro.Schema import java.util.UUID -object UUIDSerializer : AvroSerializer(), AvroLogicalTypeSupplier { +object UUIDSerializer : KSerializer, AvroLogicalTypeSupplier { @OptIn(InternalSerializationApi::class) override val descriptor = buildSerialDescriptor("uuid", PrimitiveKind.STRING) { @@ -24,14 +24,10 @@ object UUIDSerializer : AvroSerializer(), AvroLogicalTypeSupplier { return LogicalTypes.uuid() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: UUID, - ) = encoder.encodeString(obj.toString()) + override fun serialize( + encoder: Encoder, + value: UUID, + ) = encoder.encodeString(value.toString()) - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): UUID = UUID.fromString(decoder.decodeString()) + override fun deserialize(decoder: Decoder): UUID = UUID.fromString(decoder.decodeString()) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt index c7348496..116f8369 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt @@ -3,16 +3,19 @@ package com.github.avrokotlin.avro4k.serializer import com.github.avrokotlin.avro4k.AnnotatedLocation import com.github.avrokotlin.avro4k.AvroLogicalType import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier -import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder -import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder +import com.github.avrokotlin.avro4k.decoder.AvroDecoder +import com.github.avrokotlin.avro4k.encoder.AvroEncoder +import com.github.avrokotlin.avro4k.encoder.SchemaTypeMatcher +import com.github.avrokotlin.avro4k.encoder.encodeValueResolved import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import org.apache.avro.LogicalType import org.apache.avro.LogicalTypes -import org.apache.avro.Schema import java.sql.Timestamp import java.time.Instant import java.time.LocalDate @@ -27,16 +30,34 @@ object LocalDateSerializer : AvroTimeSerializer(LocalDate::class, Pri return LogicalTypes.date() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: LocalDate, - ) = encoder.encodeInt(obj.toEpochDay().toInt()) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): LocalDate = LocalDate.ofEpochDay(decoder.decodeInt().toLong()) + override fun serializeAvro( + encoder: AvroEncoder, + value: LocalDate, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.INT to { + when (it.logicalType) { + is LogicalTypes.Date, null -> value.toEpochDay().toInt() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + // Date is not compatible with LONG, so we require a null logical type to encode the timestamp + null -> value.toEpochDay() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: LocalDate, + ) = encoder.encodeInt(value.toEpochDay().toInt()) + + override fun deserializeAvro(decoder: AvroDecoder): LocalDate = deserializeGeneric(decoder) + + override fun deserializeGeneric(decoder: Decoder) = LocalDate.ofEpochDay(decoder.decodeInt().toLong()) } object LocalTimeSerializer : AvroTimeSerializer(LocalTime::class, PrimitiveKind.INT) { @@ -44,23 +65,48 @@ object LocalTimeSerializer : AvroTimeSerializer(LocalTime::class, Pri return LogicalTypes.timeMillis() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: LocalTime, - ) = encoder.encodeInt(obj.toSecondOfDay() * 1000 + obj.nano / 1000) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): LocalTime { + override fun serializeAvro( + encoder: AvroEncoder, + value: LocalTime, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.INT to { + when (it.logicalType) { + is LogicalTypes.TimeMillis, null -> value.toMillisOfDay() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + // TimeMillis is not compatible with LONG, so we require a null logical type to encode the timestamp + null -> value.toMillisOfDay().toLong() + is LogicalTypes.TimeMicros -> value.toMicroOfDay() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.truncatedTo(ChronoUnit.MILLIS).toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: LocalTime, + ) = encoder.encodeInt(value.toMillisOfDay()) + + private fun LocalTime.toMillisOfDay() = (toNanoOfDay() / 1000000).toInt() + + private fun LocalTime.toMicroOfDay() = toNanoOfDay() / 1000 + + override fun deserializeAvro(decoder: AvroDecoder): LocalTime { // avro stores times as either millis since midnight or micros since midnight - return when (schema.logicalType) { + return when (decoder.currentWriterSchema.logicalType) { is LogicalTypes.TimeMicros -> LocalTime.ofNanoOfDay(decoder.decodeInt() * 1000L) is LogicalTypes.TimeMillis -> LocalTime.ofNanoOfDay(decoder.decodeInt() * 1000000L) - else -> throw SerializationException("Unsupported logical type for LocalTime [${schema.logicalType}]") + else -> throw SerializationException("Unsupported logical type for LocalTime [${decoder.currentWriterSchema.logicalType}]") } } + + override fun deserializeGeneric(decoder: Decoder): LocalTime { + return LocalTime.ofNanoOfDay(decoder.decodeInt() * 1000000L) + } } object LocalDateTimeSerializer : AvroTimeSerializer(LocalDateTime::class, PrimitiveKind.LONG) { @@ -68,16 +114,30 @@ object LocalDateTimeSerializer : AvroTimeSerializer(LocalDateTime return LogicalTypes.timestampMillis() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: LocalDateTime, - ) = InstantSerializer.encodeAvroValue(schema, encoder, obj.toInstant(ZoneOffset.UTC)) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): LocalDateTime = LocalDateTime.ofInstant(Instant.ofEpochMilli(decoder.decodeLong()), ZoneOffset.UTC) + override fun serializeAvro( + encoder: AvroEncoder, + value: LocalDateTime, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + is LogicalTypes.TimestampMillis, null -> value.toInstant(ZoneOffset.UTC).toEpochMilli() + is LogicalTypes.TimestampMicros -> value.toInstant(ZoneOffset.UTC).toEpochMicros() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.truncatedTo(ChronoUnit.MILLIS).toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: LocalDateTime, + ) = encoder.encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMilli()) + + override fun deserializeAvro(decoder: AvroDecoder): LocalDateTime = deserializeGeneric(decoder) + + override fun deserializeGeneric(decoder: Decoder): LocalDateTime { + return LocalDateTime.ofInstant(Instant.ofEpochMilli(decoder.decodeLong()), ZoneOffset.UTC) + } } object TimestampSerializer : AvroTimeSerializer(Timestamp::class, PrimitiveKind.LONG) { @@ -85,16 +145,30 @@ object TimestampSerializer : AvroTimeSerializer(Timestamp::class, Pri return LogicalTypes.timestampMillis() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: Timestamp, - ) = InstantSerializer.encodeAvroValue(schema, encoder, obj.toInstant()) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): Timestamp = Timestamp(decoder.decodeLong()) + override fun serializeAvro( + encoder: AvroEncoder, + value: Timestamp, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + is LogicalTypes.TimestampMillis, null -> value.toInstant().toEpochMilli() + is LogicalTypes.TimestampMicros -> value.toInstant().toEpochMicros() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.toInstant().toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: Timestamp, + ) = encoder.encodeLong(value.toInstant().toEpochMilli()) + + override fun deserializeAvro(decoder: AvroDecoder): Timestamp = deserializeGeneric(decoder) + + override fun deserializeGeneric(decoder: Decoder): Timestamp { + return Timestamp(decoder.decodeLong()) + } } object InstantSerializer : AvroTimeSerializer(Instant::class, PrimitiveKind.LONG) { @@ -102,16 +176,30 @@ object InstantSerializer : AvroTimeSerializer(Instant::class, Primitive return LogicalTypes.timestampMillis() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: Instant, - ) = encoder.encodeLong(obj.toEpochMilli()) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): Instant = Instant.ofEpochMilli(decoder.decodeLong()) + override fun serializeAvro( + encoder: AvroEncoder, + value: Instant, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + is LogicalTypes.TimestampMillis, null -> value.toEpochMilli() + is LogicalTypes.TimestampMicros -> value.toEpochMicros() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: Instant, + ) = encoder.encodeLong(value.toEpochMilli()) + + override fun deserializeAvro(decoder: AvroDecoder): Instant = deserializeGeneric(decoder) + + override fun deserializeGeneric(decoder: Decoder): Instant { + return Instant.ofEpochMilli(decoder.decodeLong()) + } } object InstantToMicroSerializer : AvroTimeSerializer(Instant::class, PrimitiveKind.LONG) { @@ -119,16 +207,30 @@ object InstantToMicroSerializer : AvroTimeSerializer(Instant::class, Pr return LogicalTypes.timestampMicros() } - override fun encodeAvroValue( - schema: Schema, - encoder: ExtendedEncoder, - obj: Instant, - ) = encoder.encodeLong(ChronoUnit.MICROS.between(Instant.EPOCH, obj)) - - override fun decodeAvroValue( - schema: Schema, - decoder: ExtendedDecoder, - ): Instant = Instant.EPOCH.plus(decoder.decodeLong(), ChronoUnit.MICROS) + override fun serializeAvro( + encoder: AvroEncoder, + value: Instant, + ) = encoder.encodeValueResolved( + SchemaTypeMatcher.Scalar.LONG to { + when (it.logicalType) { + is LogicalTypes.TimestampMicros, null -> value.toEpochMicros() + is LogicalTypes.TimestampMillis -> value.toEpochMilli() + else -> it.logicalType.throwUnsupportedWith() + } + }, + SchemaTypeMatcher.Scalar.STRING to { value.toString() } + ) + + override fun serializeGeneric( + encoder: Encoder, + value: Instant, + ) = encoder.encodeLong(value.toEpochMicros()) + + override fun deserializeAvro(decoder: AvroDecoder): Instant = deserializeGeneric(decoder) + + override fun deserializeGeneric(decoder: Decoder): Instant { + return Instant.EPOCH.plus(decoder.decodeLong(), ChronoUnit.MICROS) + } } @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) @@ -140,4 +242,10 @@ abstract class AvroTimeSerializer( buildSerialDescriptor(klass.qualifiedName!!, kind) { annotations = listOf(AvroLogicalType(this@AvroTimeSerializer::class)) } -} \ No newline at end of file +} + +private inline fun LogicalType?.throwUnsupportedWith(): Nothing { + throw SerializationException("Unsupported logical type $this for kotlin type ${T::class.qualifiedName}") +} + +private fun Instant.toEpochMicros() = ChronoUnit.MICROS.between(Instant.EPOCH, this) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt deleted file mode 100644 index 982db7af..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt +++ /dev/null @@ -1,20 +0,0 @@ -package com.github.avrokotlin.avro4k.serializer - -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi -import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.descriptors.buildSerialDescriptor - -@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) -fun buildByteArraySerialDescriptor( - serialName: String, - vararg annotations: Annotation, -) = buildSerialDescriptor(serialName, StructureKind.LIST) { - element("item", buildSerialDescriptor("item", PrimitiveKind.BYTE)) - this.annotations = listOf(*annotations) -} - -fun Long.toIntExact(): Int { - return Math.toIntExact(this) -} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt index 8ccf26f9..499f86ce 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroAssertions.kt @@ -31,11 +31,10 @@ class AvroEncodingAssertions( fun isEncodedAs( expectedEncodedGenericValue: Any?, expectedDecodedValue: T = valueToEncode, + writerSchema: Schema = avro.schema(serializer), ): AvroEncodingAssertions { - val writerSchema: Schema = avro.schema(serializer) - val apacheEncodedBytes = avroApacheEncode(expectedEncodedGenericValue, writerSchema) - val actualEncodedBytes = avro4kEncode(valueToEncode, writerSchema) + val apacheEncodedBytes = avroApacheEncode(expectedEncodedGenericValue, writerSchema) withClue("Encoded bytes are not the same as apache avro library.") { if (!actualEncodedBytes.contentEquals(apacheEncodedBytes)) { val expectedAvroJson = bytesToAvroJson(apacheEncodedBytes, writerSchema) @@ -57,7 +56,7 @@ class AvroEncodingAssertions( return this } - inline fun isDecodedAs(expected: R) = isDecodedAs(expected, serializer()) + inline fun isDecodedAs(expected: R) = isDecodedAs(expected, Avro.serializersModule.serializer()) fun isDecodedAs( expected: R, @@ -161,7 +160,7 @@ open class AvroSchemaAssertions( object AvroAssertions { inline fun assertThat(): AvroSchemaAssertions { - return AvroSchemaAssertions(serializer()) + return AvroSchemaAssertions(Avro.serializersModule.serializer()) } fun assertThat(serializer: KSerializer): AvroSchemaAssertions { @@ -169,7 +168,7 @@ object AvroAssertions { } inline fun assertThat(value: T): AvroEncodingAssertions { - return AvroEncodingAssertions(value, serializer()) + return AvroEncodingAssertions(value, Avro.serializersModule.serializer()) } @Suppress("UNCHECKED_CAST") diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFileTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFileTest.kt index bb3a17e3..ffa84e98 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFileTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroObjectContainerFileTest.kt @@ -6,6 +6,7 @@ import kotlinx.serialization.Contextual import kotlinx.serialization.Serializable import org.apache.avro.file.DataFileStream import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericDatumReader import org.apache.avro.generic.GenericDatumWriter import org.apache.avro.generic.GenericRecord @@ -26,7 +27,7 @@ class AvroObjectContainerFileTest : StringSpec({ firstProfile.id.value.toString(), "John Doe", 30, - "Male", + GenericData.EnumSymbol(Avro.schema(), "Male"), null ) val secondProfile = @@ -42,7 +43,7 @@ class AvroObjectContainerFileTest : StringSpec({ secondProfile.id.value.toString(), "Jane Doe", 25, - "Female", + GenericData.EnumSymbol(Avro.schema(), "Female"), record( "New York", "USA" @@ -61,15 +62,16 @@ class AvroObjectContainerFileTest : StringSpec({ it.toByteArray() } // read with apache avro lib - bytes.inputStream().use { - val dataFile = DataFileStream(it, GenericDatumReader(Avro.schema())) - dataFile.getMetaString("meta-string") shouldBe "awesome string" - dataFile.getMetaLong("meta-long") shouldBe 42 - dataFile.getMeta("bytes") shouldBe byteArrayOf(1, 3, 2, 42) - normalizeGenericData(dataFile.next()) shouldBe firstProfileGenericData - normalizeGenericData(dataFile.next()) shouldBe secondProfileGenericData - dataFile.hasNext() shouldBe false - } + val dataFile = + bytes.inputStream().use { + DataFileStream(it, GenericDatumReader(Avro.schema())) + } + dataFile.getMetaString("meta-string") shouldBe "awesome string" + dataFile.getMetaLong("meta-long") shouldBe 42 + dataFile.getMeta("bytes") shouldBe byteArrayOf(1, 3, 2, 42) + normalizeGenericData(dataFile.next()) shouldBe firstProfileGenericData + normalizeGenericData(dataFile.next()) shouldBe secondProfileGenericData + dataFile.hasNext() shouldBe false } "support reading avro object container file with metadata" { // write with apache avro lib @@ -86,17 +88,17 @@ class AvroObjectContainerFileTest : StringSpec({ it.toByteArray() } // read with avro4k - bytes.inputStream().use { - val profiles = + val profiles = + bytes.inputStream().use { AvroObjectContainerFile().decodeFromStream(it) { metadata("meta-string")?.asString() shouldBe "awesome string" metadata("meta-long")?.asLong() shouldBe 42 metadata("bytes")?.asBytes() shouldBe byteArrayOf(1, 3, 2, 42) }.toList() - profiles.size shouldBe 2 - profiles[0] shouldBe firstProfile - profiles[1] shouldBe secondProfile - } + } + profiles.size shouldBe 2 + profiles[0] shouldBe firstProfile + profiles[1] shouldBe secondProfile } }) { @Serializable diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/RecordBuilderForTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/RecordBuilderForTest.kt index 5606749d..a6defe0e 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/RecordBuilderForTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/RecordBuilderForTest.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k -import com.github.avrokotlin.avro4k.schema.extractNonNull +import com.github.avrokotlin.avro4k.schema.nonNull import org.apache.avro.Schema import org.apache.avro.generic.GenericData import org.apache.avro.generic.GenericEnumSymbol @@ -27,7 +27,7 @@ fun convertToAvroGenericValue( value: Any?, schema: Schema, ): Any? { - val schema = if (schema.isNullable) schema.extractNonNull() else schema + val schema = schema.nonNull return when (value) { is RecordBuilderForTest -> value.createRecord(schema) is Map<*, *> -> createMap(schema, value) @@ -52,13 +52,14 @@ fun normalizeGenericData(value: Any?): Any? { is ByteArray -> value.toList() is ByteBuffer -> value.array().toList() is GenericFixed -> value.bytes().toList() - is GenericEnumSymbol<*> -> value.toString() is CharSequence -> value.toString() is RecordBuilderForTest -> RecordBuilderForTest(value.fields.map { normalizeGenericData(it) }) is Byte -> value.toInt() is Short -> value.toInt() - is Boolean, is Char, is Number, null -> value + is GenericEnumSymbol<*>, + is Boolean, is Char, is Number, null, + -> value else -> TODO("Not implemented for ${value.javaClass}") } diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/ArrayEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/ArrayEncodingTest.kt index e0a63d33..0d7d9392 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/ArrayEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/ArrayEncodingTest.kt @@ -7,14 +7,27 @@ import kotlinx.serialization.Serializable class ArrayEncodingTest : StringSpec({ "support array of booleans" { + @Serializable + data class TestArrayBooleans(val booleans: List) + AvroAssertions.assertThat(TestArrayBooleans(listOf(true, false, true))) .isEncodedAs(record(listOf(true, false, true))) } + "support array of nullable booleans" { + @Serializable + data class TestArrayBooleans(val booleans: List) + + AvroAssertions.assertThat(TestArrayBooleans(listOf(true, null, false))) + .isEncodedAs(record(listOf(true, null, false))) + } "support List of doubles" { AvroAssertions.assertThat(TestListDoubles(listOf(12.54, 23.5, 9123.2314))) .isEncodedAs(record(listOf(12.54, 23.5, 9123.2314))) } "support List of records" { + @Serializable + data class TestListRecords(val records: List) + AvroAssertions.assertThat( TestListRecords( listOf( @@ -31,6 +44,28 @@ class ArrayEncodingTest : StringSpec({ ) ) } + "support List of nullable records" { + @Serializable + data class TestListNullableRecords(val records: List) + + AvroAssertions.assertThat( + TestListNullableRecords( + listOf( + Record("qwe", 123.4), + null, + Record("wer", 8234.324) + ) + ) + ).isEncodedAs( + record( + listOf( + record("qwe", 123.4), + null, + record("wer", 8234.324) + ) + ) + ) + } "support Set of records" { AvroAssertions.assertThat( TestSetRecords( @@ -55,23 +90,14 @@ class ArrayEncodingTest : StringSpec({ } }) { @Serializable - data class TestArrayBooleans(val booleans: List) - - @Serializable - data class TestListDoubles(val doubles: List) - - @Serializable - data class TestSetString(val strings: Set) - - @Serializable - data class TestArrayRecords(val records: Array) + private data class TestListDoubles(val doubles: List) @Serializable - data class TestListRecords(val records: List) + private data class TestSetString(val strings: Set) @Serializable - data class TestSetRecords(val records: Set) + private data class TestSetRecords(val records: Set) @Serializable - data class Record(val str: String, val double: Double) + private data class Record(val str: String, val double: Double) } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt index 8d6978b7..562d61fe 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroAliasEncodingTest.kt @@ -2,17 +2,21 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAlias import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.SomeEnum import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.recordWithSchema import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import org.apache.avro.Schema +import org.apache.avro.SchemaBuilder +import org.apache.avro.generic.GenericData class AvroAliasEncodingTest : StringSpec({ "support alias on field" { AvroAssertions.assertThat(EncodedField("hello")) .isEncodedAs(record("hello")) - .isDecodedAs(DecodedFieldWithAlias(5, "hello")) + .isDecodedAs(DecodedFieldWithAlias(3, "hello")) } "support alias on record" { @@ -20,31 +24,80 @@ class AvroAliasEncodingTest : StringSpec({ .isEncodedAs(record("hello")) .isDecodedAs(DecodedRecordWithAlias("hello")) } + + "support alias on record inside an union" { + val writerSchema = + Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord() + ) + AvroAssertions.assertThat(EncodedRecord("hello")) + .isEncodedAs(recordWithSchema(writerSchema.types[1], "hello"), writerSchema = writerSchema) + .isDecodedAs(DecodedRecordWithAlias("hello")) + } + + "support alias on enum" { + val writerSchema = + SchemaBuilder.record("EnumWrapperRecord").fields() + .name("value") + .type( + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) + .noDefault() + .endRecord() + AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) + .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema(), "A")), writerSchema = writerSchema) + } + + "support alias on enum inside an union" { + val writerSchema = + SchemaBuilder.record("EnumWrapperRecord").fields() + .name("value") + .type( + Schema.createUnion( + SchemaBuilder.enumeration("OtherEnum").symbols("OTHER"), + SchemaBuilder.record("UnknownRecord").aliases("RecordA") + .fields().name("field").type().stringType().noDefault() + .endRecord(), + SchemaBuilder.enumeration("UnknownEnum").aliases("com.github.avrokotlin.avro4k.SomeEnum").symbols("A", "B", "C") + ) + ) + .noDefault() + .endRecord() + AvroAssertions.assertThat(EnumWrapperRecord(SomeEnum.A)) + .isEncodedAs(record(GenericData.EnumSymbol(writerSchema.fields[0].schema().types[2], "A")), writerSchema = writerSchema) + } }) { @Serializable @SerialName("Record") - data class EncodedField( + private data class EncodedField( val s: String, ) @Serializable @SerialName("Record") - data class DecodedFieldWithAlias( - @AvroDefault("5") - val newField: Int, + private data class DecodedFieldWithAlias( + val newField: Int = 3, @AvroAlias("s") val str: String, ) @Serializable @SerialName("RecordA") - data class EncodedRecord( + private data class EncodedRecord( val field: String, ) @Serializable - @SerialName("RecordB") @AvroAlias("RecordA") - data class DecodedRecordWithAlias( + private data class DecodedRecordWithAlias( val field: String, ) + + @Serializable + @SerialName("EnumWrapperRecord") + private data class EnumWrapperRecord( + val value: SomeEnum, + ) } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroDefaultEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroDefaultEncodingTest.kt index fdb507c6..f11d6516 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroDefaultEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroDefaultEncodingTest.kt @@ -68,5 +68,6 @@ class AvroDefaultEncodingTest : StringSpec({ @AvroDecimal(0, 10) @Serializable(BigDecimalSerializer::class) val bigDecimal: BigDecimal, + val kotlinDefault: Int = 42, ) } \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroFixedEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroFixedEncodingTest.kt new file mode 100644 index 00000000..86b4a25d --- /dev/null +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/AvroFixedEncodingTest.kt @@ -0,0 +1,107 @@ +package com.github.avrokotlin.avro4k.encoding + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroAssertions +import com.github.avrokotlin.avro4k.AvroFixed +import com.github.avrokotlin.avro4k.encodeToByteArray +import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.schema +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.core.spec.style.StringSpec +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import org.apache.avro.generic.GenericData +import kotlin.io.path.Path + +class AvroFixedEncodingTest : StringSpec({ + "support fixed on data class fields" { + AvroAssertions.assertThat() + .generatesSchema(Path("/fixed_string.json")) + + val schema = Avro.schema().fields[0].schema() + AvroAssertions.assertThat(FixedStringField("1234567")) + .isEncodedAs(record(GenericData.Fixed(schema, "1234567".toByteArray()))) + } + + "support fixed on value classes" { + AvroAssertions.assertThat() + .generatesSchema(Path("/fixed_string.json")) + + val schema = Avro.schema().fields[0].schema() + AvroAssertions.assertThat(FixedNestedStringField(FixedStringValueClass("1234567"))) + .isEncodedAs(record(GenericData.Fixed(schema, "1234567".toByteArray()))) + + AvroAssertions.assertThat(FixedStringValueClass("1234567")) + .isEncodedAs(GenericData.Fixed(Avro.schema(), "1234567".toByteArray())) + } + + "top-est @AvroFixed annotation takes precedence over nested @AvroFixed annotations" { + AvroAssertions.assertThat() + .generatesSchema(Path("/fixed_string_5.json")) + + // Not 5 chars fixed + shouldThrow { + Avro.encodeToByteArray(FieldPriorToValueClass(FixedStringValueClass("1234567"))) + } + + val schema = Avro.schema().fields[0].schema() + AvroAssertions.assertThat(FieldPriorToValueClass(FixedStringValueClass("12345"))) + .isEncodedAs(record(GenericData.Fixed(schema, "12345".toByteArray()))) + } + + "encode/decode ByteArray as FIXED when schema is Type.Fixed" { + AvroAssertions.assertThat(ByteArrayFixedTest(byteArrayOf(1, 4, 9))) + .isEncodedAs( + record(byteArrayOf(0, 0, 0, 0, 0, 1, 4, 9)), + expectedDecodedValue = ByteArrayFixedTest(byteArrayOf(0, 0, 0, 0, 0, 1, 4, 9)) + ) + } + +// "Handle FIXED in unions with the good and bad fullNames and aliases" { +// fail("TODO") +// } +}) { + @Serializable + @SerialName("Fixed") + private data class FixedStringField( + @AvroFixed(7) val mystring: String, + ) + + @Serializable + @SerialName("Fixed") + private data class FixedNestedStringField( + val mystring: FixedStringValueClass, + ) + + @Serializable + @SerialName("Fixed") + private data class FieldPriorToValueClass( + @AvroFixed(5) val mystring: FixedStringValueClass, + ) + + @JvmInline + @Serializable + @SerialName("FixedString") + private value class FixedStringValueClass( + @AvroFixed(7) val mystring: String, + ) + + @Serializable + private data class ByteArrayFixedTest( + @AvroFixed(8) val z: ByteArray, + ) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as ByteArrayFixedTest + + return z.contentEquals(other.z) + } + + override fun hashCode(): Int { + return z.contentHashCode() + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/BytesEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/BytesEncodingTest.kt index 24af1b2d..be1c9298 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/BytesEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/BytesEncodingTest.kt @@ -1,32 +1,40 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.AvroFixed import com.github.avrokotlin.avro4k.record import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.Serializable +import org.apache.avro.Schema class BytesEncodingTest : StringSpec({ "encode/decode ByteArray" { AvroAssertions.assertThat(ByteArrayTest(byteArrayOf(1, 4, 9))) .isEncodedAs(record(byteArrayOf(1, 4, 9))) + + AvroAssertions.assertThat() + .generatesSchema(Schema.create(Schema.Type.BYTES)) + AvroAssertions.assertThat(byteArrayOf(1, 4, 9)) + .isEncodedAs(byteArrayOf(1, 4, 9)) } + "encode/decode List" { AvroAssertions.assertThat(ListByteTest(listOf(1, 4, 9))) .isEncodedAs(record(byteArrayOf(1, 4, 9))) + + AvroAssertions.assertThat>() + .generatesSchema(Schema.create(Schema.Type.BYTES)) + AvroAssertions.assertThat(listOf(1, 4, 9)) + .isEncodedAs(byteArrayOf(1, 4, 9)) } "encode/decode Array to ByteBuffer" { AvroAssertions.assertThat(ArrayByteTest(arrayOf(1, 4, 9))) .isEncodedAs(record(byteArrayOf(1, 4, 9))) - } - "encode/decode ByteArray as FIXED when schema is Type.Fixed" { - AvroAssertions.assertThat(ByteArrayFixedTest(byteArrayOf(1, 4, 9))) - .isEncodedAs( - record(byteArrayOf(0, 0, 0, 0, 0, 1, 4, 9)), - ByteArrayFixedTest(byteArrayOf(0, 0, 0, 0, 0, 1, 4, 9)) - ) + AvroAssertions.assertThat>() + .generatesSchema(Schema.create(Schema.Type.BYTES)) + AvroAssertions.assertThat(arrayOf(1, 4, 9)) + .isEncodedAs(byteArrayOf(1, 4, 9)) } }) { @Serializable @@ -45,24 +53,6 @@ class BytesEncodingTest : StringSpec({ } } - @Serializable - data class ByteArrayFixedTest( - @AvroFixed(8) val z: ByteArray, - ) { - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false - - other as ByteArrayFixedTest - - return z.contentEquals(other.z) - } - - override fun hashCode(): Int { - return z.contentHashCode() - } - } - @Serializable data class ArrayByteTest(val z: Array) { override fun equals(other: Any?): Boolean { diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt index 27e89deb..cf4f3cea 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/EnumEncodingTest.kt @@ -2,56 +2,83 @@ package com.github.avrokotlin.avro4k.encoding +import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroEnumDefault import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.serializer.UUIDSerializer import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.UseSerializers +import org.apache.avro.generic.GenericData class EnumEncodingTest : StringSpec({ "read / write enums" { AvroAssertions.assertThat(EnumTest(Cream.Bruce, BBM.Moore)) - .isEncodedAs(record("Bruce", "Moore")) + .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Moore"))) + + AvroAssertions.assertThat(Cream.Bruce) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) + AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) } "read / write list of enums" { AvroAssertions.assertThat(EnumListTest(listOf(Cream.Bruce, Cream.Clapton))) - .isEncodedAs(record(listOf("Bruce", "Clapton"))) + .isEncodedAs(record(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton")))) + + AvroAssertions.assertThat(listOf(Cream.Bruce, Cream.Clapton)) + .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) + AvroAssertions.assertThat(listOf(CreamValueClass(Cream.Bruce), CreamValueClass(Cream.Clapton))) + .isEncodedAs(listOf(GenericData.EnumSymbol(Avro.schema(), "Bruce"), GenericData.EnumSymbol(Avro.schema(), "Clapton"))) } "read / write nullable enums" { AvroAssertions.assertThat(NullableEnumTest(null)) .isEncodedAs(record(null)) AvroAssertions.assertThat(NullableEnumTest(Cream.Bruce)) - .isEncodedAs(record("Bruce")) + .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "Bruce"))) + + AvroAssertions.assertThat(Cream.Bruce) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) + AvroAssertions.assertThat(null) + .isEncodedAs(null) + + AvroAssertions.assertThat(CreamValueClass(Cream.Bruce)) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "Bruce")) + AvroAssertions.assertThat(null) + .isEncodedAs(null) } "Decoding enum with an unknown uses @AvroEnumDefault value" { AvroAssertions.assertThat(EnumV2WrapperRecord(EnumV2.B)) - .isEncodedAs(record("B")) + .isEncodedAs(record(GenericData.EnumSymbol(Avro.schema(), "B"))) .isDecodedAs(EnumV1WrapperRecord(EnumV1.UNKNOWN)) + + AvroAssertions.assertThat(EnumV2.B) + .isEncodedAs(GenericData.EnumSymbol(Avro.schema(), "B")) + .isDecodedAs(EnumV1.UNKNOWN) } }) { @Serializable @SerialName("EnumWrapper") - data class EnumV1WrapperRecord( + private data class EnumV1WrapperRecord( val value: EnumV1, ) @Serializable @SerialName("EnumWrapper") - data class EnumV2WrapperRecord( + private data class EnumV2WrapperRecord( val value: EnumV2, ) @Serializable @SerialName("Enum") @AvroEnumDefault("UNKNOWN") - enum class EnumV1 { + private enum class EnumV1 { UNKNOWN, A, } @@ -59,28 +86,32 @@ class EnumEncodingTest : StringSpec({ @Serializable @SerialName("Enum") @AvroEnumDefault("UNKNOWN") - enum class EnumV2 { + private enum class EnumV2 { UNKNOWN, A, B, } @Serializable - data class EnumTest(val a: Cream, val b: BBM) + private data class EnumTest(val a: Cream, val b: BBM) + + @JvmInline + @Serializable + private value class CreamValueClass(val a: Cream) @Serializable - data class EnumListTest(val a: List) + private data class EnumListTest(val a: List) @Serializable - data class NullableEnumTest(val a: Cream?) + private data class NullableEnumTest(val a: Cream?) - enum class Cream { + private enum class Cream { Bruce, Baker, Clapton, } - enum class BBM { + private enum class BBM { Bruce, Baker, Moore, diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/InlineClassEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/InlineClassEncodingTest.kt deleted file mode 100644 index 8e9c7312..00000000 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/InlineClassEncodingTest.kt +++ /dev/null @@ -1,31 +0,0 @@ -package com.github.avrokotlin.avro4k.encoding - -import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.record -import io.kotest.core.spec.style.StringSpec -import kotlinx.serialization.Serializable - -class InlineClassEncodingTest : StringSpec({ - "encode/decode @AvroInline" { - AvroAssertions.assertThat(Product("123", Name("sneakers"))) - .isEncodedAs(record("123", "sneakers")) - } - "encode/decode @AvroInline at root" { - AvroAssertions.assertThat(ValueClass(NestedValue("sneakers"))) - .isEncodedAs(record("sneakers")) - } -}) { - @Serializable - @JvmInline - private value class ValueClass(val value: NestedValue) - - @Serializable - private data class NestedValue(val field: String) - - @Serializable - @JvmInline - private value class Name(val value: String) - - @Serializable - private data class Product(val id: String, val name: Name) -} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/LogicalTypesEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/LogicalTypesEncodingTest.kt index fea4c160..f90d08c4 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/LogicalTypesEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/LogicalTypesEncodingTest.kt @@ -1,8 +1,10 @@ package com.github.avrokotlin.avro4k.encoding +import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroFixed import com.github.avrokotlin.avro4k.record +import com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.serializer.BigDecimalAsStringSerializer import com.github.avrokotlin.avro4k.serializer.InstantToMicroSerializer import io.kotest.core.spec.style.StringSpec @@ -21,6 +23,19 @@ import java.time.LocalTime import java.util.UUID class LogicalTypesEncodingTest : StringSpec({ + "support logical types at root level" { + val schema = Avro.schema().fields[0].schema() + AvroAssertions.assertThat(BigDecimal("123.45")) + .isEncodedAs( + Conversions.DecimalConversion().toBytes( + BigDecimal("123.45"), + null, + org.apache.avro.LogicalTypes.decimal(8, 2) + ), + writerSchema = schema + ) + } + "should encode and decode logical types" { val logicalTypes = LogicalTypes( diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/MapSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/MapEncodingTest.kt similarity index 98% rename from src/test/kotlin/com/github/avrokotlin/avro4k/schema/MapSchemaTest.kt rename to src/test/kotlin/com/github/avrokotlin/avro4k/encoding/MapEncodingTest.kt index c68e2ac9..2d94d97b 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/MapSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/MapEncodingTest.kt @@ -1,4 +1,4 @@ -package com.github.avrokotlin.avro4k.schema +package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions @@ -13,6 +13,7 @@ import com.github.avrokotlin.avro4k.WrappedInt import com.github.avrokotlin.avro4k.WrappedLong import com.github.avrokotlin.avro4k.WrappedShort import com.github.avrokotlin.avro4k.WrappedString +import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException import com.github.avrokotlin.avro4k.record import com.github.avrokotlin.avro4k.schema import io.kotest.assertions.throwables.shouldThrow diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/NestedClassEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/NestedClassEncodingTest.kt index 5a9e7102..d71c5867 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/NestedClassEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/NestedClassEncodingTest.kt @@ -6,23 +6,6 @@ import com.github.avrokotlin.avro4k.record import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.Serializable -@Serializable -private data class County(val name: String, val towns: List, val ceremonial: Boolean, val lat: Double, val long: Double) - -@Serializable -private data class Town(val name: String, val population: Int) - -@Serializable -private data class Birthplace(val person: String, val town: Town) - -@Serializable -private data class PersonV2( - val name: String, - val hasChickenPoxVaccine: Boolean, - @AvroDefault("null") - val hasCovidVaccine: Boolean? = null, -) - class NestedClassEncodingTest : StringSpec({ "decode nested class" { AvroAssertions.assertThat(Birthplace(person = "Sammy Sam", town = Town(name = "Hardwick", population = 123))) @@ -64,4 +47,27 @@ class NestedClassEncodingTest : StringSpec({ record("Ryan", true, null) ) } -}) \ No newline at end of file +}) { + @Serializable + private data class County( + val name: String, + val towns: List, + val ceremonial: Boolean, + val lat: Double, + val long: Double, + ) + + @Serializable + private data class Town(val name: String, val population: Int) + + @Serializable + private data class Birthplace(val person: String, val town: Town) + + @Serializable + private data class PersonV2( + val name: String, + val hasChickenPoxVaccine: Boolean, + @AvroDefault("null") + val hasCovidVaccine: Boolean? = null, + ) +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt index 5d3bb047..0fd3f789 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/PrimitiveEncodingTest.kt @@ -1,6 +1,15 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAssertions +import com.github.avrokotlin.avro4k.WrappedBoolean +import com.github.avrokotlin.avro4k.WrappedByte +import com.github.avrokotlin.avro4k.WrappedChar +import com.github.avrokotlin.avro4k.WrappedDouble +import com.github.avrokotlin.avro4k.WrappedFloat +import com.github.avrokotlin.avro4k.WrappedInt +import com.github.avrokotlin.avro4k.WrappedLong +import com.github.avrokotlin.avro4k.WrappedShort +import com.github.avrokotlin.avro4k.WrappedString import com.github.avrokotlin.avro4k.record import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.Serializable @@ -12,50 +21,93 @@ class PrimitiveEncodingTest : StringSpec({ .isEncodedAs(record(true)) AvroAssertions.assertThat(BooleanTest(false)) .isEncodedAs(record(false)) + AvroAssertions.assertThat(true) + .isEncodedAs(true) + AvroAssertions.assertThat(false) + .isEncodedAs(false) + AvroAssertions.assertThat(WrappedBoolean(true)) + .isEncodedAs(true) + AvroAssertions.assertThat(WrappedBoolean(false)) + .isEncodedAs(false) } "read write out bytes" { AvroAssertions.assertThat(ByteTest(3)) .isEncodedAs(record(3)) + AvroAssertions.assertThat(3.toByte()) + .isEncodedAs(3) + AvroAssertions.assertThat(WrappedByte(3)) + .isEncodedAs(3) } "read write out shorts" { AvroAssertions.assertThat(ShortTest(3)) .isEncodedAs(record(3)) + AvroAssertions.assertThat(3.toShort()) + .isEncodedAs(3) + AvroAssertions.assertThat(WrappedShort(3)) + .isEncodedAs(3) } "read write out chars" { AvroAssertions.assertThat(CharTest('A')) .isEncodedAs(record('A'.code)) + AvroAssertions.assertThat('A') + .isEncodedAs('A'.code) + AvroAssertions.assertThat(WrappedChar('A')) + .isEncodedAs('A'.code) } "read write out strings" { AvroAssertions.assertThat(StringTest("Hello world")) .isEncodedAs(record("Hello world")) + AvroAssertions.assertThat("Hello world") + .isEncodedAs("Hello world") + AvroAssertions.assertThat(WrappedString("Hello world")) + .isEncodedAs("Hello world") } "read write out longs" { AvroAssertions.assertThat(LongTest(65653L)) .isEncodedAs(record(65653L)) + AvroAssertions.assertThat(65653L) + .isEncodedAs(65653L) + AvroAssertions.assertThat(WrappedLong(65653)) + .isEncodedAs(65653L) } "read write out ints" { AvroAssertions.assertThat(IntTest(44)) .isEncodedAs(record(44)) + AvroAssertions.assertThat(44) + .isEncodedAs(44) + AvroAssertions.assertThat(WrappedInt(44)) + .isEncodedAs(44) } "read write out doubles" { AvroAssertions.assertThat(DoubleTest(3.235)) .isEncodedAs(record(3.235)) + AvroAssertions.assertThat(3.235) + .isEncodedAs(3.235) + AvroAssertions.assertThat(WrappedDouble(3.235)) + .isEncodedAs(3.235) } "read write out floats" { AvroAssertions.assertThat(FloatTest(3.4F)) .isEncodedAs(record(3.4F)) + AvroAssertions.assertThat(3.4F) + .isEncodedAs(3.4F) + AvroAssertions.assertThat(WrappedFloat(3.4F)) + .isEncodedAs(3.4F) } + "read write out byte arrays" { AvroAssertions.assertThat(ByteArrayTest("ABC".toByteArray())) .isEncodedAs(record(ByteBuffer.wrap("ABC".toByteArray()))) + AvroAssertions.assertThat("ABC".toByteArray()) + .isEncodedAs(ByteBuffer.wrap("ABC".toByteArray())) } }) { @Serializable diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt index 51cc29a8..2c20c429 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/RecordEncodingTest.kt @@ -3,9 +3,11 @@ package com.github.avrokotlin.avro4k.encoding import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroFixed import com.github.avrokotlin.avro4k.record +import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException import org.apache.avro.SchemaBuilder class RecordEncodingTest : StringSpec({ @@ -121,7 +123,25 @@ class RecordEncodingTest : StringSpec({ AvroAssertions.assertThat(ByteFoo(123)) .isEncodedAs(record(123)) } + "should not encode records with a different name" { + @Serializable + data class TheRecord(val v: Int) + shouldThrow { + val wrongRecordSchema = SchemaBuilder.record("AnotherRecord").fields().name("v").type().intType().noDefault().endRecord() + AvroAssertions.assertThat(TheRecord(1)) + .isEncodedAs(record(1), writerSchema = wrongRecordSchema) + } + } + "support objects" { + AvroAssertions.assertThat(ObjectClass) + .isEncodedAs(record()) + } }) { + @Serializable + object ObjectClass { + val field1 = "ignored" + } + @Serializable private data class StringFoo(val s: String) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/SealedClassEncodingTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/SealedClassEncodingTest.kt index bf6d8f2b..fa8b1227 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/SealedClassEncodingTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/encoding/SealedClassEncodingTest.kt @@ -12,12 +12,19 @@ class SealedClassEncodingTest : StringSpec({ "encode/decode sealed classes" { AvroAssertions.assertThat(ReferencingSealedClass(Operation.Binary.Add(1, 2))) .isEncodedAs(record(recordWithSchema(Avro.schema(), 1, 2))) + AvroAssertions.assertThat(Operation.Binary.Add(1, 2)) + .isEncodedAs(recordWithSchema(Avro.schema(), 1, 2)) } "encode/decode nullable sealed classes" { AvroAssertions.assertThat(ReferencingNullableSealedClass(Operation.Binary.Add(1, 2))) .isEncodedAs(record(recordWithSchema(Avro.schema(), 1, 2))) AvroAssertions.assertThat(ReferencingNullableSealedClass(null)) .isEncodedAs(record(null)) + + AvroAssertions.assertThat(Operation.Binary.Add(1, 2)) + .isEncodedAs(recordWithSchema(Avro.schema(), 1, 2)) + AvroAssertions.assertThat(null) + .isEncodedAs(null) } }) { @Serializable diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroFixedSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroFixedSchemaTest.kt deleted file mode 100644 index 7c3345e7..00000000 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroFixedSchemaTest.kt +++ /dev/null @@ -1,52 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import com.github.avrokotlin.avro4k.AvroAssertions -import com.github.avrokotlin.avro4k.AvroFixed -import io.kotest.core.spec.style.WordSpec -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlin.io.path.Path - -class AvroFixedSchemaTest : WordSpec({ - "@AvroFixed" should { - "generated fixed field schema when used on a field" { - AvroAssertions.assertThat() - .generatesSchema(Path("/fixed_string.json")) - } - - "generated fixed field schema when used on a value class' field" { - AvroAssertions.assertThat() - .generatesSchema(Path("/fixed_string.json")) - } - - "generated fixed field schema with @AvroFixed from class field instead of value class' field" { - AvroAssertions.assertThat() - .generatesSchema(Path("/fixed_string_5.json")) - } - } -}) { - @Serializable - @SerialName("Fixed") - private data class FixedStringField( - @AvroFixed(7) val mystring: String, - ) - - @Serializable - @SerialName("Fixed") - private data class FixedNestedStringField( - val mystring: FixedString, - ) - - @Serializable - @SerialName("Fixed") - private data class FieldPriorToValueClass( - @AvroFixed(5) val mystring: FixedString, - ) - - @JvmInline - @Serializable - @SerialName("FixedString") - private value class FixedString( - @AvroFixed(7) val mystring: String, - ) -} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaTest.kt index 1285ad3a..ac082011 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaTest.kt @@ -5,6 +5,7 @@ package com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAssertions import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException import com.github.avrokotlin.avro4k.schema import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.StringSpec diff --git a/src/test/resources/class_of_list_of_maps.json b/src/test/resources/class_of_list_of_maps.json index ec509ca2..136b6c84 100644 --- a/src/test/resources/class_of_list_of_maps.json +++ b/src/test/resources/class_of_list_of_maps.json @@ -1,7 +1,7 @@ { "type": "record", "name": "List2Test", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "ship", diff --git a/src/test/resources/list_of_maps.json b/src/test/resources/list_of_maps.json index c9fae030..87101774 100644 --- a/src/test/resources/list_of_maps.json +++ b/src/test/resources/list_of_maps.json @@ -1,7 +1,7 @@ { "type": "record", "name": "ListTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "list", diff --git a/src/test/resources/map_boolean_null.json b/src/test/resources/map_boolean_null.json index d53c49d2..115c6341 100644 --- a/src/test/resources/map_boolean_null.json +++ b/src/test/resources/map_boolean_null.json @@ -1,7 +1,7 @@ { "type": "record", "name": "StringBooleanTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "map", diff --git a/src/test/resources/map_int.json b/src/test/resources/map_int.json index 03ccba4a..f0e80a0c 100644 --- a/src/test/resources/map_int.json +++ b/src/test/resources/map_int.json @@ -1,7 +1,7 @@ { "type": "record", "name": "StringIntTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "map", diff --git a/src/test/resources/map_record.json b/src/test/resources/map_record.json index 406c16c6..80957549 100644 --- a/src/test/resources/map_record.json +++ b/src/test/resources/map_record.json @@ -1,7 +1,7 @@ { "type": "record", "name": "StringNestedTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "map", diff --git a/src/test/resources/map_set_nested.json b/src/test/resources/map_set_nested.json index ecaebd43..1c2d6b44 100644 --- a/src/test/resources/map_set_nested.json +++ b/src/test/resources/map_set_nested.json @@ -1,7 +1,7 @@ { "type": "record", "name": "StringSetNestedTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "map", diff --git a/src/test/resources/set_of_maps.json b/src/test/resources/set_of_maps.json index 141788a9..83df9641 100644 --- a/src/test/resources/set_of_maps.json +++ b/src/test/resources/set_of_maps.json @@ -1,7 +1,7 @@ { "type": "record", "name": "SetTest", - "namespace": "com.github.avrokotlin.avro4k.schema.MapSchemaTest", + "namespace": "com.github.avrokotlin.avro4k.encoding.MapSchemaTest", "fields": [ { "name": "set",