LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
onnx_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED
27 #define LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED
28 
29 #include <lbann_config.hpp>
30 
31 #ifdef LBANN_HAS_ONNX
32 #include "El.hpp"
33 
37 
38 #include <onnx/onnx_pb.h>
39 
40 #include <vector>
41 
42 namespace lbann {
43 namespace details {
44 
45 // TODO: There are two extra considerations. ONNX supports both
46 // "external data" (stored on disk) and segmented data. Both of these
47 // could be useful in certain circumstances.
48 
49 // Basic types first.
50 inline void set_datatype(onnx::TensorProto& p, TypeTag<float>)
51 {
52  p.set_data_type(onnx::TensorProto::FLOAT);
53 }
54 inline void set_datatype(onnx::TensorProto& p, TypeTag<double>)
55 {
56  p.set_data_type(onnx::TensorProto::DOUBLE);
57 }
58 inline void set_datatype(onnx::TensorProto& p, TypeTag<El::Complex<float>>)
59 {
60  p.set_data_type(onnx::TensorProto::COMPLEX64);
61 }
62 inline void set_datatype(onnx::TensorProto& p, TypeTag<El::Complex<double>>)
63 {
64  p.set_data_type(onnx::TensorProto::COMPLEX128);
65 }
66 
67 inline void clear_data(onnx::TensorProto& p, TypeTag<float>)
68 {
69  p.clear_float_data();
70 }
71 inline void clear_data(onnx::TensorProto& p, TypeTag<double>)
72 {
73  p.clear_double_data();
74 }
75 
76 inline void add_data(onnx::TensorProto& p, float const& x)
77 {
78  p.add_float_data(x);
79 }
80 inline void add_data(onnx::TensorProto& p, double const& x)
81 {
82  p.add_double_data(x);
83 }
84 
85 // Now FP16 types
86 #ifdef LBANN_HAS_HALF
87 inline void set_datatype(onnx::TensorProto& p, TypeTag<cpu_half_type>)
88 {
89  p.set_data_type(onnx::TensorProto::FLOAT);
90 }
91 inline void clear_data(onnx::TensorProto& p, TypeTag<cpu_half_type>)
92 {
93  p.clear_float_data();
94 }
95 inline void add_data(onnx::TensorProto& p, cpu_half_type const& x)
96 {
97  p.add_float_data(static_cast<float>(x));
98 }
99 #if defined LBANN_HAS_GPU && defined LBANN_HAS_GPU_FP16
100 inline void set_datatype(onnx::TensorProto& p, TypeTag<gpu_half_type>)
101 {
102  p.set_data_type(onnx::TensorProto::FLOAT);
103 }
104 inline void clear_data(onnx::TensorProto& p, TypeTag<gpu_half_type>)
105 {
106  p.clear_float_data();
107 }
108 inline void add_data(onnx::TensorProto& p, gpu_half_type const& x)
109 {
110  p.add_float_data(static_cast<float>(x));
111 }
112 #endif // defined LBANN_HAS_GPU && defined LBANN_HAS_GPU_FP16
113 #endif // LBANN_HAS_HALF
114 
115 // Finally, complex types.
116 template <typename T>
117 void add_data(onnx::TensorProto& p, El::Complex<T> const& x)
118 {
119  add_data(p, El::RealPart(x));
120  add_data(p, El::ImagPart(x));
121 }
122 
123 // Clear any data present in the message.
124 template <typename T>
125 void clear_msg_data(onnx::TensorProto& p, TypeTag<T>)
126 {
127  p.clear_dims();
128  p.clear_data_type();
129  // Complex types are stored in their base type.
130  clear_data(p, TypeTag<El::Base<T>>{});
131 }
132 
133 // Serialize the matrix into the given TensorProto message in
134 // row-major ordering.
135 template <typename T>
136 void serialize_to_onnx_impl(El::AbstractDistMatrix<T> const& m,
137  onnx::TensorProto& p)
138 {
139  using namespace El;
140  DistMatrixReadProxy<T, T, STAR, STAR, ELEMENT, Device::CPU> proxy(m);
141  auto& mat = proxy.GetLocked().LockedMatrix();
142  auto const height = mat.Height();
143  auto const width = mat.Width();
144 
145  clear_msg_data(p, TypeTag<T>{});
146  p.add_dims(height);
147  p.add_dims(width);
148  set_datatype(p, TypeTag<T>{});
149  p.set_data_location(onnx::TensorProto::DEFAULT);
150  for (auto r = decltype(height){0}; r < height; ++r)
151  for (auto c = decltype(width){0}; c < width; ++c)
152  add_data(p, mat.CRef(r, c));
153 }
154 
155 template <typename T, typename SizeT>
156 void serialize_to_onnx_impl(El::AbstractDistMatrix<T> const& m,
157  std::vector<SizeT> const& dims,
158  onnx::TensorProto& p)
159 {
160  using namespace El;
161  LBANN_ASSERT(lbann::get_linear_size(dims) == static_cast<size_t>(m.Height()));
162  LBANN_ASSERT(m.Width() == static_cast<Int>(1));
163 
164  DistMatrixReadProxy<T, T, STAR, STAR, ELEMENT, Device::CPU> proxy(m);
165  auto& mat = proxy.GetLocked().LockedMatrix();
166 
167  clear_msg_data(p, TypeTag<T>{});
168  for (auto const& d : dims)
169  p.add_dims(d);
170  set_datatype(p, TypeTag<T>{});
171  p.set_data_location(onnx::TensorProto::DEFAULT);
172 
173  // This may be unnecessary
174  auto const height = mat.Height();
175  auto const* const buf = mat.LockedBuffer();
176  for (auto ii = Int{0}; ii < height; ++ii) {
177  add_data(p, buf[ii]);
178  }
179 }
180 
181 } // namespace details
182 
212 template <typename T, typename SizeT>
213 void serialize_to_onnx(El::AbstractDistMatrix<T> const& m,
214  std::vector<SizeT> const& height_dims,
215  std::vector<SizeT> const& width_dims,
216  onnx::TensorProto& p)
217 {
218  if (width_dims.empty())
219  details::serialize_to_onnx_impl(m, height_dims, p);
220  else {
222  static_cast<size_t>(m.Height()));
224  static_cast<size_t>(m.Width()));
225  details::serialize_to_onnx_impl(m, p);
226  }
227 }
228 
229 } // namespace lbann
230 #endif // LBANN_HAS_ONNX
231 #endif // LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED
auto get_linear_size(std::vector< T > const &dims)
Definition: dim_helpers.hpp:59
#define LBANN_ASSERT_DEBUG(cond)
Definition: exception.hpp:104
#define LBANN_ASSERT(cond)
Definition: exception.hpp:97