Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort model equations alphabetically and wrap equations with more than 5 terms #6560

Merged
merged 5 commits into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions packages/mira/tasks/generate_model_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ def main():
# Generate LaTeX code string from MMT model
# =========================================

odeterms = {var: 0 for var in model.get_concepts_name_map().keys()}
odeterms = {var: [] for var in sorted(model.get_concepts_name_map().keys())}

for template in model.templates:
if hasattr(template, "subject"):
var = template.subject.name
odeterms[var] -= template.rate_law.args[0]
odeterms[var].append(-template.rate_law.args[0])

if hasattr(template, "outcome"):
var = template.outcome.name
odeterms[var] += template.rate_law.args[0]
odeterms[var].append(template.rate_law.args[0])

# Sort the terms such that all negative ones come first
odeterms = {var: sorted(terms, key = lambda term: str(term)) for var, terms in odeterms.items()}

# Time
if model.time and model.time.name:
Expand All @@ -41,40 +44,69 @@ def main():
time = "t"
t = sympy.Symbol(time)

# Construct Sympy equations
odesys = []
# Add "(t)" to all the state variables as time-dependent functions
for var, terms in odeterms.items():
for i, term in enumerate(terms):
if hasattr(term, 'atoms'):
for atom in term.atoms(sympy.Symbol):
if str(atom) in odeterms.keys():
term = term.subs(atom, sympy.Function(str(atom))(t))
terms[i] = term


# Construct equations
num_terms = 5 # Max number of terms per line in LaTeX align
odesys = []
exprs = ""
for i, (var, terms) in enumerate(odeterms.items()):

lhs = sympy.diff(sympy.Function(var)(t), t)
rhs = sum(terms)
exprs += sympy.latex(lhs) + " ={}& "

# Few equation terms = no wrapping needed
if len(terms) < num_terms:
exprs += sympy.latex(rhs)

# otherwise, wrap around
else:
rhs = [sympy.latex(sum(terms[j:(j + num_terms)])) for j in range(0, len(terms), num_terms)]
rhs = [line if (j == 0) | (line[0] == '-') else "+ " + line for j, line in enumerate(rhs)] # Add '+ ' to all lines past the first if not start with '- '
exprs += " \\\\ \n &".join(rhs)

# Write (time-dependent) symbols with "(t)"
rhs = terms
if hasattr(terms, 'atoms'):
for atom in terms.atoms(sympy.Symbol):
if str(atom) in odeterms.keys():
rhs = rhs.subs(atom, sympy.Function(str(atom))(t))
if i < (len(odeterms) - 1):
exprs += " \\\\ \n"

odesys = [exprs]

odesys.append(sympy.latex(sympy.Eq(lhs, rhs)))

# Observables
# Repeat for observables if present
if len(model.observables) > 0:

# Write (time-dependent) symbols with "(t)"
obs_eqs = []
for obs in model.observables.values():
lhs = sympy.Function(obs.name)(t)
terms = obs.expression.args[0]
rhs = terms
if hasattr(terms, 'atoms'):
for atom in terms.atoms(sympy.Symbol):
# Sort observables alphabetically
observables = {obs: model.observables[obs].expression.args[0] for obs in sorted(model.observables.keys())}

# Add "(t)" for all the state variables as time-dependent symbols
for obs, expr in observables.items():
if hasattr(expr, 'atoms'):
for atom in expr.atoms(sympy.Symbol):
if str(atom) in odeterms.keys():
rhs = rhs.subs(atom, sympy.Function(str(atom))(t))
obs_eqs.append(sympy.latex(sympy.Eq(lhs, rhs)))
expr = expr.subs(atom, sympy.Function(str(atom))(t))
observables[obs] = expr

for i, (obs, expr) in enumerate(observables.items()):
lhs = sympy.Function(obs)(t)
rhs = expr
exprs = " " + sympy.latex(lhs) + " ={}& " + sympy.latex(rhs)
if i == 0:
exprs = " \\\\ \n" + exprs
if i < (len(observables) - 1):
exprs += " \\\\ \n"

# Add observables
odesys += obs_eqs
odesys[0] += exprs

# Reformat:
odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}"
odesys = "\\begin{align*} \n " + odesys[0] + "\n\\end{align*}"
# =========================================

taskrunner.write_output_dict_with_timeout({"response": odesys})
Expand Down