LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
TestHelpers.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2019, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_UNIT_TEST_UTILITIES_TEST_HELPERS_HPP_INCLUDED
27 #define LBANN_UNIT_TEST_UTILITIES_TEST_HELPERS_HPP_INCLUDED
28 
32 #include <lbann/utils/options.hpp>
33 
34 #include <memory>
35 
36 namespace unit_test {
37 namespace utilities {
38 
39 template <typename T>
40 bool IsValidPtr(std::unique_ptr<T> const& ptr) noexcept
41 {
42  return static_cast<bool>(ptr);
43 }
44 
45 template <typename T>
46 bool IsValidPtr(std::shared_ptr<T> const& ptr) noexcept
47 {
48  return static_cast<bool>(ptr);
49 }
50 
51 template <typename T>
52 bool IsValidPtr(T const* ptr) noexcept
53 {
54  return static_cast<bool>(ptr);
55 }
56 
64 {
65  auto& arg_parser = lbann::global_argument_parser();
66  arg_parser.clear();
68  return arg_parser;
69 }
70 
71 inline void mock_data_reader(lbann::trainer& trainer,
72  const std::vector<El::Int>& sample_size,
73  int num_classes)
74 {
76  auto& md_dims = md.data_dims;
77  md_dims[lbann::data_reader_target_mode::CLASSIFICATION] = {num_classes};
78  md_dims[lbann::data_reader_target_mode::INPUT] = sample_size;
79 
80  // Set up the data reader in the data coordinator
81  // TODO: This is a bit awkward and can be better refactored
82  auto& dc = trainer.get_data_coordinator();
83  dc.set_mock_dr_metadata(md);
84 }
85 
86 } // namespace utilities
87 } // namespace unit_test
88 #endif // LBANN_UNIT_TEST_UTILITIES_TEST_HELPERS_HPP_INCLUDED
Basic argument parsing with automatic help messages.
bool IsValidPtr(std::unique_ptr< T > const &ptr) noexcept
Definition: TestHelpers.hpp:40
lbann::default_arg_parser_type & reset_global_argument_parser()
Return the global LBANN argument parser reset to its default condition.
Definition: TestHelpers.hpp:63
TargetModeDimMap data_dims
Definition: metadata.hpp:84
void construct_all_options()
void mock_data_reader(lbann::trainer &trainer, const std::vector< El::Int > &sample_size, int num_classes)
Definition: TestHelpers.hpp:71
const data_coordinator & get_data_coordinator() const
Definition: trainer.hpp:165
User-facing class that represents a set of compute resources.
Definition: trainer.hpp:60
default_arg_parser_type & global_argument_parser()
void set_mock_dr_metadata(const DataReaderMetaData &drm)
Data structure containing metadata from the data readers.
Definition: metadata.hpp:82