LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
data_reader_cifar10.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // data_reader_cifar10 .hpp .cpp - Data reader for CIFAR-10/100
28 
29 #ifndef LBANN_DATA_READER_CIFAR10_HPP
30 #define LBANN_DATA_READER_CIFAR10_HPP
31 
32 #include "data_reader_image.hpp"
33 
34 namespace lbann {
35 
48 {
49 public:
50  cifar10_reader(bool shuffle = true);
51  cifar10_reader(const cifar10_reader&) = default;
52  cifar10_reader& operator=(const cifar10_reader&) = default;
53 
54  ~cifar10_reader() override;
55 
56  cifar10_reader* copy() const override { return new cifar10_reader(*this); }
57 
58  std::string get_type() const override { return "cifar10_reader"; }
59 
60  void set_input_params(const int, const int, const int, const int) override
61  {
62  set_defaults();
63  }
64  void load() override;
65 
66 protected:
67  void set_defaults() override;
68  bool fetch_datum(CPUMat& X, int data_id, int mb_idx) override;
69  bool fetch_label(CPUMat& Y, int data_id, int mb_idx) override;
70 
71 private:
76  std::vector<std::vector<unsigned char>> m_images;
78  std::vector<uint8_t> m_labels;
79 };
80 
81 } // namespace lbann
82 
83 #endif // LBANN_DATA_READER_CIFAR10_HPP
void set_defaults() override
cifar10_reader & operator=(const cifar10_reader &)=default
void load() override
bool fetch_datum(CPUMat &X, int data_id, int mb_idx) override
cifar10_reader(bool shuffle=true)
std::vector< uint8_t > m_labels
bool fetch_label(CPUMat &Y, int data_id, int mb_idx) override
El::Matrix< DataType, El::Device::CPU > CPUMat
Definition: base.hpp:116
std::vector< std::vector< unsigned char > > m_images
void set_input_params(const int, const int, const int, const int) override
~cifar10_reader() override
std::string get_type() const override
cifar10_reader * copy() const override