Skip to content

Commit

Permalink
added comments about IRCode
Browse files Browse the repository at this point in the history
  • Loading branch information
pevnak committed Nov 20, 2023
1 parent 64d76ce commit 4e5421e
Show file tree
Hide file tree
Showing 4 changed files with 451 additions and 1 deletion.
7 changes: 6 additions & 1 deletion docs/src/lecture_07/lecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ end

Observe that macro dispatch is based on the types of AST that are handed to the macro, not the types that the AST evaluates to at runtime.

List of all defined versions of macro
```julia
methods(var"@showarg")
```

## [Notes on quotation](@id lec7_quotation)
In the previous lecture we have seen that we can *quote a block of code*, which tells the compiler to treat the input as a data and parse it. We have talked about three ways of quoting code.
1. `:(quoted code)`
Expand Down Expand Up @@ -280,7 +285,7 @@ let
@show tstart
end
```
We see that variable `r` has not been assigned during the evaluation of macro. We have also used `let` block in orders not to define any variables in the global scope.
We see that variable `r` has not been assigned during the evaluation of macro. We have also used `let` block in orders not to define any variables in the global scope. The problem with the above is that it cannot be nested.
Why is that?
Let's observe how the macro was expanded
```julia
Expand Down
162 changes: 162 additions & 0 deletions docs/src/lecture_09/ircode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
These notes are from poking around `Core.Compiler` to see, how they are different from working just with `CodeInfo` and `IRTools.jl`. Notes are mainly around IRCode. Why there is a `Core.Compiler.IRCode` when there was `Core.CodeInfo`? Seems to be historical reasons. At the beginning, Julia did not have any intermediate representation and code directly emitted LLVM. Then, it has received an `CodeInfo` as in intermediate representation. `IRCode` seems like an evolution of `CodeInfo`. `Core.Compiler` works mostly with `IRCode`, but the `IRCode` can be converted to the `CodeInfo` and the other way around. `IRCode` seems to be designed more for implementation of various optimisation phases. Personal experience tells me it is much nicer to work with even on the low level.

Throughout the explanation, we assume that `Core.Compiler` was imported as `CC` to decrease the typing load.

Let's play with a simple silly function
```julia

function foo(x,y)
z = x * y
z + sin(x)
end
```

### IRCode
We can obtain `CC.IRCode`
```julia
import Core.Compiler as CC
(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1"))
```
which returns `Core.Compiler.IRCode` in `ir` and return-type `Float64` in `rt`.
The output might look like
```
julia> (ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1"))
1─ %1 = (_2 * _3)::Float64
│ %2 = Main.sin(_2)::Float64
│ %3 = (%1 + %2)::Float64
└── return %3
=> Float64
```
Options of `optimize_until` are `compact 1`, `compact 2`, `nothing.` I do not see a difference between `compact 2` and `compact 2`.

The IRCode structure is defined as
```
struct IRCode
stmts::InstructionStream
argtypes::Vector{Any}
sptypes::Vector{VarState}
linetable::Vector{LineInfoNode}
cfg::CFG
new_nodes::NewNodeStream
meta::Vector{Expr}
end
```
where
* `stmts` is a stream of instruction (more in this below)
* `argtypes` holds types of arguments of the function whose `IRCode` we have obtained
* `sptypes` is a vector of `VarState`. It seems to be related to parameters of types
* `linetable` is a table of unique lines in the source code from which statements
* `cfg` holds control flow graph, which contains building blocks and jumps between them
* `new_nodes` is an infrastructure that can be used to insert new instructions to the existing `IRCode` . The idea behind is that since insertion requires a renumbering all statements, they are put in a separate queue. They are put to correct position with a correct `SSANumber` by calling `compact!`.
* `meta` is something.

