From d3a28510f06517cc5b71b39ef6dc096bb6a9eca9 Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Mon, 10 Feb 2025 17:06:03 -0500 Subject: [PATCH 1/3] Sort the LHS variable name --- packages/mira/tasks/generate_model_latex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/mira/tasks/generate_model_latex.py b/packages/mira/tasks/generate_model_latex.py index 4ea7053b79..c7b4c7ee8e 100644 --- a/packages/mira/tasks/generate_model_latex.py +++ b/packages/mira/tasks/generate_model_latex.py @@ -23,7 +23,7 @@ def main(): # Generate LaTeX code string from MMT model # ========================================= - odeterms = {var: 0 for var in model.get_concepts_name_map().keys()} + odeterms = {var: 0 for var in sorted(model.get_concepts_name_map().keys())} for template in model.templates: if hasattr(template, "subject"): From f23413f3f31879ef4310bb87d4346a1e5bdd648c Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Mon, 10 Feb 2025 17:07:31 -0500 Subject: [PATCH 2/3] Sort the terms of each ODE --- packages/mira/tasks/generate_model_latex.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/mira/tasks/generate_model_latex.py b/packages/mira/tasks/generate_model_latex.py index c7b4c7ee8e..fee5197f93 100644 --- a/packages/mira/tasks/generate_model_latex.py +++ b/packages/mira/tasks/generate_model_latex.py @@ -23,16 +23,19 @@ def main(): # Generate LaTeX code string from MMT model # ========================================= - odeterms = {var: 0 for var in sorted(model.get_concepts_name_map().keys())} + odeterms = {var: [] for var in sorted(model.get_concepts_name_map().keys())} for template in model.templates: if hasattr(template, "subject"): var = template.subject.name - odeterms[var] -= template.rate_law.args[0] + odeterms[var].append(-template.rate_law.args[0]) if hasattr(template, "outcome"): var = template.outcome.name - odeterms[var] += template.rate_law.args[0] + odeterms[var].append(template.rate_law.args[0]) + + # Sort the terms such that all negative ones come first + odeterms = {var: sorted(terms, key = lambda term: str(term)) for var, terms in odeterms.items()} # Time if model.time and model.time.name: From fb9fe90cd3d9b9119051b482ad4a7e9e97bfb745 Mon Sep 17 00:00:00 2001 From: Nelson Liu Date: Mon, 10 Feb 2025 23:40:46 -0500 Subject: [PATCH 3/3] Wrap long equations and sort observables --- packages/mira/tasks/generate_model_latex.py | 75 ++++++++++++++------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/packages/mira/tasks/generate_model_latex.py b/packages/mira/tasks/generate_model_latex.py index fee5197f93..bf7d46f0b9 100644 --- a/packages/mira/tasks/generate_model_latex.py +++ b/packages/mira/tasks/generate_model_latex.py @@ -44,40 +44,69 @@ def main(): time = "t" t = sympy.Symbol(time) - # Construct Sympy equations - odesys = [] + # Add "(t)" to all the state variables as time-dependent functions for var, terms in odeterms.items(): + for i, term in enumerate(terms): + if hasattr(term, 'atoms'): + for atom in term.atoms(sympy.Symbol): + if str(atom) in odeterms.keys(): + term = term.subs(atom, sympy.Function(str(atom))(t)) + terms[i] = term + + + # Construct equations + num_terms = 5 # Max number of terms per line in LaTeX align + odesys = [] + exprs = "" + for i, (var, terms) in enumerate(odeterms.items()): + lhs = sympy.diff(sympy.Function(var)(t), t) + rhs = sum(terms) + exprs += sympy.latex(lhs) + " ={}& " - # Write (time-dependent) symbols with "(t)" - rhs = terms - if hasattr(terms, 'atoms'): - for atom in terms.atoms(sympy.Symbol): - if str(atom) in odeterms.keys(): - rhs = rhs.subs(atom, sympy.Function(str(atom))(t)) + # Few equation terms = no wrapping needed + if len(terms) < num_terms: + exprs += sympy.latex(rhs) - odesys.append(sympy.latex(sympy.Eq(lhs, rhs))) + # otherwise, wrap around + else: + rhs = [sympy.latex(sum(terms[j:(j + num_terms)])) for j in range(0, len(terms), num_terms)] + rhs = [line if (j == 0) | (line[0] == '-') else "+ " + line for j, line in enumerate(rhs)] # Add '+ ' to all lines past the first if not start with '- ' + exprs += " \\\\ \n &".join(rhs) - # Observables + if i < (len(odeterms) - 1): + exprs += " \\\\ \n" + + odesys = [exprs] + + + # Repeat for observables if present if len(model.observables) > 0: - # Write (time-dependent) symbols with "(t)" - obs_eqs = [] - for obs in model.observables.values(): - lhs = sympy.Function(obs.name)(t) - terms = obs.expression.args[0] - rhs = terms - if hasattr(terms, 'atoms'): - for atom in terms.atoms(sympy.Symbol): + # Sort observables alphabetically + observables = {obs: model.observables[obs].expression.args[0] for obs in sorted(model.observables.keys())} + + # Add "(t)" for all the state variables as time-dependent symbols + for obs, expr in observables.items(): + if hasattr(expr, 'atoms'): + for atom in expr.atoms(sympy.Symbol): if str(atom) in odeterms.keys(): - rhs = rhs.subs(atom, sympy.Function(str(atom))(t)) - obs_eqs.append(sympy.latex(sympy.Eq(lhs, rhs))) + expr = expr.subs(atom, sympy.Function(str(atom))(t)) + observables[obs] = expr + + for i, (obs, expr) in enumerate(observables.items()): + lhs = sympy.Function(obs)(t) + rhs = expr + exprs = " " + sympy.latex(lhs) + " ={}& " + sympy.latex(rhs) + if i == 0: + exprs = " \\\\ \n" + exprs + if i < (len(observables) - 1): + exprs += " \\\\ \n" - # Add observables - odesys += obs_eqs + odesys[0] += exprs # Reformat: - odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}" + odesys = "\\begin{align*} \n " + odesys[0] + "\n\\end{align*}" # ========================================= taskrunner.write_output_dict_with_timeout({"response": odesys})