LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
weights_helpers.hpp
Go to the documentation of this file.
1
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3
// Produced at the Lawrence Livermore National Laboratory.
4
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5
// the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6
//
7
// LLNL-CODE-697807.
8
// All rights reserved.
9
//
10
// This file is part of LBANN: Livermore Big Artificial Neural Network
11
// Toolkit. For details, see http://software.llnl.gov/LBANN or
12
// https://github.com/LLNL/LBANN.
13
//
14
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15
// may not use this file except in compliance with the License. You may
16
// obtain a copy of the License at:
17
//
18
// http://www.apache.org/licenses/LICENSE-2.0
19
//
20
// Unless required by applicable law or agreed to in writing, software
21
// distributed under the License is distributed on an "AS IS" BASIS,
22
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23
// implied. See the License for the specific language governing
24
// permissions and limitations under the license.
26
#ifndef LBANN_WEIGHTS_WEIGHTS_HELPERS_HPP_INCLUDED
27
#define LBANN_WEIGHTS_WEIGHTS_HELPERS_HPP_INCLUDED
28
29
#include "
lbann/utils/exception.hpp
"
30
#include "
lbann/utils/typename.hpp
"
31
#include "
lbann/weights/data_type_weights.hpp
"
32
#include "
lbann/weights/weights.hpp
"
33
45
namespace
lbann
{
46
namespace
weights_details {
47
51
template
<
typename
TensorDataType>
52
struct
SafeWeightsAccessor
53
{
54
using
ValuesType
= El::AbstractDistMatrix<TensorDataType>;
55
using
DataTypeWeights
=
data_type_weights<TensorDataType>
;
56
57
static
ValuesType
&
mutable_values
(
weights
& w)
58
{
59
auto
* dtw =
dynamic_cast<
DataTypeWeights
*
>
(&w);
60
if
(!dtw)
61
LBANN_ERROR
(
"Weights object named \""
,
62
w.
get_name
(),
63
"\" does not have weights of dynamic type \""
,
64
TypeName<TensorDataType>(),
65
"\"."
);
66
return
dtw->get_values();
67
}
68
};
// class SafeWeightsAccessor
69
70
}
// namespace weights_details
71
}
// namespace lbann
72
#endif // LBANN_WEIGHTS_WEIGHTS_HELPERS_HPP_INCLUDED
lbann::data_type_weights
Definition:
l2.hpp:41
weights.hpp
LBANN_ERROR
#define LBANN_ERROR(...)
Definition:
exception.hpp:37
lbann::weights_details::SafeWeightsAccessor::mutable_values
static ValuesType & mutable_values(weights &w)
Definition:
weights_helpers.hpp:57
typename.hpp
lbann::weights
Definition:
weights/weights.hpp:100
lbann::weights_details::SafeWeightsAccessor
Ensure safe access to weights objects' data.
Definition:
weights_helpers.hpp:52
data_type_weights.hpp
exception.hpp
lbann::weights_details::SafeWeightsAccessor::ValuesType
El::AbstractDistMatrix< TensorDataType > ValuesType
Definition:
weights_helpers.hpp:54
lbann::weights::get_name
std::string get_name() const
Definition:
weights/weights.hpp:121
lbann
Definition:
callback_helpers.hpp:32
include
lbann
weights
weights_helpers.hpp
Generated on Wed Oct 11 2023 20:49:38 for LBANN by
1.8.13