27 #ifndef LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED 28 #define LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED 35 template <
typename TensorDataType, data_layout Layout, El::Device Device>
42 if (m_set_slice_points_from_data_reader) {
43 std::vector<size_t> slice_points;
44 std::string slice_point_method_name =
"'get_slice_points_from_reader'";
46 LBANN_WARNING(
"slice_points_from_reader is deprecated and will be removed " 47 "in a future version.");
51 for (
auto& slice_point : dr_metadata.
slice_points.at(m_var_category)) {
52 slice_points.push_back(slice_point);
55 if (slice_points.size() < 2u) {
56 LBANN_ERROR(slice_point_method_name,
" is not supported by the reader.");
59 m_slice_points = std::move(slice_points);
63 const auto& input_dims = this->get_input_dims();
64 const size_t num_outputs = this->get_num_children();
65 if (m_slice_dim >= input_dims.size()) {
66 std::ostringstream err;
67 err << this->get_type() <<
" layer \"" << this->get_name() <<
"\" " 68 <<
"is slicing along dimension " << m_slice_dim <<
", " 69 <<
"but it has a " << input_dims.size() <<
"-D input tensor " 70 <<
"(parent layer \"" << this->get_parent_layers()[0]->get_name()
72 <<
"outputs with dimensions ";
73 for (
size_t d = 0; d < input_dims.size(); ++d) {
74 err << (d > 0 ?
" x " :
"") << input_dims[d];
79 if (m_slice_points.size() <= num_outputs) {
88 m_slice_points.size(),
91 if (!std::is_sorted(m_slice_points.begin(), m_slice_points.end())) {
96 "has unsorted slice points");
98 if (m_slice_points.back() >
static_cast<size_t>(input_dims[m_slice_dim])) {
103 "has a slice point of ",
104 m_slice_points.back(),
106 "which is outside the expected range " 108 input_dims[m_slice_dim],
118 "attempted to slice along dimension ",
121 "but model-parallel slice layer only supports flat data");
125 auto output_dims = input_dims;
126 for (
size_t i = 0; i < num_outputs; ++i) {
127 output_dims[m_slice_dim] = m_slice_points[i + 1] - m_slice_points[i];
128 this->set_output_dims(output_dims, i);
134 #endif // LBANN_LAYER_SLICE_IMPL_HPP_INCLUDED virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
DataReaderMetaData get_dr_metadata() const
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
trainer const & get_const_trainer()
Get a const reference to the global trainer visible to this rank.
const data_coordinator & get_data_coordinator() const
#define LBANN_WARNING(...)