diff --git a/Project.toml b/Project.toml index de06d941..aa8ccaee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FeynmanDiagram" uuid = "e424a512-dbd9-41ff-9883-094748823e72" authors = ["Kun Chen", "Pengcheng Hou", "Daniel Cerkoney"] -version = "1.0.1" +version = "1.0.2" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/backend/compiler_python.jl b/src/backend/compiler_python.jl index 1efc40c9..0cf48816 100644 --- a/src/backend/compiler_python.jl +++ b/src/backend/compiler_python.jl @@ -6,28 +6,28 @@ # Arguments: - `graphs` vector of computational graphs """ -function to_python_str(graphs::AbstractVector{<:AbstractGraph}) +function to_python_str(graphs::AbstractVector{<:AbstractGraph}; + root::AbstractVector{Int}=[id(g) for g in graphs], name::String="eval_graph", in_place::Bool=false) head = "" body = "" leafidx = 0 - root = [id(g) for g in graphs] inds_visitedleaf = Int[] inds_visitednode = Int[] - gid_to_leafid = Dict{String,Int64}() - rootidx = 0 + map_validx_leaf = Dict{Int,eltype(graphs)}() # mapping from the index of the leafVal to the leaf graph for graph in graphs for g in PostOrderDFS(graph) #leaf first search g_id = id(g) target = "g$(g_id)" isroot = false if g_id in root + target_root = "root[:, $(findfirst(x -> x == g_id, root)-1)]" isroot = true end if isempty(subgraphs(g)) #leaf g_id in inds_visitedleaf && continue - body *= " $target = leaf[$(leafidx)]\n" - gid_to_leafid[target] = leafidx + body *= " $target = leafVal[:, $(leafidx)]\n" leafidx += 1 + map_validx_leaf[leafidx] = g push!(inds_visitedleaf, g_id) else g_id in inds_visitednode && continue @@ -35,23 +35,25 @@ function to_python_str(graphs::AbstractVector{<:AbstractGraph}) push!(inds_visitednode, g_id) end if isroot - body *= " root$(rootidx) = $target\n" - rootidx += 1 + body *= " $target_root = $target\n" end end end - head *= "def graphfunc(leaf):\n" - output = ["root$(i)" for i in 0:rootidx-1] - output = join(output, ",") - tail = " return $output\n\n" - - expr = head * body * tail + if in_place + head *= "def $name(root, leafVal):\n" + else + head *= "import torch\n" + head *= "def $name(leafVal):\n" + head *= " root = torch.empty(leafVal.shape[0], $(length(graphs)), dtype=leafVal.dtype, device=leafVal.device)\n" + end + tail = " return root\n\n" - return expr, gid_to_leafid + return head * body * tail, map_validx_leaf end -function compile_Python(graphs::AbstractVector{<:AbstractGraph}, filename::String="GraphFunc.py") - py_string, leafmap = to_python_str(graphs) - open(filename, "w") do f +function compile_Python(graphs::AbstractVector{<:AbstractGraph}, filename::String; + root::AbstractVector{Int}=[id(g) for g in graphs], func_name="eval_graph") + py_string, leafmap = to_python_str(graphs, root=root, name=func_name) + open(filename, "a") do f write(f, py_string) end return leafmap