Skip to content

Commit

Permalink
Refactor field_argument and formula_statement classes
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardodebenedictis committed Jan 30, 2024
1 parent 7dab8ad commit 4910e30
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 44 deletions.
79 changes: 36 additions & 43 deletions include/statement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,29 @@ namespace riddle
virtual std::string to_string() const = 0;
};

class field_argument final
{
public:
field_argument(id_token &&object_id, std::unique_ptr<expression> &&expr = nullptr) : object_id(std::move(object_id)), expr(std::move(expr)) {}

void execute(scope &scp, env &ctx) const;

std::string to_string() const
{
if (expr)
return object_id.to_string() + " = " + expr->to_string();
return object_id.to_string();
}

private:
id_token object_id;
std::unique_ptr<expression> expr;
};

class local_field_statement final : public statement
{
public:
local_field_statement(std::vector<id_token> &&field_type, std::vector<id_token> &&fields, std::vector<std::unique_ptr<expression>> &&initializers) : field_type(std::move(field_type)), fields(std::move(fields)), initializers(std::move(initializers)) {}
local_field_statement(std::vector<id_token> &&field_type, std::vector<field_argument> &&fields) : field_type(std::move(field_type)), fields(std::move(fields)) {}

void execute(scope &scp, env &ctx) const override;

Expand All @@ -28,43 +47,36 @@ namespace riddle
for (size_t i = 1; i < field_type.size(); ++i)
result += "." + field_type[i].to_string();
result += " " + fields[0].to_string();
if (initializers[0])
result += " = " + initializers[0]->to_string();
for (size_t i = 1; i < fields.size(); ++i)
{
result += ", " + fields[i].to_string();
if (initializers[i])
result += " = " + initializers[i]->to_string();
}
return result + ";";
}

private:
std::vector<id_token> field_type;
std::vector<id_token> fields;
std::vector<std::unique_ptr<expression>> initializers;
std::vector<field_argument> fields;
};

class assignment_statement final : public statement
{
public:
assignment_statement(std::vector<id_token> &&lhs, id_token &&object_id, std::unique_ptr<expression> &&rhs) : lhs(std::move(lhs)), object_id(std::move(object_id)), rhs(std::move(rhs)) {}
assignment_statement(std::vector<id_token> &&object_id, id_token &&field_name, std::unique_ptr<expression> &&rhs) : object_id(std::move(object_id)), field_name(std::move(field_name)), rhs(std::move(rhs)) {}

void execute(scope &scp, env &ctx) const override;

std::string to_string() const override
{
if (lhs.empty())
return object_id.to_string() + " = " + rhs->to_string() + ";";
std::string result = lhs[0].to_string();
for (size_t i = 1; i < lhs.size(); ++i)
result += "." + lhs[i].to_string();
return result + object_id.to_string() + " = " + rhs->to_string() + ";";
if (object_id.empty())
return field_name.to_string() + " = " + rhs->to_string() + ";";
std::string result = object_id[0].to_string();
for (size_t i = 1; i < object_id.size(); ++i)
result += "." + object_id[i].to_string();
return result + field_name.to_string() + " = " + rhs->to_string() + ";";
}

private:
std::vector<id_token> lhs;
id_token object_id;
std::vector<id_token> object_id;
id_token field_name;
std::unique_ptr<expression> rhs;
};

Expand Down Expand Up @@ -107,20 +119,20 @@ namespace riddle
class disjunction_statement final : public statement
{
public:
disjunction_statement(std::vector<conjunction_statement> &&blocks) : blocks(std::move(blocks)) {}
disjunction_statement(std::vector<std::unique_ptr<conjunction_statement>> &&blocks) : blocks(std::move(blocks)) {}

void execute(scope &scp, env &ctx) const override;

std::string to_string() const override
{
std::string result = blocks[0].to_string();
std::string result = blocks[0]->to_string();
for (size_t i = 1; i < blocks.size(); ++i)
result += " or " + blocks[i].to_string();
result += " or " + blocks[i]->to_string();
return result;
}

private:
std::vector<conjunction_statement> blocks;
std::vector<std::unique_ptr<conjunction_statement>> blocks;
};

class for_all_statement final : public statement
Expand All @@ -147,25 +159,6 @@ namespace riddle
std::vector<std::unique_ptr<statement>> statements;
};

