Skip to content

Added Variance Threshold method in DataLoader #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tmva/tmva/inc/TMVA/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ namespace TMVA {
DataSetInfo& AddDataSet( DataSetInfo& );
DataSetInfo& AddDataSet( const TString& );

DataLoader* VarTransform(TString trafoDefinition);

// special case: signal/background

// Data input related
Expand Down Expand Up @@ -177,6 +179,7 @@ namespace TMVA {
DataInputHandler& DataInput() { return *fDataInputHandler; }
DataSetInfo& DefaultDataSetInfo();
void SetInputTreesFromEventAssignTrees();
void CopyDataLoader(DataLoader* des, DataLoader* src);


private:
Expand Down
130 changes: 130 additions & 0 deletions tmva/tmva/src/DataLoader.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,136 @@ TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet( const TString& dsiName )
return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
}

//_______________________________________________________________________
void TMVA::DataLoader::CopyDataLoader(TMVA::DataLoader* des, TMVA::DataLoader* src)
{
//Loading Dataset from DataInputHandler for subseed
for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();treeinfo++)
{
des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
}

for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();treeinfo++)
{
des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
}
}

////////////////////////////////////////////////////////////////////////////////
/// Computes variance of all the variables and
/// returns a new DataLoader with the selected variables whose variance is above a specific threshold.
/// Threshold can be provided by user otherwise default value is 0 i.e. remove the variables which have same value in all
/// the events.
///
/// \param[in] trafoDefinition Tranformation Definition String
///
/// Transformation Definition String Format: "VT(optional float value)"
///
/// Usage examples:
///
/// String | Description
/// ------- |----------------------------------------
/// "VT" | Select variables whose variance is above threshold value = 0 (Default)
/// "VT(1.5)" | Select variables whose variance is above threshold value = 1.5
TMVA::DataLoader* TMVA::DataLoader::VarTransform(TString trafoDefinition)
{

TString trOptions = "0";
TString trName = "None";
if (trafoDefinition.Contains("(")) {

// contains transformation parameters
Ssiz_t parStart = trafoDefinition.Index( "(" );
Ssiz_t parLen = trafoDefinition.Index( ")", parStart )-parStart+1;

trName = trafoDefinition(0,parStart);
trOptions = trafoDefinition(parStart,parLen);
trOptions.Remove(parLen-1,1);
trOptions.Remove(0,1);
}
else
trName = trafoDefinition;

// variance threshold variable transformation
if (trName == "VT") {

// find threshold value from given input
Double_t threshold = 0.0;
if (!trOptions.IsFloat()){
Log() << kFATAL << " VT transformation must be passed a floating threshold value" << Endl;
return this;
}
else
threshold = trOptions.Atof();
Log() << kINFO << "Transformation: " << trName << Endl;
Log() << kINFO << "Threshold value: " << threshold << Endl;

// get events
const std::vector<Event*>& events = DefaultDataSetInfo().GetDataSet()->GetEventCollection();
UInt_t nevts = events.size();
Log() << kINFO << "Number of events: " << nevts << Endl;
const UInt_t nvars = DefaultDataSetInfo().GetNVariables();
Log() << kINFO << "Number of variables before transformation: " << nvars << Endl;
std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();

// calculate mean of variables
Double_t sumOfWeights = 0;
std::vector<Double_t> varMean(nvars);
for (UInt_t ievt=0; ievt<nevts; ievt++) {

const Event* ev = events[ievt];

Double_t weight = ev->GetWeight();
sumOfWeights += weight;
for (UInt_t ivar=0; ivar<nvars; ivar++) {
Double_t x = ev->GetValue(ivar);
varMean[ivar] += x*weight;
}
}
if (sumOfWeights <= 0) {
Log() << kFATAL << " the sum of event weights calcualted for your input is == 0"
<< " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
}
for (UInt_t ivar=0; ivar<nvars; ivar++) {
varMean[ivar] = varMean[ivar]/sumOfWeights;
}

// calculate variance
TVectorD x0( nvars ); x0 *= 0;
for (UInt_t ievt=0; ievt<nevts; ievt++) {
const Event* ev = events[ievt];
Double_t weight = ev->GetWeight();
for (UInt_t ivar=0; ivar<nvars; ivar++) {
Double_t x = ev->GetValue(ivar);

// get mean
Double_t mean = varMean[ivar];
x0(ivar) += weight*(x-mean)*(x-mean);
}
}

// return a new dataloader
// iterate over all variables, ignore the ones whose variance is below specific threshold
TMVA::DataLoader *transformedloader = new TMVA::DataLoader(DefaultDataSetInfo().GetName());
for (UInt_t ivar=0; ivar<nvars; ivar++) {
Double_t variance = x0(ivar)/sumOfWeights;
Log() << kINFO << "Variable " << vars[ivar].GetExpression() <<" variance = " << variance << Endl;
if (variance > threshold)
transformedloader->AddVariable(vars[ivar].GetExpression(), vars[ivar].GetVarType());
}
CopyDataLoader(transformedloader,this);
transformedloader->PrepareTrainingAndTestTree(this->DefaultDataSetInfo().GetCut("Signal"), this->DefaultDataSetInfo().GetCut("Background"), this->DefaultDataSetInfo().GetSplitOptions());

Log() << kINFO << "Number of variables after transformation: " << transformedloader->DefaultDataSetInfo().GetNVariables() << Endl;

return transformedloader;
}
else {
Log() << kFATAL << "Incorrect transformation string provided, please check" << Endl;
}
return this;
}

// ________________________________________________
// the next functions are to assign events directly

Expand Down