Skip to content
This repository has been archived by the owner on Sep 17, 2022. It is now read-only.

refactored pattern matching and added constructor case #53

Merged
merged 1 commit into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ import zio.morphir.ir.Pattern
import zio.morphir.ir.NativeFunction.*

import zio.Chunk
import zio.prelude._

import scala.collection.immutable.ListMap
import zio.morphir.ir.ValueModule.Value
import zio.morphir.ir.TypeModule.Specification.TypeAliasSpecification
import zio.morphir.IRModule
object Interpreter {

final case class Variables(map: Map[Name, Result])

sealed trait Result

object Result {
Expand Down Expand Up @@ -151,16 +150,16 @@ object Interpreter {
var newVariables: Map[Name, Result] = Map.empty
while (i < length) {
matchPattern(evaluatedBody, casesChunk(i)._1) match {
case MatchResult.Success(variables) =>
case Right(variables) =>
rightHandSide = casesChunk(i)._2
newVariables = variables.map { case (key, value) => key -> Result.Strict(value) }
i = length
case MatchResult.Failure(_, _) =>
case Left(_) =>
i += 1
}
}

if (rightHandSide eq null) throw new InterpretationError.MatchError("didn't match")
if (rightHandSide eq null) throw new InterpretationError.MatchError(s"could not match $evaluatedBody")
else loop(rightHandSide, variables ++ newVariables, references)

case RecordCase(fields) =>
Expand Down Expand Up @@ -235,13 +234,13 @@ object Interpreter {
case LambdaCase(argumentPattern, body) =>
(input: Any) =>
matchPattern(input, argumentPattern) match {
case MatchResult.Success(newVariables) =>
case Right(newVariables) =>
loop(
body,
variables ++ newVariables.map { case (key, value) => key -> Result.Strict(value) },
references
)
case MatchResult.Failure(pattern, input) =>
case Left(MatchFailure(pattern, input)) =>
throw new InterpretationError.MatchError(
s"Pattern $pattern didn't match input $input"
)
Expand All @@ -250,13 +249,13 @@ object Interpreter {
case DestructureCase(pattern, valueToDestruct, inValue) =>
val evaluatedValueToDestruct = loop(valueToDestruct, variables, references)
matchPattern(evaluatedValueToDestruct, pattern) match {
case MatchResult.Success(newVariables) =>
case Right(newVariables) =>
loop(
inValue,
variables ++ newVariables.map { case (key, value) => key -> Result.Strict(value) },
references
)
case MatchResult.Failure(pattern, input) =>
case Left(MatchFailure(pattern, input)) =>
throw new InterpretationError.MatchError(
s"Pattern $pattern didn't match input $input"
)
Expand All @@ -271,63 +270,70 @@ object Interpreter {
}
}

sealed trait MatchResult
object MatchResult {
final case class Failure(body: Any, caseStatement: Pattern[Any]) extends MatchResult
final case class Success(variables: Map[Name, Any]) extends MatchResult
}
case class MatchFailure(body: Any, caseStatement: Pattern[Any])
type Variables = Map[Name, Any]
type MatchResult = Either[MatchFailure, Variables]

def matchPattern(body: Any, caseStatement: Pattern[Any]): MatchResult = {
val unitValue: Unit = ()
val err = MatchResult.Failure(body, caseStatement)
import Pattern._
val noMatch = Left(MatchFailure(body, caseStatement))
val empty: Variables = Map.empty
def helper(bodies: Chunk[Any], caseStatements: Chunk[Pattern[Any]]): MatchResult = {
if (bodies.length != caseStatements.length) noMatch
else
bodies
.zip(caseStatements)
.forEach((body, caseStatement) => matchPattern(body, caseStatement))
.map(_.foldLeft(empty)(_ ++ _))
}
println(s"Attempting to match $body vs $caseStatement")
caseStatement match {
case Pattern.AsPattern(pattern, name, _) =>
matchPattern(body, pattern) match {
case MatchResult.Success(_) =>
MatchResult.Success(Map.empty + (name -> body))
case x: MatchResult.Failure => x
case WildcardPattern(_) => Right(empty)
case AsPattern(pattern, name, _) =>
val result = matchPattern(body, pattern)
result match {
case Right(vars) => println("As matched properly"); Right(vars + (name -> body))
case Left(blah) => println(s"As failed to match: $result"); Left(blah)
}
case TuplePattern(patterns, _) =>
try {
helper(tupleToChunk(body), patterns)
} catch {
case _ => noMatch
}

case ConstructorPattern(patternName, patternArgs, _) =>
body match {
case GenericCaseClass(fqName, args) =>
if (fqName != GenericCaseClass.fqNameToGenericCaseClassName(patternName)) noMatch
else
helper(args.values.toChunk, patternArgs)

case _ => noMatch
}
case Pattern.ConstructorPattern(_, _, _) =>
???
case Pattern.EmptyListPattern(_) =>
if (body == Nil) MatchResult.Success(Map.empty) else err
case Pattern.HeadTailPattern(headPattern, tailPattern, _) =>
case EmptyListPattern(_) =>
body match {
case Nil => Right(empty)
case _ => noMatch
}
case HeadTailPattern(headPattern, tailPattern, _) =>
body match {
case head :: tail =>
val headMatchResult = matchPattern(head, headPattern)
val tailMatchResult = matchPattern(tail, tailPattern)
(headMatchResult, tailMatchResult) match {
case (MatchResult.Success(headVariables), MatchResult.Success(tailVariables)) =>
MatchResult.Success(headVariables ++ tailVariables)
case (failure: MatchResult.Failure, _) => failure
case (_, failure: MatchResult.Failure) => failure
}
case _ => err
for {
headVars <- matchPattern(head, headPattern)
tailVars <- matchPattern(tail, tailPattern)
} yield headVars ++ tailVars
case _ =>
println(s"I do not recognize $body as a list"); noMatch
}
case Pattern.LiteralPattern(literal, _) =>
if (body == literal.value) MatchResult.Success(Map.empty) else err
case Pattern.UnitPattern(_) =>
if (body == unitValue) MatchResult.Success(Map.empty) else err
case Pattern.TuplePattern(patterns, _) =>
def helper(remainingBody: List[Any], remainingPattern: List[Pattern[_]]): MatchResult =
(remainingBody, remainingPattern) match {
case (Nil, Nil) => MatchResult.Success(Map.empty)
case (_, Nil) => err
case (Nil, _) => err
case (b :: bs, t :: ts) =>
(helper(bs, ts), matchPattern(b, t)) match {
case (MatchResult.Success(m1), MatchResult.Success(m2)) => MatchResult.Success(m1 ++ m2)
case _ => err
}
}

try {
helper(tupleToChunk(body).toList, patterns.toList)
} catch {
case _: InterpretationError.MatchError => err
case LiteralPattern(patternLiteral, _) =>
if (body == patternLiteral.value) Right(empty) else noMatch
case UnitPattern(_) =>
println("Checking unit case")
body match {
case () => Right(empty)
case _ => noMatch
}
case Pattern.WildcardPattern(_) =>
MatchResult.Success(Map.empty)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ object InterpreterSpec extends MorphirBaseSpec {
assertTrue(evaluate(patternTupleOneCaseCounterExample) == Right("right"))
}
),
suite("constructor")(
test("Should evaluate correctly") {
assertTrue(evaluate(patternConstructorCaseExample) == Right(new BigInteger("10000")))
}
),
suite("head tail list")(
test("Should evaluate correctly") {
assertTrue(evaluate(patternHeadTailCaseExample) == Right(List("world")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ object CaseExample extends ValueSyntax with TypeSyntax {
) -> Dsl.wholeNumber(new java.math.BigInteger("107"))
)

lazy val patternConstructorCaseExample =
Dsl.patternMatch(
checkingAccountConstructorExample,
constructorPattern(
checkingAccountTypeName,
Chunk(wildcardPattern, asPattern(wildcardPattern, Name("x")))
) -> variable(Name("x"))
)

val patternUnitCaseExample =
Dsl.patternMatch(
unit,
Expand Down