Before going further, let's take a look on `InstructionStream` defined as
```julia
struct InstructionStream
inst::Vector{Any}
type::Vector{Any}
info::Vector{CallInfo}
line::Vector{Int32}
flag::Vector{UInt8}
end
```
where
* `inst` is a vector of instructions, stored as `Expr`essions. The allowed fields in `head` are described [here](https://docs.julialang.org/en/v1/devdocs/ast/#Expr-types)
* `type` is the type of the value returned by the corresponding statement
* `CallInfo` is ???some info???
* `line` is an index into `IRCode.linetable` identifying from which line in source code the statement comes from
* `flag` are some flags providing additional information about the statement.
- `0x01 << 0` = statement is marked as `@inbounds`
- `0x01 << 1` = statement is marked as `@inline`
- `0x01 << 2` = statement is marked as `@noinline`
- `0x01 << 3` = statement is within a block that leads to `throw` call
- `0x01` << 4 = statement may be removed if its result is unused, in particular it is thus be both pure and effect free
- `0x01 << 5-6 = <unused>`
- `0x01 << 7 = <reserved>` has out-of-band info

For the above `foo` function, the InstructionStream looks like

```julia
julia> DataFrame(flag = ir.stmts.flag, info = ir.stmts.info, inst = ir.stmts.inst, line = ir.stmts.line, type = ir.stmts.type)
4×5 DataFrame
Row │ flag info inst line type
│ UInt8 CallInfo Any Int32 Any
─────┼────────────────────────────────────────────────────────────────────────
1112 MethodMatchInfo(MethodLookupResu _2 * _3 1 Float64
280 MethodMatchInfo(MethodLookupResu Main.sin(_2) 2 Float64
3112 MethodMatchInfo(MethodLookupResu %1 + %2 2 Float64
40 NoCallInfo() return %3 2 Any
```
We can index into the statements as `ir.stmts[1]`, which provides a "view" into the vector. To obtain the first instruction, we can do `ir.stmts[1][:inst]`.
The IRCode is typed, but the fields can contain `Any`. It is up to the user to provide corrrect types of the output and there is no helper functions to perform typing. A workaround is shown in the Petite Diffractor project. Julia's sections of the manual https://docs.julialang.org/en/v1/devdocs/ssair/ and seems incredibly useful. The IR form they talk about seems to be `Core.Compiler.IRCode`.
It seems to be that it is possible to insert IR instructions into the it structure by queuing that to the field `stmts` and then call `compact!`, which would perform the heavy machinery of relabeling everything.
#### Example of modifying the function through IRCode
Below is an MWE that tries to modify the IRCode of a function and execute it. The goal is to change the function `foo` to `fooled`.
```julia
import Core.Compiler as CC
using Core: SSAValue, GlobalRef, ReturnNode

function foo(x,y)
z = x * y
z + sin(x)
end

function fooled(x,y)
z = x * y
z + sin(x) + cos(y)
end

(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1"));
nr = CC.insert_node!(ir, 2, CC.NewInstruction(Expr(:call, Core.GlobalRef(Main, :cos), Core.Argument(3)), Float64))
nr2 = CC.insert_node!(ir, 4, CC.NewInstruction(Expr(:call, GlobalRef(Main, :+), SSAValue(3), nr), Float64))
CC.setindex!(ir.stmts[4], ReturnNode(nr2), :inst)
ir = CC.compact!(ir)
irfooled = Core.OpaqueClosure(ir)
irfooled(1.0, 2.0) == fooled(1.0, 2.0)
```
So what we did?
1. `(ir, rt) = only(Base.code_ircode(foo, (Float64, Float64), optimize_until = "compact 1"))` obtain the `IRCode` of the function `foo` when called with both arguments being `Float64`. `rt` contains the return type of the
2. A new instruction `cos` is inserted to the `ir` by `Core.Compiler.insert_node!`, which takes as an argument an `IRCode`, position (2 in our case), and new instruction. The new instruction is created by `NewInstruction` accepting as an input expression `Expr` and a return type. Here, we force it to be `Float64`, but ideally it should be inferred. (This would be the next stage). Or, may-be, we can run it through type inference? . The new instruction is added to the `ir.new_nodes` instruction stream and obtain a new SSAValue returned in `nr`, which can be then used further.
3. We add one more instruction `+` that uses output of the instruction we add in step 2, `nr` and SSAValue from statement 3 of the original IR (at this moment, the IR is still numbered with respect to the old IR, the renumbering will happen later.) The output of this second instruction is returned in `nr2`.
4. Then, we rewrite the return statement to return `nr2` instead of `SSAValue(3)`.
5. `ir = CC.compact!(ir)` is superimportant since it moves the newly added statements from `ir.new_stmts` to `ir.stmts` and importantly renumbers `SSAValues.` *Even though the function is mutating, the mutation here is meant that the argument is changed, but the new correct IRCode is returned and therefore has to be reassigned.*
6. The function is created through `OpaqueClosure.`
7. The last line certifies that the function do what it should do.
There is no infrastructure to make the above manipulation transparent, like is the case of @generated function and codeinfo. It is possible to hook through generated function by converting the IRCode to untyped CodeInfo, in which case you do not have to bother with typing.
#### How to obtain code info the proper way?
This is the way code info is obtained in the diffractor.
```julia
mthds = Base._methods_by_ftype(sig, -1, world)
match = only(mthds)

mi = Core.Compiler.specialize_method(match)
ci = Core.Compiler.retrieve_code_info(mi, world)
```
### CodeInfo
`IRTools.jl` are great for modifying `CodeInfo`. I have found two tools for modifying `IRCode` and I wonder if they have been abandoned because they were both dead ends or because of lack of human labor. I am also aware of Also, [this](https://nbviewer.org/gist/tkf/d4734be24d2694a3afd669f8f50e6b0f/00_notebook.ipynb) is quite cool play with IRStuff.
Resources
* https://vchuravy.dev/talks/licm/
* [CompilerPluginTools](https://github.com/JuliaCompilerPlugins/CompilerPluginTools.jl)
* [CodeInfoTools.jl](https://github.com/JuliaCompilerPlugins/CodeInfoTools.jl).
* TKF's [CodeInfo.jl](https://github.com/tkf/ShowCode.jl) is nice for visualization of the IRCode
* Diffractor is an awesome source of howto. For example function `my_insert_node!` in `src/stage1/hacks.jl`
* https://nbviewer.org/gist/tkf/d4734be24d2694a3afd669f8f50e6b0f/00_notebook.ipynb
* https://github.com/JuliaCompilerPlugins/Mixtape.jl
202 changes: 202 additions & 0 deletions docs/src/lecture_09/petite_diffractor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# A simple reverse-mode AD.
# Lots of simplifications have been made (in particular, there is no support for
# control flow). But this illustrates most of the principles behind Zygote.
# https://fluxml.ai/Zygote.jl/dev/internals/


#####

# We assume to have a set of AD rules (e.g. ChainRules), which for a given function returns its evaluation and pullback. If we are tasked with computing the gradient.

# 1. If a rule exists for this function, directly return the rule.
# 2. If not, deconstruct the function into a sequence of functions by asking f `IRCode`
# 3. Replace statements by calls to obtain the evaluation of the statements and the pullback.
# 4. Chain pullbacks in reverse order.
# 5. Return the function evaluation and the chained pullback.

# The idea is that we will replace each statement of `foo` with a statement returning the function value and pullback. At the moment and for simplicity, we assume that appropriate chain is defined. Moreover, we need to keep track of mapping old SSAValues to new SSAValues in ssamap, since their values will differ.

import Core.Compiler as CC
using ChainRules
using Core: SSAValue, GlobalRef, ReturnNode


function get_ciir(f, sig; world = Core.Compiler.get_world_counter(), optimize_until = "compact 1")
mi = only(Base.method_instances(f, sig, world))
ci = Base.uncompressed_ir(mi.def::Method)
(ir, rt) = only(Base.code_ircode(f, sig; optimize_until))
(copy(ci), ir, rt)
end



# struct Pullback{S,T}
# pullbacks::T
# end


argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n]
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[a.id]
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))
argtype(ir::CC.IRCode, a) = error("argtype of $(typeof(a)) not supported")

"""
type_of_pullback(ir, inst)
infer type of the pullback
"""
function type_of_pullback(ir, inst, optimize_until = "compact 1")
inst.head != :call && error("inferrin return type of calls is supported")
params = tuple([argtype(ir, a) for a in inst.args]...)
(ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until))
if !(rt <:Tuple{A,B} where {A,B})
error("The return type of pullback `ChainRules.rrule($(params))` should be tuple")
end
rt
end

