11// @(#)root/tmva/tmva/dnn:$Id$
2- // Author: Vladimir Ilievski, Saurav Shekhar
2+ // Author: Anushree Rankawat
33
44/* *********************************************************************************
55 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
1111 * Generative Adversarial Networks *
1212 * *
1313 * Authors (alphabetical): *
14- * Vladimir Ilievski <ilievski.vladimir@live.com> - CERN, Switzerland *
15- * Saurav Shekhar <sauravshekhar01@gmail.com> - ETH Zurich, Switzerland *
14+ * Anushree Rankawat <anushreerankawat110@gmail.com> *
1615 * *
1716 * Copyright (c) 2005-2015: *
1817 * CERN, Switzerland *
5251#include " TMVA/DNN/Architectures/Cuda.h"
5352#endif
5453
54+ #ifdef R__HAS_TMVACPU
55+ using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
56+ #else
57+ using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
58+ #endif
59+ using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
60+
5561#include " TMVA/DNN/Architectures/Reference.h"
5662#include " TMVA/DNN/Functions.h"
5763#include " TMVA/DNN/DeepNet.h"
@@ -65,8 +71,6 @@ using namespace TMVA::DNN;
6571using Architecture_t = TCpu<Double_t>;
6672using Scalar_t = Architecture_t::Scalar_t;
6773using DeepNet_t = TMVA::DNN::TDeepNet<Architecture_t>;
68- // using Matrix_t = typename TCpu<double>::Matrix_t;
69- // using TensorInput = std::tuple<const std::vector<Matrix_t> &, const Matrix_t &, const Matrix_t &>;
7074using TensorDataLoader_t = TTensorDataLoader<TMVAInput_t, Architecture_t>;
7175
7276using TMVA::DNN::EActivationFunction;
@@ -106,13 +110,6 @@ class MethodGAN : public MethodBase {
106110private:
107111 // Key-Value vector type, contining the values for the training options
108112 using KeyValueVector_t = std::vector<std::map<TString, TString>>;
109- // using TensorInput = std::tuple<const std::vector<TMatrixT<Double_t>> &>;
110- #ifdef R__HAS_TMVACPU
111- using ArchitectureImpl_t = TMVA::DNN::TCpu<Double_t>;
112- #else
113- using ArchitectureImpl_t = TMVA::DNN::TReference<Double_t>;
114- #endif
115- using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
116113 std::unique_ptr<DeepNetImpl_t> generatorFNet, discriminatorFNet, combinedFNet;
117114 using Matrix_t = typename ArchitectureImpl_t::Matrix_t;
118115
@@ -133,7 +130,7 @@ class MethodGAN : public MethodBase {
133130 * a reference in the function. */
134131 template <typename Architecture_t, typename Layer_t>
135132 void CreateDeepNet (DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
136- std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &fNet , TString layoutString);
133+ std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, std::unique_ptr<DeepNetImpl_t> &modelNet , TString layoutString);
137134
138135 size_t fGeneratorInputDepth ; // /< The depth of the input of the generator.
139136 size_t fGeneratorInputHeight ; // /< The height of the input of the generator.
@@ -197,20 +194,22 @@ class MethodGAN : public MethodBase {
197194 void Train ();
198195
199196 Double_t GetMvaValue (Double_t *err = 0 , Double_t *errUpper = 0 );
200- Double_t GetMvaValueGAN (std::unique_ptr<DeepNetImpl_t> & fNet , Double_t *err = 0 , Double_t *errUpper = 0 );
201-
197+ Double_t GetMvaValueGAN (std::unique_ptr<DeepNetImpl_t> & modelNet, Double_t *err = 0 , Double_t *errUpper = 0 );
202198 void CreateNoisyMatrices (std::vector<TMatrixT<Double_t>> &inputTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, DeepNet_t &DeepNet, size_t nSamples, size_t classLabel);
203199 Double_t ComputeLoss (TTensorDataLoader<TensorInput, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
204200 Double_t ComputeLoss (TTensorDataLoader<TMVAInput_t, Architecture_t> &generalDataloader, DeepNet_t &DeepNet);
205- void CreateDiscriminatorFakeData (std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, size_t nSamples, size_t classLabel);
206- void CombineGAN (DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet);
207-
208- // void AddWeightsXMLToGAN(std::unique_ptr<DeepNetImpl_t> & fNet, void * parent);
201+ void CreateDiscriminatorFakeData (std::vector<TMatrixT<Double_t>> &predTensor, TMatrixT<Double_t> &outputMatrix, TMatrixT<Double_t> &weights, TTensorDataLoader<TensorInput, Architecture_t> &trainingData, DeepNet_t &genDeepNet, DeepNet_t &disDeepNet, EOutputFunction outputFunction, size_t nSamples, size_t classLabel, size_t epochs);
202+ void CombineGAN (DeepNet_t &combinedDeepNet, DeepNet_t &generatorNet, DeepNet_t &discriminatorNet, std::unique_ptr<DeepNetImpl_t> & combinedNet);
203+ void SetDiscriminatorLayerTraining (DeepNet_t &discrimatorNet);
209204
210205 /* ! Methods for writing and reading weights */
211206 using MethodBase::ReadWeightsFromStream;
212207 void AddWeightsXMLTo (void *parent) const ;
208+ void AddWeightsXMLToGenerator (void *parent) const ;
209+ void AddWeightsXMLToDiscriminator (void *parent) const ;
213210 void ReadWeightsFromXML (void *wghtnode);
211+ void ReadWeightsFromXMLGenerator (void *rootXML);
212+ void ReadWeightsFromXMLDiscriminator (void *rootXML);
214213 void ReadWeightsFromStream (std::istream &);
215214
216215 /* Create ranking */
0 commit comments