From bf955e74738eaa6915314b26da9790066ce633d7 Mon Sep 17 00:00:00 2001 From: jeshecdom Date: Wed, 6 Nov 2024 19:44:39 +0100 Subject: [PATCH 1/5] First version of the expression simplification phase. Still a couple of TODOs, like adding the CLI option to activate/deactivate optimization phase. --- src/optimizer/expr_simplification.ts | 288 +++++++++++++++++++++++++++ src/optimizer/optimization_phase.ts | 12 ++ src/pipeline/build.ts | 10 + src/types/resolveDescriptors.ts | 7 + src/types/resolveExpression.ts | 2 +- 5 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 src/optimizer/expr_simplification.ts create mode 100644 src/optimizer/optimization_phase.ts diff --git a/src/optimizer/expr_simplification.ts b/src/optimizer/expr_simplification.ts new file mode 100644 index 000000000..0db56e554 --- /dev/null +++ b/src/optimizer/expr_simplification.ts @@ -0,0 +1,288 @@ +import { CompilerContext } from "../context"; +import { TactConstEvalError } from "../errors"; +import { resolveFuncType } from "../generator/writers/resolveFuncType"; +import { AstCondition, AstExpression, AstFunctionDef, AstReceiver, AstStatement, AstValue, createAstNode } from "../grammar/ast"; +import { Interpreter } from "../interpreter"; +import { getAllStaticFunctions, getAllTypes, getStaticFunction, replaceStaticFunctions } from "../types/resolveDescriptors"; +import { getExpType, registerExpType } from "../types/resolveExpression"; +import { FunctionDescription, ReceiverDescription, Value } from "../types/types"; +import { makeValueExpression } from "./util"; + +export function simplify_expressions(ctx: CompilerContext): CompilerContext { + + // The interpreter in charge of simplifiying expressions + const interpreter = new Interpreter(ctx); + + // Traverse the program and attempt to evaluate every expression + + // Process functions + const newFunctions: Map = new Map(); + + for (const f of getAllStaticFunctions(ctx)) { + if (f.ast.kind === "function_def") { + const statementsResult = process_statements(f.ast.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + const newFunctionCode = createAstNode({...f.ast, statements: newStatements}) as AstFunctionDef; + newFunctions.set(f.name, {...f, ast: newFunctionCode}); + } + // The rest of kinds do not have explicit Tact expressions. + } + ctx = replaceStaticFunctions(ctx, newFunctions); + + // Process all types + for (const t of getAllTypes(ctx)) { + + // Process init + if (t.init) { + process_statements(t.init.ast.statements, ctx, interpreter); + } + + // TODO: Need to replace initializer function + + // Process receivers + + const newReceivers: ReceiverDescription[] = []; + + for (const r of t.receivers) { + const statementsResult = process_statements(r.ast.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + const newReceiverCode = createAstNode({...r.ast, statements: newStatements}) as AstReceiver; + newReceivers.push({...r, ast: newReceiverCode}); + } + + // TODO: Need to replace the receivers in the type + + // Process methods + for (const m of t.functions.values()) { + if (m.ast.kind === "function_def") { + process_statements(m.ast.statements, ctx, interpreter); + } + // The rest of kinds do not have explicit Tact expressions. + } + + // TODO: Need to replace methods + + } + return ctx; + +} + +function process_statements(statements: AstStatement[], ctx: CompilerContext, interpreter: Interpreter): { stmts: AstStatement[], ctx: CompilerContext } { + const newStatements: AstStatement[] = []; + + for (const stmt of statements) { + const result = process_statement(stmt, ctx, interpreter); + newStatements.push(result.stmt); + ctx = result.ctx; + } + + return { stmts: newStatements, ctx: ctx }; +} + +function process_statement(stmt: AstStatement, ctx: CompilerContext, interpreter: Interpreter): { stmt: AstStatement, ctx: CompilerContext } { + switch (stmt.kind) { + case "statement_assign": + case "statement_expression": + case "statement_let": + case "statement_destruct": + case "statement_augmentedassign": { + const value = tryExpression(stmt.expression, interpreter); + if (value !== undefined) { + const new_expr = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, new_expr, getExpType(ctx, stmt.expression)); + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + expression: new_expr, + }) as AstStatement, + ctx: ctx + }; + } + return { + stmt: stmt, + ctx: ctx + }; + } + case "statement_return": { + if (stmt.expression !== null) { + const value = tryExpression(stmt.expression, interpreter); + if (value !== undefined) { + const new_expr = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, new_expr, getExpType(ctx, stmt.expression)); + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + expression: new_expr, + }) as AstStatement, + ctx: ctx + }; + } + } + return { + stmt: stmt, + ctx: ctx + }; + } + case "statement_condition": { + const value = tryExpression(stmt.condition, interpreter); + let newCondition = stmt.condition; + if (value !== undefined) { + newCondition = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, newCondition, getExpType(ctx, stmt.condition)); + } + + const trueStatementsResult = process_statements(stmt.trueStatements, ctx, interpreter); + const newTrueStatements = trueStatementsResult.stmts; + ctx = trueStatementsResult.ctx; + + let newFalseStatements: AstStatement[] | null = null; + if (stmt.falseStatements !== null) { + const falseStatementsResult = process_statements(stmt.falseStatements, ctx, interpreter); + newFalseStatements = falseStatementsResult.stmts; + ctx = falseStatementsResult.ctx; + } + + let newElseIf: AstCondition | null = null; + if (stmt.elseif !== null) { + const elseIfResult = process_statement(stmt.elseif, ctx, interpreter); + newElseIf = elseIfResult.stmt as AstCondition; + ctx = elseIfResult.ctx; + } + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + condition: newCondition, + trueStatements: newTrueStatements, + falseStatements: newFalseStatements, + elseif: newElseIf + }) as AstStatement, + ctx: ctx + }; + } + case "statement_foreach": { + const value = tryExpression(stmt.map, interpreter); + let newMap = stmt.map; + if (value !== undefined) { + newMap = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, newMap, getExpType(ctx, stmt.map)); + } + const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + map: newMap, + statements: newStatements, + }) as AstStatement, + ctx: ctx + }; + } + case "statement_until": + case "statement_while": { + const value = tryExpression(stmt.condition, interpreter); + let newCondition = stmt.condition; + if (value !== undefined) { + newCondition = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, newCondition, getExpType(ctx, stmt.condition)); + } + const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + condition: newCondition, + statements: newStatements, + }) as AstStatement, + ctx: ctx + }; + } + case "statement_repeat": { + const value = tryExpression(stmt.iterations, interpreter); + let newIterations = stmt.iterations; + if (value !== undefined) { + newIterations = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, newIterations, getExpType(ctx, stmt.iterations)); + } + const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + iterations: newIterations, + statements: newStatements, + }) as AstStatement, + ctx: ctx + }; + } + case "statement_try": { + const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + statements: newStatements, + }) as AstStatement, + ctx: ctx + }; + } + case "statement_try_catch": { + const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + + const catchStatementsResult = process_statements(stmt.catchStatements, ctx, interpreter); + const newCatchStatements = catchStatementsResult.stmts; + ctx = catchStatementsResult.ctx; + + // Create the replacement node + return { + stmt: createAstNode({ + ...stmt, + statements: newStatements, + catchStatements: newCatchStatements + }) as AstStatement, + ctx: ctx + }; + } + } +} + +function tryExpression(expr: AstExpression, interpreter: Interpreter): Value | undefined { + try { + // Eventually, this will be replaced by the partial evaluator. + return interpreter.interpretExpression(expr); + } catch (e) { + if (e instanceof TactConstEvalError) { + if (!e.fatal) { + return undefined; + } + } + throw e; + } +} diff --git a/src/optimizer/optimization_phase.ts b/src/optimizer/optimization_phase.ts new file mode 100644 index 000000000..72005dc1e --- /dev/null +++ b/src/optimizer/optimization_phase.ts @@ -0,0 +1,12 @@ +import { CompilerContext } from "../context"; +import { simplify_expressions } from "./expr_simplification"; + +export function optimize_tact(ctx: CompilerContext): CompilerContext { + + // Call the expression simplification phase + ctx = simplify_expressions(ctx); + + // Here, we will call the constant propagation analyzer + + return ctx; +} \ No newline at end of file diff --git a/src/pipeline/build.ts b/src/pipeline/build.ts index 91d93131a..e968c5074 100644 --- a/src/pipeline/build.ts +++ b/src/pipeline/build.ts @@ -21,6 +21,7 @@ import { precompile } from "./precompile"; import { getCompilerVersion } from "./version"; import { idText } from "../grammar/ast"; import { TactErrorCollection } from "../errors"; +import { optimize_tact } from "../optimizer/optimization_phase"; export function enableFeatures( ctx: CompilerContext, @@ -86,6 +87,15 @@ export async function build(args: { return { ok: true, error: [] }; } + // Run high level optimization phase + try { + // TODO: Add configuration option to dump optimized code and also activate/deactivate optimization phase + ctx = optimize_tact(ctx); + } catch (e) { + logger.error(e as Error); + return { ok: false, error: [e as Error] }; + } + // Compile contracts let ok = true; const errorMessages: TactErrorCollection[] = []; diff --git a/src/types/resolveDescriptors.ts b/src/types/resolveDescriptors.ts index f6eaafa3e..c4fc6d8fc 100644 --- a/src/types/resolveDescriptors.ts +++ b/src/types/resolveDescriptors.ts @@ -2007,6 +2007,13 @@ export function getStaticFunction( return r; } +export function replaceStaticFunctions(ctx: CompilerContext, newFunctions: Map): CompilerContext { + for (const [name, funcDesc] of newFunctions) { + ctx = staticFunctionsStore.set(ctx, name, funcDesc) + } + return ctx; +} + export function hasStaticFunction(ctx: CompilerContext, name: string) { return !!staticFunctionsStore.get(ctx, name); } diff --git a/src/types/resolveExpression.ts b/src/types/resolveExpression.ts index 702523f0f..27ac22f86 100644 --- a/src/types/resolveExpression.ts +++ b/src/types/resolveExpression.ts @@ -52,7 +52,7 @@ export function getExpType(ctx: CompilerContext, exp: AstExpression) { return t.description; } -function registerExpType( +export function registerExpType( ctx: CompilerContext, exp: AstExpression, description: TypeRef, From b297ab25ad0b9bc12bc09526b8d7d15ea5787eb5 Mon Sep 17 00:00:00 2001 From: jeshecdom Date: Fri, 15 Nov 2024 03:41:55 +0100 Subject: [PATCH 2/5] Further fixes to computing the new AST after expression simplifications. Added CLI options for skipping optimization phase and for dumping optimized tact code. --- schemas/configSchema.json | 10 + src/config/parseConfig.ts | 9 + src/grammar/ast.ts | 7 +- src/interpreter.ts | 2 +- src/optimizer/expr_simplification.ts | 600 ++++++++++++++---- src/optimizer/optimization_phase.ts | 32 +- .../expr_simplification.spec.ts.snap | 11 + .../test/expr_simplification.spec.ts | 44 ++ .../interpreter-called-when-no-contract.tact | 11 + src/optimizer/util.ts | 49 +- src/pipeline/build.ts | 38 +- src/prettyPrinter.ts | 2 + src/types/resolveDescriptors.ts | 27 +- 13 files changed, 704 insertions(+), 138 deletions(-) create mode 100644 src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap create mode 100644 src/optimizer/test/expr_simplification.spec.ts create mode 100644 src/optimizer/test/failed/interpreter-called-when-no-contract.tact diff --git a/schemas/configSchema.json b/schemas/configSchema.json index 1245a24de..ada0fd544 100644 --- a/schemas/configSchema.json +++ b/schemas/configSchema.json @@ -56,6 +56,16 @@ "default": false, "description": "False by default. If set to true, enables generation of a getter the information on the interfaces provided by the contract.\n\nRead more about supported interfaces: https://docs.tact-lang.org/ref/evolution/OTP-001." }, + "skipTactOptimizationPhase": { + "type": "boolean", + "default": false, + "description": "False by default. If set to true, skips the Tact code optimization phase." + }, + "dumpOptimizedTactCode": { + "type": "boolean", + "default": false, + "description": "False by default. If set to true, dumps the code produced by the Tact code optimization phase. In case the optimization phase is skipped, this option is ignored." + }, "experimental": { "type": "object", "description": "Experimental options that might be removed in the future. Use with caution!", diff --git a/src/config/parseConfig.ts b/src/config/parseConfig.ts index f4633b186..66f52143a 100644 --- a/src/config/parseConfig.ts +++ b/src/config/parseConfig.ts @@ -33,6 +33,15 @@ export const optionsSchema = z * Read more: https://docs.tact-lang.org/book/contracts#interfaces */ interfacesGetter: z.boolean().optional(), + /** + * If set to true, skips the Tact code optimization phase. + */ + skipTactOptimizationPhase: z.boolean().optional(), + /** + * If set to true, dumps the code produced by the Tact code optimization phase. + * In case the optimization phase is skipped, this option is ignored. + */ + dumpOptimizedTactCode: z.boolean().optional(), /** * Experimental options that might be removed in the future. Use with caution! */ diff --git a/src/grammar/ast.ts b/src/grammar/ast.ts index cc75370b2..1546fedab 100644 --- a/src/grammar/ast.ts +++ b/src/grammar/ast.ts @@ -638,7 +638,12 @@ export type AstNull = { loc: SrcInfo; }; -export type AstValue = AstNumber | AstBoolean | AstNull | AstString; +export type AstValue = + | AstNumber + | AstBoolean + | AstNull + | AstString + | AstStructInstance; export type AstConstantAttribute = | { type: "virtual"; loc: SrcInfo } diff --git a/src/interpreter.ts b/src/interpreter.ts index 66213c44b..8a91d1620 100644 --- a/src/interpreter.ts +++ b/src/interpreter.ts @@ -960,7 +960,7 @@ export class Interpreter { if (foundContractConst.value !== undefined) { return foundContractConst.value; } else { - throwErrorConstEval( + throwNonFatalErrorConstEval( `cannot evaluate declared contract/trait constant ${idTextErr(ast.field)} as it does not have a body`, ast.field.loc, ); diff --git a/src/optimizer/expr_simplification.ts b/src/optimizer/expr_simplification.ts index 0db56e554..9ba55dfe1 100644 --- a/src/optimizer/expr_simplification.ts +++ b/src/optimizer/expr_simplification.ts @@ -1,75 +1,371 @@ import { CompilerContext } from "../context"; import { TactConstEvalError } from "../errors"; -import { resolveFuncType } from "../generator/writers/resolveFuncType"; -import { AstCondition, AstExpression, AstFunctionDef, AstReceiver, AstStatement, AstValue, createAstNode } from "../grammar/ast"; +import { + AstCondition, + AstContractDeclaration, + AstExpression, + AstStatement, + AstTraitDeclaration, + AstTypeDecl, + cloneAstNode, + idText, +} from "../grammar/ast"; import { Interpreter } from "../interpreter"; -import { getAllStaticFunctions, getAllTypes, getStaticFunction, replaceStaticFunctions } from "../types/resolveDescriptors"; +import { + getAllStaticConstants, + getAllStaticFunctions, + getAllTypes, + replaceStaticConstants, + replaceStaticFunctions, + replaceTypes, +} from "../types/resolveDescriptors"; import { getExpType, registerExpType } from "../types/resolveExpression"; -import { FunctionDescription, ReceiverDescription, Value } from "../types/types"; +import { + ConstantDescription, + FieldDescription, + FunctionDescription, + InitDescription, + ReceiverDescription, + TypeDescription, + Value, +} from "../types/types"; import { makeValueExpression } from "./util"; export function simplify_expressions(ctx: CompilerContext): CompilerContext { - - // The interpreter in charge of simplifiying expressions + // The interpreter in charge of simplifying expressions const interpreter = new Interpreter(ctx); // Traverse the program and attempt to evaluate every expression // Process functions - const newFunctions: Map = new Map(); + const newStaticFunctions: Map = new Map(); for (const f of getAllStaticFunctions(ctx)) { if (f.ast.kind === "function_def") { - const statementsResult = process_statements(f.ast.statements, ctx, interpreter); + const statementsResult = process_statements( + f.ast.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; - const newFunctionCode = createAstNode({...f.ast, statements: newStatements}) as AstFunctionDef; - newFunctions.set(f.name, {...f, ast: newFunctionCode}); + const newFunctionCode = cloneAstNode({ + ...f.ast, + statements: newStatements, + }); + newStaticFunctions.set(f.name, { ...f, ast: newFunctionCode }); + } else { + // The rest of kinds do not have explicit Tact expressions, so just copy the current function description + newStaticFunctions.set(f.name, f); + } + } + ctx = replaceStaticFunctions(ctx, newStaticFunctions); + + // Process all static constants + const newStaticConstants: Map = new Map(); + + for (const c of getAllStaticConstants(ctx)) { + if (c.ast.kind === "constant_def") { + const expressionResult = process_expression( + c.ast.initializer, + ctx, + interpreter, + ); + const newInitializer = expressionResult.expr; + ctx = expressionResult.ctx; + const newConstantCode = cloneAstNode({ + ...c.ast, + initializer: newInitializer, + }); + newStaticConstants.set(c.name, { ...c, ast: newConstantCode }); + } else { + // The rest of kinds do not have explicit Tact expressions, so just copy the current description + newStaticConstants.set(c.name, c); } - // The rest of kinds do not have explicit Tact expressions. } - ctx = replaceStaticFunctions(ctx, newFunctions); + ctx = replaceStaticConstants(ctx, newStaticConstants); // Process all types + + /** + * By calling the function getAllTypes on the context object "ctx", one gets an array of TypeDescriptions. + * Each TypeDescription stores the type declarations in two different ways: + * - Directly in the TypeDescription object there are fields, constants, and method + * declarations. However, these declarations are "coalesced" in the following sense: + * If the TypeDescription is a contract, it will contain copies of methods, constants and fields of traits that the + * contract inherits from. Similarly, each trait will have declarations of other traits + * that the trait inherits from. + * + * For example, if we look into the "functions" property of the TypeDescription object of a contract + * we will find functions defined in BaseTrait. + * + * - Indirectly in the "ast" property of the TypeDescription. Contrary to the previous case, + * the fields, constants and methods in the ast property are NOT coalesced. This means, for example, + * that the methods in a TypeDescription of a contract will be methods that are actually + * declared in the contract and not in some trait that the contract inherits from. + * + * The above means that we will need to process the properties in TypeDescription first, + * and then use those properties to build the AST (carefully ensuring that only fields, constants and methods + * that were in the original AST, remain in the new AST). + */ + const newTypes: Map = new Map(); + for (const t of getAllTypes(ctx)) { + let newInitializer: InitDescription | null = null; // Process init if (t.init) { - process_statements(t.init.ast.statements, ctx, interpreter); + const statementsResult = process_statements( + t.init.ast.statements, + ctx, + interpreter, + ); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + const newInitCode = cloneAstNode({ + ...t.init.ast, + statements: newStatements, + }); + newInitializer = { ...t.init, ast: newInitCode }; } - // TODO: Need to replace initializer function + // Process constants + const newConstants: ConstantDescription[] = []; + + // This map will be used to quickly recover the new definitions when + // building the AST later + const newConstantsMap: Map = new Map(); + + for (const c of t.constants) { + if (c.ast.kind === "constant_def") { + const expressionResult = process_expression( + c.ast.initializer, + ctx, + interpreter, + ); + const newInitializer = expressionResult.expr; + ctx = expressionResult.ctx; + const newConstantCode = cloneAstNode({ + ...c.ast, + initializer: newInitializer, + }); + const newConstantDescription = { ...c, ast: newConstantCode }; + newConstants.push(newConstantDescription); + newConstantsMap.set(c.name, newConstantDescription); + } else { + // The rest of kinds do not have explicit Tact expressions, so just copy the current description + newConstants.push(c); + newConstantsMap.set(c.name, c); + } + } - // Process receivers + // Process fields + const newFields: FieldDescription[] = []; + + // This map will be used to quickly recover the new definitions when + // building the AST later + const newFieldsMap: Map = new Map(); + + for (const f of t.fields) { + if (f.ast.initializer !== null) { + const expressionResult = process_expression( + f.ast.initializer, + ctx, + interpreter, + ); + const newInitializer = expressionResult.expr; + ctx = expressionResult.ctx; + const newFieldCode = cloneAstNode({ + ...f.ast, + initializer: newInitializer, + }); + const newFieldDescription = { ...f, ast: newFieldCode }; + newFields.push(newFieldDescription); + newFieldsMap.set(f.name, newFieldDescription); + } else { + // Field without initializer, no expression to simplify inside + newFields.push(f); + newFieldsMap.set(f.name, f); + } + } + // Process receivers const newReceivers: ReceiverDescription[] = []; + // This map will be used to quickly recover the new definitions when + // building the AST later. + // Since receivers do not have names, I will use their id in their original ast + // as key. + const newReceiversMap: Map = new Map(); + for (const r of t.receivers) { - const statementsResult = process_statements(r.ast.statements, ctx, interpreter); + const statementsResult = process_statements( + r.ast.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; - const newReceiverCode = createAstNode({...r.ast, statements: newStatements}) as AstReceiver; - newReceivers.push({...r, ast: newReceiverCode}); + const newReceiverCode = cloneAstNode({ + ...r.ast, + statements: newStatements, + }); + const newReceiverDescription = { ...r, ast: newReceiverCode }; + newReceivers.push(newReceiverDescription); + newReceiversMap.set(r.ast.id, newReceiverDescription); } - // TODO: Need to replace the receivers in the type - // Process methods - for (const m of t.functions.values()) { + + // This is already a map in TypeDescription. This is the reason + // I did not need a separate map, like in the previous cases. + const newMethods: Map = new Map(); + + for (const [name, m] of t.functions) { if (m.ast.kind === "function_def") { - process_statements(m.ast.statements, ctx, interpreter); + const statementsResult = process_statements( + m.ast.statements, + ctx, + interpreter, + ); + const newStatements = statementsResult.stmts; + ctx = statementsResult.ctx; + const newMethodCode = cloneAstNode({ + ...m.ast, + statements: newStatements, + }); + newMethods.set(name, { ...m, ast: newMethodCode }); + } else { + // The rest of kinds do not have explicit Tact expressions, so just copy the current function description + newMethods.set(name, m); + } + } + + // Now, we need to create the new AST, depending on its kind. + let newAst: AstTypeDecl; + + switch (t.ast.kind) { + case "primitive_type_decl": { + newAst = t.ast; + break; + } + case "struct_decl": + case "message_decl": { + newAst = cloneAstNode({ + ...t.ast, + fields: newFields.map((f) => f.ast), + }); + break; + } + case "trait": { + const newDeclarations: AstTraitDeclaration[] = []; + + for (const decl of t.ast.declarations) { + switch (decl.kind) { + case "asm_function_def": + case "function_decl": + case "function_def": { + const newCode = newMethods.get(idText(decl.name))! + .ast as AstTraitDeclaration; + newDeclarations.push(newCode); + break; + } + case "constant_decl": + case "constant_def": { + const newCode = newConstantsMap.get( + idText(decl.name), + )!.ast; + newDeclarations.push(newCode); + break; + } + case "field_decl": { + const newCode = newFieldsMap.get( + idText(decl.name), + )!.ast; + newDeclarations.push(newCode); + break; + } + case "receiver": { + const newCode = newReceiversMap.get(decl.id)!.ast; + newDeclarations.push(newCode); + break; + } + } + } + + newAst = cloneAstNode({ + ...t.ast, + declarations: newDeclarations, + }); + + break; + } + case "contract": { + const newDeclarations: AstContractDeclaration[] = []; + + for (const decl of t.ast.declarations) { + switch (decl.kind) { + case "asm_function_def": + case "function_def": { + const newCode = newMethods.get(idText(decl.name))! + .ast as AstContractDeclaration; + newDeclarations.push(newCode); + break; + } + case "constant_def": { + const newCode = newConstantsMap.get( + idText(decl.name), + )!.ast as AstContractDeclaration; + newDeclarations.push(newCode); + break; + } + case "field_decl": { + const newCode = newFieldsMap.get( + idText(decl.name), + )!.ast; + newDeclarations.push(newCode); + break; + } + case "receiver": { + const newCode = newReceiversMap.get(decl.id)!.ast; + newDeclarations.push(newCode); + break; + } + case "contract_init": + newDeclarations.push(newInitializer!.ast); + break; + } + } + + newAst = cloneAstNode({ + ...t.ast, + declarations: newDeclarations, + }); + + break; } - // The rest of kinds do not have explicit Tact expressions. } - - // TODO: Need to replace methods + newTypes.set(t.name, { + ...t, + ast: newAst, + init: newInitializer, + constants: newConstants, + fields: newFields, + functions: newMethods, + receivers: newReceivers, + }); } - return ctx; + ctx = replaceTypes(ctx, newTypes); + return ctx; } -function process_statements(statements: AstStatement[], ctx: CompilerContext, interpreter: Interpreter): { stmts: AstStatement[], ctx: CompilerContext } { +function process_statements( + statements: AstStatement[], + ctx: CompilerContext, + interpreter: Interpreter, +): { stmts: AstStatement[]; ctx: CompilerContext } { const newStatements: AstStatement[] = []; for (const stmt of statements) { @@ -81,199 +377,265 @@ function process_statements(statements: AstStatement[], ctx: CompilerContext, in return { stmts: newStatements, ctx: ctx }; } -function process_statement(stmt: AstStatement, ctx: CompilerContext, interpreter: Interpreter): { stmt: AstStatement, ctx: CompilerContext } { +function process_statement( + stmt: AstStatement, + ctx: CompilerContext, + interpreter: Interpreter, +): { stmt: AstStatement; ctx: CompilerContext } { switch (stmt.kind) { case "statement_assign": case "statement_expression": case "statement_let": case "statement_destruct": case "statement_augmentedassign": { - const value = tryExpression(stmt.expression, interpreter); - if (value !== undefined) { - const new_expr = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, new_expr, getExpType(ctx, stmt.expression)); + const expressionResult = process_expression( + stmt.expression, + ctx, + interpreter, + ); + const new_expr = expressionResult.expr; + ctx = expressionResult.ctx; - // Create the replacement node - return { - stmt: createAstNode({ - ...stmt, - expression: new_expr, - }) as AstStatement, - ctx: ctx - }; - } + // Create the replacement node return { - stmt: stmt, - ctx: ctx + stmt: cloneAstNode({ + ...stmt, + expression: new_expr, + }), + ctx: ctx, }; } case "statement_return": { if (stmt.expression !== null) { - const value = tryExpression(stmt.expression, interpreter); - if (value !== undefined) { - const new_expr = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, new_expr, getExpType(ctx, stmt.expression)); - - // Create the replacement node - return { - stmt: createAstNode({ - ...stmt, - expression: new_expr, - }) as AstStatement, - ctx: ctx - }; - } + const expressionResult = process_expression( + stmt.expression, + ctx, + interpreter, + ); + const new_expr = expressionResult.expr; + ctx = expressionResult.ctx; + + // Create the replacement node + return { + stmt: cloneAstNode({ + ...stmt, + expression: new_expr, + }), + ctx: ctx, + }; } return { stmt: stmt, - ctx: ctx + ctx: ctx, }; } case "statement_condition": { - const value = tryExpression(stmt.condition, interpreter); - let newCondition = stmt.condition; - if (value !== undefined) { - newCondition = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, newCondition, getExpType(ctx, stmt.condition)); - } - - const trueStatementsResult = process_statements(stmt.trueStatements, ctx, interpreter); + const expressionResult = process_expression( + stmt.condition, + ctx, + interpreter, + ); + const newCondition = expressionResult.expr; + ctx = expressionResult.ctx; + + const trueStatementsResult = process_statements( + stmt.trueStatements, + ctx, + interpreter, + ); const newTrueStatements = trueStatementsResult.stmts; ctx = trueStatementsResult.ctx; let newFalseStatements: AstStatement[] | null = null; if (stmt.falseStatements !== null) { - const falseStatementsResult = process_statements(stmt.falseStatements, ctx, interpreter); + const falseStatementsResult = process_statements( + stmt.falseStatements, + ctx, + interpreter, + ); newFalseStatements = falseStatementsResult.stmts; ctx = falseStatementsResult.ctx; } let newElseIf: AstCondition | null = null; if (stmt.elseif !== null) { - const elseIfResult = process_statement(stmt.elseif, ctx, interpreter); + const elseIfResult = process_statement( + stmt.elseif, + ctx, + interpreter, + ); newElseIf = elseIfResult.stmt as AstCondition; ctx = elseIfResult.ctx; } // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, condition: newCondition, trueStatements: newTrueStatements, falseStatements: newFalseStatements, - elseif: newElseIf - }) as AstStatement, - ctx: ctx + elseif: newElseIf, + }), + ctx: ctx, }; } case "statement_foreach": { - const value = tryExpression(stmt.map, interpreter); - let newMap = stmt.map; - if (value !== undefined) { - newMap = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, newMap, getExpType(ctx, stmt.map)); - } - const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const expressionResult = process_expression( + stmt.map, + ctx, + interpreter, + ); + const newMap = expressionResult.expr; + ctx = expressionResult.ctx; + + const statementsResult = process_statements( + stmt.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, map: newMap, statements: newStatements, - }) as AstStatement, - ctx: ctx + }), + ctx: ctx, }; } case "statement_until": case "statement_while": { - const value = tryExpression(stmt.condition, interpreter); - let newCondition = stmt.condition; - if (value !== undefined) { - newCondition = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, newCondition, getExpType(ctx, stmt.condition)); - } - const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const expressionResult = process_expression( + stmt.condition, + ctx, + interpreter, + ); + const newCondition = expressionResult.expr; + ctx = expressionResult.ctx; + + const statementsResult = process_statements( + stmt.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, condition: newCondition, statements: newStatements, - }) as AstStatement, - ctx: ctx + }), + ctx: ctx, }; } case "statement_repeat": { - const value = tryExpression(stmt.iterations, interpreter); - let newIterations = stmt.iterations; - if (value !== undefined) { - newIterations = makeValueExpression(value); - // Register the new expression in the context - ctx = registerExpType(ctx, newIterations, getExpType(ctx, stmt.iterations)); - } - const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const expressionResult = process_expression( + stmt.iterations, + ctx, + interpreter, + ); + const newIterations = expressionResult.expr; + ctx = expressionResult.ctx; + + const statementsResult = process_statements( + stmt.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, iterations: newIterations, statements: newStatements, - }) as AstStatement, - ctx: ctx + }), + ctx: ctx, }; } case "statement_try": { - const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const statementsResult = process_statements( + stmt.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, statements: newStatements, - }) as AstStatement, - ctx: ctx + }), + ctx: ctx, }; } case "statement_try_catch": { - const statementsResult = process_statements(stmt.statements, ctx, interpreter); + const statementsResult = process_statements( + stmt.statements, + ctx, + interpreter, + ); const newStatements = statementsResult.stmts; ctx = statementsResult.ctx; - const catchStatementsResult = process_statements(stmt.catchStatements, ctx, interpreter); + const catchStatementsResult = process_statements( + stmt.catchStatements, + ctx, + interpreter, + ); const newCatchStatements = catchStatementsResult.stmts; ctx = catchStatementsResult.ctx; // Create the replacement node return { - stmt: createAstNode({ + stmt: cloneAstNode({ ...stmt, statements: newStatements, - catchStatements: newCatchStatements - }) as AstStatement, - ctx: ctx + catchStatements: newCatchStatements, + }), + ctx: ctx, }; } } } -function tryExpression(expr: AstExpression, interpreter: Interpreter): Value | undefined { +function process_expression( + expr: AstExpression, + ctx: CompilerContext, + interpreter: Interpreter, +): { expr: AstExpression; ctx: CompilerContext } { + const value = tryExpression(expr, interpreter); + let newExpr = expr; + if (value !== undefined) { + try { + newExpr = makeValueExpression(value); + // Register the new expression in the context + ctx = registerExpType(ctx, newExpr, getExpType(ctx, expr)); + } catch (_) { + // This means that transforming the value into an AST node is + // unsupported or it failed to register the type of the expression. + // Just use the original expression. + newExpr = expr; + } + } + return { expr: newExpr, ctx: ctx }; +} + +function tryExpression( + expr: AstExpression, + interpreter: Interpreter, +): Value | undefined { try { // Eventually, this will be replaced by the partial evaluator. return interpreter.interpretExpression(expr); diff --git a/src/optimizer/optimization_phase.ts b/src/optimizer/optimization_phase.ts index 72005dc1e..37df6f4e5 100644 --- a/src/optimizer/optimization_phase.ts +++ b/src/optimizer/optimization_phase.ts @@ -1,12 +1,40 @@ import { CompilerContext } from "../context"; +import { prettyPrint } from "../prettyPrinter"; +import { + getAllStaticConstants, + getAllStaticFunctions, + getAllTypes, +} from "../types/resolveDescriptors"; import { simplify_expressions } from "./expr_simplification"; +import { writeFileSync } from "fs"; export function optimize_tact(ctx: CompilerContext): CompilerContext { - // Call the expression simplification phase ctx = simplify_expressions(ctx); // Here, we will call the constant propagation analyzer return ctx; -} \ No newline at end of file +} + +export function dump_tact_code(ctx: CompilerContext, file: string) { + let program = ""; + + for (const c of getAllStaticConstants(ctx)) { + program += `${prettyPrint(c.ast)}\n`; + } + + program += "\n"; + + for (const f of getAllStaticFunctions(ctx)) { + program += `${prettyPrint(f.ast)}\n\n`; + } + + for (const t of getAllTypes(ctx)) { + program += `${prettyPrint(t.ast)}\n\n`; + } + + writeFileSync(file, program, { + flag: "w", + }); +} diff --git a/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap b/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap new file mode 100644 index 000000000..2e48a0932 --- /dev/null +++ b/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap @@ -0,0 +1,11 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`expression-simplification should fail expression simplification for interpreter-called-when-no-contract 1`] = ` +":6:9: Cannot evaluate expression to a constant: divisor expression must be non-zero +Line 6, col 9: + 5 | fun blowup(v: Int) { +> 6 | 1 / v; + ^ + 7 | } +" +`; diff --git a/src/optimizer/test/expr_simplification.spec.ts b/src/optimizer/test/expr_simplification.spec.ts new file mode 100644 index 000000000..e344d8992 --- /dev/null +++ b/src/optimizer/test/expr_simplification.spec.ts @@ -0,0 +1,44 @@ +import { featureEnable } from "../../config/features"; +import { CompilerContext } from "../../context"; +import { __DANGER_resetNodeId } from "../../grammar/ast"; +import { openContext } from "../../grammar/store"; +import { resolveDescriptors } from "../../types/resolveDescriptors"; +import { getAllExpressionTypes } from "../../types/resolveExpression"; +import { resolveStatements } from "../../types/resolveStatements"; +import { loadCases } from "../../utils/loadCases"; +import { simplify_expressions } from "../expr_simplification"; + +describe("expression-simplification", () => { + beforeEach(() => { + __DANGER_resetNodeId(); + }); + for (const r of loadCases(__dirname + "/success/")) { + it("should pass expression simplification for " + r.name, () => { + let ctx = openContext( + new CompilerContext(), + [{ code: r.code, path: "", origin: "user" }], + [], + ); + ctx = featureEnable(ctx, "external"); + ctx = resolveDescriptors(ctx); + ctx = resolveStatements(ctx); + ctx = simplify_expressions(ctx); + expect(getAllExpressionTypes(ctx)).toMatchSnapshot(); + }); + } + for (const r of loadCases(__dirname + "/failed/")) { + it("should fail expression simplification for " + r.name, () => { + let ctx = openContext( + new CompilerContext(), + [{ code: r.code, path: "", origin: "user" }], + [], + ); + ctx = featureEnable(ctx, "external"); + ctx = resolveDescriptors(ctx); + ctx = resolveStatements(ctx); + expect(() => { + simplify_expressions(ctx); + }).toThrowErrorMatchingSnapshot(); + }); + } +}); diff --git a/src/optimizer/test/failed/interpreter-called-when-no-contract.tact b/src/optimizer/test/failed/interpreter-called-when-no-contract.tact new file mode 100644 index 000000000..d5c8777c4 --- /dev/null +++ b/src/optimizer/test/failed/interpreter-called-when-no-contract.tact @@ -0,0 +1,11 @@ +// Interpreter should be called even if no contract is declared in the file. + +primitive Int; + +fun blowup(v: Int) { + 1 / v; +} + +fun test() { + blowup(0); // Interpreter will execute this call and inform of a division by zero. +} diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts index 4b98ab779..03b61d630 100644 --- a/src/optimizer/util.ts +++ b/src/optimizer/util.ts @@ -5,15 +5,16 @@ import { createAstNode, AstValue, isValue, + AstId, + AstStructFieldInitializer, + idText, } from "../grammar/ast"; import { dummySrcInfo } from "../grammar/grammar"; import { throwInternalCompilerError } from "../errors"; -import { Value } from "../types/types"; +import { StructValue, Value } from "../types/types"; export function extractValue(ast: AstValue): Value { - switch ( - ast.kind // Missing structs - ) { + switch (ast.kind) { case "null": return null; case "boolean": @@ -22,6 +23,16 @@ export function extractValue(ast: AstValue): Value { return ast.value; case "string": return ast.value; + case "struct_instance": + return ast.args.reduce( + (resObj, fieldWithInit) => { + resObj[idText(fieldWithInit.field)] = extractValue( + fieldWithInit.initializer as AstValue, + ); + return resObj; + }, + { $tactStruct: idText(ast.type) } as StructValue, + ); } } @@ -58,11 +69,39 @@ export function makeValueExpression(value: Value): AstValue { }); return result as AstValue; } + if (typeof value === "object" && "$tactStruct" in value) { + const fields = Object.entries(value) + .filter(([name, _]) => name !== "$tactStruct") + .map(([name, val]) => { + return createAstNode({ + kind: "struct_field_initializer", + field: makeIdExpression(name), + initializer: makeValueExpression(val), + loc: dummySrcInfo, + }) as AstStructFieldInitializer; + }); + const result = createAstNode({ + kind: "struct_instance", + type: makeIdExpression(value["$tactStruct"] as string), + args: fields, + loc: dummySrcInfo, + }); + return result as AstValue; + } throwInternalCompilerError( - `structs, addresses, cells, and comment values are not supported at the moment.`, + `addresses, cells, and comment values are not supported as AST nodes at the moment.`, ); } +function makeIdExpression(name: string): AstId { + const result = createAstNode({ + kind: "id", + text: name, + loc: dummySrcInfo, + }); + return result as AstId; +} + export function makeUnaryExpression( op: AstUnaryOperation, operand: AstExpression, diff --git a/src/pipeline/build.ts b/src/pipeline/build.ts index e968c5074..c1a5f5426 100644 --- a/src/pipeline/build.ts +++ b/src/pipeline/build.ts @@ -21,7 +21,7 @@ import { precompile } from "./precompile"; import { getCompilerVersion } from "./version"; import { idText } from "../grammar/ast"; import { TactErrorCollection } from "../errors"; -import { optimize_tact } from "../optimizer/optimization_phase"; +import { dump_tact_code, optimize_tact } from "../optimizer/optimization_phase"; export function enableFeatures( ctx: CompilerContext, @@ -87,13 +87,35 @@ export async function build(args: { return { ok: true, error: [] }; } - // Run high level optimization phase - try { - // TODO: Add configuration option to dump optimized code and also activate/deactivate optimization phase - ctx = optimize_tact(ctx); - } catch (e) { - logger.error(e as Error); - return { ok: false, error: [e as Error] }; + // Run high level optimization phase, if active in the options. + if ( + config.options?.skipTactOptimizationPhase === undefined || + !config.options.skipTactOptimizationPhase + ) { + try { + if (config.options?.dumpOptimizedTactCode) { + // Dump the code before optimization + dump_tact_code( + ctx, + config.output + + `/${config.name}_unoptimized_tact_dump.tact`, + ); + } + + ctx = optimize_tact(ctx); + + if (config.options?.dumpOptimizedTactCode) { + // Dump the code after optimization + dump_tact_code( + ctx, + config.output + `/${config.name}_optimized_tact_dump.tact`, + ); + } + } catch (e) { + logger.error("Tact code optimization failed."); + logger.error(e as Error); + return { ok: false, error: [e as Error] }; + } } // Compile contracts diff --git a/src/prettyPrinter.ts b/src/prettyPrinter.ts index 4c7685802..79f888a36 100644 --- a/src/prettyPrinter.ts +++ b/src/prettyPrinter.ts @@ -784,6 +784,8 @@ export function prettyPrint(node: AstNode): string { return pp.ppAstStructFieldInit(node); case "import": return pp.ppAstImport(node); + case "asm_function_def": + return pp.ppAstAsmFunctionDef(node); default: throwInternalCompilerError( `Unsupported AST type: ${JSONbig.stringify(node, null, 2)}`, diff --git a/src/types/resolveDescriptors.ts b/src/types/resolveDescriptors.ts index c4fc6d8fc..adf4bd01a 100644 --- a/src/types/resolveDescriptors.ts +++ b/src/types/resolveDescriptors.ts @@ -2007,9 +2007,32 @@ export function getStaticFunction( return r; } -export function replaceStaticFunctions(ctx: CompilerContext, newFunctions: Map): CompilerContext { +export function replaceStaticConstants( + ctx: CompilerContext, + newConstants: Map, +): CompilerContext { + for (const [name, constDesc] of newConstants) { + ctx = staticConstantsStore.set(ctx, name, constDesc); + } + return ctx; +} + +export function replaceStaticFunctions( + ctx: CompilerContext, + newFunctions: Map, +): CompilerContext { for (const [name, funcDesc] of newFunctions) { - ctx = staticFunctionsStore.set(ctx, name, funcDesc) + ctx = staticFunctionsStore.set(ctx, name, funcDesc); + } + return ctx; +} + +export function replaceTypes( + ctx: CompilerContext, + newTypes: Map, +): CompilerContext { + for (const [name, typeDesc] of newTypes) { + ctx = store.set(ctx, name, typeDesc); } return ctx; } From c5d86780a5c94adc4da0bedb006960c935609b70 Mon Sep 17 00:00:00 2001 From: jeshecdom Date: Mon, 18 Nov 2024 16:18:42 +0100 Subject: [PATCH 3/5] Added positive tests. Further fixes: - subexpressions inside struct instances did not have their types registered correctly. - makeValueExpression now receives a SrcInfo object. --- package.json | 2 +- src/optimizer/expr_simplification.ts | 62 ++++++- .../expr_simplification.spec.ts.snap | 173 ++++++++++++++++++ ...terpreter-simplifies-when-no-contract.tact | 33 ++++ src/optimizer/util.ts | 60 ++++-- yarn.lock | 2 +- 6 files changed, 310 insertions(+), 22 deletions(-) create mode 100644 src/optimizer/test/success/interpreter-simplifies-when-no-contract.tact diff --git a/package.json b/package.json index 4db5dec06..8f281028f 100644 --- a/package.json +++ b/package.json @@ -77,7 +77,7 @@ "@typescript-eslint/parser": "^7.0.2", "ajv-cli": "^5.0.0", "cross-env": "^7.0.3", - "cspell": "^8.8.3", + "cspell": "^8.16.0", "eslint": "^8.56.0", "glob": "^8.1.0", "husky": "^9.1.5", diff --git a/src/optimizer/expr_simplification.ts b/src/optimizer/expr_simplification.ts index 9ba55dfe1..ba601f1d0 100644 --- a/src/optimizer/expr_simplification.ts +++ b/src/optimizer/expr_simplification.ts @@ -1,5 +1,5 @@ import { CompilerContext } from "../context"; -import { TactConstEvalError } from "../errors"; +import { TactConstEvalError, throwInternalCompilerError } from "../errors"; import { AstCondition, AstContractDeclaration, @@ -7,14 +7,18 @@ import { AstStatement, AstTraitDeclaration, AstTypeDecl, + AstValue, cloneAstNode, idText, + isValue, + SrcInfo, } from "../grammar/ast"; import { Interpreter } from "../interpreter"; import { getAllStaticConstants, getAllStaticFunctions, getAllTypes, + getType, replaceStaticConstants, replaceStaticFunctions, replaceTypes, @@ -27,6 +31,7 @@ import { InitDescription, ReceiverDescription, TypeDescription, + TypeRef, Value, } from "../types/types"; import { makeValueExpression } from "./util"; @@ -101,7 +106,7 @@ export function simplify_expressions(ctx: CompilerContext): CompilerContext { * * - Indirectly in the "ast" property of the TypeDescription. Contrary to the previous case, * the fields, constants and methods in the ast property are NOT coalesced. This means, for example, - * that the methods in a TypeDescription of a contract will be methods that are actually + * that the methods in a TypeDescription's ast of a contract will be methods that are actually * declared in the contract and not in some trait that the contract inherits from. * * The above means that we will need to process the properties in TypeDescription first, @@ -615,13 +620,13 @@ function process_expression( ctx: CompilerContext, interpreter: Interpreter, ): { expr: AstExpression; ctx: CompilerContext } { - const value = tryExpression(expr, interpreter); + const value = tryExpressionSimplification(expr, interpreter); let newExpr = expr; if (value !== undefined) { try { - newExpr = makeValueExpression(value); + newExpr = makeValueExpression(value, expr.loc); // Register the new expression in the context - ctx = registerExpType(ctx, newExpr, getExpType(ctx, expr)); + ctx = registerAllSubExpTypes(ctx, newExpr, getExpType(ctx, expr)); } catch (_) { // This means that transforming the value into an AST node is // unsupported or it failed to register the type of the expression. @@ -632,7 +637,7 @@ function process_expression( return { expr: newExpr, ctx: ctx }; } -function tryExpression( +function tryExpressionSimplification( expr: AstExpression, interpreter: Interpreter, ): Value | undefined { @@ -648,3 +653,48 @@ function tryExpression( throw e; } } + +function registerAllSubExpTypes(ctx: CompilerContext, expr: AstValue, expType: TypeRef): CompilerContext { + switch(expr.kind) { + case "boolean": + case "number": + case "string": + case "null": { + ctx = registerExpType(ctx, expr, expType); + break; + } + case "struct_instance": { + ctx = registerExpType(ctx, expr, expType); + + const structFields = getType(ctx, expr.type).fields; + const fieldTypes: Map = new Map(); + + for (const field of structFields) { + fieldTypes.set(field.name, field.type); + } + + for (const fieldValue of expr.args) { + // Typechecking ensures that each field in the struct instance has a type + const fieldType = fieldTypes.get(idText(fieldValue.field)); + if (fieldType === undefined) { + throwInternalCompilerError(`Field ${idText(fieldValue.field)} does not have a declared type in struct ${idText(expr.type)}.`, fieldValue.loc); + } + ctx = registerAllSubExpTypes(ctx, ensureAstValue(fieldValue.initializer, fieldValue.loc), fieldType); + } + } + } + return ctx; +} + +function ensureAstValue(expr: AstExpression, src: SrcInfo): AstValue { + switch(expr.kind) { + case "boolean": + case "null": + case "number": + case "string": + case "struct_instance": + return expr; + default: + throwInternalCompilerError(`Expressions of kind ${expr.kind} are not ASTValues.`, src); + } +} diff --git a/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap b/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap index 2e48a0932..158bd8c9f 100644 --- a/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap +++ b/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap @@ -9,3 +9,176 @@ Line 6, col 9: 7 | } " `; + +exports[`expression-simplification should pass expression simplification for interpreter-simplifies-when-no-contract 1`] = ` +[ + [ + "v", + "Int", + ], + [ + "10", + "Int", + ], + [ + "v + 10", + "Int", + ], + [ + "3", + "Int", + ], + [ + "v + 10 + 3", + "Int", + ], + [ + "3", + "Int", + ], + [ + "7", + "Int", + ], + [ + "3 + 7", + "Int", + ], + [ + "A {a: v + 10 + 3, b: 3 + 7}", + "A", + ], + [ + "s", + "A", + ], + [ + "true", + "Bool", + ], + [ + "B {nested: s, c: true}", + "B", + ], + [ + "0", + "Int", + ], + [ + "exprFun1(0)", + "A", + ], + [ + "exprFun1(0).a", + "Int", + ], + [ + "2", + "Int", + ], + [ + "exprFun1(2)", + "A", + ], + [ + "exprFun2(exprFun1(2))", + "B", + ], + [ + "1", + "Int", + ], + [ + "exprFun1(1)", + "A", + ], + [ + "exprFun2(exprFun1(1))", + "B", + ], + [ + "exprFun2(exprFun1(1)).nested", + "A", + ], + [ + "2", + "Int", + ], + [ + "exprFun1(2)", + "A", + ], + [ + "exprFun2(exprFun1(2))", + "B", + ], + [ + "exprFun2(exprFun1(2)).c", + "Bool", + ], + [ + "c1", + "Int", + ], + [ + "exprFun1(c1)", + "A", + ], + [ + "exprFun1(c1).a", + "Int", + ], + [ + "0", + "Int", + ], + [ + "exprFun1(c1).a > 0", + "Bool", + ], + [ + "exprFun2(exprFun1(2)).c || exprFun1(c1).a > 0", + "Bool", + ], + [ + "13", + "Int", + ], + [ + "B { nested: A { a: 15, b: 10 }, c: true }", + "B", + ], + [ + "A { a: 15, b: 10 }", + "A", + ], + [ + "15", + "Int", + ], + [ + "10", + "Int", + ], + [ + "true", + "Bool", + ], + [ + "A { a: 14, b: 10 }", + "A", + ], + [ + "14", + "Int", + ], + [ + "10", + "Int", + ], + [ + "true", + "Bool", + ], +] +`; diff --git a/src/optimizer/test/success/interpreter-simplifies-when-no-contract.tact b/src/optimizer/test/success/interpreter-simplifies-when-no-contract.tact new file mode 100644 index 000000000..adf794a4f --- /dev/null +++ b/src/optimizer/test/success/interpreter-simplifies-when-no-contract.tact @@ -0,0 +1,33 @@ +// Interpreter should be called even if no contract is declared in the file. + +primitive Int; +primitive Bool; + +struct A { + a: Int; + b: Int; +} + +struct B { + nested: A; + c: Bool; +} + +fun exprFun1(v: Int): A { + return A {a: v + 10 + 3, b: 3 + 7}; // Interpreter cannot simplify field "b", because field "a" fails interpretation, which means + // that the entire struct instance fails interpretation. + // To actually simplify field "b" (and also field "a" which could be simplified to "v + 13"), + // we need here the partial evaluator. +} + +fun exprFun2(s: A): B { + return B {nested: s, c: true}; +} + +fun test() { + let c1 = exprFun1(0).a; // Interpreter simplifies to 13 + let c2 = exprFun2(exprFun1(2)); // Interpreter simplifies to "B {nested: A {a: 15, b: 10}, c: true}" + let c3 = exprFun2(exprFun1(1)).nested; // Interpreter simplifies to "A {a: 14, b: 10}" + let c4 = exprFun2(exprFun1(2)).c || exprFun1(c1).a > 0; // Interpreter simplifies to "true" because || short-circuits, + // even if exprFun1(c1) cannot be simplified. +} \ No newline at end of file diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts index 03b61d630..8dcd2761d 100644 --- a/src/optimizer/util.ts +++ b/src/optimizer/util.ts @@ -9,7 +9,7 @@ import { AstStructFieldInitializer, idText, } from "../grammar/ast"; -import { dummySrcInfo } from "../grammar/grammar"; +import { dummySrcInfo, SrcInfo } from "../grammar/grammar"; import { throwInternalCompilerError } from "../errors"; import { StructValue, Value } from "../types/types"; @@ -36,11 +36,16 @@ export function extractValue(ast: AstValue): Value { } } -export function makeValueExpression(value: Value): AstValue { +export function makeValueExpression(value: Value, baseSrc: SrcInfo = dummySrcInfo): AstValue { + const valueString = valueToString(value); + // Keep all the info of the original source, but force the contents to have the + // new expression. + const newSrc = new SrcInfo({...baseSrc.interval, contents: valueString}, baseSrc.file, baseSrc.origin); + if (value === null) { const result = createAstNode({ kind: "null", - loc: dummySrcInfo, + loc: newSrc, }); return result as AstValue; } @@ -48,7 +53,7 @@ export function makeValueExpression(value: Value): AstValue { const result = createAstNode({ kind: "string", value: value, - loc: dummySrcInfo, + loc: newSrc, }); return result as AstValue; } @@ -57,7 +62,7 @@ export function makeValueExpression(value: Value): AstValue { kind: "number", base: 10, value: value, - loc: dummySrcInfo, + loc: newSrc, }); return result as AstValue; } @@ -65,7 +70,7 @@ export function makeValueExpression(value: Value): AstValue { const result = createAstNode({ kind: "boolean", value: value, - loc: dummySrcInfo, + loc: newSrc, }); return result as AstValue; } @@ -75,33 +80,60 @@ export function makeValueExpression(value: Value): AstValue { .map(([name, val]) => { return createAstNode({ kind: "struct_field_initializer", - field: makeIdExpression(name), - initializer: makeValueExpression(val), - loc: dummySrcInfo, + field: makeIdExpression(name, baseSrc), + initializer: makeValueExpression(val, baseSrc), + loc: newSrc, }) as AstStructFieldInitializer; }); const result = createAstNode({ kind: "struct_instance", - type: makeIdExpression(value["$tactStruct"] as string), + type: makeIdExpression(value["$tactStruct"] as string, baseSrc), args: fields, - loc: dummySrcInfo, + loc: newSrc, }); return result as AstValue; } + value; throwInternalCompilerError( - `addresses, cells, and comment values are not supported as AST nodes at the moment.`, + "addresses, cells, slices, and comment values are not supported as AST nodes at the moment.", ); } -function makeIdExpression(name: string): AstId { +function makeIdExpression(name: string, baseSrc: SrcInfo): AstId { const result = createAstNode({ kind: "id", text: name, - loc: dummySrcInfo, + loc: baseSrc, }); return result as AstId; } +export function valueToString(value: Value): string { + if (value === null) { + return "null"; + } + if (typeof value === "string") { + return value; + } + if (typeof value === "bigint") { + return value.toString(); + } + if (typeof value === "boolean") { + return value.toString(); + } + if (typeof value === "object" && "$tactStruct" in value) { + const fields = Object.entries(value) + .filter(([name, _]) => name !== "$tactStruct") + .map(([name, val]) => { + return `${name}: ${valueToString(val)}` + }).join(", "); + return `${value["$tactStruct"]} { ${fields} }`; + } + throwInternalCompilerError( + "Transformation of addresses, cells, slices or comment values into strings is not supported at the moment.", + ); +} + export function makeUnaryExpression( op: AstUnaryOperation, operand: AstExpression, diff --git a/yarn.lock b/yarn.lock index ebf553d9b..42cfe71dc 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2354,7 +2354,7 @@ cspell-trie-lib@8.16.0: "@cspell/cspell-types" "8.16.0" gensequence "^7.0.0" -cspell@^8.8.3: +cspell@^8.16.0: version "8.16.0" resolved "https://registry.npmjs.org/cspell/-/cspell-8.16.0.tgz#1897d123f8854304bc84ac332590c730e93c5123" integrity sha512-U6Up/4nODE+Ca+zqwZXTgBioGuF2JQHLEUIuoRJkJzAZkIBYDqrMXM+zdSL9E39+xb9jAtr9kPAYJf1Eybgi9g== From e9f02f5c5127e3b954dfa6702cc7c32112ad4d2b Mon Sep 17 00:00:00 2001 From: jeshecdom Date: Mon, 18 Nov 2024 16:24:46 +0100 Subject: [PATCH 4/5] Lint, prettier, spell, knip check --- src/optimizer/expr_simplification.ts | 29 ++++++++++++++++++++-------- src/optimizer/util.ts | 23 ++++++++++++++-------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/optimizer/expr_simplification.ts b/src/optimizer/expr_simplification.ts index ba601f1d0..f07598f65 100644 --- a/src/optimizer/expr_simplification.ts +++ b/src/optimizer/expr_simplification.ts @@ -10,7 +10,6 @@ import { AstValue, cloneAstNode, idText, - isValue, SrcInfo, } from "../grammar/ast"; import { Interpreter } from "../interpreter"; @@ -654,8 +653,12 @@ function tryExpressionSimplification( } } -function registerAllSubExpTypes(ctx: CompilerContext, expr: AstValue, expType: TypeRef): CompilerContext { - switch(expr.kind) { +function registerAllSubExpTypes( + ctx: CompilerContext, + expr: AstValue, + expType: TypeRef, +): CompilerContext { + switch (expr.kind) { case "boolean": case "number": case "string": @@ -668,7 +671,7 @@ function registerAllSubExpTypes(ctx: CompilerContext, expr: AstValue, expType: T const structFields = getType(ctx, expr.type).fields; const fieldTypes: Map = new Map(); - + for (const field of structFields) { fieldTypes.set(field.name, field.type); } @@ -677,9 +680,16 @@ function registerAllSubExpTypes(ctx: CompilerContext, expr: AstValue, expType: T // Typechecking ensures that each field in the struct instance has a type const fieldType = fieldTypes.get(idText(fieldValue.field)); if (fieldType === undefined) { - throwInternalCompilerError(`Field ${idText(fieldValue.field)} does not have a declared type in struct ${idText(expr.type)}.`, fieldValue.loc); + throwInternalCompilerError( + `Field ${idText(fieldValue.field)} does not have a declared type in struct ${idText(expr.type)}.`, + fieldValue.loc, + ); } - ctx = registerAllSubExpTypes(ctx, ensureAstValue(fieldValue.initializer, fieldValue.loc), fieldType); + ctx = registerAllSubExpTypes( + ctx, + ensureAstValue(fieldValue.initializer, fieldValue.loc), + fieldType, + ); } } } @@ -687,7 +697,7 @@ function registerAllSubExpTypes(ctx: CompilerContext, expr: AstValue, expType: T } function ensureAstValue(expr: AstExpression, src: SrcInfo): AstValue { - switch(expr.kind) { + switch (expr.kind) { case "boolean": case "null": case "number": @@ -695,6 +705,9 @@ function ensureAstValue(expr: AstExpression, src: SrcInfo): AstValue { case "struct_instance": return expr; default: - throwInternalCompilerError(`Expressions of kind ${expr.kind} are not ASTValues.`, src); + throwInternalCompilerError( + `Expressions of kind ${expr.kind} are not ASTValues.`, + src, + ); } } diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts index 8dcd2761d..d8d5810a3 100644 --- a/src/optimizer/util.ts +++ b/src/optimizer/util.ts @@ -36,11 +36,18 @@ export function extractValue(ast: AstValue): Value { } } -export function makeValueExpression(value: Value, baseSrc: SrcInfo = dummySrcInfo): AstValue { +export function makeValueExpression( + value: Value, + baseSrc: SrcInfo = dummySrcInfo, +): AstValue { const valueString = valueToString(value); - // Keep all the info of the original source, but force the contents to have the + // Keep all the info of the original source, but force the contents to have the // new expression. - const newSrc = new SrcInfo({...baseSrc.interval, contents: valueString}, baseSrc.file, baseSrc.origin); + const newSrc = new SrcInfo( + { ...baseSrc.interval, contents: valueString }, + baseSrc.file, + baseSrc.origin, + ); if (value === null) { const result = createAstNode({ @@ -93,7 +100,6 @@ export function makeValueExpression(value: Value, baseSrc: SrcInfo = dummySrcInf }); return result as AstValue; } - value; throwInternalCompilerError( "addresses, cells, slices, and comment values are not supported as AST nodes at the moment.", ); @@ -108,7 +114,7 @@ function makeIdExpression(name: string, baseSrc: SrcInfo): AstId { return result as AstId; } -export function valueToString(value: Value): string { +function valueToString(value: Value): string { if (value === null) { return "null"; } @@ -125,9 +131,10 @@ export function valueToString(value: Value): string { const fields = Object.entries(value) .filter(([name, _]) => name !== "$tactStruct") .map(([name, val]) => { - return `${name}: ${valueToString(val)}` - }).join(", "); - return `${value["$tactStruct"]} { ${fields} }`; + return `${name}: ${valueToString(val)}`; + }) + .join(", "); + return `${value["$tactStruct"] as string} { ${fields} }`; } throwInternalCompilerError( "Transformation of addresses, cells, slices or comment values into strings is not supported at the moment.", From 90bcde3e556a40546c4961ec5c6c59e3988bfbe8 Mon Sep 17 00:00:00 2001 From: jeshecdom Date: Mon, 16 Dec 2024 20:17:34 +0100 Subject: [PATCH 5/5] Addressed issues in review. Pending changes related to Value type. These will be addressed once Value type gets refactored in issue #1190. --- schemas/configSchema.json | 4 +- src/config/parseConfig.ts | 5 +- src/grammar/ast.ts | 66 + src/optimizer/expr-simplification.ts | 367 ++++++ src/optimizer/expr_simplification.ts | 713 ---------- src/optimizer/optimization-phase.ts | 1152 +++++++++++++++++ src/optimizer/optimization_phase.ts | 40 - ....snap => expr-simplification.spec.ts.snap} | 76 ++ ...on.spec.ts => expr-simplification.spec.ts} | 11 +- src/optimizer/util.ts | 12 +- src/pipeline/build.ts | 66 +- src/types/resolveDescriptors.ts | 30 - src/types/resolveExpression.ts | 8 + 13 files changed, 1729 insertions(+), 821 deletions(-) create mode 100644 src/optimizer/expr-simplification.ts delete mode 100644 src/optimizer/expr_simplification.ts create mode 100644 src/optimizer/optimization-phase.ts delete mode 100644 src/optimizer/optimization_phase.ts rename src/optimizer/test/__snapshots__/{expr_simplification.spec.ts.snap => expr-simplification.spec.ts.snap} (71%) rename src/optimizer/test/{expr_simplification.spec.ts => expr-simplification.spec.ts} (78%) diff --git a/schemas/configSchema.json b/schemas/configSchema.json index ada0fd544..82595f59d 100644 --- a/schemas/configSchema.json +++ b/schemas/configSchema.json @@ -61,10 +61,10 @@ "default": false, "description": "False by default. If set to true, skips the Tact code optimization phase." }, - "dumpOptimizedTactCode": { + "dumpCodeBeforeAndAfterTactOptimizationPhase": { "type": "boolean", "default": false, - "description": "False by default. If set to true, dumps the code produced by the Tact code optimization phase. In case the optimization phase is skipped, this option is ignored." + "description": "False by default. If set to true, dumps the code produced before and after the Tact code optimization phase." }, "experimental": { "type": "object", diff --git a/src/config/parseConfig.ts b/src/config/parseConfig.ts index 66f52143a..ac01c0815 100644 --- a/src/config/parseConfig.ts +++ b/src/config/parseConfig.ts @@ -38,10 +38,9 @@ export const optionsSchema = z */ skipTactOptimizationPhase: z.boolean().optional(), /** - * If set to true, dumps the code produced by the Tact code optimization phase. - * In case the optimization phase is skipped, this option is ignored. + * If set to true, dumps the code produced before and after the Tact code optimization phase. */ - dumpOptimizedTactCode: z.boolean().optional(), + dumpCodeBeforeAndAfterTactOptimizationPhase: z.boolean().optional(), /** * Experimental options that might be removed in the future. Use with caution! */ diff --git a/src/grammar/ast.ts b/src/grammar/ast.ts index b2cc17a7d..d39e46bdd 100644 --- a/src/grammar/ast.ts +++ b/src/grammar/ast.ts @@ -1,3 +1,4 @@ +import { throwInternalCompilerError } from "../errors"; import { dummySrcInfo, SrcInfo } from "./grammar"; export type AstModule = { @@ -917,6 +918,71 @@ export function isValue(ast: AstExpression): boolean { case "field_access": case "static_call": return false; + default: + throwInternalCompilerError("Unrecognized AstExpression"); + } +} + +export function isAstExpression(ast: AstNode): ast is AstExpression { + switch (ast.kind) { + case "null": + case "boolean": + case "number": + case "string": + case "id": + case "struct_instance": + case "method_call": + case "init_of": + case "op_unary": + case "op_binary": + case "conditional": + case "field_access": + case "static_call": + return true; + + case "asm_function_def": + case "bounced_message_type": + case "constant_decl": + case "constant_def": + case "contract": + case "contract_init": + case "destruct_end": + case "destruct_mapping": + case "field_decl": + case "func_id": + case "function_attribute": + case "function_decl": + case "function_def": + case "import": + case "map_type": + case "message_decl": + case "module": + case "native_function_decl": + case "optional_type": + case "primitive_type_decl": + case "receiver": + case "statement_assign": + case "statement_augmentedassign": + case "statement_condition": + case "statement_destruct": + case "statement_expression": + case "statement_foreach": + case "statement_let": + case "statement_repeat": + case "statement_return": + case "statement_try": + case "statement_try_catch": + case "statement_until": + case "statement_while": + case "struct_decl": + case "struct_field_initializer": + case "trait": + case "type_id": + case "typed_parameter": + return false; + + default: + throwInternalCompilerError("Unrecognized AstNode"); } } diff --git a/src/optimizer/expr-simplification.ts b/src/optimizer/expr-simplification.ts new file mode 100644 index 000000000..aa2851278 --- /dev/null +++ b/src/optimizer/expr-simplification.ts @@ -0,0 +1,367 @@ +import { CompilerContext } from "../context"; +import { TactConstEvalError, throwInternalCompilerError } from "../errors"; +import { + AstConstantDef, + AstContractDeclaration, + AstExpression, + AstFieldDecl, + AstStatement, + AstTraitDeclaration, + AstValue, + idText, + SrcInfo, +} from "../grammar/ast"; +import { Interpreter } from "../interpreter"; +import { getType } from "../types/resolveDescriptors"; +import { getExpType, registerExpType } from "../types/resolveExpression"; +import { TypeRef, Value } from "../types/types"; +import { + OptimizationContext, + registerAstNodeChange, +} from "./optimization-phase"; +import { makeValueExpression, UnsupportedOperation } from "./util"; + +export function simplifyAllExpressions(optCtx: OptimizationContext) { + // The interpreter in charge of simplifying expressions + const interpreter = new Interpreter(optCtx.ctx); + + // Traverse the program and attempt to evaluate every expression + + for (const moduleItem of optCtx.modifiedAst.items) { + switch (moduleItem.kind) { + case "asm_function_def": + case "native_function_decl": + case "primitive_type_decl": + // Nothing to simplify + break; + case "struct_decl": + case "message_decl": { + moduleItem.fields.forEach((field) => { + simplifyFieldDecl(field, optCtx, interpreter); + }); + break; + } + case "constant_def": { + simplifyConstantDef(moduleItem, optCtx, interpreter); + break; + } + case "function_def": { + moduleItem.statements.forEach((stmt) => { + simplifyStatement(stmt, optCtx, interpreter); + }); + break; + } + case "contract": { + moduleItem.declarations.forEach((decl) => { + simplifyContractDeclaration(decl, optCtx, interpreter); + }); + break; + } + case "trait": { + moduleItem.declarations.forEach((decl) => { + simplifyTraitDeclaration(decl, optCtx, interpreter); + }); + break; + } + default: + throwInternalCompilerError("Unrecognized module item kind"); + } + } +} + +function simplifyFieldDecl( + ast: AstFieldDecl, + optCtx: OptimizationContext, + interpreter: Interpreter, +) { + if (ast.initializer !== null) { + ast.initializer = simplifyExpression( + ast.initializer, + optCtx, + interpreter, + ); + } +} + +function simplifyConstantDef( + ast: AstConstantDef, + optCtx: OptimizationContext, + interpreter: Interpreter, +) { + ast.initializer = simplifyExpression(ast.initializer, optCtx, interpreter); +} + +function simplifyContractDeclaration( + decl: AstContractDeclaration, + optCtx: OptimizationContext, + interpreter: Interpreter, +) { + switch (decl.kind) { + case "asm_function_def": { + // This kind is not changed by the optimizer + break; + } + case "field_decl": { + simplifyFieldDecl(decl, optCtx, interpreter); + break; + } + case "constant_def": { + simplifyConstantDef(decl, optCtx, interpreter); + break; + } + case "function_def": + case "receiver": + case "contract_init": { + decl.statements.forEach((stmt) => { + simplifyStatement(stmt, optCtx, interpreter); + }); + break; + } + default: + throwInternalCompilerError( + "Unrecognized contract declaration kind", + ); + } +} + +function simplifyTraitDeclaration( + decl: AstTraitDeclaration, + optCtx: OptimizationContext, + interpreter: Interpreter, +) { + switch (decl.kind) { + case "asm_function_def": + case "constant_decl": + case "function_decl": { + // These kinds are not changed by the optimizer + break; + } + case "field_decl": { + simplifyFieldDecl(decl, optCtx, interpreter); + break; + } + case "constant_def": { + simplifyConstantDef(decl, optCtx, interpreter); + break; + } + case "function_def": + case "receiver": { + decl.statements.forEach((stmt) => { + simplifyStatement(stmt, optCtx, interpreter); + }); + break; + } + default: + throwInternalCompilerError("Unrecognized trait declaration kind"); + } +} + +function simplifyStatement( + stmt: AstStatement, + optCtx: OptimizationContext, + interpreter: Interpreter, +) { + switch (stmt.kind) { + case "statement_assign": + case "statement_expression": + case "statement_let": + case "statement_destruct": + case "statement_augmentedassign": { + stmt.expression = simplifyExpression( + stmt.expression, + optCtx, + interpreter, + ); + break; + } + case "statement_return": { + if (stmt.expression !== null) { + stmt.expression = simplifyExpression( + stmt.expression, + optCtx, + interpreter, + ); + } + break; + } + case "statement_condition": { + stmt.condition = simplifyExpression( + stmt.condition, + optCtx, + interpreter, + ); + stmt.trueStatements.forEach((trueStmt) => { + simplifyStatement(trueStmt, optCtx, interpreter); + }); + + if (stmt.falseStatements !== null) { + stmt.falseStatements.forEach((falseStmt) => { + simplifyStatement(falseStmt, optCtx, interpreter); + }); + } + + if (stmt.elseif !== null) { + simplifyStatement(stmt.elseif, optCtx, interpreter); + } + break; + } + case "statement_foreach": { + stmt.map = simplifyExpression(stmt.map, optCtx, interpreter); + stmt.statements.forEach((loopStmt) => { + simplifyStatement(loopStmt, optCtx, interpreter); + }); + break; + } + case "statement_until": + case "statement_while": { + stmt.condition = simplifyExpression( + stmt.condition, + optCtx, + interpreter, + ); + stmt.statements.forEach((loopStmt) => { + simplifyStatement(loopStmt, optCtx, interpreter); + }); + break; + } + case "statement_repeat": { + stmt.iterations = simplifyExpression( + stmt.iterations, + optCtx, + interpreter, + ); + stmt.statements.forEach((loopStmt) => { + simplifyStatement(loopStmt, optCtx, interpreter); + }); + break; + } + case "statement_try": { + stmt.statements.forEach((tryStmt) => { + simplifyStatement(tryStmt, optCtx, interpreter); + }); + break; + } + case "statement_try_catch": { + stmt.statements.forEach((tryStmt) => { + simplifyStatement(tryStmt, optCtx, interpreter); + }); + stmt.catchStatements.forEach((catchStmt) => { + simplifyStatement(catchStmt, optCtx, interpreter); + }); + break; + } + default: + throwInternalCompilerError("Unrecognized statement kind"); + } +} + +function simplifyExpression( + expr: AstExpression, + optCtx: OptimizationContext, + interpreter: Interpreter, +): AstExpression { + const value = tryExpressionSimplification(expr, interpreter); + let newExpr = expr; + if (typeof value !== "undefined") { + try { + newExpr = makeValueExpression(value, expr.loc); + // Register the new expression in the context + registerAstNodeChange(optCtx, expr, newExpr); + // To maintain consistency with types in the CompilerContext, register the + // types of all newly created expressions + optCtx.ctx = registerAllSubExpTypes( + optCtx.ctx, + newExpr, + getExpType(optCtx.ctx, expr), + ); + } catch (e) { + if (e instanceof UnsupportedOperation) { + // This means that transforming the value into an AST node is + // unsupported. Just use the original expression. + newExpr = expr; + } else { + throw e; + } + } + } + + return newExpr; +} + +function tryExpressionSimplification( + expr: AstExpression, + interpreter: Interpreter, +): Value | undefined { + try { + // Eventually, this will be replaced by the partial evaluator. + return interpreter.interpretExpression(expr); + } catch (e) { + if (e instanceof TactConstEvalError) { + if (!e.fatal) { + return undefined; + } + } + throw e; + } +} + +function registerAllSubExpTypes( + ctx: CompilerContext, + expr: AstValue, + expType: TypeRef, +): CompilerContext { + switch (expr.kind) { + case "boolean": + case "number": + case "string": + case "null": { + ctx = registerExpType(ctx, expr, expType); + break; + } + case "struct_instance": { + ctx = registerExpType(ctx, expr, expType); + + const structFields = getType(ctx, expr.type).fields; + const fieldTypes: Map = new Map(); + + for (const field of structFields) { + fieldTypes.set(field.name, field.type); + } + + for (const fieldValue of expr.args) { + const fieldType = fieldTypes.get(idText(fieldValue.field)); + if (typeof fieldType === "undefined") { + throwInternalCompilerError( + `Field ${idText(fieldValue.field)} does not have a declared type in struct ${idText(expr.type)}.`, + fieldValue.loc, + ); + } + ctx = registerAllSubExpTypes( + ctx, + ensureAstValue(fieldValue.initializer, fieldValue.loc), + fieldType, + ); + } + break; + } + default: + throwInternalCompilerError("Unrecognized AstValue."); + } + return ctx; +} + +function ensureAstValue(expr: AstExpression, src: SrcInfo): AstValue { + switch (expr.kind) { + case "boolean": + case "null": + case "number": + case "string": + case "struct_instance": + return expr; + default: + throwInternalCompilerError( + `Expressions of kind ${expr.kind} are not ASTValues.`, + src, + ); + } +} diff --git a/src/optimizer/expr_simplification.ts b/src/optimizer/expr_simplification.ts deleted file mode 100644 index f07598f65..000000000 --- a/src/optimizer/expr_simplification.ts +++ /dev/null @@ -1,713 +0,0 @@ -import { CompilerContext } from "../context"; -import { TactConstEvalError, throwInternalCompilerError } from "../errors"; -import { - AstCondition, - AstContractDeclaration, - AstExpression, - AstStatement, - AstTraitDeclaration, - AstTypeDecl, - AstValue, - cloneAstNode, - idText, - SrcInfo, -} from "../grammar/ast"; -import { Interpreter } from "../interpreter"; -import { - getAllStaticConstants, - getAllStaticFunctions, - getAllTypes, - getType, - replaceStaticConstants, - replaceStaticFunctions, - replaceTypes, -} from "../types/resolveDescriptors"; -import { getExpType, registerExpType } from "../types/resolveExpression"; -import { - ConstantDescription, - FieldDescription, - FunctionDescription, - InitDescription, - ReceiverDescription, - TypeDescription, - TypeRef, - Value, -} from "../types/types"; -import { makeValueExpression } from "./util"; - -export function simplify_expressions(ctx: CompilerContext): CompilerContext { - // The interpreter in charge of simplifying expressions - const interpreter = new Interpreter(ctx); - - // Traverse the program and attempt to evaluate every expression - - // Process functions - const newStaticFunctions: Map = new Map(); - - for (const f of getAllStaticFunctions(ctx)) { - if (f.ast.kind === "function_def") { - const statementsResult = process_statements( - f.ast.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - const newFunctionCode = cloneAstNode({ - ...f.ast, - statements: newStatements, - }); - newStaticFunctions.set(f.name, { ...f, ast: newFunctionCode }); - } else { - // The rest of kinds do not have explicit Tact expressions, so just copy the current function description - newStaticFunctions.set(f.name, f); - } - } - ctx = replaceStaticFunctions(ctx, newStaticFunctions); - - // Process all static constants - const newStaticConstants: Map = new Map(); - - for (const c of getAllStaticConstants(ctx)) { - if (c.ast.kind === "constant_def") { - const expressionResult = process_expression( - c.ast.initializer, - ctx, - interpreter, - ); - const newInitializer = expressionResult.expr; - ctx = expressionResult.ctx; - const newConstantCode = cloneAstNode({ - ...c.ast, - initializer: newInitializer, - }); - newStaticConstants.set(c.name, { ...c, ast: newConstantCode }); - } else { - // The rest of kinds do not have explicit Tact expressions, so just copy the current description - newStaticConstants.set(c.name, c); - } - } - ctx = replaceStaticConstants(ctx, newStaticConstants); - - // Process all types - - /** - * By calling the function getAllTypes on the context object "ctx", one gets an array of TypeDescriptions. - * Each TypeDescription stores the type declarations in two different ways: - * - Directly in the TypeDescription object there are fields, constants, and method - * declarations. However, these declarations are "coalesced" in the following sense: - * If the TypeDescription is a contract, it will contain copies of methods, constants and fields of traits that the - * contract inherits from. Similarly, each trait will have declarations of other traits - * that the trait inherits from. - * - * For example, if we look into the "functions" property of the TypeDescription object of a contract - * we will find functions defined in BaseTrait. - * - * - Indirectly in the "ast" property of the TypeDescription. Contrary to the previous case, - * the fields, constants and methods in the ast property are NOT coalesced. This means, for example, - * that the methods in a TypeDescription's ast of a contract will be methods that are actually - * declared in the contract and not in some trait that the contract inherits from. - * - * The above means that we will need to process the properties in TypeDescription first, - * and then use those properties to build the AST (carefully ensuring that only fields, constants and methods - * that were in the original AST, remain in the new AST). - */ - const newTypes: Map = new Map(); - - for (const t of getAllTypes(ctx)) { - let newInitializer: InitDescription | null = null; - - // Process init - if (t.init) { - const statementsResult = process_statements( - t.init.ast.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - const newInitCode = cloneAstNode({ - ...t.init.ast, - statements: newStatements, - }); - newInitializer = { ...t.init, ast: newInitCode }; - } - - // Process constants - const newConstants: ConstantDescription[] = []; - - // This map will be used to quickly recover the new definitions when - // building the AST later - const newConstantsMap: Map = new Map(); - - for (const c of t.constants) { - if (c.ast.kind === "constant_def") { - const expressionResult = process_expression( - c.ast.initializer, - ctx, - interpreter, - ); - const newInitializer = expressionResult.expr; - ctx = expressionResult.ctx; - const newConstantCode = cloneAstNode({ - ...c.ast, - initializer: newInitializer, - }); - const newConstantDescription = { ...c, ast: newConstantCode }; - newConstants.push(newConstantDescription); - newConstantsMap.set(c.name, newConstantDescription); - } else { - // The rest of kinds do not have explicit Tact expressions, so just copy the current description - newConstants.push(c); - newConstantsMap.set(c.name, c); - } - } - - // Process fields - const newFields: FieldDescription[] = []; - - // This map will be used to quickly recover the new definitions when - // building the AST later - const newFieldsMap: Map = new Map(); - - for (const f of t.fields) { - if (f.ast.initializer !== null) { - const expressionResult = process_expression( - f.ast.initializer, - ctx, - interpreter, - ); - const newInitializer = expressionResult.expr; - ctx = expressionResult.ctx; - const newFieldCode = cloneAstNode({ - ...f.ast, - initializer: newInitializer, - }); - const newFieldDescription = { ...f, ast: newFieldCode }; - newFields.push(newFieldDescription); - newFieldsMap.set(f.name, newFieldDescription); - } else { - // Field without initializer, no expression to simplify inside - newFields.push(f); - newFieldsMap.set(f.name, f); - } - } - - // Process receivers - const newReceivers: ReceiverDescription[] = []; - - // This map will be used to quickly recover the new definitions when - // building the AST later. - // Since receivers do not have names, I will use their id in their original ast - // as key. - const newReceiversMap: Map = new Map(); - - for (const r of t.receivers) { - const statementsResult = process_statements( - r.ast.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - const newReceiverCode = cloneAstNode({ - ...r.ast, - statements: newStatements, - }); - const newReceiverDescription = { ...r, ast: newReceiverCode }; - newReceivers.push(newReceiverDescription); - newReceiversMap.set(r.ast.id, newReceiverDescription); - } - - // Process methods - - // This is already a map in TypeDescription. This is the reason - // I did not need a separate map, like in the previous cases. - const newMethods: Map = new Map(); - - for (const [name, m] of t.functions) { - if (m.ast.kind === "function_def") { - const statementsResult = process_statements( - m.ast.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - const newMethodCode = cloneAstNode({ - ...m.ast, - statements: newStatements, - }); - newMethods.set(name, { ...m, ast: newMethodCode }); - } else { - // The rest of kinds do not have explicit Tact expressions, so just copy the current function description - newMethods.set(name, m); - } - } - - // Now, we need to create the new AST, depending on its kind. - let newAst: AstTypeDecl; - - switch (t.ast.kind) { - case "primitive_type_decl": { - newAst = t.ast; - break; - } - case "struct_decl": - case "message_decl": { - newAst = cloneAstNode({ - ...t.ast, - fields: newFields.map((f) => f.ast), - }); - break; - } - case "trait": { - const newDeclarations: AstTraitDeclaration[] = []; - - for (const decl of t.ast.declarations) { - switch (decl.kind) { - case "asm_function_def": - case "function_decl": - case "function_def": { - const newCode = newMethods.get(idText(decl.name))! - .ast as AstTraitDeclaration; - newDeclarations.push(newCode); - break; - } - case "constant_decl": - case "constant_def": { - const newCode = newConstantsMap.get( - idText(decl.name), - )!.ast; - newDeclarations.push(newCode); - break; - } - case "field_decl": { - const newCode = newFieldsMap.get( - idText(decl.name), - )!.ast; - newDeclarations.push(newCode); - break; - } - case "receiver": { - const newCode = newReceiversMap.get(decl.id)!.ast; - newDeclarations.push(newCode); - break; - } - } - } - - newAst = cloneAstNode({ - ...t.ast, - declarations: newDeclarations, - }); - - break; - } - case "contract": { - const newDeclarations: AstContractDeclaration[] = []; - - for (const decl of t.ast.declarations) { - switch (decl.kind) { - case "asm_function_def": - case "function_def": { - const newCode = newMethods.get(idText(decl.name))! - .ast as AstContractDeclaration; - newDeclarations.push(newCode); - break; - } - case "constant_def": { - const newCode = newConstantsMap.get( - idText(decl.name), - )!.ast as AstContractDeclaration; - newDeclarations.push(newCode); - break; - } - case "field_decl": { - const newCode = newFieldsMap.get( - idText(decl.name), - )!.ast; - newDeclarations.push(newCode); - break; - } - case "receiver": { - const newCode = newReceiversMap.get(decl.id)!.ast; - newDeclarations.push(newCode); - break; - } - case "contract_init": - newDeclarations.push(newInitializer!.ast); - break; - } - } - - newAst = cloneAstNode({ - ...t.ast, - declarations: newDeclarations, - }); - - break; - } - } - - newTypes.set(t.name, { - ...t, - ast: newAst, - init: newInitializer, - constants: newConstants, - fields: newFields, - functions: newMethods, - receivers: newReceivers, - }); - } - ctx = replaceTypes(ctx, newTypes); - - return ctx; -} - -function process_statements( - statements: AstStatement[], - ctx: CompilerContext, - interpreter: Interpreter, -): { stmts: AstStatement[]; ctx: CompilerContext } { - const newStatements: AstStatement[] = []; - - for (const stmt of statements) { - const result = process_statement(stmt, ctx, interpreter); - newStatements.push(result.stmt); - ctx = result.ctx; - } - - return { stmts: newStatements, ctx: ctx }; -} - -function process_statement( - stmt: AstStatement, - ctx: CompilerContext, - interpreter: Interpreter, -): { stmt: AstStatement; ctx: CompilerContext } { - switch (stmt.kind) { - case "statement_assign": - case "statement_expression": - case "statement_let": - case "statement_destruct": - case "statement_augmentedassign": { - const expressionResult = process_expression( - stmt.expression, - ctx, - interpreter, - ); - const new_expr = expressionResult.expr; - ctx = expressionResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - expression: new_expr, - }), - ctx: ctx, - }; - } - case "statement_return": { - if (stmt.expression !== null) { - const expressionResult = process_expression( - stmt.expression, - ctx, - interpreter, - ); - const new_expr = expressionResult.expr; - ctx = expressionResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - expression: new_expr, - }), - ctx: ctx, - }; - } - return { - stmt: stmt, - ctx: ctx, - }; - } - case "statement_condition": { - const expressionResult = process_expression( - stmt.condition, - ctx, - interpreter, - ); - const newCondition = expressionResult.expr; - ctx = expressionResult.ctx; - - const trueStatementsResult = process_statements( - stmt.trueStatements, - ctx, - interpreter, - ); - const newTrueStatements = trueStatementsResult.stmts; - ctx = trueStatementsResult.ctx; - - let newFalseStatements: AstStatement[] | null = null; - if (stmt.falseStatements !== null) { - const falseStatementsResult = process_statements( - stmt.falseStatements, - ctx, - interpreter, - ); - newFalseStatements = falseStatementsResult.stmts; - ctx = falseStatementsResult.ctx; - } - - let newElseIf: AstCondition | null = null; - if (stmt.elseif !== null) { - const elseIfResult = process_statement( - stmt.elseif, - ctx, - interpreter, - ); - newElseIf = elseIfResult.stmt as AstCondition; - ctx = elseIfResult.ctx; - } - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - condition: newCondition, - trueStatements: newTrueStatements, - falseStatements: newFalseStatements, - elseif: newElseIf, - }), - ctx: ctx, - }; - } - case "statement_foreach": { - const expressionResult = process_expression( - stmt.map, - ctx, - interpreter, - ); - const newMap = expressionResult.expr; - ctx = expressionResult.ctx; - - const statementsResult = process_statements( - stmt.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - map: newMap, - statements: newStatements, - }), - ctx: ctx, - }; - } - case "statement_until": - case "statement_while": { - const expressionResult = process_expression( - stmt.condition, - ctx, - interpreter, - ); - const newCondition = expressionResult.expr; - ctx = expressionResult.ctx; - - const statementsResult = process_statements( - stmt.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - condition: newCondition, - statements: newStatements, - }), - ctx: ctx, - }; - } - case "statement_repeat": { - const expressionResult = process_expression( - stmt.iterations, - ctx, - interpreter, - ); - const newIterations = expressionResult.expr; - ctx = expressionResult.ctx; - - const statementsResult = process_statements( - stmt.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - iterations: newIterations, - statements: newStatements, - }), - ctx: ctx, - }; - } - case "statement_try": { - const statementsResult = process_statements( - stmt.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - statements: newStatements, - }), - ctx: ctx, - }; - } - case "statement_try_catch": { - const statementsResult = process_statements( - stmt.statements, - ctx, - interpreter, - ); - const newStatements = statementsResult.stmts; - ctx = statementsResult.ctx; - - const catchStatementsResult = process_statements( - stmt.catchStatements, - ctx, - interpreter, - ); - const newCatchStatements = catchStatementsResult.stmts; - ctx = catchStatementsResult.ctx; - - // Create the replacement node - return { - stmt: cloneAstNode({ - ...stmt, - statements: newStatements, - catchStatements: newCatchStatements, - }), - ctx: ctx, - }; - } - } -} - -function process_expression( - expr: AstExpression, - ctx: CompilerContext, - interpreter: Interpreter, -): { expr: AstExpression; ctx: CompilerContext } { - const value = tryExpressionSimplification(expr, interpreter); - let newExpr = expr; - if (value !== undefined) { - try { - newExpr = makeValueExpression(value, expr.loc); - // Register the new expression in the context - ctx = registerAllSubExpTypes(ctx, newExpr, getExpType(ctx, expr)); - } catch (_) { - // This means that transforming the value into an AST node is - // unsupported or it failed to register the type of the expression. - // Just use the original expression. - newExpr = expr; - } - } - return { expr: newExpr, ctx: ctx }; -} - -function tryExpressionSimplification( - expr: AstExpression, - interpreter: Interpreter, -): Value | undefined { - try { - // Eventually, this will be replaced by the partial evaluator. - return interpreter.interpretExpression(expr); - } catch (e) { - if (e instanceof TactConstEvalError) { - if (!e.fatal) { - return undefined; - } - } - throw e; - } -} - -function registerAllSubExpTypes( - ctx: CompilerContext, - expr: AstValue, - expType: TypeRef, -): CompilerContext { - switch (expr.kind) { - case "boolean": - case "number": - case "string": - case "null": { - ctx = registerExpType(ctx, expr, expType); - break; - } - case "struct_instance": { - ctx = registerExpType(ctx, expr, expType); - - const structFields = getType(ctx, expr.type).fields; - const fieldTypes: Map = new Map(); - - for (const field of structFields) { - fieldTypes.set(field.name, field.type); - } - - for (const fieldValue of expr.args) { - // Typechecking ensures that each field in the struct instance has a type - const fieldType = fieldTypes.get(idText(fieldValue.field)); - if (fieldType === undefined) { - throwInternalCompilerError( - `Field ${idText(fieldValue.field)} does not have a declared type in struct ${idText(expr.type)}.`, - fieldValue.loc, - ); - } - ctx = registerAllSubExpTypes( - ctx, - ensureAstValue(fieldValue.initializer, fieldValue.loc), - fieldType, - ); - } - } - } - return ctx; -} - -function ensureAstValue(expr: AstExpression, src: SrcInfo): AstValue { - switch (expr.kind) { - case "boolean": - case "null": - case "number": - case "string": - case "struct_instance": - return expr; - default: - throwInternalCompilerError( - `Expressions of kind ${expr.kind} are not ASTValues.`, - src, - ); - } -} diff --git a/src/optimizer/optimization-phase.ts b/src/optimizer/optimization-phase.ts new file mode 100644 index 000000000..9240073f2 --- /dev/null +++ b/src/optimizer/optimization-phase.ts @@ -0,0 +1,1152 @@ +import { CompilerContext } from "../context"; +import { throwInternalCompilerError } from "../errors"; +import { + AstCondition, + AstConditional, + AstConstantDef, + AstContract, + AstContractDeclaration, + AstContractInit, + AstExpression, + AstFieldAccess, + AstFieldDecl, + AstFunctionDef, + AstInitOf, + AstMessageDecl, + AstMethodCall, + AstModule, + AstModuleItem, + AstOpBinary, + AstOpUnary, + AstReceiver, + AstStatement, + AstStatementAssign, + AstStatementAugmentedAssign, + AstStatementDestruct, + AstStatementExpression, + AstStatementForEach, + AstStatementLet, + AstStatementRepeat, + AstStatementReturn, + AstStatementTry, + AstStatementTryCatch, + AstStatementUntil, + AstStatementWhile, + AstStaticCall, + AstStructDecl, + AstStructFieldInitializer, + AstStructInstance, + AstTrait, + AstTraitDeclaration, + AstTypeDecl, + cloneAstNode, + createAstNode, + isAstExpression, +} from "../grammar/ast"; +import { prettyPrint } from "../prettyPrinter"; +import { + getAllStaticConstants, + getAllStaticFunctions, + getAllTypes, +} from "../types/resolveDescriptors"; +import { writeFile } from "node:fs/promises"; +import { simplifyAllExpressions } from "./expr-simplification"; +import { TypeDescription } from "../types/types"; +import { getExpTypeById, registerExpType } from "../types/resolveExpression"; + +/* These are the node types that the optimization phase is allowed to modify */ +type AstMutableNode = + | AstExpression + | AstStatement + | AstTypeDecl + | AstFieldDecl + | AstFunctionDef + | AstModule + | AstContractInit + | AstReceiver + | AstConstantDef + | AstStructFieldInitializer; + +export type OptimizationContext = { + originalAst: AstModule; + modifiedAst: AstModule; + nodeReplacements: Map; + originalIDs: Map; + ctx: CompilerContext; +}; + +export function optimizeTact(ctx: OptimizationContext) { + // Call the expression simplification phase + simplifyAllExpressions(ctx); + + // Here, we will call the constant propagation analyzer +} + +export function prepareAstForOptimization( + ctx: CompilerContext, + doOptimizationFlag: boolean, +): OptimizationContext { + // Create a module AST that stores the entire program. + const moduleItems: AstModuleItem[] = []; + + // Extract constants + for (const c of getAllStaticConstants(ctx)) { + if (c.ast.kind === "constant_decl") { + throwInternalCompilerError( + "Constant declarations cannot be top level module declarations.", + ); + } + moduleItems.push(c.ast); + } + + // Extract functions + for (const f of getAllStaticFunctions(ctx)) { + if (f.ast.kind === "function_decl") { + throwInternalCompilerError( + "Function declarations cannot be top level module declarations.", + ); + } + moduleItems.push(f.ast); + } + + // Extract type declarations + for (const t of getAllTypes(ctx)) { + moduleItems.push(t.ast); + } + + // Uses an empty list of imports. AstModule nodes will be deleted at the end of the optimization phase anyway, + // because everything needs to be put back into the format inside of CompilerContext + const moduleAst = createAstNode({ + kind: "module", + items: moduleItems, + imports: [], + }) as AstModule; + + if (doOptimizationFlag) { + const changedIds: Map = new Map(); + const newAst = makeUnfrozenCopyOfModule(moduleAst, changedIds); + return buildOptimizationContext(moduleAst, newAst, changedIds, ctx); + } else { + return buildOptimizationContext(moduleAst, moduleAst, new Map(), ctx); + } +} + +function makeUnfrozenCopyOfModule( + ast: AstModule, + changedNodeIds: Map, +): AstModule { + const newItems: AstModuleItem[] = []; + + for (const moduleItem of ast.items) { + switch (moduleItem.kind) { + case "asm_function_def": + case "native_function_decl": + case "primitive_type_decl": { + // These kinds are not modified by the optimizer at this moment. + // So, just pass the [frozen] node. + newItems.push(moduleItem); + break; + } + case "constant_def": { + newItems.push( + makeUnfrozenCopyOfConstantDef(moduleItem, changedNodeIds), + ); + break; + } + case "function_def": { + newItems.push( + makeUnfrozenCopyOfFunctionDef(moduleItem, changedNodeIds), + ); + break; + } + case "message_decl": { + newItems.push( + makeUnfrozenCopyOfMessageDecl(moduleItem, changedNodeIds), + ); + break; + } + case "struct_decl": { + newItems.push( + makeUnfrozenCopyOfStructDecl(moduleItem, changedNodeIds), + ); + break; + } + case "trait": { + newItems.push( + makeUnfrozenCopyOfTrait(moduleItem, changedNodeIds), + ); + break; + } + case "contract": { + newItems.push( + makeUnfrozenCopyOfContract(moduleItem, changedNodeIds), + ); + break; + } + default: + throwInternalCompilerError("Unrecognized AstMutable node"); + } + } + + const newModuleNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newModuleNode.items = newItems; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newModuleNode); + return newModuleNode; +} + +function makeUnfrozenCopyOfConstantDef( + ast: AstConstantDef, + changedNodeIds: Map, +): AstConstantDef { + const newInitializer = makeUnfrozenCopyOfExpression( + ast.initializer, + changedNodeIds, + ); + const newConstantDefNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newConstantDefNode.initializer = newInitializer; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newConstantDefNode); + return newConstantDefNode; +} + +function makeUnfrozenCopyOfFunctionDef( + ast: AstFunctionDef, + changedNodeIds: Map, +): AstFunctionDef { + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newFunctionDefNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newFunctionDefNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newFunctionDefNode); + return newFunctionDefNode; +} + +function makeUnfrozenCopyOfMessageDecl( + ast: AstMessageDecl, + changedNodeIds: Map, +): AstMessageDecl { + const newFields = ast.fields.map((field) => + makeUnfrozenCopyOfFieldDecl(field, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.fields = newFields; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfStructDecl( + ast: AstStructDecl, + changedNodeIds: Map, +): AstStructDecl { + const newFields = ast.fields.map((field) => + makeUnfrozenCopyOfFieldDecl(field, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.fields = newFields; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfTrait( + ast: AstTrait, + changedNodeIds: Map, +): AstTrait { + const newDeclarations: AstTraitDeclaration[] = []; + + for (const decl of ast.declarations) { + switch (decl.kind) { + case "asm_function_def": + case "constant_decl": + case "function_decl": { + // These kinds are not changed by the optimizer + newDeclarations.push(decl); + break; + } + case "field_decl": { + newDeclarations.push( + makeUnfrozenCopyOfFieldDecl(decl, changedNodeIds), + ); + break; + } + case "constant_def": { + newDeclarations.push( + makeUnfrozenCopyOfConstantDef(decl, changedNodeIds), + ); + break; + } + case "function_def": { + newDeclarations.push( + makeUnfrozenCopyOfFunctionDef(decl, changedNodeIds), + ); + break; + } + case "receiver": { + newDeclarations.push( + makeUnfrozenCopyOfReceiver(decl, changedNodeIds), + ); + break; + } + default: + throwInternalCompilerError( + "Unrecognized AstTrait declaration kind", + ); + } + } + + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.declarations = newDeclarations; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfContract( + ast: AstContract, + changedNodeIds: Map, +): AstContract { + const newDeclarations: AstContractDeclaration[] = []; + + for (const decl of ast.declarations) { + switch (decl.kind) { + case "asm_function_def": { + // This kind is not changed by the optimizer + newDeclarations.push(decl); + break; + } + case "field_decl": { + newDeclarations.push( + makeUnfrozenCopyOfFieldDecl(decl, changedNodeIds), + ); + break; + } + case "constant_def": { + newDeclarations.push( + makeUnfrozenCopyOfConstantDef(decl, changedNodeIds), + ); + break; + } + case "function_def": { + newDeclarations.push( + makeUnfrozenCopyOfFunctionDef(decl, changedNodeIds), + ); + break; + } + case "receiver": { + newDeclarations.push( + makeUnfrozenCopyOfReceiver(decl, changedNodeIds), + ); + break; + } + case "contract_init": { + newDeclarations.push( + makeUnfrozenCopyOfContractInit(decl, changedNodeIds), + ); + break; + } + default: + throwInternalCompilerError( + "Unrecognized AstContract declaration kind", + ); + } + } + + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.declarations = newDeclarations; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfFieldDecl( + ast: AstFieldDecl, + changedNodeIds: Map, +): AstFieldDecl { + if (ast.initializer === null) { + // If there is no initializer expression, + // just use the original node because there is nothing to change + return ast; + } + + const newInitializer = makeUnfrozenCopyOfExpression( + ast.initializer, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.initializer = newInitializer; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfReceiver( + ast: AstReceiver, + changedNodeIds: Map, +): AstReceiver { + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfContractInit( + ast: AstContractInit, + changedNodeIds: Map, +): AstContractInit { + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfStatement( + ast: AstStatement, + changedNodeIds: Map, +): AstStatement { + switch (ast.kind) { + case "statement_assign": { + return makeUnfrozenCopyOfAssign(ast, changedNodeIds); + } + case "statement_augmentedassign": { + return makeUnfrozenCopyOfAugmentedAssign(ast, changedNodeIds); + } + case "statement_expression": { + return makeUnfrozenCopyOfStatementExpression(ast, changedNodeIds); + } + case "statement_let": { + return makeUnfrozenCopyOfLet(ast, changedNodeIds); + } + case "statement_destruct": { + return makeUnfrozenCopyOfDestruct(ast, changedNodeIds); + } + case "statement_return": { + return makeUnfrozenCopyOfReturn(ast, changedNodeIds); + } + case "statement_until": { + return makeUnfrozenCopyOfUntil(ast, changedNodeIds); + } + case "statement_while": { + return makeUnfrozenCopyOfWhile(ast, changedNodeIds); + } + case "statement_repeat": { + return makeUnfrozenCopyOfRepeat(ast, changedNodeIds); + } + case "statement_foreach": { + return makeUnfrozenCopyOfForEach(ast, changedNodeIds); + } + case "statement_condition": { + return makeUnfrozenCopyOfCondition(ast, changedNodeIds); + } + case "statement_try": { + return makeUnfrozenCopyOfTry(ast, changedNodeIds); + } + case "statement_try_catch": { + return makeUnfrozenCopyOfTryCatch(ast, changedNodeIds); + } + default: + throwInternalCompilerError("Unrecognized AstStatement kind"); + } +} + +function makeUnfrozenCopyOfAssign( + ast: AstStatementAssign, + changedNodeIds: Map, +): AstStatementAssign { + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfAugmentedAssign( + ast: AstStatementAugmentedAssign, + changedNodeIds: Map, +): AstStatementAugmentedAssign { + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfStatementExpression( + ast: AstStatementExpression, + changedNodeIds: Map, +): AstStatementExpression { + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfLet( + ast: AstStatementLet, + changedNodeIds: Map, +): AstStatementLet { + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfDestruct( + ast: AstStatementDestruct, + changedNodeIds: Map, +): AstStatementDestruct { + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfReturn( + ast: AstStatementReturn, + changedNodeIds: Map, +): AstStatementReturn { + if (ast.expression === null) { + return ast; + } + const newExpr = makeUnfrozenCopyOfExpression( + ast.expression, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.expression = newExpr; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfUntil( + ast: AstStatementUntil, + changedNodeIds: Map, +): AstStatementUntil { + const newCondition = makeUnfrozenCopyOfExpression( + ast.condition, + changedNodeIds, + ); + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.condition = newCondition; + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfWhile( + ast: AstStatementWhile, + changedNodeIds: Map, +): AstStatementWhile { + const newCondition = makeUnfrozenCopyOfExpression( + ast.condition, + changedNodeIds, + ); + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.condition = newCondition; + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfRepeat( + ast: AstStatementRepeat, + changedNodeIds: Map, +): AstStatementRepeat { + const newIterations = makeUnfrozenCopyOfExpression( + ast.iterations, + changedNodeIds, + ); + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.iterations = newIterations; + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfForEach( + ast: AstStatementForEach, + changedNodeIds: Map, +): AstStatementForEach { + const newMap = makeUnfrozenCopyOfExpression(ast.map, changedNodeIds); + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.map = newMap; + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfCondition( + ast: AstCondition, + changedNodeIds: Map, +): AstCondition { + const newCondition = makeUnfrozenCopyOfExpression( + ast.condition, + changedNodeIds, + ); + const newTrueStatements = ast.trueStatements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newFalseStatements = + ast.falseStatements !== null + ? ast.falseStatements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ) + : null; + const newElseIf = + ast.elseif !== null + ? makeUnfrozenCopyOfCondition(ast.elseif, changedNodeIds) + : null; + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.condition = newCondition; + newNode.trueStatements = newTrueStatements; + newNode.falseStatements = newFalseStatements; + newNode.elseif = newElseIf; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfTry( + ast: AstStatementTry, + changedNodeIds: Map, +): AstStatementTry { + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.statements = newStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfTryCatch( + ast: AstStatementTryCatch, + changedNodeIds: Map, +): AstStatementTryCatch { + const newStatements = ast.statements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newCatchStatements = ast.catchStatements.map((stmt) => + makeUnfrozenCopyOfStatement(stmt, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.statements = newStatements; + newNode.catchStatements = newCatchStatements; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfExpression( + ast: AstExpression, + changedNodeIds: Map, +): AstExpression { + switch (ast.kind) { + case "id": + case "null": + case "boolean": + case "number": + case "string": + // These leaf nodes are never changed. + return ast; + case "struct_instance": + return makeUnfrozenCopyOfStructInstance(ast, changedNodeIds); + case "field_access": + return makeUnfrozenCopyOfFieldAccess(ast, changedNodeIds); + case "method_call": + return makeUnfrozenCopyOfMethodCall(ast, changedNodeIds); + case "static_call": + return makeUnfrozenCopyOfStaticCall(ast, changedNodeIds); + case "op_unary": + return makeUnfrozenCopyOfUnaryOp(ast, changedNodeIds); + case "op_binary": + return makeUnfrozenCopyOfBinaryOp(ast, changedNodeIds); + case "init_of": + return makeUnfrozenCopyOfInitOf(ast, changedNodeIds); + case "conditional": + return makeUnfrozenCopyOfConditional(ast, changedNodeIds); + default: + throwInternalCompilerError("Unrecognized AstExpression kind"); + } +} + +function makeUnfrozenCopyOfStructInstance( + ast: AstStructInstance, + changedNodeIds: Map, +): AstStructInstance { + const newArgs = ast.args.map((initializer) => + makeUnfrozenCopyOfFieldInitializer(initializer, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.args = newArgs; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfFieldAccess( + ast: AstFieldAccess, + changedNodeIds: Map, +): AstFieldAccess { + const newAggregate = makeUnfrozenCopyOfExpression( + ast.aggregate, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.aggregate = newAggregate; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfMethodCall( + ast: AstMethodCall, + changedNodeIds: Map, +): AstMethodCall { + const newArgs = ast.args.map((expr) => + makeUnfrozenCopyOfExpression(expr, changedNodeIds), + ); + const newSelf = makeUnfrozenCopyOfExpression(ast.self, changedNodeIds); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.args = newArgs; + newNode.self = newSelf; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfStaticCall( + ast: AstStaticCall, + changedNodeIds: Map, +): AstStaticCall { + const newArgs = ast.args.map((expr) => + makeUnfrozenCopyOfExpression(expr, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.args = newArgs; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfUnaryOp( + ast: AstOpUnary, + changedNodeIds: Map, +): AstOpUnary { + const newOperand = makeUnfrozenCopyOfExpression( + ast.operand, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.operand = newOperand; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfBinaryOp( + ast: AstOpBinary, + changedNodeIds: Map, +): AstOpBinary { + const newLeft = makeUnfrozenCopyOfExpression(ast.left, changedNodeIds); + const newRight = makeUnfrozenCopyOfExpression(ast.right, changedNodeIds); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.left = newLeft; + newNode.right = newRight; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfInitOf( + ast: AstInitOf, + changedNodeIds: Map, +): AstInitOf { + const newArgs = ast.args.map((expr) => + makeUnfrozenCopyOfExpression(expr, changedNodeIds), + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.args = newArgs; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfConditional( + ast: AstConditional, + changedNodeIds: Map, +): AstConditional { + const newCondition = makeUnfrozenCopyOfExpression( + ast.condition, + changedNodeIds, + ); + const newThen = makeUnfrozenCopyOfExpression( + ast.thenBranch, + changedNodeIds, + ); + const newElse = makeUnfrozenCopyOfExpression( + ast.elseBranch, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.condition = newCondition; + newNode.thenBranch = newThen; + newNode.elseBranch = newElse; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +function makeUnfrozenCopyOfFieldInitializer( + ast: AstStructFieldInitializer, + changedNodeIds: Map, +): AstStructFieldInitializer { + const newInitializer = makeUnfrozenCopyOfExpression( + ast.initializer, + changedNodeIds, + ); + const newNode = cloneAstNode(ast); + // The rest of properties will not be touched by the optimizer. + newNode.initializer = newInitializer; + // Remember the ID of the new node + changedNodeIds.set(ast.id, newNode); + return newNode; +} + +export function updateCompilerContext( + optCtx: OptimizationContext, +): CompilerContext { + processStaticFunctions(optCtx.ctx, optCtx.nodeReplacements); + + processStaticConstants(optCtx.ctx, optCtx.nodeReplacements); + + processTypes(optCtx.ctx, optCtx.nodeReplacements); + + return optCtx.ctx; +} + +export function dumpTactCode(ast: AstModule, file: string) { + const program = prettyPrint(ast); + + void writeFile(file, program); +} + +function buildOptimizationContext( + ast: AstModule, + newAst: AstModule, + changedIds: Map, + ctx: CompilerContext, +): OptimizationContext { + // Build inverse map + const originalIDs: Map = new Map(); + + for (const [id, node] of changedIds) { + originalIDs.set(node.id, id); + } + + // To maintain consistency with types in the CompilerContext, register the + // types of all newly created expressions in changedIds. + for (const [id, node] of changedIds) { + if (isAstExpression(node)) { + ctx = registerExpType(ctx, node, getExpTypeById(ctx, id)); + } + } + + return { + originalAst: ast, + modifiedAst: newAst, + nodeReplacements: changedIds, + originalIDs: originalIDs, + ctx: ctx, + }; +} + +export function registerAstNodeChange( + optCtx: OptimizationContext, + nodeToReplace: AstMutableNode, + newNode: AstMutableNode, +) { + const idToReplace = nodeToReplace.id; + + // Is the idToReplace already a replacement of an original ID? + if (optCtx.originalIDs.has(idToReplace)) { + // Obtain the original ID + const originalID = optCtx.originalIDs.get(idToReplace)!; + // Now replace the original node + optCtx.nodeReplacements.set(originalID, newNode); + // Update the inverse map + optCtx.originalIDs.set(newNode.id, originalID); + } else { + // idToReplace is an original node + optCtx.nodeReplacements.set(idToReplace, newNode); + // Update the inverse map + optCtx.originalIDs.set(newNode.id, idToReplace); + } +} + +function processStaticFunctions( + ctx: CompilerContext, + nodeReplacements: Map, +) { + for (const f of getAllStaticFunctions(ctx)) { + if (nodeReplacements.has(f.ast.id)) { + f.ast = ensureFunction(nodeReplacements.get(f.ast.id)!); + } + } +} + +function processStaticConstants( + ctx: CompilerContext, + nodeReplacements: Map, +) { + for (const c of getAllStaticConstants(ctx)) { + if (nodeReplacements.has(c.ast.id)) { + c.ast = ensureConstant(nodeReplacements.get(c.ast.id)!); + } + } +} + +function processTypeInitDescription( + t: TypeDescription, + nodeReplacements: Map, +) { + if (t.init === null) { + return; + } + + if (nodeReplacements.has(t.init.ast.id)) { + t.init.ast = ensureContractInit(nodeReplacements.get(t.init.ast.id)!); + } +} + +function processTypeConstantDescriptions( + t: TypeDescription, + nodeReplacements: Map, +) { + for (const c of t.constants) { + if (nodeReplacements.has(c.ast.id)) { + c.ast = ensureConstant(nodeReplacements.get(c.ast.id)!); + } + } +} + +function processTypeFieldDescriptions( + t: TypeDescription, + nodeReplacements: Map, +) { + for (const f of t.fields) { + if (nodeReplacements.has(f.ast.id)) { + f.ast = ensureFieldDecl(nodeReplacements.get(f.ast.id)!); + } + } +} + +function processTypeReceiverDescriptions( + t: TypeDescription, + nodeReplacements: Map, +) { + for (const r of t.receivers) { + if (nodeReplacements.has(r.ast.id)) { + r.ast = ensureReceiver(nodeReplacements.get(r.ast.id)!); + } + } +} + +function processTypeFunctionDescriptions( + t: TypeDescription, + nodeReplacements: Map, +) { + for (const [_, m] of t.functions) { + if (nodeReplacements.has(m.ast.id)) { + m.ast = ensureFunction(nodeReplacements.get(m.ast.id)!); + } + } +} + +function processTypes( + ctx: CompilerContext, + nodeReplacements: Map, +) { + /** + * By calling the function getAllTypes on the context object "ctx", one gets an array of TypeDescriptions. + * Each TypeDescription stores the type declarations in two different ways: + * - Directly in the TypeDescription object there are fields, constants, and method + * declarations. However, these declarations are "coalesced" in the following sense: + * If the TypeDescription is a contract, it will contain copies of methods, constants and fields of traits that the + * contract inherits from. Similarly, each trait will have declarations of other traits + * that the trait inherits from. + * + * For example, if we look into the "functions" property of the TypeDescription object of a contract + * we will find functions defined in BaseTrait. + * + * - Indirectly in the "ast" property of the TypeDescription. Contrary to the previous case, + * the fields, constants and methods in the ast property are NOT coalesced. This means, for example, + * that the methods in a TypeDescription's ast of a contract will be methods that are actually + * declared in the contract and not in some trait that the contract inherits from. + */ + + for (const t of getAllTypes(ctx)) { + // First, process all the coalesced data + + processTypeInitDescription(t, nodeReplacements); + + processTypeConstantDescriptions(t, nodeReplacements); + + processTypeFieldDescriptions(t, nodeReplacements); + + processTypeReceiverDescriptions(t, nodeReplacements); + + processTypeFunctionDescriptions(t, nodeReplacements); + + // Now, the non-coalesced data, which is simply the changed node in nodeReplacements + if (nodeReplacements.has(t.ast.id)) { + t.ast = ensureTypeDecl(nodeReplacements.get(t.ast.id)!); + } + } +} + +function ensureFunction(ast: AstMutableNode): AstFunctionDef { + // Type AstMutableNode restricts the possibilities of the + // function type to AstFunctionDef + if (ast.kind === "function_def") { + return ast; + } else { + throwInternalCompilerError(`kind ${ast.kind} is not a function kind`); + } +} + +function ensureConstant(ast: AstMutableNode): AstConstantDef { + // Type AstMutableNode restricts the possibilities of the + // constant type to AstConstantDef + if (ast.kind === "constant_def") { + return ast; + } else { + throwInternalCompilerError(`kind ${ast.kind} is not a constant kind`); + } +} + +function ensureContractInit(ast: AstMutableNode): AstContractInit { + if (ast.kind === "contract_init") { + return ast; + } else { + throwInternalCompilerError( + `kind ${ast.kind} is not a contract initialization method`, + ); + } +} + +function ensureFieldDecl(ast: AstMutableNode): AstFieldDecl { + if (ast.kind === "field_decl") { + return ast; + } else { + throwInternalCompilerError( + `kind ${ast.kind} is not a field declaration`, + ); + } +} + +function ensureReceiver(ast: AstMutableNode): AstReceiver { + if (ast.kind === "receiver") { + return ast; + } else { + throwInternalCompilerError(`kind ${ast.kind} is not a receiver`); + } +} + +function ensureTypeDecl(ast: AstMutableNode): AstTypeDecl { + switch (ast.kind) { + case "contract": + case "message_decl": + case "primitive_type_decl": + case "struct_decl": + case "trait": + return ast; + default: + throwInternalCompilerError( + `kind ${ast.kind} is not a type declaration`, + ); + } +} diff --git a/src/optimizer/optimization_phase.ts b/src/optimizer/optimization_phase.ts deleted file mode 100644 index 37df6f4e5..000000000 --- a/src/optimizer/optimization_phase.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { CompilerContext } from "../context"; -import { prettyPrint } from "../prettyPrinter"; -import { - getAllStaticConstants, - getAllStaticFunctions, - getAllTypes, -} from "../types/resolveDescriptors"; -import { simplify_expressions } from "./expr_simplification"; -import { writeFileSync } from "fs"; - -export function optimize_tact(ctx: CompilerContext): CompilerContext { - // Call the expression simplification phase - ctx = simplify_expressions(ctx); - - // Here, we will call the constant propagation analyzer - - return ctx; -} - -export function dump_tact_code(ctx: CompilerContext, file: string) { - let program = ""; - - for (const c of getAllStaticConstants(ctx)) { - program += `${prettyPrint(c.ast)}\n`; - } - - program += "\n"; - - for (const f of getAllStaticFunctions(ctx)) { - program += `${prettyPrint(f.ast)}\n\n`; - } - - for (const t of getAllTypes(ctx)) { - program += `${prettyPrint(t.ast)}\n\n`; - } - - writeFileSync(file, program, { - flag: "w", - }); -} diff --git a/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap b/src/optimizer/test/__snapshots__/expr-simplification.spec.ts.snap similarity index 71% rename from src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap rename to src/optimizer/test/__snapshots__/expr-simplification.spec.ts.snap index 158bd8c9f..8362a83f2 100644 --- a/src/optimizer/test/__snapshots__/expr_simplification.spec.ts.snap +++ b/src/optimizer/test/__snapshots__/expr-simplification.spec.ts.snap @@ -140,6 +140,82 @@ exports[`expression-simplification should pass expression simplification for int "exprFun2(exprFun1(2)).c || exprFun1(c1).a > 0", "Bool", ], + [ + "v + 10", + "Int", + ], + [ + "v + 10 + 3", + "Int", + ], + [ + "3 + 7", + "Int", + ], + [ + "A {a: v + 10 + 3, b: 3 + 7}", + "A", + ], + [ + "B {nested: s, c: true}", + "B", + ], + [ + "exprFun1(0)", + "A", + ], + [ + "exprFun1(0).a", + "Int", + ], + [ + "exprFun1(2)", + "A", + ], + [ + "exprFun2(exprFun1(2))", + "B", + ], + [ + "exprFun1(1)", + "A", + ], + [ + "exprFun2(exprFun1(1))", + "B", + ], + [ + "exprFun2(exprFun1(1)).nested", + "A", + ], + [ + "exprFun1(2)", + "A", + ], + [ + "exprFun2(exprFun1(2))", + "B", + ], + [ + "exprFun2(exprFun1(2)).c", + "Bool", + ], + [ + "exprFun1(c1)", + "A", + ], + [ + "exprFun1(c1).a", + "Int", + ], + [ + "exprFun1(c1).a > 0", + "Bool", + ], + [ + "exprFun2(exprFun1(2)).c || exprFun1(c1).a > 0", + "Bool", + ], [ "13", "Int", diff --git a/src/optimizer/test/expr_simplification.spec.ts b/src/optimizer/test/expr-simplification.spec.ts similarity index 78% rename from src/optimizer/test/expr_simplification.spec.ts rename to src/optimizer/test/expr-simplification.spec.ts index e344d8992..0b58606f0 100644 --- a/src/optimizer/test/expr_simplification.spec.ts +++ b/src/optimizer/test/expr-simplification.spec.ts @@ -6,7 +6,8 @@ import { resolveDescriptors } from "../../types/resolveDescriptors"; import { getAllExpressionTypes } from "../../types/resolveExpression"; import { resolveStatements } from "../../types/resolveStatements"; import { loadCases } from "../../utils/loadCases"; -import { simplify_expressions } from "../expr_simplification"; +import { simplifyAllExpressions } from "../expr-simplification"; +import { prepareAstForOptimization } from "../optimization-phase"; describe("expression-simplification", () => { beforeEach(() => { @@ -22,8 +23,9 @@ describe("expression-simplification", () => { ctx = featureEnable(ctx, "external"); ctx = resolveDescriptors(ctx); ctx = resolveStatements(ctx); - ctx = simplify_expressions(ctx); - expect(getAllExpressionTypes(ctx)).toMatchSnapshot(); + const optCtx = prepareAstForOptimization(ctx, true); + simplifyAllExpressions(optCtx); + expect(getAllExpressionTypes(optCtx.ctx)).toMatchSnapshot(); }); } for (const r of loadCases(__dirname + "/failed/")) { @@ -36,8 +38,9 @@ describe("expression-simplification", () => { ctx = featureEnable(ctx, "external"); ctx = resolveDescriptors(ctx); ctx = resolveStatements(ctx); + const optCtx = prepareAstForOptimization(ctx, true); expect(() => { - simplify_expressions(ctx); + simplifyAllExpressions(optCtx); }).toThrowErrorMatchingSnapshot(); }); } diff --git a/src/optimizer/util.ts b/src/optimizer/util.ts index d8d5810a3..fab4c6674 100644 --- a/src/optimizer/util.ts +++ b/src/optimizer/util.ts @@ -10,9 +10,11 @@ import { idText, } from "../grammar/ast"; import { dummySrcInfo, SrcInfo } from "../grammar/grammar"; -import { throwInternalCompilerError } from "../errors"; import { StructValue, Value } from "../types/types"; +export class UnsupportedOperation extends Error {} + +// TODO: This method will disappear once Value type is refactored: issue #1190 export function extractValue(ast: AstValue): Value { switch (ast.kind) { case "null": @@ -36,6 +38,7 @@ export function extractValue(ast: AstValue): Value { } } +// TODO: This method will disappear once Value type is refactored: issue #1190 export function makeValueExpression( value: Value, baseSrc: SrcInfo = dummySrcInfo, @@ -100,7 +103,8 @@ export function makeValueExpression( }); return result as AstValue; } - throwInternalCompilerError( + // TODO: These types will be included once the Value type gets refactored. + throw new UnsupportedOperation( "addresses, cells, slices, and comment values are not supported as AST nodes at the moment.", ); } @@ -127,6 +131,7 @@ function valueToString(value: Value): string { if (typeof value === "boolean") { return value.toString(); } + // TODO: This code will change once Value type gets refactored: issue 1190 if (typeof value === "object" && "$tactStruct" in value) { const fields = Object.entries(value) .filter(([name, _]) => name !== "$tactStruct") @@ -136,7 +141,8 @@ function valueToString(value: Value): string { .join(", "); return `${value["$tactStruct"] as string} { ${fields} }`; } - throwInternalCompilerError( + // TODO: These types will be included once the Value type gets refactored: issue 1190 + throw new UnsupportedOperation( "Transformation of addresses, cells, slices or comment values into strings is not supported at the moment.", ); } diff --git a/src/pipeline/build.ts b/src/pipeline/build.ts index c1a5f5426..a3fc59353 100644 --- a/src/pipeline/build.ts +++ b/src/pipeline/build.ts @@ -20,8 +20,13 @@ import { compile } from "./compile"; import { precompile } from "./precompile"; import { getCompilerVersion } from "./version"; import { idText } from "../grammar/ast"; -import { TactErrorCollection } from "../errors"; -import { dump_tact_code, optimize_tact } from "../optimizer/optimization_phase"; +import { TactErrorCollection, throwInternalCompilerError } from "../errors"; +import { + dumpTactCode, + optimizeTact, + prepareAstForOptimization, + updateCompilerContext, +} from "../optimizer/optimization-phase"; export function enableFeatures( ctx: CompilerContext, @@ -87,37 +92,46 @@ export async function build(args: { return { ok: true, error: [] }; } - // Run high level optimization phase, if active in the options. - if ( + // Prepare ast for optimization phase (true = do it, false = skip it) + const doOptimizationFlag = config.options?.skipTactOptimizationPhase === undefined || - !config.options.skipTactOptimizationPhase - ) { - try { - if (config.options?.dumpOptimizedTactCode) { - // Dump the code before optimization - dump_tact_code( - ctx, - config.output + - `/${config.name}_unoptimized_tact_dump.tact`, - ); - } + !config.options.skipTactOptimizationPhase; - ctx = optimize_tact(ctx); + const optimizationCtx = prepareAstForOptimization(ctx, doOptimizationFlag); - if (config.options?.dumpOptimizedTactCode) { - // Dump the code after optimization - dump_tact_code( - ctx, - config.output + `/${config.name}_optimized_tact_dump.tact`, - ); - } + // Dump the code before optimization phase + if (config.options?.dumpCodeBeforeAndAfterTactOptimizationPhase) { + dumpTactCode( + optimizationCtx.originalAst, + config.output + `/${config.name}-unoptimized-tact-dump.tact`, + ); + } + + // Run high level optimization phase + if (doOptimizationFlag) { + try { + optimizeTact(optimizationCtx); + ctx = updateCompilerContext(optimizationCtx); } catch (e) { - logger.error("Tact code optimization failed."); - logger.error(e as Error); - return { ok: false, error: [e as Error] }; + // TODO: e is not an Error in general. Change interface of logger. + if (e instanceof Error) { + logger.error("Tact code optimization failed."); + logger.error(e); + return { ok: false, error: [e] }; + } else { + throwInternalCompilerError("Not an instance of Error"); + } } } + // Dump the code after optimization phase + if (config.options?.dumpCodeBeforeAndAfterTactOptimizationPhase) { + dumpTactCode( + optimizationCtx.modifiedAst, + config.output + `/${config.name}-optimized-tact-dump.tact`, + ); + } + // Compile contracts let ok = true; const errorMessages: TactErrorCollection[] = []; diff --git a/src/types/resolveDescriptors.ts b/src/types/resolveDescriptors.ts index 291511a97..557c83fb5 100644 --- a/src/types/resolveDescriptors.ts +++ b/src/types/resolveDescriptors.ts @@ -2008,36 +2008,6 @@ export function getStaticFunction( return r; } -export function replaceStaticConstants( - ctx: CompilerContext, - newConstants: Map, -): CompilerContext { - for (const [name, constDesc] of newConstants) { - ctx = staticConstantsStore.set(ctx, name, constDesc); - } - return ctx; -} - -export function replaceStaticFunctions( - ctx: CompilerContext, - newFunctions: Map, -): CompilerContext { - for (const [name, funcDesc] of newFunctions) { - ctx = staticFunctionsStore.set(ctx, name, funcDesc); - } - return ctx; -} - -export function replaceTypes( - ctx: CompilerContext, - newTypes: Map, -): CompilerContext { - for (const [name, typeDesc] of newTypes) { - ctx = store.set(ctx, name, typeDesc); - } - return ctx; -} - export function hasStaticFunction(ctx: CompilerContext, name: string) { return !!staticFunctionsStore.get(ctx, name); } diff --git a/src/types/resolveExpression.ts b/src/types/resolveExpression.ts index 9ebd01708..a26c4f907 100644 --- a/src/types/resolveExpression.ts +++ b/src/types/resolveExpression.ts @@ -53,6 +53,14 @@ export function getExpType(ctx: CompilerContext, exp: AstExpression) { return t.description; } +export function getExpTypeById(ctx: CompilerContext, id: number) { + const t = store.get(ctx, id); + if (!t) { + throwInternalCompilerError(`Expression ${id} not found`); + } + return t.description; +} + export function registerExpType( ctx: CompilerContext, exp: AstExpression,