Skip to content

Commit

Permalink
Opaque types support
Browse files Browse the repository at this point in the history
  • Loading branch information
OndrejSpanel committed Dec 4, 2024
1 parent b7506f5 commit d013799
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,20 @@ object QuicklensMacros {
def symbolAccessorByNameOrError(obj: Term, name: String): Term = {
val objTpe = obj.tpe.widenAll
val objSymbol = objTpe.matchingTypeSymbol
val mem = objSymbol.fieldMember(name)
// opaque types could find members of underlying types - do not ask them (see https://github.com/scala/scala3/issues/22143)
val mem = if !objSymbol.flags.is(Flags.Deferred) then objSymbol.fieldMember(name) else Symbol.noSymbol
if (mem != Symbol.noSymbol)
Select(obj, mem)
else
//Select(obj, mem)
objSymbol.methodMember(name) match
case List(m) =>
Select(obj, m)
case Nil =>
findExtensionMethod(objSymbol, name) match {
findExtensionMethod(objSymbol, name) match
case List((owner, extension)) =>
Apply(Select(owner, extension), List(obj))
case syms =>
reportMethodError(objSymbol, name, syms.map(_._2))
}
case lst =>
report.errorAndAbort(multipleMatchingMethods(objSymbol.name, name, lst))
}
Expand Down Expand Up @@ -226,8 +225,10 @@ object QuicklensMacros {
}

def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Option[Symbol] = {
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
if !sym.flags.is(Flags.Deferred) then
val memberMethods = sym.methodMember(name)
filterMethodsByNameAndArgs(memberMethods, argsMap)
else None
}

/**
Expand All @@ -253,7 +254,6 @@ object QuicklensMacros {
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
// default values in extensions are obtained by calling a method receiving the extension parameter
val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values)
//println(s"defaultMethodArgs ${obj.show} ${methodSymbol.name} $defaultMethodArgs")
if defaultMethodArgs.nonEmpty then
Apply(Select(obj, methodSymbol), defaultMethodArgs)
else
Expand Down Expand Up @@ -295,12 +295,20 @@ object QuicklensMacros {
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def findCompanionLikeObject(objSymbol: Symbol): Option[Symbol] = {
def optSymbol(objSymbol: Symbol) = Option.when(!objSymbol.isNoSymbol)(objSymbol)
optSymbol(objSymbol.companionModule).orElse {
// for opaque types, the companion type is not found by objSymbol.companionModule
// try to find an object by name in the owner scope
optSymbol(objSymbol.owner.fieldMember(objSymbol.name)).filter(_.flags.is(Flags.Module))
}
}
def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[(Term, Symbol)] = {
// TODO: can we check parameter types somehow?
def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1)

// TODO: try to search in symbol parent object as well
val symbols = Seq(sym.companionModule).filter(_ != Symbol.noSymbol)
// TODO: try to search in symbol parent scope as well, as extension methods could be located there as well
val symbols = findCompanionLikeObject(sym).filter(_ != Symbol.noSymbol).toList

symbols.flatMap(s => s.declaredMethods.map(Ref(s) -> _)).filter((_, m) => m.name == methodName && isExtensionMethod(m)).toList
}
Expand Down Expand Up @@ -356,17 +364,17 @@ object QuicklensMacros {
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap).filter(m => m.owner == objSymbol) match
case Some(copy) =>
callMethod(obj, copy, List(argsMap))
case None =>
val objCompanion = objSymbol.companionModule
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap) match
val objCompanion = findCompanionLikeObject(objSymbol)
objCompanion.flatMap(methodSymbolByNameAndArgs(_, "copy", argsMap)) match
case Some(copy) =>
// now try to call the extension as a method, assume the object is its first parameter
val extensionParameter = copy.paramSymss.headOption.map(_.headOption).flatten
val argsWithObj = List(extensionParameter.map(name => name.name -> obj).toMap, argsMap)
callMethod(Ref(objCompanion), copy, argsWithObj)
callMethod(Ref(objCompanion.get), copy, argsWithObj)
case None => report.errorAndAbort(noSuchMember(objSymbol.name, "copy"))
} else
report.errorAndAbort(s"Unsupported source object: must be a case class, sealed trait or class with copy method, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
Expand Down Expand Up @@ -429,6 +437,7 @@ object QuicklensMacros {
def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, pathTree: PathTree): Term = pathTree match {
case PathTree.Empty =>
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
// TODO: calling extension may be necessary here
Apply(Select(mod.asTerm, apply), List(objTerm))
case PathTree.Node(children) =>
accumulateToCopy(owner, mod, objTerm, children)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package com.softwaremill.quicklens
package test

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -90,5 +91,4 @@ class ExplicitCopyTest extends AnyFlatSpec with Matchers {
// val f = Frozen("A", 0)
// f.modify(_.state).setTo('B')
// }

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

object ExtensionCopyTest {
case class V(x: Double, y: Double)
case class V(x: Double, y: Double, z: Double)

opaque type Vec = V

object Vec {
def apply(x: Double, y: Double): Vec = V(x, y)
def apply(x: Double, y: Double): Vec = V(x, y, 0)

extension (v: Vec) {
def x: Double = v.x
def y: Double = v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y)
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y, 0)
}
}
}
Expand Down Expand Up @@ -55,35 +55,35 @@ class ExtensionCopyTest extends AnyFlatSpec with Matchers {
}

val a = VecCompanion(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
val b = a.modify(_.x).using(_ + 10)
assert(b.x == 11)
}

it should "modify a class with extension methods in companion" in {
case class V(x: Double, y: Double)
case class V(xm: Double, ym: Double)

class VecClass(val v: V)

object VecClass {
def apply(x: Double, y: Double): VecClass = new VecClass(V(x, y))

extension (v: VecClass) {
def x: Double = v.v.x
def y: Double = v.v.y
def x: Double = v.v.xm
def y: Double = v.v.ym
def copy(x: Double = v.x, y: Double = v.y): VecClass = new VecClass(V(x, y))
}
}

val a = VecClass(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
val b = a.modify(_.x).using(_ + 10)
assert(b.x == 11)
}
/*

it should "modify an opaque type with extension methods" in {
import ExtensionCopyTest.*

val a = Vec(1, 2)
val b = a.modify(_.x).using(_ + 1)
println(b)
val b = a.modify(_.x).using(_ + 10)
assert(b.x == 11)
}
*/
}

0 comments on commit d013799

Please sign in to comment.