LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
data_reader_image.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // data_reader_image .hpp .cpp - generic data reader class for image dataset
28 
29 #ifndef IMAGE_DATA_READER_HPP
30 #define IMAGE_DATA_READER_HPP
31 
35 
36 namespace lbann {
38 {
39 public:
40  using img_src_t = std::string;
41  using label_t = int;
42  using sample_t = std::pair<img_src_t, label_t>;
46  using labels_t = std::vector<label_t>;
47 
48  image_data_reader(bool shuffle = true);
51 
57  virtual void set_input_params(const int width = 0,
58  const int height = 0,
59  const int num_ch = 0,
60  const int num_labels = 0);
61 
62  // dataset specific functions
63  void load() override;
64 
65  void setup(int num_io_threads,
66  observer_ptr<thread_pool> io_thread_pool) override;
67 
68  int get_num_labels() const override { return m_num_labels; }
69  virtual int get_image_width() const { return m_image_width; }
70  virtual int get_image_height() const { return m_image_height; }
71  virtual int get_image_num_channels() const { return m_image_num_channels; }
73  int get_linearized_data_size() const override
74  {
76  }
77  int get_linearized_label_size() const override { return m_num_labels; }
78  const std::vector<El::Int> get_data_dims() const override
79  {
81  }
82 
84  const sample_list_t& get_sample_list() const { return m_sample_list; }
85 
91  sample_t get_sample(const size_t idx) const;
92 
93  void do_preload_data_store() override;
94 
95  void load_conduit_node_from_file(int data_id, conduit::Node& node);
96 
97 protected:
98  void copy_members(const image_data_reader& rhs);
99 
102  virtual void set_defaults();
103  bool fetch_label(Mat& Y, int data_id, int mb_idx) override;
105 
108  void dump_sample_label_list(const std::string& dump_file_name);
110  void load_list_of_samples(const std::string filename);
112  void
113  load_list_of_samples_from_archive(const std::string& sample_list_archive);
116  void gen_list_of_samples();
118  void load_labels(std::vector<char>& preloaded_buffer);
120  void read_labels(std::istream& istrm);
122  size_t determine_num_of_samples(std::istream& istrm) const;
123 
124  std::string m_image_dir;
130 
133 
134  bool load_conduit_nodes_from_file(const std::unordered_set<int>& data_ids);
135 };
136 
137 } // namespace lbann
138 
139 #endif // IMAGE_DATA_READER_HPP
image_data_reader & operator=(const image_data_reader &)
bool load_conduit_nodes_from_file(const std::unordered_set< int > &data_ids)
int m_num_labels
number of labels
size_t determine_num_of_samples(std::istream &istrm) const
Return the number of lines in the input stream.
int get_linearized_data_size() const override
Get the total number of channel values in a sample of image(s).
sample_list_t::sample_idx_t sample_idx_t
virtual void set_defaults()
void load_conduit_node_from_file(int data_id, conduit::Node &node)
const std::vector< El::Int > get_data_dims() const override
Get the dimensions of the data.
virtual int get_image_height() const
image_data_reader(bool shuffle=true)
std::vector< label_t > labels_t
void load() override
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
Definition: base.hpp:54
void copy_members(const image_data_reader &rhs)
void do_preload_data_store() override
sample_t get_sample(const size_t idx) const
int m_image_linearized_size
linearized image size
El::Matrix< DataType, El::Device::CPU > Mat
Definition: base.hpp:185
virtual int get_image_num_channels() const
virtual void set_input_params(const int width=0, const int height=0, const int num_ch=0, const int num_labels=0)
const sample_list_t & get_sample_list() const
Allow read-only access to the entire sample list.
std::string m_image_dir
where images are stored
void load_labels(std::vector< char > &preloaded_buffer)
Load the labels for samples.
int get_num_labels() const override
Return the number of labels (classes) in this dataset.
std::pair< img_src_t, label_t > sample_t
void setup(int num_io_threads, observer_ptr< thread_pool > io_thread_pool) override
void load_list_of_samples_from_archive(const std::string &sample_list_archive)
Load the sample list from a serialized archive from another rank.
void load_list_of_samples(const std::string filename)
Rely on pre-determined list of samples.
int get_linearized_label_size() const override
Get the linearized size (i.e. number of elements) in a label.
virtual int get_image_width() const
bool fetch_label(Mat &Y, int data_id, int mb_idx) override
int m_image_num_channels
number of image channels
void dump_sample_label_list(const std::string &dump_file_name)
void read_labels(std::istream &istrm)
Read the labels from an open input stream.
typename samples_t::size_type sample_idx_t
Type for the index into the sample list.