LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
confusion_matrix.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED
28 #define LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED
29 
31 
32 namespace lbann {
33 namespace callback {
34 
42 {
43 public:
44  using AbsDistMatType = El::AbstractDistMatrix<DataType>;
45 
46 public:
47  confusion_matrix(std::string&& prediction_layer,
48  std::string&& label_layer,
49  std::string&& prefix);
50  confusion_matrix(std::string const& prediction_layer,
51  std::string const& label_layer,
52  std::string const& prefix);
55  confusion_matrix* copy() const override
56  {
57  return new confusion_matrix(*this);
58  }
59  std::string name() const override { return "confusion matrix"; }
60 
61  void setup(model* m) override;
62 
63  void on_epoch_begin(model* m) override { reset_counts(*m); }
64  void on_epoch_end(model* m) override { save_confusion_matrix(*m); }
65  void on_validation_begin(model* m) override { reset_counts(*m); }
66  void on_validation_end(model* m) override { save_confusion_matrix(*m); }
67  void on_test_begin(model* m) override { reset_counts(*m); }
68  void on_test_end(model* m) override { save_confusion_matrix(*m); }
69  void on_batch_end(model* m) override { update_counts(*m); }
70  void on_batch_evaluate_end(model* m) override { update_counts(*m); }
71 
72 private:
74  void write_specific_proto(lbann_data::Callback& proto) const final;
75 
79  std::string m_prediction_layer;
83  std::string m_label_layer;
85  std::string m_prefix;
86 
92  std::map<execution_mode, std::vector<El::Int>> m_counts;
93 
98  std::unique_ptr<AbsDistMatType> m_predictions_v;
104  std::unique_ptr<AbsDistMatType> m_labels_v;
105 
107  const AbsDistMatType& get_predictions(const model& m) const;
109  const AbsDistMatType& get_labels(const model& m) const;
110 
112  void reset_counts(const model& m);
117  void update_counts(const model& m);
119  void save_confusion_matrix(const model& m);
120 };
121 
122 // Builder function
123 std::unique_ptr<callback_base> build_confusion_matrix_callback_from_pbuf(
124  const google::protobuf::Message&,
125  std::shared_ptr<lbann_summary> const&);
126 
127 } // namespace callback
128 } // namespace lbann
129 
130 #endif // LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED
const AbsDistMatType & get_predictions(const model &m) const
void on_validation_end(model *m) override
Called immediately after the end of validation.
void update_counts(const model &m)
void on_batch_evaluate_end(model *m) override
Called at the end of a (mini-)batch evaluation (validation / testing).
std::unique_ptr< AbsDistMatType > m_predictions_v
Base class for callbacks during training/testing.
Definition: callback.hpp:76
void on_epoch_end(model *m) override
Called immediate after the end of each epoch.
Abstract base class for neural network models.
Definition: model.hpp:83
const AbsDistMatType & get_labels(const model &m) const
confusion_matrix(std::string &&prediction_layer, std::string &&label_layer, std::string &&prefix)
void setup(model *m) override
Called once to set up the callback on the model (after all layers are set up).
void reset_counts(const model &m)
void on_batch_end(model *m) override
Called immediately after the end of a (mini-)batch.
std::unique_ptr< callback_base > build_confusion_matrix_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
void on_test_end(model *m) override
Called immediately after the end of testing.
void on_test_begin(model *m) override
Called at the beginning of testing.
void write_specific_proto(lbann_data::Callback &proto) const final
std::string name() const override
Return this callback&#39;s name.
confusion_matrix * copy() const override
std::map< execution_mode, std::vector< El::Int > > m_counts
void on_epoch_begin(model *m) override
Called at the beginning of each epoch.
void save_confusion_matrix(const model &m)
void on_validation_begin(model *m) override
Called at the beginning of validation.
std::unique_ptr< AbsDistMatType > m_labels_v
confusion_matrix & operator=(const confusion_matrix &)
El::AbstractDistMatrix< DataType > AbsDistMatType