LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
argument_parser.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_UTILS_ARGUMENT_PARSER_HPP_INCLUDED
28 #define LBANN_UTILS_ARGUMENT_PARSER_HPP_INCLUDED
29 
32 
33 #include <clara.hpp>
34 
35 #include <any>
36 #include <initializer_list>
37 #include <iostream>
38 #include <sstream>
39 #include <stdexcept>
40 #include <string>
41 #include <unordered_map>
42 #include <unordered_set>
43 #include <utility>
44 
45 namespace lbann {
46 namespace utils {
47 
52 struct parse_error : std::runtime_error
53 {
57  template <typename T>
58  parse_error(T&& what_arg) : std::runtime_error{std::forward<T>(what_arg)}
59  {}
60 }; // parse_error
61 
68 {
69  void handle_error(clara::detail::InternalParseResult result,
70  clara::Parser& parser,
71  std::vector<char const*>& argv);
72 }; // struct strict_parsing
73 
81 {
82  void handle_error(clara::detail::InternalParseResult result,
83  clara::Parser& parser,
84  std::vector<char const*>& argv);
85 }; // struct allow_extra_parameters
86 
159 template <typename ErrorHandler>
160 class argument_parser : ErrorHandler
161 {
162 public:
164 
174  template <typename T>
176  {
177  public:
178  readonly_reference(T& val) noexcept : ref_(val) {}
179  T const& get() const noexcept { return ref_; }
180  operator T const&() const noexcept { return this->get(); }
181 
182  template <typename S>
183  bool operator==(S const& y) const noexcept
184  {
185  return this->get() == y;
186  }
187 
188  private:
189  T& ref_;
190  }; // class readonly_reference<T>
191 
196  struct parse_error : std::runtime_error
197  {
201  template <typename T>
202  parse_error(T&& what_arg) : std::runtime_error{std::forward<T>(what_arg)}
203  {}
204  };
205 
210  struct missing_required_arguments : std::runtime_error
211  {
218  template <typename Container>
219  missing_required_arguments(Container const& missing_args)
220  : std::runtime_error{build_what_string_(missing_args)}
221  {}
222 
223  private:
224  template <typename Container>
225  std::string build_what_string_(Container const& missing_args)
226  {
227  std::ostringstream oss;
228  oss << "The following required arguments are missing: {";
229  for (auto const& x : missing_args)
230  oss << " \"" << x << "\"";
231  oss << " }";
232  return oss.str();
233  }
234  };
235 
237 
238 public:
240 
243  argument_parser();
244 
257  argument_parser(argument_parser const&) = delete;
258 
260  argument_parser& operator=(argument_parser const&) = delete;
261 
263  argument_parser(argument_parser&&) = default;
264 
266  argument_parser& operator=(argument_parser&&) = default;
267 
269 
270 
291  add_flag(std::string const& name,
292  std::initializer_list<std::string> cli_flags,
293  std::string const& description);
294 
316  template <typename AccessPolicy>
318  add_flag(std::string const& name,
319  std::initializer_list<std::string> cli_flags,
321  std::string const& description)
322  {
323  if (env.exists() && env.template value<bool>())
324  return add_flag_impl_(name,
325  std::move(cli_flags),
326  description + "\nENV: {" + env.name() + "}",
327  true);
328  else
329  return add_flag(name,
330  std::move(cli_flags),
331  description + "\nENV: {" + env.name() + "}");
332  }
333 
356  template <typename T>
357  readonly_reference<T> add_option(std::string const& name,
358  std::initializer_list<std::string> cli_flags,
359  std::string const& description,
360  T default_value = T());
361 
388  template <typename T, typename AccessPolicy>
389  readonly_reference<T> add_option(std::string const& name,
390  std::initializer_list<std::string> cli_flags,
392  std::string const& description,
393  T default_value = T())
394  {
395  if (env.exists())
396  return add_option(name,
397  std::move(cli_flags),
398  description + "\nENV: {" + env.name() + "}",
399  env.template value<T>());
400  else
401  return add_option(name,
402  std::move(cli_flags),
403  description + "\nENV: {" + env.name() + "}",
404  std::move(default_value));
405  }
406 
425  add_option(std::string const& name,
426  std::initializer_list<std::string> cli_flags,
427  std::string const& description,
428  char const* default_value)
429  {
430  return add_option(name,
431  std::move(cli_flags),
432  description,
433  std::string(default_value));
434  }
435 
455  template <typename AccessPolicy>
457  add_option(std::string const& name,
458  std::initializer_list<std::string> cli_flags,
460  std::string const& description,
461  char const* default_value)
462  {
463  return add_option(name,
464  cli_flags,
465  std::move(env),
466  description + "\nENV: {" + env.name() + "}",
467  std::string(default_value));
468  }
469 
489  template <typename T>
490  readonly_reference<T> add_argument(std::string const& name,
491  std::string const& description,
492  T default_value = T());
493 
509  std::string const& description,
510  char const* default_value)
511  {
512  return add_argument(name, description, std::string(default_value));
513  }
514 
526  template <typename T>
527  readonly_reference<T> add_required_argument(std::string const& name,
528  std::string const& description);
529 
535  void clear() noexcept;
536 
538 
539 
552  void parse(int argc, char const* const argv[]);
553 
567  void parse_no_finalize(int argc, char const* const argv[]);
568 
578  void finalize() const;
579 
581 
582 
591  std::string get_exe_name() const noexcept;
592 
600  bool option_is_defined(std::string const& option_name) const;
601 
603  bool help_requested() const;
604 
610  template <typename T>
611  T const& get(std::string const& option_name) const;
612 
614 
615 
620  void print_help(std::ostream& stream) const;
621 
623 
624 private:
626  void init() noexcept;
627 
630  add_flag_impl_(std::string const& name,
631  std::initializer_list<std::string> cli_flags,
632  std::string const& description,
633  bool default_value);
634 
635 private:
637  std::unordered_map<std::string, std::any> params_;
639  std::unordered_set<std::string> required_;
641  clara::Parser parser_;
642 };
643 
644 template <typename ErrorHandler>
646  std::string const& option_name) const
647 {
648  return params_.count(option_name);
649 }
650 
651 template <typename ErrorHandler>
652 template <typename T>
653 inline T const&
654 argument_parser<ErrorHandler>::get(std::string const& option_name) const
655 {
656  if (!option_is_defined(option_name)) {
657  LBANN_ERROR("Invalid option: ", option_name);
658  }
659  return std::any_cast<T const&>(params_.at(option_name));
660 }
661 
662 template <typename ErrorHandler>
663 template <typename T>
665  std::string const& name,
666  std::initializer_list<std::string> cli_flags,
667  std::string const& description,
668  T default_value) -> readonly_reference<T>
669 {
670  params_[name] = std::move(default_value);
671  auto& param_ref = std::any_cast<T&>(params_[name]);
672  clara::Opt option(param_ref, name);
673  for (auto const& f : cli_flags)
674  option[f];
675  parser_ |= option(description).optional();
676  return param_ref;
677 }
678 
679 template <typename ErrorHandler>
680 template <typename T>
681 inline auto
683  std::string const& description,
684  T default_value)
686 {
687  params_[name] = std::move(default_value);
688  auto& param_ref = std::any_cast<T&>(params_[name]);
689  parser_ |= clara::Arg(param_ref, name)(description).optional();
690  return param_ref;
691 }
692 
693 template <typename ErrorHandler>
694 template <typename T>
696  std::string const& name,
697  std::string const& description) -> readonly_reference<T>
698 {
699  // Add the reference to bind to
700  params_[name] = T{};
701  auto& param_any = params_[name];
702  auto& param_ref = std::any_cast<T&>(param_any);
703 
704  required_.insert(name);
705 
706  // Make sure the required arguments are all grouped together.
707  auto iter = parser_.m_args.cbegin(), invalid = parser_.m_args.cend();
708  while (iter != invalid && !iter->isOptional())
709  ++iter;
710 
711  // Create the argument
712  auto ret = parser_.m_args.emplace(
713  iter,
714  [name, &param_ref, this](std::string const& value) {
715  auto result = clara::detail::convertInto(value, param_ref);
716  if (result)
717  required_.erase(name);
718  return result;
719  },
720  name);
721  ret->operator()(description).required();
722  return param_ref;
723 }
724 
725 template <typename ErrorHandler>
727 {
728  init();
729 }
730 
731 template <typename ErrorHandler>
733 {
734  std::unordered_map<std::string, std::any>{}.swap(params_);
735  std::unordered_set<std::string>{}.swap(required_);
736  parser_ = clara::Parser{};
737  init();
738 }
739 
740 template <typename ErrorHandler>
742 {
743  params_["print help"] = false;
744  parser_ |= clara::ExeName();
745  parser_ |= clara::Help(std::any_cast<bool&>(params_["print help"]));
746 }
747 
748 template <typename ErrorHandler>
749 void argument_parser<ErrorHandler>::parse(int argc, char const* const argv[])
750 {
751  parse_no_finalize(argc, argv);
752  finalize();
753 }
754 
755 template <typename ErrorHandler>
757  char const* const argv[])
758 {
759  std::vector<char const*> newargv(argv, argv + argc);
760  auto parse_result =
761  parser_.parse(clara::Args(newargv.size(), newargv.data()));
762 
763  if (!parse_result)
764  this->handle_error(parse_result, parser_, newargv);
765 }
766 
767 template <typename ErrorHandler>
769 {
770  if (!help_requested() && required_.size())
771  throw missing_required_arguments(required_);
772 }
773 
774 template <typename ErrorHandler>
776  std::string const& name,
777  std::initializer_list<std::string> cli_flags,
778  std::string const& description) -> readonly_reference<bool>
779 {
780  return add_flag_impl_(name, std::move(cli_flags), description, false);
781 }
782 
783 template <typename ErrorHandler>
785 {
786  return parser_.m_exeName.name();
787 }
788 
789 template <typename ErrorHandler>
791 {
792  return std::any_cast<bool>(params_.at("print help"));
793 }
794 
795 template <typename ErrorHandler>
796 void argument_parser<ErrorHandler>::print_help(std::ostream& out) const
797 {
798  out << parser_ << std::endl;
799 }
800 
801 template <typename ErrorHandler>
803  std::string const& name,
804  std::initializer_list<std::string> cli_flags,
805  std::string const& description,
806  bool default_value) -> readonly_reference<bool>
807 {
808  params_[name] = default_value;
809  auto& param_ref = std::any_cast<bool&>(params_[name]);
810  clara::Opt option(param_ref);
811  for (auto const& f : cli_flags)
812  option[f];
813  parser_ |= option(description).optional();
814  return param_ref;
815 }
816 
817 } // namespace utils
818 
820 
822 
823 } // namespace lbann
824 
826 template <typename ErrorHandler>
827 std::ostream&
828 operator<<(std::ostream& os,
830 {
831  parser.print_help(os);
832  return os;
833 }
834 
835 #endif /* LBANN_UTILS_ARGUMENT_PARSER_HPP_INCLUDED */
std::string get_exe_name() const noexcept
Get the executable name.
Basic argument parsing with automatic help messages.
bool exists() const
Test if the variable exists in the environment.
readonly_reference< bool > add_flag(std::string const &name, std::initializer_list< std::string > cli_flags, EnvVariable< AccessPolicy > env, std::string const &description)
Add a flag with environment variable override.
void parse_no_finalize(int argc, char const *const argv[])
Parse the command line arguments but do not finalize the parser.
argument_parser()
Create the parser.
readonly_reference< T > add_option(std::string const &name, std::initializer_list< std::string > cli_flags, std::string const &description, T default_value=T())
Add an additional named option.
std::exception subclass that is thrown if the parser can not parse the arguments. ...
readonly_reference< T > add_option(std::string const &name, std::initializer_list< std::string > cli_flags, EnvVariable< AccessPolicy > env, std::string const &description, T default_value=T())
Add an additional named option.
An environment variable.
readonly_reference< bool > add_flag_impl_(std::string const &name, std::initializer_list< std::string > cli_flags, std::string const &description, bool default_value)
Implementation of add_flag.
#define LBANN_ERROR(...)
Definition: exception.hpp:37
readonly_reference< T > add_argument(std::string const &name, std::string const &description, T default_value=T())
Add an optional positional argument.
std::string build_what_string_(Container const &missing_args)
Generates nicely formatted description messages.
Definition: description.hpp:49
T const & get(std::string const &option_name) const
Get the requested value from the argument list.
std::string const & name() const noexcept
Get the name of the environment variable.
void parse(int argc, char const *const argv[])
Parse the command line arguments and finalize the arguments.
std::unordered_set< std::string > required_
Patch around in-progress clara limitation.
readonly_reference< bool > add_flag(std::string const &name, std::initializer_list< std::string > cli_flags, std::string const &description)
Add a flag (i.e. a boolean parameter that is "true" if given and "false" if not given).
parse_error(T &&what_arg)
Construct the exception with the string to be return by what()
readonly_reference< std::string > add_option(std::string const &name, std::initializer_list< std::string > cli_flags, EnvVariable< AccessPolicy > env, std::string const &description, char const *default_value)
Add an additional named option; overloaded for "char const*" parameters.
std::exception subclass that is thrown if the parser can not parse the arguments. ...
parse_error(T &&what_arg)
Construct the exception with the string to be return by what()
void clear() noexcept
Clear all state in the parser.
bool help_requested() const
Test if help has been requested.
std::ostream & operator<<(std::ostream &os, const ParallelStrategy &ps)
Definition: layer.hpp:191
clara::Parser parser_
The underlying clara object.
readonly_reference< std::string > add_option(std::string const &name, std::initializer_list< std::string > cli_flags, std::string const &description, char const *default_value)
Add an additional named option; overloaded for "char const*" parameters.
std::exception subclass that is thrown if a required argument is not found.
void print_help(std::ostream &stream) const
Print a help string to a stream.
std::unordered_map< std::string, std::any > params_
Dictionary of arguments to their values.
void finalize() const
Assert that all required components are set properly.
bool option_is_defined(std::string const &option_name) const
Test if an option exists in the parser.
default_arg_parser_type & global_argument_parser()
void init() noexcept
Reinitialize the parser.
readonly_reference< std::string > add_argument(std::string const &name, std::string const &description, char const *default_value)
Add a positional argument; char const* overload.
void finalize(lbann_comm *comm=nullptr)
readonly_reference< T > add_required_argument(std::string const &name, std::string const &description)
Add a "required" positional argument.
A proxy class representing the current value associated with an option.
missing_required_arguments(Container const &missing_args)
Construct the exception with a list of the missing argument names.