Skip to content

Commit

Permalink
Merge pull request #163 from numericalEFT/computgraph_zhiyi
Browse files Browse the repository at this point in the history
refactor `compiler_python.jl` for supporting both `jax` and `mindspore` framework
  • Loading branch information
peter0627ustc authored Dec 6, 2023
2 parents dd29172 + f398cd0 commit cfe8054
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/backend/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ using ..RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(Compilers)

include("static.jl")
include("toMindspore.jl")
include("compiler_python.jl")

end
21 changes: 14 additions & 7 deletions src/backend/toMindspore.jl → src/backend/compiler_python.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,21 @@ function to_pystatic(::Type{ComputationalGraphs.Power{N}}, subgraphs::Vector{Fey
end

"""
function to_julia_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the static graph representation in mindspore framework.
function to_python_str(graphs::AbstractVector{<:AbstractGraph})
Compile a list of graphs into a string for a python static function and output a python script which support the mindspore and jax framework.
# Arguments:
- `graphs` vector of computational graphs
- `framework` the type of the python frameworks, including `:jax` and `mindspore`.
"""
function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
head = "import mindspore as ms\n@ms.jit\n"
# head *= "def graphfunc(leaf):\n"
# body = " graph_list = []\n"
function to_python_str(graphs::AbstractVector{<:AbstractGraph}, framework::Symbol=:jax)
if framework == :jax
head = ""
elseif framework == :mindspore
head = "import mindspore as ms\n@ms.jit\n"
else
error("no support for $type framework")
end
body = ""
leafidx = 1
root = [id(g) for g in graphs]
Expand Down

0 comments on commit cfe8054

Please sign in to comment.