class formula_argument final
{
public:
formula_argument(id_token &&object_id, std::unique_ptr<expression> &&expr) : object_id(std::move(object_id)), expr(std::move(expr)) {}

void execute(scope &scp, env &ctx) const;

std::string to_string() const
{
if (expr)
return object_id.to_string() + " = " + expr->to_string();
return object_id.to_string();
}

private:
id_token object_id;
std::unique_ptr<expression> expr;
};

class return_statement final : public statement
{
public:
Expand All @@ -187,7 +180,7 @@ namespace riddle
class formula_statement final : public statement
{
public:
formula_statement(bool is_fact, id_token &&formula_name, std::vector<id_token> &&formula_scope, id_token &&predicate_name, std::vector<formula_argument> &&arguments) : is_fact(is_fact), formula_name(std::move(formula_name)), formula_scope(std::move(formula_scope)), predicate_name(std::move(predicate_name)), arguments(std::move(arguments)) {}
formula_statement(bool is_fact, id_token &&formula_name, std::vector<id_token> &&formula_scope, id_token &&predicate_name, std::vector<field_argument> &&arguments) : is_fact(is_fact), formula_name(std::move(formula_name)), formula_scope(std::move(formula_scope)), predicate_name(std::move(predicate_name)), arguments(std::move(arguments)) {}

void execute(scope &scp, env &ctx) const override;

Expand Down Expand Up @@ -217,6 +210,6 @@ namespace riddle
id_token formula_name;
std::vector<id_token> formula_scope;
id_token predicate_name;
std::vector<formula_argument> arguments;
std::vector<field_argument> arguments;
};
} // namespace riddle
226 changes: 225 additions & 1 deletion src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,231 @@ namespace riddle

