forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_bytecode.h
59 lines (50 loc) · 1.53 KB
/
export_bytecode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#pragma once
#include <unordered_map>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
namespace torch {
namespace jit {
struct ExportedFunction {
ExportedFunction(
const Module& m,
const Function& f,
std::unique_ptr<Graph> g,
bool t)
: mod(m), function(f), optimizedGraph(std::move(g)), toplevel(t) {}
Module mod;
const Function& function;
std::unique_ptr<Graph> optimizedGraph;
bool toplevel;
};
class TORCH_API BytecodeExportSet {
public:
BytecodeExportSet() = default;
BytecodeExportSet(const BytecodeExportSet&) = delete;
BytecodeExportSet& operator=(const BytecodeExportSet&) = delete;
BytecodeExportSet(BytecodeExportSet&&) = default;
BytecodeExportSet& operator=(BytecodeExportSet&&) = default;
void add(const c10::QualifiedName& qn, ExportedFunction);
void update(const c10::QualifiedName& qn, bool toplevel);
bool contains(const c10::QualifiedName& qn) const;
template <typename F>
void visit(F&& f) {
for (auto& item : items_) {
if (item.second.toplevel) {
f(item.first, item.second);
}
}
for (auto& item : items_) {
if (!item.second.toplevel) {
f(item.first, item.second);
}
}
}
private:
std::unordered_map<c10::QualifiedName, ExportedFunction> items_;
};
} // namespace jit
} // namespace torch