LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
slice_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED
28 #define LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED
29 
32 
33 namespace lbann {
34 
35 template <typename TensorDataType, data_layout Layout, El::Device Device>
37 {
39 
40  // Setup the slice points if they are to be established by the data reader
41  // TODO: Move this responsibility to another component (input layer)
42  if (m_set_slice_points_from_data_reader) {
43  std::vector<size_t> slice_points;
44  std::string slice_point_method_name = "'get_slice_points_from_reader'";
45 
46  LBANN_WARNING("slice_points_from_reader is deprecated and will be removed "
47  "in a future version.");
48 
50  const DataReaderMetaData& dr_metadata = dc.get_dr_metadata();
51  for (auto& slice_point : dr_metadata.slice_points.at(m_var_category)) {
52  slice_points.push_back(slice_point);
53  }
54 
55  if (slice_points.size() < 2u) {
56  LBANN_ERROR(slice_point_method_name, " is not supported by the reader.");
57  return;
58  }
59  m_slice_points = std::move(slice_points);
60  }
61 
62  // Check that slice parameters are valid
63  const auto& input_dims = this->get_input_dims();
64  const size_t num_outputs = this->get_num_children();
65  if (m_slice_dim >= input_dims.size()) {
66  std::ostringstream err;
67  err << this->get_type() << " layer \"" << this->get_name() << "\" "
68  << "is slicing along dimension " << m_slice_dim << ", "
69  << "but it has a " << input_dims.size() << "-D input tensor "
70  << "(parent layer \"" << this->get_parent_layers()[0]->get_name()
71  << "\" "
72  << "outputs with dimensions ";
73  for (size_t d = 0; d < input_dims.size(); ++d) {
74  err << (d > 0 ? " x " : "") << input_dims[d];
75  }
76  err << ")";
77  LBANN_ERROR(err.str());
78  }
79  if (m_slice_points.size() <= num_outputs) {
80  LBANN_ERROR(this->get_type(),
81  " layer \"",
82  this->get_name(),
83  "\" ",
84  "has ",
85  num_outputs,
86  " children, "
87  "but only ",
88  m_slice_points.size(),
89  " slice points");
90  }
91  if (!std::is_sorted(m_slice_points.begin(), m_slice_points.end())) {
92  LBANN_ERROR(this->get_type(),
93  " layer \"",
94  this->get_name(),
95  "\" ",
96  "has unsorted slice points");
97  }
98  if (m_slice_points.back() > static_cast<size_t>(input_dims[m_slice_dim])) {
99  LBANN_ERROR(this->get_type(),
100  " layer \"",
101  this->get_name(),
102  "\" ",
103  "has a slice point of ",
104  m_slice_points.back(),
105  ", ",
106  "which is outside the expected range "
107  "[0 ",
108  input_dims[m_slice_dim],
109  "]");
110  }
111 
112  // Model-parallel implementation only supports flat data
113  if (Layout == data_layout::MODEL_PARALLEL && input_dims.size() != 1) {
114  LBANN_ERROR(this->get_type(),
115  " layer \"",
116  this->get_name(),
117  "\" ",
118  "attempted to slice along dimension ",
119  m_slice_dim,
120  ", ",
121  "but model-parallel slice layer only supports flat data");
122  }
123 
124  // Set output tensor dimensions
125  auto output_dims = input_dims;
126  for (size_t i = 0; i < num_outputs; ++i) {
127  output_dims[m_slice_dim] = m_slice_points[i + 1] - m_slice_points[i];
128  this->set_output_dims(output_dims, i);
129  }
130 }
131 
132 } // namespace lbann
133 
134 #endif // LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED
virtual void setup_dims()
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
DataReaderMetaData get_dr_metadata() const
void setup_dims() override
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
Definition: slice_impl.hpp:36
#define LBANN_ERROR(...)
Definition: exception.hpp:37
trainer const & get_const_trainer()
Get a const reference to the global trainer visible to this rank.
const data_coordinator & get_data_coordinator() const
Definition: trainer.hpp:165
SPModeSlicePoints slice_points
Definition: metadata.hpp:85
#define LBANN_WARNING(...)
Definition: exception.hpp:53
Data structure containing metadata from the data readers.
Definition: metadata.hpp:82