LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
proto_common.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_PROTO_PROTO_COMMON_HPP_INCLUDED
28 #define LBANN_PROTO_PROTO_COMMON_HPP_INCLUDED
29 
31 
32 #define LBANN_ASSERT_MSG_HAS_FIELD(MSG, FIELD) \
33  do { \
34  if (!MSG.has_##FIELD()) { \
35  LBANN_ERROR("No field \"" #FIELD "\" in the given message:\n{\n", \
36  MSG.DebugString(), \
37  "\n}\n"); \
38  } \
39  } while (false)
40 
41 // Forward declaration of protobuf classes
42 namespace lbann_data {
43 class LbannPB;
44 class Trainer;
45 } // namespace lbann_data
46 
47 namespace lbann {
48 
61 void customize_data_readers_sample_list(const lbann_comm& comm,
62  ::lbann_data::LbannPB& p);
63 
68  lbann_comm* comm,
69  const ::lbann_data::LbannPB& p,
70  std::map<execution_mode, generic_data_reader*>& data_readers);
71 
73 void set_num_parallel_readers(const lbann_comm& comm, ::lbann_data::LbannPB& p);
74 
76 void get_cmdline_overrides(const lbann_comm& comm, ::lbann_data::LbannPB& p);
77 
79 void print_parameters(const lbann_comm& comm,
80  ::lbann_data::LbannPB& p,
81  std::vector<int>& root_random_seeds,
82  std::vector<int>& random_seeds,
83  std::vector<int>& data_seq_random_seeds);
84 
86 void save_session(const lbann_comm& comm,
87  const int argc,
88  char* const* argv,
89  ::lbann_data::LbannPB& p);
90 
92 void read_prototext_file(const std::string& fn,
93  ::lbann_data::LbannPB& pb,
94  const bool master);
95 
97 void read_prototext_string(const std::string& contents,
98  lbann_data::LbannPB& pb,
99  const bool master);
100 
102 bool write_prototext_file(const std::string& fn, ::lbann_data::LbannPB& pb);
103 
105 std::string trim(std::string const& str);
106 
107 // These functions work on trimmed, nonempty strings
108 namespace details {
109 
110 template <typename T>
111 std::vector<T> parse_list_impl(std::string const& str)
112 {
113 #ifdef LBANN_HAS_GPU_FP16
114  using ParseType =
115  typename std::conditional<std::is_same<T, fp16>::value, float, T>::type;
116 #else
117  using ParseType = T;
118 #endif
119  ParseType entry;
120  std::vector<T> list;
121  std::istringstream iss(str);
122  while (iss.good()) {
123  iss >> entry;
124  list.emplace_back(std::move(entry));
125  }
126  return list;
127 }
128 
129 template <typename T>
130 std::set<T> parse_set_impl(std::string const& str)
131 {
132 #ifdef LBANN_HAS_GPU_FP16
133  using ParseType =
134  typename std::conditional<std::is_same<T, fp16>::value, float, T>::type;
135 #else
136  using ParseType = T;
137 #endif
138  ParseType entry;
139  std::set<T> set;
140  std::istringstream iss(str);
141  while (iss.good()) {
142  iss >> entry;
143  set.emplace(std::move(entry));
144  }
145  return set;
146 }
147 
148 // TODO (trb 07/25/19): we should think about what to do about bad
149 // input. That is, if a user calls parse_list<int>("one two three"),
150 // the result is undefined (one test I did gave [0,0,0] and another
151 // gave [INT_MAX,INT_MAX,INT_MAX]). In most cases in LBANN, I would
152 // guess that this will result in a logic error further down the
153 // codepath, but we shouldn't count on it.
154 
155 } // namespace details
156 
158 template <typename T = std::string>
159 std::vector<T> parse_list(std::string const& str)
160 {
161  auto trim_str = trim(str);
162  if (trim_str.size())
163  return details::parse_list_impl<T>(trim_str);
164  return {};
165 }
166 
168 template <typename T = std::string>
169 std::set<T> parse_set(std::string const& str)
170 {
171  auto trim_str = trim(str);
172  if (trim_str.size())
173  return details::parse_set_impl<T>(trim_str);
174  return {};
175 }
176 
177 } // namespace lbann
178 
179 #endif // LBANN_PROTO_PROTO_COMMON_HPP_INCLUDED
std::set< T > parse_set(std::string const &str)
Parse a space-separated set.
std::set< T > parse_set_impl(std::string const &str)
std::vector< T > parse_list_impl(std::string const &str)
void save_session(const lbann_comm &comm, const int argc, char *const *argv, ::lbann_data::LbannPB &p)
prints prototext file, cmd line, etc to file
void get_cmdline_overrides(const lbann_comm &comm, ::lbann_data::LbannPB &p)
adjusts the values in p by querying the options db
void customize_data_readers_sample_list(const lbann_comm &comm, ::lbann_data::LbannPB &p)
Customize the name of the sample list.
void read_prototext_string(const std::string &contents, lbann_data::LbannPB &pb, const bool master)
Read prototext from a string into a protobuf message.
std::string trim(std::string const &str)
Trim leading and trailing whitespace from a string.
void print_parameters(const lbann_comm &comm, ::lbann_data::LbannPB &p, std::vector< int > &root_random_seeds, std::vector< int > &random_seeds, std::vector< int > &data_seq_random_seeds)
print various params (learn_rate, etc) to cout
bool write_prototext_file(const std::string &fn, ::lbann_data::LbannPB &pb)
Write a protobuf message into a prototext file.
void read_prototext_file(const std::string &fn, ::lbann_data::LbannPB &pb, const bool master)
Read prototext from a file into a protobuf message.
void init_data_readers(lbann_comm *comm, const ::lbann_data::LbannPB &p, std::map< execution_mode, generic_data_reader *> &data_readers)
instantiates one or more generic_data_readers and inserts them in &data_readers
void set_num_parallel_readers(const lbann_comm &comm, ::lbann_data::LbannPB &p)
adjusts the number of parallel data readers
std::vector< T > parse_list(std::string const &str)
Parse a space-separated list.