Skip to content

Commit e42fe7c

Browse files
authored
[RF][HS3] Patched ParamHistFuncs | Minor clean up
* Fix Constant Flag in Data axes * ParamHistFuncs accept custom modifiers * Store binning Information with ParamHistFunc * Add Warning to ParamHistFunc Compatibility * clang-format clean-up * Update warning message for histogram binnings version * Update warning message for default binning
1 parent 1e660dc commit e42fe7c

File tree

7 files changed

+161
-31
lines changed

7 files changed

+161
-31
lines changed

roofit/hs3/inc/RooFitHS3/RooJSONFactoryWSTool.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ class RooJSONFactoryWSTool {
229229
void importVariable(const RooFit::Detail::JSONNode &p);
230230
void importDependants(const RooFit::Detail::JSONNode &n);
231231

232-
void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &p);
233-
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n);
232+
void exportVariable(const RooAbsArg *v, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);
233+
void exportVariables(const RooArgSet &allElems, RooFit::Detail::JSONNode &n, bool storeConstant, bool storeBins);
234234

235235
void exportAllObjects(RooFit::Detail::JSONNode &n);
236236

roofit/hs3/src/JSONFactories_HistFactory.cxx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,16 @@ void collectElements(RooArgSet &elems, RooAbsArg *arg)
823823
}
824824
}
825825

826+
bool allRooRealVar(const RooAbsCollection &list)
827+
{
828+
for (auto *var : list) {
829+
if (!dynamic_cast<RooRealVar *>(var)) {
830+
return false;
831+
}
832+
}
833+
return true;
834+
}
835+
826836
struct Sample {
827837
std::string name;
828838
std::vector<double> hist;
@@ -920,7 +930,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
920930
addNormFactor(par, sample, ws);
921931
} else if (auto hf = dynamic_cast<const RooHistFunc *>(e)) {
922932
updateObservables(hf->dataHist());
923-
} else if (auto phf = dynamic_cast<ParamHistFunc *>(e)) {
933+
} else if (ParamHistFunc *phf = dynamic_cast<ParamHistFunc *>(e); phf && allRooRealVar(phf->paramList())) {
924934
phfs.push_back(phf);
925935
} else if (auto fip = dynamic_cast<RooStats::HistFactory::FlexibleInterpVar *>(e)) {
926936
// some (modified) histfactory models have several instances of FlexibleInterpVar

roofit/hs3/src/JSONFactories_RooFitCore.cxx

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <RooAbsCachedPdf.h>
1616
#include <RooAddPdf.h>
1717
#include <RooAddModel.h>
18+
#include <RooBinning.h>
1819
#include <RooBinSamplingPdf.h>
1920
#include <RooBinWidthFunction.h>
2021
#include <RooCategory.h>
@@ -33,6 +34,7 @@
3334
#include <RooLegacyExpPoly.h>
3435
#include <RooLognormal.h>
3536
#include <RooMultiVarGaussian.h>
37+
#include <RooStats/HistFactory/ParamHistFunc.h>
3638
#include <RooPoisson.h>
3739
#include <RooPolynomial.h>
3840
#include <RooPolyVar.h>
@@ -532,6 +534,71 @@ class RooMultiVarGaussianFactory : public RooFit::JSONIO::Importer {
532534
}
533535
};
534536

