Skip to content
Merged
18 changes: 12 additions & 6 deletions roofit/roofitcore/inc/RooSimultaneous.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,22 @@ class RooSimultaneous : public RooAbsPdf {
std::map<std::string, double> const& precisions,
bool useCategoryNames=false);

RooAbsGenContext* autoGenContext(const RooArgSet &vars, const RooDataSet* prototype=nullptr, const RooArgSet* auxProto=nullptr,
bool verbose=false, bool autoBinned=true, const char* binnedTag="") const override ;
RooAbsGenContext* genContext(const RooArgSet &vars, const RooDataSet *prototype=nullptr,
const RooArgSet* auxProto=nullptr, bool verbose= false) const override ;

protected:

void initialize(RooAbsCategoryLValue& inIndexCat, std::map<std::string,RooAbsPdf*> pdfMap) ;

void selectNormalization(const RooArgSet* depSet=nullptr, bool force=false) override ;
void selectNormalizationRange(const char* rangeName=nullptr, bool force=false) override ;

RooArgSet const& flattenedCatList() const;

mutable RooSetProxy _plotCoefNormSet ;
const TNamed* _plotCoefNormRange ;
const TNamed* _plotCoefNormRange = nullptr;

class CacheElem : public RooAbsCacheElement {
public:
Expand All @@ -109,14 +117,12 @@ class RooSimultaneous : public RooAbsPdf {

friend class RooSimGenContext ;
friend class RooSimSplitGenContext ;
RooAbsGenContext* autoGenContext(const RooArgSet &vars, const RooDataSet* prototype=nullptr, const RooArgSet* auxProto=nullptr,
bool verbose=false, bool autoBinned=true, const char* binnedTag="") const override ;
RooAbsGenContext* genContext(const RooArgSet &vars, const RooDataSet *prototype=nullptr,
const RooArgSet* auxProto=nullptr, bool verbose= false) const override ;

RooCategoryProxy _indexCat ; ///< Index category
TList _pdfProxyList ; ///< List of PDF proxies (named after applicable category state)
Int_t _numPdf ; ///< Number of registered PDFs
Int_t _numPdf = 0; ///< Number of registered PDFs
private:
mutable std::unique_ptr<RooArgSet> _indexCatSet ; ///<! Index category wrapped in a RooArgSet if needed internally

ClassDefOverride(RooSimultaneous,3) // Simultaneous operator p.d.f, functions like C++ 'switch()' on input p.d.fs operating on index category5A
};
Expand Down
9 changes: 6 additions & 3 deletions roofit/roofitcore/src/RooAbsArg.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,12 @@ bool RooAbsArg::redirectServers(const RooAbsCollection& newSetOrig, bool mustRep

if (!newServer) {
if (mustReplaceAll) {
coutE(LinkStateMgmt) << "RooAbsArg::redirectServers(" << (void*)this << "," << GetName() << "): server " << oldServer->GetName()
<< " (" << (void*)oldServer << ") not redirected" << (nameChange?"[nameChange]":"") << endl ;
ret = true ;
std::stringstream ss;
ss << "RooAbsArg::redirectServers(" << (void*)this << "," << GetName() << "): server " << oldServer->GetName()
<< " (" << (void*)oldServer << ") not redirected" << (nameChange?"[nameChange]":"");
const std::string errorMsg = ss.str();
coutE(LinkStateMgmt) << errorMsg << std::endl;
throw std::runtime_error(errorMsg);
}
continue ;
}
Expand Down
43 changes: 11 additions & 32 deletions roofit/roofitcore/src/RooSimGenContext.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,49 +61,28 @@ RooSimGenContext::RooSimGenContext(const RooSimultaneous &model, const RooArgSet
RooAbsGenContext(model,vars,prototype,auxProto,verbose), _pdf(&model), _protoData(0)
{
// Determine if we are requested to generate the index category
RooAbsCategory *idxCat = (RooAbsCategory*) model._indexCat.absArg() ;
RooAbsCategoryLValue const& idxCat = model.indexCat();
RooArgSet pdfVars(vars) ;

RooArgSet allPdfVars(pdfVars) ;
if (prototype) allPdfVars.add(*prototype->get(),true) ;

if (!idxCat->isDerived()) {
pdfVars.remove(*idxCat,true,true) ;
bool doGenIdx = allPdfVars.find(idxCat->GetName())?true:false ;
RooArgSet catsAmongAllVars;
allPdfVars.selectCommon(model.flattenedCatList(), catsAmongAllVars);

if (!doGenIdx) {
if(catsAmongAllVars.size() != model.flattenedCatList().size()) {
oocoutE(_pdf,Generation) << "RooSimGenContext::ctor(" << GetName() << ") ERROR: This context must"
<< " generate the index category" << endl ;
<< " generate all components of the index category" << endl ;
_isValid = false ;
_numPdf = 0 ;
_haveIdxProto = false ;
return ;
}
} else {
bool anyServer(false), allServers(true) ;
for(RooAbsArg* server : idxCat->servers()) {
if (vars.find(server->GetName())) {
anyServer=true ;
pdfVars.remove(*server,true,true) ;
} else {
allServers=false ;
}
}

if (anyServer && !allServers) {
oocoutE(_pdf,Generation) << "RooSimGenContext::ctor(" << GetName() << ") ERROR: This context must"
<< " generate all components of a derived index category" << endl ;
_isValid = false ;
_numPdf = 0 ;
_haveIdxProto = false ;
return ;
}
}

// We must either have the prototype or extended likelihood to determined
// the relative fractions of the components
_haveIdxProto = prototype ? true : false ;
_idxCatName = idxCat->GetName() ;
_idxCatName = idxCat.GetName() ;
if (!_haveIdxProto && !model.canBeExtended()) {
oocoutE(_pdf,Generation) << "RooSimGenContext::ctor(" << GetName() << ") ERROR: Need either extended mode"
<< " or prototype data to calculate number of events per category" << endl ;
Expand All @@ -129,7 +108,7 @@ RooSimGenContext::RooSimGenContext(const RooSimultaneous &model, const RooArgSet
// Name the context after the associated state and add to list
cx->SetName(proxy->name()) ;
_gcList.push_back(cx) ;
_gcIndex.push_back(idxCat->lookupIndex(proxy->name()));
_gcIndex.push_back(idxCat.lookupIndex(proxy->name()));

// Fill fraction threshold array
_fracThresh[i] = _fracThresh[i-1] + (_haveIdxProto?0:pdf->expectedEvents(&allPdfVars)) ;
Expand All @@ -145,13 +124,13 @@ RooSimGenContext::RooSimGenContext(const RooSimultaneous &model, const RooArgSet

// Clone the index category
_idxCatSet = new RooArgSet;
RooArgSet(model._indexCat.arg()).snapshot(*_idxCatSet, true);
RooArgSet(model.indexCat()).snapshot(*_idxCatSet, true);
if (!_idxCatSet) {
oocoutE(_pdf,Generation) << "RooSimGenContext::RooSimGenContext(" << GetName() << ") Couldn't deep-clone index category, abort," << endl ;
throw std::string("RooSimGenContext::RooSimGenContext() Couldn't deep-clone index category, abort") ;
}

_idxCat = (RooAbsCategoryLValue*) _idxCatSet->find(model._indexCat.arg().GetName()) ;
_idxCat = static_cast<RooAbsCategoryLValue*>(_idxCatSet->find(model.indexCat().GetName()));
}


Expand All @@ -177,7 +156,7 @@ RooSimGenContext::~RooSimGenContext()
void RooSimGenContext::attach(const RooArgSet& args)
{
if (_idxCat->isDerived()) {
_idxCat->recursiveRedirectServers(args,true) ;
_idxCat->recursiveRedirectServers(args) ;
}

// Forward initGenerator call to all components
Expand All @@ -195,7 +174,7 @@ void RooSimGenContext::initGenerator(const RooArgSet &theEvent)
{
// Attach the index category clone to the event
if (_idxCat->isDerived()) {
_idxCat->recursiveRedirectServers(theEvent,true) ;
_idxCat->recursiveRedirectServers(theEvent) ;
} else {
_idxCat = (RooAbsCategoryLValue*) theEvent.find(_idxCat->GetName()) ;
}
Expand Down
43 changes: 11 additions & 32 deletions roofit/roofitcore/src/RooSimSplitGenContext.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,46 +53,25 @@ RooSimSplitGenContext::RooSimSplitGenContext(const RooSimultaneous &model, const
RooAbsGenContext(model,vars,0,0,verbose), _pdf(&model)
{
// Determine if we are requested to generate the index category
RooAbsCategory *idxCat = (RooAbsCategory*) model._indexCat.absArg() ;
RooAbsCategoryLValue const& idxCat = model.indexCat();
RooArgSet pdfVars(vars) ;

RooArgSet allPdfVars(pdfVars) ;

if (!idxCat->isDerived()) {
pdfVars.remove(*idxCat,true,true) ;
bool doGenIdx = allPdfVars.find(idxCat->GetName())?true:false ;
RooArgSet catsAmongAllVars;
allPdfVars.selectCommon(model.flattenedCatList(), catsAmongAllVars);

if (!doGenIdx) {
if(catsAmongAllVars.size() != model.flattenedCatList().size()) {
oocoutE(_pdf,Generation) << "RooSimSplitGenContext::ctor(" << GetName() << ") ERROR: This context must"
<< " generate the index category" << endl ;
<< " generate all components of the index category" << endl ;
_isValid = false ;
_numPdf = 0 ;
// coverity[UNINIT_CTOR]
return ;
}
} else {
bool anyServer(false), allServers(true) ;
for(RooAbsArg* server : idxCat->servers()) {
if (vars.find(server->GetName())) {
anyServer=true ;
pdfVars.remove(*server,true,true) ;
} else {
allServers=false ;
}
}

if (anyServer && !allServers) {
oocoutE(_pdf,Generation) << "RooSimSplitGenContext::ctor(" << GetName() << ") ERROR: This context must"
<< " generate all components of a derived index category" << endl ;
_isValid = false ;
_numPdf = 0 ;
// coverity[UNINIT_CTOR]
return ;
}
}

// We must extended likelihood to determine the relative fractions of the components
_idxCatName = idxCat->GetName() ;
_idxCatName = idxCat.GetName() ;
if (!model.canBeExtended()) {
oocoutE(_pdf,Generation) << "RooSimSplitGenContext::RooSimSplitGenContext(" << GetName() << "): All components of the simultaneous PDF "
<< "must be extended PDFs. Otherwise, it is impossible to calculate the number of events to be generated per component." << endl ;
Expand All @@ -118,7 +97,7 @@ RooSimSplitGenContext::RooSimSplitGenContext(const RooSimultaneous &model, const
RooAbsGenContext* cx = pdf->autoGenContext(*compVars,0,0,verbose,autoBinned,binnedTag) ;
delete compVars ;

const auto state = idxCat->lookupIndex(proxy->name());
const auto state = idxCat.lookupIndex(proxy->name());

cx->SetName(proxy->name()) ;
_gcList.push_back(cx) ;
Expand All @@ -134,11 +113,11 @@ RooSimSplitGenContext::RooSimSplitGenContext(const RooSimultaneous &model, const
}

// Clone the index category
if(RooArgSet(model._indexCat.arg()).snapshot(_idxCatSet, true)) {
if(RooArgSet(model.indexCat()).snapshot(_idxCatSet, true)) {
oocoutE(_pdf,Generation) << "RooSimSplitGenContext::RooSimSplitGenContext(" << GetName() << ") Couldn't deep-clone index category, abort," << endl ;
throw std::string("RooSimSplitGenContext::RooSimSplitGenContext() Couldn't deep-clone index category, abort") ;
}
_idxCat = (RooAbsCategoryLValue*) _idxCatSet.find(model._indexCat.arg().GetName()) ;
_idxCat = static_cast<RooAbsCategoryLValue*>(_idxCatSet.find(model.indexCat().GetName()));
}


Expand All @@ -162,7 +141,7 @@ RooSimSplitGenContext::~RooSimSplitGenContext()
void RooSimSplitGenContext::attach(const RooArgSet& args)
{
if (_idxCat->isDerived()) {
_idxCat->recursiveRedirectServers(args,true) ;
_idxCat->recursiveRedirectServers(args) ;
}

// Forward initGenerator call to all components
Expand All @@ -180,7 +159,7 @@ void RooSimSplitGenContext::initGenerator(const RooArgSet &theEvent)
{
// Attach the index category clone to the event
if (_idxCat->isDerived()) {
_idxCat->recursiveRedirectServers(theEvent,true) ;
_idxCat->recursiveRedirectServers(theEvent) ;
} else {
_idxCat = (RooAbsCategoryLValue*) theEvent.find(_idxCat->GetName()) ;
}
Expand Down
Loading