Skip to content

Commit

Permalink
Store all information about the AST in FormulaAST (#212)
Browse files Browse the repository at this point in the history
* Store all information about the AST in FormulaAST

Before this commit, for some binary and unary operations
FormulaAST was storing opaque function pointers,
losing the information about what the operation was.
Now we always store the kind of operation as part
of the AST node.

This patch also makes treatment of different kinds
of unary and binary operations more uniform and,
during AST evaluation, removes the indirection
due to calling code through the function pointer.

* Remove Undefined Formula NodeType, it's never used
  • Loading branch information
eguiraud authored Oct 23, 2023
1 parent de34794 commit faf8e9d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 58 deletions.
45 changes: 30 additions & 15 deletions include/correction.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@ class FormulaAst {
Literal,
Variable,
Parameter,
UnaryCall,
BinaryCall,
UAtom,
Expression,
Undefined,
Unary,
Binary,
};
enum class BinaryOp {
Equal,
Expand All @@ -67,21 +64,40 @@ class FormulaAst {
Div,
Times,
Pow,
Atan2,
Max,
Min
};
enum class UnaryOp { Negative };
typedef double (*UnaryFcn)(double);
typedef double (*BinaryFcn)(double, double);
typedef std::variant<
enum class UnaryOp {
Negative,
Log,
Log10,
Exp,
Erf,
Sqrt,
Abs,
Cos,
Sin,
Tan,
Acos,
Asin,
Atan,
Cosh,
Sinh,
Tanh,
Acosh,
Asinh,
Atanh
};
using NodeData = std::variant<
std::monostate,
double, // literal/parameter
size_t, // parameter/variable index
UnaryOp,
BinaryOp,
UnaryFcn,
BinaryFcn
> NodeData;
BinaryOp
>;
// TODO: try std::unique_ptr<const Ast> child1, child2 or std::array
typedef std::vector<FormulaAst> Children;
using Children = std::vector<FormulaAst>;

static FormulaAst parse(
ParserType type,
Expand All @@ -91,7 +107,6 @@ class FormulaAst {
bool bind_parameters
);

FormulaAst() : nodetype_(NodeType::Undefined) {};
FormulaAst(NodeType nodetype, NodeData data, Children children) :
nodetype_(nodetype), data_(data), children_(children) {};
double evaluate(const std::vector<Variable::Type>& variables, const std::vector<double>& parameters) const;
Expand Down
99 changes: 56 additions & 43 deletions src/formula_ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,58 +115,58 @@ namespace {
if ( opname == "-" ) { op = FormulaAst::UnaryOp::Negative; }
else { throw std::runtime_error("Unrecognized unary operation: " + std::string(opname)); }
return {
FormulaAst::NodeType::UAtom,
FormulaAst::NodeType::Unary,
op,
{translate_tformula_ast(ast->nodes[1], context)}
};
}
else if (ast->name == "CALLU" ) {
if ( ast->nodes.size() != 2 ) { throw std::runtime_error("CALLU without 2 nodes?"); }
FormulaAst::UnaryFcn fun;
FormulaAst::UnaryOp op;
auto name = ast->nodes[0]->token;
// TODO: lookup in static map
if ( name == "log" ) { fun = [](double x) { return std::log(x); }; }
else if ( name == "log10" ) { fun = [](double x) { return std::log10(x); }; }
else if ( name == "exp" ) { fun = [](double x) { return std::exp(x); }; }
else if ( name == "erf" ) { fun = [](double x) { return std::erf(x); }; }
else if ( name == "sqrt" ) { fun = [](double x) { return std::sqrt(x); }; }
else if ( name == "abs" ) { fun = [](double x) { return std::abs(x); }; }
else if ( name == "cos" ) { fun = [](double x) { return std::cos(x); }; }
else if ( name == "sin" ) { fun = [](double x) { return std::sin(x); }; }
else if ( name == "tan" ) { fun = [](double x) { return std::tan(x); }; }
else if ( name == "acos" ) { fun = [](double x) { return std::acos(x); }; }
else if ( name == "asin" ) { fun = [](double x) { return std::asin(x); }; }
else if ( name == "atan" ) { fun = [](double x) { return std::atan(x); }; }
else if ( name == "cosh" ) { fun = [](double x) { return std::cosh(x); }; }
else if ( name == "sinh" ) { fun = [](double x) { return std::sinh(x); }; }
else if ( name == "tanh" ) { fun = [](double x) { return std::tanh(x); }; }
else if ( name == "acosh" ) { fun = [](double x) { return std::acosh(x); }; }
else if ( name == "asinh" ) { fun = [](double x) { return std::asinh(x); }; }
else if ( name == "atanh" ) { fun = [](double x) { return std::atanh(x); }; }
if ( name == "log" ) { op = FormulaAst::UnaryOp::Log; }
else if ( name == "log10" ) { op = FormulaAst::UnaryOp::Log10; }
else if ( name == "exp" ) { op = FormulaAst::UnaryOp::Exp; }
else if ( name == "erf" ) { op = FormulaAst::UnaryOp::Erf; }
else if ( name == "sqrt" ) { op = FormulaAst::UnaryOp::Sqrt; }
else if ( name == "abs" ) { op = FormulaAst::UnaryOp::Abs; }
else if ( name == "cos" ) { op = FormulaAst::UnaryOp::Cos; }
else if ( name == "sin" ) { op = FormulaAst::UnaryOp::Sin; }
else if ( name == "tan" ) { op = FormulaAst::UnaryOp::Tan; }
else if ( name == "acos" ) { op = FormulaAst::UnaryOp::Acos; }
else if ( name == "asin" ) { op = FormulaAst::UnaryOp::Asin; }
else if ( name == "atan" ) { op = FormulaAst::UnaryOp::Atan; }
else if ( name == "cosh" ) { op = FormulaAst::UnaryOp::Cosh; }
else if ( name == "sinh" ) { op = FormulaAst::UnaryOp::Sinh; }
else if ( name == "tanh" ) { op = FormulaAst::UnaryOp::Tanh; }
else if ( name == "acosh" ) { op = FormulaAst::UnaryOp::Acosh; }
else if ( name == "asinh" ) { op = FormulaAst::UnaryOp::Asinh; }
else if ( name == "atanh" ) { op = FormulaAst::UnaryOp::Atanh; }
else {
throw std::runtime_error("unrecognized unary function: " + std::string(name));
}
return {
FormulaAst::NodeType::UnaryCall,
fun,
FormulaAst::NodeType::Unary,
op,
{translate_tformula_ast(ast->nodes[1], context)}
};
}
else if (ast->name == "CALLB" ) {
if ( ast->nodes.size() != 3 ) { throw std::runtime_error("CALLB without 3 nodes?"); }
FormulaAst::BinaryFcn fun;
FormulaAst::BinaryOp op;
auto name = ast->nodes[0]->token;
// TODO: lookup in static map
if ( name == "atan2" ) { fun = [](double x, double y) { return std::atan2(x, y); }; }
else if ( name == "pow" ) { fun = [](double x, double y) { return std::pow(x, y); }; }
else if ( name == "max" ) { fun = [](double x, double y) { return std::max(x, y); }; }
else if ( name == "min" ) { fun = [](double x, double y) { return std::min(x, y); }; }
if ( name == "atan2" ) { op = FormulaAst::BinaryOp::Atan2; }
else if ( name == "pow" ) { op = FormulaAst::BinaryOp::Pow; }
else if ( name == "max" ) { op = FormulaAst::BinaryOp::Max; }
else if ( name == "min" ) { op = FormulaAst::BinaryOp::Min; }
else {
throw std::runtime_error("unrecognized binary function: " + std::string(name));
}
return {
FormulaAst::NodeType::BinaryCall,
fun,
FormulaAst::NodeType::Binary,
op,
{translate_tformula_ast(ast->nodes[1], context), translate_tformula_ast(ast->nodes[2], context)}
};
}
Expand All @@ -187,7 +187,7 @@ namespace {
else if ( opname == "^" ) { op = FormulaAst::BinaryOp::Pow; }
else { throw std::runtime_error("Unrecognized binary operation: " + std::string(opname)); }
return {
FormulaAst::NodeType::Expression,
FormulaAst::NodeType::Binary,
op,
{translate_tformula_ast(ast->nodes[0], context), translate_tformula_ast(ast->nodes[2], context)}
};
Expand Down Expand Up @@ -219,21 +219,31 @@ double FormulaAst::evaluate(const std::vector<Variable::Type>& values, const std
return std::get<double>(values[std::get<size_t>(data_)]);
case NodeType::Parameter:
return params[std::get<size_t>(data_)];
case NodeType::UAtom:
case NodeType::Unary: {
const auto arg = children_[0].evaluate(values, params);
switch (std::get<UnaryOp>(data_)) {
case UnaryOp::Negative: return -children_[0].evaluate(values, params);
case UnaryOp::Negative: return -arg;
case UnaryOp::Log: return std::log(arg);
case UnaryOp::Log10: return std::log10(arg);
case UnaryOp::Exp: return std::exp(arg);
case UnaryOp::Erf: return std::erf(arg);
case UnaryOp::Sqrt: return std::sqrt(arg);
case UnaryOp::Abs: return std::abs(arg);
case UnaryOp::Cos: return std::cos(arg);
case UnaryOp::Sin: return std::sin(arg);
case UnaryOp::Tan: return std::tan(arg);
case UnaryOp::Acos: return std::acos(arg);
case UnaryOp::Asin: return std::asin(arg);
case UnaryOp::Atan: return std::atan(arg);
case UnaryOp::Cosh: return std::cosh(arg);
case UnaryOp::Sinh: return std::sinh(arg);
case UnaryOp::Tanh: return std::tanh(arg);
case UnaryOp::Acosh: return std::acosh(arg);
case UnaryOp::Asinh: return std::asinh(arg);
case UnaryOp::Atanh: return std::atanh(arg);
}
case NodeType::UnaryCall:
return std::get<UnaryFcn>(data_)(
children_[0].evaluate(values, params)
);
case NodeType::BinaryCall:
return std::get<BinaryFcn>(data_)(
children_[0].evaluate(values, params), children_[1].evaluate(values, params)
);
case NodeType::Undefined:
throw std::runtime_error("Unrecognized AST node");
case NodeType::Expression: {
}
case NodeType::Binary: {
auto left = children_[0].evaluate(values, params);
auto right = children_[1].evaluate(values, params);
switch (std::get<BinaryOp>(data_)) {
Expand All @@ -248,6 +258,9 @@ double FormulaAst::evaluate(const std::vector<Variable::Type>& values, const std
case BinaryOp::Div: return left / right;
case BinaryOp::Times: return left * right;
case BinaryOp::Pow: return std::pow(left, right);
case BinaryOp::Atan2: return std::atan2(left, right);
case BinaryOp::Max: return std::max(left, right);
case BinaryOp::Min: return std::min(left, right);
};
}
default:
Expand Down

0 comments on commit faf8e9d

Please sign in to comment.