From f8829af13b3d43e8b8ddbf81bef557716d40cd7e Mon Sep 17 00:00:00 2001 From: ZhiyiLi Date: Mon, 20 Nov 2023 22:50:08 +0800 Subject: [PATCH] add mindspore.jl --- src/backend/compiler.jl | 1 + src/backend/mindspore.jl | 119 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 src/backend/mindspore.jl diff --git a/src/backend/compiler.jl b/src/backend/compiler.jl index 6559b418..10dedb99 100644 --- a/src/backend/compiler.jl +++ b/src/backend/compiler.jl @@ -8,5 +8,6 @@ using ..RuntimeGeneratedFunctions RuntimeGeneratedFunctions.init(Compilers) include("static.jl") +include("mindspore.jl") end \ No newline at end of file diff --git a/src/backend/mindspore.jl b/src/backend/mindspore.jl new file mode 100644 index 00000000..4f2b1715 --- /dev/null +++ b/src/backend/mindspore.jl @@ -0,0 +1,119 @@ +using PyCall +ms = pyimport("mindspore") + +""" + function to_static(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) + +Returns the static representation of a computational graph node `g` with operator `operator`, subgraphs `subgraphs`, and subgraph factors `subgraph_factors`. +""" +function to_pystatic(operator::Type, subgraphs::AbstractVector{<:AbstractGraph}, subgraph_factors::AbstractVector) + error( + "Static representation for computational graph nodes with operator $(operator) not yet implemented! " * + "Please define a method `to_static(::Type{$(operator)}, subgraphs::$(typeof(subgraphs)), subgraph_factors::$(typeof(subgraph_factors)))`." + ) +end + +function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " + ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " * ") * ")" + # return "(" * join(["g$(g.id)" for g in subgraphs], " * ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Graph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "((g$(subgraphs[1].id))**$N$factor_str)" +end + +function to_pystatic(::Type{ComputationalGraphs.Sum}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " + ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Prod}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {F,W} + if length(subgraphs) == 1 + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "(g$(subgraphs[1].id)$factor_str)" + else + terms = ["g$(g.id)" * (gfactor == 1 ? "" : " * $gfactor") for (g, gfactor) in zip(subgraphs, subgraph_factors)] + return "(" * join(terms, " * ") * ")" + end +end + +function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{FeynmanGraph{F,W}}, subgraph_factors::Vector{F}) where {N,F,W} + factor_str = subgraph_factors[1] == 1 ? "" : " * $(subgraph_factors[1])" + return "((g$(subgraphs[1].id))**$N$factor_str)" +end + +function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph}) + head = "import mindspore as ms\n@ms.jit\n" + head *= "def graphfunc():\n" + body = " graph_list = []\n" + leafidx = 1 + root = [id(g) for g in graphs] + inds_visitedleaf = Int[] + inds_visitednode = Int[] + for graph in graphs + for g in PostOrderDFS(graph) #leaf first search + g_id = id(g) + target = "g$(g_id)" + isroot = false + if g_id in root + isroot = true + end + if isempty(subgraphs(g)) #leaf + g_id in inds_visitedleaf && continue + factor_str = factor(g) == 1 ? "" : " * $(factor(g))" + body *= " $target = ms.Tensor(1.0)$factor_str\n" + leafidx += 1 + push!(inds_visitedleaf, g_id) + else + g_id in inds_visitednode && continue + factor_str = factor(g) == 1 ? "" : " * $(factor(g))" + body *= " $target = $(to_pystatic(operator(g), subgraphs(g), subgraph_factors(g)))$factor_str\n" + push!(inds_visitednode, g_id) + end + if isroot + body *= " graph_list.append($target)\n" + end + end + end + tail = " return graph_list\n" + tail *= "output = graphfunc()" + expr = head * body * tail + # return head * body * tail + f = open("GraphFunc.py","w") + write(f,expr) +end + + +# function to_mindspore_graph(graphs::AbstractVector{<:AbstractGraph}) +# pyexpr = to_python_str_ms(graphs) +# py""" +# import mindspore as ms +# exec($pyexpr) +# ms_graph = jit(fn=graphfunc) +# out = ms_graph() +# """ +# return py"out" +# end +