-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
451 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
These notes are from poking around `Core.Compiler` to see, how they are different from working just with `CodeInfo` and `IRTools.jl`. Notes are mainly around IRCode. Why there is a `Core.Compiler.IRCode` when there was `Core.CodeInfo`? Seems to be historical reasons. At the beginning, Julia did not have any intermediate representation and code directly emitted LLVM. Then, it has received an `CodeInfo` as in intermediate representation. `IRCode` seems like an evolution of `CodeInfo`. `Core.Compiler` works mostly with `IRCode`, but the `IRCode` can be converted to the `CodeInfo` and the other way around. `IRCode` seems to be designed more for implementation of various optimisation phases. Personal experience tells me it is much nicer to work with even on the low level. | ||
|
||
Throughout the explanation, we assume that `Core.Compiler` was imported as `CC` to decrease the typing load. | ||
|
||
Let's play with a simple silly function | ||
```julia | ||
|
||
function foo(x,y) | ||
z = x * y | ||
z + sin(x) | ||
end | ||
``` | ||
|
||
### IRCode | ||
We can obtain `CC.IRCode` | ||
```julia | ||
import Core.Compiler as CC | ||
(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1")) | ||
``` | ||
which returns `Core.Compiler.IRCode` in `ir` and return-type `Float64` in `rt`. | ||
The output might look like | ||
``` | ||
julia> (ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1")) | ||
1─ %1 = (_2 * _3)::Float64 | ||
│ %2 = Main.sin(_2)::Float64 | ||
│ %3 = (%1 + %2)::Float64 | ||
└── return %3 | ||
=> Float64 | ||
``` | ||
Options of `optimize_until` are `compact 1`, `compact 2`, `nothing.` I do not see a difference between `compact 2` and `compact 2`. | ||
|
||
The IRCode structure is defined as | ||
``` | ||
struct IRCode | ||
stmts::InstructionStream | ||
argtypes::Vector{Any} | ||
sptypes::Vector{VarState} | ||
linetable::Vector{LineInfoNode} | ||
cfg::CFG | ||
new_nodes::NewNodeStream | ||
meta::Vector{Expr} | ||
end | ||
``` | ||
where | ||
* `stmts` is a stream of instruction (more in this below) | ||
* `argtypes` holds types of arguments of the function whose `IRCode` we have obtained | ||
* `sptypes` is a vector of `VarState`. It seems to be related to parameters of types | ||
* `linetable` is a table of unique lines in the source code from which statements | ||
* `cfg` holds control flow graph, which contains building blocks and jumps between them | ||
* `new_nodes` is an infrastructure that can be used to insert new instructions to the existing `IRCode` . The idea behind is that since insertion requires a renumbering all statements, they are put in a separate queue. They are put to correct position with a correct `SSANumber` by calling `compact!`. | ||
* `meta` is something. | ||
|
||
Before going further, let's take a look on `InstructionStream` defined as | ||
```julia | ||
struct InstructionStream | ||
inst::Vector{Any} | ||
type::Vector{Any} | ||
info::Vector{CallInfo} | ||
line::Vector{Int32} | ||
flag::Vector{UInt8} | ||
end | ||
``` | ||
where | ||
* `inst` is a vector of instructions, stored as `Expr`essions. The allowed fields in `head` are described [here](https://docs.julialang.org/en/v1/devdocs/ast/#Expr-types) | ||
* `type` is the type of the value returned by the corresponding statement | ||
* `CallInfo` is ???some info??? | ||
* `line` is an index into `IRCode.linetable` identifying from which line in source code the statement comes from | ||
* `flag` are some flags providing additional information about the statement. | ||
- `0x01 << 0` = statement is marked as `@inbounds` | ||
- `0x01 << 1` = statement is marked as `@inline` | ||
- `0x01 << 2` = statement is marked as `@noinline` | ||
- `0x01 << 3` = statement is within a block that leads to `throw` call | ||
- `0x01` << 4 = statement may be removed if its result is unused, in particular it is thus be both pure and effect free | ||
- `0x01 << 5-6 = <unused>` | ||
- `0x01 << 7 = <reserved>` has out-of-band info | ||
|
||
For the above `foo` function, the InstructionStream looks like | ||
|
||
```julia | ||
julia> DataFrame(flag = ir.stmts.flag, info = ir.stmts.info, inst = ir.stmts.inst, line = ir.stmts.line, type = ir.stmts.type) | ||
4×5 DataFrame | ||
Row │ flag info inst line type | ||
│ UInt8 CallInfo Any Int32 Any | ||
─────┼──────────────────────────────────────────────────────────────────────── | ||
1 │ 112 MethodMatchInfo(MethodLookupResu… _2 * _3 1 Float64 | ||
2 │ 80 MethodMatchInfo(MethodLookupResu… Main.sin(_2) 2 Float64 | ||
3 │ 112 MethodMatchInfo(MethodLookupResu… %1 + %2 2 Float64 | ||
4 │ 0 NoCallInfo() return %3 2 Any | ||
``` | ||
We can index into the statements as `ir.stmts[1]`, which provides a "view" into the vector. To obtain the first instruction, we can do `ir.stmts[1][:inst]`. | ||
The IRCode is typed, but the fields can contain `Any`. It is up to the user to provide corrrect types of the output and there is no helper functions to perform typing. A workaround is shown in the Petite Diffractor project. Julia's sections of the manual https://docs.julialang.org/en/v1/devdocs/ssair/ and seems incredibly useful. The IR form they talk about seems to be `Core.Compiler.IRCode`. | ||
It seems to be that it is possible to insert IR instructions into the it structure by queuing that to the field `stmts` and then call `compact!`, which would perform the heavy machinery of relabeling everything. | ||
#### Example of modifying the function through IRCode | ||
Below is an MWE that tries to modify the IRCode of a function and execute it. The goal is to change the function `foo` to `fooled`. | ||
```julia | ||
import Core.Compiler as CC | ||
using Core: SSAValue, GlobalRef, ReturnNode | ||
|
||
function foo(x,y) | ||
z = x * y | ||
z + sin(x) | ||
end | ||
|
||
function fooled(x,y) | ||
z = x * y | ||
z + sin(x) + cos(y) | ||
end | ||
|
||
(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1")); | ||
nr = CC.insert_node!(ir, 2, CC.NewInstruction(Expr(:call, Core.GlobalRef(Main, :cos), Core.Argument(3)), Float64)) | ||
nr2 = CC.insert_node!(ir, 4, CC.NewInstruction(Expr(:call, GlobalRef(Main, :+), SSAValue(3), nr), Float64)) | ||
CC.setindex!(ir.stmts[4], ReturnNode(nr2), :inst) | ||
ir = CC.compact!(ir) | ||
irfooled = Core.OpaqueClosure(ir) | ||
irfooled(1.0, 2.0) == fooled(1.0, 2.0) | ||
``` | ||
So what we did? | ||
1. `(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1"))` obtain the `IRCode` of the function `foo` when called with both arguments being `Float64`. `rt` contains the return type of the | ||
2. A new instruction `cos` is inserted to the `ir` by `Core.Compiler.insert_node!`, which takes as an argument an `IRCode`, position (2 in our case), and new instruction. The new instruction is created by `NewInstruction` accepting as an input expression `Expr` and a return type. Here, we force it to be `Float64`, but ideally it should be inferred. (This would be the next stage). Or, may-be, we can run it through type inference? . The new instruction is added to the `ir.new_nodes` instruction stream and obtain a new SSAValue returned in `nr`, which can be then used further. | ||
3. We add one more instruction `+` that uses output of the instruction we add in step 2, `nr` and SSAValue from statement 3 of the original IR (at this moment, the IR is still numbered with respect to the old IR, the renumbering will happen later.) The output of this second instruction is returned in `nr2`. | ||
4. Then, we rewrite the return statement to return `nr2` instead of `SSAValue(3)`. | ||
5. `ir = CC.compact!(ir)` is superimportant since it moves the newly added statements from `ir.new_stmts` to `ir.stmts` and importantly renumbers `SSAValues.` *Even though the function is mutating, the mutation here is meant that the argument is changed, but the new correct IRCode is returned and therefore has to be reassigned.* | ||
6. The function is created through `OpaqueClosure.` | ||
7. The last line certifies that the function do what it should do. | ||
There is no infrastructure to make the above manipulation transparent, like is the case of @generated function and codeinfo. It is possible to hook through generated function by converting the IRCode to untyped CodeInfo, in which case you do not have to bother with typing. | ||
#### How to obtain code info the proper way? | ||
This is the way code info is obtained in the diffractor. | ||
```julia | ||
mthds = Base._methods_by_ftype(sig, -1, world) | ||
match = only(mthds) | ||
|
||
mi = Core.Compiler.specialize_method(match) | ||
ci = Core.Compiler.retrieve_code_info(mi, world) | ||
``` | ||
### CodeInfo | ||
`IRTools.jl` are great for modifying `CodeInfo`. I have found two tools for modifying `IRCode` and I wonder if they have been abandoned because they were both dead ends or because of lack of human labor. I am also aware of Also, [this](https://nbviewer.org/gist/tkf/d4734be24d2694a3afd669f8f50e6b0f/00_notebook.ipynb) is quite cool play with IRStuff. | ||
Resources | ||
* https://vchuravy.dev/talks/licm/ | ||
* [CompilerPluginTools](https://github.com/JuliaCompilerPlugins/CompilerPluginTools.jl) | ||
* [CodeInfoTools.jl](https://github.com/JuliaCompilerPlugins/CodeInfoTools.jl). | ||
* TKF's [CodeInfo.jl](https://github.com/tkf/ShowCode.jl) is nice for visualization of the IRCode | ||
* Diffractor is an awesome source of howto. For example function `my_insert_node!` in `src/stage1/hacks.jl` | ||
* https://nbviewer.org/gist/tkf/d4734be24d2694a3afd669f8f50e6b0f/00_notebook.ipynb | ||
* https://github.com/JuliaCompilerPlugins/Mixtape.jl | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# A simple reverse-mode AD. | ||
# Lots of simplifications have been made (in particular, there is no support for | ||
# control flow). But this illustrates most of the principles behind Zygote. | ||
# https://fluxml.ai/Zygote.jl/dev/internals/ | ||
|
||
|
||
##### | ||
|
||
# We assume to have a set of AD rules (e.g. ChainRules), which for a given function returns its evaluation and pullback. If we are tasked with computing the gradient. | ||
|
||
# 1. If a rule exists for this function, directly return the rule. | ||
# 2. If not, deconstruct the function into a sequence of functions by asking f `IRCode` | ||
# 3. Replace statements by calls to obtain the evaluation of the statements and the pullback. | ||
# 4. Chain pullbacks in reverse order. | ||
# 5. Return the function evaluation and the chained pullback. | ||
|
||
# The idea is that we will replace each statement of `foo` with a statement returning the function value and pullback. At the moment and for simplicity, we assume that appropriate chain is defined. Moreover, we need to keep track of mapping old SSAValues to new SSAValues in ssamap, since their values will differ. | ||
|
||
import Core.Compiler as CC | ||
using ChainRules | ||
using Core: SSAValue, GlobalRef, ReturnNode | ||
|
||
|
||
function get_ciir(f, sig; world = Core.Compiler.get_world_counter(), optimize_until = "compact 1") | ||
mi = only(Base.method_instances(f, sig, world)) | ||
ci = Base.uncompressed_ir(mi.def::Method) | ||
(ir, rt) = only(Base.code_ircode(f, sig; optimize_until)) | ||
(copy(ci), ir, rt) | ||
end | ||
|
||
|
||
|
||
# struct Pullback{S,T} | ||
# pullbacks::T | ||
# end | ||
|
||
|
||
argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n] | ||
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[a.id] | ||
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f)) | ||
argtype(ir::CC.IRCode, a) = error("argtype of $(typeof(a)) not supported") | ||
|
||
""" | ||
type_of_pullback(ir, inst) | ||
infer type of the pullback | ||
""" | ||
function type_of_pullback(ir, inst, optimize_until = "compact 1") | ||
inst.head != :call && error("inferrin return type of calls is supported") | ||
params = tuple([argtype(ir, a) for a in inst.args]...) | ||
(ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until)) | ||
if !(rt <:Tuple{A,B} where {A,B}) | ||
error("The return type of pullback `ChainRules.rrule($(params))` should be tuple") | ||
end | ||
rt | ||
end | ||
|
||
remap(d, args::Tuple) = map(a -> remap(d,a), args) | ||
remap(d, args::Vector) = map(a -> remap(d,a), args) | ||
remap(d, r::ReturnNode) = ReturnNode(remap(d, r.val)) | ||
remap(d, x::SSAValue) = d[x] | ||
remap(d, x) = x | ||
|
||
function forward(ir) | ||
pullbacks = [] | ||
new_insts = Any[] | ||
new_line = Int32[] | ||
new_types = Any[] | ||
ssamap = Dict{SSAValue,SSAValue}() | ||
fval_ssa = nothing | ||
for (i, stmt) in enumerate(ir.stmts) | ||
inst = stmt[:inst] | ||
if inst isa Expr && inst.head == :call | ||
new_inst = Expr(:call, GlobalRef(ChainRules, :rrule), remap(ssamap, inst.args)...) | ||
tt = type_of_pullback(ir, inst) | ||
push!(new_insts, new_inst) | ||
push!(new_line, stmt[:line]) | ||
push!(new_types, tt) | ||
rrule_ssa = SSAValue(length(new_insts)) | ||
|
||
|
||
push!(new_insts, Expr(:call, :getindex, rrule_ssa, 1)) | ||
push!(new_line, stmt[:line]) | ||
push!(new_types, tt.parameters[1]) | ||
val_ssa = SSAValue(length(new_insts)) | ||
ssamap[SSAValue(i)] = val_ssa | ||
(stmt[:type] != tt.parameters[1]) && @info("pullback of $(inst) has a different type than normal inst") | ||
|
||
push!(new_insts, Expr(:call, :getindex, rrule_ssa, 2)) | ||
pullback_ssa = SSAValue(length(new_insts)) | ||
push!(new_line, stmt[:line]) | ||
push!(new_types, tt.parameters[2]) | ||
push!(pullbacks, pullback_ssa) | ||
continue | ||
end | ||
|
||
if inst isa ReturnNode | ||
fval_ssa = remap(ssamap, inst.val) | ||
continue | ||
end | ||
error("unknown node $(i)") | ||
end | ||
|
||
# construct tuple with all pullbacks | ||
push!(new_insts, Expr(:call, :tuple, pullbacks...)) | ||
pull_ssa = SSAValue(length(new_insts)) | ||
push!(new_line, new_line[end]) | ||
push!(new_types, Tuple{[new_types[x.id] for x in pullbacks]...}) | ||
|
||
# construct the tuple containing forward and reverse | ||
push!(new_insts, Expr(:call, :tuple, fval_ssa, pull_ssa)) | ||
ret_ssa = SSAValue(length(new_insts)) | ||
push!(new_line, new_line[end]) | ||
push!(new_types, Tuple{new_types[fval_ssa.id], new_types[pull_ssa.id]}) | ||
|
||
# put a return statement | ||
push!(new_insts, ReturnNode(ret_ssa)) | ||
push!(new_line, new_line[end]) | ||
push!(new_types, Any) | ||
|
||
# this nightmare construct the IRCode with absolutely useless type information | ||
is = CC.InstructionStream( | ||
new_insts, # inst::Vector{Any} | ||
new_types, # type::Vector{Any} | ||
fill(CC.NoCallInfo(), length(new_insts)), # info::Vector{CallInfo} | ||
new_line, # line::Vector{Int32} | ||
fill(UInt8(0), length(new_insts)), # flag::Vector{UInt8} | ||
) | ||
cfg = CC.compute_basic_blocks(new_insts) | ||
new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes) | ||
end | ||
|
||
|
||
|
||
|
||
|
||
function foo(x,y) | ||
z = x * y | ||
z + sin(x) | ||
end | ||
|
||
|
||
(ci, ir, rt) = get_ciir(foo, (Float64, Float64)) | ||
new_ir = forward(ir) | ||
CC.replace_code_newstyle!(ci, ir) | ||
|
||
forw = Core.OpaqueClosure(forward(ir)) | ||
fval, pullbacks = forw(1.0,1.0) | ||
|
||
(1.0,1.0) | ||
|
||
|
||
""" | ||
function reverse(ir) | ||
we construct the reverse using the original `ir` code, since we can obtain it in from the | ||
parameter `S` of the Pullback{S,T}. `S` can contain `(foo, Float64,Float64)` when | ||
we compute the gradient of `foo`. | ||
""" | ||
function reverse(ir) | ||
diffmap = Dict{Any,Any}() # this will hold the mapping where is the gradient with respect to SSA. | ||
# the argument of the pullback we are defining is a gradient with respect to the argument of return | ||
# which we assume to be the last of insturction in `inst` | ||
@assert ir.stmts.inst[end] isa ReturnNode | ||
diffmap[ir.stmts.inst[end].val] = Core.Argument(2) | ||
|
||
reverse_inst = [] | ||
|
||
# a first IR will be to get a structure will pullbacks from the first argument | ||
push!(reverse_inst, Expr(:call, GlobalRef(Core, :getfield), Core.Argument(1), :pullbacks)) | ||
pullbacks_ssa = SSAValue(length(reverse_inst)) | ||
|
||
# we should filter statements from which we can get the pullbacks, but for the trivial | ||
# function without control flow this is not needed | ||
|
||
# now we iterate over pullbacks and execute one by one with correct argument | ||
for i in (length(ir.stmts)-1):-1:1 | ||
inst = ir.stmts[i][:inst] | ||
val_ssa = SSAValue(i) | ||
|
||
#first, we get the pullback | ||
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), pullbacks_ssa, i)) | ||
pullback_ssa = SSAValue(length(reverse_inst)) | ||
|
||
#execute pullback | ||
push!(reverse_inst, Expr(:call, pullback_ssa, diffmap[val_ssa])) | ||
arg_grad = SSAValue(length(reverse_inst)) | ||
for (j, a) in enumerate(inst.args) | ||
j == 1 && continue # we omit gradient with respect to the name of the function and rrule | ||
if haskey(diffmap, a) # we need to perform addition | ||
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), arg_grad, j)) | ||
sv = SSAValue(length(reverse_inst)) | ||
push!(reverse_inst, Expr(:call, GlobalRef(Base, :+), sv, diffmap[a])) | ||
diffmap[a] = SSAValue(length(reverse_inst)) | ||
else | ||
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), arg_grad, j)) | ||
diffmap[a] = SSAValue(length(reverse_inst)) | ||
end | ||
end | ||
end | ||
|
||
end |
Oops, something went wrong.