LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
sample_list.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_DATA_READERS_SAMPLE_LIST_HPP
28 #define LBANN_DATA_READERS_SAMPLE_LIST_HPP
29 
30 #include <functional>
31 #include <iostream>
32 #include <string>
33 #include <vector>
34 
36 
37 namespace lbann {
38 
39 // Forward Declarations
40 class lbann_comm;
41 
42 static const std::string multi_sample_exclusion = "MULTI-SAMPLE_EXCLUSION";
43 static const std::string multi_sample_inclusion = "MULTI-SAMPLE_INCLUSION";
44 static const std::string single_sample = "SINGLE-SAMPLE";
45 static const std::string multi_sample_inclusion_v2 =
46  "MULTI-SAMPLE_INCLUSION_V2";
47 static const std::string conduit_hdf5_exclusion = "CONDUIT_HDF5_EXCLUSION";
48 static const std::string conduit_hdf5_inclusion = "CONDUIT_HDF5_INCLUSION";
49 
51 {
64  size_t m_num_files;
66  std::string m_file_dir;
67  std::string m_sample_list_name;
68  std::string m_label_filename;
69 
71 
72  void set_sample_list_type(const std::string& line1);
73  void set_sample_count(const std::string& line2);
74  void set_data_file_dir(const std::string& line3);
75  void set_label_filename(const std::string& line4);
76 
77  bool is_multi_sample() const;
78  bool is_exclusive() const;
79  bool use_label_header() const;
80  bool has_unused_sample_fields() const;
81  size_t get_sample_count() const;
82  size_t get_num_files() const;
83  const std::string& get_file_dir() const;
84  const std::string& get_sample_list_name() const;
86  void set_sample_list_name(const std::string& n);
87  const std::string& get_label_filename() const;
88  template <class Archive>
89  void serialize(Archive& ar);
90 };
91 
92 template <typename sample_name_t>
94 {
95 public:
96  using name_t = sample_name_t;
98  using sample_file_id_t = std::size_t;
101  using sample_t = std::template pair<sample_file_id_t, sample_name_t>;
103  using samples_t = std::template vector<sample_t>;
105  using sample_idx_t = typename samples_t::size_type;
107  using sample_map_t = std::unordered_map<sample_name_t, sample_idx_t>;
109  using file_id_stats_v_t = std::vector<std::string>;
110 
111  sample_list();
112  virtual ~sample_list();
113  sample_list(const sample_list& rhs);
114  sample_list& operator=(const sample_list& rhs);
115  sample_list& copy(const sample_list& rhs);
116 
117  void copy_members(const sample_list& rhs);
118 
121  void load(std::istream& istrm, size_t stride = 1, size_t offset = 0);
122 
127  void load(const std::string& samplelist_file,
128  const lbann_comm& comm,
129  bool interleave);
130  void load(std::istream& istrm, const lbann_comm& comm, bool interleave);
133  void load(const sample_list_header& header,
134  std::istream& istrm,
135  const lbann_comm& comm,
136  bool interleave);
137 
139  void load_from_string(const std::string& samplelist,
140  const lbann_comm& comm,
141  bool interleave);
142 
144  virtual size_t size() const;
145 
147  virtual size_t get_num_files() const;
148 
150  bool empty() const;
151 
153  template <class Archive>
154  void serialize(Archive& ar);
155 
157  virtual bool to_string(std::string& sstr) const;
158 
160  void write(const std::string filename) const;
161 
163  const samples_t& get_list() const;
164 
166  const sample_list_header& get_header() const;
167 
169  const sample_t& operator[](size_t idx) const;
170 
171  virtual const std::string& get_samples_filename(sample_file_id_t id) const;
172 
173  const std::string& get_samples_dirname() const;
174  const std::string& get_label_filename() const;
175 
176  void all_gather_archive(const std::string& archive,
177  std::vector<std::string>& gathered_archive,
178  lbann_comm& comm);
179  void all_gather_archive_new(const std::string& archive,
180  std::vector<std::string>& gathered_archive,
181  lbann_comm& comm);
182 
183  template <typename T>
184  size_t
185  all_gather_field(T data, std::vector<T>& gathered_data, lbann_comm& comm);
186  virtual void all_gather_packed_lists(lbann_comm& comm);
187 
189  void keep_sample_order(bool keep);
190 
193  void set_sample_list_name(const std::string& n);
194 
196  void set_data_file_check();
198  void unset_data_file_check();
199 
201  void build_sample_map_from_name_to_index();
202 
204  void clear_sample_map_from_name_to_index();
205 
207  sample_idx_t get_sample_index(const sample_name_t& sn);
208 
209 protected:
212  std::string read_header_line(std::istream& ifs,
213  const std::string& listname,
214  const std::string& info);
215 
217  void read_header(std::istream& istrm);
218 
221  virtual void
222  read_sample_list(std::istream& istrm, size_t stride = 1, size_t offset = 0);
223 
226  virtual void assign_samples_name();
227 
229  size_t get_samples_per_file(std::istream& istrm,
230  size_t stride = 1,
231  size_t offset = 0);
232 
234  void write_header(std::string& sstr, size_t num_files) const;
235 
237  virtual void
238  get_num_samples(size_t& total, size_t& included, size_t& excluded) const;
239 
240  virtual void set_samples_filename(sample_file_id_t id,
241  const std::string& filename);
242 
244  virtual void reorder();
245 
246 protected:
249 
251  size_t m_stride;
252 
255 
258 
261 
264 
265 private:
268 };
269 
270 void handle_mpi_error(int ierr);
271 
272 template <typename T>
273 inline T uninitialized_sample_name();
274 
275 } // namespace lbann
276 
277 #endif // LBANN_DATA_READERS_SAMPLE_LIST_HPP
bool m_is_multi_sample
Whether each data file includes multiple samples.
Definition: sample_list.hpp:53
void set_label_filename(const std::string &line4)
size_t m_stride
The stride used in loading sample list file.
std::size_t sample_file_id_t
The type for the index assigned to each sample file.
Definition: sample_list.hpp:98
const std::string & get_sample_list_name() const
sample_list_header m_header
header info of sample list
bool m_is_exclusive
Whether to list the IDs of samples to exclude or to include.
Definition: sample_list.hpp:55
std::template vector< sample_t > samples_t
Type for the list of samples.
static const std::string multi_sample_inclusion
Definition: sample_list.hpp:43
void write(std::ostream &os, google::protobuf::Message const &msg)
Write the protobuf message in prototext in a stream.
T uninitialized_sample_name()
std::unordered_map< long long, sample_idx_t > sample_map_t
Type for the map from sample name to the sample list index.
bool m_no_label_header
Whether to read the header line for a label file.
Definition: sample_list.hpp:57
T & data(const cnpy::NpyArray &na, const std::vector< size_t > indices)
Definition: cnpy_utils.hpp:75
sample_map_t m_map_name_to_idx
Map from sample name to the corresponding index into the sample list.
void load(std::string const &pbuf_filename, google::protobuf::Message &msg)
Fill the protobuf message from a binary file.
static const std::string conduit_hdf5_inclusion
Definition: sample_list.hpp:48
const std::string & get_file_dir() const
std::string to_string(El::Device const &d)
static const std::string conduit_hdf5_exclusion
Definition: sample_list.hpp:47
const std::string & get_label_filename() const
std::string m_sample_list_name
Definition: sample_list.hpp:67
static const std::string multi_sample_exclusion
Definition: sample_list.hpp:42
bool m_keep_order
maintain the original sample order as listed in the file
std::template pair< sample_file_id_t, long long > sample_t
size_t m_excluded_sample_count
Number of excluded samples.
Definition: sample_list.hpp:63
void set_sample_list_name(const std::string &n)
Save the filename or stream name of this sample list for debugging.
bool m_has_unused_sample_fields
Whether the sample list has fields to represent unused samples.
Definition: sample_list.hpp:59
samples_t m_sample_list
List of all samples with a file identifier and sample name for each sample.
static const std::string multi_sample_inclusion_v2
Definition: sample_list.hpp:45
std::string m_file_dir
Data file directory.
Definition: sample_list.hpp:66
void set_sample_count(const std::string &line2)
bool m_check_data_file
Whether to check the existence of data file.
void handle_mpi_error(int ierr)
file_id_stats_v_t m_file_id_stats_map
Maps sample&#39;s file id to file names, file descriptors, and use counts.
void set_data_file_dir(const std::string &line3)
std::vector< std::string > file_id_stats_v_t
Mapping of the file index to the filename.
static const std::string single_sample
Definition: sample_list.hpp:44
void set_sample_list_type(const std::string &line1)
size_t m_included_sample_count
Number of included samples.
Definition: sample_list.hpp:61
typename samples_t::size_type sample_idx_t
Type for the index into the sample list.