Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 54 additions & 9 deletions PWGHF/Core/HfMlResponseLcToPKPi.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
#ifndef PWGHF_CORE_HFMLRESPONSELCTOPKPI_H_
#define PWGHF_CORE_HFMLRESPONSELCTOPKPI_H_

#include <map>
#include <string>
#include <vector>

#include "PWGHF/DataModel/CandidateReconstructionTables.h"

#include "PWGHF/Core/HfMlResponse.h"

// Fill the map of available input features
Expand Down Expand Up @@ -129,10 +133,22 @@ enum class InputFeaturesLcToPKPi : uint8_t {
tofNSigmaPrExpPr0,
tofNSigmaPiExpPi2,
tpcTofNSigmaPrExpPr0,
tpcTofNSigmaPiExpPi2
tpcTofNSigmaPiExpPi2,
kfChi2PrimProton,
kfChi2PrimKaon,
kfChi2PrimPion,
kfChi2GeoKaonPion,
kfChi2GeoProtonPion,
kfChi2GeoProtonKaon,
kfDcaKaonPion,
kfDcaProtonPion,
kfDcaProtonKaon,
kfChi2Geo,
kfChi2Topo,
kfDecayLengthNormalised
};

template <typename TypeOutputScore = float>
template <typename TypeOutputScore = float, aod::hf_cand::VertexerType reconstructionType = aod::hf_cand::VertexerType::DCAFitter>
class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
{
public:
Expand Down Expand Up @@ -179,8 +195,6 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaPr2, nSigTpcPr2);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaKa2, nSigTpcKa2);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcNSigmaPi2, nSigTpcPi2);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tpcNSigmaPrExpPr0, tpcNSigmaPr);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tpcNSigmaPiExpPi2, tpcNSigmaPi);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcNSigmaPrExpPr0, nSigTpcPr0, nSigTpcPr2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcNSigmaPiExpPi2, nSigTpcPi2, nSigTpcPi0);
// TOF PID variables
Expand All @@ -193,8 +207,6 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaPr2, nSigTofPr2);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaKa2, nSigTofKa2);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tofNSigmaPi2, nSigTofPi2);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tofNSigmaPrExpPr0, tofNSigmaPr);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tofNSigmaPiExpPi2, tofNSigmaPi);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tofNSigmaPrExpPr0, nSigTofPr0, nSigTofPr2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tofNSigmaPiExpPi2, nSigTofPi2, nSigTofPi0);
// Combined PID variables
Expand All @@ -207,13 +219,29 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr0, tpcTofNSigmaPr0);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr1, tpcTofNSigmaPr1);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, tpcTofNSigmaPr2, tpcTofNSigmaPr2);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong0, prong2, tpcTofNSigmaPrExpPr0, tpcTofNSigmaPr);
// CHECK_AND_FILL_VEC_LCTOPKPI_OBJECT_SIGNED(prong2, prong0, tpcTofNSigmaPiExpPi2, tpcTofNSigmaPi);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcTofNSigmaPrExpPr0, tpcTofNSigmaPr0, tpcTofNSigmaPr2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, tpcTofNSigmaPiExpPi2, tpcTofNSigmaPi2, tpcTofNSigmaPi0);
}
if constexpr (reconstructionType == aod::hf_cand::VertexerType::KfParticle) {
switch (idx) {
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2PrimProton, kfChi2PrimProng0, kfChi2PrimProng2);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfChi2PrimKaon, kfChi2PrimProng1);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2PrimPion, kfChi2PrimProng2, kfChi2PrimProng0);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2GeoKaonPion, kfChi2GeoProng1Prong2, kfChi2GeoProng0Prong1);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfChi2GeoProtonPion, kfChi2GeoProng0Prong2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfChi2GeoProtonKaon, kfChi2GeoProng0Prong1, kfChi2GeoProng1Prong2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfDcaKaonPion, kfDcaProng1Prong2, kfDcaProng0Prong1);
CHECK_AND_FILL_VEC_LCTOPKPI_FULL(candidate, kfDcaProtonPion, kfDcaProng0Prong2);
CHECK_AND_FILL_VEC_LCTOPKPI_SIGNED(candidate, kfDcaProtonKaon, kfDcaProng0Prong1, kfDcaProng1Prong2);
CHECK_AND_FILL_VEC_LCTOPKPI(kfChi2Geo);
CHECK_AND_FILL_VEC_LCTOPKPI(kfChi2Topo);
case static_cast<uint8_t>(InputFeaturesLcToPKPi::kfDecayLengthNormalised): {
inputFeatures.emplace_back(candidate.kfDecayLength() / candidate.kfDecayLengthError());
break;
}
}
}
}

return inputFeatures;
}

