29 #ifndef LBANN_DATA_READER_CSV_HPP 30 #define LBANN_DATA_READER_CSV_HPP 33 #include <unordered_map> 57 std::string
get_type()
const override {
return "csv_reader"; }
91 std::function<DataType(
const std::string&)> f)
115 void load()
override;
117 void setup(
int num_io_threads,
159 std::vector<DataType>
fetch_line(
int data_id);
162 void skip_rows(std::ifstream& s,
int rows);
211 std::unordered_map<int, std::function<DataType(const std::string&)>>
215 [](
const std::string& s) ->
int {
return std::stoi(s); };
218 [](
const std::string& s) -> DataType {
return std::stod(s); };
223 #endif // LBANN_DATA_READER_CSV_HPP std::vector< std::streampos > m_index
void set_has_header(bool b)
Set whether the CSV file has a header; default true.
std::vector< std::ifstream * > m_ifstreams
Input file streams (per-thread).
std::function< DataType(const std::string &)> m_response_transform
Response transform function that converts to a DataType.
int m_response_col
Column containing responses; functions the same as the label column.
std::map< data_field_type, bool > m_supported_input_types
Holds a true value for each input data type that is supported. Use an ordered map so that checkpoints...
int m_num_labels
Number of label classes.
csv_reader & operator=(const csv_reader &)
void set_column_transform(int col, std::function< DataType(const std::string &)> f)
std::function< int(const std::string &)> m_label_transform
Label transform function that converts to an int.
void enable_responses(bool b=false)
Enable fetching responses (disabled by default).
void set_response_transform(std::function< DataType(const std::string &)> f)
bool m_has_header
Whether the CSV file has a header.
std::vector< DataType > fetch_line(int data_id)
int get_linearized_label_size() const override
Get the linearized size (i.e. number of elements) in a label.
std::unordered_map< int, std::function< DataType(const std::string &)> > m_col_transforms
Per-column transformation functions.
int m_num_samples
Number of samples.
void set_separator(char sep)
Set the column separator (default is ',').
void setup(int num_io_threads, observer_ptr< thread_pool > io_thread_pool) override
bool fetch_label(CPUMat &Y, int data_id, int mb_idx) override
Fetch the label associated with data_id.
El::Matrix< DataType, El::Device::CPU > CPUMat
std::vector< DataType > fetch_line_label_response(int data_id)
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
void set_label_col(int col)
Set the label column.
bool m_disable_responses
Whether to fetch responses.
csv_reader * copy() const override
std::string fetch_raw_line(int data_id)
int get_num_labels() const override
Return the number of labels (classes) in this dataset.
void skip_rows(std::ifstream &s, int rows)
Skip rows in an ifstream.
int m_skip_cols
Number of columns (from the left) to skip.
int get_linearized_data_size() const override
Get the linearized size (i.e. number of elements) in a sample.
void set_skip_cols(int cols)
Set the number of columns (from the left) to skip; default 0.
bool m_disable_labels
Whether to fetch labels.
csv_reader(bool shuffle=true)
int m_num_cols
Number of columns (including the label column and skipped columns).
void setup_ifstreams()
Initialize the ifstreams vector.
bool fetch_response(CPUMat &Y, int data_id, int mb_idx) override
Fetch the response associated with data_id.
const std::vector< El::Int > get_data_dims() const override
Get the dimensions of the data.
std::vector< DataType > m_responses
Store responses.
void set_label_transform(std::function< int(const std::string &)> f)
void disable_labels(bool b=true)
Disable fetching labels.
void set_skip_rows(int rows)
Set the number of rows (from the top) to skip; default 0.
std::string get_type() const override
char m_separator
String value that separates data.
std::vector< int > m_labels
Store labels.
int m_skip_rows
Number of rows to skip.
void set_response_col(int col)
Set the response column.
bool fetch_datum(CPUMat &X, int data_id, int mb_idx) override