LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
compound_data_reader.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_GENERIC_COMPOUND_DATA_READER_HPP
28 #define LBANN_GENERIC_COMPOUND_DATA_READER_HPP
29 
30 #include "data_reader.hpp"
31 
32 #include <utility>
33 
34 namespace lbann {
35 
41 {
42 public:
43  generic_compound_data_reader(std::vector<generic_data_reader*> data_readers,
44  bool shuffle = true)
45  : generic_data_reader(shuffle), m_data_readers(std::move(data_readers))
46  {
47  if (m_data_readers.empty()) {
48  throw lbann_exception(
49  "generic_compound_data_reader: data reader list empty");
50  }
51  }
52 
54  : generic_data_reader(other)
55  {
56  for (auto&& reader : other.m_data_readers) {
57  m_data_readers.push_back(reader->copy());
58  }
59  }
62  {
64  for (auto&& reader : m_data_readers) {
65  delete reader;
66  }
67  m_data_readers.clear();
68  for (auto&& reader : other.m_data_readers) {
69  m_data_readers.push_back(reader->copy());
70  }
71  return *this;
72  }
74  {
75  for (auto&& reader : m_data_readers) {
76  delete reader;
77  }
78  }
79  generic_compound_data_reader* copy() const override = 0;
80 
81  //************************************************************************
83  //************************************************************************
85  {
89  for (auto&& reader : m_data_readers) {
90  reader->set_execution_mode_split_fraction(m, 0);
91  }
92  }
93 
94  void set_role(std::string role) override
95  {
97  for (auto&& reader : m_data_readers) {
98  reader->set_role(role);
99  }
100  }
101 
103  std::vector<generic_data_reader*>& get_data_readers()
104  {
105  return m_data_readers;
106  }
107 
108  bool has_labels() const override
109  {
110  for (auto&& reader : m_data_readers) {
111  return reader->has_labels();
112  }
113  return false;
114  }
115 
116  bool has_responses() const override
117  {
118  for (auto&& reader : m_data_readers) {
119  return reader->has_responses();
120  }
121  return false;
122  }
123 
124  void set_has_labels(const bool b) override
125  {
126  for (auto&& reader : m_data_readers) {
127  reader->set_has_labels(b);
128  }
129  }
131  void set_has_responses(const bool b) override
132  {
133  for (auto&& reader : m_data_readers) {
134  reader->set_has_responses(b);
135  }
136  }
137 
138  //************************************************************************
139 
140 protected:
142  std::vector<generic_data_reader*> m_data_readers;
143 };
144 
145 } // namespace lbann
146 
147 #endif // LBANN_GENERIC_COMPOUND_DATA_READER_HPP
void set_has_responses(const bool b) override
Whether or not a data reader has a response field.
generic_compound_data_reader(const generic_compound_data_reader &other)
virtual void set_role(std::string role)
generic_compound_data_reader & operator=(const generic_compound_data_reader &other)
virtual void set_execution_mode_split_fraction(execution_mode m, double s)
execution_mode
Neural network execution mode.
Definition: base.hpp:229
exception lbann_exception
Definition: exception.hpp:145
void set_has_labels(const bool b) override
Whether or not a data reader has labels.
generic_data_reader & operator=(const generic_data_reader &)=default
std::vector< generic_data_reader * > & get_data_readers()
needed to support data_store_merge_samples
generic_compound_data_reader * copy() const override=0
generic_compound_data_reader(std::vector< generic_data_reader *> data_readers, bool shuffle=true)
std::vector< generic_data_reader * > m_data_readers
List of readers providing data.
void set_role(std::string role) override
void set_execution_mode_split_fraction(execution_mode m, double s) override
Apply operations to subsidiary data readers.