Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide direct support for Map #142

Merged
merged 5 commits into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ libraryDependencies += "dev.zio" %% "zio-schema" % "<version>"
[link-discord]: https://discord.gg/2ccFBr4 "Discord"
[Stage]: https://img.shields.io/badge/Project%20Stage-Development-yellowgreen.svg
[Stage-Page]: https://github.com/zio/zio/wiki/Project-Stages

## Contributing

For the general guidelines, see ZIO [contributor's guide](https://github.com/zio/zio/blob/master/docs/about/contributing.md).

#### TL;DR

Before you submit a PR, make sure your tests are passing, and that the code is properly formatted

```
sbt prepare

sbt test
```
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ object JsonCodec extends Codec {
}

private[codec] def schemaEncoder[A](schema: Schema[A]): JsonEncoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Sequence(schema, _, g, _) => JsonEncoder.chunk(schemaEncoder(schema)).contramap(g)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Sequence(schema, _, g, _) => JsonEncoder.chunk(schemaEncoder(schema)).contramap(g)
case Schema.MapSchema(ks, vs, _) =>
JsonEncoder.chunk(schemaEncoder(ks).both(schemaEncoder(vs))).contramap(m => Chunk.fromIterable(m))
case Schema.Transform(c, _, g, _) => transformEncoder(c, g)
case Schema.Tuple(l, r, _) => JsonEncoder.tuple2(schemaEncoder(l), schemaEncoder(r))
case Schema.Optional(schema, _) => JsonEncoder.option(schemaEncoder(schema))
Expand Down Expand Up @@ -188,11 +190,13 @@ object JsonCodec extends Codec {
schemaDecoder(schema).decodeJson(json)

private[codec] def schemaDecoder[A](schema: Schema[A]): JsonDecoder[A] = schema match {
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Optional(codec, _) => JsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple(left, right, _) => JsonDecoder.tuple2(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _) => JsonDecoder.chunk(schemaDecoder(codec)).map(f)
case Schema.Primitive(standardType, _) => primitiveCodec(standardType)
case Schema.Optional(codec, _) => JsonDecoder.option(schemaDecoder(codec))
case Schema.Tuple(left, right, _) => JsonDecoder.tuple2(schemaDecoder(left), schemaDecoder(right))
case Schema.Transform(codec, f, _, _) => schemaDecoder(codec).mapOrFail(f)
case Schema.Sequence(codec, f, _, _) => JsonDecoder.chunk(schemaDecoder(codec)).map(f)
case Schema.MapSchema(ks, vs, _) =>
JsonDecoder.chunk(schemaDecoder(ks) <*> schemaDecoder(vs)).map(entries => entries.toList.toMap)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(structure, _) => recordDecoder(structure.toChunk)
case Schema.EitherSchema(left, right, _) => JsonDecoder.either(schemaDecoder(left), schemaDecoder(right))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ object JsonCodecSpec extends DefaultRunnableSpec {
)
}
),
suite("Map")(
testM("of complex keys and values") {

case class Key(name: String, index: Int)
case class Value(first: Int, second: Boolean)

assertEncodes(
Schema.map(
DeriveSchema.gen[Key],
DeriveSchema.gen[Value]
),
Map(Key("a", 0) -> Value(0, true), Key("b", 1) -> Value(1, false)),
JsonCodec.Encoder.charSequenceToByteChunk(
"""[[{"name":"a","index":0},{"first":0,"second":true}],[{"name":"b","index":1},{"first":1,"second":false}]]"""
)
)
}
),
suite("record")(
testM("of primitives") {
assertEncodes(
Expand Down Expand Up @@ -212,6 +230,21 @@ object JsonCodecSpec extends DefaultRunnableSpec {
case (schema, value) => assertEncodesThenDecodes(schema, value)
}
},
testM("Map of complex keys and values") {
case class Key(name: String, index: Int)
case class Value(first: Int, second: Boolean)

assertDecodes(
Schema.map(
DeriveSchema.gen[Key],
DeriveSchema.gen[Value]
),
Map(Key("a", 0) -> Value(0, true), Key("b", 1) -> Value(1, false)),
JsonCodec.Encoder.charSequenceToByteChunk(
"""[[{"name":"a","index":0},{"first":0,"second":true}],[{"name":"b","index":1},{"first":1,"second":false}]]"""
)
)
},
testM("of records") {
checkM(for {
(left, a) <- SchemaGen.anyRecordAndValue
Expand Down Expand Up @@ -464,6 +497,12 @@ object JsonCodecSpec extends DefaultRunnableSpec {
case _ => value
}

implicit def mapEncoder[K, V](
implicit keyEncoder: JsonEncoder[K],
valueEncoder: JsonEncoder[V]
): JsonEncoder[Map[K, V]] =
JsonEncoder.chunk(keyEncoder.both(valueEncoder)).contramap[Map[K, V]](m => Chunk.fromIterable(m))

private def jsonEncoded[A](value: A)(implicit enc: JsonEncoder[A]): Chunk[Byte] =
JsonCodec.Encoder.charSequenceToByteChunk(enc.encodeJson(value, None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ object ZioOpticsBuilder extends AccessorBuilder {
)

override def makeTraversal[S, A](
collection: Schema.Sequence[S, A],
collection: Schema.Collection[S, A],
element: Schema[A]
): Optic[S, S, Chunk[A], OpticFailure, OpticFailure, Chunk[A], S] =
ZTraversal(
ZioOpticsBuilder.makeTraversalGet(collection),
ZioOpticsBuilder.makeTraversalSet(collection)
)
collection match {
case seq @ Schema.Sequence(_, _, _, _) =>
ZTraversal(
ZioOpticsBuilder.makeSeqTraversalGet(seq),
ZioOpticsBuilder.makeSeqTraversalSet(seq)
)
case Schema.MapSchema(_, _, _) =>
ZTraversal(
ZioOpticsBuilder.makeMapTraversalGet,
ZioOpticsBuilder.makeMapTraversalSet
)
}

private[optics] def makeLensGet[S, A](
product: Schema.Record[S],
Expand Down Expand Up @@ -89,13 +97,13 @@ object ZioOpticsBuilder extends AccessorBuilder {
}
}

private[optics] def makeTraversalGet[S, A](
private[optics] def makeSeqTraversalGet[S, A](
collection: Schema.Sequence[S, A]
): S => Either[(OpticFailure, S), Chunk[A]] = { whole: S =>
Right(collection.toChunk(whole))
}

private[optics] def makeTraversalSet[S, A](
private[optics] def makeSeqTraversalSet[S, A](
collection: Schema.Sequence[S, A]
): Chunk[A] => S => Either[(OpticFailure, S), S] = { (piece: Chunk[A]) => (whole: S) =>
val builder = ChunkBuilder.make[A]()
Expand All @@ -111,6 +119,15 @@ object ZioOpticsBuilder extends AccessorBuilder {
Right(collection.fromChunk(builder.result()))
}

private[optics] def makeMapTraversalGet[K, V](whole: Map[K, V]): Either[(OpticFailure, Map[K, V]), Chunk[(K, V)]] =
Right(Chunk.fromIterable(whole))

private[optics] def makeMapTraversalSet[K, V]
: Chunk[(K, V)] => Map[K, V] => Either[(OpticFailure, Map[K, V]), Map[K, V]] = {
(piece: Chunk[(K, V)]) => (whole: Map[K, V]) =>
Right(whole ++ piece.toList)
}

private def spliceRecord(
fields: ListMap[String, DynamicValue],
label: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,21 @@ object ProtobufCodec extends Codec {
(schema, value) match {
case (Schema.GenericRecord(structure, _), v: Map[String, _]) => encodeRecord(fieldNumber, structure.toChunk, v)
case (Schema.Sequence(element, _, g, _), v) => encodeSequence(fieldNumber, element, g(v))
case (Schema.Transform(codec, _, g, _), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType, _), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right, _), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec, _), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right, _), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (lzy @ Schema.Lazy(_), v) => encode(fieldNumber, lzy.schema, v)
case (Schema.Meta(ast, _), _) => encode(fieldNumber, Schema[SchemaAst], ast)
case ProductEncoder(encode) => encode(fieldNumber)
case (Schema.Enum1(c, _), v) => encodeEnum(fieldNumber, v, c)
case (Schema.Enum2(c1, c2, _), v) => encodeEnum(fieldNumber, v, c1, c2)
case (Schema.Enum3(c1, c2, c3, _), v) => encodeEnum(fieldNumber, v, c1, c2, c3)
case (Schema.EnumN(cs, _), v) => encodeEnum(fieldNumber, v, cs.toSeq: _*)
case (_, _) => Chunk.empty
case (Schema.MapSchema(ks, vs, _), v: Map[k, v]) =>
encodeSequence(fieldNumber, ks <*> vs, Chunk.fromIterable(v))
case (Schema.Transform(codec, _, g, _), _) => g(value).map(encode(fieldNumber, codec, _)).getOrElse(Chunk.empty)
case (Schema.Primitive(standardType, _), v) => encodePrimitive(fieldNumber, standardType, v)
case (Schema.Tuple(left, right, _), v @ (_, _)) => encodeTuple(fieldNumber, left, right, v)
case (Schema.Optional(codec, _), v: Option[_]) => encodeOptional(fieldNumber, codec, v)
case (Schema.EitherSchema(left, right, _), v: Either[_, _]) => encodeEither(fieldNumber, left, right, v)
case (lzy @ Schema.Lazy(_), v) => encode(fieldNumber, lzy.schema, v)
case (Schema.Meta(ast, _), _) => encode(fieldNumber, Schema[SchemaAst], ast)
case ProductEncoder(encode) => encode(fieldNumber)
case (Schema.Enum1(c, _), v) => encodeEnum(fieldNumber, v, c)
case (Schema.Enum2(c1, c2, _), v) => encodeEnum(fieldNumber, v, c1, c2)
case (Schema.Enum3(c1, c2, c3, _), v) => encodeEnum(fieldNumber, v, c1, c2, c3)
case (Schema.EnumN(cs, _), v) => encodeEnum(fieldNumber, v, cs.toSeq: _*)
case (_, _) => Chunk.empty
}

private def encodeEnum[Z](fieldNumber: Option[Int], value: Z, cases: Schema.Case[_, Z]*): Chunk[Byte] = {
Expand Down Expand Up @@ -443,6 +445,15 @@ object ProtobufCodec extends Codec {
},
true
)
case Schema.MapSchema(ks: Schema[k], vs: Schema[v], _) =>
decoder(
Schema.Sequence(
ks <*> vs,
(c: Chunk[(k, v)]) => c.toList.toMap,
(m: Map[k, v]) => Chunk.fromIterable(m),
Chunk.empty
)
)
case Schema.Transform(codec, f, _, _) => transformDecoder(codec, f)
case Schema.Primitive(standardType, _) => primitiveDecoder(standardType)
case Schema.Tuple(left, right, _) => tupleDecoder(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,18 @@ object ProtobufCodecSpec extends DefaultRunnableSpec {
ed2 <- encodeAndDecodeNS(sequenceOfSumSchema, richSequence)
} yield assert(ed)(equalTo(Chunk(richSequence))) && assert(ed2)(equalTo(richSequence))
},
testM("map of products") {
val m: Map[Record, MyRecord] = Map(
Record("AAA", 1) -> MyRecord(1),
Record("BBB", 2) -> MyRecord(2)
)

val mSchema = Schema.map(Record.schemaRecord, myRecord)
for {
ed <- encodeAndDecode(mSchema, m)
ed2 <- encodeAndDecodeNS(mSchema, m)
} yield assert(ed)(equalTo(Chunk.succeed(m))) && assert(ed2)(equalTo(m))
},
testM("recursive data types") {
checkM(SchemaGen.anyRecursiveTypeAndValue) {
case (schema, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ trait AccessorBuilder {

def makePrism[S, A](sum: Schema.Enum[S], term: Schema.Case[A, S]): Prism[S, A]

def makeTraversal[S, A](collection: Schema.Sequence[S, A], element: Schema[A]): Traversal[S, A]
def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A]
}
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ object DynamicValue {
case Schema.Sequence(schema, _, toChunk, _) =>
DynamicValue.Sequence(toChunk(value).map(fromSchemaAndValue(schema, _)))

case Schema.MapSchema(ks: Schema[k], vs: Schema[v], _) =>
val entries = value.asInstanceOf[Map[k, v]].map {
case (key, value) => (fromSchemaAndValue(ks, key), fromSchemaAndValue(vs, value))
}
DynamicValue.Dictionary(Chunk.fromIterable(entries))

case Schema.EitherSchema(left, right, _) =>
value match {
case Left(a) => DynamicValue.LeftValue(fromSchemaAndValue(left, a))
Expand Down Expand Up @@ -1839,6 +1845,8 @@ object DynamicValue {

final case class Sequence(values: Chunk[DynamicValue]) extends DynamicValue

final case class Dictionary[K, V](entries: Chunk[(DynamicValue, DynamicValue)]) extends DynamicValue

sealed case class Primitive[A](value: A, standardType: StandardType[A]) extends DynamicValue

sealed case class Singleton[A](instance: A) extends DynamicValue
Expand Down
17 changes: 16 additions & 1 deletion zio-schema/shared/src/main/scala/zio/schema/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
implicit def chunk[A](implicit schemaA: Schema[A]): Schema[Chunk[A]] =
Schema.Sequence[Chunk[A], A](schemaA, identity, identity, Chunk.empty)

implicit def map[K, V](implicit ks: Schema[K], vs: Schema[V]): Schema[Map[K, V]] =
Schema.MapSchema(ks, vs, Chunk.empty)

implicit def either[A, B](implicit left: Schema[A], right: Schema[B]): Schema[Either[A, B]] =
EitherSchema(left, right)

Expand Down Expand Up @@ -239,12 +242,14 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
def rawConstruct(values: Chunk[Any]): Either[String, R]
}

sealed trait Collection[Col, Elem] extends Schema[Col]

final case class Sequence[Col, Elem](
schemaA: Schema[Elem],
fromChunk: Chunk[Elem] => Col,
toChunk: Col => Chunk[Elem],
override val annotations: Chunk[Any]
) extends Schema[Col] { self =>
) extends Collection[Col, Elem] { self =>
override type Accessors[Lens[_, _], Prism[_, _], Traversal[_, _]] = Traversal[Col, Elem]

override def annotate(annotation: Any): Sequence[Col, Elem] = copy(annotations = annotations :+ annotation)
Expand Down Expand Up @@ -376,6 +381,16 @@ object Schema extends TupleSchemas with RecordSchemas with EnumSchemas {
override def makeAccessors(b: AccessorBuilder): Unit = ()

}

final case class MapSchema[K, V](ks: Schema[K], vs: Schema[V], override val annotations: Chunk[Any])
extends Collection[Map[K, V], (K, V)] { self =>
override type Accessors[Lens[_, _], Prism[_, _], Traversal[_, _]] = Traversal[Map[K, V], (K, V)]

override def annotate(annotation: Any): MapSchema[K, V] = copy(annotations = annotations :+ annotation)

override def makeAccessors(b: AccessorBuilder): b.Traversal[Map[K, V], (K, V)] =
b.makeTraversal(self, ks <*> vs)
}
}

//scalafmt: { maxColumn = 400 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ object SchemaAst {
.buildProduct()
case Schema.Sequence(schema, _, _, _) =>
subtree(NodePath.root, Chunk.empty, schema, dimensions = 1)
case Schema.MapSchema(ks, vs, _) =>
NodeBuilder(NodePath.root, Chunk.empty, optional = false, dimensions = 1)
.addLabelledSubtree("key", ks)
.addLabelledSubtree("value", vs)
.buildProduct()
case Schema.Transform(schema, _, _, _) => subtree(NodePath.root, Chunk.empty, schema)
case lzy @ Schema.Lazy(_) => fromSchema(lzy.schema)
case s: Schema.Record[A] =>
Expand Down Expand Up @@ -235,6 +240,8 @@ object SchemaAst {
.buildProduct()
case Schema.Sequence(schema, _, _, _) =>
subtree(path, lineage, schema, optional, dimensions + 1)
case Schema.MapSchema(ks, vs, _) =>
subtree(path, lineage, ks <*> vs, optional = false, dimensions + 1)
case Schema.Transform(schema, _, _, _) => subtree(path, lineage, schema, optional, dimensions)
case lzy @ Schema.Lazy(_) => subtree(path, lineage, lzy.schema, optional, dimensions)
case s: Schema.Record[_] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object DynamicValueGen {
case Schema.Enum22(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22, _) => anyDynamicValueOfEnum(Chunk(case1, case2, case3, case4, case5, case6, case7, case8, case9, case10, case11, case12, case13, case14, case15, case16, case17, case18, case19, case20, case21, case22))
case Schema.EnumN(cases, _) => anyDynamicValueOfEnum(Chunk.fromIterable(cases.toSeq))
case Schema.Sequence(schema, _, _, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(schema)).map(DynamicValue.Sequence(_))
case Schema.MapSchema(ks, vs, _) => Gen.chunkOfBounded(0, 2)(anyDynamicValueOfSchema(ks).zip(anyDynamicValueOfSchema(vs))).map(DynamicValue.Dictionary(_))
case Schema.Optional(schema, _) => Gen.oneOf(anyDynamicSomeValueOfSchema(schema), Gen.const(DynamicValue.NoneValue))
case Schema.Tuple(left, right, _) => anyDynamicTupleValue(left, right)
case Schema.EitherSchema(left, right, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestAccessorBuilder extends AccessorBuilder {
override def makePrism[S, A](sum: Schema.Enum[S], term: schema.Schema.Case[A, S]): Prism[S, A] =
TestAccessorBuilder.Prism(sum, term)

override def makeTraversal[S, A](collection: Schema.Sequence[S, A], element: Schema[A]): Traversal[S, A] =
override def makeTraversal[S, A](collection: Schema.Collection[S, A], element: Schema[A]): Traversal[S, A] =
TestAccessorBuilder.Traversal(collection, element)
}

Expand All @@ -22,5 +22,5 @@ object TestAccessorBuilder {

case class Prism[S, A](s: Schema.Enum[S], a: Schema.Case[A, S])

case class Traversal[S, A](s: Schema.Sequence[S, A], a: Schema[A])
case class Traversal[S, A](s: Schema.Collection[S, A], a: Schema[A])
}