LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
data_reader_csv.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // data_reader_csv .hpp .cpp - generic_data_reader class for CSV files
28 
29 #ifndef LBANN_DATA_READER_CSV_HPP
30 #define LBANN_DATA_READER_CSV_HPP
31 
32 #include "data_reader.hpp"
33 #include <unordered_map>
34 
35 namespace lbann {
36 
45 {
46 public:
50  csv_reader(bool shuffle = true);
51  csv_reader(const csv_reader&);
53  ~csv_reader() override;
54 
55  csv_reader* copy() const override { return new csv_reader(*this); }
56 
57  std::string get_type() const override { return "csv_reader"; }
58 
60  void set_label_col(int col) { m_label_col = col; }
62  void set_response_col(int col) { m_response_col = col; }
64  void disable_labels(bool b = true)
65  {
67  m_disable_labels = b;
68  }
70  void enable_responses(bool b = false)
71  {
74  }
76  void set_separator(char sep) { m_separator = sep; }
78  void set_skip_cols(int cols) { m_skip_cols = cols; }
80  void set_skip_rows(int rows) { m_skip_rows = rows; }
82  void set_has_header(bool b) { m_has_header = b; }
83 
90  void set_column_transform(int col,
91  std::function<DataType(const std::string&)> f)
92  {
93  m_col_transforms[col] = f;
94  }
95 
100  void set_label_transform(std::function<int(const std::string&)> f)
101  {
102  m_label_transform = f;
103  }
107  void set_response_transform(std::function<DataType(const std::string&)> f)
108  {
110  }
111 
115  void load() override;
116 
117  void setup(int num_io_threads,
118  observer_ptr<thread_pool> io_thread_pool) override;
119 
120  int get_num_labels() const override { return m_num_labels; }
121  int get_linearized_data_size() const override
122  {
123  // Account for label and skipped columns.
124  if (m_label_col < m_skip_cols) {
125  return m_num_cols - m_skip_cols;
126  }
127  else {
128  return m_num_cols - 1 - m_skip_cols;
129  }
130  }
131  int get_linearized_label_size() const override { return m_num_labels; }
132  const std::vector<El::Int> get_data_dims() const override
133  {
134  return {get_linearized_data_size()};
135  }
136 
143  std::vector<DataType> fetch_line_label_response(int data_id);
144 
145 protected:
150  bool fetch_datum(CPUMat& X, int data_id, int mb_idx) override;
152  bool fetch_label(CPUMat& Y, int data_id, int mb_idx) override;
154  bool fetch_response(CPUMat& Y, int data_id, int mb_idx) override;
155 
159  std::vector<DataType> fetch_line(int data_id);
160 
162  void skip_rows(std::ifstream& s, int rows);
163 
165  void setup_ifstreams();
166 
170  std::string fetch_raw_line(int data_id);
171 
173  char m_separator = ',';
175  int m_skip_cols = 0;
177  int m_skip_rows = 0;
179  bool m_has_header = true;
185  int m_label_col = -1;
187  int m_response_col = -1;
189  bool m_disable_labels = false;
191  bool m_disable_responses = true;
193  int m_num_cols = 0;
195  int m_num_samples = 0;
197  int m_num_labels = 0;
199  std::vector<std::ifstream*> m_ifstreams;
205  std::vector<std::streampos> m_index;
207  std::vector<int> m_labels;
209  std::vector<DataType> m_responses;
211  std::unordered_map<int, std::function<DataType(const std::string&)>>
214  std::function<int(const std::string&)> m_label_transform =
215  [](const std::string& s) -> int { return std::stoi(s); };
217  std::function<DataType(const std::string&)> m_response_transform =
218  [](const std::string& s) -> DataType { return std::stod(s); };
219 };
220 
221 } // namespace lbann
222 
223 #endif // LBANN_DATA_READER_CSV_HPP
std::vector< std::streampos > m_index
void set_has_header(bool b)
Set whether the CSV file has a header; default true.
std::vector< std::ifstream * > m_ifstreams
Input file streams (per-thread).
std::function< DataType(const std::string &)> m_response_transform
Response transform function that converts to a DataType.
int m_response_col
Column containing responses; functions the same as the label column.
std::map< data_field_type, bool > m_supported_input_types
Holds a true value for each input data type that is supported. Use an ordered map so that checkpoints...
int m_num_labels
Number of label classes.
csv_reader & operator=(const csv_reader &)
#define INPUT_DATA_TYPE_LABELS
void set_column_transform(int col, std::function< DataType(const std::string &)> f)
std::function< int(const std::string &)> m_label_transform
Label transform function that converts to an int.
void enable_responses(bool b=false)
Enable fetching responses (disabled by default).
void set_response_transform(std::function< DataType(const std::string &)> f)
bool m_has_header
Whether the CSV file has a header.
std::vector< DataType > fetch_line(int data_id)
int get_linearized_label_size() const override
Get the linearized size (i.e. number of elements) in a label.
std::unordered_map< int, std::function< DataType(const std::string &)> > m_col_transforms
Per-column transformation functions.
int m_num_samples
Number of samples.
void set_separator(char sep)
Set the column separator (default is &#39;,&#39;).
void setup(int num_io_threads, observer_ptr< thread_pool > io_thread_pool) override
~csv_reader() override
bool fetch_label(CPUMat &Y, int data_id, int mb_idx) override
Fetch the label associated with data_id.
El::Matrix< DataType, El::Device::CPU > CPUMat
Definition: base.hpp:116
std::vector< DataType > fetch_line_label_response(int data_id)
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
Definition: base.hpp:54
void set_label_col(int col)
Set the label column.
bool m_disable_responses
Whether to fetch responses.
csv_reader * copy() const override
std::string fetch_raw_line(int data_id)
int get_num_labels() const override
Return the number of labels (classes) in this dataset.
void skip_rows(std::ifstream &s, int rows)
Skip rows in an ifstream.
int m_skip_cols
Number of columns (from the left) to skip.
#define INPUT_DATA_TYPE_RESPONSES
int get_linearized_data_size() const override
Get the linearized size (i.e. number of elements) in a sample.
void set_skip_cols(int cols)
Set the number of columns (from the left) to skip; default 0.
bool m_disable_labels
Whether to fetch labels.
csv_reader(bool shuffle=true)
int m_num_cols
Number of columns (including the label column and skipped columns).
void setup_ifstreams()
Initialize the ifstreams vector.
bool fetch_response(CPUMat &Y, int data_id, int mb_idx) override
Fetch the response associated with data_id.
const std::vector< El::Int > get_data_dims() const override
Get the dimensions of the data.
std::vector< DataType > m_responses
Store responses.
void set_label_transform(std::function< int(const std::string &)> f)
void disable_labels(bool b=true)
Disable fetching labels.
void set_skip_rows(int rows)
Set the number of rows (from the top) to skip; default 0.
std::string get_type() const override
char m_separator
String value that separates data.
std::vector< int > m_labels
Store labels.
int m_skip_rows
Number of rows to skip.
void set_response_col(int col)
Set the response column.
bool fetch_datum(CPUMat &X, int data_id, int mb_idx) override
void load() override