Skip to content

Commit

Permalink
[fwd port] Fix local contract upgrade (#20743)
Browse files Browse the repository at this point in the history
* manual port of the SBuiltinFun change of pull/20296

* Fix local contract upgrade (#20296)

* Add metadata test for local contracts, fix logic

* factorize the code

* fix tests but also uncovered a problem with interfaces

* use toInterfaceContractId in daml-script tests

* fix typo in pretty-printer

* add positive tests

* fix exercise by interface

* always perform the upgrade validation check in the engine

* simplify error reporting

* ensure package is loaded before importing global contract

* simplify check code even more

* allow losing observers that are also signatories

* Apply suggestion

Co-authored-by: Remy <remy.haemmerle@daml.com>

---------

Co-authored-by: Remy <remy.haemmerle@daml.com>

* fix UpgradeTest

* format

* disable some tests for now

---------

Co-authored-by: Remy <remy.haemmerle@daml.com>
  • Loading branch information
paulbrauner-da and remyhaemmerle-da authored Feb 11, 2025
1 parent 6bfca01 commit 3c30b11
Show file tree
Hide file tree
Showing 5 changed files with 1,166 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private[lf] object Pretty {
char('\'') + text(p) + char('\'')

def prettyParties(p: Set[Party]): Doc =
char('{') & intercalate(char(','), p.map(prettyParty)) & char('{')
char('{') & intercalate(char(','), p.map(prettyParty)) & char('}')

def prettyDamlException(error: interpretation.Error): Doc = {
import interpretation.Error._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,10 @@ private[lf] object SBuiltinFun {
coid: V.ContractId,
interfaceId: TypeConName,
)(k: SAny => Control[Question.Update]): Control[Question.Update] = {
hardFetchTemplate(machine, coid) { (pkgName, srcTplId, srcArg) =>
hardFetchTemplate(machine, coid) { srcContract =>
val pkgName = srcContract.packageName
val srcTplId = srcContract.templateId
val srcArg = srcContract.value.asInstanceOf[SRecord]
ensureTemplateImplementsInterface(machine, interfaceId, coid, srcTplId) {
viewInterface(machine, interfaceId, srcTplId, srcArg) { srcView =>
resolvePackageName(machine, pkgName) { pkgId =>
Expand All @@ -1277,13 +1280,13 @@ private[lf] object SBuiltinFun {
dstTplId,
dstArg,
allowCatchingContractInfoErrors = false,
) { contract =>
) { dstContract =>
// If the destination and src templates are the same, we skip the computation
// of the destination template's view and the validation of the contract info.
if (dstTplId == srcTplId)
k(SAny(Ast.TTyCon(dstTplId), dstArg))
else
validateContractInfo(machine, coid, dstTplId, contract) { () =>
checkContractUpgradable(coid, srcContract, dstContract) { () =>
executeExpression(machine, SEPreventCatch(dstView)) { dstViewValue =>
if (srcViewValue != dstViewValue) {
Control.Error(
Expand Down Expand Up @@ -2317,8 +2320,8 @@ private[lf] object SBuiltinFun {
dstTmplId: TypeConName,
coid: V.ContractId,
)(f: SValue => Control[Question.Update]): Control[Question.Update] = {
def importContract(coinst: V.ContractInstance) = {
val V.ContractInstance(_, _, srcTmplId, coinstArg) = coinst
def importContract(srcContract: ContractInfo) = {
val srcTmplId = srcContract.templateId
if (srcTmplId.qualifiedName != dstTmplId.qualifiedName)
Control.Error(
IE.WronglyTypedContract(coid, dstTmplId, srcTmplId)
Expand All @@ -2328,30 +2331,30 @@ private[lf] object SBuiltinFun {
dstTmplId.packageId,
language.Reference.Template(dstTmplId),
) { () =>
importValue(machine, dstTmplId, coinstArg) { templateArg =>
importValue(machine, dstTmplId, srcContract.arg) { templateArg =>
getContractInfo(
machine,
coid,
dstTmplId,
templateArg,
allowCatchingContractInfoErrors = false,
) { contract =>
ensureContractActive(machine, coid, contract.templateId) {
) { dstContract =>
ensureContractActive(machine, coid, dstContract.templateId) {

machine.checkContractVisibility(coid, contract)
machine.checkContractVisibility(coid, dstContract)
machine.enforceLimitAddInputContract()
machine.enforceLimitSignatoriesAndObservers(coid, contract)
machine.enforceLimitSignatoriesAndObservers(coid, dstContract)

// In Validation mode, we always call validateContractInfo
// In Submission mode, we only call validateContractInfo when src != dest
val needValidationCall: Boolean =
machine.validating || srcTmplId.packageId != dstTmplId.packageId
if (needValidationCall) {
validateContractInfo(machine, coid, srcTmplId, contract) { () =>
f(contract.value)
checkContractUpgradable(coid, srcContract, dstContract) { () =>
f(dstContract.value)
}
} else {
f(contract.value)
f(dstContract.value)
}
}
}
Expand All @@ -2369,20 +2372,74 @@ private[lf] object SBuiltinFun {
templateArg,
allowCatchingContractInfoErrors = false,
) { contract =>
if (srcTmplId == dstTmplId) f(templateArg)
// If the local contract has the same package ID as the target template ID, then we don't need to
// import its value and validate its contract info again.
if (srcTmplId == dstTmplId)
f(templateArg)
else
importContract(
V.ContractInstance(
contract.packageName,
contract.packageVersion,
srcTmplId,
contract.arg,
)
)
importContract(contract)
}
}
case None =>
machine.lookupContract(coid)(importContract)
machine.lookupContract(coid)(coinst =>
machine.ensurePackageIsLoaded(
coinst.template.packageId,
language.Reference.Template(coinst.template),
) { () =>
importValue(machine, coinst.template, coinst.arg) { templateArg =>
getContractInfo(
machine,
coid,
coinst.template,
templateArg,
allowCatchingContractInfoErrors = false,
)(importContract)
}
}
)
}
}

/** Checks that the metadata of [original] and [recomputed] are the same, fails with a [Control.Error] if not. */
private def checkContractUpgradable(
coid: V.ContractId,
original: ContractInfo,
recomputed: ContractInfo,
)(
k: () => Control[Question.Update]
): Control[Question.Update] = {

def check[T](getter: ContractInfo => T, desc: String): Option[String] =
Option.when(getter(recomputed) != getter(original))(
s"$desc mismatch: $original vs $recomputed"
)

List(
check(_.signatories, "signatories"),
// This definition of observers allows observers to lose parties that are signatories
check(_.stakeholders, "stakeholders"),
check(_.keyOpt.map(_.maintainers), "key maintainers"),
check(_.keyOpt.map(_.globalKey.key), "key value"),
).flatten match {
case Nil => k()
case errors =>
Control.Error(
IE.Dev(
NameOf.qualifiedNameOfCurrentFunc,
IE.Dev.Upgrade(
// TODO(https://github.com/digital-asset/daml/issues/20305): also include the original metadata
IE.Dev.Upgrade.ValidationFailed(
coid = coid,
srcTemplateId = original.templateId,
dstTemplateId = recomputed.templateId,
signatories = recomputed.signatories,
observers = recomputed.observers,
keyOpt = recomputed.keyOpt.map(_.globalKeyWithMaintainers),
msg = errors.mkString("['", "', '", "']"),
)
),
)
)
}
}

Expand All @@ -2392,9 +2449,7 @@ private[lf] object SBuiltinFun {
private def hardFetchTemplate(
machine: UpdateMachine,
coid: V.ContractId,
)(
k: (Ref.PackageName, Ref.TypeConName, SRecord) => Control[Question.Update]
): Control[Question.Update] = {
)(k: ContractInfo => Control[Question.Update]): Control[Question.Update] = {
machine.getIfLocalContract(coid) match {
case Some((templateId, templateArg)) =>
ensureContractActive(machine, coid, templateId) {
Expand All @@ -2404,9 +2459,7 @@ private[lf] object SBuiltinFun {
templateId,
templateArg,
allowCatchingContractInfoErrors = false,
) { contract =>
k(contract.packageName, templateId, templateArg.asInstanceOf[SRecord])
}
)(k)
}
case None =>
machine.lookupContract(coid) { case V.ContractInstance(_, _, srcTmplId, coinstArg) =>
Expand All @@ -2423,26 +2476,10 @@ private[lf] object SBuiltinFun {
allowCatchingContractInfoErrors = false,
) { contract =>
ensureContractActive(machine, coid, contract.templateId) {

machine.checkContractVisibility(coid, contract)
machine.enforceLimitAddInputContract()
machine.enforceLimitSignatoriesAndObservers(coid, contract)

if (machine.validating) {
validateContractInfo(machine, coid, srcTmplId, contract) { () =>
k(
contract.packageName,
contract.templateId,
contract.value.asInstanceOf[SRecord],
)
}
} else {
k(
contract.packageName,
contract.templateId,
contract.value.asInstanceOf[SRecord],
)
}
k(contract)
}
}
}
Expand All @@ -2451,50 +2488,6 @@ private[lf] object SBuiltinFun {
}
}

private def validateContractInfo(
machine: UpdateMachine,
coid: V.ContractId,
srcTemplateId: Ref.Identifier,
contract: ContractInfo,
)(
continue: () => Control[Question.Update]
): Control[Question.Update] = {

val keyOpt: Option[GlobalKeyWithMaintainers] = contract.keyOpt match {
case None => None
case Some(cachedKey) =>
Some(cachedKey.globalKeyWithMaintainers)
}
machine.needUpgradeVerification(
location = NameOf.qualifiedNameOfCurrentFunc,
coid = coid,
signatories = contract.signatories,
observers = contract.observers,
keyOpt = keyOpt,
continue = {
case None =>
continue()
case Some(msg) =>
Control.Error(
IE.Dev(
NameOf.qualifiedNameOfCurrentFunc,
IE.Dev.Upgrade(
IE.Dev.Upgrade.ValidationFailed(
coid = coid,
srcTemplateId = srcTemplateId,
dstTemplateId = contract.templateId,
signatories = contract.signatories,
observers = contract.observers,
keyOpt = keyOpt,
msg = msg,
)
),
)
)
},
)
}

private def importValue[Q](machine: Machine[Q], templateId: TypeConName, coinstArg: V)(
f: SValue => Control[Q]
): Control[Q] = {
Expand Down
Loading

0 comments on commit 3c30b11

Please sign in to comment.