Skip to content

Commit

Permalink
Provide direct support for Map (zio#142)
Browse files Browse the repository at this point in the history
* Provide direct support for Map
Unit tests for json serdes
Unit tests for protobuf serdes

* run scalafmt

* Rebased/reformatted

* Rebased/reformatted, and added a minimum contribution guide

* Run sbt `prepare test` instead of `fmt test`
  • Loading branch information
pierangeloc authored and landlockedsurfer committed May 28, 2022
1 parent 1a14c1a commit 1487338
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 31 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ libraryDependencies += "dev.zio" %% "zio-schema" % "<version>"
[link-discord]: https://discord.gg/2ccFBr4 "Discord"
[Stage]: https://img.shields.io/badge/Project%20Stage-Development-yellowgreen.svg
[Stage-Page]: https://github.com/zio/zio/wiki/Project-Stages

## Contributing

For the general guidelines, see ZIO [contributor's guide](https://github.com/zio/zio/blob/master/docs/about/contributing.md).

#### TL;DR

Before you submit a PR, make sure your tests are passing, and that the code is properly formatted

```
sbt prepare
sbt test
```
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ object JsonCodec extends Codec {
}

private[codec] def schemaEncoder[A](schema: Schema[A]): JsonEncoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Sequence(schema, _, g, _) => JsonEncoder.chunk(schemaEncoder(schema)).contramap(g)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Sequence(schema, _, g, _) => JsonEncoder.chunk(schemaEncoder(schema)).contramap(g)
case Schema.MapSchema(ks, vs, _) =>
JsonEncoder.chunk(schemaEncoder(ks).both(schemaEncoder(vs))).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _) => transformEncoder(c, g)
case Schema.Tuple(l, r, _) => JsonEncoder.tuple2(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => JsonEncoder.option(schemaEncoder(schema))
Expand Down Expand Up @@ -188,11 +190,13 @@ object JsonCodec extends Codec {
schemaDecoder(schema).decodeJson(json)

private[codec] def schemaDecoder[A](schema: Schema[A]): JsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Optional(codec, _) => JsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple(left, right, _) => JsonDecoder.tuple2(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _) => JsonDecoder.chunk(schemaDecoder(codec)).map(f)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Optional(codec, _) => JsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple(left, right, _) => JsonDecoder.tuple2(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _) => JsonDecoder.chunk(schemaDecoder(codec)).map(f)
case Schema.MapSchema(ks, vs, _) =>
JsonDecoder.chunk(schemaDecoder(ks) <*> schemaDecoder(vs)).map(entries => entries.toList.toMap)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(structure, _) => recordDecoder(structure.toChunk)
case Schema.EitherSchema(left, right, _) => JsonDecoder.either(schemaDecoder(left), schemaDecoder(right))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ object JsonCodecSpec extends DefaultRunnableSpec {
)
}
),
suite("Map")(
testM("of complex keys and values") {

case class Key(name: String, index: Int)
case class Value(first: Int, second: Boolean)

assertEncodes(
Schema.map(
DeriveSchema.gen[Key],
DeriveSchema.gen[Value]
),
Map(Key("a", 0) -> Value(0, true), Key("b", 1) -> Value(1, false)),
JsonCodec.Encoder.charSequenceToByteChunk(
"""[[{"name":"a","index":0},{"first":0,"second":true}],[{"name":"b","index":1},{"first":1,"second":false}]]"""
)
)
}
),
suite("record")(
testM("of primitives") {
assertEncodes(
Expand Down Expand Up @@ -212,6 +230,21 @@ object JsonCodecSpec extends DefaultRunnableSpec {
case (schema, value) => assertEncodesThenDecodes(schema, value)
}
},
testM("Map of complex keys and values") {
case class Key(name: String, index: Int)
case class Value(first: Int, second: Boolean)

assertDecodes(
Schema.map(
DeriveSchema.gen[Key],
DeriveSchema.gen[Value]
),
Map(Key("a", 0) -> Value(0, true), Key("b", 1) -> Value(1, false)),
JsonCodec.Encoder.charSequenceToByteChunk(
"""[[{"name":"a","index":0},{"first":0,"second":true}],[{"name":"b","index":1},{"first":1,"second":false}]]"""
)
)
},
testM("of records") {
checkM(for {
(left, a) <- SchemaGen.anyRecordAndValue
Expand Down Expand Up @@ -464,6 +497,12 @@ object JsonCodecSpec extends DefaultRunnableSpec {
case _ => value
}

implicit def mapEncoder[K, V](
implicit keyEncoder: JsonEncoder[K],
valueEncoder: JsonEncoder[V]
): JsonEncoder[Map[K, V]] =
JsonEncoder.chunk(keyEncoder.both(valueEncoder)).contramap[Map[K, V]](m => Chunk.fromIterable(m))

private def jsonEncoded[A](value: A)(implicit enc: JsonEncoder[A]): Chunk[Byte] =
JsonCodec.Encoder.charSequenceToByteChunk(enc.encodeJson(value, None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ object ZioOpticsBuilder extends AccessorBuilder {
)

override def makeTraversal[S, A](
collection: Schema.Sequence[S, A],
collection: Schema.Collection[S, A],
element: Schema[A]
): Optic[S, S, Chunk[A], OpticFailure, OpticFailure, Chunk[A], S] =
ZTraversal(
ZioOpticsBuilder.makeTraversalGet(collection),
ZioOpticsBuilder.makeTraversalSet(collection)
)
collection match {
case seq @ Schema.Sequence(_, _, _, _) =>
ZTraversal(
ZioOpticsBuilder.makeSeqTraversalGet(seq),
ZioOpticsBuilder.makeSeqTraversalSet(seq)
)
case Schema.MapSchema(_, _, _) =>
ZTraversal(
ZioOpticsBuilder.makeMapTraversalGet,
ZioOpticsBuilder.makeMapTraversalSet
)
}

private[optics] def makeLensGet[S, A](
product: Schema.Record[S],
Expand Down Expand Up @@ -89,13 +97,13 @@ object ZioOpticsBuilder extends AccessorBuilder {
}
}

private[optics] def makeTraversalGet[S, A](
private[optics] def makeSeqTraversalGet[S, A](
collection: Schema.Sequence[S, A]
): S => Either[(OpticFailure, S), Chunk[A]] = { whole: S =>
Right(collection.toChunk(whole))
}

private[optics] def makeTraversalSet[S, A](
private[optics] def makeSeqTraversalSet[S, A](
collection: Schema.Sequence[S, A]
): Chunk[A] => S => Either[(OpticFailure, S), S] = { (piece: Chunk[A]) => (whole: S) =>
val builder = ChunkBuilder.make[A]()
Expand All @@ -111,6 +119,15 @@ object ZioOpticsBuilder extends AccessorBuilder {
Right(collection.fromChunk(builder.result()))
}

private[optics] def makeMapTraversalGet[K, V](whole: Map[K, V]): Either[(OpticFailure, Map[K, V]), Chunk[(K, V)]] =
Right(Chunk.fromIterable(whole))

private[optics] def makeMapTraversalSet[K, V]
: Chunk[(K, V)] => Map[K, V] => Either[(OpticFailure, Map[K, V]), Map[K, V]] = {
(piece: Chunk[(K, V)]) => (whole: Map[K, V]) =>
Right(whole ++ piece.toList)
}

private def spliceRecord(
fields: ListMap[String, DynamicValue],
label: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,21 @@ object ProtobufCodec extends Codec {
(schema, value) match {
case (Schema.GenericRecord(structure, _), v: Map[String, _]) => encodeRecord(fieldNumber, structure.toChunk, v)
case (Schema.Sequence(element, _, g, _), v) => encodeSequence(fieldNumber, element, g(v))
case (Schema.Transform(codec, _, g, _), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType, _), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right, _), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec, _), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right, _), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (lzy @ Schema.Lazy(_), v) => encode(fieldNumber, lzy.schema, v)
case (Schema.Meta(ast, _), _) => encode(fieldNumber, Schema[SchemaAst], ast)
case ProductEncoder(encode) => encode(fieldNumber)
case (Schema.Enum1(c, _), v) => encodeEnum(fieldNumber, v, c)
case (Schema.Enum2(c1, c2, _), v) => encodeEnum(fieldNumber, v, c1, c2)
case (Schema.Enum3(c1, c2, c3, _), v) => encodeEnum(fieldNumber, v, c1, c2, c3)
case (Schema.EnumN(cs, _), v) => encodeEnum(fieldNumber, v, cs.toSeq: _*)
case (_, _) => Chunk.empty
case (Schema.MapSchema(ks, vs, _), v: Map[k, v]) =>
encodeSequence(fieldNumber, ks <*> vs, Chunk.fromIterable(v))
case (Schema.Transform(codec, _, g, _), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType, _), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right, _), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec, _), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right, _), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (lzy @ Schema.Lazy(_), v) => encode(fieldNumber, lzy.schema, v)
case (Schema.Meta(ast, _), _) => encode(fieldNumber, Schema[SchemaAst], ast)
case ProductEncoder(encode) => encode(fieldNumber)
case (Schema.Enum1(c, _), v) => encodeEnum(fieldNumber, v, c)
case (Schema.Enum2(c1, c2, _), v) => encodeEnum(fieldNumber, v, c1, c2)
case (Schema.Enum3(c1, c2, c3, _), v) => encodeEnum(fieldNumber, v, c1, c2, c3)
case (Schema.EnumN(cs, _), v) => encodeEnum(fieldNumber, v, cs.toSeq: _*)
case (_, _) => Chunk.empty
}

private def encodeEnum[Z](fieldNumber: Option[Int], value: Z, cases: Schema.Case[_, Z]*): Chunk[Byte] = {
Expand Down Expand Up @@ -443,6 +445,15 @@ object ProtobufCodec extends Codec {
},
true
)
case Schema.MapSchema(ks: Schema[k], vs: Schema[v], _) =>
decoder(
Schema.Sequence(
ks <*> vs,
(c: Chunk[(k, v)]) => c.toList.toMap,
(m: Map[k, v]) => Chunk.fromIterable(m),
Chunk.empty
)
)
case Schema.Transform(codec, f, _, _) => transformDecoder(codec, f)
case Schema.Primitive(standardType, _) => primitiveDecoder(standardType)
case Schema.Tuple(left, right, _) => tupleDecoder(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,18 @@ object ProtobufCodecSpec extends DefaultRunnableSpec {
ed2 <- encodeAndDecodeNS(sequenceOfSumSchema, richSequence)
} yield assert(ed)(equalTo(Chunk(richSequence))) && assert(ed2)(equalTo(richSequence))
},
testM("map of products") {
val m: Map[Record, MyRecord] = Map(
Record("AAA", 1) -> MyRecord(1),
Record("BBB", 2) -> MyRecord(2)
)

val mSchema = Schema.map(Record.schemaRecord, myRecord)
for {
ed <- encodeAndDecode(mSchema, m)
ed2 <- encodeAndDecodeNS(mSchema, m)
} yield assert(ed)(equalTo(Chunk.succeed(m))) && assert(ed2)(equalTo(m))
},
testM("recursive data types") {
checkM(SchemaGen.anyRecursiveTypeAndValue) {
case (schema, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ trait AccessorBuilder {

def makePrism[S, A](sum: Schema.Enum[S], term: Schema.Case[A, S]): Prism[S, A]

def makeTraversal[S, A](collection: Schema.Sequence[S, A], element: Schema[A]): Traversal[S, A]
def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A]
}
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ object DynamicValue {
case Schema.Sequence(schema, _, toChunk, _) =>
DynamicValue.Sequence(toChunk(value).map(fromSchemaAndValue(schema, _)))

case Schema.MapSchema(ks: Schema[k], vs: Schema[v], _) =>
val entries = value.asInstanceOf[Map[k, v]].map {
case (key, value) => (fromSchemaAndValue(ks, key), fromSchemaAndValue(vs, value))
}
DynamicValue.Dictionary(Chunk.fromIterable(entries))

case Schema.EitherSchema(left, right, _) =>
value match {
case Left(a) => DynamicValue.LeftValue(fromSchemaAndValue(left, a))
Expand Down Expand Up @@ -1839,6 +1845,8 @@ object DynamicValue {

final case class Sequence(values: Chunk[DynamicValue]) extends DynamicValue

final case class Dictionary[K, V](entries: Chunk[(DynamicValue, DynamicValue)]) extends DynamicValue

sealed case class Primitive[A](value: A, standardType: StandardType[A]) extends DynamicValue

sealed case class Singleton[A](instance: A) extends DynamicValue
Expand Down
17 changes: 16 additions & 1 deletion zio-schema/shared/src/main/scala/zio/schema/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
implicit def chunk[A](implicit schemaA: Schema[A]): Schema[Chunk[A]] =
Schema.Sequence[Chunk[A], A](schemaA, identity, identity, Chunk.empty)

implicit def map[K, V](implicit ks: Schema[K], vs: Schema[V]): Schema[Map[K, V]] =
Schema.MapSchema(ks, vs, Chunk.empty)

implicit def either[A, B](implicit left: Schema[A], right: Schema[B]): Schema[Either[A, B]] =
EitherSchema(left, right)

Expand Down Expand Up @@ -239,12 +242,14 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
def rawConstruct(values: Chunk[Any]): Either[String, R]
}

sealed trait Collection[Col, Elem] extends Schema[Col]

final case class Sequence[Col, Elem](
schemaA: Schema[Elem],
fromChunk: Chunk[Elem] => Col,
toChunk: Col => Chunk[Elem],
override val annotations: Chunk[Any]
) extends Schema[Col] { self =>
) extends Collection[Col, Elem] { self =>
override type Accessors[Lens[_, _], Prism[_, _], Traversal[_, _]] = Traversal[Col, Elem]

override def annotate(annotation: Any): Sequence[Col, Elem] = copy(annotations = annotations :+ annotation)
Expand Down Expand Up @@ -376,6 +381,16 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
override def makeAccessors(b: AccessorBuilder): Unit = ()

}

final case class MapSchema[K, V](ks: Schema[K], vs: Schema[V], override val annotations: Chunk[Any])
extends Collection[Map[K, V], (K, V)] { self =>
override type Accessors[Lens[_, _], Prism[_, _], Traversal[_, _]] = Traversal[Map[K, V], (K, V)]

override def annotate(annotation: Any): MapSchema[K, V] = copy(annotations = annotations :+ annotation)

override def makeAccessors(b: AccessorBuilder): b.Traversal[Map[K, V], (K, V)] =
b.makeTraversal(self, ks <*> vs)
}
}

//scalafmt: { maxColumn = 400 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ object SchemaAst {
.buildProduct()
case Schema.Sequence(schema, _, _, _) =>
subtree(NodePath.root, Chunk.empty, schema, dimensions = 1)
case Schema.MapSchema(ks, vs, _) =>
NodeBuilder(NodePath.root, Chunk.empty, optional = false, dimensions = 1)
.addLabelledSubtree("key", ks)
.addLabelledSubtree("value", vs)
.buildProduct()
case Schema.Transform(schema, _, _, _) => subtree(NodePath.root, Chunk.empty, schema)
case lzy @ Schema.Lazy(_) => fromSchema(lzy.schema)
case s: Schema.Record[A] =>
Expand Down Expand Up @@ -235,6 +240,8 @@ object SchemaAst {
.buildProduct()
case Schema.Sequence(schema, _, _, _) =>
subtree(path, lineage, schema, optional, dimensions + 1)
case Schema.MapSchema(ks, vs, _) =>
subtree(path, lineage, ks <*> vs, optional = false, dimensions + 1)
case Schema.Transform(schema, _, _, _) => subtree(path, lineage, schema, optional, dimensions)
case lzy @ Schema.Lazy(_) => subtree(path, lineage, lzy.schema, optional, dimensions)
case s: Schema.Record[_] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object DynamicValueGen {
case Schema.Enum22(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22))
case Schema.EnumN(cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq))
case Schema.Sequence(schema, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.MapSchema(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue))
case Schema.Tuple(left, right, _) => anyDynamicTupleValue(left, right)
case Schema.EitherSchema(left, right, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestAccessorBuilder extends AccessorBuilder {
override def makePrism[S, A](sum: Schema.Enum[S], term: schema.Schema.Case[A, S]): Prism[S, A] =
TestAccessorBuilder.Prism(sum, term)

override def makeTraversal[S, A](collection: Schema.Sequence[S, A], element: Schema[A]): Traversal[S, A] =
override def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A] =
TestAccessorBuilder.Traversal(collection, element)
}

Expand All @@ -22,5 +22,5 @@ object TestAccessorBuilder {

case class Prism[S, A](s: Schema.Enum[S], a: Schema.Case[A, S])

case class Traversal[S, A](s: Schema.Sequence[S, A], a: Schema[A])
case class Traversal[S, A](s: Schema.Collection[S, A], a: Schema[A])
}

0 comments on commit 1487338

Please sign in to comment.