Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace HistFactory{
void setHigh(RooAbsReal& param, double newHigh);

void printAllInterpCodes();
const std::vector<int>& interpolationCodes() const { return _interpCode; }

TObject* clone(const char* newname) const override { return new FlexibleInterpVar(*this, newname); }
~FlexibleInterpVar() override ;
Expand Down
48 changes: 48 additions & 0 deletions roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ class FlexibleInterpVarStreamer : public RooFit::JSONIO::Exporter {
static_cast<const RooStats::HistFactory::FlexibleInterpVar *>(func);
elem["type"] << key();
auto &vars = elem["vars"];
elem["interpolationCodes"] << fip->interpolationCodes();
vars.set_seq();
for (const auto &v : fip->variables()) {
vars.append_child() << v->GetName();
Expand Down Expand Up @@ -668,6 +669,52 @@ class PiecewiseInterpolationFactory : public RooFit::JSONIO::Importer {
}
};

class FlexibleInterpVarFactory : public RooFit::JSONIO::Importer {
public:
bool importFunction(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
if (!p.has_child("vars")) {
RooJSONFactoryWSTool::error("no vars of '" + name + "'");
}
if (!p.has_child("high")) {
RooJSONFactoryWSTool::error("no high variations of '" + name + "'");
}
if (!p.has_child("low")) {
RooJSONFactoryWSTool::error("no low variations of '" + name + "'");
}
if (!p.has_child("nom")) {
RooJSONFactoryWSTool::error("no nominal variation of '" + name + "'");
}

double nom(p["nom"].val_float());

RooArgList vars;
for (const auto &d : p["vars"].children()) {
std::string objname(RooJSONFactoryWSTool::name(d));
RooRealVar *obj = tool->request<RooRealVar>(objname, name);
vars.add(*obj);
}

std::vector<double> high;
high << p["high"];

std::vector<double> low;
high << p["low"];

RooStats::HistFactory::FlexibleInterpVar fip(name.c_str(), name.c_str(), vars, nom, low, high);

if (p.has_child("interpolationCodes")) {
for (size_t i = 0; i < vars.size(); ++i) {
fip.setInterpCode(*static_cast<RooAbsReal *>(vars.at(i)), p["interpolationCodes"][i].val_int());
}
}

tool->workspace()->import(fip, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
return true;
}
};

class HistFactoryStreamer : public RooFit::JSONIO::Exporter {
public:
bool autoExportDependants() const override { return false; }
Expand Down Expand Up @@ -937,6 +984,7 @@ STATIC_EXECUTE(

registerImporter<RooRealSumPdfFactory>("histfactory", true);
registerImporter<PiecewiseInterpolationFactory>("interpolation", true);
registerImporter<FlexibleInterpVarFactory>("interpolation0d", true);
registerExporter<FlexibleInterpVarStreamer>(RooStats::HistFactory::FlexibleInterpVar::Class(), true);
registerExporter<PiecewiseInterpolationStreamer>(PiecewiseInterpolation::Class(), true);
registerExporter<HistFactoryStreamer>(RooProdPdf::Class(), true);
Expand Down
105 changes: 103 additions & 2 deletions roofit/hs3/src/JSONFactories_RooFitCore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include <RooFormulaVar.h>
#include <RooGenericPdf.h>
#include <RooHistFunc.h>
#include <RooHistPdf.h>
#include <RooProdPdf.h>
#include <RooRealSumFunc.h>
#include <RooRealSumPdf.h>
#include <RooRealVar.h>
#include <RooSimultaneous.h>
Expand Down Expand Up @@ -256,6 +258,34 @@ class RooRealSumPdfFactory : public RooFit::JSONIO::Importer {
}
};

class RooRealSumFuncFactory : public RooFit::JSONIO::Importer {
public:
bool importFunction(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
if (!p.has_child("samples")) {
RooJSONFactoryWSTool::error("no samples given in '" + name + "'");
}
if (!p.has_child("coefficients")) {
RooJSONFactoryWSTool::error("no coefficients given in '" + name + "'");
}
RooArgList samples;
for (const auto &sample : p["samples"].children()) {
RooAbsReal *s = tool->request<RooAbsReal>(sample.val(), name);
samples.add(*s);
}
RooArgList coefficients;
for (const auto &coef : p["coefficients"].children()) {
RooAbsReal *c = tool->request<RooAbsReal>(coef.val(), name);
coefficients.add(*c);
}

RooRealSumFunc thefunc(name.c_str(), name.c_str(), samples, coefficients);
tool->workspace()->import(thefunc, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
return true;
}
};

///////////////////////////////////////////////////////////////////////////////////////////////////////
// specialized exporter implementations
///////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -286,6 +316,31 @@ class RooRealSumPdfStreamer : public RooFit::JSONIO::Exporter {
}
};

class RooRealSumFuncStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override
{
const static std::string keystring = "sumfunc";
return keystring;
}
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
{
const RooRealSumFunc *pdf = static_cast<const RooRealSumFunc *>(func);
elem["type"] << key();
auto &samples = elem["samples"];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this called samples? This is specific to HistFactory, but in general the functions can also represent things that are not "samples". Can it just be called "functions" without any compatibility issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for "samples" rather than "functions" because the structure of the sumpdf with "functions" and "coefficients" is clearly intended to work that way - if you're just adding random stuff, you could also use RooAddition or RooAddPdf

samples.set_seq();
auto &coefs = elem["coefficients"];
coefs.set_seq();
for (const auto &s : pdf->funcList()) {
samples.append_child() << s->GetName();
}
for (const auto &c : pdf->coefList()) {
coefs.append_child() << c->GetName();
}
return true;
}
};

class RooSimultaneousStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override
Expand Down Expand Up @@ -356,6 +411,48 @@ class RooHistFuncFactory : public RooFit::JSONIO::Importer {
}
};

class RooHistPdfStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override
{
static const std::string keystring = "histogramPdf";
return keystring;
}
bool exportObject(RooJSONFactoryWSTool *, const RooAbsArg *func, JSONNode &elem) const override
{
const RooHistPdf *hf = static_cast<const RooHistPdf *>(func);
const RooDataHist &dh = hf->dataHist();
elem["type"] << key();
RooArgList vars(*dh.get());
std::unique_ptr<TH1> hist{hf->createHistogram(RooJSONFactoryWSTool::concat(&vars).c_str())};
auto &data = elem["data"];
RooJSONFactoryWSTool::exportHistogram(*hist, data, RooJSONFactoryWSTool::names(&vars));
return true;
}
};