remap(d, args::Tuple) = map(a -> remap(d,a), args)
remap(d, args::Vector) = map(a -> remap(d,a), args)
remap(d, r::ReturnNode) = ReturnNode(remap(d, r.val))
remap(d, x::SSAValue) = d[x]
remap(d, x) = x

function forward(ir)
pullbacks = []
new_insts = Any[]
new_line = Int32[]
new_types = Any[]
ssamap = Dict{SSAValue,SSAValue}()
fval_ssa = nothing
for (i, stmt) in enumerate(ir.stmts)
inst = stmt[:inst]
if inst isa Expr && inst.head == :call
new_inst = Expr(:call, GlobalRef(ChainRules, :rrule), remap(ssamap, inst.args)...)
tt = type_of_pullback(ir, inst)
push!(new_insts, new_inst)
push!(new_line, stmt[:line])
push!(new_types, tt)
rrule_ssa = SSAValue(length(new_insts))


push!(new_insts, Expr(:call, :getindex, rrule_ssa, 1))
push!(new_line, stmt[:line])
push!(new_types, tt.parameters[1])
val_ssa = SSAValue(length(new_insts))
ssamap[SSAValue(i)] = val_ssa
(stmt[:type] != tt.parameters[1]) && @info("pullback of $(inst) has a different type than normal inst")