537+
class ParamHistFuncFactory : public RooFit::JSONIO::Importer {
538+
public:
539+
bool importArg(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
540+
{
541+
std::string name(RooJSONFactoryWSTool::name(p));
542+
RooArgList varList = tool->requestArgList<RooRealVar>(p, "variables");
543+
if (!p.has_child("axes")) {
544+
std::stringstream ss;
545+
ss << "No axes given in '" << name << "'"
546+
<< ". Using default binning (uniform; nbins=100). If needed, export the Workspace to JSON with a newer "
547+
<< "Root version that supports custom ParamHistFunc binnings(>=6.38.00)." << std::endl;
548+
RooJSONFactoryWSTool::warning(ss.str());
549+
tool->wsEmplace<ParamHistFunc>(name, varList, tool->requestArgList<RooAbsReal>(p, "parameters"));
550+
return true;
551+
}
552+
tool->wsEmplace<ParamHistFunc>(name, readBinning(p, varList), tool->requestArgList<RooAbsReal>(p, "parameters"));
553+
return true;
554+
}
555+
556+
private:
557+
RooArgList readBinning(const JSONNode &topNode, const RooArgList &varList) const
558+
{
559+
// Temporary map from variable name → RooRealVar
560+
std::map<std::string, std::unique_ptr<RooRealVar>> varMap;
561+
562+
// Build variables from JSON
563+
for (const JSONNode &node : topNode["axes"].children()) {
564+
const std::string name = node["name"].val();
565+
std::unique_ptr<RooRealVar> obs;
566+
567+
if (node.has_child("edges")) {
568+
std::vector<double> edges;
569+
for (const auto &bound : node["edges"].children()) {
570+
edges.push_back(bound.val_double());
571+
}
572+
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), edges.front(), edges.back());
573+
RooBinning bins(obs->getMin(), obs->getMax());
574+
for (auto b : edges)
575+
bins.addBoundary(b);
576+
obs->setBinning(bins);
577+
} else {
578+
obs = std::make_unique<RooRealVar>(name.c_str(), name.c_str(), node["min"].val_double(),
579+
node["max"].val_double());
580+
obs->setBins(node["nbins"].val_int());
581+
}
582+
583+
varMap[name] = std::move(obs);
584+
}
585+
586+
// Now build the final list following the order in varList
587+
RooArgList vars;
588+
for (int i = 0; i < varList.getSize(); ++i) {
589+
const auto *refVar = dynamic_cast<RooRealVar *>(varList.at(i));
590+
if (!refVar)
591+
continue;
592+
593+
auto it = varMap.find(refVar->GetName());
594+
if (it != varMap.end()) {
595+
vars.addOwned(std::move(it->second)); // preserve ownership
596+
}
597+
}
598+
return vars;
599+
}
600+
};
601+
535602
///////////////////////////////////////////////////////////////////////////////////////////////////////
536603
// specialized exporter implementations
537604
///////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -696,6 +763,7 @@ class RooFormulaArgStreamer : public RooFit::JSONIO::Exporter {
696763
expr.ReplaceAll("TMath::Sin", "sin");
697764
expr.ReplaceAll("TMath::Sqrt", "sqrt");
698765
expr.ReplaceAll("TMath::Power", "pow");
766+
expr.ReplaceAll("TMath::Erf", "erf");
699767
}
700768
};
701769
template <class RooArg_t>
@@ -952,6 +1020,47 @@ class RooExtendPdfStreamer : public RooFit::JSONIO::Exporter {
9521020
}
9531021
};
9541022

