26 #ifndef LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED 27 #define LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED 29 #include <lbann_config.hpp> 38 #include <onnx/onnx_pb.h> 50 inline void set_datatype(onnx::TensorProto& p, TypeTag<float>)
52 p.set_data_type(onnx::TensorProto::FLOAT);
54 inline void set_datatype(onnx::TensorProto& p, TypeTag<double>)
56 p.set_data_type(onnx::TensorProto::DOUBLE);
58 inline void set_datatype(onnx::TensorProto& p, TypeTag<El::Complex<float>>)
60 p.set_data_type(onnx::TensorProto::COMPLEX64);
62 inline void set_datatype(onnx::TensorProto& p, TypeTag<El::Complex<double>>)
64 p.set_data_type(onnx::TensorProto::COMPLEX128);
67 inline void clear_data(onnx::TensorProto& p, TypeTag<float>)
71 inline void clear_data(onnx::TensorProto& p, TypeTag<double>)
73 p.clear_double_data();
76 inline void add_data(onnx::TensorProto& p,
float const& x)
80 inline void add_data(onnx::TensorProto& p,
double const& x)
87 inline void set_datatype(onnx::TensorProto& p, TypeTag<cpu_half_type>)
89 p.set_data_type(onnx::TensorProto::FLOAT);
91 inline void clear_data(onnx::TensorProto& p, TypeTag<cpu_half_type>)
95 inline void add_data(onnx::TensorProto& p, cpu_half_type
const& x)
97 p.add_float_data(static_cast<float>(x));
99 #if defined LBANN_HAS_GPU && defined LBANN_HAS_GPU_FP16 100 inline void set_datatype(onnx::TensorProto& p, TypeTag<gpu_half_type>)
102 p.set_data_type(onnx::TensorProto::FLOAT);
104 inline void clear_data(onnx::TensorProto& p, TypeTag<gpu_half_type>)
106 p.clear_float_data();
108 inline void add_data(onnx::TensorProto& p, gpu_half_type
const& x)
110 p.add_float_data(static_cast<float>(x));
112 #endif // defined LBANN_HAS_GPU && defined LBANN_HAS_GPU_FP16 113 #endif // LBANN_HAS_HALF 116 template <
typename T>
117 void add_data(onnx::TensorProto& p, El::Complex<T>
const& x)
119 add_data(p, El::RealPart(x));
120 add_data(p, El::ImagPart(x));
124 template <
typename T>
125 void clear_msg_data(onnx::TensorProto& p, TypeTag<T>)
130 clear_data(p, TypeTag<El::Base<T>>{});
135 template <
typename T>
136 void serialize_to_onnx_impl(El::AbstractDistMatrix<T>
const& m,
137 onnx::TensorProto& p)
140 DistMatrixReadProxy<T, T, STAR, STAR, ELEMENT, Device::CPU> proxy(m);
141 auto& mat = proxy.GetLocked().LockedMatrix();
142 auto const height = mat.Height();
143 auto const width = mat.Width();
145 clear_msg_data(p, TypeTag<T>{});
148 set_datatype(p, TypeTag<T>{});
149 p.set_data_location(onnx::TensorProto::DEFAULT);
150 for (
auto r = decltype(height){0}; r < height; ++r)
151 for (
auto c = decltype(width){0}; c < width; ++c)
152 add_data(p, mat.CRef(r, c));
155 template <
typename T,
typename SizeT>
156 void serialize_to_onnx_impl(El::AbstractDistMatrix<T>
const& m,
157 std::vector<SizeT>
const& dims,
158 onnx::TensorProto& p)
164 DistMatrixReadProxy<T, T, STAR, STAR, ELEMENT, Device::CPU> proxy(m);
165 auto& mat = proxy.GetLocked().LockedMatrix();
167 clear_msg_data(p, TypeTag<T>{});
168 for (
auto const& d : dims)
170 set_datatype(p, TypeTag<T>{});
171 p.set_data_location(onnx::TensorProto::DEFAULT);
174 auto const height = mat.Height();
175 auto const*
const buf = mat.LockedBuffer();
176 for (
auto ii = Int{0}; ii < height; ++ii) {
177 add_data(p, buf[ii]);
212 template <
typename T,
typename SizeT>
213 void serialize_to_onnx(El::AbstractDistMatrix<T>
const& m,
214 std::vector<SizeT>
const& height_dims,
215 std::vector<SizeT>
const& width_dims,
216 onnx::TensorProto& p)
218 if (width_dims.empty())
219 details::serialize_to_onnx_impl(m, height_dims, p);
222 static_cast<size_t>(m.Height()));
224 static_cast<size_t>(m.Width()));
225 details::serialize_to_onnx_impl(m, p);
230 #endif // LBANN_HAS_ONNX 231 #endif // LBANN_UTILS_ONNX_UTILS_HPP_INCLUDED auto get_linear_size(std::vector< T > const &dims)
#define LBANN_ASSERT_DEBUG(cond)
#define LBANN_ASSERT(cond)