push!(new_insts, Expr(:call, :getindex, rrule_ssa, 2))
pullback_ssa = SSAValue(length(new_insts))
push!(new_line, stmt[:line])
push!(new_types, tt.parameters[2])
push!(pullbacks, pullback_ssa)
continue
end

if inst isa ReturnNode
fval_ssa = remap(ssamap, inst.val)
continue
end
error("unknown node $(i)")
end

# construct tuple with all pullbacks
push!(new_insts, Expr(:call, :tuple, pullbacks...))
pull_ssa = SSAValue(length(new_insts))
push!(new_line, new_line[end])
push!(new_types, Tuple{[new_types[x.id] for x in pullbacks]...})

# construct the tuple containing forward and reverse
push!(new_insts, Expr(:call, :tuple, fval_ssa, pull_ssa))
ret_ssa = SSAValue(length(new_insts))
push!(new_line, new_line[end])
push!(new_types, Tuple{new_types[fval_ssa.id], new_types[pull_ssa.id]})

# put a return statement
push!(new_insts, ReturnNode(ret_ssa))
push!(new_line, new_line[end])
push!(new_types, Any)

# this nightmare construct the IRCode with absolutely useless type information
is = CC.InstructionStream(
new_insts, # inst::Vector{Any}
new_types, # type::Vector{Any}
fill(CC.NoCallInfo(), length(new_insts)), # info::Vector{CallInfo}
new_line, # line::Vector{Int32}
fill(UInt8(0), length(new_insts)), # flag::Vector{UInt8}
)
cfg = CC.compute_basic_blocks(new_insts)
new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes)
end





function foo(x,y)
z = x * y
z + sin(x)
end


(ci, ir, rt) = get_ciir(foo, (Float64, Float64))
new_ir = forward(ir)
CC.replace_code_newstyle!(ci, ir)

forw = Core.OpaqueClosure(forward(ir))
fval, pullbacks = forw(1.0,1.0)

(1.0,1.0)


"""
function reverse(ir)
we construct the reverse using the original `ir` code, since we can obtain it in from the
parameter `S` of the Pullback{S,T}. `S` can contain `(foo, Float64,Float64)` when
we compute the gradient of `foo`.
"""
function reverse(ir)
diffmap = Dict{Any,Any}() # this will hold the mapping where is the gradient with respect to SSA.
# the argument of the pullback we are defining is a gradient with respect to the argument of return
# which we assume to be the last of insturction in `inst`
@assert ir.stmts.inst[end] isa ReturnNode
diffmap[ir.stmts.inst[end].val] = Core.Argument(2)

reverse_inst = []

# a first IR will be to get a structure will pullbacks from the first argument
push!(reverse_inst, Expr(:call, GlobalRef(Core, :getfield), Core.Argument(1), :pullbacks))
pullbacks_ssa = SSAValue(length(reverse_inst))

# we should filter statements from which we can get the pullbacks, but for the trivial
# function without control flow this is not needed

# now we iterate over pullbacks and execute one by one with correct argument
for i in (length(ir.stmts)-1):-1:1
inst = ir.stmts[i][:inst]
val_ssa = SSAValue(i)

#first, we get the pullback
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), pullbacks_ssa, i))
pullback_ssa = SSAValue(length(reverse_inst))

#execute pullback
push!(reverse_inst, Expr(:call, pullback_ssa, diffmap[val_ssa]))
arg_grad = SSAValue(length(reverse_inst))
for (j, a) in enumerate(inst.args)
j == 1 && continue # we omit gradient with respect to the name of the function and rrule
if haskey(diffmap, a) # we need to perform addition
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), arg_grad, j))
sv = SSAValue(length(reverse_inst))
push!(reverse_inst, Expr(:call, GlobalRef(Base, :+), sv, diffmap[a]))
diffmap[a] = SSAValue(length(reverse_inst))
else
push!(reverse_inst, Expr(:call, GlobalRef(Base, :getindex), arg_grad, j))
diffmap[a] = SSAValue(length(reverse_inst))
end
end
end

end
Loading

0 comments on commit 4e5421e

Please sign in to comment.