class RooHistPdfFactory : public RooFit::JSONIO::Importer {
public:
bool importPdf(RooJSONFactoryWSTool *tool, const JSONNode &p) const override
{
std::string name(RooJSONFactoryWSTool::name(p));
if (!p.has_child("data")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only see a data child here. But to fully specify a RooHistPdf, don't you also need the bin boundaries? Where are they stored?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bin boundaries are stored in the observables.

RooJSONFactoryWSTool::error("function '" + name + "' is of histogram type, but does not define a 'data' key");
}
RooArgSet varlist;
tool->getObservables(*tool->workspace(), p["data"], name, varlist);
RooDataHist *dh = dynamic_cast<RooDataHist *>(tool->workspace()->embeddedData(name));
if (!dh) {
auto dhForImport = tool->readBinnedData(*tool->workspace(), p["data"], name, varlist);
tool->workspace()->import(*dhForImport, RooFit::Silence(true), RooFit::Embedded());
dh = static_cast<RooDataHist *>(tool->workspace()->embeddedData(dhForImport->GetName()));
}
RooHistPdf hf(name.c_str(), name.c_str(), *(dh->get()), *dh);
tool->workspace()->import(hf, RooFit::RecycleConflictNodes(true), RooFit::Silence(true));
return true;
}
};

class RooBinSamplingPdfStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override
Expand Down Expand Up @@ -461,20 +558,24 @@ STATIC_EXECUTE(
registerImporter<RooProdPdfFactory>("pdfprod", false); registerImporter<RooGenericPdfFactory>("genericpdf", false);
registerImporter<RooFormulaVarFactory>("formulavar", false);
registerImporter<RooBinSamplingPdfFactory>("binsampling", false);
registerImporter<RooAddPdfFactory>("pdfsum", false); registerImporter<RooHistFuncFactory>("histogram", false);
registerImporter<RooAddPdfFactory>("pdfsum", false);
registerImporter<RooHistFuncFactory>("histogram", false);
registerImporter<RooHistFuncFactory>("histogramPdf", false);
registerImporter<RooSimultaneousFactory>("simultaneous", false);
registerImporter<RooBinWidthFunctionFactory>("binwidth", false);
registerImporter<RooRealSumPdfFactory>("sumpdf", false);
registerImporter<RooRealSumFuncFactory>("sumfunc", false);

registerExporter<RooBinWidthFunctionStreamer>(RooBinWidthFunction::Class(), false);
registerExporter<RooProdPdfStreamer>(RooProdPdf::Class(), false);
registerExporter<RooSimultaneousStreamer>(RooSimultaneous::Class(), false);
registerExporter<RooBinSamplingPdfStreamer>(RooBinSamplingPdf::Class(), false);
registerExporter<RooHistFuncStreamer>(RooHistFunc::Class(), false);
registerExporter<RooHistPdfStreamer>(RooHistPdf::Class(), false);
registerExporter<RooGenericPdfStreamer>(RooGenericPdf::Class(), false);
registerExporter<RooFormulaVarStreamer>(RooFormulaVar::Class(), false);
registerExporter<RooRealSumPdfStreamer>(RooRealSumPdf::Class(), false);

registerExporter<RooRealSumFuncStreamer>(RooRealSumFunc::Class(), false);
)

} // namespace
7 changes: 3 additions & 4 deletions roofit/hs3/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ROOT_ADD_PYUNITTEST(roofit-hs3-histfactory-json test_hs3_histfactory_json.py
COPY_TO_BUILDDIR ${CMAKE_CURRENT_SOURCE_DIR}/test_hs3_histfactory_json_input.root
)

ROOT_ADD_GTEST(testRooFitHS3 testRooFitHS3.cxx LIBRARIES RooFitCore RooFit RooFitHS3)
ROOT_ADD_GTEST(testHS3HistFactory testHS3HistFactory.cxx LIBRARIES RooFit RooFitHS3 HistFactory
COPY_TO_BUILDDIR ${CMAKE_CURRENT_SOURCE_DIR}/test_hs3_histfactory_json_input.root
)
Loading