LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
data_reader_mnist.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // mnist_reader .hpp .cpp - data reader class for MNIST dataset
28 
29 #ifndef LBANN_DATA_READER_MNIST_HPP
30 #define LBANN_DATA_READER_MNIST_HPP
31 
32 #include "data_reader_image.hpp"
33 
34 namespace lbann {
35 
37 {
38 public:
39  mnist_reader(bool shuffle = true);
40  mnist_reader();
41  mnist_reader(const mnist_reader&) = default;
42  mnist_reader& operator=(const mnist_reader&) = default;
43  ~mnist_reader() override {}
44  mnist_reader* copy() const override { return new mnist_reader(*this); }
45 
46  std::string get_type() const override { return "mnist_reader"; }
47 
48  void set_input_params(const int, const int, const int, const int) override
49  {
50  set_defaults();
51  }
52 
53  // MNIST-specific functions
54  void load() override;
55 
56 protected:
57  void set_defaults() override;
58  bool fetch_datum(CPUMat& X, int data_id, int mb_idx) override;
59  bool fetch_label(CPUMat& Y, int data_id, int mb_idx) override;
60 
61 protected:
62  std::vector<std::vector<unsigned char>> m_image_data;
63 };
64 
65 } // namespace lbann
66 
67 #endif // LBANN_DATA_READER_MNIST_HPP
mnist_reader & operator=(const mnist_reader &)=default
void load() override
bool fetch_label(CPUMat &Y, int data_id, int mb_idx) override
void set_defaults() override
mnist_reader * copy() const override
El::Matrix< DataType, El::Device::CPU > CPUMat
Definition: base.hpp:116
void set_input_params(const int, const int, const int, const int) override
bool fetch_datum(CPUMat &X, int data_id, int mb_idx) override
std::vector< std::vector< unsigned char > > m_image_data
std::string get_type() const override