1023+
class ParamHistFuncStreamer : public RooFit::JSONIO::Exporter {
1024+
public:
1025+
std::string const &key() const override;
1026+
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
1027+
{
1028+
auto *pdf = static_cast<const ParamHistFunc *>(func);
1029+
elem["type"] << key();
1030+
RooJSONFactoryWSTool::fillSeq(elem["variables"], pdf->dataVars());
1031+
RooJSONFactoryWSTool::fillSeq(elem["parameters"], pdf->paramList());
1032+
writeBinningInfo(pdf, elem);
1033+
return true;
1034+
}
1035+
1036+
private:
1037+
void writeBinningInfo(const ParamHistFunc *pdf, JSONNode &elem) const
1038+
{
1039+
auto &observablesNode = elem["axes"].set_seq();
1040+
// axes have to be ordered to get consistent bin indices
1041+
for (auto *var : static_range_cast<RooRealVar *>(pdf->dataVars())) {
1042+
std::string name = var->GetName();
1043+
RooJSONFactoryWSTool::testValidName(name, false);
1044+
JSONNode &obsNode = observablesNode.append_child().set_map();
1045+
obsNode["name"] << name;
1046+
if (var->getBinning().isUniform()) {
1047+
obsNode["min"] << var->getMin();
1048+
obsNode["max"] << var->getMax();
1049+
obsNode["nbins"] << var->getBins();
1050+
} else {
1051+
auto &edges = obsNode["edges"];
1052+
edges.set_seq();
1053+
double val = var->getBinning().binLow(0);
1054+
edges.append_child() << val;
1055+
for (int i = 0; i < var->getBinning().numBins(); ++i) {
1056+
val = var->getBinning().binHigh(i);
1057+
edges.append_child() << val;
1058+
}
1059+
}
1060+
}
1061+
}
1062+
};
1063+
9551064
#define DEFINE_EXPORTER_KEY(class_name, name) \
9561065
std::string const &class_name::key() const \
9571066
{ \
@@ -989,6 +1098,7 @@ DEFINE_EXPORTER_KEY(RooRealIntegralStreamer, "integral");
9891098
DEFINE_EXPORTER_KEY(RooDerivativeStreamer, "derivative");
9901099
DEFINE_EXPORTER_KEY(RooFFTConvPdfStreamer, "fft_conv_pdf");
9911100
DEFINE_EXPORTER_KEY(RooExtendPdfStreamer, "extend_pdf");
1101+
DEFINE_EXPORTER_KEY(ParamHistFuncStreamer, "step");
9921102

9931103
///////////////////////////////////////////////////////////////////////////////////////////////////////
9941104
// instantiate all importers and exporters
@@ -1021,6 +1131,7 @@ STATIC_EXECUTE([]() {
10211131
registerImporter<RooDerivativeFactory>("derivative", false);
10221132
registerImporter<RooFFTConvPdfFactory>("fft_conv_pdf", false);
10231133
registerImporter<RooExtendPdfFactory>("extend_pdf", false);
1134+
registerImporter<ParamHistFuncFactory>("step", false);
10241135

10251136
registerExporter<RooAddPdfStreamer<RooAddPdf>>(RooAddPdf::Class(), false);
10261137
registerExporter<RooAddPdfStreamer<RooAddModel>>(RooAddModel::Class(), false);
@@ -1047,6 +1158,7 @@ STATIC_EXECUTE([]() {
10471158
registerExporter<RooDerivativeStreamer>(RooDerivative::Class(), false);
10481159
registerExporter<RooFFTConvPdfStreamer>(RooFFTConvPdf::Class(), false);
10491160
registerExporter<RooExtendPdfStreamer>(RooExtendPdf::Class(), false);
1161+
registerExporter<ParamHistFuncStreamer>(ParamHistFunc::Class(), false);
10501162
});
10511163

10521164
} // namespace

roofit/hs3/src/RooFitHS3_wsexportkeys.cxx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ auto RooFitHS3_wsexportkeys = R"({
6262
"sigmaR": "sigma_R"
6363
}
6464
},
65+
"RooEffProd": {
66+
"type": "efficiency_product_pdf_dist",
67+
"proxies": {
68+
"pdf": "pdf",
69+
"eff": "eff"
70+
}
71+
},
6572
"RooGamma": {
6673
"type": "gamma_dist",
6774
"proxies": {
@@ -79,13 +86,6 @@ auto RooFitHS3_wsexportkeys = R"({
7986
"sigma": "sigma"
8087
}
8188
},
82-
"ParamHistFunc": {
83-
"type": "step",
84-
"proxies": {
85-
"dataVars": "variables",
86-
"paramSet": "parameters"
87-
}
88-
},
8989
"RooLandau": {
9090
"type": "landau_dist",
9191
"proxies": {

roofit/hs3/src/RooFitHS3_wsfactoryexpressions.cxx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ auto RooFitHS3_wsfactoryexpressions = R"({
4343
"coefficients"
4444
]
4545
},
46+
"efficiency_product_pdf_dist": {
47+
"class": "RooEffProd",
48+
"arguments": [
49+
"pdf",
50+
"eff"
51+
]
52+
},
4653
"gamma_dist": {
4754
"class": "RooGamma",
4855
"arguments": [
@@ -112,13 +119,6 @@ auto RooFitHS3_wsfactoryexpressions = R"({
112119
"observables"
113120
]
114121
},
115-
"step": {
116-
"class": "ParamHistFunc",
117-
"arguments": [
118-
"variables",
119-
"parameters"
120-
]
121-
},
122122
"sum": {
123123
"class": "RooAddition",
124124
"arguments": [

roofit/hs3/src/RooJSONFactoryWSTool.cxx

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ RooAbsReal *RooJSONFactoryWSTool::requestImpl<RooAbsReal>(const std::string &obj
965965
* @param node The JSONNode to which the variable will be exported.
966966
* @return void
967967
*/
968-
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
968+
void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node, bool storeConstant, bool storeBins)
969969
{
970970
auto *cv = dynamic_cast<const RooConstVar *>(v);
971971
auto *rrv = dynamic_cast<const RooRealVar *>(v);
@@ -984,10 +984,10 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
984984
var["const"] << true;
985985
} else if (rrv) {
986986
var["value"] << rrv->getVal();
987-
if (rrv->isConstant()) {
987+
if (rrv->isConstant() && storeConstant) {
988988
var["const"] << rrv->isConstant();
989989
}
990-
if (rrv->getBins() != 100) {
990+
if (rrv->getBins() != 100 && storeBins) {
991991
var["nbins"] << rrv->getBins();
992992
}
993993
_domains->readVariable(*rrv);
@@ -1004,12 +1004,12 @@ void RooJSONFactoryWSTool::exportVariable(const RooAbsArg *v, JSONNode &node)
10041004
* @param n The JSONNode to which the variables will be exported.
10051005
* @return void
10061006
*/
1007-
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n)
1007+
void RooJSONFactoryWSTool::exportVariables(const RooArgSet &allElems, JSONNode &n, bool storeConstant, bool storeBins)
10081008
{
10091009
// export a list of RooRealVar objects
10101010
n.set_seq();
10111011
for (RooAbsArg *arg : allElems) {
1012-
exportVariable(arg, n);
1012+
exportVariable(arg, n, storeConstant, storeBins);
10131013
}
10141014
}
10151015

@@ -1070,7 +1070,7 @@ void RooJSONFactoryWSTool::exportObject(RooAbsArg const &func, std::set<std::str
10701070
// categories are created by the respective RooSimultaneous, so we're skipping the export here
10711071
return;
10721072
} else if (dynamic_cast<RooRealVar const *>(&func) || dynamic_cast<RooConstVar const *>(&func)) {
1073-
exportVariable(&func, *_varsNode);
1073+
exportVariable(&func, *_varsNode, true, false);
10741074
return;
10751075
}
10761076

@@ -1554,18 +1554,14 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15541554

15551555
// this really is an unbinned dataset
15561556
output["type"] << "unbinned";
1557-
exportVariables(variables, output["axes"]);
1557+
exportVariables(variables, output["axes"], false, true);
15581558
auto &coords = output["entries"].set_seq();
15591559
std::vector<double> weightVals;
15601560
bool hasNonUnityWeights = false;
15611561
for (int i = 0; i < data.numEntries(); ++i) {
15621562
data.get(i);
15631563
coords.append_child().fill_seq(variables, [](auto x) { return static_cast<RooRealVar *>(x)->getVal(); });
15641564
std::string datasetName = data.GetName();
1565-
/*if (datasetName.find("combData_ZvvH126.5") != std::string::npos) {
1566-
file << dynamic_cast<RooAbsReal *>(data.get(i)->find("atlas_invMass_PttEtaConvVBFCat1"))->getVal() <<
1567-
std::endl;
1568-
}*/
15691565
if (data.isWeighted()) {
15701566
weightVals.push_back(data.weight());
15711567
if (data.weight() != 1.)
@@ -1575,7 +1571,6 @@ void RooJSONFactoryWSTool::exportData(RooAbsData const &data)
15751571
if (data.isWeighted() && hasNonUnityWeights) {
15761572
output["weights"].fill_seq(weightVals);
15771573
}
1578-
// file.close();
15791574
}
15801575

15811576
/**
@@ -1960,7 +1955,8 @@ void RooJSONFactoryWSTool::exportAllObjects(JSONNode &n)
19601955
snapshotSorted.sort();
19611956
std::string name(snsh->GetName());
19621957
if (name != "default_values") {
1963-
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"]);
1958+
this->exportVariables(snapshotSorted, appendNamedChild(n["parameter_points"], name)["parameters"], true,
1959+
false);
19641960
}
19651961
}
19661962
_varsNode = nullptr;
@@ -2240,8 +2236,14 @@ void RooJSONFactoryWSTool::importAllNodes(const JSONNode &n)
22402236
combineDatasets(*_rootnodeInput, datasets);
22412237

22422238
for (auto const &d : datasets) {
2243-
if (d)
2239+
if (d) {
22442240
_workspace.import(*d);
2241+
for (auto const &obs : *d->get()) {
2242+
if (auto *rrv = dynamic_cast<RooRealVar *>(obs)) {
2243+
_workspace.var(rrv->GetName())->setBinning(rrv->getBinning());
2244+
}
2245+
}
2246+
}
22452247
}
22462248

22472249
_rootnodeInput = nullptr;

roofit/jsoninterface/inc/RooFit/Detail/JSONInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::sp
265265
return n;
266266
}
267267

268+
inline RooFit::Detail::JSONNode &operator<<(RooFit::Detail::JSONNode &n, std::span<const int> v)
269+
{
270+
n.fill_seq(v);
271+
return n;
272+
}
273+
268274
template <class Key, class T, class Hash, class KeyEqual, class Allocator>
269275
RooFit::Detail::JSONNode &
270276
operator<<(RooFit::Detail::JSONNode &n, const std::unordered_map<Key, T, Hash, KeyEqual, Allocator> &m)

0 commit comments

Comments
 (0)