From 4c714395994ea17c49f24a807285ad9488ac9cec Mon Sep 17 00:00:00 2001
From: siranipour <si292@cam.ac.uk>
Date: Mon, 14 Sep 2020 13:39:57 +0100
Subject: [PATCH] Writing pseudodata to file

---
 nnpdfcpp/src/common/inc/nnpdfsettings.h  |  1 +
 nnpdfcpp/src/common/src/nnpdfsettings.cc |  8 +++++
 nnpdfcpp/src/nnfit/inc/nnfit.h           |  5 +--
 nnpdfcpp/src/nnfit/src/nnfit.cc          | 43 +++++++++++++++++++++---
 4 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/nnpdfcpp/src/common/inc/nnpdfsettings.h b/nnpdfcpp/src/common/inc/nnpdfsettings.h
index ad93d484a1..8f36deb2d1 100644
--- a/nnpdfcpp/src/common/inc/nnpdfsettings.h
+++ b/nnpdfcpp/src/common/inc/nnpdfsettings.h
@@ -137,6 +137,7 @@ public:
   FlMutProperty  const& GetFlMutProp(int i) const { return fFlMutProperty[i]; }
   vector<int> const& GetArch() const { return fArch; }
   map<string,string> const& GetTheoryMap() const { return fTheory; }
+  bool SavePseudodata() const;
   bool IsQED() const;
   bool IsIC()  const;
   bool IsThUncertainties() const { return fThUncertainties; }
diff --git a/nnpdfcpp/src/common/src/nnpdfsettings.cc b/nnpdfcpp/src/common/src/nnpdfsettings.cc
index 0ca0af96af..dedc6d8725 100644
--- a/nnpdfcpp/src/common/src/nnpdfsettings.cc
+++ b/nnpdfcpp/src/common/src/nnpdfsettings.cc
@@ -600,6 +600,14 @@ void NNPDFSettings::SetPlotFile(string const& plotfile)
   }
 }
 
+bool NNPDFSettings::SavePseudodata() const
+{
+  if(NNPDFSettings::Exists("fitting", "savepseudodata"))
+    return Get("fitting", "savepseudodata").as<bool>();
+  return false;
+}
+
+
 bool NNPDFSettings::IsQED() const
 {
   const basisType isqed = NNPDFSettings::getFitBasisType(Get("fitting","fitbasis").as<string>());
diff --git a/nnpdfcpp/src/nnfit/inc/nnfit.h b/nnpdfcpp/src/nnfit/inc/nnfit.h
index 799154e02f..c254b15897 100644
--- a/nnpdfcpp/src/nnfit/inc/nnfit.h
+++ b/nnpdfcpp/src/nnfit/inc/nnfit.h
@@ -48,9 +48,10 @@ void CreateResultsFolder(const NNPDFSettings &settings, const int replica)
 void LoadAllDataAndSplit(NNPDFSettings const& settings,
                          vector<Experiment*> & training,
                          vector<Experiment*> & validation,
-                         vector<PositivitySet> & pos);
+                         vector<PositivitySet> & pos,
+                         int const& replica);
 
-void TrainValidSplit(const NNPDFSettings &settings, Experiment* const& exp, Experiment *&tr, Experiment *&val);
+void TrainValidSplit(const NNPDFSettings &settings, Experiment* const& exp, Experiment *&tr, Experiment *&val, int const& replica);
 
 
 // Add chi^2 results to fit log
diff --git a/nnpdfcpp/src/nnfit/src/nnfit.cc b/nnpdfcpp/src/nnfit/src/nnfit.cc
index fb068b253e..5ca7b3cbc4 100644
--- a/nnpdfcpp/src/nnfit/src/nnfit.cc
+++ b/nnpdfcpp/src/nnfit/src/nnfit.cc
@@ -117,7 +117,7 @@ int main(int argc, char **argv)
       vector<Experiment*> training;
       vector<Experiment*> validation;
       vector<PositivitySet> pos;