return std::make_unique<predicate_declaration>(std::move(name), std::move(params), std::move(base_predicates), std::move(stmts));
}
std::unique_ptr<statement> parser::parse_statement() {}
std::unique_ptr<statement> parser::parse_statement()
{
switch (tk->sym)
{
case BOOL_ID:
case INT_ID:
case REAL_ID:
case TIME_ID:
case STRING_ID:
{ // a local field having a primitive type..
std::vector<id_token> field_type;
std::vector<field_argument> fields;

field_type.emplace_back(tk->to_string(), tk->start_line, tk->start_pos, tk->end_line, tk->end_pos);
tk = next_token();

do
{ // the fields..
if (!match(ID_ID))
error("Expected identifier..");

auto name = *static_cast<const id_token *>(tokens[pos - 2].get()); // the name of the field..

if (match(EQ_ID))
fields.emplace_back(std::move(name), parse_expression());
else
fields.emplace_back(std::move(name));
} while (match(COMMA_ID));

if (!match(SEMICOLON_ID))
error("Expected `;`..");

return std::make_unique<local_field_statement>(std::move(field_type), std::move(fields));
}
case ID_ID:
{ // either a local field, an assignment or an expression..
size_t c_pos = pos;
std::vector<id_token> object_id;
object_id.emplace_back(*static_cast<const id_token *>(tk));
tk = next_token();
while (match(DOT_ID))
{
if (!match(ID_ID))
error("Expected identifier..");
object_id.emplace_back(*static_cast<const id_token *>(tokens[pos - 2].get()));
}

switch (tk->sym)
{
case ID_ID:
{ // a local field..
std::vector<field_argument> fields;

do
{ // the fields..
if (!match(ID_ID))
error("Expected identifier..");

auto name = *static_cast<const id_token *>(tokens[pos - 2].get()); // the name of the field..

if (match(EQ_ID))
fields.emplace_back(std::move(name), parse_expression());
else
fields.emplace_back(std::move(name));
} while (match(COMMA_ID));
}
case EQ_ID:
{ // an assignment..
id_token field_name = object_id.back();
object_id.pop_back();
std::unique_ptr<expression> expr = parse_expression();

if (!match(SEMICOLON_ID))
error("Expected `;`..");

return std::make_unique<assignment_statement>(std::move(object_id), std::move(field_name), std::move(expr));
}
case PLUS_ID: // an expression..
case MINUS_ID:
case STAR_ID:
case SLASH_ID:
case LT_ID:
case LTEQ_ID:
case EQEQ_ID:
case GTEQ_ID:
case GT_ID:
case BANGEQ_ID:
case IMPLICATION_ID:
case BAR_ID:
case AMP_ID:
case CARET_ID:
case SEMICOLON_ID:
{
backtrack(c_pos);
std::unique_ptr<expression> expr = parse_expression();
if (!match(SEMICOLON_ID))
error("Expected `;`..");
return std::make_unique<expression_statement>(std::move(expr));
}
default:
error("Expected either identifier or `=` or `+` or `-` or `*` or `/` or `<` or `<=` or `==` or `>=` or `>` or `!=` or `=>` or `|` or `&` or `^` or `;` or identifier..");
}
}
case LBRACE_ID:
{ // either a block or a disjunction..
std::vector<std::unique_ptr<conjunction_statement>> conjuncts;
backtrack(pos - 1);
do
{
std::vector<std::unique_ptr<statement>> stmts;
if (!match(LBRACE_ID))
error("Expected `{`..");
while (!match(RBRACE_ID))
stmts.emplace_back(parse_statement());

if (match(LBRACKET_ID))
{ // a priced conjunct..
std::unique_ptr<expression> cst = parse_expression();
if (!match(RBRACKET_ID))
error("Expected `]`..");
conjuncts.emplace_back(std::make_unique<conjunction_statement>(std::move(stmts), std::move(cst)));
}
else
conjuncts.emplace_back(std::make_unique<conjunction_statement>(std::move(stmts)));
} while (match(OR_ID));

if (conjuncts.size() == 1)
return std::move(conjuncts[0]);

return std::make_unique<disjunction_statement>(std::move(conjuncts));
}
case FOR_ID:
{ // a for loop..
if (!match(LPAREN_ID))
error("Expected `(`..");

std::vector<id_token> enum_type;
std::vector<std::unique_ptr<statement>> statements;

do // the enum type..
{
if (!match(ID_ID))
error("Expected identifier..");
enum_type.emplace_back(*static_cast<const id_token *>(tokens[pos - 2].get()));
} while (match(DOT_ID));

if (!match(ID_ID))
error("Expected identifier..");

auto name = *static_cast<const id_token *>(tokens[pos - 2].get()); // the name of the enum..

if (!match(RPAREN_ID))
error("Expected `)`..");

if (!match(LBRACE_ID))
error("Expected `{`..");

while (!match(RBRACE_ID))
statements.emplace_back(parse_statement());

return std::make_unique<for_all_statement>(std::move(enum_type), std::move(name), std::move(statements));
}
case RETURN_ID:
{ // a return statement..
tk = next_token();
std::unique_ptr<expression> expr = parse_expression();
if (!match(SEMICOLON_ID))
error("Expected `;`..");
return std::make_unique<return_statement>(std::move(expr));
}
case FACT_ID:
case GOAL_ID:
{ // either a fact or a goal..
bool is_fact = match(FACT_ID);
std::vector<id_token> formula_scope;
std::vector<field_argument> arguments;

if (!match(ID_ID))
error("Expected identifier..");

auto name = *static_cast<const id_token *>(tokens[pos - 2].get()); // the name of the formula..

if (!match(EQ_ID))
error("Expected `=`..");

if (!match(NEW_ID))
error("Expected `new`..");

do // the formula scope..
{
if (!match(ID_ID))
error("Expected identifier..");
formula_scope.emplace_back(*static_cast<const id_token *>(tokens[pos - 2].get()));
} while (match(DOT_ID));

auto predicate_name = formula_scope.back();
formula_scope.pop_back();

if (!match(LPAREN_ID))
error("Expected `(`..");

if (!match(RPAREN_ID))
do // the arguments..
{
if (!match(ID_ID))
error("Expected identifier..");

auto name = *static_cast<const id_token *>(tokens[pos - 2].get()); // the name of the argument..

if (match(EQ_ID))
arguments.emplace_back(std::move(name), parse_expression());
else
arguments.emplace_back(std::move(name));
} while (match(COMMA_ID));

if (!match(RPAREN_ID))
error("Expected `)`..");

if (!match(SEMICOLON_ID))
error("Expected `;`..");

return std::make_unique<formula_statement>(is_fact, std::move(name), std::move(formula_scope), std::move(predicate_name), std::move(arguments));
}
}
}
std::unique_ptr<expression> parser::parse_expression(const size_t &pr)
{
std::unique_ptr<expression> expr;
Expand Down

0 comments on commit 4910e30

Please sign in to comment.