LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
dump_model_graph.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_CALLBACKS_CALLBACK_DUMP_MODEL_GRAPH_HPP_INCLUDED
28 #define LBANN_CALLBACKS_CALLBACK_DUMP_MODEL_GRAPH_HPP_INCLUDED
29 
31 
32 namespace lbann {
33 namespace callback {
34 
42 {
43 public:
44  dump_model_graph(std::string basename, bool print)
45  : m_basename(basename), m_print(print)
46  {}
47  dump_model_graph(const dump_model_graph&) = default;
48  dump_model_graph& operator=(const dump_model_graph&) = default;
49  dump_model_graph* copy() const override
50  {
51  return new dump_model_graph(*this);
52  }
53  std::string name() const override { return "print tensor dimensions"; }
54 
55  void on_setup_end(model* m) override;
56 
57 private:
59  void write_specific_proto(lbann_data::Callback& proto) const final;
60 
62  std::string m_basename;
64  bool m_print;
65 };
66 
67 // Builder function
68 std::unique_ptr<callback_base> build_dump_model_graph_callback_from_pbuf(
69  const google::protobuf::Message&,
70  std::shared_ptr<lbann_summary> const&);
71 
72 } // namespace callback
73 } // namespace lbann
74 
75 #endif // LBANN_CALLBACKS_CALLBACK_DUMP_MODEL_GRAPH_HPP_INCLUDED
void on_setup_end(model *m) override
Called at the end of setup.
std::unique_ptr< callback_base > build_dump_model_graph_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
void write_specific_proto(lbann_data::Callback &proto) const final
dump_model_graph(std::string basename, bool print)
dump_model_graph * copy() const override
Base class for callbacks during training/testing.
Definition: callback.hpp:76
Abstract base class for neural network models.
Definition: model.hpp:83
void print(const El::Matrix< uint8_t > &mat, El::Int height, El::Int width, El::Int channels=1)
Definition: helper.hpp:87
dump_model_graph & operator=(const dump_model_graph &)=default
Dump model graph callback.
std::string name() const override
Return this callback&#39;s name.