-      LoadAllDataAndSplit(settings, training, validation, pos);
+      LoadAllDataAndSplit(settings, training, validation, pos, replica);
 
       // Fit Basis
       std::unique_ptr<FitBasis> fitbasis(getFitBasis(settings, NNPDFSettings::getFitBasisType(settings.Get("fitting","fitbasis").as<string>()), replica));
@@ -385,7 +385,8 @@ int main(int argc, char **argv)
 void LoadAllDataAndSplit(NNPDFSettings const& settings,
                          vector<Experiment*> & training,
                          vector<Experiment*> & validation,
-                         vector<PositivitySet> & pos)
+                         vector<PositivitySet> & pos,
+                         int const& replica)
 {
   auto T0Set = std::make_unique<LHAPDFSet>(settings.Get("datacuts","t0pdfset").as<string>(), PDFSet::erType::ER_MCT0);
   for (int i = 0; i < settings.GetNExp(); i++)
@@ -418,7 +419,7 @@ void LoadAllDataAndSplit(NNPDFSettings const& settings,
       training.push_back(NULL);
       validation.push_back(NULL);
 
-      TrainValidSplit(settings, exp.get(), training.back(), validation.back());
+      TrainValidSplit(settings, exp.get(), training.back(), validation.back(), replica);
     }
 
   // Read Positivity Sets
@@ -505,13 +506,16 @@ void LogChi2(const FitPDFSet* pdf,
 }
 
 void TrainValidSplit(NNPDFSettings const& settings,
-                     Experiment* const& exp, Experiment* &tr, Experiment* &val)
+                     Experiment* const& exp, Experiment* &tr, Experiment* &val,
+                     int const& replica)
 {
   vector<DataSet> trainingSets;
   vector<DataSet> validationSets;
 
   vector<int> trCovMatMask(0);
   vector<int> valCovMatMask(0);
+  // Vector containg vectors of masks
+  vector<vector<int>> trMasks, valMasks;
   int AccumulatedData = 0;
 
   int expValSize = 0; // size of validation experiment
@@ -535,6 +539,10 @@ void TrainValidSplit(NNPDFSettings const& settings,
       std::sort(trMaskset.begin(), trMaskset.end());
       std::sort(valMaskset.begin(), valMaskset.end());
 
+
+      trMasks.push_back(trMaskset);
+      valMasks.push_back(valMaskset);
+
       if (settings.IsThUncertainties())
       {
         /*
@@ -574,6 +582,33 @@ void TrainValidSplit(NNPDFSettings const& settings,
   if (expValSize != 0)
       val = new Experiment(*exp, validationSets);
 
+
+  if(settings.SavePseudodata())
+  {
+    // Save the pseudodata if requested in the runcard
+    std::ofstream training_file, validation_file;
+    training_file.open(settings.GetResultsDirectory() + "/nnfit/replica_" + std::to_string(replica) + "/training.dat", std::ios_base::app);
+    validation_file.open(settings.GetResultsDirectory() + "/nnfit/replica_" + std::to_string(replica) + "/validation.dat", std::ios_base::app);
+
+    for(int i=0; i < tr->GetNSet(); ++i)
+    {
+      auto ds = tr->GetSet(i);
+      // The training mask for set i
+      vector<int> tr_mask = trMasks[i];
+      for(int j=0; j < ds.GetNData(); ++j){
+        training_file << tr->GetExpName() << "\t" << ds.GetSetName() << "\t" << tr_mask[j] << "\t" << ds.GetData(j) << "\n";
+        }
+    }
+    for(int i=0; i < val->GetNSet(); ++i)
+    {
+      auto ds = val->GetSet(i);
+      vector<int> val_mask = valMasks[i];
+      for(int j=0; j < ds.GetNData(); ++j){
+        validation_file << val->GetExpName() << "\t" << ds.GetSetName() << "\t" << val_mask[j] << "\t" << ds.GetData(j) << "\n";
+        }
+    }
+  }
+
   // read covmat from file if specified in the runcard
   if (settings.IsThUncertainties())
   {
-- 
GitLab