Expand Down Expand Up @@ -273,6 +301,23 @@ class HfMlResponseLcToPKPi : public HfMlResponse<TypeOutputScore>
FILL_MAP_LCTOPKPI(tpcTofNSigmaPr2),
FILL_MAP_LCTOPKPI(tpcTofNSigmaPrExpPr0),
FILL_MAP_LCTOPKPI(tpcTofNSigmaPiExpPi2)};
if constexpr (reconstructionType == aod::hf_cand::VertexerType::KfParticle) {
std::map<std::string, uint8_t> mapKfFeatures{
// KFParticle variables
FILL_MAP_LCTOPKPI(kfChi2PrimProton),
FILL_MAP_LCTOPKPI(kfChi2PrimKaon),
FILL_MAP_LCTOPKPI(kfChi2PrimPion),
FILL_MAP_LCTOPKPI(kfChi2GeoKaonPion),
FILL_MAP_LCTOPKPI(kfChi2GeoProtonPion),
FILL_MAP_LCTOPKPI(kfChi2GeoProtonKaon),
FILL_MAP_LCTOPKPI(kfDcaKaonPion),
FILL_MAP_LCTOPKPI(kfDcaProtonPion),
FILL_MAP_LCTOPKPI(kfDcaProtonKaon),
FILL_MAP_LCTOPKPI(kfChi2Geo),
FILL_MAP_LCTOPKPI(kfChi2Topo),
FILL_MAP_LCTOPKPI(kfDecayLengthNormalised)};
MlResponse<TypeOutputScore>::mAvailableInputFeatures.insert(mapKfFeatures.begin(), mapKfFeatures.end());
}
}
};

Expand Down
59 changes: 42 additions & 17 deletions PWGHF/TableProducer/candidateSelectorLc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ struct HfCandidateSelectorLc {
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};

HfHelper hfHelper;
o2::analysis::HfMlResponseLcToPKPi<float> hfMlResponse;
o2::analysis::HfMlResponseLcToPKPi<float, aod::hf_cand::VertexerType::DCAFitter> hfMlResponseDCA;
o2::analysis::HfMlResponseLcToPKPi<float, aod::hf_cand::VertexerType::KfParticle> hfMlResponseKF;
std::vector<float> outputMlLcToPKPi = {};
std::vector<float> outputMlLcToPiKP = {};
o2::ccdb::CcdbApi ccdbApi;
Expand Down Expand Up @@ -142,15 +143,28 @@ struct HfCandidateSelectorLc {
}

if (applyMl) {
hfMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
hfMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
} else {
hfMlResponse.setModelPathsLocal(onnxFileNames);
if (doprocessNoBayesPidWithDCAFitterN || doprocessBayesPidWithDCAFitterN) {
hfMlResponseDCA.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
hfMlResponseDCA.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
} else {
hfMlResponseDCA.setModelPathsLocal(onnxFileNames);
}
hfMlResponseDCA.cacheInputFeaturesIndices(namesInputFeatures);
hfMlResponseDCA.init();
}
if (doprocessNoBayesPidWithKFParticle || doprocessBayesPidWithKFParticle) {
hfMlResponseKF.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
hfMlResponseKF.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB);
} else {
hfMlResponseKF.setModelPathsLocal(onnxFileNames);
}
hfMlResponseKF.cacheInputFeaturesIndices(namesInputFeatures);
hfMlResponseKF.init();
}
hfMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
hfMlResponse.init();
}

massK0Star892 = o2::constants::physics::MassK0Star892;
Expand Down Expand Up @@ -273,7 +287,7 @@ struct HfCandidateSelectorLc {
return false;
}

float massLc, massKPi;
float massLc{0.f}, massKPi{0.f};
if constexpr (reconstructionType == aod::hf_cand::VertexerType::DCAFitter) {
if (trackProton.globalIndex() == candidate.prong0Id()) {
massLc = hfHelper.invMassLcToPKPi(candidate);
Expand Down Expand Up @@ -553,13 +567,24 @@ struct HfCandidateSelectorLc {
isSelectedMlLcToPKPi = false;
isSelectedMlLcToPiKP = false;

if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
std::vector<float> inputFeaturesLcToPKPi = hfMlResponse.getInputFeatures(candidate, true);
isSelectedMlLcToPKPi = hfMlResponse.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
}
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
std::vector<float> inputFeaturesLcToPiKP = hfMlResponse.getInputFeatures(candidate, false);
isSelectedMlLcToPiKP = hfMlResponse.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
if constexpr (reconstructionType == aod::hf_cand::VertexerType::DCAFitter) {
if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
std::vector<float> inputFeaturesLcToPKPi = hfMlResponseDCA.getInputFeatures(candidate, true);
isSelectedMlLcToPKPi = hfMlResponseDCA.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
}
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
std::vector<float> inputFeaturesLcToPiKP = hfMlResponseDCA.getInputFeatures(candidate, false);
isSelectedMlLcToPiKP = hfMlResponseDCA.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
}
} else {
if (pidLcToPKPi == 1 && pidBayesLcToPKPi == 1 && topolLcToPKPi) {
std::vector<float> inputFeaturesLcToPKPi = hfMlResponseKF.getInputFeatures(candidate, true);
isSelectedMlLcToPKPi = hfMlResponseKF.isSelectedMl(inputFeaturesLcToPKPi, candidate.pt(), outputMlLcToPKPi);
}
if (pidLcToPiKP == 1 && pidBayesLcToPiKP == 1 && topolLcToPiKP) {
std::vector<float> inputFeaturesLcToPiKP = hfMlResponseKF.getInputFeatures(candidate, false);
isSelectedMlLcToPiKP = hfMlResponseKF.isSelectedMl(inputFeaturesLcToPiKP, candidate.pt(), outputMlLcToPiKP);
}
}

hfMlLcToPKPiCandidate(outputMlLcToPKPi, outputMlLcToPiKP);
Expand Down
Loading
Loading