27 #ifndef LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED 28 #define LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED 48 std::string&& label_layer,
49 std::string&& prefix);
51 std::string
const& label_layer,
52 std::string
const& prefix);
59 std::string
name()
const override {
return "confusion matrix"; }
92 std::map<execution_mode, std::vector<El::Int>>
m_counts;
124 const google::protobuf::Message&,
125 std::shared_ptr<lbann_summary>
const&);
130 #endif // LBANN_CALLBACKS_CALLBACK_CONFUSION_MATRIX_HPP_INCLUDED const AbsDistMatType & get_predictions(const model &m) const
void on_validation_end(model *m) override
Called immediately after the end of validation.
void update_counts(const model &m)
void on_batch_evaluate_end(model *m) override
Called at the end of a (mini-)batch evaluation (validation / testing).
std::unique_ptr< AbsDistMatType > m_predictions_v
std::string m_label_layer
std::string m_prediction_layer
Base class for callbacks during training/testing.
void on_epoch_end(model *m) override
Called immediate after the end of each epoch.
Abstract base class for neural network models.
const AbsDistMatType & get_labels(const model &m) const
confusion_matrix(std::string &&prediction_layer, std::string &&label_layer, std::string &&prefix)
void setup(model *m) override
Called once to set up the callback on the model (after all layers are set up).
void reset_counts(const model &m)
void on_batch_end(model *m) override
Called immediately after the end of a (mini-)batch.
std::unique_ptr< callback_base > build_confusion_matrix_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
void on_test_end(model *m) override
Called immediately after the end of testing.
void on_test_begin(model *m) override
Called at the beginning of testing.
void write_specific_proto(lbann_data::Callback &proto) const final
std::string name() const override
Return this callback's name.
confusion_matrix * copy() const override
std::map< execution_mode, std::vector< El::Int > > m_counts
void on_epoch_begin(model *m) override
Called at the beginning of each epoch.
void save_confusion_matrix(const model &m)
void on_validation_begin(model *m) override
Called at the beginning of validation.
std::unique_ptr< AbsDistMatType > m_labels_v
confusion_matrix & operator=(const confusion_matrix &)
El::AbstractDistMatrix< DataType > AbsDistMatType