LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
weights_proxy.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_WEIGHTS_WEIGHTS_PROXY_HPP_INCLUDED
27 #define LBANN_WEIGHTS_WEIGHTS_PROXY_HPP_INCLUDED
28 
29 #include "lbann_config.hpp"
30 
31 #include "lbann/base.hpp"
34 
35 #if defined LBANN_DEBUG
37 #define LBANN_DEBUG_ASSERT_POINTER(ptr) \
38  do { \
39  if (!ptr) \
40  LBANN_ERROR("Pointer \"" #ptr "\" is null."); \
41  } while (0)
42 #define LBANN_IN_DEBUG_MODE true
43 #else
44 #define LBANN_DEBUG_ASSERT_POINTER(ptr)
45 #define LBANN_IN_DEBUG_MODE false
46 #endif
47 
48 namespace lbann {
49 
76 template <typename TensorDataType>
78 {
80  using ValuesType = El::AbstractDistMatrix<TensorDataType>;
82  using ValuesPtrType = std::unique_ptr<ValuesType>;
83 
84 public:
86 
89  WeightsProxy() = default;
90 
97  {
98  if (!w.expired()) {
99  this->setup(w);
100  }
101  }
102 
108  {
109  if (!other.master_weights_.expired()) {
110  this->setup(other.master_weights_);
111  }
112  }
113 
121  template <typename T>
123  {
124  auto ptr = other.master_weights_pointer();
125  if (!ptr.expired()) {
126  this->setup(ptr);
127  }
128  }
129 
135  WeightsProxy(WeightsProxy&& other) noexcept
136  : master_weights_{std::move(other.master_weights_)},
137  values_{std::move(other.values_)}
138  {
139  other.clear();
140  }
141 
143  ~WeightsProxy() noexcept { this->clear(); }
144 
146 
148 
152  {
153  WeightsProxy<TensorDataType>(other).swap(*this);
154  return *this;
155  }
156 
165  template <typename T>
167  {
168  WeightsProxy<TensorDataType>(other).swap(*this);
169  return *this;
170  }
171 
174  {
175  // "Move-and-swap" idiom
176  WeightsProxy<TensorDataType>(std::move(other)).swap(*this);
177  return *this;
178  }
179 
181 
182 
188  void clear() noexcept
189  {
190  master_weights_.reset();
191  values_.reset();
192  }
193 
200  void setup(ViewingWeightsPtr const& w)
201  {
202  master_weights_ = w;
203  if (master_weights_.expired()) {
204  values_.reset();
205  }
206  else {
208  }
209  }
210 
218  {
219  if (!empty()) {
220  const auto& master_values = master_weights_.lock()->get_values();
221  if (values_->Viewing()) {
222  El::LockedView(*values_,
223  dynamic_cast<const ValuesType&>(master_values));
224  }
225  else {
226  El::Copy(master_values, *values_);
227  }
228  }
229  }
230 
232 
233 
236  bool empty() const noexcept { return values_ == nullptr; }
237 
244  ValuesType const& values() const noexcept(!LBANN_IN_DEBUG_MODE)
245  {
247  return *values_;
248  }
249 
256  weights const& master_weights() const
257  {
259  return *master_weights_.lock();
260  }
261 
263  noexcept(!LBANN_IN_DEBUG_MODE)
264  {
266  return master_weights_;
267  }
268 
270 
271 
275  {
276  std::swap(master_weights_, other.master_weights_);
277  std::swap(values_, other.values_);
278  }
279 
281 
282 private:
284 
289  {
290  auto const& vals = dtw.get_values();
291  ValuesPtrType ret(vals.Construct(vals.Grid(), vals.Root()));
292  El::LockedView(*ret, vals);
293  return ret;
294  }
295 
305  template <typename OtherT>
307  {
308  return setup_values_as_copy_(w);
309  }
310 
313  {
314  if (auto dtw = dynamic_cast<data_type_weights<TensorDataType> const*>(&w))
315  return setup_values_(*dtw);
316  return setup_values_as_copy_(w);
317  }
318 
321  {
322  // In this case, w has some other dynamic type. So we need to
323  // deep-copy every time. Thus, we allocate a target for this deep
324  // copy here.
325  ValuesPtrType ret{ValuesType::Instantiate(w.get_matrix_distribution())};
326  ret->Resize(w.get_matrix_height(), w.get_matrix_width());
327  return ret;
328  }
330 private:
331  // These members should never observably differ in nullity.
333 
337 
340 
342 };
343 
344 // Conform to LBANN's scheme
345 template <typename TensorDataType>
347 
348 } // namespace lbann
349 #undef LBANN_IN_DEBUG_MODE
350 #endif // LBANN_WEIGHTS_WEIGHTS_PROXY_HPP_INCLUDED
void clear() noexcept
Restore the default state of the proxy.
WeightsProxy(WeightsProxy &&other) noexcept
Move a WeightsProxy object.
std::unique_ptr< ValuesType > ValuesPtrType
Convenience typedef for poitners to weights values.
AbsDistMatrixType & get_values() override
WeightsProxy(WeightsProxy< T > const &other)
Copy a WeightsProxy object.
bool empty() const noexcept
Check if the proxy is referencing a weights object.
void synchronize_with_master()
Synchronize the held values with the master set.
size_t get_matrix_width() const
#define LBANN_DEBUG_ASSERT_POINTER(ptr)
WeightsProxy(WeightsProxy const &other)
Copy a WeightsProxy object.
ValuesType const & values() const noexcept(!LBANN_IN_DEBUG_MODE)
Access the values.
ViewingWeightsPtr master_weights_pointer() const noexcept(!LBANN_IN_DEBUG_MODE)
El::DistData get_matrix_distribution() const
ValuesPtrType setup_values_(data_type_weights< OtherT > const &w) const
Establish the target matrix storage.
Proxy a weights object as a different data type.
ValuesPtrType setup_values_(weights const &w) const
Establish the target matrix storage.
ValuesPtrType values_
The values in this data type.
ValuesPtrType setup_values_(data_type_weights< TensorDataType > const &dtw) const
Establish the view of the master data.
void setup(ViewingWeightsPtr const &w)
Provide setup function for delayed construction.
WeightsProxy(ViewingWeightsPtr const &w)
Construct a proxy given the master object.
std::weak_ptr< weights > ViewingWeightsPtr
Smart pointer to reference a weights object.
Definition: layer.hpp:89
WeightsProxy()=default
Construct an empty proxy.
El::AbstractDistMatrix< TensorDataType > ValuesType
The type of weights values.
void swap(WeightsProxy< TensorDataType > &other)
Swap contents with another WeightsProxy object.
WeightsProxy & operator=(WeightsProxy const &other)
Copy assignment operator.
size_t get_matrix_height() const
ValuesPtrType setup_values_as_copy_(weights const &w) const
Create the matrix object to store the copied weights.
#define LBANN_IN_DEBUG_MODE
WeightsProxy & operator=(WeightsProxy &&other) noexcept
Move assignment from another proxy object.
ViewingWeightsPtr master_weights_
The proxied master weights.
weights const & master_weights() const
Access the master weights object directly.
~WeightsProxy() noexcept
Destructor.
WeightsProxy & operator=(WeightsProxy< T > const &other)
Assignment from WeightsProxy object of a different type.