27 #ifndef LBANN_UTILS_ENTRYWISE_OPERATOR_HPP 28 #define LBANN_UTILS_ENTRYWISE_OPERATOR_HPP 39 template <
template <
typename>
class Op,
typename TensorDataType>
41 const El::AbstractMatrix<TensorDataType>& input,
42 El::AbstractMatrix<TensorDataType>& output)
44 using UnaryOperator = Op<TensorDataType>;
46 std::stringstream err;
47 if (input.GetDevice() != El::Device::CPU) {
50 else if (output.GetDevice() != El::Device::CPU) {
53 else if (input.Height() != output.Height() ||
54 input.Width() != output.Width()) {
55 err <<
"input matrix dimensions " 56 <<
"(" << input.Height() <<
" x " << input.Width() <<
")" 57 <<
"don't match output matrix dimensions " 58 <<
"(" << output.Height() <<
" x " << output.Width() <<
")";
63 if (input.Contiguous() && output.Contiguous()) {
64 const auto* input_buffer = input.LockedBuffer();
65 auto* output_buffer = output.Buffer();
66 const size_t size = input.Height() * input.Width();
68 for (
size_t i = 0; i < size; ++i) {
70 output_buffer[i] = op(input_buffer[i]);
74 auto const width = input.Width();
75 auto const height = input.Height();
77 for (El::Int col = 0; col < width; ++col) {
78 for (El::Int row = 0; row < height; ++row) {
80 output(row, col) = op(input(row, col));
90 template <
template <
typename>
class Op,
typename TensorDataType>
92 const El::AbstractMatrix<TensorDataType>& input1,
93 const El::AbstractMatrix<TensorDataType>& input2,
94 El::AbstractMatrix<TensorDataType>& output)
96 using BinaryOperator = Op<TensorDataType>;
98 if (input1.GetDevice() != El::Device::CPU ||
99 input2.GetDevice() != El::Device::CPU) {
102 else if (output.GetDevice() != El::Device::CPU) {
105 else if (input1.Height() != input2.Height() ||
106 input1.Width() != input2.Width() ||
107 input1.Height() != output.Height() ||
108 input1.Width() != output.Width()) {
119 "don't match output matrix dimensions " 128 if (input1.Contiguous() && input2.Contiguous() && output.Contiguous()) {
129 const auto* input1_buffer = input1.LockedBuffer();
130 const auto* input2_buffer = input2.LockedBuffer();
131 auto* output_buffer = output.Buffer();
132 const size_t size = input1.Height() * input1.Width();
134 for (
size_t i = 0; i < size; ++i) {
136 output_buffer[i] = op(input1_buffer[i], input2_buffer[i]);
140 auto const width = input1.Width();
141 auto const height = input1.Height();
143 for (El::Int col = 0; col < width; ++col) {
144 for (El::Int row = 0; row < height; ++row) {
146 output(row, col) = op(input1(row, col), input2(row, col));
156 template <
template <
typename>
class Op,
typename TensorDataType>
158 const El::AbstractDistMatrix<TensorDataType>& input,
159 El::AbstractDistMatrix<TensorDataType>& output)
161 if (input.Height() != output.Height() || input.Width() != output.Width()) {
168 "don't match output matrix dimensions " 175 else if (input.DistData() != output.DistData()) {
176 LBANN_ERROR(
"input and output matrix distributions don't match");
178 apply_entrywise_unary_operator<Op>(input.LockedMatrix(), output.Matrix());
185 template <
template <
typename>
class Op,
typename TensorDataType>
187 const El::AbstractDistMatrix<TensorDataType>& input1,
188 const El::AbstractDistMatrix<TensorDataType>& input2,
189 El::AbstractDistMatrix<TensorDataType>& output)
191 if (input1.Height() != input2.Height() || input1.Width() != input2.Width() ||
192 input1.Height() != output.Height() || input1.Width() != output.Width()) {
203 "don't match output matrix dimensions " 210 else if (input1.DistData() != input2.DistData() ||
211 input1.DistData() != output.DistData()) {
212 LBANN_ERROR(
"input and output matrix distributions don't match");
214 apply_entrywise_binary_operator<Op>(input1.LockedMatrix(),
215 input2.LockedMatrix(),
221 #endif // LBANN_UTILS_ENTRYWISE_OPERATOR_HPP void apply_entrywise_binary_operator(const El::AbstractMatrix< TensorDataType > &input1, const El::AbstractMatrix< TensorDataType > &input2, El::AbstractMatrix< TensorDataType > &output)
void apply_entrywise_unary_operator(const El::AbstractMatrix< TensorDataType > &input, El::AbstractMatrix< TensorDataType > &output)
#define LBANN_OMP_PARALLEL_FOR_COLLAPSE2
#define LBANN_OMP_PARALLEL_FOR