diff --git a/.idea/.gitignore b/.idea/.gitignore index 13566b81b..a9d7db9c0 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -6,3 +6,5 @@ # Datasource local storage ignored files /dataSources/ /dataSources.local.xml +# GitHub Copilot persisted chat sessions +/copilot/chatSessions diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml index 919ce1f1f..edd5ef344 100644 --- a/.idea/codeStyles/Project.xml +++ b/.idea/codeStyles/Project.xml @@ -3,5 +3,288 @@ + + + + + GETTERS_AND_SETTERS + KEEP + + + OVERRIDDEN_METHODS + KEEP + + + +
+ + + + true + true + true + true + + + +
+
+ + + + true + true + true + true + + + +
+
+ + + + true + true + true + true + + + +
+
+ + + + true + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + true + + + +
+
+ + + + true + true + + + +
+
+ + + + true + true + + + +
+
+ + + + true + true + + + +
+
+ + + + true + true + + + +
+
+ + + true + + +
+
+ + + true + + +
+
+ + + true + + +
+
+ + + + true + true + + + BY_NAME + +
+
+ + + + true + true + true + + + +
+
+ + + true + + +
+
+ + + true + + +
+
+ + + true + + +
+
+ + + + true + true + + + +
+
+ + + true + + +
+
+
+
\ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 7ff59b18f..9cd683e93 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -57,7 +57,7 @@ - + \ No newline at end of file diff --git a/pom.xml b/pom.xml index 7fe4aa71d..ff0249201 100644 --- a/pom.xml +++ b/pom.xml @@ -67,7 +67,7 @@ + at their actual locations (bundled are overridden on recompile). --> @@ -196,19 +196,20 @@ - - com.github.spotbugs - spotbugs-maven-plugin - 4.8.2.0 - - - - - check - - - - + + + + + + + + + + + + + + org.apache.maven.plugins diff --git a/src/main/java/org/prlprg/RSession.java b/src/main/java/org/prlprg/RSession.java index 1d289b656..fcd6ba5e6 100644 --- a/src/main/java/org/prlprg/RSession.java +++ b/src/main/java/org/prlprg/RSession.java @@ -2,8 +2,11 @@ import org.prlprg.sexp.BaseEnvSXP; import org.prlprg.sexp.GlobalEnvSXP; +import org.prlprg.sexp.NamespaceEnvSXP; public interface RSession { + NamespaceEnvSXP baseNamespace(); + BaseEnvSXP baseEnv(); GlobalEnvSXP globalEnv(); @@ -13,4 +16,6 @@ public interface RSession { boolean isSpecial(String name); boolean isBuiltinInternal(String name); + + NamespaceEnvSXP getNamespace(String name, String version); } diff --git a/src/main/java/org/prlprg/bc/Bc.java b/src/main/java/org/prlprg/bc/Bc.java index b1170ded8..577a7521a 100644 --- a/src/main/java/org/prlprg/bc/Bc.java +++ b/src/main/java/org/prlprg/bc/Bc.java @@ -2,6 +2,7 @@ import com.google.common.primitives.ImmutableIntArray; import java.util.*; +import java.util.function.Function; import javax.annotation.Nullable; import org.prlprg.primitive.Constants; import org.prlprg.sexp.IntSXP; @@ -13,27 +14,13 @@ * constants. */ public record Bc(BcCode code, ConstPool consts) { + /** * The only version of R bytecodes we support, which is also the latest version. The bytecode's * version is denoted by the first integer in its code. */ public static final int R_BC_VERSION = 12; - /** - * Create from the raw GNU-R representation, bytecodes not including the initial version number. - */ - public static Bc fromRaw(ImmutableIntArray bytecodes, List consts) - throws BcFromRawException { - var poolAndMakeIdx = ConstPool.fromRaw(consts); - var pool = poolAndMakeIdx.first(); - var makePoolIdx = poolAndMakeIdx.second(); - try { - return new Bc(BcCode.fromRaw(bytecodes, makePoolIdx), pool); - } catch (BcFromRawException e) { - throw new BcFromRawException("malformed bytecode\nConstants: " + pool, e); - } - } - @Override public String toString() { return code + "\n" + consts; @@ -56,6 +43,7 @@ public static class Builder { private @Nullable IntSXP currentSrcRef = null; private boolean trackSrcRefs = true; private boolean trackExpressions = true; + private final Map> patches = new LinkedHashMap<>(); public void setTrackSrcRefs(boolean track) { this.trackSrcRefs = track; @@ -66,27 +54,30 @@ public void setTrackExpressions(boolean track) { } /** Append a constant and return its index. */ - public ConstPool.TypedIdx addConst(S c) { + public ConstPool.Idx addConst(S c) { return consts.add(c); } /** Append an instruction. */ - public void addInstr(BcInstr instr) { + public int addInstr(BcInstr instr) { + var idx = code.size(); code.add(instr); if (trackExpressions) { assert currentExpr != null; for (var i = 0; i <= instr.op().nArgs(); i++) { - expressions.add(addConst(currentExpr).idx); + expressions.add(addConst(currentExpr).idx()); } } if (trackSrcRefs) { assert currentSrcRef != null; for (var i = 0; i <= instr.op().nArgs(); i++) { - srcRefs.add(addConst(currentSrcRef).idx); + srcRefs.add(addConst(currentSrcRef).idx()); } } + + return idx; } public BcLabel makeLabel() { @@ -105,6 +96,10 @@ void patchLabel(BcLabel label) { * @return The bytecode. */ public Bc build() { + // this is the cb$patchlabels() + patches.forEach(code::patch); + + // this is the cb$commitlocs() if (trackExpressions) { var expressionsIndex = SEXPs.integer(expressions.build()).withClass("expressionsIndex"); addConst(expressionsIndex); @@ -131,5 +126,9 @@ public void setCurrentLoc(Loc loc) { public Loc getCurrentLoc() { return new Loc(trackExpressions ? currentExpr : null, trackSrcRefs ? currentSrcRef : null); } + + public void addInstrPatch(int instrIdx, Function patch) { + patches.put(instrIdx, patch); + } } } diff --git a/src/main/java/org/prlprg/bc/BcCode.java b/src/main/java/org/prlprg/bc/BcCode.java index 2b502ddad..b315142b8 100644 --- a/src/main/java/org/prlprg/bc/BcCode.java +++ b/src/main/java/org/prlprg/bc/BcCode.java @@ -2,11 +2,12 @@ import com.google.common.collect.ForwardingList; import com.google.common.collect.ImmutableList; -import com.google.common.primitives.ImmutableIntArray; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Objects; +import java.util.function.Function; import javax.annotation.concurrent.Immutable; /** @@ -26,69 +27,6 @@ protected List delegate() { return code; } - /** - * Create from the raw GNU-R representation, not including the initial version number. - * - * @param makePoolIdx A function to create pool indices from raw integers - */ - static BcCode fromRaw(ImmutableIntArray bytecodes, ConstPool.MakeIdx makePoolIdx) - throws BcFromRawException { - if (bytecodes.isEmpty()) { - throw new BcFromRawException("Bytecode is empty, needs at least version number"); - } - if (bytecodes.get(0) != Bc.R_BC_VERSION) { - throw new BcFromRawException("Unsupported bytecode version: " + bytecodes.get(0)); - } - - var labelMap = labelFactoryFromRaw(bytecodes); - - var builder = new Builder(); - int i = 1; - int sanityCheckJ = 0; - while (i < bytecodes.length()) { - try { - var instrAndI = BcInstrs.fromRaw(bytecodes, i, labelMap, makePoolIdx); - var instr = instrAndI.first(); - i = instrAndI.second(); - - builder.add(instr); - sanityCheckJ++; - - try { - var sanityCheckJFromI = labelMap.make(i).getTarget(); - if (sanityCheckJFromI != sanityCheckJ) { - throw new AssertionError( - "expected target offset " + sanityCheckJ + ", got " + sanityCheckJFromI); - } - } catch (IllegalArgumentException | AssertionError e) { - throw new AssertionError( - "BcInstrs.fromRaw and BcInstrs.sizeFromRaw are out of sync, at instruction " + instr, - e); - } - } catch (BcFromRawException e) { - throw new BcFromRawException( - "malformed bytecode at " + i + "\nBytecode up to this point: " + builder.build(), e); - } - } - return builder.build(); - } - - static BcLabel.Factory labelFactoryFromRaw(ImmutableIntArray bytecodes) { - var builder = new BcLabel.Factory.Builder(); - int i = 1; - while (i < bytecodes.length()) { - try { - var size = BcInstrs.sizeFromRaw(bytecodes, i); - builder.step(size, 1); - i += size; - } catch (BcFromRawException e) { - throw new BcFromRawException( - "malformed bytecode at " + i + "\nBytecode up to this point: " + builder.build(), e); - } - } - return builder.build(); - } - @Override public String toString() { StringBuilder sb = new StringBuilder("=== CODE ==="); @@ -99,6 +37,20 @@ public String toString() { return sb.toString(); } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + BcCode bcInstrs = (BcCode) o; + return Objects.equals(code, bcInstrs.code); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), code); + } + /** * A builder class for creating BcArray instances. * @@ -129,5 +81,10 @@ public BcCode build() { public int size() { return code.size(); } + + public void patch(int idx, Function patch) { + assert (idx >= 0 && idx < code.size()); + code.set(idx, patch.apply(code.get(idx))); + } } } diff --git a/src/main/java/org/prlprg/bc/BcFromRawException.java b/src/main/java/org/prlprg/bc/BcFromRawException.java deleted file mode 100644 index 792dea894..000000000 --- a/src/main/java/org/prlprg/bc/BcFromRawException.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.prlprg.bc; - -/** - * Exception thrown when a raw byte-array and constant list can't be converted to typed bytecode - * ({@link Bc}). - */ -public class BcFromRawException extends RuntimeException { - BcFromRawException(String message) { - super(message); - } - - BcFromRawException(String message, Throwable cause) { - super(message, cause); - } -} diff --git a/src/main/java/org/prlprg/bc/BcInstr.java b/src/main/java/org/prlprg/bc/BcInstr.java index 5785dfeac..b1bfae32c 100644 --- a/src/main/java/org/prlprg/bc/BcInstr.java +++ b/src/main/java/org/prlprg/bc/BcInstr.java @@ -1,11 +1,8 @@ package org.prlprg.bc; -import com.google.common.primitives.ImmutableIntArray; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import javax.annotation.Nullable; import org.prlprg.sexp.*; -import org.prlprg.util.Either; -import org.prlprg.util.Pair; /** * A single bytecode instruction, consists of an operation and arguments. The operation is @@ -31,7 +28,7 @@ public BcOp op() { } } - record BrIfNot(ConstPool.TypedIdx ast, BcLabel label) implements BcInstr { + record BrIfNot(ConstPool.Idx ast, BcLabel label) implements BcInstr { @Override public BcOp op() { return BcOp.BRIFNOT; @@ -87,8 +84,7 @@ public BcOp op() { } } - record StartFor( - ConstPool.TypedIdx ast, ConstPool.TypedIdx elemName, BcLabel end) + record StartFor(ConstPool.Idx ast, ConstPool.Idx elemName, BcLabel end) implements BcInstr { @Override public BcOp op() { @@ -124,7 +120,7 @@ public BcOp op() { } } - record LdConst(ConstPool.Idx constant) implements BcInstr { + record LdConst(ConstPool.Idx constant) implements BcInstr { @Override public BcOp op() { return BcOp.LDCONST; @@ -152,56 +148,56 @@ public BcOp op() { } } - record GetVar(ConstPool.TypedIdx name) implements BcInstr { + record GetVar(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETVAR; } } - record DdVal(ConstPool.TypedIdx name) implements BcInstr { + record DdVal(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.DDVAL; } } - record SetVar(ConstPool.TypedIdx name) implements BcInstr { + record SetVar(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.SETVAR; } } - record GetFun(ConstPool.TypedIdx name) implements BcInstr { + record GetFun(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETFUN; } } - record GetGlobFun(ConstPool.TypedIdx name) implements BcInstr { + record GetGlobFun(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETGLOBFUN; } } - record GetSymFun(ConstPool.TypedIdx name) implements BcInstr { + record GetSymFun(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETSYMFUN; } } - record GetBuiltin(ConstPool.TypedIdx name) implements BcInstr { + record GetBuiltin(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETBUILTIN; } } - record GetIntlBuiltin(ConstPool.TypedIdx name) implements BcInstr { + record GetIntlBuiltin(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETINTLBUILTIN; @@ -216,7 +212,7 @@ public BcOp op() { } /** {@code code} is usually but not always bytecode (see eval.c). */ - record MakeProm(ConstPool.Idx code) implements BcInstr { + record MakeProm(ConstPool.Idx code) implements BcInstr { @Override public BcOp op() { return BcOp.MAKEPROM; @@ -230,7 +226,7 @@ public BcOp op() { } } - record SetTag(@Nullable ConstPool.TypedIdx tag) implements BcInstr { + record SetTag(@Nullable ConstPool.Idx tag) implements BcInstr { @Override public BcOp op() { return BcOp.SETTAG; @@ -251,7 +247,7 @@ public BcOp op() { } } - record PushConstArg(ConstPool.Idx constant) implements BcInstr { + record PushConstArg(ConstPool.Idx constant) implements BcInstr { @Override public BcOp op() { return BcOp.PUSHCONSTARG; @@ -279,28 +275,28 @@ public BcOp op() { } } - record Call(ConstPool.TypedIdx ast) implements BcInstr { + record Call(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.CALL; } } - record CallBuiltin(ConstPool.TypedIdx ast) implements BcInstr { + record CallBuiltin(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.CALLBUILTIN; } } - record CallSpecial(ConstPool.TypedIdx ast) implements BcInstr { + record CallSpecial(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.CALLSPECIAL; } } - record MakeClosure(ConstPool.TypedIdx arg) implements BcInstr { + record MakeClosure(ConstPool.Idx arg) implements BcInstr { public ListSXP formals(ConstPool pool) { return (ListSXP) pool.get(this.arg).get(0); } @@ -320,126 +316,126 @@ public BcOp op() { } } - record UMinus(ConstPool.TypedIdx ast) implements BcInstr { + record UMinus(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.UMINUS; } } - record UPlus(ConstPool.TypedIdx ast) implements BcInstr { + record UPlus(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.UPLUS; } } - record Add(ConstPool.TypedIdx ast) implements BcInstr { + record Add(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.ADD; } } - record Sub(ConstPool.TypedIdx ast) implements BcInstr { + record Sub(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.SUB; } } - record Mul(ConstPool.TypedIdx ast) implements BcInstr { + record Mul(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.MUL; } } - record Div(ConstPool.TypedIdx ast) implements BcInstr { + record Div(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.DIV; } } - record Expt(ConstPool.TypedIdx ast) implements BcInstr { + record Expt(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.EXPT; } } - record Sqrt(ConstPool.TypedIdx ast) implements BcInstr { + record Sqrt(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.SQRT; } } - record Exp(ConstPool.TypedIdx ast) implements BcInstr { + record Exp(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.EXP; } } - record Eq(ConstPool.TypedIdx ast) implements BcInstr { + record Eq(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.EQ; } } - record Ne(ConstPool.TypedIdx ast) implements BcInstr { + record Ne(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.NE; } } - record Lt(ConstPool.TypedIdx ast) implements BcInstr { + record Lt(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.LT; } } - record Le(ConstPool.TypedIdx ast) implements BcInstr { + record Le(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.LE; } } - record Ge(ConstPool.TypedIdx ast) implements BcInstr { + record Ge(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.GE; } } - record Gt(ConstPool.TypedIdx ast) implements BcInstr { + record Gt(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.GT; } } - record And(ConstPool.TypedIdx ast) implements BcInstr { + record And(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.AND; } } - record Or(ConstPool.TypedIdx ast) implements BcInstr { + record Or(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.OR; } } - record Not(ConstPool.TypedIdx ast) implements BcInstr { + record Not(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.NOT; @@ -453,21 +449,21 @@ public BcOp op() { } } - record StartAssign(ConstPool.TypedIdx name) implements BcInstr { + record StartAssign(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.STARTASSIGN; } } - record EndAssign(ConstPool.TypedIdx name) implements BcInstr { + record EndAssign(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.ENDASSIGN; } } - record StartSubset(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubset(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBSET; @@ -481,7 +477,7 @@ public BcOp op() { } } - record StartSubassign(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubassign(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBASSIGN; @@ -495,7 +491,7 @@ public BcOp op() { } } - record StartC(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartC(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTC; @@ -509,7 +505,7 @@ public BcOp op() { } } - record StartSubset2(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubset2(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBSET2; @@ -523,7 +519,7 @@ public BcOp op() { } } - record StartSubassign2(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubassign2(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBASSIGN2; @@ -537,15 +533,14 @@ public BcOp op() { } } - record Dollar(ConstPool.TypedIdx ast, ConstPool.TypedIdx member) - implements BcInstr { + record Dollar(ConstPool.Idx ast, ConstPool.Idx member) implements BcInstr { @Override public BcOp op() { return BcOp.DOLLAR; } } - record DollarGets(ConstPool.TypedIdx ast, ConstPool.TypedIdx member) + record DollarGets(ConstPool.Idx ast, ConstPool.Idx member) implements BcInstr { @Override public BcOp op() { @@ -616,73 +611,74 @@ public BcOp op() { } } - // ???: call-idx can be negative? We make TypedIdx null to support this case, but not sure if + // ???: call-idx can be negative? We make TypedIdx null to support this case, + // but not sure if // it's possible. - // This applies to every other `@Nullable call` in this file. - record VecSubset(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + // This applies to every other `@Nullable call` in this file. + record VecSubset(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.VECSUBSET; } } - record MatSubset(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record MatSubset(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.MATSUBSET; } } - record VecSubassign(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record VecSubassign(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.VECSUBASSIGN; } } - record MatSubassign(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record MatSubassign(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.MATSUBASSIGN; } } - record And1st(ConstPool.TypedIdx ast, BcLabel shortCircuit) implements BcInstr { + record And1st(ConstPool.Idx ast, BcLabel shortCircuit) implements BcInstr { @Override public BcOp op() { return BcOp.AND1ST; } } - record And2nd(ConstPool.TypedIdx ast) implements BcInstr { + record And2nd(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.AND2ND; } } - record Or1st(ConstPool.TypedIdx ast, BcLabel shortCircuit) implements BcInstr { + record Or1st(ConstPool.Idx ast, BcLabel shortCircuit) implements BcInstr { @Override public BcOp op() { return BcOp.OR1ST; } } - record Or2nd(ConstPool.TypedIdx ast) implements BcInstr { + record Or2nd(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.OR2ND; } } - record GetVarMissOk(ConstPool.TypedIdx name) implements BcInstr { + record GetVarMissOk(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.GETVAR_MISSOK; } } - record DdValMissOk(ConstPool.TypedIdx name) implements BcInstr { + record DdValMissOk(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.DDVAL_MISSOK; @@ -696,35 +692,35 @@ public BcOp op() { } } - record SetVar2(ConstPool.TypedIdx name) implements BcInstr { + record SetVar2(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.SETVAR2; } } - record StartAssign2(ConstPool.TypedIdx name) implements BcInstr { + record StartAssign2(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.STARTASSIGN2; } } - record EndAssign2(ConstPool.TypedIdx name) implements BcInstr { + record EndAssign2(ConstPool.Idx name) implements BcInstr { @Override public BcOp op() { return BcOp.ENDASSIGN2; } } - record SetterCall(ConstPool.TypedIdx ast, ConstPool.Idx valueExpr) implements BcInstr { + record SetterCall(ConstPool.Idx ast, ConstPool.Idx valueExpr) implements BcInstr { @Override public BcOp op() { return BcOp.SETTER_CALL; } } - record GetterCall(ConstPool.TypedIdx ast) implements BcInstr { + record GetterCall(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.GETTER_CALL; @@ -746,11 +742,19 @@ public BcOp op() { } } + /** + * The OP_SWITCH instruction. + * + * @param ast + * @param names {@code null} represents {@code NilSXP} + * @param chrLabelsIdx {@code null} represents {@code NilSXP} + * @param numLabelsIdx {@code null} represents {@code NilSXP} + */ record Switch( - ConstPool.TypedIdx ast, - @Nullable Either, ConstPool.TypedIdx> names, - @Nullable ConstPool.TypedIdx cOffsets, - @Nullable ConstPool.TypedIdx iOffsets) + ConstPool.Idx ast, + @Nullable ConstPool.Idx names, + @Nullable ConstPool.Idx chrLabelsIdx, + @Nullable ConstPool.Idx numLabelsIdx) implements BcInstr { @Override public BcOp op() { @@ -765,140 +769,140 @@ public BcOp op() { } } - record StartSubsetN(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubsetN(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBSET_N; } } - record StartSubassignN(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubassignN(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBASSIGN_N; } } - record VecSubset2(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record VecSubset2(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.VECSUBSET2; } } - record MatSubset2(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record MatSubset2(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.MATSUBSET2; } } - record VecSubassign2(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record VecSubassign2(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.VECSUBASSIGN2; } } - record MatSubassign2(@Nullable ConstPool.TypedIdx ast) implements BcInstr { + record MatSubassign2(@Nullable ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.MATSUBASSIGN2; } } - record StartSubset2N(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubset2N(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBSET2_N; } } - record StartSubassign2N(ConstPool.TypedIdx ast, BcLabel after) implements BcInstr { + record StartSubassign2N(ConstPool.Idx ast, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.STARTSUBASSIGN2_N; } } - record SubsetN(@Nullable ConstPool.TypedIdx ast, int n) implements BcInstr { + record SubsetN(@Nullable ConstPool.Idx ast, int n) implements BcInstr { @Override public BcOp op() { return BcOp.SUBSET_N; } } - record Subset2N(@Nullable ConstPool.TypedIdx ast, int n) implements BcInstr { + record Subset2N(@Nullable ConstPool.Idx ast, int n) implements BcInstr { @Override public BcOp op() { return BcOp.SUBSET2_N; } } - record SubassignN(@Nullable ConstPool.TypedIdx ast, int n) implements BcInstr { + record SubassignN(@Nullable ConstPool.Idx ast, int n) implements BcInstr { @Override public BcOp op() { return BcOp.SUBASSIGN_N; } } - record Subassign2N(@Nullable ConstPool.TypedIdx ast, int n) implements BcInstr { + record Subassign2N(@Nullable ConstPool.Idx ast, int n) implements BcInstr { @Override public BcOp op() { return BcOp.SUBASSIGN2_N; } } - record Log(ConstPool.TypedIdx ast) implements BcInstr { + record Log(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.LOG; } } - record LogBase(ConstPool.TypedIdx ast) implements BcInstr { + record LogBase(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.LOGBASE; } } - record Math1(ConstPool.TypedIdx ast, int funId) implements BcInstr { + record Math1(ConstPool.Idx ast, int funId) implements BcInstr { @Override public BcOp op() { return BcOp.MATH1; } } - record DotCall(ConstPool.TypedIdx ast, int numArgs) implements BcInstr { + record DotCall(ConstPool.Idx ast, int numArgs) implements BcInstr { @Override public BcOp op() { return BcOp.DOTCALL; } } - record Colon(ConstPool.TypedIdx ast) implements BcInstr { + record Colon(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.COLON; } } - record SeqAlong(ConstPool.TypedIdx ast) implements BcInstr { + record SeqAlong(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.SEQALONG; } } - record SeqLen(ConstPool.TypedIdx ast) implements BcInstr { + record SeqLen(ConstPool.Idx ast) implements BcInstr { @Override public BcOp op() { return BcOp.SEQLEN; } } - record BaseGuard(ConstPool.TypedIdx expr, BcLabel after) implements BcInstr { + record BaseGuard(ConstPool.Idx expr, BcLabel after) implements BcInstr { @Override public BcOp op() { return BcOp.BASEGUARD; @@ -940,256 +944,3 @@ public BcOp op() { } } } - -class BcInstrs { - /** - * Create from the raw GNU-R representation. - * - * @param bytecodes The full list of GNU-R bytecodes including ones before and after this one. - * @param i The index in the list where this instruction starts. - * @param Label So we can create labels from GNU-R labels. - * @param makePoolIdx A function to create pool indices from raw integers. - * @return The instruction and the index in the list where the next instruction starts. - * @apiNote This has to be in a separate class because it's package-private but interface methods - * are public. - */ - static Pair fromRaw( - ImmutableIntArray bytecodes, int i, BcLabel.Factory Label, ConstPool.MakeIdx makePoolIdx) { - BcOp op; - try { - op = BcOp.valueOf(bytecodes.get(i++)); - } catch (IllegalArgumentException e) { - throw new BcFromRawException("invalid opcode (instruction) at " + bytecodes.get(i - 1)); - } - - try { - var instr = - switch (op) { - case BCMISMATCH -> - throw new BcFromRawException("invalid opcode " + BcOp.BCMISMATCH.value()); - case RETURN -> new BcInstr.Return(); - case GOTO -> new BcInstr.Goto(Label.make(bytecodes.get(i++))); - case BRIFNOT -> - new BcInstr.BrIfNot( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case POP -> new BcInstr.Pop(); - case DUP -> new BcInstr.Dup(); - case PRINTVALUE -> new BcInstr.PrintValue(); - case STARTLOOPCNTXT -> - new BcInstr.StartLoopCntxt(bytecodes.get(i++) != 0, Label.make(bytecodes.get(i++))); - case ENDLOOPCNTXT -> new BcInstr.EndLoopCntxt(bytecodes.get(i++) != 0); - case DOLOOPNEXT -> new BcInstr.DoLoopNext(); - case DOLOOPBREAK -> new BcInstr.DoLoopBreak(); - case STARTFOR -> - new BcInstr.StartFor( - makePoolIdx.lang(bytecodes.get(i++)), - makePoolIdx.sym(bytecodes.get(i++)), - Label.make(bytecodes.get(i++))); - case STEPFOR -> new BcInstr.StepFor(Label.make(bytecodes.get(i++))); - case ENDFOR -> new BcInstr.EndFor(); - case SETLOOPVAL -> new BcInstr.SetLoopVal(); - case INVISIBLE -> new BcInstr.Invisible(); - case LDCONST -> new BcInstr.LdConst(makePoolIdx.any(bytecodes.get(i++))); - case LDNULL -> new BcInstr.LdNull(); - case LDTRUE -> new BcInstr.LdTrue(); - case LDFALSE -> new BcInstr.LdFalse(); - case GETVAR -> new BcInstr.GetVar(makePoolIdx.sym(bytecodes.get(i++))); - case DDVAL -> new BcInstr.DdVal(makePoolIdx.sym(bytecodes.get(i++))); - case SETVAR -> new BcInstr.SetVar(makePoolIdx.sym(bytecodes.get(i++))); - case GETFUN -> new BcInstr.GetFun(makePoolIdx.sym(bytecodes.get(i++))); - case GETGLOBFUN -> new BcInstr.GetGlobFun(makePoolIdx.sym(bytecodes.get(i++))); - case GETSYMFUN -> new BcInstr.GetSymFun(makePoolIdx.sym(bytecodes.get(i++))); - case GETBUILTIN -> new BcInstr.GetBuiltin(makePoolIdx.sym(bytecodes.get(i++))); - case GETINTLBUILTIN -> new BcInstr.GetIntlBuiltin(makePoolIdx.sym(bytecodes.get(i++))); - case CHECKFUN -> new BcInstr.CheckFun(); - case MAKEPROM -> new BcInstr.MakeProm(makePoolIdx.any(bytecodes.get(i++))); - case DOMISSING -> new BcInstr.DoMissing(); - case SETTAG -> new BcInstr.SetTag(makePoolIdx.strOrSymOrNil(bytecodes.get(i++))); - case DODOTS -> new BcInstr.DoDots(); - case PUSHARG -> new BcInstr.PushArg(); - case PUSHCONSTARG -> new BcInstr.PushConstArg(makePoolIdx.any(bytecodes.get(i++))); - case PUSHNULLARG -> new BcInstr.PushNullArg(); - case PUSHTRUEARG -> new BcInstr.PushTrueArg(); - case PUSHFALSEARG -> new BcInstr.PushFalseArg(); - case CALL -> new BcInstr.Call(makePoolIdx.lang(bytecodes.get(i++))); - case CALLBUILTIN -> new BcInstr.CallBuiltin(makePoolIdx.lang(bytecodes.get(i++))); - case CALLSPECIAL -> new BcInstr.CallSpecial(makePoolIdx.lang(bytecodes.get(i++))); - case MAKECLOSURE -> - new BcInstr.MakeClosure(makePoolIdx.formalsBodyAndMaybeSrcRef(bytecodes.get(i++))); - case UMINUS -> new BcInstr.UMinus(makePoolIdx.lang(bytecodes.get(i++))); - case UPLUS -> new BcInstr.UPlus(makePoolIdx.lang(bytecodes.get(i++))); - case ADD -> new BcInstr.Add(makePoolIdx.lang(bytecodes.get(i++))); - case SUB -> new BcInstr.Sub(makePoolIdx.lang(bytecodes.get(i++))); - case MUL -> new BcInstr.Mul(makePoolIdx.lang(bytecodes.get(i++))); - case DIV -> new BcInstr.Div(makePoolIdx.lang(bytecodes.get(i++))); - case EXPT -> new BcInstr.Expt(makePoolIdx.lang(bytecodes.get(i++))); - case SQRT -> new BcInstr.Sqrt(makePoolIdx.lang(bytecodes.get(i++))); - case EXP -> new BcInstr.Exp(makePoolIdx.lang(bytecodes.get(i++))); - case EQ -> new BcInstr.Eq(makePoolIdx.lang(bytecodes.get(i++))); - case NE -> new BcInstr.Ne(makePoolIdx.lang(bytecodes.get(i++))); - case LT -> new BcInstr.Lt(makePoolIdx.lang(bytecodes.get(i++))); - case LE -> new BcInstr.Le(makePoolIdx.lang(bytecodes.get(i++))); - case GE -> new BcInstr.Ge(makePoolIdx.lang(bytecodes.get(i++))); - case GT -> new BcInstr.Gt(makePoolIdx.lang(bytecodes.get(i++))); - case AND -> new BcInstr.And(makePoolIdx.lang(bytecodes.get(i++))); - case OR -> new BcInstr.Or(makePoolIdx.lang(bytecodes.get(i++))); - case NOT -> new BcInstr.Not(makePoolIdx.lang(bytecodes.get(i++))); - case DOTSERR -> new BcInstr.DotsErr(); - case STARTASSIGN -> new BcInstr.StartAssign(makePoolIdx.sym(bytecodes.get(i++))); - case ENDASSIGN -> new BcInstr.EndAssign(makePoolIdx.sym(bytecodes.get(i++))); - case STARTSUBSET -> - new BcInstr.StartSubset( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case DFLTSUBSET -> new BcInstr.DfltSubset(); - case STARTSUBASSIGN -> - new BcInstr.StartSubassign( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case DFLTSUBASSIGN -> new BcInstr.DfltSubassign(); - case STARTC -> - new BcInstr.StartC( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case DFLTC -> new BcInstr.DfltC(); - case STARTSUBSET2 -> - new BcInstr.StartSubset2( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case DFLTSUBSET2 -> new BcInstr.DfltSubset2(); - case STARTSUBASSIGN2 -> - new BcInstr.StartSubassign2( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case DFLTSUBASSIGN2 -> new BcInstr.DfltSubassign2(); - case DOLLAR -> - new BcInstr.Dollar( - makePoolIdx.lang(bytecodes.get(i++)), makePoolIdx.sym(bytecodes.get(i++))); - case DOLLARGETS -> - new BcInstr.DollarGets( - makePoolIdx.lang(bytecodes.get(i++)), makePoolIdx.sym(bytecodes.get(i++))); - case ISNULL -> new BcInstr.IsNull(); - case ISLOGICAL -> new BcInstr.IsLogical(); - case ISINTEGER -> new BcInstr.IsInteger(); - case ISDOUBLE -> new BcInstr.IsDouble(); - case ISCOMPLEX -> new BcInstr.IsComplex(); - case ISCHARACTER -> new BcInstr.IsCharacter(); - case ISSYMBOL -> new BcInstr.IsSymbol(); - case ISOBJECT -> new BcInstr.IsObject(); - case ISNUMERIC -> new BcInstr.IsNumeric(); - case VECSUBSET -> new BcInstr.VecSubset(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case MATSUBSET -> new BcInstr.MatSubset(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case VECSUBASSIGN -> - new BcInstr.VecSubassign(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case MATSUBASSIGN -> - new BcInstr.MatSubassign(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case AND1ST -> - new BcInstr.And1st( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case AND2ND -> new BcInstr.And2nd(makePoolIdx.lang(bytecodes.get(i++))); - case OR1ST -> - new BcInstr.Or1st( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case OR2ND -> new BcInstr.Or2nd(makePoolIdx.lang(bytecodes.get(i++))); - case GETVAR_MISSOK -> new BcInstr.GetVarMissOk(makePoolIdx.sym(bytecodes.get(i++))); - case DDVAL_MISSOK -> new BcInstr.DdValMissOk(makePoolIdx.sym(bytecodes.get(i++))); - case VISIBLE -> new BcInstr.Visible(); - case SETVAR2 -> new BcInstr.SetVar2(makePoolIdx.sym(bytecodes.get(i++))); - case STARTASSIGN2 -> new BcInstr.StartAssign2(makePoolIdx.sym(bytecodes.get(i++))); - case ENDASSIGN2 -> new BcInstr.EndAssign2(makePoolIdx.sym(bytecodes.get(i++))); - case SETTER_CALL -> - new BcInstr.SetterCall( - makePoolIdx.lang(bytecodes.get(i++)), makePoolIdx.any(bytecodes.get(i++))); - case GETTER_CALL -> new BcInstr.GetterCall(makePoolIdx.lang(bytecodes.get(i++))); - case SWAP -> new BcInstr.SpecialSwap(); - case DUP2ND -> new BcInstr.Dup2nd(); - case SWITCH -> - new BcInstr.Switch( - makePoolIdx.lang(bytecodes.get(i++)), - makePoolIdx.strOrNilOrOther(bytecodes.get(i++)), - makePoolIdx.intOrOther(bytecodes.get(i++)), - makePoolIdx.intOrOther(bytecodes.get(i++))); - case RETURNJMP -> new BcInstr.ReturnJmp(); - case STARTSUBSET_N -> - new BcInstr.StartSubsetN( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case STARTSUBASSIGN_N -> - new BcInstr.StartSubassignN( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case VECSUBSET2 -> - new BcInstr.VecSubset2(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case MATSUBSET2 -> - new BcInstr.MatSubset2(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case VECSUBASSIGN2 -> - new BcInstr.VecSubassign2(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case MATSUBASSIGN2 -> - new BcInstr.MatSubassign2(makePoolIdx.langOrNegative(bytecodes.get(i++))); - case STARTSUBSET2_N -> - new BcInstr.StartSubset2N( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case STARTSUBASSIGN2_N -> - new BcInstr.StartSubassign2N( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case SUBSET_N -> - new BcInstr.SubsetN( - makePoolIdx.langOrNegative(bytecodes.get(i++)), bytecodes.get(i++)); - case SUBSET2_N -> - new BcInstr.Subset2N( - makePoolIdx.langOrNegative(bytecodes.get(i++)), bytecodes.get(i++)); - case SUBASSIGN_N -> - new BcInstr.SubassignN( - makePoolIdx.langOrNegative(bytecodes.get(i++)), bytecodes.get(i++)); - case SUBASSIGN2_N -> - new BcInstr.Subassign2N( - makePoolIdx.langOrNegative(bytecodes.get(i++)), bytecodes.get(i++)); - case LOG -> new BcInstr.Log(makePoolIdx.lang(bytecodes.get(i++))); - case LOGBASE -> new BcInstr.LogBase(makePoolIdx.lang(bytecodes.get(i++))); - case MATH1 -> - new BcInstr.Math1(makePoolIdx.lang(bytecodes.get(i++)), bytecodes.get(i++)); - case DOTCALL -> - new BcInstr.DotCall(makePoolIdx.lang(bytecodes.get(i++)), bytecodes.get(i++)); - case COLON -> new BcInstr.Colon(makePoolIdx.lang(bytecodes.get(i++))); - case SEQALONG -> new BcInstr.SeqAlong(makePoolIdx.lang(bytecodes.get(i++))); - case SEQLEN -> new BcInstr.SeqLen(makePoolIdx.lang(bytecodes.get(i++))); - case BASEGUARD -> - new BcInstr.BaseGuard( - makePoolIdx.lang(bytecodes.get(i++)), Label.make(bytecodes.get(i++))); - case INCLNK -> new BcInstr.IncLnk(); - case DECLNK -> new BcInstr.DecLnk(); - case DECLNK_N -> new BcInstr.DeclnkN(bytecodes.get(i++)); - case INCLNKSTK -> new BcInstr.IncLnkStk(); - case DECLNKSTK -> new BcInstr.DecLnkStk(); - }; - return new Pair<>(instr, i); - } catch (IllegalArgumentException e) { - throw new BcFromRawException("invalid opcode " + op + " (arguments)", e); - } catch (ArrayIndexOutOfBoundsException e) { - throw new BcFromRawException( - "invalid opcode " + op + " (arguments, unexpected end of bytecode stream)"); - } - } - - /** - * Get the GNU-R size of the instruction at the position without creating it. - * - * @param bytecodes The full list of GNU-R bytecodes including ones before and after this one. - * @param i The index in the list where this instruction starts. - * @return The size of the instruction (we don't return next position because it can be computed - * from this). - * @apiNote This has to be in a separate class because it's package-private but interface methods - * are public. - */ - @SuppressWarnings({"DuplicateBranchesInSwitch", "DuplicatedCode"}) - static int sizeFromRaw(ImmutableIntArray bytecodes, int i) { - BcOp op; - try { - op = BcOp.valueOf(bytecodes.get(i++)); - } catch (IllegalArgumentException e) { - throw new BcFromRawException("invalid opcode (instruction) " + bytecodes.get(i - 1)); - } - - try { - return 1 + op.nArgs(); - } catch (IllegalArgumentException e) { - throw new BcFromRawException("invalid opcode (arguments) " + op, e); - } catch (ArrayIndexOutOfBoundsException e) { - throw new BcFromRawException( - "invalid opcode (arguments, unexpected end of bytecode stream) " + op); - } - } -} diff --git a/src/main/java/org/prlprg/bc/BcLabel.java b/src/main/java/org/prlprg/bc/BcLabel.java index b040cbd92..1724dade3 100644 --- a/src/main/java/org/prlprg/bc/BcLabel.java +++ b/src/main/java/org/prlprg/bc/BcLabel.java @@ -1,8 +1,6 @@ package org.prlprg.bc; import com.google.common.base.Objects; -import com.google.common.primitives.ImmutableIntArray; -import com.google.errorprone.annotations.CanIgnoreReturnValue; /** A branch instruction destination. */ public final class BcLabel { @@ -37,87 +35,4 @@ public int hashCode() { public String toString() { return "BcLabel(" + target + ')'; } - - /** - * Create labels from GNU-R labels. - * - * @implNote This contains a map of positions in GNU-R bytecode to positions in our bytecode. We - * need this because every index in our bytecode maps to an instruction, while indexes in - * GNU-R's bytecode also map to the bytecode version and instruction metadata. - */ - static class Factory { - private final ImmutableIntArray posMap; - - private Factory(ImmutableIntArray posMap) { - this.posMap = posMap; - } - - /** Create a label from a GNU-R label. */ - BcLabel make(int gnurLabel) { - if (gnurLabel == 0) { - throw new IllegalArgumentException("GNU-R label 0 is reserved for the version number"); - } - - var target = posMap.get(gnurLabel); - if (target == -1) { - var gnurEarlier = gnurLabel - 1; - int earlier; - do { - earlier = posMap.get(gnurEarlier); - } while (earlier == -1); - var gnurLater = gnurLabel + 1; - int later; - do { - later = posMap.get(gnurLater); - } while (later == -1); - throw new IllegalArgumentException( - "GNU-R position maps to the middle of one of our instructions: " - + gnurLabel - + " between " - + earlier - + " and " - + later); - } - return new BcLabel(target); - } - - /** - * Create an object which creates labels from GNU-R labels, by building the map of positions in - * GNU-R bytecode to positions in our bytecode (see {@link Factory} implNote). - */ - static class Builder { - private final ImmutableIntArray.Builder map = ImmutableIntArray.builder(); - private int targetPc = 0; - - Builder() { - // Add initial mapping of 1 -> 0 (version # is 0) - map.add(-1); - map.add(0); - } - - /** Step m times in the source bytecode and n times in the target bytecode */ - @CanIgnoreReturnValue - Builder step(int sourceOffset, @SuppressWarnings("SameParameterValue") int targetOffset) { - if (sourceOffset < 0 || targetOffset < 0) { - throw new IllegalArgumentException("offsets must be nonnegative"); - } - - targetPc += targetOffset; - // Offsets before sourceOffset map to the middle of the previous instruction - for (int i = 0; i < sourceOffset - 1; i++) { - map.add(-1); - } - // Add target position - if (sourceOffset > 0) { - map.add(targetPc); - } - - return this; - } - - Factory build() { - return new Factory(map.build()); - } - } - } } diff --git a/src/main/java/org/prlprg/bc/BcOp.java b/src/main/java/org/prlprg/bc/BcOp.java index b0ddf2d25..268f8a936 100644 --- a/src/main/java/org/prlprg/bc/BcOp.java +++ b/src/main/java/org/prlprg/bc/BcOp.java @@ -151,7 +151,7 @@ public int value() { @SuppressWarnings("DuplicateBranchesInSwitch") public int nArgs() { return switch (this) { - case BCMISMATCH -> throw new BcFromRawException("invalid opcode " + BCMISMATCH.value()); + case BCMISMATCH -> throw new IllegalStateException("invalid opcode " + BCMISMATCH.value()); case RETURN -> 0; case GOTO -> 1; case BRIFNOT -> 2; diff --git a/src/main/java/org/prlprg/bc/Compiler.java b/src/main/java/org/prlprg/bc/Compiler.java index e10cb9644..647e64b3a 100644 --- a/src/main/java/org/prlprg/bc/Compiler.java +++ b/src/main/java/org/prlprg/bc/Compiler.java @@ -1,31 +1,42 @@ package org.prlprg.bc; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; +import static org.prlprg.sexp.SEXPType.*; + +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.ImmutableIntArray; +import java.util.*; +import java.util.function.*; +import java.util.stream.IntStream; import javax.annotation.Nullable; import org.prlprg.RSession; import org.prlprg.bc.BcInstr.*; import org.prlprg.sexp.*; -import org.prlprg.util.NotImplementedError; - -// FIXME: use null instead of Optional (except for return types) -// FIXME: update the SEXP API based on the experience with this code -// - especially the clumsy ListSXP -// TODO: 11.4 Inlining simple .Internal functions -// TODO: 12 The switch function -// TODO: 13 Assignments expressions -// TODO: 16 Improved subset and sub-assignment handling -// TODO: simple interpreter for the constantFoldCode +/** + * The R bytecode compiler that aims to be byte-to-byte compatible with GNU-R's bytecode compiler. + * + *

The compiler follows the implementation of the GNU-R bytecode compiler as described in the R + * compiler package documentation [1]. It uses the Method Object pattern, i.e., in the constructor + * it is set what shall be compiled and the {@link #compile()} method is called to perform the + * compilation. + * + *

In the comments below the {@code >> } indicates comments takes directly form [1]. + * + *

[1] A Byte Code Compiler for R by Luke Tierney, University of Iowa, accessed on August 23, + * 2023 from https://homepage.cs.uiowa.edu/~luke/R/compiler/compiler.pdf + */ public class Compiler { + + /** SEXP types that can participate in constan folding. */ + private static final Set ALLOWED_FOLDABLE_MODES = Set.of(LGL, INT, REAL, CPLX, STR); + private static final Set MAYBE_NSE_SYMBOLS = Set.of("bquote"); - private static final Set ALLOWED_INLINES = + + /** + * List of functions that gets special treatment when considering inlining (cf. {@link + * #getInlineInfo(String, boolean)}). + */ + private static final Set LANGUAGE_FUNS = Set.of( "^", "~", @@ -74,8 +85,12 @@ public class Compiler { "return", "switch"); - // one-parameter functions evaluated by the math1 function in arithmetic.c - // the order is important + /** + * The list of one-parameter functions evaluated by the math1 function in arithmetic.c + * + *

NOTE that the order is important (and has to be the same as in R)! The MATH1_OP takes an + * index into an array of math functions. + */ private static final List MATH1_FUNS = List.of( "floor", @@ -103,46 +118,134 @@ public class Compiler { "sinpi", "tanpi"); + private static final Set SAFE_BASE_INTERNALS = + Set.of( + "atan2", + "besselY", + "beta", + "choose", + "drop", + "inherits", + "is.vector", + "lbeta", + "lchoose", + "nchar", + "polyroot", + "typeof", + "vector", + "which.max", + "which.min", + "is.loaded", + "identical", + "match", + "rep.int", + "rep_len"); + + /** Black list of functions that should not be inlined. */ private static final Set FORBIDDEN_INLINES = Set.of("standardGeneric"); + /** should match DOTCALL_MAX in eval.c */ + private static final int DOTCALL_MAX = 16; + + /** R limit for the max vector size to participate in constant folding */ private static final int MAX_CONST_SIZE = 10; - // I also did not know: - // > x <- function() TRUE - // > .Internal(inspect(body(x))) - // @5c0d69684e48 10 LGLSXP g0c1 [REF(2)] (len=1, tl=0) 1 - // > x <- function() T - // > .Internal(inspect(body(x))) - // @5c0d6769c910 01 SYMSXP g0c0 [MARK,REF(4),LCK,gp=0x4000] "T" (has value) + /** + * The set of constants that can be folded. + * + *

Note (I also did not know): + * + *


+   * > x <- function() TRUE
+   * > .Internal(inspect(body(x)))
+   * @5c0d69684e48 10 LGLSXP g0c1 [REF(2)] (len=1, tl=0) 1
+   * > x <- function() T
+   * > .Internal(inspect(body(x)))
+   * @5c0d6769c910 01 SYMSXP g0c0 [MARK,REF(4),LCK,gp=0x4000] "T" (has value)
+   * 
+ */ private static final Set ALLOWED_FOLDABLE_CONSTS = Set.of("pi", "T", "F"); - private static final Set ALLOWED_FOLDABLE_FUNS = Set.of(); + /** The set of functions that can be folded. */ + private static final Set ALLOWED_FOLDABLE_FUNS = + Set.of("c", "+", "*", "/", ":", "-", "^", "(", "log2", "log", "sqrt", "rep", "seq.int"); - // should match DOTCALL_MAX in eval.c - private static final int DOTCALL_MAX = 16; + /** + * The set of functions that if encounted in loop body allow us to stop the search whether loop + * context can be avoided (cf. {@link #canSkipLoopContext(SEXP, boolean)}) + */ + private static final Set LOOP_STOP_FUNS = Set.of("function", "for", "while", "repeat"); + + /** + * The set of functions that has to be handled recursively in the search for loop context (cf. + * {@link #canSkipLoopContext(SEXP, boolean)}) + */ + private static final Set LOOP_TOP_FUNS = Set.of("(", "{", "if"); + + /** + * The set of functions that indicates the need for loop context (cf. {@link + * #canSkipLoopContext(SEXP, boolean)}) + */ + private static final Set LOOP_BREAK_FUNS = Set.of("break", "next"); + + /** + * The set of functions that indicates the need for loop context (cf. {@link + * #canSkipLoopContext(SEXP, boolean)}) + */ + private static final Set EVAL_FUNS = Set.of("eval", "evalq", "source"); + /** The target for the byte code of this compiler. */ private final Bc.Builder cb = new Bc.Builder(); + /** Corresponding R session used to lookup symbols. */ private final RSession rsession; /** The initial expression to compile. */ private final SEXP expr; + /** The current compilation context. */ private Context ctx; - /* - * 0 - No inlining - * 1 - Functions in the base packages found through a namespace that are not shadowed by - * function arguments or visible local assignments may be inlined. - * 2 - In addition to the inlining permitted by Level 1, functions that are syntactically special - * or are considered core language functions and are found via the global environment at compile - * time may be inlined. Other functions in the base packages found via the global environment - * may be inlined with a guard that ensures at runtime that the inlined function has not been - * masked; if it has, then the call in handled by the AST interpreter. - * 3 - Any function in the base packages found via the global environment may be inlined. + /** + * The optimization level: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
LevelDescription (from compiler.R)
0No inlining
1Functions in the base packages found through a namespace that are not + * shadowed by function arguments or visible local assignments may be inlined
2In addition to the inlining permitted by Level 1, functions that are + * syntactically special or are considered core language functions and are found via the global + * environment at compile time may be inlined. Other functions in the base packages found via the + * global environment may be inlined with a guard that ensures at runtime that the inlined function + * has not been masked; if it has, then the call in handled by the AST interpreter.
3Any function in the base packages found via the global environment may be inlined.
*/ private int optimizationLevel = 2; + /** + * Creates a compiler for the given expression, context, session, and location. + * + * @param expr the expression to be compiled + * @param ctx the context used for the compilation + * @param rsession the session used for symbol resolution + * @param loc the source code location of the expression + */ private Compiler(SEXP expr, Context ctx, RSession rsession, Loc loc) { this.expr = expr; this.ctx = ctx; @@ -159,10 +262,23 @@ private Compiler(SEXP expr, Context ctx, RSession rsession, Loc loc) { cb.setCurrentLoc(loc); } + /** + * Creates a compiler for the given function. + * + * @param fun the function to be compiled + * @param rsession the session used for symbol resolution + */ public Compiler(CloSXP fun, RSession rsession) { - this(fun.body(), Context.functionContext(fun), rsession, functionLoc(fun)); + this(fun.bodyAST(), Context.functionContext(fun), rsession, functionLoc(fun)); } + /** + * Creates a clone of the current state of the compiler except for the expression and context. + * + * @param expr the new expression to compile + * @param ctx the new context for the compilation + * @param loc the source code location of the expression + */ private Compiler fork(SEXP expr, Context ctx, Loc loc) { var compiler = new Compiler(expr, ctx, rsession, loc); compiler.setOptimizationLevel(optimizationLevel); @@ -173,142 +289,149 @@ public void setOptimizationLevel(int level) { this.optimizationLevel = level; } - public Bc compile() { - cb.addConst(expr); - compile(expr, false); - return cb.build(); - } - - private static Loc functionLoc(CloSXP fun) { - var body = fun.body(); - - IntSXP srcRef; - if (!(body instanceof LangSXP b - && b.fun() instanceof RegSymSXP sym - && sym.name().equals("{"))) { // FIXME: ugly - // try to get the srcRef from the function itself - // normally, it would be attached to the `{` - srcRef = fun.getSrcRef(); - } else { - srcRef = extractSrcRef(body, 0); + /** + * Executes the compilation and returns the compiled code if the expression does not contain any + * calls to the browser function. + * + * @return the compiled code or empty if the expression contains a call to the browser function + */ + public Optional compile() { + if (mayCallBrowser(expr)) { + return Optional.empty(); } - return new Loc(body, srcRef); + return Optional.of(genCode()); + } + + private Bc genCode() { + cb.addConst(expr); + compile(expr, false, false); + return cb.build(); } private void compile(SEXP expr) { - compile(expr, true); + compile(expr, false, true); } - private void compile(SEXP expr, boolean setLoc) { + /** + * Compiles an expression. This is the entry point for the recursive compilation process. + * + * @param expr the expression to compile + * @param missingOK passing this flag to {@code compileSym} + * @param setLoc whether to set the current location from {@code expr} + */ + private void compile(SEXP expr, boolean missingOK, boolean setLoc) { Loc loc = null; if (setLoc) { loc = cb.getCurrentLoc(); - cb.setCurrentLoc(new Loc(expr, extractSrcRef(expr, 0))); + cb.setCurrentLoc(new Loc(expr, extractSrcRef(expr, 0).orElse(null))); } - constantFold(expr) - .ifPresentOrElse( - this::compileConst, - () -> { - switch (expr) { - case LangSXP e -> compileCall(e, true); - case RegSymSXP e -> compileSym(e, false); - case SpecialSymSXP e -> stop("unhandled special symbol: "); - case PromSXP ignored -> stop("cannot compile promise literals in code"); - case BCodeSXP ignored -> stop("cannot compile byte code literals in code"); - default -> compileConst(expr); - } - }); + constantFold(expr).ifPresentOrElse(this::compileConst, () -> compileNonConst(expr, missingOK)); if (loc != null) { cb.setCurrentLoc(loc); } } - private void compileSym(RegSymSXP e, boolean missingOk) { - if (e.isEllipsis()) { - // TODO: notifyWrongDotsUse - cb.addInstr(new DotsErr()); - } else if (e.isDdSym()) { - // TODO: if (!findLocVar("...")) - // notifyWrongDotsUse - var idx = cb.addConst(e); - cb.addInstr(missingOk ? new DdValMissOk(idx) : new DdVal(idx)); - checkTailCall(); - } else { - // TODO: if (!findVar(sym)) - // notifyUndefVar - var idx = cb.addConst(e); - cb.addInstr(missingOk ? new GetVarMissOk(idx) : new GetVar(idx)); - checkTailCall(); + private void compileConst(SEXP val) { + switch (val) { + case NilSXP ignored -> cb.addInstr(new LdNull()); + case LglSXP x when x == SEXPs.TRUE -> cb.addInstr(new LdTrue()); + case LglSXP x when x == SEXPs.FALSE -> cb.addInstr(new LdFalse()); + default -> cb.addInstr(new LdConst(cb.addConst(val))); } + + tailCallReturn(); } - @SuppressFBWarnings( - value = "DLS_DEAD_LOCAL_STORE", - justification = "False positive, probably because of ignored switch case") - private void compileConst(SEXP expr) { + private void compileNonConst(SEXP expr, boolean missingOK) { switch (expr) { - case NilSXP ignored -> cb.addInstr(new LdNull()); - case LglSXP x when x == SEXPs.TRUE -> cb.addInstr(new LdTrue()); - case LglSXP x when x == SEXPs.FALSE -> cb.addInstr(new LdFalse()); - default -> cb.addInstr(new LdConst(cb.addConst(expr))); + case LangSXP e -> compileCall(e, true); + case RegSymSXP e -> compileSym(e, missingOK); + case SpecialSymSXP e -> stop("unhandled special symbol: " + e); + case PromSXP ignored -> stop("cannot compile promise literals in code"); + case BCodeSXP ignored -> stop("cannot compile byte code literals in code"); + default -> compileConst(expr); } + } - checkTailCall(); + /** + * Compiles a symbol. + * + * @param sym the symbol to compile + * @param missingOK specifies whether to use DDLVAL_MISSOK or DDLVAL instruction + */ + private void compileSym(RegSymSXP sym, boolean missingOK) { + if (sym.isEllipsis()) { + // TODO: notifyWrongDotsUse + cb.addInstr(new DotsErr()); + } else if (sym.isDdSym()) { + // TODO: notifyWrongDotsUse + var idx = cb.addConst(sym); + cb.addInstr(missingOK ? new DdValMissOk(idx) : new DdVal(idx)); + tailCallReturn(); + } else { + // TODO: notifyUndefVar + var idx = cb.addConst(sym); + cb.addInstr(missingOK ? new GetVarMissOk(idx) : new GetVar(idx)); + tailCallReturn(); + } } private void compileCall(LangSXP call, boolean canInline) { var loc = cb.getCurrentLoc(); - cb.setCurrentLoc(new Loc(call, extractSrcRef(call, 0))); + cb.setCurrentLoc(new Loc(call, extractSrcRef(call, 0).orElse(null))); var args = call.args(); switch (call.fun()) { case RegSymSXP fun -> { if (!(canInline && tryInlineCall(fun, call))) { - // TODO: check call - compileCallSymFun(fun, args, call); + // TODO: notifyBadCall + compileCallSymFun(call, fun, args); } } case SpecialSymSXP fun -> throw new IllegalStateException("Trying to call special symbol: " + fun); - case LangSXP fun -> { - if (fun.fun() instanceof RegSymSXP sym && LOOP_BREAK_FUNS.contains(sym.name())) { - // From the R source code: - // ## **** this hack is needed for now because of the way the - // ## **** parser handles break() and next() calls - // Consequently, the RDSReader returns a LangSXP(LangSXP(break/next, NULL), NULL) for - // break() and next() calls - compile(fun); - } else { - compileCallExprFun(fun, args, call); - } - } + case LangSXP fun -> + fun.funName() + .filter(LOOP_BREAK_FUNS::contains) + .ifPresentOrElse( + ignored -> + // >> ## **** this hack is needed for now because of the way the + // >> ## **** parser handles break() and next() calls + // >> Consequently, the RDSReader returns a LangSXP(LangSXP(break/next, NULL), + // NULL) for + // >> break() and next() calls + compile(fun), + () -> compileCallExprFun(call, fun, args)); } cb.setCurrentLoc(loc); } - private void compileCallSymFun(RegSymSXP fun, ListSXP args, LangSXP call) { + private void compileCallSymFun(LangSXP call, RegSymSXP fun, ListSXP args) { cb.addInstr(new GetFun(cb.addConst(fun))); var nse = MAYBE_NSE_SYMBOLS.contains(fun.name()); compileArgs(args, nse); cb.addInstr(new Call(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); } - private void compileCallExprFun(LangSXP fun, ListSXP args, LangSXP call) { + private void compileCallExprFun(LangSXP call, LangSXP fun, ListSXP args) { usingCtx(ctx.nonTailContext(), () -> compile(fun)); cb.addInstr(new CheckFun()); compileArgs(args, false); cb.addInstr(new Call(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); } - @SuppressFBWarnings( - value = "DLS_DEAD_LOCAL_STORE", - justification = "False positive, probably because of ignored switch case") + /** + * Compiles the arguments of a function call. + * + * @param args the arguments to compile + * @param nse whether to compile the arguments in non-standard evaluation mode + */ private void compileArgs(ListSXP args, boolean nse) { for (var arg : args) { var tag = arg.tag(); @@ -320,8 +443,7 @@ private void compileArgs(ListSXP args, boolean nse) { compileTag(tag); } case SymSXP x when x.isEllipsis() -> - // TODO: if (!findLocVar("...")) - // notifyWrongDotsUse + // TODO: notifyWrongDotsUse cb.addInstr(new DoDots()); case SymSXP x -> { compileNormArg(x, nse); @@ -342,17 +464,18 @@ private void compileArgs(ListSXP args, boolean nse) { } private void compileNormArg(SEXP arg, boolean nse) { - if (!nse) { + ConstPool.Idx idx; + + if (nse) { + idx = cb.addConst(arg); + } else { var compiler = fork(arg, ctx.promiseContext(), cb.getCurrentLoc()); - var bc = compiler.compile(); - arg = SEXPs.bcode(bc); + idx = cb.addConst(SEXPs.bcode(compiler.genCode())); } - cb.addInstr(new MakeProm(cb.addConst(arg))); + + cb.addInstr(new MakeProm(idx)); } - @SuppressFBWarnings( - value = "DLS_DEAD_LOCAL_STORE", - justification = "False positive, probably because of ignored switch case") private void compileConstArg(SEXP arg) { switch (arg) { case NilSXP ignored -> cb.addInstr(new PushNullArg()); @@ -364,161 +487,252 @@ private void compileConstArg(SEXP arg) { private void compileTag(@Nullable String tag) { if (tag != null && !tag.isEmpty()) { - cb.addInstr(new SetTag(cb.addConst(SEXPs.string(tag)))); + cb.addInstr(new SetTag(cb.addConst(SEXPs.symbol(tag)))); } } - private void checkTailCall() { + /** A helper to add a return when the context is in tail position. */ + @SuppressWarnings("UnusedReturnValue") + private boolean tailCallReturn() { if (ctx.isTailCall()) { cb.addInstr(new Return()); - } - } - - private boolean tryInlineCall(RegSymSXP fun, LangSXP call) { - if (optimizationLevel == 0) { + return true; + } else { return false; } + } - // it seems that there is no way to pattern match on Optional - var binding = ctx.resolve(fun.name()).orElse(null); - if (binding == null) { - return false; - } - if (binding.first() instanceof BaseEnvSXP) { - return tryInlineBase(fun.name(), call, true); + /** A helper to add an invisible return when the context is in tail position. */ + private boolean tailCallInvisibleReturn() { + if (ctx.isTailCall()) { + cb.addInstr(new Invisible()); + cb.addInstr(new Return()); + return true; } else { return false; } } - private @Nullable Function getBaseInlineHandler(String name) { - // feels better this way than to pay the price to allocate the whole, big - // inlining table at class construction - return switch (name) { - case "{" -> this::inlineBlock; - case "if" -> this::inlineCondition; - case "function" -> this::inlineFunction; - case "(" -> this::inlineParentheses; - case "local" -> this::inlineLocal; - case "return" -> this::inlineReturn; - case ".Internal" -> this::inlineInternal; - case "&&" -> (c) -> inlineLogicalAndOr(c, true); - case "||" -> (c) -> inlineLogicalAndOr(c, false); - case "repeat" -> this::inlineRepeat; - case "break" -> (c) -> inlineBreakNext(c, true); - case "next" -> (c) -> inlineBreakNext(c, false); - case "while" -> this::inlineWhile; - case "for" -> this::inlineFor; - case "+" -> (c) -> inlineAddSub(c, true); - case "-" -> (c) -> inlineAddSub(c, false); - case "*" -> (c) -> inlinePrim2(c, Mul::new); - case "/" -> (c) -> inlinePrim2(c, Div::new); - case "^" -> (c) -> inlinePrim2(c, Expt::new); - case "exp" -> (c) -> inlinePrim1(c, Exp::new); - case "sqrt" -> (c) -> inlinePrim1(c, Sqrt::new); - case "log" -> this::inlineLog; - case "==" -> (c) -> inlinePrim2(c, Eq::new); - case "!=" -> (c) -> inlinePrim2(c, Ne::new); - case "<" -> (c) -> inlinePrim2(c, Lt::new); - case "<=" -> (c) -> inlinePrim2(c, Le::new); - case ">" -> (c) -> inlinePrim2(c, Gt::new); - case ">=" -> (c) -> inlinePrim2(c, Ge::new); - case "&" -> (c) -> inlinePrim2(c, And::new); - case "|" -> (c) -> inlinePrim2(c, Or::new); - case "!" -> (c) -> inlinePrim1(c, Not::new); - case "$" -> this::inlineDollar; - case "is.character" -> (c) -> inlineIsXyz(c, IsCharacter::new); - case "is.complex" -> (c) -> inlineIsXyz(c, IsComplex::new); - case "is.double" -> (c) -> inlineIsXyz(c, IsDouble::new); - case "is.integer" -> (c) -> inlineIsXyz(c, IsInteger::new); - case "is.logical" -> (c) -> inlineIsXyz(c, IsLogical::new); - case "is.name" -> (c) -> inlineIsXyz(c, IsSymbol::new); - case "is.null" -> (c) -> inlineIsXyz(c, IsNull::new); - case "is.object" -> (c) -> inlineIsXyz(c, IsObject::new); - case "is.symbol" -> (c) -> inlineIsXyz(c, IsSymbol::new); - case ".Call" -> this::inlineDotCall; - case ":" -> (c) -> inlinePrim2(c, Colon::new); - case "seq_along" -> (c) -> inlinePrim1(c, SeqAlong::new); - case "seq_len" -> (c) -> inlinePrim1(c, SeqLen::new); - case "::", ":::" -> this::inlineMultiColon; - case "with", "require" -> this::compileSuppressingUndefined; - case String s when MATH1_FUNS.contains(s) -> (c) -> inlineMath1(c, MATH1_FUNS.indexOf(s)); - case String s when rsession.isBuiltin(s) -> (c) -> inlineBuiltin(c, false); - case String s when rsession.isSpecial(s) -> this::inlineSpecial; - default -> null; - }; - } - /** - * Tries to inline a function from the base package. + * The entry point for the inlining process. It will either inline the given call or return false. * - * @param name - * @param call + * @param fun the function to inline + * @param call the entire call * @return true if the function was inlined, false otherwise */ - private boolean tryInlineBase(String name, LangSXP call, boolean allowWithGuard) { - boolean guarded = false; - - if (FORBIDDEN_INLINES.contains(name)) { + private boolean tryInlineCall(RegSymSXP fun, LangSXP call) { + if (optimizationLevel == 0) { return false; } - if (optimizationLevel < 1) { - return false; - } + return getInlineInfo(fun.name(), true) + .filter(info -> info.env().isBase()) + .map(info -> tryInlineBase(call, info)) + .orElse(false); + } - if (optimizationLevel == 2 && !ALLOWED_INLINES.contains(name)) { - if (allowWithGuard) { - guarded = true; - } else { - return false; - } - } + private Optional> getBaseInlineHandler(String name) { + // feels better this way than to pay the price to allocate the whole, big + // inlining table at class construction + Function fun = + switch (name) { + case "{" -> this::inlineBlock; + case "if" -> this::inlineCondition; + case "function" -> this::inlineFunction; + case "(" -> this::inlineParentheses; + case "local" -> this::inlineLocal; + case "return" -> this::inlineReturn; + case ".Internal" -> this::inlineDotInternalCall; + case "&&" -> (c) -> inlineLogicalAndOr(c, true); + case "||" -> (c) -> inlineLogicalAndOr(c, false); + case "repeat" -> this::inlineRepeat; + case "break" -> (c) -> inlineBreakNext(c, true); + case "next" -> (c) -> inlineBreakNext(c, false); + case "while" -> this::inlineWhile; + case "for" -> this::inlineFor; + case "+" -> (c) -> inlineAddSub(c, true); + case "-" -> (c) -> inlineAddSub(c, false); + case "*" -> (c) -> inlinePrim2(c, Mul::new); + case "/" -> (c) -> inlinePrim2(c, Div::new); + case "^" -> (c) -> inlinePrim2(c, Expt::new); + case "exp" -> (c) -> inlinePrim1(c, Exp::new); + case "sqrt" -> (c) -> inlinePrim1(c, Sqrt::new); + case "log" -> this::inlineLog; + case "==" -> (c) -> inlinePrim2(c, Eq::new); + case "!=" -> (c) -> inlinePrim2(c, Ne::new); + case "<" -> (c) -> inlinePrim2(c, Lt::new); + case "<=" -> (c) -> inlinePrim2(c, Le::new); + case ">" -> (c) -> inlinePrim2(c, Gt::new); + case ">=" -> (c) -> inlinePrim2(c, Ge::new); + case "&" -> (c) -> inlinePrim2(c, And::new); + case "|" -> (c) -> inlinePrim2(c, Or::new); + case "!" -> (c) -> inlinePrim1(c, Not::new); + case "$" -> this::inlineDollar; + case "is.character" -> (c) -> inlineIsXyz(c, IsCharacter::new); + case "is.complex" -> (c) -> inlineIsXyz(c, IsComplex::new); + case "is.double" -> (c) -> inlineIsXyz(c, IsDouble::new); + case "is.integer" -> (c) -> inlineIsXyz(c, IsInteger::new); + case "is.logical" -> (c) -> inlineIsXyz(c, IsLogical::new); + case "is.name" -> (c) -> inlineIsXyz(c, IsSymbol::new); + case "is.null" -> (c) -> inlineIsXyz(c, IsNull::new); + case "is.object" -> (c) -> inlineIsXyz(c, IsObject::new); + case "is.symbol" -> (c) -> inlineIsXyz(c, IsSymbol::new); + case ".Call" -> this::inlineDotCall; + case ":" -> (c) -> inlinePrim2(c, Colon::new); + case "seq_along" -> (c) -> inlinePrim1(c, SeqAlong::new); + case "seq_len" -> (c) -> inlinePrim1(c, SeqLen::new); + case "::", ":::" -> this::inlineMultiColon; + case "with", "require" -> this::compileSuppressingUndefined; + case "switch" -> this::inlineSwitch; + case "<-", "=", "<<-" -> this::inlineAssign; + case "[" -> (call) -> inlineSubset(false, call); + case "[[" -> (call) -> inlineSubset(true, call); + case String s when MATH1_FUNS.contains(s) -> (c) -> inlineMath1(c, MATH1_FUNS.indexOf(s)); + case String s when rsession.isBuiltin(s) -> (c) -> inlineBuiltin(c, false); + case String s when rsession.isSpecial(s) -> this::inlineSpecial; + case String s when SAFE_BASE_INTERNALS.contains(s) -> this::inlineSimpleInternal; + default -> null; + }; - var inline = getBaseInlineHandler(name); - if (inline == null) { - return false; + return Optional.ofNullable(fun); + } + + private Optional> getSetterInlineHandler( + InlineInfo info) { + if (!(info.env().isBase())) { + return Optional.empty(); } - if (guarded) { - var end = cb.makeLabel(); - usingCtx( - ctx.nonTailContext(), - () -> { - // The BASEGUARD checks the validity of the inline code, i.e. if what - // was from base at compile time hasn't changed. - // if the inlined code is not valid the guard instruction will evaluate the call in - // the AST interpreter and jump over the inlined code. - cb.addInstr(new BaseGuard(cb.addConst(call), end)); - if (!inline.apply(call)) { - // At this point the guard is useless and the following code - // should run. - // I guess the likelihood that something changed is slim, - // to care about removing it. - compileCall(call, false); - } - }); - cb.patchLabel(end); - checkTailCall(); - return true; - } else { - return inline.apply(call); + BiFunction fun = + switch (info.name) { + case "$<-" -> this::inlineDollarAssign; + case "[<-" -> (flhs, call) -> inlineSquareBracketAssign(false, flhs, call); + case "[[<-" -> (flhs, call) -> inlineSquareBracketAssign(true, flhs, call); + case "@<-" -> this::inlineSlotAssign; + default -> null; + }; + + return Optional.ofNullable(fun); + } + + private Optional> getGetterInlineHandler(InlineInfo info) { + if (!info.env().isBase()) { + return Optional.empty(); } + + Function fun = + switch (info.name) { + case "$" -> this::inlineDollarSubset; + case "[" -> (call) -> inlineSquareBracketSubSet(false, call); + case "[[" -> (call) -> inlineSquareBracketSubSet(true, call); + default -> null; + }; + + return Optional.ofNullable(fun); } /** - * From the R documentation: + * Returns inline information for the given function name or empty if the function is not + * inlinable. + * + *

It tries to resolve the function name in the current context {@link #ctx} and then based on + * the given rules from R compiler figure out whether the function should be inlinable or not. * - *

The inlining handler for `{` needs to consider that a pair of braces { and } can - * surround zero, one, or more expressions. A set of empty braces is equivalent to the constant - * NULL. If there is more than one expression, then all the values of all expressions other than - * the last are ignored. These expressions are compiled in a no-value context (currently - * equivalent to a non-tail-call context), and then code is generated to pop their values off the - * stack. The final expression is then compiled according to the context in which the braces - * expression occurs. + * @param name the name of the function to inline + * @param guardOK whether to allow inlining with a guard + * @return the inline information or empty if the function is not inlinable + */ + private Optional getInlineInfo(String name, boolean guardOK) { + if (FORBIDDEN_INLINES.contains(name) || optimizationLevel < 1) { + return Optional.empty(); + } + + // FIXME: this considers everything else "global" which is not true, but cannot be + // fixed until we have a proper environment chain supported in the Rsession + return ctx.resolve(name) + .map( + res -> { + if (res.first() instanceof NamespaceEnvSXP) { + // if is in a namespace we do not have to worry about + // shadowing + return new InlineInfo(name, res.first(), res.second(), false); + } else if (optimizationLevel >= 3 + || (optimizationLevel == 2 && LANGUAGE_FUNS.contains(name))) { + return new InlineInfo(name, res.first(), res.second(), false); + } else if (guardOK && res.first().isBase()) { + // this is the case when the function comes from baseenv() + // therefore it could be shadowed by some other function + // and thus needs to be guarded + return new InlineInfo(name, res.first(), res.second(), true); + } else { + return null; + } + }); + } + + /** + * Tries to inline a function from the base package. * - * @param call + * @param call the call to inline + * @param info the inline information + * @return true if the function was inlined, false otherwise */ + private boolean tryInlineBase(LangSXP call, InlineInfo info) { + assert (info.env().isBase()); + + return getBaseInlineHandler(info.name) + .map( + inline -> { + if (info.guard()) { + var end = cb.makeLabel(); + + usingCtx( + ctx.nonTailContext(), + () -> { + // >> The BASEGUARD checks the validity of the inline code, i.e. if what + // >> was from base at compile time hasn't changed. If the inlined code is + // >> not valid the guard instruction will evaluate the call in the AST + // >> interpreter and jump over the inlined code. + cb.addInstr(new BaseGuard(cb.addConst(call), end)); + if (!inline.apply(call)) { + // if the inlining failed, we need to compile the call + compileCall(call, false); + } + }); + cb.patchLabel(end); + tailCallReturn(); + return true; + } else { + return inline.apply(call); + } + }) + .orElse(false); + } + + private boolean trySetterInline(RegSymSXP funSym, FlattenLHS flhs, LangSXP call) { + return getInlineInfo(funSym.name(), false) + .flatMap(this::getSetterInlineHandler) + .map(handler -> handler.apply(flhs, call)) + .orElse(false); + } + + private boolean tryGetterInline(RegSymSXP funSym, LangSXP call) { + return getInlineInfo(funSym.name(), false) + .flatMap(this::getGetterInlineHandler) + .map(handler -> handler.apply(call)) + .orElse(false); + } + + // >> The inlining handler for `{` needs to consider that a pair of braces { and } can + // >> surround zero, one, or more expressions. A set of empty braces is equivalent to the + // constant + // >> NULL. If there is more than one expression, then all the values of all expressions other + // than + // >> the last are ignored. These expressions are compiled in a no-value context (currently + // >> equivalent to a non-tail-call context), and then code is generated to pop their values + // off the + // >> stack. The final expression is then compiled according to the context in which the braces + // >> expression occurs. private boolean inlineBlock(LangSXP call) { var n = call.args().size(); if (n == 0) { @@ -531,30 +745,47 @@ private boolean inlineBlock(LangSXP call) { ctx.nonTailContext(), () -> { for (var i = 0; i < n - 1; i++) { - var arg = call.arg(i).value(); + var arg = call.arg(i); // i + 1 because the block srcref's first element is the opening brace - cb.setCurrentLoc(new Loc(arg, extractSrcRef(call, i + 1))); - compile(arg, false); + cb.setCurrentLoc(new Loc(arg, extractSrcRef(call, i + 1).orElse(null))); + compile(arg, false, false); cb.addInstr(new Pop()); } }); } - var last = call.arg(n - 1).value(); - cb.setCurrentLoc(new Loc(last, extractSrcRef(call, n))); - compile(last, false); + var last = call.arg(n - 1); + cb.setCurrentLoc(new Loc(last, extractSrcRef(call, n).orElse(null))); + compile(last, false, false); cb.setCurrentLoc(loc); } return true; } + // all inline functions have the same signature + // (it is then easier to reference the method in the getInlineHandler) + @SuppressWarnings("SameReturnValue") private boolean inlineCondition(LangSXP call) { - var test = call.arg(0).value(); - var thenBranch = call.arg(1).value(); - var elseBranch = Optional.ofNullable(call.args().size() == 3 ? call.arg(2).value() : null); - - // TODO: constant fold + var test = call.arg(0); + var thenBranch = call.arg(1); + var elseBranch = Optional.ofNullable(call.args().size() == 3 ? call.arg(2) : null); + + var ct = constantFold(test).orElse(null); + + if (ct instanceof LglSXP lgl && lgl.isScalar()) { + if (lgl == SEXPs.TRUE) { + compile(thenBranch); + } else if (elseBranch.isPresent()) { + compile(elseBranch.get()); + } else if (ctx.isTailCall()) { + cb.addInstr(new LdNull()); + tailCallInvisibleReturn(); + } else { + cb.addInstr(new LdNull()); + } + return true; + } usingCtx(ctx.nonTailContext(), () -> compile(test)); @@ -569,99 +800,54 @@ private boolean inlineCondition(LangSXP call) { this::compile, () -> { cb.addInstr(new LdNull()); - cb.addInstr(new Invisible()); - cb.addInstr(new Return()); + tailCallInvisibleReturn(); }); } else { var endLabel = cb.makeLabel(); cb.addInstr(new Goto(endLabel)); cb.patchLabel(elseLabel); - elseBranch.ifPresentOrElse( - this::compile, - () -> { - cb.addInstr(new LdNull()); - }); + elseBranch.ifPresentOrElse(this::compile, () -> cb.addInstr(new LdNull())); cb.patchLabel(endLabel); } return true; } - /** - * From the R documentation: - * - *

Compiling of function expressions is somewhat similar to compiling promises for - * function arguments. The body of a function is compiled into a separate byte code object and - * stored in the constant pool together with the formals. Then code is emitted for creating a - * closure from the formals, compiled body, and the current environment. For now, only the body of - * functions is compiled, not the default argument expressions. This should be changed in future - * versions of the compiler. - * - * @param call - */ private boolean inlineFunction(LangSXP call) { - // TODO: sourcerefs - // TODO: if (mayCallBrowser(body, cntxt)) return(FALSE) - - var formals = (ListSXP) call.arg(0).value(); - var body = call.arg(1).value(); - var sref = call.args().size() > 2 ? call.arg(2).value() : SEXPs.NULL; + var formals = (ListSXP) call.arg(0); + var body = call.arg(1); + if (mayCallBrowser(body)) { + return false; + } + var sref = call.args().size() > 2 ? call.arg(2) : SEXPs.NULL; var compiler = fork(body, ctx.functionContext(formals, body), cb.getCurrentLoc()); - var cbody = compiler.compile(); - var cbodysxp = SEXPs.bcode(cbody); + var cbody = compiler.compile().map(SEXPs::bcode).orElse(body); - // FIXME: ugly - var cnst = SEXPs.vec(formals, cbodysxp, sref); + // this is not CLOSXP, but vector of these elements as + // required by the bytecode op + var cfun = SEXPs.vec(formals, cbody, sref); - cb.addInstr(new MakeClosure(cb.addConst(cnst))); + cb.addInstr(new MakeClosure(cb.addConst(cfun))); - checkTailCall(); + tailCallReturn(); return true; } - /** - * From the R documentation: - * - *

- * In R an expression of the form (expr) is interpreted as a call to the function ( with the argument - * expr. Parentheses are used to guide the parser, and for the most part (expr) is equivalent to expr. - * There are two exceptions: - *

    - *
  • Since ( is a function an expression of the form (...) is legal whereas - * just ... may not be, - * depending on the context. A runtime error will occur unless the ... argument expands to - * exactly one non-missing argument.
  • - *
  • In tail position a call to ( sets the visible flag to TRUE. So at top level for example the result - * of an assignment expression x <- 1 would not be printed, but the result of (x <- 1 - * would be printed. It is not clear that this feature really needs to be preserved within - * functions — it could be made a feature of the read-eval-print loop — but for now it is a - * feature of the interpreter that the compiler should preserve.
  • - *
- * - * The inlining handler for ( calls handles a ... argument case or a case with - * fewer or more than one argument as a generic BUILTIN call. If the expression is in tail position - * then the argument is compiled in a non-tail-call context, a VISIBLE instruction is emitted to set - * the visible flag to TRUE, and a RETURN instruction is emitted. If the expression is in non-tail - * position, then code for the argument is generated in the current context. - *

- * - * @param call - */ private boolean inlineParentheses(LangSXP call) { if (anyDots(call.args())) { return inlineBuiltin(call, false); } else if (call.args().size() != 1) { - // TODO: notifyWrongArgCount("(", cntxt, loc = cb$savecurloc()) + // TODO: notifyWrongArgCount return inlineBuiltin(call, false); } else if (ctx.isTailCall()) { - usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0).value())); + usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0))); cb.addInstr(new Visible()); cb.addInstr(new Return()); return true; } else { - compile(call.arg(0).value()); + compile(call.arg(0)); return true; } } @@ -688,7 +874,7 @@ private boolean inlineBuiltin(LangSXP call, boolean internal) { // call cb.addInstr(new CallBuiltin(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); return true; } @@ -713,16 +899,11 @@ private void compileBuiltinArgs(ListSXP args, boolean missingOK) { .ifPresentOrElse( this::compileConstArg, () -> { - compileSym(sym, missingOK); + usingCtx(ctx.argContext(), () -> compileSym(sym, missingOK)); cb.addInstr(new PushArg()); }); case LangSXP call -> { - // FIXME: GNUR does: - // cmp(a, cb, ncntxt) - // which is weird since it says in the doc: - // > ... Constant folding is needed here since it doesn’t go through cmp. - // a possible reason why to go through cmp is to set location... - compileCall(call, true); + usingCtx(ctx.argContext(), () -> compile(call)); cb.addInstr(new PushArg()); } default -> compileConstArg(arg.value()); @@ -734,22 +915,16 @@ private void compileBuiltinArgs(ListSXP args, boolean missingOK) { private boolean inlineSpecial(LangSXP call) { cb.addInstr(new CallSpecial(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineLocal(LangSXP call) { - // From the R documentation: - // - // > While local is currently implemented as a closure, because of its importance relative to - // local - // > variable determination it is a good idea to inline it as well. The current semantics are - // such that - // > the interpreter treats - // > local(expr) - // > essentially the same as - // > (function() expr)() + // >> the interpreter treats + // >> local(expr) + // >> essentially the same as + // >> (function() expr)() if (call.args().size() != 1) { return false; @@ -757,35 +932,19 @@ private boolean inlineLocal(LangSXP call) { var closure = SEXPs.lang( - SEXPs.lang(SEXPs.symbol("function"), SEXPs.list(SEXPs.NULL, call.arg(0).value())), + SEXPs.lang(SEXPs.symbol("function"), SEXPs.list(SEXPs.NULL, call.arg(0), SEXPs.NULL)), SEXPs.list()); compile(closure); return true; } private boolean inlineReturn(LangSXP call) { - // From the R documentation: - // - // > A call to return causes a return from the associated function call, as determined by the - // lexical - // > context in which the return expression is defined. If the return is captured in a closure - // and is - // > executed within a callee then this requires a longjmp. A longjmp is also needed if the - // return call - // > occurs within a loop that is compiled to a separate code object to support a setjmp for - // break or - // > next calls. The RETURNJMP instruction is provided for this purpose. In all other cases an - // ordinary - // > RETURN instruction can be used. return calls with ..., which may be legal if ... contains - // only one - // > argument, or missing arguments or more than one argument, which will produce runtime - // errors, - // > are compiled as generic SPECIAL calls. - if (dotsOrMissing(call.args()) || call.args().size() > 1) { + ListSXP args = call.args(); + if (dotsOrMissing(args) || args.size() > 1) { return inlineSpecial(call); } - var v = call.args().isEmpty() ? SEXPs.NULL : call.arg(0).value(); + var v = args.isEmpty() ? SEXPs.NULL : args.value(0); usingCtx(ctx.nonTailContext(), () -> compile(v)); cb.addInstr(ctx.isReturnJump() ? new ReturnJmp() : new Return()); @@ -793,16 +952,15 @@ private boolean inlineReturn(LangSXP call) { return true; } - private boolean inlineInternal(LangSXP call) { - if (!(call.arg(0).value() instanceof LangSXP subCall)) { + private boolean inlineDotInternalCall(LangSXP call) { + if (!(call.arg(0) instanceof LangSXP subCall)) { return false; } + if (!(subCall.fun() instanceof RegSymSXP sym)) { return false; } - // we cannot do the .Internal(is.builtin.internal(sym)) check - // so we will believe that the rsession is right if (rsession.isBuiltinInternal(sym.name())) { return inlineBuiltin(subCall, true); } else { @@ -810,25 +968,99 @@ private boolean inlineInternal(LangSXP call) { } } - /** - * From the R documentation: - * - *

> In many languages it is possible to convert the expression a && b to an equivalent if - * expression > of the form > if (a) { if (b) TRUE else FALSE } > Similarly, in these languages - * the expression a || b is equivalent to > if (a) TRUE else if (b) TRUE else FALSE > Compilation - * of these expressions is thus reduced to compiling if expressions. > Unfortunately, because of - * the possibility of NA values, these equivalencies do not hold in R. In > R, NA || TRUE should - * evaluate to TRUE and NA && FALSE to FALSE. This is handled by introducing > special - * instructions AND1ST and AND2ND for && expressions and OR1ST and OR2ND for ||. > The code - * generator for && expressions generates code to evaluate the first argument and then > emits an - * AND1ST instruction. The AND1ST instruction has one operand, the label for the instruction > - * following code for the second argument. If the value on the stack produced by the first - * argument > is FALSE then AND1ST jumps to the label and skips evaluation of the second argument; - * the value > of the expression is FALSE. The code for the second argument is generated next, - * followed by an > AND2ND instruction. This removes the values of the two arguments to && from - * the stack and pushes > the value of the expression onto the stack. A RETURN instruction is - * generated if the && expression > was in tail position. - */ + private boolean isSimpleFormals(CloSXP def) { + var formals = def.formals(); + var names = formals.names(); + + if (names.contains("...")) { + return false; + } + + for (var x : formals.values()) { + if (!missing(x) && x.typeOneOf(SYM, LANG, PROM, BCODE)) { + return false; + } + } + + return true; + } + + private boolean hasSimpleArgs(LangSXP call, List formals) { + for (var arg : call.args().values()) { + if (missing(arg)) { + return false; + } else if (arg instanceof RegSymSXP sym) { + if (!formals.contains(sym.name())) { + return false; + } + } else if (arg.typeOneOf(LANG, PROM, BCODE)) { + return false; + } + } + + return true; + } + + private Optional extractSimpleInternal(CloSXP def) { + if (!isSimpleFormals(def)) { + return Optional.empty(); + } + + var b = def.bodyAST(); + + if (b instanceof LangSXP lb && lb.funName("{") && lb.args().size() == 1) { + // unwrap the { call if it has just one argument + b = lb.arg(0); + } + + return b.asLang() + .filter(call -> call.funName(".Internal")) + .flatMap(call -> call.arg(0).asLang()) + .filter( + internalCall -> internalCall.funName().map(rsession::isBuiltinInternal).orElse(false)) + .filter(internalCall -> hasSimpleArgs(internalCall, def.formals().names())); + } + + private Optional tryConvertToDotInternalCall(LangSXP call) { + return Optional.of(call) + .filter(c -> !dotsOrMissing(c.args())) + .flatMap(LangSXP::funName) + .flatMap(ctx::findFunDef) + .flatMap( + def -> + extractSimpleInternal(def) + .map( + internalCall -> { + var cenv = new HashMap(); + + def.formals().forEach((x) -> cenv.put(x.tag(), x.value())); + matchCall(def, call).args().forEach((x) -> cenv.put(x.tag(), x.value())); + + var args = + internalCall.args().stream() + .map( + (x) -> + (x.value() instanceof RegSymSXP sym) + ? cenv.get(sym.name()) + : x.value()) + .toList(); + + return SEXPs.lang( + SEXPs.symbol(".Internal"), + SEXPs.list(SEXPs.lang(internalCall.fun(), SEXPs.list2(args)))); + })); + } + + private boolean inlineSimpleInternal(LangSXP call) { + if (anyDots(call.args())) { + return false; + } + + return tryConvertToDotInternalCall(call).map(this::inlineDotInternalCall).orElse(false); + } + + // Because of the possibility of NA values, R cannot reduce && and || to a simple rewriting + // into if / else. private boolean inlineLogicalAndOr(LangSXP call, boolean isAnd) { var callIdx = cb.addConst(call); var label = cb.makeLabel(); @@ -836,20 +1068,20 @@ private boolean inlineLogicalAndOr(LangSXP call, boolean isAnd) { usingCtx( ctx.argContext(), () -> { - compile(call.arg(0).value()); + compile(call.arg(0)); cb.addInstr(isAnd ? new And1st(callIdx, label) : new Or1st(callIdx, label)); - compile(call.arg(1).value()); + compile(call.arg(1)); cb.addInstr(isAnd ? new And2nd(callIdx) : new Or2nd(callIdx)); }); cb.patchLabel(label); - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineRepeat(LangSXP call) { - var body = call.arg(0).value(); + var body = call.arg(0); return inlineSimpleLoop(body, this::compileRepeatBody); } @@ -869,43 +1101,86 @@ private boolean inlineSimpleLoop(SEXP body, Consumer cmpBody) { } cb.addInstr(new LdNull()); - if (ctx.isTailCall()) { - cb.addInstr(new Invisible()); - cb.addInstr(new Return()); - } + tailCallInvisibleReturn(); return true; } - private void compileRepeatBody(SEXP body) { - var startLabel = cb.makeLabel(); - var endLabel = cb.makeLabel(); - cb.patchLabel(startLabel); - usingCtx(ctx.loopContext(startLabel, endLabel), () -> compile(body)); - cb.addInstr(new Pop()); - cb.addInstr(new Goto(startLabel)); - cb.patchLabel(endLabel); + private boolean canSkipLoopContextList(ListSXP list, boolean breakOK) { + return list.values().stream().noneMatch(x -> !missing(x) && !canSkipLoopContext(x, breakOK)); } - private boolean inlineBreakNext(LangSXP call, boolean isBreak) { - // Java's pattern matching is so pathetic that it is simply not worth it - if (ctx.loop() instanceof Loop.InLoop loop) { - if (loop.gotoOK()) { - cb.addInstr(new Goto(isBreak ? loop.end() : loop.start())); - return true; - } else { - return inlineSpecial(call); - } - } else { - // TODO: notifyWrongBreakNext("break", cntxt, loc = cb$savecurloc()) - // or notifyWrongBreakNext("next", cntxt, loc = cb$savecurloc()) - return inlineSpecial(call); + private boolean canSkipLoopContext(SEXP body, boolean breakOK) { + if (body instanceof LangSXP l) { + if (l.fun() instanceof RegSymSXP s) { + var name = s.name(); + if (!breakOK && LOOP_BREAK_FUNS.contains(name)) { + // FIXME: why don't we need to check if it is a base version? + // GNUR does not do that, but: + // > `break` <- function() print("b") + // > i <- 0 + // > repeat({ i <<- i + 1; if (i == 10) break; }) + // I mean all of this is very much unsound, just why in this case do we care + // less? + return false; + } else if (LOOP_STOP_FUNS.contains(name) && ctx.isBaseVersion(name)) { + return true; + } else if (LOOP_TOP_FUNS.contains(name) && ctx.isBaseVersion(name)) { + // recursively check the rest of the body + // this branch keeps the breakOK! + return canSkipLoopContextList(l.args(), breakOK); + } else if (EVAL_FUNS.contains(name)) { + // FIXME: again no check if it is a base version + + // From R documentation: + // > Loops that include a call to eval (or evalq, source) are compiled with + // > context to support a programming pattern present e.g. in package Rmpi: a + // server + // application is + // > implemented using an infinite loop, which evaluates de-serialized code + // received from + // the client; the + // > server shuts down when it receives a serialized version of break. + return false; + } else { + // recursively check the rest of the body + return canSkipLoopContextList(l.args(), false); + } + } else { + return canSkipLoopContextList(l.asList(), false); + } + } + return true; + } + + private void compileRepeatBody(SEXP body) { + var startLabel = cb.makeLabel(); + var endLabel = cb.makeLabel(); + cb.patchLabel(startLabel); + usingCtx(ctx.loopContext(startLabel, endLabel), () -> compile(body)); + cb.addInstr(new Pop()); + cb.addInstr(new Goto(startLabel)); + cb.patchLabel(endLabel); + } + + private boolean inlineBreakNext(LangSXP call, boolean isBreak) { + var loop = ctx.loop(); + if (loop != null) { + if (loop.gotoOK()) { + cb.addInstr(new Goto(isBreak ? loop.end() : loop.start())); + return true; + } else { + return inlineSpecial(call); + } + } else { + // TODO: notifyWrongBreakNext or notifyWrongBreakNext + return inlineSpecial(call); } } private boolean inlineWhile(LangSXP call) { - var test = call.arg(0).value(); - var body = call.arg(1).value(); + var test = call.arg(0); + var body = call.arg(1); return inlineSimpleLoop(body, (b) -> compileWhileBody(call, test, b)); } @@ -927,18 +1202,20 @@ private void compileWhileBody(LangSXP call, SEXP test, SEXP body) { } private boolean inlineFor(LangSXP call) { - var loopVar = call.arg(0).value(); - var seq = call.arg(1).value(); - var body = call.arg(2).value(); + var loopVar = call.arg(0); + var seq = call.arg(1); + var body = call.arg(2); if (!(loopVar instanceof RegSymSXP loopSym)) { - // From R source code: - // > ## not worth warning here since the parser should not allow this + // >> not worth warning here since the parser should not allow this return false; } usingCtx(ctx.nonTailContext(), () -> compile(seq)); + var ci = cb.addConst(loopSym); + var callIdx = cb.addConst(call); + if (canSkipLoopContext(body, true)) { compileForBody(call, body, loopSym); } else { @@ -947,7 +1224,7 @@ private boolean inlineFor(LangSXP call) { () -> { var startLabel = cb.makeLabel(); var endLabel = cb.makeLabel(); - cb.addInstr(new StartFor(cb.addConst(call), cb.addConst(loopSym), startLabel)); + cb.addInstr(new StartFor(callIdx, ci, startLabel)); cb.patchLabel(startLabel); cb.addInstr(new StartLoopCntxt(true, endLabel)); compileForBody(call, body, null); @@ -957,10 +1234,7 @@ private boolean inlineFor(LangSXP call) { } cb.addInstr(new EndFor()); - if (ctx.isTailCall()) { - cb.addInstr(new Invisible()); - cb.addInstr(new Return()); - } + tailCallInvisibleReturn(); return true; } @@ -992,85 +1266,84 @@ private boolean inlineAddSub(LangSXP call, boolean isAdd) { } } - private boolean inlinePrim1(LangSXP call, Function, BcInstr> makeOp) { + private boolean inlinePrim1(LangSXP call, Function, BcInstr> makeOp) { if (dotsOrMissing(call.args())) { return inlineBuiltin(call, false); } if (call.args().size() != 1) { - // TODO: notifyWrongArgCount(e[[1]], cntxt, loc = cb$savecurloc()) + // TODO: notifyWrongArgCount return inlineBuiltin(call, false); } - usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0).value())); + usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0))); cb.addInstr(makeOp.apply(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); return true; } - private boolean inlinePrim2(LangSXP call, Function, BcInstr> makeOp) { + private boolean inlinePrim2(LangSXP call, Function, BcInstr> makeOp) { if (dotsOrMissing(call.args())) { return inlineBuiltin(call, false); } if (call.args().size() != 2) { - // TODO: notifyWrongArgCount(e[[1]], cntxt, loc = cb$savecurloc()) + // TODO: notifyWrongArgCount return inlineBuiltin(call, false); } usingCtx( ctx.nonTailContext(), () -> { - compile(call.arg(0).value()); - - // From the R documentation: - // > the second argument has to - // > be compiled with an argument context since the stack already has the value of the - // first argument - // > on it and that would need to be popped before a jump. - usingCtx(ctx.argContext(), () -> compile(call.arg(1).value())); + compile(call.arg(0)); + + // >> the second argument has to be compiled with an argument context + // >> since the stack already has the value of the first argument + // >> on it and that would need to be popped before a jump. + usingCtx(ctx.argContext(), () -> compile(call.arg(1))); }); cb.addInstr(makeOp.apply(cb.addConst(call))); - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineLog(LangSXP call) { - if (dotsOrMissing(call.args()) - || call.args().names().stream().anyMatch(Objects::nonNull) - || call.args().isEmpty() - || call.args().size() > 2) { - return inlineBuiltin(call, false); + ListSXP args = call.args(); + + if (dotsOrMissing(args) || args.hasTags() || args.isEmpty() || args.size() > 2) { + return inlineSpecial(call); } var idx = cb.addConst(call); - usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0).value())); - if (call.args().size() == 1) { + usingCtx(ctx.nonTailContext(), () -> compile(args.value(0))); + if (args.size() == 1) { cb.addInstr(new Log(idx)); } else { - usingCtx(ctx.argContext(), () -> compile(call.arg(1).value())); + usingCtx(ctx.argContext(), () -> compile(args.value(1))); cb.addInstr(new LogBase(idx)); } - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineMath1(LangSXP call, int idx) { - if (dotsOrMissing(call.args())) { + ListSXP args = call.args(); + + if (dotsOrMissing(args)) { return inlineBuiltin(call, false); } - if (call.args().size() != 1) { - // TODO: notifyWrongArgCount(e[[1]], cntxt, loc = cb$savecurloc()) + if (args.size() != 1) { + // TODO: notifyWrongArgCount return inlineBuiltin(call, false); } - usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0).value())); + usingCtx(ctx.nonTailContext(), () -> compile(args.value(0))); cb.addInstr(new Math1(cb.addConst(call), idx)); - checkTailCall(); + tailCallReturn(); return true; } @@ -1080,19 +1353,19 @@ private boolean inlineDollar(LangSXP call) { } SEXP sym; - if (call.arg(1).value() instanceof StrSXP s && s.size() == 1 && !s.get(0).isEmpty()) { + if (call.arg(1) instanceof StrSXP s && s.size() == 1 && !s.get(0).isEmpty()) { // > list(a=1)$"a" sym = SEXPs.symbol(s.get(0)); } else { - sym = call.arg(1).value(); + sym = call.arg(1); } if (sym instanceof RegSymSXP s) { - usingCtx(ctx.argContext(), () -> compile(call.arg(0).value())); + usingCtx(ctx.argContext(), () -> compile(call.arg(0))); var callIdx = cb.addConst(call); var symIdx = cb.addConst(s); cb.addInstr(new Dollar(callIdx, symIdx)); - checkTailCall(); + tailCallReturn(); return true; } else { return inlineSpecial(call); @@ -1104,61 +1377,662 @@ private boolean inlineIsXyz(LangSXP c, Supplier makeOp) { return inlineBuiltin(c, false); } - usingCtx(ctx.argContext(), () -> compile(c.arg(0).value())); + usingCtx(ctx.argContext(), () -> compile(c.arg(0))); cb.addInstr(makeOp.get()); - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineDotCall(LangSXP call) { - if (dotsOrMissing(call.args()) - || call.args().names().stream().anyMatch(Objects::nonNull) - || call.args().isEmpty() - || call.args().size() > DOTCALL_MAX) { + ListSXP args = call.args(); + if (dotsOrMissing(args) || args.hasTags() || args.isEmpty() || args.size() > DOTCALL_MAX) { return inlineBuiltin(call, false); } - usingCtx(ctx.nonTailContext(), () -> compile(call.arg(0).value())); - usingCtx(ctx.argContext(), () -> call.args().values(1).forEach(this::compile)); - cb.addInstr(new DotCall(cb.addConst(call), call.args().size() - 1)); + usingCtx(ctx.nonTailContext(), () -> compile(args.value(0))); + usingCtx(ctx.argContext(), () -> args.values(1).forEach(this::compile)); + cb.addInstr(new DotCall(cb.addConst(call), args.size() - 1)); - checkTailCall(); + tailCallReturn(); return true; } private boolean inlineMultiColon(LangSXP call) { - if (!dotsOrMissing(call.args()) && call.args().size() == 2) { - - // FIXME: ugly - String s1 = - switch (call.arg(0).value()) { - case StrSXP s when s.size() == 1 -> s.get(0); - case RegSymSXP s -> s.name(); - default -> null; - }; - - String s2 = - switch (call.arg(1).value()) { - case StrSXP s when s.size() == 1 -> s.get(0); - case RegSymSXP s -> s.name(); - default -> null; - }; + ListSXP args = call.args(); + + if (!dotsOrMissing(args) && args.size() == 2) { + Function extractName = + (x) -> + switch (x) { + case StrSXP s when s.size() == 1 -> s.get(0); + case RegSymSXP s -> s.name(); + default -> null; + }; + + String s1 = extractName.apply(args.value(0)); + String s2 = extractName.apply(args.value(1)); if (s1 == null || s2 == null) { return false; } - var args = SEXPs.list(SEXPs.string(s1), SEXPs.string(s2)); - compileCallSymFun((RegSymSXP) call.fun(), args, call); + var newArgs = SEXPs.list(SEXPs.string(s1), SEXPs.string(s2)); + compileCallSymFun(call, (RegSymSXP) call.fun(), newArgs); return true; } return false; } + private boolean inlineSwitch(LangSXP call) { + ListSXP args = call.args(); + + if (args.isEmpty() || anyDots(args)) { + return inlineSpecial(call); + } + + // before reading on a taste of switch in R: + // + // > switch("b", a=1,b=,c=,e=2,b=3,b=4,c=,d=5,6) + // + // so you can appreciate the complexity of the + // code bellow + + // 1. extract the switch expression components + var expr = args.value(0); + var cases = args.values(1); + + // TODO: notifyNoSwitchcases + + var names = args.names(1); + // allow for corner cases like switch(x, 1) which always + // returns 1 if x is a character scalar. + if (cases.size() == 1 && names.getFirst() == null) { + names = List.of(""); + } + + // 2. figure out which type of switch (numeric / character) we are compiling + + // number of default cases + boolean haveNames; + boolean haveCharDefault; + + var numberOfDefaults = names.stream().filter(String::isEmpty).count(); + if (numberOfDefaults == cases.size()) { + // none of the case is named -- this might be the first case when the expr is + // numeric + haveNames = false; + haveCharDefault = false; + } else if (numberOfDefaults == 1) { + // one default + haveNames = true; + haveCharDefault = true; + } else if (numberOfDefaults == 0) { + // no default case + haveNames = true; + haveCharDefault = false; + } else { + // more than one default (which confuses the fuck out of me) + // TODO: notifyMultipleSwitchDefaults + return inlineSpecial(call); + } + + // a boolean vector indicating which (if any) arguments are missing + var miss = cases.stream().map(this::missing).toList(); + + // 3. build labels for cases + + // the label for code that signals an error if + // a numerical selector expression chooses a case with an empty argument + var missLabel = miss.contains(true) ? cb.makeLabel() : null; + + // will be for code that invisibly procures the value NULL, which is the default + // case for a + // numerical selector argument and also for a character selector when no unnamed + // default case is provided. + var defaultLabel = cb.makeLabel(); + var labels = new ArrayList(miss.size() + 1); + miss.stream().map(x -> x ? missLabel : cb.makeLabel()).forEach(labels::add); + labels.add(defaultLabel); + + // needed as the GOTO target for a switch expression that is not in tail + // position + var endLabel = ctx.isTailCall() ? null : cb.makeLabel(); + + var nLabels = new ArrayList(); + var uniqueNames = new ArrayList(); + + if (haveNames) { + names.stream().distinct().filter(x -> !x.isEmpty()).forEachOrdered(uniqueNames::add); + if (haveCharDefault) { + uniqueNames.add(""); + } + + // the following acrobacy is, so we compile, quite unexpectedly IMHO, + // switch("b", a=1,b=,c=,e=2,b=3,b=4,c=,d=5,6) + // a code that returns 2 (matching e label on b input) + var aidxBuilder = ImmutableIntArray.builder(); + IntStream.range(0, miss.size()).filter(x -> !miss.get(x)).forEach(aidxBuilder::add); + aidxBuilder.add(names.size()); + var aidx = aidxBuilder.build(); + + for (var n : uniqueNames) { + var start = names.indexOf(n); + var idx = aidx.stream().filter(x -> x >= start).min().getAsInt(); + nLabels.add(labels.get(idx)); + } + + if (!haveCharDefault) { + uniqueNames.add(""); + nLabels.add(defaultLabel); + } + } + + // 4. compile the expression on which we dispatch to the cases + usingCtx(ctx.nonTailContext(), () -> compile(expr)); + var callIdx = cb.addConst(call); + + // 5. emit the switch instruction + + // this is more complicated than it should be, but there is no easy way around the restrictions + // how + // the BC representation is set: + // - instructions are records thus immutable with non-null fields + // - we want the BC to be the same as GNU-R one + // + // In R: the labels for the individual cases are represented by + // lists which are directly placed in the bytecode itself. + // At the end, when R is closing the code buffer, it calls cb$patchLabels() + // which in turns all "chars" and "lists" pushes into const pool. + // The effect is that the label vectors will be pushed last, and so we need to + // follow the same logic and path the instruction at the end. + // + // So we cannot really add any meaningful args to switch at this point, we need to patch it + // later. + var switchIdx = 0; + + if (haveNames) { + var cni = cb.addConst(SEXPs.string(uniqueNames)); + switchIdx = cb.addInstr(new Switch(callIdx, cni, null, null)); + } else { + // even though we use null to represent NilSXP + // we still need to add it into the const pool here so the order is kept + cb.addConst(SEXPs.NULL); + switchIdx = cb.addInstr(new Switch(callIdx, null, null, null)); + } + + cb.addInstrPatch( + switchIdx, + (instr) -> { + var oldSwitch = (Switch) instr; + var numLabelsIdx = SEXPs.integer(labels.stream().map(BcLabel::getTarget).toList()); + Switch newSwitch; + + if (haveNames) { + var chrLabelsIdx = SEXPs.integer(nLabels.stream().map(BcLabel::getTarget).toList()); + newSwitch = + new Switch( + oldSwitch.ast(), + oldSwitch.names(), + cb.addConst(chrLabelsIdx), + cb.addConst(numLabelsIdx)); + } else { + newSwitch = + new Switch(oldSwitch.ast(), oldSwitch.names(), null, cb.addConst(numLabelsIdx)); + } + + return newSwitch; + }); + + // 6. compile the cases + + // > emit code to signal an error if a numeric switch hist an + // > empty alternative (fall through, as for character, might + // > make more sense but that isn't the way switch() works) + if (miss.contains(true)) { + cb.patchLabel(missLabel); + compile( + SEXPs.lang( + SEXPs.symbol("stop"), + SEXPs.list(SEXPs.string("empty alternative in numeric switch")))); + } + + // code for the default case + cb.patchLabel(defaultLabel); + cb.addInstr(new LdNull()); + if (!tailCallInvisibleReturn()) { + cb.addInstr(new Goto(endLabel)); + } + + // code for the non-empty cases + + // > Finally the labels and code for the non-empty alternatives are written to + // the code buffer. In + // > non-tail position the code is followed by a GOTO instruction that jumps to + // endLabel. The final case + // > does not need this GOTO. + for (int i = 0; i < cases.size(); i++) { + if (miss.get(i)) { + continue; + } + cb.patchLabel(labels.get(i)); + compile(cases.get(i)); + if (!ctx.isTailCall()) { + cb.addInstr(new Goto(endLabel)); + } + } + + if (!ctx.isTailCall()) { + cb.patchLabel(endLabel); + } + + return true; + } + + private boolean inlineAssign(LangSXP call) { + if (!checkAssign(call)) { + return inlineSpecial(call); + } + + var superAssign = call.fun().equals(SEXPs.SUPER_ASSIGN); + var lhs = call.arg(0); + var value = call.arg(1); + var symbolOpt = Context.getAssignVar(call); + + if (symbolOpt.isPresent() && lhs instanceof StrOrRegSymSXP) { + compileSymbolAssign(symbolOpt.get(), value, superAssign); + return true; + } else if (symbolOpt.isPresent() && lhs instanceof LangSXP left) { + compileComplexAssign(symbolOpt.get(), left, value, superAssign); + return true; + } else { + return inlineSpecial(call); + } + } + + private void compileSymbolAssign(String name, SEXP value, boolean superAssign) { + usingCtx(ctx.nonTailContext(), () -> compile(value)); + var ci = cb.addConst(SEXPs.symbol(name)); + cb.addInstr(superAssign ? new SetVar2(ci) : new SetVar(ci)); + + tailCallInvisibleReturn(); + } + + /* + * >> Assignments for complex LVAL specifications. This is the stuff that + * >> nightmares are made of ... + */ + private void compileComplexAssign(String name, LangSXP lhs, SEXP value, boolean superAssign) { + + // > The stack invariant maintained by the assignment process is + // > that the current right hand side value is on the top, followed by the evaluated left hand + // side values, + // > the binding cell, and the original right hand side value. Thus the start instruction leaves + // the right + // > hand side value on the top, then the value of the left hand side variable, the binding + // cell, and again + // > the right hand side value on the stack. + + if (!ctx.isTopLevel()) { + cb.addInstr(new IncLnkStk()); + } + usingCtx(ctx.nonTailContext(), () -> compile(value)); + var csi = cb.addConst(SEXPs.symbol(name)); + cb.addInstr(superAssign ? new StartAssign2(csi) : new StartAssign(csi)); + + var flat = flattenPlace(lhs, cb.getCurrentLoc()); + + usingCtx( + ctx.argContext(), + () -> { + for (int i = flat.size() - 1; i >= 1; i--) { + compileGetterCall(flat.get(i)); + } + compileSetterCall(flat.getFirst(), value); + for (int i = 1; i < flat.size(); i++) { + compileSetterCall(flat.get(i), SEXPs.ASSIGN_VTMP); + } + }); + + cb.addInstr(superAssign ? new EndAssign2(csi) : new EndAssign(csi)); + if (!ctx.isTopLevel()) { + cb.addInstr(new DecLnkStk()); + } + + tailCallInvisibleReturn(); + } + + private void compileSetterCall(FlattenLHS flhs, SEXP value) { + var afun = + Context.getAssignFun(flhs.temp().fun()) + .orElseThrow(() -> new CompilerException("invalid function in complex assignment")); + var aargs = flhs.temp().args().set(0, null, SEXPs.ASSIGN_TMP).appended("value", value); + var acall = SEXPs.lang(afun, aargs); + + var sloc = cb.getCurrentLoc(); + + var cargs = flhs.original().args().appended("value", value); + var cexpr = SEXPs.lang(afun, cargs); + cb.setCurrentLoc(new Loc(cexpr, null)); + + if (afun instanceof RegSymSXP afunSym) { + if (!trySetterInline(afunSym, flhs, acall)) { + var ci = cb.addConst(afunSym); + cb.addInstr(new GetFun(ci)); + cb.addInstr(new PushNullArg()); + compileArgs(flhs.temp().args().subList(1), false); + var cci = cb.addConst(acall); + var cvi = cb.addConst(value); + cb.addInstr(new SetterCall(cci, cvi)); + } + } else { + compile(afun); + cb.addInstr(new CheckFun()); + cb.addInstr(new PushNullArg()); + compileArgs(flhs.temp().args().subList(1), false); + var cci = cb.addConst(acall); + var cvi = cb.addConst(value); + cb.addInstr(new SetterCall(cci, cvi)); + } + + cb.setCurrentLoc(sloc); + } + + private void compileGetterCall(FlattenLHS flhs) { + var place = flhs.temp(); + var sloc = cb.getCurrentLoc(); + + cb.setCurrentLoc(new Loc(flhs.original(), null)); + + var fun = place.fun(); + if (fun instanceof RegSymSXP funSym) { + if (!tryGetterInline(funSym, place)) { + var ci = cb.addConst(funSym); + cb.addInstr(new GetFun(ci)); + cb.addInstr(new PushNullArg()); + compileArgs(place.args().subList(1), false); + var cci = cb.addConst(place); + cb.addInstr(new GetterCall(cci)); + cb.addInstr(new SpecialSwap()); + } + } else { + compile(fun); + cb.addInstr(new CheckFun()); + cb.addInstr(new PushNullArg()); + compileArgs(place.args().subList(1), false); + var cci = cb.addConst(place); + cb.addInstr(new GetterCall(cci)); + cb.addInstr(new SpecialSwap()); + } + + cb.setCurrentLoc(sloc); + } + + private List flattenPlace(SEXP lhs, Loc loc) { + var places = new ArrayList(); + + while (lhs instanceof LangSXP orig) { + if (orig.args().isEmpty()) { + stop("bad assignment 1", loc); + } + + var temp = SEXPs.lang(orig.fun(), orig.args().set(0, null, SEXPs.ASSIGN_TMP)); + places.add(new FlattenLHS(orig, temp)); + lhs = orig.arg(0); + } + + if (!(lhs instanceof RegSymSXP)) { + stop("bad assignment 2", loc); + } + + return places; + } + + private boolean checkAssign(LangSXP call) { + if (call.args().size() != 2) { + return false; + } + + var lhs = call.arg(0); + return switch (lhs) { + case RegSymSXP ignored -> true; + case StrSXP s -> s.size() == 1; + case LangSXP ignored -> { + while (lhs instanceof LangSXP l) { + var fun = l.fun(); + var args = l.args(); + + // >> A valid left hand side call must have a function that is either a symbol or is of + // >> the form foo::bar or foo:::bar, and the first argument must be a symbol or + // >> another valid left hand side call. + if (!(fun instanceof RegSymSXP) + && !(fun instanceof LangSXP && args.size() == 2) + && args.value(0) instanceof RegSymSXP innerFun + && (innerFun.name().equals("::") || innerFun.name().equals(":::"))) { + // TODO: notifyBadAssignFun + yield false; + } + lhs = l.arg(0); + } + yield lhs instanceof RegSymSXP; + } + default -> false; + }; + } + + private boolean inlineDollarSubset(LangSXP call) { + if (anyDots(call.args()) || call.args().size() != 2) { + return false; + } + + var what = call.arg(1); + if (what instanceof StrSXP str) { + what = SEXPs.symbol(str.get(0)); + } + if (what instanceof RegSymSXP sym) { + var ci = cb.addConst(call); + var csi = cb.addConst(sym); + + cb.addInstr(new Dup2nd()); + cb.addInstr(new Dollar(ci, csi)); + cb.addInstr(new SpecialSwap()); + + return true; + } else { + return false; + } + } + + private boolean inlineSquareBracketSubSet(boolean doubleBracket, LangSXP call) { + if (dotsOrMissing(call.args()) || call.args().hasTags() || call.args().size() < 2) { + // inline cmpGetterDispatch from the R compiler + + if (anyDots(call.args())) { + return false; + } + + var ci = cb.addConst(call); + var label = cb.makeLabel(); + cb.addInstr(new Dup2nd()); + cb.addInstr(doubleBracket ? new StartSubset2(ci, label) : new StartSubset(ci, label)); + + var args = call.args().subList(1); + compileBuiltinArgs(args, true); + + cb.addInstr(doubleBracket ? new DfltSubset2() : new DfltSubset()); + cb.patchLabel(label); + cb.addInstr(new SpecialSwap()); + + return true; + } + + var ci = cb.addConst(call); + var label = cb.makeLabel(); + cb.addInstr(new Dup2nd()); + cb.addInstr(doubleBracket ? new StartSubset2N(ci, label) : new StartSubsetN(ci, label)); + var indices = call.args().subList(1); + compileIndices(indices); + + switch (indices.size()) { + case 1: + cb.addInstr(doubleBracket ? new VecSubset2(ci) : new VecSubset(ci)); + break; + case 2: + cb.addInstr(doubleBracket ? new MatSubset2(ci) : new MatSubset(ci)); + break; + default: + cb.addInstr( + doubleBracket ? new Subset2N(ci, indices.size()) : new SubsetN(ci, indices.size())); + } + + cb.patchLabel(label); + cb.addInstr(new SpecialSwap()); + + return true; + } + + private boolean inlineDollarAssign(FlattenLHS flhs, LangSXP call) { + var place = flhs.temp(); + if (anyDots(place.args()) || place.args().size() != 2) { + return false; + } else { + SEXP sym = place.arg(1); + if (sym instanceof StrSXP s) { + sym = SEXPs.symbol(s.get(0)); + } + if (sym instanceof RegSymSXP s) { + var ci = cb.addConst(call); + var csi = cb.addConst(s); + cb.addInstr(new DollarGets(ci, csi)); + return true; + } else { + return false; + } + } + } + + private boolean inlineSquareBracketAssign(boolean doubleSquare, FlattenLHS flhs, LangSXP call) { + var place = flhs.temp(); + + if (dotsOrMissing(place.args()) || place.args().hasTags() || place.args().size() < 2) { + // inlined cmpSetterDispatch + if (anyDots(place.args())) { + return false; + } + + var ci = cb.addConst(call); + var endLabel = cb.makeLabel(); + cb.addInstr( + doubleSquare ? new StartSubassign2(ci, endLabel) : new StartSubassign(ci, endLabel)); + var args = place.args().subList(1); + compileBuiltinArgs(args, true); + cb.addInstr(doubleSquare ? new DfltSubassign2() : new DfltSubassign()); + cb.patchLabel(endLabel); + return true; + } + + var ci = cb.addConst(call); + var label = cb.makeLabel(); + cb.addInstr(doubleSquare ? new StartSubassign2N(ci, label) : new StartSubassignN(ci, label)); + var indices = place.args().subList(1); + compileIndices(indices); + + switch (indices.size()) { + case 1: + cb.addInstr(doubleSquare ? new VecSubassign2(ci) : new VecSubassign(ci)); + break; + case 2: + cb.addInstr(doubleSquare ? new MatSubassign2(ci) : new MatSubassign(ci)); + break; + default: + cb.addInstr( + doubleSquare + ? new Subassign2N(ci, indices.size()) + : new SubassignN(ci, indices.size())); + } + + cb.patchLabel(label); + + return true; + } + + private void compileIndices(ListSXP indices) { + for (var idx : indices.values()) { + compile(idx, true, true); + } + } + + private boolean inlineSubset(boolean doubleSquare, LangSXP call) { + if (dotsOrMissing(call.args()) || call.args().hasTags() || call.args().size() < 2) { + if (anyDots(call.args()) || call.args().isEmpty()) { + return inlineSpecial(call); + } + + var oe = call.arg(0); + if (missing(oe)) { + return inlineSpecial(call); + } + usingCtx(ctx.argContext(), () -> compile(oe)); + var ci = cb.addConst(call); + var endLabel = cb.makeLabel(); + cb.addInstr(doubleSquare ? new StartSubset2(ci, endLabel) : new StartSubset(ci, endLabel)); + var args = call.args().subList(1); + compileBuiltinArgs(args, true); + cb.addInstr(doubleSquare ? new DfltSubset2() : new DfltSubset()); + cb.patchLabel(endLabel); + + tailCallReturn(); + return true; + } + + var oe = call.arg(0); + if (missing(oe)) { + stop("cannot compile this expression"); + } + + var ci = cb.addConst(call); + var endLabel = cb.makeLabel(); + usingCtx(ctx.argContext(), () -> compile(oe)); + cb.addInstr(doubleSquare ? new StartSubset2N(ci, endLabel) : new StartSubsetN(ci, endLabel)); + var indices = call.args().subList(1); + usingCtx(ctx.argContext(), () -> compileIndices(indices)); + + switch (indices.size()) { + case 1: + cb.addInstr(doubleSquare ? new VecSubset2(ci) : new VecSubset(ci)); + break; + case 2: + cb.addInstr(doubleSquare ? new MatSubset2(ci) : new MatSubset(ci)); + break; + default: + cb.addInstr( + doubleSquare ? new Subset2N(ci, indices.size()) : new SubsetN(ci, indices.size())); + } + + cb.patchLabel(endLabel); + tailCallReturn(); + + return true; + } + + private boolean inlineSlotAssign(FlattenLHS flhs, LangSXP call) { + var place = flhs.temp(); + + if (!dotsOrMissing(place.args()) + && place.args().size() == 2 + && place.arg(1) instanceof RegSymSXP s) { + var newPlace = SEXPs.lang(place.fun(), place.args().set(1, null, SEXPs.string(s.name()))); + var vexpr = call.args().values().getLast(); + compileSetterCall(new FlattenLHS(flhs.original(), newPlace), vexpr); + return true; + } else { + return false; + } + } + private boolean compileSuppressingUndefined(LangSXP call) { - // TODO: cntxt$suppressUndefined <- TRUE - compileCallSymFun((RegSymSXP) call.fun(), call.args(), call); + // TODO: cntxt$suppressUndefined <- TRUE (related to notification) + compileCallSymFun(call, (RegSymSXP) call.fun(), call.args()); return true; } @@ -1169,7 +2043,6 @@ private void usingCtx(Context ctx, Runnable thunk) { this.ctx = old; } - @SuppressFBWarnings("DLS_DEAD_LOCAL_STORE") private Optional constantFold(SEXP expr) { return switch (expr) { case LangSXP l -> constantFoldCall(l); @@ -1180,16 +2053,11 @@ private Optional constantFold(SEXP expr) { }; } - @SuppressFBWarnings("DLS_DEAD_LOCAL_STORE") private Optional checkConst(SEXP e) { var r = switch (e) { case NilSXP ignored -> e; - case ListOrVectorSXP xs when xs.size() <= MAX_CONST_SIZE -> - switch (xs.type()) { - case INT, REAL, LGL, CPLX, STR -> e; - default -> null; - }; + case VectorSXP xs when xs.size() <= MAX_CONST_SIZE -> e; default -> null; }; @@ -1198,34 +2066,100 @@ private Optional checkConst(SEXP e) { private Optional constantFoldSym(RegSymSXP sym) { var name = sym.name(); + if (ALLOWED_FOLDABLE_CONSTS.contains(name)) { - return ctx.resolve(name) - .filter(x -> x.first() instanceof BaseEnvSXP) - .flatMap(x -> checkConst(x.second())); + return ctx.resolve(name).filter(x -> x.first().isBase()).flatMap(x -> checkConst(x.second())); } else { return Optional.empty(); } } + /** + * Try to constant fold a call. The supported functions are defined in {@link + * #ALLOWED_FOLDABLE_FUNS}. It is a subset of the functions that are supported by the constant + * folding in GNU-R. This implementation is done on the best-effort basis and more are added as + * needed. + * + * @param call to be constant folded + * @return the constant folded value or empty if it cannot be constant folded + */ private Optional constantFoldCall(LangSXP call) { - if (!(call.fun() instanceof RegSymSXP funSym && isFoldableFun(funSym))) { - return Optional.empty(); + return call.funName() + .filter(this::isFoldableFun) + .flatMap(name -> buildArgs(call).flatMap(args -> doConstantFoldCall(name, args))) + .flatMap(this::checkConst); + } + + private Optional> buildArgs(LangSXP call) { + var argsBuilder = new ImmutableList.Builder(); + + for (var arg : call.args()) { + if (missing(arg.value())) { + return Optional.empty(); + } + + var namedArg = arg.value().attributes() != null ? arg.namedValue() : arg.value(); + var val = constantFold(namedArg); + + if (val.isPresent()) { + var v = val.get(); + if (!ALLOWED_FOLDABLE_MODES.contains(v.type())) { + return Optional.empty(); + } + + argsBuilder.add(v); + } else { + return Optional.empty(); + } } - // fold args -- check consts - // do.call <- need a basic interpreter - throw new NotImplementedError(); + return Optional.of(argsBuilder.build()); } - private boolean isFoldableFun(RegSymSXP sym) { - var name = sym.name(); - if (ALLOWED_FOLDABLE_FUNS.contains(name)) { - return ctx.resolve(name) - .filter(x -> x.first() instanceof BaseEnvSXP && x.second() instanceof CloSXP) - .isPresent(); - } else { - return false; + private Optional doConstantFoldCall(String funName, List args) { + return switch (funName) { + case "(" -> constantFoldParen(args); + case "c" -> ConstantFolding.c(args); + case "+" -> { + if (args.size() == 1) { + yield ConstantFolding.plus(args); + } else { + yield ConstantFolding.add(args); + } + } + case "*" -> ConstantFolding.mul(args); + case "/" -> ConstantFolding.div(args); + case "-" -> { + if (args.size() == 1) { + yield ConstantFolding.minus(args); + } else { + yield ConstantFolding.sub(args); + } + } + case ":" -> ConstantFolding.colon(args); + case "^" -> ConstantFolding.pow(args); + case "log" -> ConstantFolding.log(args); + case "log2" -> ConstantFolding.log2(args); + case "sqrt" -> ConstantFolding.sqrt(args); + case "rep" -> ConstantFolding.rep(args); + case "seq.int" -> ConstantFolding.seqInt(args); + default -> Optional.empty(); + }; + } + + private boolean isFoldableFun(String name) { + return Optional.of(name) + .filter(ALLOWED_FOLDABLE_FUNS::contains) + .flatMap(n -> getInlineInfo(n, false)) + .map(x -> x.env.isBase() && x.value != null && x.value.isFunction()) + .orElse(false); + } + + private Optional constantFoldParen(List args) { + if (args.size() != 1) { + return Optional.empty(); } + return constantFold(args.getFirst()); } private R stop(String message) throws CompilerException { @@ -1236,58 +2170,16 @@ private R stop(String message, Loc loc) throws CompilerException { throw new CompilerException(message, loc); } - private static final Set LOOP_STOP_FUNS = Set.of("function", "for", "while", "repeat"); - private static final Set LOOP_TOP_FUNS = Set.of("(", "{", "if"); - private static final Set LOOP_BREAK_FUNS = Set.of("break", "next"); - private static final Set EVAL_FUNS = Set.of("eval", "evalq", "source"); - - private boolean canSkipLoopContext(SEXP body, boolean breakOK) { - if (body instanceof LangSXP l) { - if (l.fun() instanceof RegSymSXP s) { - var name = s.name(); - if (!breakOK && LOOP_BREAK_FUNS.contains(name)) { - // FIXME: why don't we need to check if it is a base version? - // GNUR does not do that, but: - // > `break` <- function() print("b") - // > i <- 0 - // > repeat({ i <<- i + 1; if (i == 10) break; }) - // I mean all of this is very much unsound, just why in this case do we care less? - return false; - } else if (LOOP_STOP_FUNS.contains(name) && ctx.isBaseVersion(name)) { - return true; - } else if (EVAL_FUNS.contains(name)) { - // FIXME: again no check if it is a base version - - // From R documentation: - // > Loops that include a call to eval (or evalq, source) are compiled with - // > context to support a programming pattern present e.g. in package Rmpi: a server - // application is - // > implemented using an infinite loop, which evaluates de-serialized code received from - // the client; the - // > server shuts down when it receives a serialized version of break. - return false; - } else if (LOOP_TOP_FUNS.contains(name) && ctx.isBaseVersion(name)) { - // recursively check the rest of the body - return l.args().values().stream() - .noneMatch(x -> !missing(x) && !canSkipLoopContext(x, false)); - } - } else { - return l.asList().stream().noneMatch(x -> !missing(x) && !canSkipLoopContext(x, false)); - } - } - return true; - } - - private static boolean anyDots(ListSXP l) { + private boolean anyDots(ListSXP l) { return l.values().stream() .anyMatch(x -> !missing(x) && x instanceof SymSXP s && s.isEllipsis()); } - private static boolean dotsOrMissing(ListSXP l) { + private boolean dotsOrMissing(ListSXP l) { return l.values().stream().anyMatch(x -> missing(x) || x instanceof SymSXP s && s.isEllipsis()); } - private static boolean missing(SEXP x) { + private boolean missing(SEXP x) { // FIXME: this is a great oversimplification from the do_missing in R if (x instanceof SymSXP s) { return s.isMissing(); @@ -1296,26 +2188,115 @@ private static boolean missing(SEXP x) { } } - private static @Nullable IntSXP extractSrcRef(SEXP expr, int idx) { + private boolean mayCallBrowser(SEXP body) { + if (body instanceof LangSXP call && call.fun() instanceof RegSymSXP s) { + var name = s.name(); + if (name.equals("browser")) { + return true; + } else if (name.equals("function") && ctx.isBaseVersion(name)) { + return false; + } else { + return call.args().values().stream().anyMatch(this::mayCallBrowser); + } + } + + return false; + } + + // This is a very primitive implementation of the {@code match.call} + // it simply tries to match the named arguments to parameters + // followed by the rest of the arguments. + static LangSXP matchCall(CloSXP definition, LangSXP call) { + var matched = ImmutableList.builder(); + var remaining = new ArrayList(); + var formals = definition.formals(); + + if (formals.size() < call.args().size()) { + throw new IllegalArgumentException("Too many arguments and we do not support ... yet"); + } + + for (var actual : call.args()) { + if (actual.tag() != null) { + matched.add(actual); + formals = formals.remove(actual.tag()); + } else { + remaining.add(actual.value()); + } + } + + for (int i = 0; i < remaining.size(); i++) { + matched.add(new TaggedElem(formals.get(i).tag(), remaining.get(i))); + } + + return SEXPs.lang(call.fun(), SEXPs.list(matched.build())); + } + + // extracts the source reference from the function + // using either the body, if the body is wrapped in a block + // or from function itself + private static Loc functionLoc(CloSXP fun) { + var body = fun.bodyAST(); + + Optional srcRef; + + if (body instanceof LangSXP b && b.funName("{")) { + srcRef = extractSrcRef(body, 0); + } else { + // try to get the srcRef from the function itself + // normally, it would be attached to the `{` + srcRef = fun.getSrcRef(); + } + + return new Loc(body, srcRef.orElse(null)); + } + + /** + * Extracts source reference from the given expression. It uses the {@code srcref} attribute of + * the expression. If the expression is a block, then the {@code srcref} will be a vector of + * srcrefs in which case it will use the given index. + * + * @param expr the expression to get the source reference from + * @param idx the index of the source reference in the vector in the case the expression is a + * block + * @return source code reference or empty if not found + */ + private static Optional extractSrcRef(SEXP expr, int idx) { var attrs = expr.attributes(); if (attrs == null) { - return null; + return Optional.empty(); } var srcref = attrs.get("srcref"); if (srcref == null) { - return null; + return Optional.empty(); } if (srcref instanceof IntSXP i && i.size() >= 6) { - return i; + return Optional.of(i); } else if (srcref instanceof VecSXP v && v.size() >= idx && v.get(idx) instanceof IntSXP i && i.size() >= 6) { - return i; + return Optional.of(i); } else { - return null; + return Optional.empty(); } } + + /** + * Information inlining for an R symbol. + * + * @param name the symbol name + * @param env the environment where the symbol was found + * @param value the value of the symbol or null if no value was found + * @param guard whether the inlining should be guarded which depends on the environment, {@link + * #optimizationLevel} and target. + */ + record InlineInfo(String name, EnvSXP env, @Nullable SEXP value, boolean guard) {} + + /** + * Helper struct for inlining `[[` and the like. Essentially links the `*tmp*` with the + * corresponding code. + */ + record FlattenLHS(LangSXP original, LangSXP temp) {} } diff --git a/src/main/java/org/prlprg/bc/ConstPool.java b/src/main/java/org/prlprg/bc/ConstPool.java index 791d4bd4e..3292e1a1c 100644 --- a/src/main/java/org/prlprg/bc/ConstPool.java +++ b/src/main/java/org/prlprg/bc/ConstPool.java @@ -1,132 +1,41 @@ package org.prlprg.bc; -import com.google.common.base.Objects; import com.google.common.collect.ForwardingList; import com.google.common.collect.ImmutableList; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import java.util.Collection; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import javax.annotation.Nullable; +import edu.umd.cs.findbugs.annotations.Nullable; +import java.util.*; +import java.util.function.Function; import javax.annotation.concurrent.Immutable; import org.prlprg.sexp.*; -import org.prlprg.util.Either; -import org.prlprg.util.Pair; -/** - * A pool (array) of constants. Underneath this is an immutable list, but the elements are only - * accessible with typed integers. - */ -@SuppressFBWarnings( - value = "JCIP_FIELD_ISNT_FINAL_IN_IMMUTABLE_CLASS", - justification = - "The class isn't technically immutable but is after we use the thread-unsafe builder, " - + "so practically we treat it as immutable") +/** A pool (array) of constants. */ @Immutable public final class ConstPool extends ForwardingList { - @Nullable private ImmutableList consts; + private final ImmutableList consts; - private ConstPool() {} + private ConstPool(List consts) { + this.consts = ImmutableList.copyOf(consts); + } @Override protected List delegate() { - if (consts == null) { - throw new IllegalStateException("ConstPool is not yet built"); - } return consts; } /** * Get the element at the given pool index * - * @throws WrongPoolException if the index is for a different pool * @throws IndexOutOfBoundsException if the index is out of bounds + * @throws ClassCastException if the element is not of the expected type */ - public SEXP get(Idx idx) { - if (consts == null) { - throw new IllegalStateException("ConstPool is not yet built"); - } - return consts.get(idx.unwrapIdx(this)); - } - - /** - * Get the element at the given pool index - * - * @throws WrongPoolException if the index is for a different pool - * @throws IndexOutOfBoundsException if the index is out of bounds - */ - public S get(TypedIdx idx) { - if (consts == null) { - throw new IllegalStateException("ConstPool is not yet built"); - } - assert idx.checkType(); - @SuppressWarnings("unchecked") - var res = (S) consts.get(idx.unwrapIdx(this)); - return res; - } - - /** If the SEXP is a constant, returns its index. Otherwise returns null. */ - public @Nullable TypedIdx indexOf(S c) { - if (consts == null) { - throw new IllegalStateException("ConstPool is not yet built"); - } - var i = consts.indexOf(c); - if (i == -1) { - return null; - } - // This is only valid because TypedIdx is covariant, and only accepted because Java erases - // generics. - // The conversion from Class to Class changes the generic. - @SuppressWarnings("unchecked") - var idx = new TypedIdx<>(this, i, (Class) c.getClass()); - assert idx.checkType(); - return idx; - } - - /** Iterate all indices */ - public Iterable indices() { - return () -> - new Iterator<>() { - int i = 0; - - @Override - public boolean hasNext() { - if (consts == null) { - throw new IllegalStateException("ConstPool is not yet built"); - } - return i < consts.size(); - } - - @Override - public Idx next() { - if (!hasNext()) { - throw new IndexOutOfBoundsException(); - } - return new Idx(ConstPool.this, i++); - } - }; - } - - /** - * Create from a constant list (raw GNU-R representation). - * - * @return The pool and a function to create pool indices from raw integers, since that isn't - * ordinarily exposed. - */ - static Pair fromRaw(List consts) throws BcFromRawException { - var builder = new Builder(); - for (var c : consts) { - builder.add(c); - } - - var pool = builder.build(); - return new Pair<>(pool, new MakeIdx(pool)); + public S get(Idx idx) { + var res = consts.get(idx.idx()); + return idx.type().cast(res); } @Override public String toString() { - StringBuilder sb = new StringBuilder("=== CONSTS " + debugId() + " ==="); + StringBuilder sb = new StringBuilder("=== CONSTS ==="); var idx = 0; for (var c : this) { var cStr = c.toString(); @@ -140,13 +49,18 @@ public String toString() { return sb.toString(); } - private String debugId() { - // FIXME: this is bad! - if (consts != null) { - return "@" + hashCode(); - } else { - return "@"; - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + ConstPool sexps = (ConstPool) o; + return consts.equals(sexps.consts); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), consts); } /** @@ -154,129 +68,107 @@ private String debugId() { * *

It also contains a reference to the owner pool which is checked at runtime for extra safety. */ - public static sealed class Idx permits TypedIdx { - protected final ConstPool pool; - protected final int idx; - - private Idx(ConstPool pool, int idx) { - this.pool = pool; - this.idx = idx; - } - - /** - * Return the underlying index if the given pool is the one this was originally created with. - * - * @throws WrongPoolException if the given pool is not the one this was originally created with - */ - protected int unwrapIdx(ConstPool parent) { - if (parent != pool) { - throw new WrongPoolException(); - } - return idx; - } - + public record Idx(int idx, Class type) { @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof Idx idx1)) return false; - return idx == idx1.idx && Objects.equal(pool, idx1.pool); - } - - @Override - public int hashCode() { - return Objects.hashCode(pool, idx); + public String toString() { + // TODO: add sexp type? + return String.format("%d", idx); } - @Override - public String toString() { - return String.format("%s.%d", pool.debugId(), idx); + public static Idx create(int i, S value) { + @SuppressWarnings("unchecked") + var idx = new Idx<>(i, (Class) value.getCanonicalType()); + return idx; } } /** - * A {@link Idx} (typed index into a bytecode pool) which further checks that the {@link SEXP} is - * of a specific type. + * A builder class for creating constant pools. + * + *

Not synchronized, so don't use from multiple threads. */ - @SuppressFBWarnings( - value = "EQ_DOESNT_OVERRIDE_EQUALS", - justification = - "Idx and TypedIdx only compare `pool` and `index`, we want different types to be equal") - public static final class TypedIdx extends Idx { - private final Class sexpInterface; + public static class Builder { + private final Map index; + private final List values; + + public Builder() { + this(Collections.emptyList()); + } - private TypedIdx(ConstPool pool, int idx, Class sexpInterface) { - super(pool, idx); - if (!SEXP.class.isAssignableFrom(sexpInterface)) { - throw new IllegalArgumentException("sexpInterface must be inherit SEXP: " + sexpInterface); + public Builder(List consts) { + index = new HashMap<>(consts.size()); + values = new ArrayList<>(consts.size()); + + for (var e : consts) { + add(e); } - this.sexpInterface = sexpInterface; } - private boolean checkType() { - // The cast to Idx is required because the TypedIdx version does an `assert`. - return sexpInterface.isInstance(pool.get((Idx) this)); + public Idx add(S c) { + var i = + index.computeIfAbsent( + c, + (ignored) -> { + var x = index.size(); + values.add(c); + return x; + }); + + return Idx.create(i, c); } - } - static final class MakeIdx { - private final ConstPool pool; + /** + * Finish building the pool. + * + * @return The pool. + */ + public ConstPool build() { + return new ConstPool(ImmutableList.copyOf(values)); + } - private MakeIdx(ConstPool pool) { - this.pool = pool; + public Idx index(int i) { + return Idx.create(i, values.get(i)); } - TypedIdx lang(int i) { - return of(i, LangSXP.class); + private Idx index(int i, Class type) { + var value = values.get(i); + if (type.isInstance(value)) { + return Idx.create(i, type.cast(value)); + } else { + throw new IllegalArgumentException("Expected " + type + ", but got " + value.getClass()); + } } - TypedIdx sym(int i) { - return of(i, RegSymSXP.class); + public Idx indexLang(int i) { + return index(i, LangSXP.class); } - @Nullable TypedIdx symOrNil(int i) { - return tryOf(i, RegSymSXP.class); + public Idx indexSym(int i) { + return index(i, RegSymSXP.class); } - @Nullable TypedIdx langOrNegative(int i) { - return i >= 0 ? tryOf(i, LangSXP.class) : null; + // FIXME: do we need this? + public @Nullable Idx indexLangOrNilIfNegative(int i) { + return i >= 0 ? orNil(i, LangSXP.class) : null; } - @Nullable TypedIdx intOrOther(int i) { - return tryOf(i, IntSXP.class); + public @Nullable Idx indexStrOrSymOrNil(int i) { + return orNil(i, StrOrRegSymSXP.class); } - @Nullable TypedIdx strOrSymOrNil(int i) { - var asStrOrSymbol = tryOf(i, StrOrRegSymSXP.class); - if (asStrOrSymbol != null) { - return asStrOrSymbol; - } - var asNil = tryOf(i, NilSXP.class); - if (asNil != null) { - return null; - } else { - throw new IllegalArgumentException( - "Expected StrSXP, SymSXP or NilSXP, got " + pool.get(new Idx(pool, i))); - } + public @Nullable Idx indexStrOrNil(int i) { + return orNil(i, StrSXP.class); } - @Nullable Either, TypedIdx> strOrNilOrOther(int i) { - var asSymbol = tryOf(i, StrSXP.class); - if (asSymbol != null) { - return Either.left(asSymbol); - } - var asNil = tryOf(i, NilSXP.class); - if (asNil != null) { - return Either.right(asNil); - } else { - return null; - } + public @Nullable Idx indexIntOrNil(int i) { + return orNil(i, IntSXP.class); } - TypedIdx formalsBodyAndMaybeSrcRef(int i) { - var idx = of(i, VecSXP.class); + public Idx indexClosure(int i) { + var idx = index(i, VecSXP.class); - // Check vector shape - var vec = pool.get(idx); + // check vector shape + var vec = (VecSXP) values.get(i); if (vec.size() != 2 && vec.size() != 3) { throw new IllegalArgumentException( "Malformed formals/body/srcref vector, expected length 2 or 3, got " + vec); @@ -290,77 +182,21 @@ TypedIdx formalsBodyAndMaybeSrcRef(int i) { return idx; } - Idx any(int i) { - return new Idx(pool, i); - } - - private TypedIdx of(int i, Class sexpInterface) { - var idx = new TypedIdx<>(pool, i, sexpInterface); - if (!idx.checkType()) { - // The cast to Idx is required because the TypedIdx version does an `assert`. - throw new IllegalArgumentException( - "Expected " + sexpInterface.getSimpleName() + ", got " + pool.get((Idx) idx)); - } - return idx; - } - - private @Nullable TypedIdx tryOf(int i, Class sexpInterface) { - var idx = new TypedIdx<>(pool, i, sexpInterface); - if (!idx.checkType()) { + public @Nullable Idx orNil(int i, Class clazz) { + var value = values.get(i); + if (clazz.isInstance(value)) { + return Idx.create(i, clazz.cast(value)); + } else if (value instanceof NilSXP) { return null; + } else { + throw new IllegalArgumentException( + "Expected " + clazz + " or NilSXP, but got " + value.getClass()); } - return idx; - } - } - - /** Caused by trying to subscript one bytecode pool with an index from another. */ - public static class WrongPoolException extends RuntimeException { - private WrongPoolException() { - super("Wrong pool"); - } - } - - /** - * A builder class for creating constant pools. - * - *

Not synchronized, so don't use from multiple threads. - */ - public static class Builder { - private final ConstPool pool = new ConstPool(); - private final LinkedHashMap consts = new LinkedHashMap<>(); - - /** Create a new builder. */ - public Builder() {} - - /** Append a constant and return the index. */ - public TypedIdx add(S c) { - // This only works because TypedIdx is covariant the and generic gets erased. - // We actually cast TypedIdx into TypedIdx. - @SuppressWarnings("unchecked") - var idx = - (TypedIdx) - consts.computeIfAbsent( - c, (ignored) -> new TypedIdx<>(pool, consts.size(), (Class) c.getClass())); - return idx; - } - - /** Append instructions. */ - public ImmutableList> addAll(Collection c) { - var builder = ImmutableList.>builder(); - for (var e : c) { - builder.add(add(e)); - } - return builder.build(); } - /** - * Finish building the pool. - * - * @return The pool. - */ - public ConstPool build() { - pool.consts = ImmutableList.copyOf(consts.sequencedKeySet()); - return pool; + @SuppressWarnings("unchecked") + public void reset(Idx idx, Function fun) { + values.set(idx.idx(), fun.apply((S) values.get(idx.idx()))); } } } diff --git a/src/main/java/org/prlprg/bc/ConstantFolding.java b/src/main/java/org/prlprg/bc/ConstantFolding.java new file mode 100644 index 000000000..a74443020 --- /dev/null +++ b/src/main/java/org/prlprg/bc/ConstantFolding.java @@ -0,0 +1,365 @@ +package org.prlprg.bc; + +import com.google.common.collect.ImmutableList; +import com.google.common.math.DoubleMath; +import com.google.common.primitives.ImmutableDoubleArray; +import com.google.common.primitives.ImmutableIntArray; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.prlprg.primitive.Complex; +import org.prlprg.primitive.Logical; +import org.prlprg.sexp.*; +import org.prlprg.util.Arithmetic; + +/** + * Implements constant folding for some of the R functions. All functions are should have the same + * signature: {@code Optional f(List args)} + * + *

They shall be called from Compiler and can assume that each argument is one of the {@link + * Compiler#ALLOWED_FOLDABLE_MODES}. + * + *

The main reason it is extracted from the Compiler is to make have all the folding code in one + * place without making Compiler too big. + */ +public final class ConstantFolding { + private ConstantFolding() {} + + public static Optional add(List args) { + return math2(Arithmetic.Operation.ADD, args); + } + + public static Optional div(List args) { + return math2(Arithmetic.Operation.DIV, args); + } + + public static Optional log(List args) { + return doubleMath1(args, Math::log); + } + + public static Optional log2(List args) { + return doubleMath1(args, DoubleMath::log2); + } + + public static Optional minus(List args) { + return math1(Arithmetic.Operation.MINUS, args); + } + + public static Optional mul(List args) { + return math2(Arithmetic.Operation.MUL, args); + } + + public static Optional plus(List args) { + return math1(Arithmetic.Operation.PLUS, args); + } + + public static Optional pow(List args) { + if (args.size() != 2) { + return Optional.empty(); + } + + if (Coercions.commonType(args) == SEXPType.INT) { + return math2(Arithmetic.Operation.POW, args, Arithmetic.DOUBLE); + } else { + return math2(Arithmetic.Operation.POW, args); + } + } + + public static Optional rep(List args) { + if (args.size() != 2) { + return Optional.empty(); + } + + if (!(args.getFirst() instanceof VectorSXP x)) { + return Optional.empty(); + } + + if (!(args.getLast() instanceof NumericSXP times)) { + return Optional.empty(); + } + + if (times.size() == 1) { + return Optional.of(doRep1(x, times.asInt(0))); + } else if (times.size() == x.size()) { + return Optional.of(doRep2(x, times)); + } else { + return Optional.empty(); + } + } + + public static Optional seqInt(List args) { + if (args.size() != 3) { + return Optional.empty(); + } + + if (!(args.get(0) instanceof NumericSXP fromV) || fromV.size() != 1) { + return Optional.empty(); + } + + if (!(args.get(1) instanceof NumericSXP toV) || toV.size() != 1) { + return Optional.empty(); + } + + if (!(args.get(2) instanceof NumericSXP byV) || byV.size() != 1) { + return Optional.empty(); + } + + var type = Coercions.commonType(fromV.type(), toV.type(), byV.type()); + return switch (type) { + case INT -> { + var from = fromV.asInt(0); + var to = toV.asInt(0); + var by = byV.asInt(0); + var ans = Arithmetic.INTEGER.createResult((to - from) / by + 1); + for (int i = 0, x = from; x <= to; i++, x += by) { + ans[i] = x; + } + yield Optional.of(SEXPs.integer(ans)); + } + case REAL -> { + var from = fromV.asReal(0); + var to = toV.asReal(0); + var by = byV.asReal(0); + var size = (int) ((to - from) / by) + 1; + var ans = Arithmetic.DOUBLE.createResult(size); + var x = from; + for (int i = 0; i < size; i++) { + ans[i] = x; + x += by; + } + yield Optional.of(SEXPs.real(ans)); + } + default -> Optional.empty(); + }; + } + + public static Optional sqrt(List args) { + return doubleMath1(args, Math::sqrt); + } + + public static Optional sub(List args) { + return math2(Arithmetic.Operation.SUB, args); + } + + private static SEXP doMath1(Arithmetic.Operation op, VectorSXP va, Arithmetic arith) { + var ax = arith.fromSEXP(va); + return arith.toSEXP(doMath1(ax, arith::createResult, arith.getUnaryFun(op))); + } + + private static R[] doMath1(T[] ax, Function createResult, Function f) { + var l = ax.length; + if (l == 0) { + return createResult.apply(0); + } + + var ans = createResult.apply(l); + for (int i = 0; i < l; i++) { + ans[i] = f.apply(ax[i]); + } + + return ans; + } + + /** + * Implements the binary operation for two vectors using R semantics of recycling vectors. + * + * @param ax the left hand side operand + * @param bx the right hand side operand + * @param createResult the function to create the result vector of the corresponding type + * @param f the binary operation to apply + * @return {@code f} applied to elements of {@code ax} and {@code bx} based on R recycling rules. + * @param the type of operands + * @param the type of the result + */ + private static R[] doMath2( + T[] ax, T[] bx, Function createResult, BiFunction f) { + var la = ax.length; + var lb = bx.length; + + if (la == 0 || lb == 0) { + return createResult.apply(0); + } + + var l = Math.max(la, lb); + var ans = createResult.apply(l); + + for (int i = 0, ia = 0, ib = 0; + i < l; + ia = (++ia == la) ? 0 : ia, ib = (++ib == lb) ? 0 : ib, i++) { + + var a = ax[ia]; + var b = bx[ib]; + + ans[i] = f.apply(a, b); + } + + return ans; + } + + private static SEXP doRep1(VectorSXP x, int times) { + var res = new ImmutableList.Builder(); + + for (int i = 0; i < times; i++) { + res.addAll(x); + } + + return SEXPs.vector(x.type(), res.build()); + } + + private static SEXP doRep2(VectorSXP xs, NumericSXP times) { + var res = new ImmutableList.Builder(); + + for (int j = 0; j < times.size(); j++) { + var n = times.asInt(j); + var x = xs.get(j); + for (int i = 0; i < n; i++) { + res.add(x); + } + } + + return SEXPs.vector(xs.type(), res.build()); + } + + private static Optional doubleMath1(List args, Function f) { + if (args.size() != 1) { + return Optional.empty(); + } + if (!(args.getFirst() instanceof NumericSXP n)) { + return Optional.empty(); + } + + var res = Arrays.copyOf(n.coerceToReals(), n.size()); + for (var i = 0; i < res.length; i++) { + res[i] = f.apply(res[i]); + } + + return Optional.of(SEXPs.real(res)); + } + + private static Optional math1(Arithmetic.Operation op, List args) { + if (args.size() != 1) { + return Optional.empty(); + } + + if (!(args.getFirst() instanceof VectorSXP va)) { + return Optional.empty(); + } + + return Arithmetic.forType(va.type()).map(arith -> doMath1(op, va, arith)); + } + + private static Optional math2(Arithmetic.Operation op, List args) { + return Arithmetic.forType(args).flatMap(arith -> math2(op, args, arith)); + } + + private static Optional math2( + Arithmetic.Operation op, List args, Arithmetic arith) { + if (args.size() != 2) { + return Optional.empty(); + } + + if (!(args.get(0) instanceof VectorSXP va)) { + return Optional.empty(); + } + + if (!(args.get(1) instanceof VectorSXP vb)) { + return Optional.empty(); + } + + var ax = arith.fromSEXP(va); + var bx = arith.fromSEXP(vb); + var ans = arith.toSEXP(doMath2(ax, bx, arith::createResult, arith.getBinaryFun(op))); + + return Optional.of(ans); + } + + public static Optional c(List args) { + if (args.isEmpty()) { + return Optional.of(SEXPs.NULL); + } + + var type = args.getFirst().type(); + var capacity = 0; + + // compute the target type, the SEXPTYPE is ordered in a way that we can just take the max + for (var arg : args) { + type = Coercions.commonType(type, arg.type()); + capacity += ((VectorSXP) arg).size(); + } + + // this is safe as we have proved that all args are VectorSXP + @SuppressWarnings("unchecked") + var vecArgs = (List>) args; + + Optional vals = + switch (type) { + case STR -> { + var res = new ImmutableList.Builder(); + vecArgs.forEach(x -> res.add(x.coerceToStrings())); + yield Optional.of(SEXPs.string(res.build())); + } + case REAL -> { + var res = ImmutableDoubleArray.builder(capacity); + vecArgs.forEach(x -> Arrays.stream(x.coerceToReals()).forEach(res::add)); + yield Optional.of(SEXPs.real(res.build())); + } + case INT -> { + var res = ImmutableIntArray.builder(capacity); + vecArgs.forEach(x -> Arrays.stream(x.coerceToInts()).forEach(res::add)); + yield Optional.of(SEXPs.integer(res.build())); + } + case LGL -> { + var res = new ImmutableList.Builder(); + vecArgs.forEach(x -> res.add(x.coerceToLogicals())); + yield Optional.of(SEXPs.logical(res.build())); + } + case CPLX -> { + var res = new ImmutableList.Builder(); + vecArgs.forEach(x -> res.add(x.coerceToComplexes())); + yield Optional.of(SEXPs.complex(res.build())); + } + default -> Optional.empty(); + }; + + return vals.map( + x -> { + var names = + args.stream() + .map(SEXP::names) + .reduce( + new ArrayList<>(), + (acc, y) -> { + acc.addAll(y); + return acc; + }); + return x.withNames(names); + }); + } + + public static Optional colon(List args) { + if (args.size() != 2) { + return Optional.empty(); + } + + if (!(args.get(0) instanceof NumericSXP min) || min.size() != 1) { + return Optional.empty(); + } + + if (!(args.get(1) instanceof NumericSXP max) || min.size() != 1) { + return Optional.empty(); + } + + var imin = min.asInt(0); + var imax = max.asInt(0); + var ints = ImmutableIntArray.builder(Math.abs(imax - imin)); + var inc = imin < imax ? 1 : -1; + for (var i = imin; i != imax + inc; i += inc) { + ints.add(i); + } + + return Optional.of(SEXPs.integer(ints.build())); + } +} diff --git a/src/main/java/org/prlprg/bc/Context.java b/src/main/java/org/prlprg/bc/Context.java index 4d614cce4..5441f1fe5 100644 --- a/src/main/java/org/prlprg/bc/Context.java +++ b/src/main/java/org/prlprg/bc/Context.java @@ -5,15 +5,22 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.prlprg.sexp.*; import org.prlprg.util.Pair; +record Loop(BcLabel start, BcLabel end, boolean gotoOK) { + public Loop gotoNotOK() { + return new Loop(start, end, false); + } +} + public class Context { private final boolean topLevel; private final boolean tailCall; private final boolean returnJump; private final EnvSXP environment; - private final Loop loop; + private final @Nullable Loop loop; /** * @param topLevel {@code true} for top level expressions, {@code false} otherwise (e.g., @@ -23,9 +30,14 @@ public class Context { * is contained in a loop. * @param returnJump {@code true} indicated that the call to return needs {@code RETURNJMP}. * @param environment the compilation environment. - * @param loop + * @param loop the loop context or {@code null} if the context does not contain a loop. */ - Context(boolean topLevel, boolean tailCall, boolean returnJump, EnvSXP environment, Loop loop) { + Context( + boolean topLevel, + boolean tailCall, + boolean returnJump, + EnvSXP environment, + @Nullable Loop loop) { this.topLevel = topLevel; this.tailCall = tailCall; this.returnJump = returnJump; @@ -33,16 +45,20 @@ public class Context { this.loop = loop; } + public static Context topLevelContext(EnvSXP env) { + return new Context(true, true, false, env, null); + } + public static Context functionContext(CloSXP fun) { var env = new UserEnvSXP(fun.env()); - var ctx = new Context(false, true, false, env, new Loop.NotInLoop()); + var ctx = topLevelContext(env); - return ctx.functionContext(fun.formals(), fun.body()); + return ctx.functionContext(fun.formals(), fun.bodyAST()); } public Context functionContext(ListSXP formals, SEXP body) { var env = new UserEnvSXP(environment); - var ctx = new Context(false, true, false, env, loop); + var ctx = new Context(true, true, false, env, loop); formals.names().forEach(x -> env.set(x, SEXPs.UNBOUND_VALUE)); for (var v : formals.values()) { @@ -65,11 +81,12 @@ public Context promiseContext() { // The promise context also sets returnJump since a return call that is triggered by forcing a // promise // requires a longjmp to return from the appropriate function. - return new Context(false, true, true, environment, loop.gotoNotOK()); + return new Context(false, true, true, environment, loop != null ? loop.gotoNotOK() : null); } public Context argContext() { - return new Context(false, false, returnJump, environment, loop.gotoNotOK()); + return new Context( + false, false, returnJump, environment, loop != null ? loop.gotoNotOK() : null); } public Context loopContext(BcLabel start, BcLabel end) { @@ -79,11 +96,11 @@ public Context loopContext(BcLabel start, BcLabel end) { nctx.tailCall, nctx.returnJump, nctx.environment, - new Loop.InLoop(start, end, true)); + new Loop(start, end, true)); } public boolean isBaseVersion(String name) { - return environment.find(name).map(b -> b.first() instanceof BaseEnvSXP).orElse(false); + return environment.find(name).map(b -> b.first().isBase()).orElse(false); } public Optional> resolve(String name) { @@ -105,11 +122,11 @@ public Set findLocals(SEXP e) { var elem = todo.removeFirst(); if (elem instanceof LangSXP l && l.fun() instanceof RegSymSXP fun) { var args = l.args().values(); - var local = + Optional local = switch (fun.name()) { case "=", "<-" -> { todo.addAll(args.subList(1, args.size())); - yield getAssignedVar(l); + yield getAssignVar(l); } case "for" -> { todo.addAll(args.subList(1, args.size())); @@ -126,24 +143,24 @@ public Set findLocals(SEXP e) { if (args.size() == 2 && args.getFirst() instanceof StrOrRegSymSXP v) { yield v.reifyString(); } else { - yield Optional.empty(); + yield Optional.empty(); } } - case "function" -> { - // Variables defined within local functions created by function expressions do not - // shadow globals - // within the containing expression and therefore function expressions do not - // contribute any new - // local variables. - yield Optional.empty(); - } + case "function" -> + // Variables defined within local functions created by function expressions do not + // shadow globals + // within the containing expression and therefore function expressions do not + // contribute any new + // local variables. + Optional.empty(); + case "~", "expression", "quote" -> { // they do not evaluate their arguments and so do not contribute new local // variables. if (shadowed.contains(fun.name()) || locals.contains(fun.name())) { todo.addAll(args); } - yield Optional.empty(); + yield Optional.empty(); } case "local" -> { // local calls without an environment argument create a new environment @@ -156,11 +173,11 @@ public Set findLocals(SEXP e) { || args.size() != 1) { todo.addAll(args); } - yield Optional.empty(); + yield Optional.empty(); } default -> { todo.addAll(args); - yield Optional.empty(); + yield Optional.empty(); } }; local.ifPresent(locals::add); @@ -172,30 +189,45 @@ public Set findLocals(SEXP e) { return locals; } - private static Optional getAssignedVar(LangSXP l) { - var v = l.arg(0).value(); + public static Optional getAssignVar(LangSXP call) { + var v = call.arg(0); if (v == SEXPs.MISSING_ARG) { - throw new CompilerException("Bad assignment: " + l); + throw new CompilerException("Bad assignment: " + call); } else if (v instanceof StrOrRegSymSXP s) { return s.reifyString(); } else { - if (l.args().isEmpty()) { - throw new CompilerException("Bad assignment: " + l); + if (call.args().isEmpty()) { + throw new CompilerException("Bad assignment: " + call); } - switch (l.arg(0).value()) { + switch (call.arg(0)) { case LangSXP ll -> { - return getAssignedVar(ll); + return getAssignVar(ll); } case StrOrRegSymSXP s -> { return s.reifyString(); } - default -> { - throw new CompilerException("Bad assignment: " + l); - } + default -> throw new CompilerException("Bad assignment: " + call); } } } + public static Optional getAssignFun(SEXP fun) { + if (fun instanceof RegSymSXP s) { + return Optional.of(SEXPs.symbol(s.name() + "<-")); + } else + // >> check for and handle foo::bar(x) <- y assignments here + if (fun instanceof LangSXP call + && call.args().size() == 2 + && (call.funName("::") || call.funName(":::")) + && call.arg(0) instanceof RegSymSXP + && call.arg(1) instanceof RegSymSXP) { + var args = call.args().set(1, null, SEXPs.symbol(call.arg(1) + "<-")); + return Optional.of(SEXPs.lang(call.fun(), args)); + } else { + return Optional.empty(); + } + } + public boolean isTailCall() { return tailCall; } @@ -204,7 +236,20 @@ public boolean isReturnJump() { return returnJump; } - public Loop loop() { + @SuppressWarnings("BooleanMethodIsAlwaysInverted") + public boolean isTopLevel() { + return topLevel; + } + + @Nullable Loop loop() { return loop; } + + public Optional findFunDef(String name) { + return environment + .find(name) + .map(Pair::second) + .filter(CloSXP.class::isInstance) + .map(CloSXP.class::cast); + } } diff --git a/src/main/java/org/prlprg/bc/Loop.java b/src/main/java/org/prlprg/bc/Loop.java deleted file mode 100644 index 1d5d5f220..000000000 --- a/src/main/java/org/prlprg/bc/Loop.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.prlprg.bc; - -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; - -@SuppressFBWarnings({"EI_EXPOSE_REP2", "EI_EXPOSE_REP"}) -public sealed interface Loop { - record InLoop(BcLabel start, BcLabel end, boolean gotoOK) implements Loop { - @Override - public Loop gotoNotOK() { - return new InLoop(start, end, false); - } - } - - record NotInLoop() implements Loop { - @Override - public Loop gotoNotOK() { - return this; - } - } - - Loop gotoNotOK(); -} diff --git a/src/main/java/org/prlprg/primitive/Complex.java b/src/main/java/org/prlprg/primitive/Complex.java index cc821e7f7..ebd292930 100644 --- a/src/main/java/org/prlprg/primitive/Complex.java +++ b/src/main/java/org/prlprg/primitive/Complex.java @@ -1,9 +1,53 @@ package org.prlprg.primitive; /** Complex number */ -public record Complex(double real, double imaginary) { +public record Complex(double real, double imag) { + public static Complex fromReal(double x) { + return new Complex(x, 0); + } + @Override public String toString() { - return real + "+" + imaginary + "i"; + return real + (imag >= 0 ? "+" : "") + imag + "i"; + } + + public Complex add(Complex other) { + return new Complex(real + other.real, imag + other.imag); + } + + public Complex sub(Complex other) { + return new Complex(real - other.real, imag - other.imag); + } + + public Complex mul(Complex other) { + return new Complex( + real * other.real - imag * other.imag, real * other.imag + imag * other.real); + } + + public Complex div(Complex other) { + var D = other.real * other.real + other.imag * other.imag; + return new Complex( + (real * other.real + imag * other.imag) / D, (imag * other.real - real * other.imag) / D); + } + + public Complex pow(Complex that) { + double r = Math.hypot(this.real, this.imag); + double i = Math.atan2(this.imag, this.real); + double theta = i * that.real; + double rho; + + if (that.imag == 0) { + rho = Math.pow(r, that.real); + } else { + r = Math.log(r); + theta += r * that.imag; + rho = Math.exp(r * that.real - i * that.imag); + } + + return new Complex(rho * Math.cos(theta), rho * Math.sin(theta)); + } + + public Complex minus() { + return new Complex(-real, -imag); } } diff --git a/src/main/java/org/prlprg/primitive/Constants.java b/src/main/java/org/prlprg/primitive/Constants.java index 86e89ff05..afe545f4a 100644 --- a/src/main/java/org/prlprg/primitive/Constants.java +++ b/src/main/java/org/prlprg/primitive/Constants.java @@ -1,6 +1,7 @@ package org.prlprg.primitive; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.util.Set; import org.prlprg.sexp.SEXPs; /** Constants for R runtime primitives. Other constants are in {@link SEXPs}. */ @@ -20,6 +21,9 @@ public final class Constants { /** The actual value of an {@code NA} element of a real vector in GNU-R. */ public static final double NA_REAL = Double.NaN; + /** The actual value of an {@code NA} element of a complex vector in GNU-R. */ + public static final Complex NA_COMPLEX = new Complex(NA_REAL, NA_REAL); + /** * The "NA string": a unique string that is compared for identity which represents NA values. * @@ -29,11 +33,11 @@ public final class Constants { @SuppressWarnings("StringOperationCanBeSimplified") public static final String NA_STRING = new String("!!!NA_STRING!!!"); - /** Check if a string is the NA string. */ - @SuppressWarnings("StringEquality") - public static boolean isNaString(String s) { - return s == NA_STRING; - } + /** String representations of true values. (cf. StringTrue and truenames in util.c) */ + public static final Set TRUE_NAMES = Set.of("T", "True", "TRUE", "true"); + + /** String representations of false values. (cf. StringFalse and falsenames in util.c) */ + public static final Set FALSE_NAMES = Set.of("F", "False", "FALSE", "false"); private Constants() {} } diff --git a/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java b/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java new file mode 100644 index 000000000..3fea7eb17 --- /dev/null +++ b/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java @@ -0,0 +1,391 @@ +package org.prlprg.rds; + +import com.google.common.primitives.ImmutableIntArray; +import java.util.List; +import org.prlprg.bc.*; +import org.prlprg.sexp.IntSXP; +import org.prlprg.sexp.SEXP; +import org.prlprg.sexp.SEXPs; + +class GNURByteCodeDecoderFactory { + private final ImmutableIntArray byteCode; + private final ConstPool.Builder cpb; + private final BcCode.Builder cbb; + private final LabelMapping labelMapping; + int curr; + + GNURByteCodeDecoderFactory(ImmutableIntArray byteCode, List consts) { + this.byteCode = byteCode; + + cpb = new ConstPool.Builder(consts); + cbb = new BcCode.Builder(); + labelMapping = LabelMapping.from(byteCode); + + curr = 1; + } + + public Bc create() { + var code = buildCode(); + var pool = cpb.build(); + return new Bc(code, pool); + } + + private BcCode buildCode() { + if (byteCode.isEmpty()) { + throw new IllegalArgumentException("Bytecode is empty, needs at least version number"); + } + if (byteCode.get(0) != Bc.R_BC_VERSION) { + throw new IllegalArgumentException("Unsupported bytecode version: " + byteCode.get(0)); + } + + int sanityCheckJ = 0; + while (curr < byteCode.length()) { + try { + var instr = decode(); + + cbb.add(instr); + sanityCheckJ++; + + // FIXME: too many exceptions + try { + var sanityCheckJFromI = labelMapping.make(curr).getTarget(); + if (sanityCheckJFromI != sanityCheckJ) { + throw new AssertionError( + "expected target offset " + sanityCheckJ + ", got " + sanityCheckJFromI); + } + } catch (IllegalArgumentException | AssertionError e) { + throw new AssertionError( + "BcInstrs.fromRaw and BcInstrs.sizeFromRaw are out of sync, at instruction " + instr, + e); + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "malformed bytecode at " + curr + "\nBytecode up to this point: " + cbb.build(), e); + } + } + + return cbb.build(); + } + + BcInstr decode() { + BcOp op; + try { + op = BcOp.valueOf(byteCode.get(curr++)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "invalid opcode (instruction) at " + byteCode.get(curr - 1)); + } + + try { + return switch (op) { + case BCMISMATCH -> + throw new IllegalArgumentException("invalid opcode " + BcOp.BCMISMATCH.value()); + case RETURN -> new BcInstr.Return(); + case GOTO -> new BcInstr.Goto(labelMapping.make(byteCode.get(curr++))); + case BRIFNOT -> + new BcInstr.BrIfNot( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case POP -> new BcInstr.Pop(); + case DUP -> new BcInstr.Dup(); + case PRINTVALUE -> new BcInstr.PrintValue(); + case STARTLOOPCNTXT -> + new BcInstr.StartLoopCntxt( + byteCode.get(curr++) != 0, labelMapping.make(byteCode.get(curr++))); + case ENDLOOPCNTXT -> new BcInstr.EndLoopCntxt(byteCode.get(curr++) != 0); + case DOLOOPNEXT -> new BcInstr.DoLoopNext(); + case DOLOOPBREAK -> new BcInstr.DoLoopBreak(); + case STARTFOR -> + new BcInstr.StartFor( + cpb.indexLang(byteCode.get(curr++)), + cpb.indexSym(byteCode.get(curr++)), + labelMapping.make(byteCode.get(curr++))); + case STEPFOR -> new BcInstr.StepFor(labelMapping.make(byteCode.get(curr++))); + case ENDFOR -> new BcInstr.EndFor(); + case SETLOOPVAL -> new BcInstr.SetLoopVal(); + case INVISIBLE -> new BcInstr.Invisible(); + case LDCONST -> new BcInstr.LdConst(cpb.index(byteCode.get(curr++))); + case LDNULL -> new BcInstr.LdNull(); + case LDTRUE -> new BcInstr.LdTrue(); + case LDFALSE -> new BcInstr.LdFalse(); + case GETVAR -> new BcInstr.GetVar(cpb.indexSym(byteCode.get(curr++))); + case DDVAL -> new BcInstr.DdVal(cpb.indexSym(byteCode.get(curr++))); + case SETVAR -> new BcInstr.SetVar(cpb.indexSym(byteCode.get(curr++))); + case GETFUN -> new BcInstr.GetFun(cpb.indexSym(byteCode.get(curr++))); + case GETGLOBFUN -> new BcInstr.GetGlobFun(cpb.indexSym(byteCode.get(curr++))); + case GETSYMFUN -> new BcInstr.GetSymFun(cpb.indexSym(byteCode.get(curr++))); + case GETBUILTIN -> new BcInstr.GetBuiltin(cpb.indexSym(byteCode.get(curr++))); + case GETINTLBUILTIN -> new BcInstr.GetIntlBuiltin(cpb.indexSym(byteCode.get(curr++))); + case CHECKFUN -> new BcInstr.CheckFun(); + case MAKEPROM -> new BcInstr.MakeProm(cpb.index(byteCode.get(curr++))); + case DOMISSING -> new BcInstr.DoMissing(); + case SETTAG -> new BcInstr.SetTag(cpb.indexStrOrSymOrNil(byteCode.get(curr++))); + case DODOTS -> new BcInstr.DoDots(); + case PUSHARG -> new BcInstr.PushArg(); + case PUSHCONSTARG -> new BcInstr.PushConstArg(cpb.index(byteCode.get(curr++))); + case PUSHNULLARG -> new BcInstr.PushNullArg(); + case PUSHTRUEARG -> new BcInstr.PushTrueArg(); + case PUSHFALSEARG -> new BcInstr.PushFalseArg(); + case CALL -> new BcInstr.Call(cpb.indexLang(byteCode.get(curr++))); + case CALLBUILTIN -> new BcInstr.CallBuiltin(cpb.indexLang(byteCode.get(curr++))); + case CALLSPECIAL -> new BcInstr.CallSpecial(cpb.indexLang(byteCode.get(curr++))); + case MAKECLOSURE -> new BcInstr.MakeClosure(cpb.indexClosure(byteCode.get(curr++))); + case UMINUS -> new BcInstr.UMinus(cpb.indexLang(byteCode.get(curr++))); + case UPLUS -> new BcInstr.UPlus(cpb.indexLang(byteCode.get(curr++))); + case ADD -> new BcInstr.Add(cpb.indexLang(byteCode.get(curr++))); + case SUB -> new BcInstr.Sub(cpb.indexLang(byteCode.get(curr++))); + case MUL -> new BcInstr.Mul(cpb.indexLang(byteCode.get(curr++))); + case DIV -> new BcInstr.Div(cpb.indexLang(byteCode.get(curr++))); + case EXPT -> new BcInstr.Expt(cpb.indexLang(byteCode.get(curr++))); + case SQRT -> new BcInstr.Sqrt(cpb.indexLang(byteCode.get(curr++))); + case EXP -> new BcInstr.Exp(cpb.indexLang(byteCode.get(curr++))); + case EQ -> new BcInstr.Eq(cpb.indexLang(byteCode.get(curr++))); + case NE -> new BcInstr.Ne(cpb.indexLang(byteCode.get(curr++))); + case LT -> new BcInstr.Lt(cpb.indexLang(byteCode.get(curr++))); + case LE -> new BcInstr.Le(cpb.indexLang(byteCode.get(curr++))); + case GE -> new BcInstr.Ge(cpb.indexLang(byteCode.get(curr++))); + case GT -> new BcInstr.Gt(cpb.indexLang(byteCode.get(curr++))); + case AND -> new BcInstr.And(cpb.indexLang(byteCode.get(curr++))); + case OR -> new BcInstr.Or(cpb.indexLang(byteCode.get(curr++))); + case NOT -> new BcInstr.Not(cpb.indexLang(byteCode.get(curr++))); + case DOTSERR -> new BcInstr.DotsErr(); + case STARTASSIGN -> new BcInstr.StartAssign(cpb.indexSym(byteCode.get(curr++))); + case ENDASSIGN -> new BcInstr.EndAssign(cpb.indexSym(byteCode.get(curr++))); + case STARTSUBSET -> + new BcInstr.StartSubset( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case DFLTSUBSET -> new BcInstr.DfltSubset(); + case STARTSUBASSIGN -> + new BcInstr.StartSubassign( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case DFLTSUBASSIGN -> new BcInstr.DfltSubassign(); + case STARTC -> + new BcInstr.StartC( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case DFLTC -> new BcInstr.DfltC(); + case STARTSUBSET2 -> + new BcInstr.StartSubset2( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case DFLTSUBSET2 -> new BcInstr.DfltSubset2(); + case STARTSUBASSIGN2 -> + new BcInstr.StartSubassign2( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case DFLTSUBASSIGN2 -> new BcInstr.DfltSubassign2(); + case DOLLAR -> + new BcInstr.Dollar( + cpb.indexLang(byteCode.get(curr++)), cpb.indexSym(byteCode.get(curr++))); + case DOLLARGETS -> + new BcInstr.DollarGets( + cpb.indexLang(byteCode.get(curr++)), cpb.indexSym(byteCode.get(curr++))); + case ISNULL -> new BcInstr.IsNull(); + case ISLOGICAL -> new BcInstr.IsLogical(); + case ISINTEGER -> new BcInstr.IsInteger(); + case ISDOUBLE -> new BcInstr.IsDouble(); + case ISCOMPLEX -> new BcInstr.IsComplex(); + case ISCHARACTER -> new BcInstr.IsCharacter(); + case ISSYMBOL -> new BcInstr.IsSymbol(); + case ISOBJECT -> new BcInstr.IsObject(); + case ISNUMERIC -> new BcInstr.IsNumeric(); + case VECSUBSET -> new BcInstr.VecSubset(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case MATSUBSET -> new BcInstr.MatSubset(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case VECSUBASSIGN -> + new BcInstr.VecSubassign(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case MATSUBASSIGN -> + new BcInstr.MatSubassign(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case AND1ST -> + new BcInstr.And1st( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case AND2ND -> new BcInstr.And2nd(cpb.indexLang(byteCode.get(curr++))); + case OR1ST -> + new BcInstr.Or1st( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case OR2ND -> new BcInstr.Or2nd(cpb.indexLang(byteCode.get(curr++))); + case GETVAR_MISSOK -> new BcInstr.GetVarMissOk(cpb.indexSym(byteCode.get(curr++))); + case DDVAL_MISSOK -> new BcInstr.DdValMissOk(cpb.indexSym(byteCode.get(curr++))); + case VISIBLE -> new BcInstr.Visible(); + case SETVAR2 -> new BcInstr.SetVar2(cpb.indexSym(byteCode.get(curr++))); + case STARTASSIGN2 -> new BcInstr.StartAssign2(cpb.indexSym(byteCode.get(curr++))); + case ENDASSIGN2 -> new BcInstr.EndAssign2(cpb.indexSym(byteCode.get(curr++))); + case SETTER_CALL -> + new BcInstr.SetterCall( + cpb.indexLang(byteCode.get(curr++)), cpb.index(byteCode.get(curr++))); + case GETTER_CALL -> new BcInstr.GetterCall(cpb.indexLang(byteCode.get(curr++))); + case SWAP -> new BcInstr.SpecialSwap(); + case DUP2ND -> new BcInstr.Dup2nd(); + case SWITCH -> { + var ast = cpb.indexLang(byteCode.get(curr++)); + var names = cpb.indexStrOrNil(byteCode.get(curr++)); + var chrLabelsIdx = cpb.indexIntOrNil(byteCode.get(curr++)); + var numlabelsIdx = cpb.indexIntOrNil(byteCode.get(curr++)); + + // in the case switch does not have any named labels this will be null, + if (chrLabelsIdx != null) { + cpb.reset(chrLabelsIdx, this::remapLabels); + } + + // FIXME: can this ever be null? there always have to be some number labels? or in the + // case of empty switch? + // in some cases, the number labels can be the same as the chrLabels + // and we do not want to remap twice + if (numlabelsIdx != null && !numlabelsIdx.equals(chrLabelsIdx)) { + cpb.reset(numlabelsIdx, this::remapLabels); + } + + yield new BcInstr.Switch(ast, names, chrLabelsIdx, numlabelsIdx); + } + case RETURNJMP -> new BcInstr.ReturnJmp(); + case STARTSUBSET_N -> + new BcInstr.StartSubsetN( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case STARTSUBASSIGN_N -> + new BcInstr.StartSubassignN( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case VECSUBSET2 -> + new BcInstr.VecSubset2(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case MATSUBSET2 -> + new BcInstr.MatSubset2(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case VECSUBASSIGN2 -> + new BcInstr.VecSubassign2(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case MATSUBASSIGN2 -> + new BcInstr.MatSubassign2(cpb.indexLangOrNilIfNegative(byteCode.get(curr++))); + case STARTSUBSET2_N -> + new BcInstr.StartSubset2N( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case STARTSUBASSIGN2_N -> + new BcInstr.StartSubassign2N( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case SUBSET_N -> + new BcInstr.SubsetN( + cpb.indexLangOrNilIfNegative(byteCode.get(curr++)), byteCode.get(curr++)); + case SUBSET2_N -> + new BcInstr.Subset2N( + cpb.indexLangOrNilIfNegative(byteCode.get(curr++)), byteCode.get(curr++)); + case SUBASSIGN_N -> + new BcInstr.SubassignN( + cpb.indexLangOrNilIfNegative(byteCode.get(curr++)), byteCode.get(curr++)); + case SUBASSIGN2_N -> + new BcInstr.Subassign2N( + cpb.indexLangOrNilIfNegative(byteCode.get(curr++)), byteCode.get(curr++)); + case LOG -> new BcInstr.Log(cpb.indexLang(byteCode.get(curr++))); + case LOGBASE -> new BcInstr.LogBase(cpb.indexLang(byteCode.get(curr++))); + case MATH1 -> new BcInstr.Math1(cpb.indexLang(byteCode.get(curr++)), byteCode.get(curr++)); + case DOTCALL -> + new BcInstr.DotCall(cpb.indexLang(byteCode.get(curr++)), byteCode.get(curr++)); + case COLON -> new BcInstr.Colon(cpb.indexLang(byteCode.get(curr++))); + case SEQALONG -> new BcInstr.SeqAlong(cpb.indexLang(byteCode.get(curr++))); + case SEQLEN -> new BcInstr.SeqLen(cpb.indexLang(byteCode.get(curr++))); + case BASEGUARD -> + new BcInstr.BaseGuard( + cpb.indexLang(byteCode.get(curr++)), labelMapping.make(byteCode.get(curr++))); + case INCLNK -> new BcInstr.IncLnk(); + case DECLNK -> new BcInstr.DecLnk(); + case DECLNK_N -> new BcInstr.DeclnkN(byteCode.get(curr++)); + case INCLNKSTK -> new BcInstr.IncLnkStk(); + case DECLNKSTK -> new BcInstr.DecLnkStk(); + }; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("invalid opcode " + op + " (arguments)", e); + } catch (ArrayIndexOutOfBoundsException e) { + throw new IllegalArgumentException( + "invalid opcode " + op + " (arguments, unexpected end of bytecode stream)"); + } + } + + private IntSXP remapLabels(IntSXP oldLabels) { + var remapped = oldLabels.data().stream().map(labelMapping::getTarget).toArray(); + return SEXPs.integer(remapped); + } +} + +/** + * Create labels from GNU-R labels. + * + * @implNote This contains a map of positions in GNU-R bytecode to positions in our bytecode. We + * need this because every index in our bytecode maps to an instruction, while indexes in + * GNU-R's bytecode also map to the bytecode version and instruction metadata. + */ +class LabelMapping { + private final ImmutableIntArray posMap; + + private LabelMapping(ImmutableIntArray posMap) { + this.posMap = posMap; + } + + /** Create a label from a GNU-R label. */ + BcLabel make(int gnurLabel) { + return new BcLabel(getTarget(gnurLabel)); + } + + int getTarget(int gnurLabel) { + if (gnurLabel == 0) { + throw new IllegalArgumentException("GNU-R label 0 is reserved for the version number"); + } + + var target = posMap.get(gnurLabel); + if (target == -1) { + var gnurEarlier = gnurLabel - 1; + int earlier = posMap.get(gnurEarlier); + var gnurLater = gnurLabel + 1; + int later = posMap.get(gnurLater); + throw new IllegalArgumentException( + "GNU-R position maps to the middle of one of our instructions: " + + gnurLabel + + " between " + + earlier + + " and " + + later); + } + + return target; + } + + static LabelMapping from(ImmutableIntArray gnurBC) { + var builder = new Builder(); + // skip the BC version number + int i = 1; + while (i < gnurBC.length()) { + try { + var op = BcOp.valueOf(gnurBC.get(i)); + var size = 1 + op.nArgs(); + builder.step(size, 1); + i += size; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "malformed bytecode at " + i + "\nBytecode up to this point: " + builder.build(), e); + } + } + return builder.build(); + } + + // FIXME: inline + static class Builder { + private final ImmutableIntArray.Builder map = ImmutableIntArray.builder(); + private int targetPc = 0; + + Builder() { + // Add initial mapping of 1 -> 0 (version # is 0) + map.add(-1); + map.add(0); + } + + /** Step m times in the source bytecode and n times in the target bytecode */ + void step(int sourceOffset, @SuppressWarnings("SameParameterValue") int targetOffset) { + if (sourceOffset < 0 || targetOffset < 0) { + throw new IllegalArgumentException("offsets must be nonnegative"); + } + + targetPc += targetOffset; + // Offsets before sourceOffset map to the middle of the previous instruction + for (int i = 0; i < sourceOffset - 1; i++) { + map.add(-1); + } + // Add target position + if (sourceOffset > 0) { + map.add(targetPc); + } + } + + LabelMapping build() { + return new LabelMapping(map.build()); + } + } +} diff --git a/src/main/java/org/prlprg/rds/RDSInputStream.java b/src/main/java/org/prlprg/rds/RDSInputStream.java index 0386c84a0..ea7db7c83 100644 --- a/src/main/java/org/prlprg/rds/RDSInputStream.java +++ b/src/main/java/org/prlprg/rds/RDSInputStream.java @@ -18,8 +18,14 @@ public void close() throws IOException { in.close(); } - public boolean isAtEnd() throws IOException { - return in.available() == 0; + /** + * Reads the next byte of data from the input stream. + * + * @return the next byte of data, or -1 if the end of the stream is reached. + * @throws IOException if an I/O error occurs. + */ + public int readRaw() throws IOException { + return in.read(); } public byte readByte() throws IOException { diff --git a/src/main/java/org/prlprg/rds/RDSItemType.java b/src/main/java/org/prlprg/rds/RDSItemType.java index 810fabec0..ef0dd5c67 100644 --- a/src/main/java/org/prlprg/rds/RDSItemType.java +++ b/src/main/java/org/prlprg/rds/RDSItemType.java @@ -30,6 +30,7 @@ static RDSItemType valueOf(int i) { case 241 -> Special.BASEENV_SXP; case 240 -> Special.ATTRLANGSXP; case 239 -> Special.ATTRLISTSXP; + case 238 -> Special.ALTREPSXP; default -> { try { yield new Sexp(SEXPType.valueOf(i)); @@ -66,7 +67,8 @@ enum Special implements RDSItemType { EMPTYENV_SXP(242), BASEENV_SXP(241), ATTRLANGSXP(240), - ATTRLISTSXP(239); + ATTRLISTSXP(239), + ALTREPSXP(238); private final int i; diff --git a/src/main/java/org/prlprg/rds/RDSReader.java b/src/main/java/org/prlprg/rds/RDSReader.java index e4847d6ed..d2318ec9d 100644 --- a/src/main/java/org/prlprg/rds/RDSReader.java +++ b/src/main/java/org/prlprg/rds/RDSReader.java @@ -5,21 +5,18 @@ import java.io.*; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; +import java.util.*; import javax.annotation.Nullable; import org.prlprg.RSession; import org.prlprg.bc.Bc; -import org.prlprg.bc.BcFromRawException; +import org.prlprg.primitive.Complex; import org.prlprg.primitive.Constants; import org.prlprg.primitive.Logical; import org.prlprg.sexp.*; import org.prlprg.util.IO; public class RDSReader implements Closeable { - private final RSession session; + private final RSession rsession; private final RDSInputStream in; private final List refTable = new ArrayList<>(128); @@ -27,22 +24,22 @@ public class RDSReader implements Closeable { private Charset nativeEncoding = Charset.defaultCharset(); private RDSReader(RSession session, InputStream in) { - this.session = session; + this.rsession = session; this.in = new RDSInputStream(in); } - public static SEXP readStream(RSession session, InputStream input) throws IOException { - try (var reader = new RDSReader(session, input)) { - return reader.read(); - } - } - public static SEXP readFile(RSession session, File file) throws IOException { try (var input = new FileInputStream(file)) { return readStream(session, IO.maybeDecompress(input)); } } + public static SEXP readStream(RSession session, InputStream input) throws IOException { + try (var reader = new RDSReader(session, input)) { + return reader.read(); + } + } + private void readHeader() throws IOException { var type = in.readByte(); if (type != 'X') { @@ -53,28 +50,21 @@ private void readHeader() throws IOException { // versions var formatVersion = in.readInt(); + if (formatVersion != 2) { + // we do not support RDS version 3 because it uses ALTREP + throw new RDSException("Unsupported RDS version: " + formatVersion); + } + // writer version in.readInt(); // minimal reader version in.readInt(); - - // native encoding for version == 3 - switch (formatVersion) { - case 2: - break; - case 3: - var natEncSize = in.readInt(); - nativeEncoding = Charset.forName(in.readString(natEncSize, StandardCharsets.US_ASCII)); - break; - default: - throw new RDSException("Unsupported version: " + formatVersion); - } } public SEXP read() throws IOException { readHeader(); var sexp = readItem(); - if (!in.isAtEnd()) { + if (in.readRaw() != -1) { throw new RDSException("Expected end of file"); } return sexp; @@ -100,14 +90,18 @@ private SEXP readItem() throws IOException { case BCODE -> readByteCode(); case EXPR -> readExpr(flags); case PROM -> readPromise(flags); + case BUILTIN -> readBuiltin(false); + case SPECIAL -> readBuiltin(true); + case CPLX -> readComplex(flags); default -> throw new RDSException("Unsupported SEXP type: " + s.sexp()); }; case RDSItemType.Special s -> switch (s) { case NILVALUE_SXP -> SEXPs.NULL; case MISSINGARG_SXP -> SEXPs.MISSING_ARG; - case GLOBALENV_SXP -> session.globalEnv(); - case BASEENV_SXP -> session.baseEnv(); + case GLOBALENV_SXP -> rsession.globalEnv(); + case BASEENV_SXP -> rsession.baseEnv(); + case BASENAMESPACE_SXP -> rsession.baseNamespace(); case EMPTYENV_SXP -> SEXPs.EMPTY_ENV; case REFSXP -> readRef(flags); case NAMESPACESXP -> readNamespace(); @@ -115,23 +109,43 @@ private SEXP readItem() throws IOException { throw new RDSException("Unexpected bytecode reference here (not in bytecode)"); case ATTRLANGSXP, ATTRLISTSXP -> throw new RDSException("Unexpected attr here"); case UNBOUNDVALUE_SXP -> SEXPs.UNBOUND_VALUE; - case GENERICREFSXP, BASENAMESPACE_SXP, PACKAGESXP, PERSISTSXP, CLASSREFSXP -> + case GENERICREFSXP, PACKAGESXP, PERSISTSXP, CLASSREFSXP, ALTREPSXP -> throw new RDSException("Unsupported RDS special type: " + s); }; }; } + private SEXP readComplex(Flags flags) throws IOException { + var length = in.readInt(); + var cplx = ImmutableList.builder(); + for (int i = 0; i < length; i++) { + var real = in.readDouble(); + var im = in.readDouble(); + cplx.add(new Complex(real, im)); + } + var attributes = readAttributes(flags); + return SEXPs.complex(cplx.build(), attributes); + } + + private SEXP readBuiltin(boolean special) throws IOException { + var length = in.readInt(); + var name = in.readString(length, nativeEncoding); + return special ? SEXPs.special(name) : SEXPs.builtin(name); + } + private SEXP readPromise(Flags flags) throws IOException { readAttributes(flags); - - if (!(readItem() instanceof EnvSXP env)) { - throw new RDSException("Expected promise ENV to be environment"); - } + var tag = flags.hasTag() ? readItem() : SEXPs.NULL; var val = readItem(); var expr = readItem(); - // TODO: attributes? - return new PromSXP(expr, val, env); + if (tag instanceof NilSXP) { + return new PromSXP(expr, val, SEXPs.EMPTY_ENV); + } else if (tag instanceof EnvSXP env) { + return new PromSXP(expr, val, env); + } else { + throw new RDSException("Expected promise ENV to be environment"); + } } private SEXP readNamespace() throws IOException { @@ -140,9 +154,7 @@ private SEXP readNamespace() throws IOException { throw new RDSException("Expected 2-element list, got: " + namespaceInfo); } - // FIXME: this should be loaded from RSession - var namespace = - new NamespaceEnvSXP(SEXPs.EMPTY_ENV, namespaceInfo.get(0), namespaceInfo.get(1)); + var namespace = rsession.getNamespace(namespaceInfo.get(0), namespaceInfo.get(1)); refTable.add(namespace); return namespace; @@ -194,11 +206,8 @@ private BCodeSXP readByteCode1(SEXP[] reps) throws IOException { } var consts = readByteCodeConsts(reps); - try { - return SEXPs.bcode(Bc.fromRaw(code.data(), consts)); - } catch (BcFromRawException e) { - throw new RDSException("Error reading bytecode", e); - } + var factory = new GNURByteCodeDecoderFactory(code.data(), consts); + return SEXPs.bcode(factory.create()); } private List readByteCodeConsts(SEXP[] reps) throws IOException { @@ -253,52 +262,58 @@ private SEXP readByteCodeLang1(RDSItemType type, SEXP[] reps) throws IOException ? readAttributes() : Attributes.NONE; - SEXP tagSexp; + SEXP tagSexp = readItem(); SEXP ans; - if (type.isSexp(SEXPType.LANG) || type == RDSItemType.Special.ATTRLANGSXP) { - tagSexp = readItem(); - if (tagSexp != SEXPs.NULL) { - throw new RDSException("Expected NULL tag"); - } + if (!type.isSexp(SEXPType.LANG) + && !type.isSexp(SEXPType.LIST) + && type != RDSItemType.Special.ATTRLANGSXP + && type != RDSItemType.Special.ATTRLISTSXP) { + throw new RDSException("Unexpected bclang type: " + type); + } - var fun = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); - if (!(fun instanceof SymOrLangSXP funSymOrLang)) { - throw new RDSException("Expected symbol or language, got: " + fun.type()); - } + String tag; - var args = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); - if (!(args instanceof ListSXP argsList)) { - throw new RDSException("Expected list, got: " + args.type()); - } + if (tagSexp instanceof RegSymSXP sym) { + tag = sym.name(); + } else if (tagSexp instanceof NilSXP) { + tag = null; + } else { + throw new RDSException("Expected regular symbol or nil"); + } - ans = SEXPs.lang(funSymOrLang, argsList, attributes); - } else if (type.isSexp(SEXPType.LIST) || type == RDSItemType.Special.ATTRLISTSXP) { - tagSexp = readItem(); - String tag; + var head = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); + var tail = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); - if (tagSexp instanceof RegSymSXP sym) { - tag = sym.name(); - } else if (tagSexp instanceof NilSXP) { - tag = null; - } else { - throw new RDSException("Expected regular symbol or nil"); - } + ListSXP tailList; - var head = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); - var tail = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); + // In R LISTSXP and LANGSXP are pretty much the same. + // The RDS relies on this fact quite a bit. + // The tail could thus be a LANGSXP (for example: methods::externalRefMethod) + if (tail instanceof ListSXP) { + tailList = (ListSXP) tail; + } else if (tail instanceof LangSXP lang) { + tailList = lang.args().prepend(new TaggedElem(null, lang.fun())); + } else { + throw new RDSException("Expected list or language, got: " + tail.type()); + } + + ans = tailList.prepend(new TaggedElem(tag, head)).withAttributes(attributes); + + if (type.isSexp(SEXPType.LANG) || type == RDSItemType.Special.ATTRLANGSXP) { + ListSXP ansList = (ListSXP) ans; - var data = new ImmutableList.Builder(); - data.add(new TaggedElem(tag, head)); + var fun = ansList.get(0).value(); + if (!(fun instanceof SymOrLangSXP funSymOrLang)) { + throw new RDSException("Expected symbol or language, got: " + fun.type()); + } - if (!(tail instanceof ListSXP tailList)) { - throw new RDSException("Expected list, got: " + tail.type()); + ListSXP args = SEXPs.NULL; + if (ansList.size() > 1) { + args = ansList.subList(1); } - ListSXP.flatten(tailList, data); - ans = SEXPs.list(data.build(), attributes); - } else { - throw new RDSException("Unexpected bclang type: " + type); + ans = SEXPs.lang(funSymOrLang, args, attributes); } if (pos >= 0) { @@ -368,7 +383,7 @@ private UserEnvSXP readEnv() throws IOException { // enclosing environment - parent switch (readItem()) { case EnvSXP parent -> item.setParent(parent); - case NilSXP ignored -> item.setParent(session.baseEnv()); + case NilSXP ignored -> item.setParent(rsession.baseEnv()); default -> throw new RDSException("Expected environment (ENCLOS)"); } @@ -402,9 +417,7 @@ private UserEnvSXP readEnv() throws IOException { default -> throw new RDSException("Expected list (HASHTAB)"); } - item.setAttributes(readAttributes()); - - return item; + return item.withAttributes(readAttributes()); } private VecSXP readVec(Flags flags) throws IOException { diff --git a/src/main/java/org/prlprg/sexp/AbstractEnvSXP.java b/src/main/java/org/prlprg/sexp/AbstractEnvSXP.java new file mode 100644 index 000000000..ffa764278 --- /dev/null +++ b/src/main/java/org/prlprg/sexp/AbstractEnvSXP.java @@ -0,0 +1,46 @@ +package org.prlprg.sexp; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public abstract sealed class AbstractEnvSXP implements EnvSXP + permits BaseEnvSXP, GlobalEnvSXP, NamespaceEnvSXP, UserEnvSXP { + + protected EnvSXP parent; + protected final Map bindings; + + public AbstractEnvSXP(EnvSXP parent) { + this.parent = parent; + this.bindings = new HashMap<>(); + } + + @Override + public EnvSXP parent() { + return parent; + } + + @Override + public Optional get(String name) { + return getLocal(name).or(() -> parent.get(name)); + } + + @Override + public Optional getLocal(String name) { + return Optional.ofNullable(bindings.get(name)); + } + + @Override + public Iterable> bindings() { + return bindings.entrySet(); + } + + @Override + public int size() { + return bindings.size(); + } + + public void set(String name, SEXP value) { + bindings.put(name, value); + } +} diff --git a/src/main/java/org/prlprg/sexp/Attributes.java b/src/main/java/org/prlprg/sexp/Attributes.java index f755367ca..8baf88f44 100644 --- a/src/main/java/org/prlprg/sexp/Attributes.java +++ b/src/main/java/org/prlprg/sexp/Attributes.java @@ -45,6 +45,18 @@ public int hashCode() { return Objects.hashCode(super.hashCode(), attrs); } + public Attributes excluding(String name) { + var builder = new ImmutableMap.Builder(); + attrs.forEach( + (key, value) -> { + if (!key.equals(name)) { + builder.put(key, value); + } + }); + + return new Attributes(builder.build()); + } + /** Build an {@link Attributes} instance. */ public static class Builder { private final ImmutableMap.Builder attrs = ImmutableMap.builder(); diff --git a/src/main/java/org/prlprg/sexp/BCodeSXP.java b/src/main/java/org/prlprg/sexp/BCodeSXP.java index ea0d722d9..dec3d209f 100644 --- a/src/main/java/org/prlprg/sexp/BCodeSXP.java +++ b/src/main/java/org/prlprg/sexp/BCodeSXP.java @@ -7,19 +7,18 @@ @Immutable public sealed interface BCodeSXP extends SEXP { - /** - * The typed compiled code. - * - *

TODO will be refactored so BCodeSXP stores raw data and can generate this, the method will - * be named something like {@code getBc()} or {@code generateBc()} to make it clear that this is a - * more expensive operation than a virtual getter. - */ + /** The typed compiled code. */ Bc bc(); @Override default SEXPType type() { return SEXPType.BCODE; } + + @Override + default Class getCanonicalType() { + return BCodeSXP.class; + } } record BCodeSXPImpl(Bc bc) implements BCodeSXP { diff --git a/src/main/java/org/prlprg/sexp/BaseEnvSXP.java b/src/main/java/org/prlprg/sexp/BaseEnvSXP.java index 1b6ef8230..efc8a6213 100644 --- a/src/main/java/org/prlprg/sexp/BaseEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/BaseEnvSXP.java @@ -1,29 +1,12 @@ package org.prlprg.sexp; -import com.google.common.collect.ImmutableMap; import java.util.HashMap; -import java.util.Optional; -public final class BaseEnvSXP implements EnvSXP { - private final ImmutableMap bindings; +public final class BaseEnvSXP extends AbstractEnvSXP implements EnvSXP { public BaseEnvSXP(HashMap bindings) { - this.bindings = ImmutableMap.copyOf(bindings); - } - - @Override - public EnvSXP parent() { - return SEXPs.EMPTY_ENV; - } - - @Override - public Optional get(String name) { - return getLocal(name); - } - - @Override - public Optional getLocal(String name) { - return Optional.ofNullable(bindings.get(name)); + super(EmptyEnvSXP.INSTANCE); + bindings.forEach(this::set); } @Override diff --git a/src/main/java/org/prlprg/sexp/BuiltinSXP.java b/src/main/java/org/prlprg/sexp/BuiltinSXP.java new file mode 100644 index 000000000..9418d2b4c --- /dev/null +++ b/src/main/java/org/prlprg/sexp/BuiltinSXP.java @@ -0,0 +1,13 @@ +package org.prlprg.sexp; + +public record BuiltinSXP(String name) implements SEXP { + @Override + public SEXPType type() { + return SEXPType.BUILTIN; + } + + @Override + public Class getCanonicalType() { + return BuiltinSXP.class; + } +} diff --git a/src/main/java/org/prlprg/sexp/CloSXP.java b/src/main/java/org/prlprg/sexp/CloSXP.java index 7af385a08..6ac73720f 100644 --- a/src/main/java/org/prlprg/sexp/CloSXP.java +++ b/src/main/java/org/prlprg/sexp/CloSXP.java @@ -1,12 +1,15 @@ package org.prlprg.sexp; -import javax.annotation.Nullable; +import java.util.Optional; /** Closure SEXP. */ public sealed interface CloSXP extends SEXP { /** The argument names and default values. */ ListSXP formals(); + /** If the body is a BCodeSXP, returns the AST which is stored in the first constant pool slot. */ + SEXP bodyAST(); + /** The closure body. */ SEXP body(); @@ -24,14 +27,28 @@ default SEXPType type() { @Override CloSXP withAttributes(Attributes attributes); - @Nullable IntSXP getSrcRef(); + Optional getSrcRef(); + + @Override + default Class getCanonicalType() { + return CloSXP.class; + } } record CloSXPImpl(ListSXP formals, SEXP body, EnvSXP env, @Override Attributes attributes) implements CloSXP { @Override public String toString() { - return SEXPs.toString(this, env(), formals(), "\n → ", body()); + return SEXPs.toString(this, env(), formals(), "\n → ", body); + } + + @Override + public SEXP bodyAST() { + if (body instanceof BCodeSXP bc) { + return bc.bc().consts().getFirst(); + } else { + return body; + } } @Override @@ -40,7 +57,7 @@ public CloSXP withAttributes(Attributes attributes) { } @Override - public @Nullable IntSXP getSrcRef() { - return (IntSXP) attributes.get("srcref"); + public Optional getSrcRef() { + return Optional.ofNullable((IntSXP) attributes.get("srcref")); } } diff --git a/src/main/java/org/prlprg/sexp/Coercions.java b/src/main/java/org/prlprg/sexp/Coercions.java new file mode 100644 index 000000000..2dc299ddf --- /dev/null +++ b/src/main/java/org/prlprg/sexp/Coercions.java @@ -0,0 +1,240 @@ +package org.prlprg.sexp; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.util.List; +import java.util.regex.Pattern; +import org.prlprg.primitive.Complex; +import org.prlprg.primitive.Constants; +import org.prlprg.primitive.Logical; + +// ideally we would keep this in the actual classes +// and use double-dispatch, but since we rely on Java classes, we cannot +public final class Coercions { + private static final Pattern REAL_NUM = Pattern.compile("\\s*([+-]?)\\s*(\\d*\\.?\\d+)"); + + private Coercions() {} + + public static Object coerce(T t, SEXPType targetType) { + return switch (t) { + case Logical x -> + switch (targetType) { + case LGL -> x; + case STR -> stringFromLogical(x); + case INT -> integerFromLogical(x); + case REAL -> realFromLogical(x); + case CPLX -> complexFromLogical(x); + default -> throw new IllegalArgumentException("Unsupported target type: " + targetType); + }; + case Integer x -> + switch (targetType) { + case LGL -> logicalFromInteger(x); + case STR -> stringFromInteger(x); + case INT -> x; + case REAL -> realFromInteger(x); + case CPLX -> complexFromInteger(x); + default -> throw new IllegalArgumentException("Unsupported target type: " + targetType); + }; + case Double x -> + switch (targetType) { + case LGL -> logicalFromReal(x); + case STR -> stringFromReal(x); + case INT -> integerFromReal(x); + case REAL -> x; + case CPLX -> complexFromReal(x); + default -> throw new IllegalArgumentException("Unsupported target type: " + targetType); + }; + case String x -> + switch (targetType) { + case LGL -> logicalFromString(x); + case STR -> x; + case INT -> integerFromString(x); + case REAL -> realFromString(x); + case CPLX -> complexFromString(x); + default -> throw new IllegalArgumentException("Unsupported target type: " + targetType); + }; + case Complex x -> + switch (targetType) { + case LGL -> logicalFromComplex(x); + case STR -> stringFromComplex(x); + case INT -> integerFromComplex(x); + case REAL -> realFromComplex(x); + case CPLX -> x; + default -> throw new IllegalArgumentException("Unsupported target type: " + targetType); + }; + default -> throw new IllegalArgumentException("Unsupported type: " + t.getClass()); + }; + } + + public static SEXPType commonType(SEXPType... types) { + if (types.length == 0) { + throw new IllegalArgumentException("No types provided"); + } + + var max = types[0].i; + for (var i = 1; i < types.length; i++) { + max = Math.max(max, types[i].i); + } + + return SEXPType.valueOf(max); + } + + public static SEXPType commonType(List args) { + if (args.isEmpty()) { + throw new IllegalArgumentException("No arguments provided"); + } + + return args.stream().map(SEXP::type).reduce(args.getFirst().type(), Coercions::commonType); + } + + public static Complex complexFromInteger(int x) { + return isNA(x) ? Constants.NA_COMPLEX : Complex.fromReal(x); + } + + public static Complex complexFromLogical(Logical x) { + return isNA(x) + ? Constants.NA_COMPLEX + : x == Logical.TRUE ? Complex.fromReal(1) : Complex.fromReal(0); + } + + public static Complex complexFromReal(double x) { + return Complex.fromReal(x); + } + + public static Complex complexFromString(String x) { + if (isNA(x) || x.isBlank()) { + return Constants.NA_COMPLEX; + } + + x = x.strip(); + var hasI = false; + if (x.endsWith("i")) { + hasI = true; + x = x.substring(0, x.length() - 1); + } + + var mrs = REAL_NUM.matcher(x).results().toList(); + switch (mrs.size()) { + case 1 -> { + var m = mrs.getFirst(); + var d = Double.parseDouble(m.group(2)); + d = m.group(1).equals("-") ? -d : d; + return hasI ? new Complex(.0, d) : new Complex(d, .0); + } + + case 2 -> { + if (!hasI) { + throw new IllegalArgumentException("Invalid complex number: " + x); + } + + var m1 = mrs.getFirst(); + var m2 = mrs.getLast(); + + var r = Double.parseDouble(m1.group(2)); + r = m1.group(1).equals("-") ? -r : r; + + var i = Double.parseDouble(m2.group(2)); + i = m2.group(1).equals("-") ? -i : i; + + return new Complex(r, i); + } + + default -> throw new IllegalArgumentException("Invalid complex number: " + x); + } + } + + public static int integerFromComplex(Complex x) { + return isNA(x) ? Constants.NA_INT : (int) x.real(); + } + + public static int integerFromLogical(Logical x) { + return switch (x) { + case TRUE -> 1; + case FALSE -> 0; + case NA -> Constants.NA_INT; + }; + } + + public static int integerFromReal(double x) { + return Double.isNaN(x) ? Constants.NA_INT : (int) x; + } + + public static int integerFromString(String x) { + return isNA(x) ? Constants.NA_INT : Integer.parseInt(x); + } + + @SuppressFBWarnings("ES_COMPARING_PARAMETER_STRING_WITH_EQ") + public static boolean isNA(String x) { + //noinspection StringEquality + return x == Constants.NA_STRING; + } + + public static boolean isNA(Logical x) { + return x == Logical.NA; + } + + public static boolean isNA(int x) { + return x == Constants.NA_INT; + } + + public static boolean isNA(double x) { + return Double.isNaN(x); + } + + public static boolean isNA(Complex x) { + return isNA(x.real()) || isNA(x.imag()); + } + + public static Logical logicalFromComplex(Complex x) { + return (Double.isNaN(x.real()) || Double.isNaN(x.imag())) + ? Logical.NA + : x.real() != 0 ? Logical.TRUE : Logical.FALSE; + } + + public static Logical logicalFromInteger(int x) { + return isNA(x) ? Logical.NA : x != 0 ? Logical.TRUE : Logical.FALSE; + } + + public static Logical logicalFromReal(double x) { + return Double.isNaN(x) ? Logical.NA : x != 0 ? Logical.TRUE : Logical.FALSE; + } + + public static Logical logicalFromString(String x) { + if (isNA(x)) { + return Logical.NA; + } + + return Constants.TRUE_NAMES.contains(x) ? Logical.TRUE : Logical.FALSE; + } + + public static double realFromComplex(Complex x) { + return isNA(x) ? Constants.NA_INT : x.real(); + } + + public static double realFromInteger(int x) { + return isNA(x) ? Constants.NA_REAL : (double) x; + } + + public static double realFromLogical(Logical x) { + return isNA(x) ? Constants.NA_REAL : x == Logical.TRUE ? 1.0 : 0.0; + } + + public static double realFromString(String x) { + return isNA(x) ? Constants.NA_REAL : Double.parseDouble(x); + } + + public static String stringFromComplex(Complex x) { + return isNA(x) ? Constants.NA_STRING : x.toString(); + } + + public static String stringFromInteger(int x) { + return isNA(x) ? Constants.NA_STRING : Integer.toString(x); + } + + public static String stringFromLogical(Logical x) { + return isNA(x) ? Constants.NA_STRING : x == Logical.TRUE ? "TRUE" : "FALSE"; + } + + public static String stringFromReal(double x) { + return Double.isNaN(x) ? Constants.NA_STRING : Double.toString(x); + } +} diff --git a/src/main/java/org/prlprg/sexp/ComplexSXP.java b/src/main/java/org/prlprg/sexp/ComplexSXP.java index cbc3bfadf..8df5b47a0 100644 --- a/src/main/java/org/prlprg/sexp/ComplexSXP.java +++ b/src/main/java/org/prlprg/sexp/ComplexSXP.java @@ -8,7 +8,7 @@ /** Complex vector SEXP. */ @Immutable public sealed interface ComplexSXP extends VectorSXP - permits ComplexSXPImpl, EmptyComplexSXPImpl, SimpleComplexSXP { + permits ComplexSXPImpl, EmptyComplexSXPImpl, ScalarComplexSXP { @Override default SEXPType type() { return SEXPType.CPLX; @@ -16,6 +16,11 @@ default SEXPType type() { @Override ComplexSXP withAttributes(Attributes attributes); + + @Override + default Class getCanonicalType() { + return ComplexSXP.class; + } } /** Complex vector which doesn't fit any of the more specific subclasses. */ @@ -47,6 +52,22 @@ public ComplexSXP withAttributes(Attributes attributes) { } } +final class ScalarComplexSXP extends ScalarSXPImpl implements ComplexSXP { + ScalarComplexSXP(Complex data) { + super(data); + } + + @SuppressWarnings("MissingJavadoc") + public Complex value() { + return data; + } + + @Override + public ComplexSXP withAttributes(Attributes attributes) { + return SEXPs.complex(data, attributes); + } +} + /** Empty complex vector with no ALTREP, ATTRIB, or OBJECT. */ final class EmptyComplexSXPImpl extends EmptyVectorSXPImpl implements ComplexSXP { static final EmptyComplexSXPImpl INSTANCE = new EmptyComplexSXPImpl(); diff --git a/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java b/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java index d92a48267..b39ab89cb 100644 --- a/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java @@ -1,6 +1,8 @@ package org.prlprg.sexp; +import java.util.Map; import java.util.Optional; +import java.util.Set; import org.prlprg.util.Pair; public final class EmptyEnvSXP implements EnvSXP { @@ -19,6 +21,11 @@ public Optional get(String name) { return Optional.empty(); } + @Override + public void set(String name, SEXP value) { + throw new UnsupportedOperationException("cannot set a value in the empty environment"); + } + @Override public Optional getLocal(String name) { return Optional.empty(); @@ -29,6 +36,16 @@ public Optional> find(String name) { return Optional.empty(); } + @Override + public Iterable> bindings() { + return Set.of(); + } + + @Override + public int size() { + return 0; + } + @Override public String toString() { return ""; diff --git a/src/main/java/org/prlprg/sexp/EmptyVectorSXPImpl.java b/src/main/java/org/prlprg/sexp/EmptyVectorSXPImpl.java deleted file mode 100644 index 20e396eb4..000000000 --- a/src/main/java/org/prlprg/sexp/EmptyVectorSXPImpl.java +++ /dev/null @@ -1,36 +0,0 @@ -package org.prlprg.sexp; - -import com.google.common.collect.Iterators; -import com.google.common.collect.UnmodifiableIterator; -import javax.annotation.concurrent.Immutable; - -/** Class for representing a scalar SEXP of a primitive type with no attributes. */ -@Immutable -abstract non-sealed class EmptyVectorSXPImpl implements VectorSXP { - protected EmptyVectorSXPImpl() {} - - @Override - public UnmodifiableIterator iterator() { - return Iterators.forArray(); - } - - @Override - public T get(int i) { - throw new IndexOutOfBoundsException(); - } - - @Override - public int size() { - return 0; - } - - @Override - public String toString() { - return ""; - } - - @Override - public Attributes attributes() { - return Attributes.NONE; - } -} diff --git a/src/main/java/org/prlprg/sexp/EnvSXP.java b/src/main/java/org/prlprg/sexp/EnvSXP.java index faad7074b..c122bc5d0 100644 --- a/src/main/java/org/prlprg/sexp/EnvSXP.java +++ b/src/main/java/org/prlprg/sexp/EnvSXP.java @@ -1,10 +1,11 @@ package org.prlprg.sexp; +import java.util.Map; import java.util.Optional; import org.prlprg.util.Pair; public sealed interface EnvSXP extends SEXP - permits BaseEnvSXP, EmptyEnvSXP, GlobalEnvSXP, NamespaceEnvSXP, UserEnvSXP { + permits AbstractEnvSXP, BaseEnvSXP, EmptyEnvSXP, GlobalEnvSXP, NamespaceEnvSXP, UserEnvSXP { /** * Environments are linked in a parent chain. Every environment, except the empty environment, has * a parent that will be returned by this function. @@ -21,6 +22,14 @@ public sealed interface EnvSXP extends SEXP */ Optional get(String name); + /** + * Set the value of a symbol in the environment. + * + * @param name the name of the symbol + * @param value the value of the symbol + */ + void set(String name, SEXP value); + /** * Get the value of a symbol in the environment, without following the parent chain. * @@ -43,4 +52,29 @@ default SEXPType type() { default Optional> find(String name) { return getLocal(name).map(v -> new Pair<>(this, v)).or(() -> parent().find(name)); } + + Iterable> bindings(); + + /** + * Get the number of symbols in the environment. + * + * @return the number of symbols in the environment + */ + int size(); + + /** + * Returns {@code true} if this is the base environment ({@code baseenv()}) or a base namespace + * ({@code .BaseNamespaceEnv}). namespace. + * + * @return + */ + default boolean isBase() { + return this instanceof BaseEnvSXP + || this instanceof NamespaceEnvSXP ns && ns.getName().equals("base"); + } + + @Override + default Class getCanonicalType() { + return EnvSXP.class; + } } diff --git a/src/main/java/org/prlprg/sexp/ExprSXP.java b/src/main/java/org/prlprg/sexp/ExprSXP.java index aa495d863..f8ca5d2d3 100644 --- a/src/main/java/org/prlprg/sexp/ExprSXP.java +++ b/src/main/java/org/prlprg/sexp/ExprSXP.java @@ -18,6 +18,11 @@ default SEXPType type() { @Override ExprSXP withAttributes(Attributes attributes); + + @Override + default Class getCanonicalType() { + return ExprSXP.class; + } } record ExprSXPImpl(ImmutableList data, @Override Attributes attributes) implements ExprSXP { diff --git a/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java b/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java index 4c15a65ca..e58f1323a 100644 --- a/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java @@ -1,33 +1,8 @@ package org.prlprg.sexp; -import java.util.Optional; - -public final class GlobalEnvSXP implements EnvSXP { - private final EnvSXP parent; - +public final class GlobalEnvSXP extends AbstractEnvSXP implements EnvSXP { public GlobalEnvSXP(EnvSXP parent) { - this.parent = parent; - } - - @Override - public EnvSXP parent() { - return parent; - } - - // FIXME: parent should return the proper namespaces - // in default R session that is: - // stats, graphics, grDevices, utils, datasets, methods, Autoloads, base - - @Override - public Optional get(String name) { - // FIXME: implement - return Optional.empty(); - } - - @Override - public Optional getLocal(String name) { - // FIXME: implement - return Optional.empty(); + super(parent); } @Override diff --git a/src/main/java/org/prlprg/sexp/IntSXP.java b/src/main/java/org/prlprg/sexp/IntSXP.java index 6977d40c5..cf6680139 100644 --- a/src/main/java/org/prlprg/sexp/IntSXP.java +++ b/src/main/java/org/prlprg/sexp/IntSXP.java @@ -6,8 +6,8 @@ /** Integer vector SEXP. */ @Immutable -public sealed interface IntSXP extends VectorSXP - permits EmptyIntSXPImpl, IntSXPImpl, SimpleIntSXP { +public sealed interface IntSXP extends NumericSXP + permits EmptyIntSXPImpl, IntSXPImpl, ScalarIntSXP { /** * The data contained in this vector. Note that if it's an empty or scalar, those aren't actually * backed by an {@link ImmutableIntArray}, so this gets created and returns every access. @@ -21,6 +21,11 @@ default SEXPType type() { @Override IntSXP withAttributes(Attributes attributes); + + @Override + default Class getCanonicalType() { + return IntSXP.class; + } } /** Int vector which doesn't fit any of the more specific subclasses. */ @@ -50,6 +55,61 @@ public String toString() { public IntSXP withAttributes(Attributes attributes) { return SEXPs.integer(data, attributes); } + + @Override + public int asInt(int index) { + return data.get(index); + } + + @Override + public double asReal(int index) { + return get(index); + } +} + +/** Simple scalar integer = int vector of size 1 with no ALTREP, ATTRIB, or OBJECT. */ +final class ScalarIntSXP extends ScalarSXPImpl implements IntSXP { + ScalarIntSXP(int data) { + super(data); + } + + @SuppressWarnings("MissingJavadoc") + public int value() { + return data; + } + + @Override + public ImmutableIntArray data() { + return ImmutableIntArray.of(data); + } + + @Override + public IntSXP withAttributes(Attributes attributes) { + return SEXPs.integer(data, attributes); + } + + @Override + public String[] coerceToStrings() { + return new String[] {String.valueOf(data)}; + } + + @Override + public int asInt(int index) { + if (index == 0) { + return data; + } else { + throw new ArrayIndexOutOfBoundsException("Index out of bounds: " + index); + } + } + + @Override + public double asReal(int index) { + if (index == 0) { + return data.doubleValue(); + } else { + throw new ArrayIndexOutOfBoundsException("Index out of bounds: " + index); + } + } } /** Empty int vector with no ALTREP, ATTRIB, or OBJECT. */ @@ -69,4 +129,14 @@ public ImmutableIntArray data() { public IntSXP withAttributes(Attributes attributes) { return SEXPs.integer(ImmutableIntArray.of(), attributes); } + + @Override + public int asInt(int index) { + throw new ArrayIndexOutOfBoundsException("Empty int vector"); + } + + @Override + public double asReal(int index) { + throw new ArrayIndexOutOfBoundsException("Empty int vector"); + } } diff --git a/src/main/java/org/prlprg/sexp/LangSXP.java b/src/main/java/org/prlprg/sexp/LangSXP.java index 2ccc584a2..0c9b8ba8d 100644 --- a/src/main/java/org/prlprg/sexp/LangSXP.java +++ b/src/main/java/org/prlprg/sexp/LangSXP.java @@ -1,6 +1,7 @@ package org.prlprg.sexp; import com.google.common.collect.ImmutableList; +import java.util.Optional; import javax.annotation.concurrent.Immutable; import org.prlprg.primitive.Names; @@ -18,15 +19,28 @@ default SEXPType type() { return SEXPType.LANG; } + @Override + default Class getCanonicalType() { + return LangSXP.class; + } + @Override Attributes attributes(); @Override LangSXP withAttributes(Attributes attributes); - TaggedElem arg(int i); + SEXP arg(int i); + + ListSXP asList(); - ImmutableList asList(); + default Optional funName() { + return fun() instanceof RegSymSXP funSym ? Optional.of(funSym.name()) : Optional.empty(); + } + + default boolean funName(String name) { + return funName().map(name::equals).orElse(false); + } } record LangSXPImpl(SymOrLangSXP fun, ListSXP args, @Override Attributes attributes) @@ -44,7 +58,7 @@ private String deparse() { return args.get(0) + " " + funName + " " + args.get(1); } } - return fun().toString() + (args() instanceof NilSXP ? "()" : args().toString()); + return fun() + (args() instanceof NilSXP ? "()" : args().toString()); } @Override @@ -53,12 +67,13 @@ public LangSXP withAttributes(Attributes attributes) { } @Override - public TaggedElem arg(int i) { - return args.get(i); + public SEXP arg(int i) { + return args.get(i).value(); } @Override - public ImmutableList asList() { - return new ImmutableList.Builder().add(fun).addAll(args.values()).build(); + public ListSXP asList() { + var l = new ImmutableList.Builder().add(fun).addAll(args.values()).build(); + return SEXPs.list2(l); } } diff --git a/src/main/java/org/prlprg/sexp/LglSXP.java b/src/main/java/org/prlprg/sexp/LglSXP.java index 82ae65756..f83cb54fd 100644 --- a/src/main/java/org/prlprg/sexp/LglSXP.java +++ b/src/main/java/org/prlprg/sexp/LglSXP.java @@ -8,7 +8,7 @@ /** Logical vector SEXP. */ @Immutable public sealed interface LglSXP extends VectorSXP - permits EmptyLglSXPImpl, LglSXPImpl, SimpleLglSXP { + permits EmptyLglSXPImpl, LglSXPImpl, ScalarLglSXP { @Override default SEXPType type() { return SEXPType.LGL; @@ -19,6 +19,11 @@ default SEXPType type() { @Override LglSXP withAttributes(Attributes attributes); + + @Override + default Class getCanonicalType() { + return LglSXP.class; + } } /** Logical vector which doesn't fit any of the more specific subclasses. */ @@ -49,6 +54,27 @@ public LglSXP withAttributes(Attributes attributes) { } } +/** Simple scalar logical = logical vector of size 1 with no ALTREP, ATTRIB, or OBJECT. */ +final class ScalarLglSXP extends ScalarSXPImpl implements LglSXP { + static final ScalarLglSXP TRUE = new ScalarLglSXP(Logical.TRUE); + static final ScalarLglSXP FALSE = new ScalarLglSXP(Logical.FALSE); + static final ScalarLglSXP NA = new ScalarLglSXP(Logical.NA); + + private ScalarLglSXP(Logical data) { + super(data); + } + + @SuppressWarnings("MissingJavadoc") + public Logical value() { + return data; + } + + @Override + public LglSXP withAttributes(Attributes attributes) { + return SEXPs.logical(data, attributes); + } +} + /** Empty logical vector with no ALTREP, ATTRIB, or OBJECT. */ final class EmptyLglSXPImpl extends EmptyVectorSXPImpl implements LglSXP { static final EmptyLglSXPImpl INSTANCE = new EmptyLglSXPImpl(); diff --git a/src/main/java/org/prlprg/sexp/ListSXP.java b/src/main/java/org/prlprg/sexp/ListSXP.java index 92c3c0733..ac6482af7 100644 --- a/src/main/java/org/prlprg/sexp/ListSXP.java +++ b/src/main/java/org/prlprg/sexp/ListSXP.java @@ -3,7 +3,11 @@ import com.google.common.collect.ImmutableList; import java.util.Iterator; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; +import java.util.stream.Stream; +import javax.annotation.Nullable; /** * R "list". Confusingly, this is actually like Lua's @@ -14,23 +18,6 @@ * because it's more efficient. */ public sealed interface ListSXP extends ListOrVectorSXP permits NilSXP, ListSXPImpl { - /** - * Flatten {@code src} while adding its elements to {@code target}. Ex: - * - *

-   *   b = []; flatten([1, [2, 3], 4], b) ==> b = [1, 2, 3, 4]
-   * 
- */ - static void flatten(ListSXP src, ImmutableList.Builder target) { - for (var i : src) { - if (i.value() instanceof ListSXP lst) { - flatten(lst, target); - } else { - target.add(i); - } - } - } - @Override ListSXP withAttributes(Attributes attributes); @@ -39,6 +26,37 @@ static void flatten(ListSXP src, ImmutableList.Builder target) { List values(int fromIndex); List names(); + + List names(int fromIndex); + + ListSXP set(int index, @Nullable String tag, SEXP value); + + ListSXP appended(String tag, SEXP value); + + ListSXP appended(ListSXP other); + + ListSXP subList(int fromIndex); + + default boolean hasTags() { + return names().stream().anyMatch(x -> !x.isEmpty()); + } + + ListSXP remove(String tag); + + Stream stream(); + + Optional get(String name); + + ListSXP prepend(TaggedElem elem); + + @Override + default Class getCanonicalType() { + return ListSXP.class; + } + + default SEXP value(int i) { + return get(i).value(); + } } record ListSXPImpl(ImmutableList data, @Override Attributes attributes) @@ -60,12 +78,74 @@ public List values() { @Override public List values(int fromIndex) { - return values().subList(1, size()); + return values().subList(fromIndex, size()); } @Override public List names() { - return data.stream().map(TaggedElem::tag).toList(); + return data.stream().map(TaggedElem::tagOrEmpty).toList(); + } + + @Override + public List names(int fromIndex) { + return names().subList(fromIndex, size()); + } + + @Override + public ListSXP set(int index, @Nullable String tag, SEXP value) { + return new ListSXPImpl( + ImmutableList.builder() + .addAll(data.subList(0, index)) + .add(new TaggedElem(tag, value)) + .addAll(data.subList(index + 1, data.size())) + .build(), + attributes); + } + + @Override + public ListSXP appended(@Nullable String tag, SEXP value) { + return new ListSXPImpl( + ImmutableList.builder().addAll(data).add(new TaggedElem(tag, value)).build(), + attributes); + } + + @Override + public ListSXP appended(ListSXP other) { + return new ListSXPImpl( + ImmutableList.builder().addAll(data).addAll(other.iterator()).build(), + attributes); + } + + @Override + public ListSXP subList(int fromIndex) { + return new ListSXPImpl(data.subList(fromIndex, data.size()), attributes); + } + + @Override + public ListSXP remove(String tag) { + var builder = ImmutableList.builder(); + for (var i : this) { + if (!tag.equals(i.tag())) { + builder.add(i); + } + } + return new ListSXPImpl(builder.build(), Objects.requireNonNull(attributes())); + } + + @Override + public Stream stream() { + return data.stream(); + } + + @Override + public Optional get(String name) { + return Optional.empty(); + } + + @Override + public ListSXP prepend(TaggedElem elem) { + return new ListSXPImpl( + ImmutableList.builder().add(elem).addAll(data).build(), attributes); } @Override diff --git a/src/main/java/org/prlprg/sexp/NamespaceEnvSXP.java b/src/main/java/org/prlprg/sexp/NamespaceEnvSXP.java index 6045f37b1..f5846b48c 100644 --- a/src/main/java/org/prlprg/sexp/NamespaceEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/NamespaceEnvSXP.java @@ -1,39 +1,26 @@ package org.prlprg.sexp; -import java.util.Optional; +import java.util.Map; -public final class NamespaceEnvSXP implements EnvSXP { - private final EnvSXP parent; +public final class NamespaceEnvSXP extends AbstractEnvSXP implements EnvSXP { private final String name; private final String version; - public NamespaceEnvSXP(EnvSXP parent, String name, String version) { - this.parent = parent; + public NamespaceEnvSXP(String name, String version, EnvSXP parent, Map bindings) { + super(parent); + bindings.forEach(this::set); this.name = name; this.version = version; } - @Override - public EnvSXP parent() { - return parent; - } - - @Override - public Optional get(String name) { - // TODO: implement - return Optional.empty(); - } - - @Override - public Optional getLocal(String name) { - // TODO: implement - return Optional.empty(); - } - public String getVersion() { return version; } + public String getName() { + return name; + } + @Override public String toString() { // TODO: add some link to the R session? diff --git a/src/main/java/org/prlprg/sexp/NilSXP.java b/src/main/java/org/prlprg/sexp/NilSXP.java index e3e15fbf7..3df81c87b 100644 --- a/src/main/java/org/prlprg/sexp/NilSXP.java +++ b/src/main/java/org/prlprg/sexp/NilSXP.java @@ -1,7 +1,11 @@ package org.prlprg.sexp; +import com.google.common.collect.ImmutableList; import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; import org.prlprg.util.EmptyIterator; @@ -20,6 +24,11 @@ public SEXPType type() { return SEXPType.NIL; } + @Override + public Class getCanonicalType() { + return NilSXP.class; + } + @Override public String toString() { return "NULL"; @@ -71,6 +80,61 @@ public List names() { return Collections.emptyList(); } + @Override + public List names(int fromIndex) { + if (fromIndex == 0) { + return Collections.emptyList(); + } else { + throw new UnsupportedOperationException("NULL is empty"); + } + } + + @Override + public ListSXP set(int index, @Nullable String tag, SEXP value) { + throw new UnsupportedOperationException("NULL is empty"); + } + + @Override + public ListSXP appended(String tag, SEXP value) { + return new ListSXPImpl( + new ImmutableList.Builder().add(new TaggedElem(tag, value)).build(), + Attributes.NONE); + } + + @Override + public ListSXP appended(ListSXP other) { + return other; + } + + @Override + public ListSXP subList(int fromIndex) { + if (fromIndex == 0) { + return this; + } else { + throw new UnsupportedOperationException("NULL is empty"); + } + } + + @Override + public ListSXP remove(String tag) { + throw new UnsupportedOperationException("NULL is empty"); + } + + @Override + public Stream stream() { + throw new UnsupportedOperationException("NULL is empty"); + } + + @Override + public Optional get(String name) { + return Optional.empty(); + } + + @Override + public ListSXP prepend(TaggedElem elem) { + return SEXPs.list(List.of(elem)); + } + @Override public ListSXP withAttributes(Attributes attributes) { throw new UnsupportedOperationException("Cannot set attributes on NULL"); diff --git a/src/main/java/org/prlprg/sexp/NumericSXP.java b/src/main/java/org/prlprg/sexp/NumericSXP.java new file mode 100644 index 000000000..4d3671412 --- /dev/null +++ b/src/main/java/org/prlprg/sexp/NumericSXP.java @@ -0,0 +1,7 @@ +package org.prlprg.sexp; + +public sealed interface NumericSXP extends VectorSXP permits IntSXP, RealSXP { + int asInt(int index); + + double asReal(int index); +} diff --git a/src/main/java/org/prlprg/sexp/PromSXP.java b/src/main/java/org/prlprg/sexp/PromSXP.java index 4556ebde7..58433ef8b 100644 --- a/src/main/java/org/prlprg/sexp/PromSXP.java +++ b/src/main/java/org/prlprg/sexp/PromSXP.java @@ -27,4 +27,9 @@ public EnvSXP getEnv() { public SEXPType type() { return SEXPType.PROM; } + + @Override + public Class getCanonicalType() { + return PromSXP.class; + } } diff --git a/src/main/java/org/prlprg/sexp/RealSXP.java b/src/main/java/org/prlprg/sexp/RealSXP.java index 52de9918e..59052d6ea 100644 --- a/src/main/java/org/prlprg/sexp/RealSXP.java +++ b/src/main/java/org/prlprg/sexp/RealSXP.java @@ -1,13 +1,17 @@ package org.prlprg.sexp; +import com.google.common.math.DoubleMath; import com.google.common.primitives.ImmutableDoubleArray; +import java.util.Objects; import java.util.PrimitiveIterator; import javax.annotation.concurrent.Immutable; /** Real vector SEXP. */ @Immutable -public sealed interface RealSXP extends VectorSXP - permits EmptyRealSXPImpl, RealSXPImpl, SimpleRealSXP { +public sealed interface RealSXP extends NumericSXP + permits EmptyRealSXPImpl, RealSXPImpl, ScalarRealSXP { + double DOUBLE_CMP_DELTA = 0.000001d; + @Override default SEXPType type() { return SEXPType.REAL; @@ -18,6 +22,11 @@ default SEXPType type() { @Override RealSXP withAttributes(Attributes attributes); + + @Override + default Class getCanonicalType() { + return RealSXP.class; + } } /** Real vector which doesn't fit any of the more specific subclasses. */ @@ -46,6 +55,86 @@ public String toString() { public RealSXP withAttributes(Attributes attributes) { return SEXPs.real(data, attributes); } + + @Override + public int asInt(int index) { + return (int) data.get(index); + } + + @Override + public double asReal(int index) { + return data.get(index); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + var that = (RealSXPImpl) o; + var data2 = that.data; + if (data.length() != data2.length()) { + return false; + } + for (int i = 0; i < data.length(); i++) { + if (!DoubleMath.fuzzyEquals(data.get(i), data2.get(i), DOUBLE_CMP_DELTA)) { + return false; + } + } + return Objects.equals(attributes, that.attributes); + } + + @Override + public int hashCode() { + return Objects.hash(data, attributes); + } +} + +/** Simple scalar real = vector of size 1 with no ALTERP, ATTRIB, or OBJECT. */ +final class ScalarRealSXP extends ScalarSXPImpl implements RealSXP { + ScalarRealSXP(double data) { + super(data); + } + + @SuppressWarnings("MissingJavadoc") + public double value() { + return data; + } + + @Override + public RealSXP withAttributes(Attributes attributes) { + return SEXPs.real(data, attributes); + } + + @Override + public int asInt(int index) { + if (index == 0) { + return data.intValue(); + } else { + throw new ArrayIndexOutOfBoundsException("Index out of bounds: " + index); + } + } + + @Override + public double asReal(int index) { + if (index == 0) { + return data; + } else { + throw new ArrayIndexOutOfBoundsException("Index out of bounds: " + index); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + var that = (ScalarRealSXP) o; + return DoubleMath.fuzzyEquals(data, that.data, DOUBLE_CMP_DELTA); + } + + @Override + public int hashCode() { + return Objects.hash(data); + } } /** Empty real vector with no ALTREP, ATTRIB, or OBJECT. */ @@ -60,4 +149,14 @@ private EmptyRealSXPImpl() { public RealSXP withAttributes(Attributes attributes) { return SEXPs.real(ImmutableDoubleArray.of(), attributes); } + + @Override + public int asInt(int index) { + throw new ArrayIndexOutOfBoundsException("Empty real vector"); + } + + @Override + public double asReal(int index) { + throw new ArrayIndexOutOfBoundsException("Empty real vector"); + } } diff --git a/src/main/java/org/prlprg/sexp/RegSymSXP.java b/src/main/java/org/prlprg/sexp/RegSymSXP.java index a29f532c0..b8696f5c3 100644 --- a/src/main/java/org/prlprg/sexp/RegSymSXP.java +++ b/src/main/java/org/prlprg/sexp/RegSymSXP.java @@ -1,24 +1,14 @@ package org.prlprg.sexp; import com.google.common.base.Objects; -import com.google.common.collect.ImmutableList; import java.util.Optional; /** Symbol which isn't "unbound value" or "missing arg" */ public final class RegSymSXP implements SymSXP, StrOrRegSymSXP { - private static final ImmutableList LITERAL_NAMES = - ImmutableList.of("TRUE", "FALSE", "NULL", "NA", "Inf", "NaN"); - private final String name; private final boolean isEscaped; RegSymSXP(String name) { - if (name.isBlank()) { - throw new IllegalArgumentException("Symbol name cannot be blank"); - } - if (LITERAL_NAMES.contains(name)) { - throw new IllegalArgumentException("Symbol name reserved by literal: " + name); - } this.name = name; isEscaped = name.chars().anyMatch(c -> !Character.isAlphabetic(c) && c != '.' && c != '_'); } diff --git a/src/main/java/org/prlprg/sexp/SEXP.java b/src/main/java/org/prlprg/sexp/SEXP.java index 8cb874178..cf22a2019 100644 --- a/src/main/java/org/prlprg/sexp/SEXP.java +++ b/src/main/java/org/prlprg/sexp/SEXP.java @@ -1,6 +1,7 @@ package org.prlprg.sexp; -import java.util.Objects; +import com.google.common.collect.Streams; +import java.util.*; import javax.annotation.Nullable; /** @@ -11,7 +12,15 @@ * suspect GNU-R SEXPs aren't actually S-expressions. */ public sealed interface SEXP - permits StrOrRegSymSXP, SymOrLangSXP, ListOrVectorSXP, CloSXP, EnvSXP, BCodeSXP, PromSXP { + permits StrOrRegSymSXP, + SymOrLangSXP, + ListOrVectorSXP, + CloSXP, + EnvSXP, + BCodeSXP, + PromSXP, + BuiltinSXP, + SpecialSXP { /** * SEXPTYPE. It's important to distinguish these from the SEXP's class, because there's a class * for every type but not vice versa due to subclasses (e.g. simple-scalar ints have the same @@ -19,9 +28,18 @@ public sealed interface SEXP */ SEXPType type(); + /** + * The canonical class of this SEXP in the Java land. For example, there are specialized classes + * for simple scalars, but they all have the same SEXPType. Every SEXP that override the {@link + * #type()} method should also override this method. + * + * @return the Java class of this SEXP + */ + Class getCanonicalType(); + /** * @return {@code null} if the SEXP doesn't support attributes ({@link #withAttributes} throws an - * exception) and {@code Attributes.NONE} if it does but there are none. + * exception) and {@link Attributes.NONE} if it does but there are none. */ default @Nullable Attributes attributes() { return null; @@ -43,4 +61,51 @@ default SEXP withClass(String name) { var attrs = Objects.requireNonNull(attributes()).including("class", SEXPs.string(name)); return withAttributes(attrs); } + + /** + * The implementation of the is.function() which eventually calls the isFunction() from + * Rinlinedfuns.h + * + * @return {@code true} if this SEXP is a function (closure, builtin, or special). + */ + default boolean isFunction() { + return this instanceof CloSXP || this instanceof BuiltinSXP || this instanceof SpecialSXP; + } + + default SEXP withNames(String name) { + return withNames(SEXPs.string(name)); + } + + default SEXP withNames(Collection names) { + return withNames(SEXPs.string(names)); + } + + default SEXP withNames(StrSXP names) { + if (names.isEmpty()) { + return withAttributes(Objects.requireNonNull(attributes()).excluding("names")); + } else { + return withAttributes(Objects.requireNonNull(attributes()).including("names", names)); + } + } + + default List names() { + var names = Objects.requireNonNull(attributes()).get("names"); + if (names == null) { + return List.of(); + } else { + return Streams.stream((StrSXP) names).toList(); + } + } + + default boolean typeOneOf(SEXPType... types) { + return Arrays.stream(types).anyMatch(t -> t == type()); + } + + default Optional asLang() { + return as(LangSXP.class); + } + + default Optional as(Class clazz) { + return clazz.isInstance(this) ? Optional.of(clazz.cast(this)) : Optional.empty(); + } } diff --git a/src/main/java/org/prlprg/sexp/SEXPType.java b/src/main/java/org/prlprg/sexp/SEXPType.java index a07092c2b..ce8d4a031 100644 --- a/src/main/java/org/prlprg/sexp/SEXPType.java +++ b/src/main/java/org/prlprg/sexp/SEXPType.java @@ -6,7 +6,7 @@ *

SEXPTYPEs are fixed in GNU-R: we can represent SEXPs of custom types on the server, but they * have to be converted to something like external pointers if we have a client using the GNU-R * runtime (which is the only planned runtime). Furthermore, we don't refine existing SEXPTypes even - * when we refine the SEXP class: e.g. {@link SimpleIntSXP} has the same SEXPType as any other int + * when we refine the SEXP class: e.g. {@link ScalarIntSXP} has the same SEXPType as any other int * vector since the SEXPType for all int vectors is {@code INT}. Therefore it's important to * distinguish {@link SEXP#type} from the SEXP's class. * diff --git a/src/main/java/org/prlprg/sexp/SEXPs.java b/src/main/java/org/prlprg/sexp/SEXPs.java index fb45fac91..ae2e6cb03 100644 --- a/src/main/java/org/prlprg/sexp/SEXPs.java +++ b/src/main/java/org/prlprg/sexp/SEXPs.java @@ -3,8 +3,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.ImmutableDoubleArray; import com.google.common.primitives.ImmutableIntArray; -import java.util.Arrays; -import java.util.Collection; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nullable; @@ -18,12 +17,13 @@ public final class SEXPs { // region constants public static final NilSXP NULL = NilSXP.INSTANCE; - public static final SimpleLglSXP TRUE = SimpleLglSXP.TRUE; - public static final SimpleLglSXP FALSE = SimpleLglSXP.FALSE; - public static final SimpleLglSXP NA_LOGICAL = SimpleLglSXP.NA; - public static final SimpleIntSXP NA_INTEGER = new SimpleIntSXP(Constants.NA_INT); - public static final RealSXP NA_REAL = new SimpleRealSXP(Constants.NA_REAL); - public static final StrSXP NA_STRING = new SimpleStrSXP(Constants.NA_STRING); + public static final LglSXP TRUE = ScalarLglSXP.TRUE; + public static final LglSXP FALSE = ScalarLglSXP.FALSE; + public static final LglSXP NA_LOGICAL = ScalarLglSXP.NA; + public static final IntSXP NA_INTEGER = new ScalarIntSXP(Constants.NA_INT); + public static final RealSXP NA_REAL = new ScalarRealSXP(Constants.NA_REAL); + public static final StrSXP NA_STRING = new ScalarStrSXP(Constants.NA_STRING); + public static final ComplexSXP NA_COMPLEX = new ScalarComplexSXP(Constants.NA_COMPLEX); public static final LglSXP EMPTY_LOGICAL = EmptyLglSXPImpl.INSTANCE; public static final IntSXP EMPTY_INTEGER = EmptyIntSXPImpl.INSTANCE; public static final RealSXP EMPTY_REAL = EmptyRealSXPImpl.INSTANCE; @@ -33,7 +33,20 @@ public final class SEXPs { public static final SpecialSymSXP MISSING_ARG = new SpecialSymSXP("MISSING_ARG"); - public static final RegSymSXP ELLIPSIS = new RegSymSXP("..."); + private static final Map SYMBOL_POOL = new HashMap<>(); + + public static final RegSymSXP DOTS_SYMBOL = symbol("..."); + public static final RegSymSXP SUPER_ASSIGN = symbol("<<-"); + public static final RegSymSXP ASSIGN_TMP = symbol("*tmp*"); + public static final RegSymSXP ASSIGN_VTMP = symbol("*vtmp*"); + + static { + Set.of("TRUE", "FALSE", "NULL", "NA", "Inf", "NaN") + .forEach( + x -> { + SYMBOL_POOL.put(x, new RegSymSXP(x)); + }); + } public static final EmptyEnvSXP EMPTY_ENV = EmptyEnvSXP.INSTANCE; @@ -48,20 +61,20 @@ public static LglSXP logical(Logical data) { }; } - public static SimpleIntSXP integer(int data) { - return new SimpleIntSXP(data); + public static IntSXP integer(int data) { + return new ScalarIntSXP(data); } - public static SimpleRealSXP real(double data) { - return new SimpleRealSXP(data); + public static RealSXP real(double data) { + return new ScalarRealSXP(data); } - public static SimpleStrSXP string(String data) { - return new SimpleStrSXP(data); + public static StrSXP string(String data) { + return new ScalarStrSXP(data); } - public static SimpleComplexSXP complex(Complex data) { - return new SimpleComplexSXP(data); + public static ComplexSXP complex(Complex data) { + return new ScalarComplexSXP(data); } public static IntSXP integer(int first, int... rest) { @@ -96,10 +109,18 @@ public static IntSXP integer(int[] data) { return integer(ImmutableIntArray.copyOf(data)); } + public static IntSXP integer(Integer[] data) { + return integer(ImmutableIntArray.copyOf(Arrays.asList(data))); + } + public static RealSXP real(double[] data) { return real(ImmutableDoubleArray.copyOf(data)); } + public static RealSXP real(Double[] data) { + return real(ImmutableDoubleArray.copyOf(Arrays.asList(data))); + } + public static StrSXP string(String[] data) { return string(ImmutableList.copyOf(data)); } @@ -116,8 +137,8 @@ public static ExprSXP expr(SEXP... data) { return expr(ImmutableList.copyOf(data)); } - public static SimpleComplexSXP complex(double real, double imaginary) { - return new SimpleComplexSXP(new Complex(real, imaginary)); + public static ComplexSXP complex(double real, double imaginary) { + return new ScalarComplexSXP(new Complex(real, imaginary)); } public static LglSXP logical(Logical data, Attributes attributes) { @@ -316,7 +337,7 @@ public static ExprSXP expr(Collection data, Attributes attributes) { } public static ListSXP list(SEXP... data) { - return list(Arrays.stream(data).map(TaggedElem::new).collect(Collectors.toList())); + return list(Arrays.stream(data).map(TaggedElem::new).toList()); } public static ListSXP list(ImmutableList data) { @@ -327,6 +348,11 @@ public static ListSXP list(Collection data) { return list(ImmutableList.copyOf(data)); } + // FIXME: ugly + public static ListSXP list2(Collection data) { + return list(data.stream().map(TaggedElem::new).toList()); + } + public static ListSXP list(TaggedElem[] data, Attributes attributes) { return list(ImmutableList.copyOf(data), attributes); } @@ -365,10 +391,7 @@ public static LangSXP lang(SymOrLangSXP fun, ListSXP args, Attributes attributes } public static RegSymSXP symbol(String name) { - if (name.equals("...")) { - return ELLIPSIS; - } - return new RegSymSXP(name); + return SYMBOL_POOL.computeIfAbsent(name, RegSymSXP::new); } // endregion @@ -389,4 +412,24 @@ static String toString(SEXP sexp, Object... data) { } private SEXPs() {} + + public static SEXP builtin(String name) { + return new BuiltinSXP(name); + } + + public static SEXP special(String name) { + return new SpecialSXP(name); + } + + @SuppressWarnings("unchecked") + public static VectorSXP vector(SEXPType type, ImmutableList build) { + return switch (type) { + case LGL -> (VectorSXP) logical((List) build); + case INT -> (VectorSXP) integer((List) build); + case REAL -> (VectorSXP) real((List) build); + case STR -> (VectorSXP) string((List) build); + case CPLX -> (VectorSXP) complex((List) build); + default -> throw new IllegalArgumentException("Unsupported type: " + type); + }; + } } diff --git a/src/main/java/org/prlprg/sexp/SimpleComplexSXP.java b/src/main/java/org/prlprg/sexp/SimpleComplexSXP.java deleted file mode 100644 index 281fd86e6..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleComplexSXP.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.prlprg.sexp; - -import org.prlprg.primitive.Complex; - -/** Simple scalar complex = vector of size 1 with no ALTERP, ATTRIB, or OBJECT. */ -public final class SimpleComplexSXP extends SimpleScalarSXPImpl implements ComplexSXP { - SimpleComplexSXP(Complex data) { - super(data); - } - - @SuppressWarnings("MissingJavadoc") - public Complex value() { - return data; - } - - @Override - public ComplexSXP withAttributes(Attributes attributes) { - return SEXPs.complex(data, attributes); - } -} diff --git a/src/main/java/org/prlprg/sexp/SimpleIntSXP.java b/src/main/java/org/prlprg/sexp/SimpleIntSXP.java deleted file mode 100644 index 6876acd3b..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleIntSXP.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.prlprg.sexp; - -import com.google.common.primitives.ImmutableIntArray; - -/** Simple scalar integer = int vector of size 1 with no ALTREP, ATTRIB, or OBJECT. */ -public final class SimpleIntSXP extends SimpleScalarSXPImpl implements IntSXP { - SimpleIntSXP(int data) { - super(data); - } - - @SuppressWarnings("MissingJavadoc") - public int value() { - return data; - } - - @Override - public ImmutableIntArray data() { - return ImmutableIntArray.of(data); - } - - @Override - public IntSXP withAttributes(Attributes attributes) { - return SEXPs.integer(data, attributes); - } -} diff --git a/src/main/java/org/prlprg/sexp/SimpleLglSXP.java b/src/main/java/org/prlprg/sexp/SimpleLglSXP.java deleted file mode 100644 index 8f2fd1a9a..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleLglSXP.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.prlprg.sexp; - -import org.prlprg.primitive.Logical; - -/** Simple scalar logical = logical vector of size 1 with no ALTREP, ATTRIB, or OBJECT. */ -public final class SimpleLglSXP extends SimpleScalarSXPImpl implements LglSXP { - static final SimpleLglSXP TRUE = new SimpleLglSXP(Logical.TRUE); - static final SimpleLglSXP FALSE = new SimpleLglSXP(Logical.FALSE); - static final SimpleLglSXP NA = new SimpleLglSXP(Logical.NA); - - private SimpleLglSXP(Logical data) { - super(data); - } - - @SuppressWarnings("MissingJavadoc") - public Logical value() { - return data; - } - - @Override - public LglSXP withAttributes(Attributes attributes) { - return SEXPs.logical(data, attributes); - } -} diff --git a/src/main/java/org/prlprg/sexp/SimpleRealSXP.java b/src/main/java/org/prlprg/sexp/SimpleRealSXP.java deleted file mode 100644 index a0cf0f829..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleRealSXP.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.prlprg.sexp; - -/** Simple scalar real = vector of size 1 with no ALTERP, ATTRIB, or OBJECT. */ -public final class SimpleRealSXP extends SimpleScalarSXPImpl implements RealSXP { - SimpleRealSXP(double data) { - super(data); - } - - @SuppressWarnings("MissingJavadoc") - public double value() { - return data; - } - - @Override - public RealSXP withAttributes(Attributes attributes) { - return SEXPs.real(data, attributes); - } -} diff --git a/src/main/java/org/prlprg/sexp/SimpleScalarSXPImpl.java b/src/main/java/org/prlprg/sexp/SimpleScalarSXPImpl.java deleted file mode 100644 index 078d170ab..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleScalarSXPImpl.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.prlprg.sexp; - -import com.google.common.base.Objects; -import com.google.common.collect.Iterators; -import com.google.common.collect.UnmodifiableIterator; -import javax.annotation.concurrent.Immutable; - -/** Class for representing a scalar SEXP of a primitive type with no attributes. */ -@Immutable -abstract class SimpleScalarSXPImpl { - final T data; - - protected SimpleScalarSXPImpl(T data) { - this.data = data; - } - - public UnmodifiableIterator iterator() { - return Iterators.forArray(data); - } - - public T get(int i) { - if (i != 0) { - throw new IndexOutOfBoundsException(); - } - return data; - } - - public int size() { - return 1; - } - - @Override - public String toString() { - return data.toString(); - } - - public Attributes attributes() { - return Attributes.NONE; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof SimpleScalarSXPImpl that)) return false; - return Objects.equal(data, that.data); - } - - @Override - public int hashCode() { - return Objects.hashCode(data); - } -} diff --git a/src/main/java/org/prlprg/sexp/SimpleStrSXP.java b/src/main/java/org/prlprg/sexp/SimpleStrSXP.java deleted file mode 100644 index 11fc9728d..000000000 --- a/src/main/java/org/prlprg/sexp/SimpleStrSXP.java +++ /dev/null @@ -1,30 +0,0 @@ -package org.prlprg.sexp; - -import java.util.Optional; - -/** Simple scalar string = vector of size 1 with no ALTERP, ATTRIB, or OBJECT. */ -public final class SimpleStrSXP extends SimpleScalarSXPImpl implements StrSXP { - SimpleStrSXP(String data) { - super(data); - } - - @SuppressWarnings("MissingJavadoc") - public String value() { - return data; - } - - @Override - public String toString() { - return StrSXPs.quoteString(data); - } - - @Override - public StrSXP withAttributes(Attributes attributes) { - return SEXPs.string(data, attributes); - } - - @Override - public Optional reifyString() { - return Optional.of(data); - } -} diff --git a/src/main/java/org/prlprg/sexp/SpecialSXP.java b/src/main/java/org/prlprg/sexp/SpecialSXP.java new file mode 100644 index 000000000..c24fc2546 --- /dev/null +++ b/src/main/java/org/prlprg/sexp/SpecialSXP.java @@ -0,0 +1,13 @@ +package org.prlprg.sexp; + +public record SpecialSXP(String name) implements SEXP { + @Override + public SEXPType type() { + return SEXPType.SPECIAL; + } + + @Override + public Class getCanonicalType() { + return SpecialSXP.class; + } +} diff --git a/src/main/java/org/prlprg/sexp/StrSXP.java b/src/main/java/org/prlprg/sexp/StrSXP.java index a0d005357..176b4b74d 100644 --- a/src/main/java/org/prlprg/sexp/StrSXP.java +++ b/src/main/java/org/prlprg/sexp/StrSXP.java @@ -6,17 +6,21 @@ import com.google.common.escape.Escapers; import java.util.Optional; import javax.annotation.concurrent.Immutable; -import org.prlprg.primitive.Constants; /** String vector SEXP. */ @Immutable public sealed interface StrSXP extends VectorSXP, StrOrRegSymSXP - permits EmptyStrSXPImpl, SimpleStrSXP, StrSXPImpl { + permits EmptyStrSXPImpl, ScalarStrSXP, StrSXPImpl { @Override default SEXPType type() { return SEXPType.STR; } + @Override + default Class getCanonicalType() { + return StrSXP.class; + } + @Override Attributes attributes(); @@ -57,6 +61,32 @@ public Optional reifyString() { } } +final class ScalarStrSXP extends ScalarSXPImpl implements StrSXP { + ScalarStrSXP(String data) { + super(data); + } + + @SuppressWarnings("MissingJavadoc") + public String value() { + return data; + } + + @Override + public String toString() { + return StrSXPs.quoteString(data); + } + + @Override + public StrSXP withAttributes(Attributes attributes) { + return SEXPs.string(data, attributes); + } + + @Override + public Optional reifyString() { + return Optional.of(data); + } +} + /** Empty string vector with no ALTREP, ATTRIB, or OBJECT. */ final class EmptyStrSXPImpl extends EmptyVectorSXPImpl implements StrSXP { static final EmptyStrSXPImpl INSTANCE = new EmptyStrSXPImpl(); @@ -85,7 +115,7 @@ final class StrSXPs { .build(); static String quoteString(String s) { - return Constants.isNaString(s) ? "NA" : "\"" + rEscaper.escape(s) + "\""; + return Coercions.isNA(s) ? "NA" : "\"" + rEscaper.escape(s) + "\""; } private StrSXPs() {} diff --git a/src/main/java/org/prlprg/sexp/SymSXP.java b/src/main/java/org/prlprg/sexp/SymSXP.java index 42adae122..3e608592f 100644 --- a/src/main/java/org/prlprg/sexp/SymSXP.java +++ b/src/main/java/org/prlprg/sexp/SymSXP.java @@ -10,9 +10,14 @@ default SEXPType type() { return SEXPType.SYM; } + @Override + default Class getCanonicalType() { + return SymSXP.class; + } + /** Whether this is the ellipsis symbol. */ default boolean isEllipsis() { - return this == SEXPs.ELLIPSIS; + return this == SEXPs.DOTS_SYMBOL; } /** Whether this is the missing symbol. */ diff --git a/src/main/java/org/prlprg/sexp/TaggedElem.java b/src/main/java/org/prlprg/sexp/TaggedElem.java index 9c3d46f95..be6a1a0f1 100644 --- a/src/main/java/org/prlprg/sexp/TaggedElem.java +++ b/src/main/java/org/prlprg/sexp/TaggedElem.java @@ -22,4 +22,22 @@ public String toString() { ? value.toString() : value == SEXPs.MISSING_ARG ? tag + "=" : tag + "=" + value; } + + public SEXP namedValue() { + if (tag == null) { + return value; + } else { + return value.withNames(tag); + } + } + + /** + * Returns the tag or an empty string if the tag is null. This is to follow what GNU-R does when + * printing names. + * + * @return the tag or an empty string if the tag is null + */ + public String tagOrEmpty() { + return tag == null ? "" : tag; + } } diff --git a/src/main/java/org/prlprg/sexp/UserEnvSXP.java b/src/main/java/org/prlprg/sexp/UserEnvSXP.java index 7db31298e..6ab89500a 100644 --- a/src/main/java/org/prlprg/sexp/UserEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/UserEnvSXP.java @@ -1,13 +1,8 @@ package org.prlprg.sexp; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; import javax.annotation.Nullable; -public final class UserEnvSXP implements EnvSXP { - private EnvSXP parent; - private final Map entries; +public final class UserEnvSXP extends AbstractEnvSXP implements EnvSXP { private @Nullable Attributes attributes; public UserEnvSXP() { @@ -15,23 +10,11 @@ public UserEnvSXP() { } public UserEnvSXP(EnvSXP parent) { - this.parent = parent; - this.entries = new HashMap<>(); + super(parent); } - @Override - public EnvSXP parent() { - return parent; - } - - @Override - public Optional get(String name) { - return getLocal(name).or(() -> parent.get(name)); - } - - @Override - public Optional getLocal(String name) { - return Optional.ofNullable(entries.get(name)); + public void setParent(EnvSXP parent) { + this.parent = parent; } @Override @@ -45,15 +28,7 @@ public UserEnvSXP withAttributes(Attributes attributes) { return this; } - public void setParent(EnvSXP parent) { - this.parent = parent; - } - - public void set(String name, SEXP value) { - entries.put(name, value); - } - - public void setAttributes(Attributes attributes) { - this.attributes = attributes; + @Nullable public Attributes getAttributes() { + return attributes; } } diff --git a/src/main/java/org/prlprg/sexp/VecSXP.java b/src/main/java/org/prlprg/sexp/VecSXP.java index 1ebc4dfc0..9876c5264 100644 --- a/src/main/java/org/prlprg/sexp/VecSXP.java +++ b/src/main/java/org/prlprg/sexp/VecSXP.java @@ -1,7 +1,10 @@ package org.prlprg.sexp; +import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; import com.google.common.collect.UnmodifiableIterator; +import javax.annotation.concurrent.Immutable; /** Generic vector SEXP = vector which contains SEXPs. */ public sealed interface VecSXP extends VectorSXP { @@ -10,6 +13,11 @@ default SEXPType type() { return SEXPType.VEC; } + @Override + default Class getCanonicalType() { + return VecSXP.class; + } + @Override Attributes attributes(); @@ -43,3 +51,86 @@ public VecSXP withAttributes(Attributes attributes) { return SEXPs.vec(data, attributes); } } + +/** Class for representing a scalar SEXP of a primitive type with no attributes. */ +@Immutable +abstract sealed class ScalarSXPImpl implements VectorSXP + permits ScalarComplexSXP, ScalarIntSXP, ScalarLglSXP, ScalarRealSXP, ScalarStrSXP { + final T data; + + protected ScalarSXPImpl(T data) { + this.data = data; + } + + public UnmodifiableIterator iterator() { + return Iterators.forArray(data); + } + + public T get(int i) { + if (i != 0) { + throw new IndexOutOfBoundsException(); + } + return data; + } + + public int size() { + return 1; + } + + @Override + public String toString() { + return data.toString(); + } + + public Attributes attributes() { + return Attributes.NONE; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof ScalarSXPImpl that)) return false; + return Objects.equal(data, that.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(data); + } +} + +/** Class for representing a scalar SEXP of a primitive type with no attributes. */ +@Immutable +abstract sealed class EmptyVectorSXPImpl implements VectorSXP + permits EmptyComplexSXPImpl, + EmptyIntSXPImpl, + EmptyLglSXPImpl, + EmptyRealSXPImpl, + EmptyStrSXPImpl { + protected EmptyVectorSXPImpl() {} + + @Override + public UnmodifiableIterator iterator() { + return Iterators.forArray(); + } + + @Override + public T get(int i) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int size() { + return 0; + } + + @Override + public String toString() { + return ""; + } + + @Override + public Attributes attributes() { + return Attributes.NONE; + } +} diff --git a/src/main/java/org/prlprg/sexp/VectorSXP.java b/src/main/java/org/prlprg/sexp/VectorSXP.java index 617cfe306..9f36a7a86 100644 --- a/src/main/java/org/prlprg/sexp/VectorSXP.java +++ b/src/main/java/org/prlprg/sexp/VectorSXP.java @@ -1,15 +1,82 @@ package org.prlprg.sexp; import java.util.stream.BaseStream; +import org.prlprg.primitive.Complex; +import org.prlprg.primitive.Logical; /** SEXP vector (immutable list). */ public sealed interface VectorSXP extends ListOrVectorSXP - permits ComplexSXP, ExprSXP, IntSXP, LglSXP, RealSXP, StrSXP, VecSXP, EmptyVectorSXPImpl { + permits ComplexSXP, + EmptyVectorSXPImpl, + ScalarSXPImpl, + ExprSXP, + LglSXP, + NumericSXP, + StrSXP, + VecSXP { @Override Attributes attributes(); @Override VectorSXP withAttributes(Attributes attributes); + + default boolean isScalar() { + return size() == 1; + } + + /** + * Coerce the elements of this vector to strings. + * + * @return the elements as strings. + */ + default String[] coerceToStrings() { + return coerceTo(String.class); + } + + default Double[] coerceToReals() { + return coerceTo(Double.class); + } + + default Integer[] coerceToInts() { + return coerceTo(Integer.class); + } + + default Logical[] coerceToLogicals() { + return coerceTo(Logical.class); + } + + default Complex[] coerceToComplexes() { + return coerceTo(Complex.class); + } + + default R[] coerceTo(Class clazz) { + Object[] target; + SEXPType targetType; + if (clazz == String.class) { + target = new String[size()]; + targetType = SEXPType.STR; + } else if (clazz == Double.class) { + target = new Double[size()]; + targetType = SEXPType.REAL; + } else if (clazz == Integer.class) { + target = new Integer[size()]; + targetType = SEXPType.INT; + } else if (clazz == Logical.class) { + target = new Logical[size()]; + targetType = SEXPType.LGL; + } else if (clazz == Complex.class) { + target = new Complex[size()]; + targetType = SEXPType.CPLX; + } else { + throw new IllegalArgumentException("Unsupported target type: " + clazz); + } + + for (int i = 0; i < size(); i++) { + target[i] = Coercions.coerce(get(i), targetType); + } + + return (R[]) target; + } } final class VectorSXPs { diff --git a/src/main/java/org/prlprg/util/Arithmetic.java b/src/main/java/org/prlprg/util/Arithmetic.java new file mode 100644 index 000000000..a3322316d --- /dev/null +++ b/src/main/java/org/prlprg/util/Arithmetic.java @@ -0,0 +1,234 @@ +package org.prlprg.util; + +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.prlprg.primitive.Complex; +import org.prlprg.sexp.*; + +public interface Arithmetic { + Arithmetic INTEGER = + new Arithmetic<>() { + // FIXME: check for overflow + @Override + public Integer add(Integer a, Integer b) { + return a + b; + } + + @Override + public Integer sub(Integer a, Integer b) { + return a - b; + } + + // FIXME: check for overflow + @Override + public Integer mul(Integer a, Integer b) { + return a * b; + } + + @Override + public Integer div(Integer a, Integer b) { + return a / b; + } + + @Override + public Integer pow(Integer a, Integer b) { + throw new UnsupportedOperationException("pow on integers"); + } + + @Override + public Integer plus(Integer a) { + return a; + } + + @Override + public Integer minus(Integer a) { + return -a; + } + + @Override + public Integer[] createResult(int size) { + return new Integer[size]; + } + + @Override + public Integer[] fromSEXP(VectorSXP vec) { + return vec.coerceToInts(); + } + + @Override + public VectorSXP toSEXP(Integer[] ts) { + return SEXPs.integer(ts); + } + }; + + Arithmetic DOUBLE = + new Arithmetic<>() { + @Override + public Double add(Double a, Double b) { + return a + b; + } + + @Override + public Double sub(Double a, Double b) { + return a - b; + } + + @Override + public Double mul(Double a, Double b) { + return a * b; + } + + @Override + public Double div(Double a, Double b) { + return a / b; + } + + @Override + public Double pow(Double a, Double b) { + return Math.pow(a, b); + } + + @Override + public Double plus(Double a) { + return a; + } + + @Override + public Double minus(Double a) { + return -a; + } + + @Override + public Double[] createResult(int size) { + return new Double[size]; + } + + @Override + public Double[] fromSEXP(VectorSXP vec) { + return vec.coerceToReals(); + } + + @Override + public VectorSXP toSEXP(Double[] ts) { + return SEXPs.real(ts); + } + }; + Arithmetic COMPLEX = + new Arithmetic<>() { + @Override + public Complex add(Complex a, Complex b) { + return a.add(b); + } + + @Override + public Complex sub(Complex a, Complex b) { + return a.sub(b); + } + + @Override + public Complex mul(Complex a, Complex b) { + return a.mul(b); + } + + @Override + public Complex div(Complex a, Complex b) { + return a.div(b); + } + + @Override + public Complex pow(Complex a, Complex b) { + return a.pow(b); + } + + @Override + public Complex plus(Complex a) { + return a; + } + + @Override + public Complex minus(Complex a) { + return a.minus(); + } + + @Override + public Complex[] createResult(int size) { + return new Complex[size]; + } + + @Override + public Complex[] fromSEXP(VectorSXP vec) { + return vec.coerceToComplexes(); + } + + @Override + public VectorSXP toSEXP(Complex[] ts) { + return SEXPs.complex(ts); + } + }; + + static Optional> forType(SEXPType type) { + var arith = + switch (type) { + case INT -> INTEGER; + case REAL -> DOUBLE; + case CPLX -> COMPLEX; + default -> null; + }; + return Optional.ofNullable(arith); + } + + static Optional> forType(List args) { + return forType(Coercions.commonType(args)); + } + + T add(T a, T b); + + T sub(T a, T b); + + T mul(T a, T b); + + T div(T a, T b); + + T pow(T a, T b); + + T plus(T a); + + T minus(T a); + + T[] createResult(int size); + + T[] fromSEXP(VectorSXP vec); + + VectorSXP toSEXP(T[] ts); + + default BiFunction getBinaryFun(Operation op) { + return switch (op) { + case ADD -> this::add; + case SUB -> this::sub; + case MUL -> this::mul; + case DIV -> this::div; + case POW -> this::pow; + default -> throw new IllegalArgumentException("Unsupported binary operation: " + op); + }; + } + + default Function getUnaryFun(Operation op) { + return switch (op) { + case PLUS -> this::plus; + case MINUS -> this::minus; + default -> throw new IllegalArgumentException("Unsupported unary operation: " + op); + }; + } + + enum Operation { + ADD, + SUB, + MUL, + DIV, + POW, + PLUS, + MINUS + } +} diff --git a/src/main/java/org/prlprg/util/Either.java b/src/main/java/org/prlprg/util/Either.java index 22dc5d89a..ea521e5fd 100644 --- a/src/main/java/org/prlprg/util/Either.java +++ b/src/main/java/org/prlprg/util/Either.java @@ -1,5 +1,7 @@ package org.prlprg.util; +import java.util.NoSuchElementException; + @SuppressWarnings("MissingJavadoc") public sealed interface Either permits Left, Right { static Either left(L left) { @@ -9,8 +11,56 @@ static Either left(L left) { static Either right(R right) { return new Right<>(right); } + + boolean isLeft(); + + boolean isRight(); + + L getLeft(); + + R getRight(); } -record Left(L left) implements Either {} +record Left(L left) implements Either { + @Override + public boolean isLeft() { + return true; + } -record Right(R right) implements Either {} + @Override + public boolean isRight() { + return false; + } + + @Override + public L getLeft() { + return left; + } + + @Override + public R getRight() { + throw new NoSuchElementException("This either contains left value"); + } +} + +record Right(R right) implements Either { + @Override + public boolean isLeft() { + return false; + } + + @Override + public boolean isRight() { + return true; + } + + @Override + public L getLeft() { + throw new NoSuchElementException("This either contains right value"); + } + + @Override + public R getRight() { + return right; + } +} diff --git a/src/test/java/org/prlprg/bc/BcTests.java b/src/test/java/org/prlprg/bc/BcTests.java index ec723e5b7..fc7978682 100644 --- a/src/test/java/org/prlprg/bc/BcTests.java +++ b/src/test/java/org/prlprg/bc/BcTests.java @@ -10,35 +10,39 @@ public class BcTests { @Test @DisplayName("Create bytecode array") void createBcArray() { - var bcBuilder = new Bc.Builder(); - bcBuilder.setTrackSrcRefs(false); - bcBuilder.setTrackExpressions(false); + var bcb = new Bc.Builder(); + bcb.setTrackSrcRefs(false); + bcb.setTrackExpressions(false); + var ast = SEXPs.lang(SEXPs.symbol("+"), SEXPs.list(SEXPs.integer(1), SEXPs.integer(2))); - // It doesn't make sense to implement SEXP#clone because you'd just reuse the SEXP since - // they are immutable. + + // It doesn't make sense to implement SEXP#clone because you'd just reuse the + // SEXP since they are immutable. // Only SEXP.withXYZ(...) methods. var astClone = SEXPs.lang(SEXPs.symbol("+"), SEXPs.list(SEXPs.integer(1), SEXPs.integer(2))); - bcBuilder.addInstr(new BcInstr.LdConst(bcBuilder.addConst(SEXPs.integer(1)))); - bcBuilder.addInstr(new BcInstr.LdConst(bcBuilder.addConst(SEXPs.integer(2)))); - bcBuilder.addInstr(new BcInstr.Add(bcBuilder.addConst(ast))); - bcBuilder.addInstr(new BcInstr.Return()); - var bc = bcBuilder.build(); + bcb.addInstr(new BcInstr.LdConst(bcb.addConst(SEXPs.integer(1)))); + bcb.addInstr(new BcInstr.LdConst(bcb.addConst(SEXPs.integer(2)))); + bcb.addInstr(new BcInstr.Add(bcb.addConst(ast))); + bcb.addInstr(new BcInstr.Return()); + + var bc = bcb.build(); + assertEquals(4, bc.code().size()); assertEquals(BcOp.LDCONST, bc.code().get(0).op()); assertEquals(BcOp.LDCONST, bc.code().get(1).op()); assertEquals(BcOp.ADD, bc.code().get(2).op()); assertEquals(BcOp.RETURN, bc.code().get(3).op()); + var consts = bc.consts().stream().toList(); + assertEquals(3, consts.size()); assertEquals(SEXPs.integer(1), consts.get(0)); assertEquals(SEXPs.integer(2), consts.get(1)); assertEquals(astClone, consts.get(2)); - assertEquals(0, ((BcInstr.LdConst) bc.code().get(0)).constant().idx); - assertEquals(bc.consts(), ((BcInstr.LdConst) bc.code().get(0)).constant().pool); - assertEquals(1, ((BcInstr.LdConst) bc.code().get(1)).constant().idx); - assertEquals(bc.consts(), ((BcInstr.LdConst) bc.code().get(1)).constant().pool); - assertEquals(2, ((BcInstr.Add) bc.code().get(2)).ast().idx); - assertEquals(bc.consts(), ((BcInstr.Add) bc.code().get(2)).ast().pool); + + assertEquals(0, ((BcInstr.LdConst) bc.code().get(0)).constant().idx()); + assertEquals(1, ((BcInstr.LdConst) bc.code().get(1)).constant().idx()); + assertEquals(2, ((BcInstr.Add) bc.code().get(2)).ast().idx()); } } diff --git a/src/test/java/org/prlprg/bc/CompilerTest.java b/src/test/java/org/prlprg/bc/CompilerTest.java index 053017013..cb076b26c 100644 --- a/src/test/java/org/prlprg/bc/CompilerTest.java +++ b/src/test/java/org/prlprg/bc/CompilerTest.java @@ -1,22 +1,34 @@ package org.prlprg.bc; +import static com.google.common.truth.Truth.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.prlprg.util.StructuralUtils.printStructurally; +import static org.junit.jupiter.api.Assertions.fail; +import com.google.common.collect.Streams; import java.io.File; import java.io.IOException; +import java.util.List; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; -import org.prlprg.RSession; -import org.prlprg.rsession.TestRSession; -import org.prlprg.sexp.BCodeSXP; -import org.prlprg.sexp.CloSXP; -import org.prlprg.sexp.SEXPs; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.prlprg.sexp.*; import org.prlprg.util.*; +import org.prlprg.util.AbstractGNURBasedTest; -public class CompilerTest implements Tests { +public class CompilerTest extends AbstractGNURBasedTest implements Tests { - private final RSession rsession = new TestRSession(); - private final GNUR R = new GNUR(rsession); + @Test + public void testEmptyList() { + assertBytecode( + """ + function () + { + function(other = list()) 1 + } + """); + } @Test public void testEmptyBlock() { @@ -69,8 +81,6 @@ public void testFunctionLeftParenInlining() { assertBytecode(""" function(x) (...) """); - - // TODO: (x <- 1) } @Test @@ -101,93 +111,93 @@ public void specialsInlining() { @Test public void inlineLocal() { assertBytecode(""" -function(x) local(x) -"""); + function(x) local(x) + """); } @Test public void inlineReturn() { assertBytecode(""" -function(x) return(x) -"""); + function(x) return(x) + """); } @Test public void inlineBuiltinsInternal() { assertBytecode(""" -function(x) .Internal(inspect(x)) -"""); + function(x) .Internal(inspect(x)) + """); assertBytecode(""" -function(x) .Internal(inspect2(x)) -"""); + function(x) .Internal(inspect2(x)) + """); } @Test public void inlineLogicalAnd() { assertBytecode(""" -function(x, y) x && y -"""); + function(x, y) x && y + """); assertBytecode(""" -function(x, y, z) x && y && z -"""); + function(x, y, z) x && y && z + """); assertBytecode(""" -function(x, y) x && y && (x && y) -"""); + function(x, y) x && y && (x && y) + """); } @Test public void inlineLogicalOr() { assertBytecode(""" -function(x, y) x || y -"""); + function(x, y) x || y + """); assertBytecode(""" -function(x, y, z) x || y || z -"""); + function(x, y, z) x || y || z + """); assertBytecode(""" -function(x, y) x || y || (x || y) -"""); + function(x, y) x || y || (x || y) + """); } @Test public void inlineLogicalAndOr() { assertBytecode(""" -function(x, y) x && y || y -"""); + function(x, y) x && y || y + """); assertBytecode(""" -function(x, y, z) x || y && z -"""); + function(x, y, z) x || y && z + """); } @Test public void inlineRepeat() { assertBytecode(""" - function(x) repeat(x) - """); + function(x) repeat(x) + """); assertBytecode(""" function(x, y) repeat({ if (x) break() else y }) """); assertBytecode(""" - function(x, y) repeat({ if (x) next() else y }) - """); + function(x, y) repeat({ if (x) next() else y }) + """); assertBytecode(""" - function(x, y=break()) repeat({ if (x) y else 1 }) - """); + function(x, y=break()) repeat({ if (x) y else 1 }) + """); } @Test public void inlineWhile() { assertBytecode(""" - function(x) while(x) 1 - """); + function(x) while(x) 1 + """); assertBytecode(""" function(x, y) while(x) { break() } @@ -201,58 +211,59 @@ public void inlineWhile() { @Test public void inlineFor() { assertBytecode(""" -function(x) for (i in x) 1 -"""); + function(x) for (i in x) 1 + """); assertBytecode(""" -function(x) for (i in x) if (i) break() else 1 -"""); + function(x) for (i in x) if (i) break() else 1 + """); } @Test public void inlineArithmetics() { assertBytecode(""" - function(x, y) x + y - """); + function(x, y) x + y + """); assertBytecode(""" - function(x, y) x - y - """); + function(x, y) x - y + """); - assertBytecode(""" - function(x, y) { - list(x + y - x + 10, -x + 1, +y) - } - """); + assertBytecode( + """ + function(x, y) { + list(x + y - x + 10, -x + 1, +y) + } + """); assertBytecode( """ - function(x, y) { - list(x * y / x * 10, exp(x) ^ 2, sqrt(exp(x))) - } - """); + function(x, y) { + list(x * y / x * 10, exp(x) ^ 2, sqrt(exp(x))) + } + """); assertBytecode(""" function(x, y) { list(log(x), log(x, y)) } - """); + """); } @Test public void inlineMath1() { assertBytecode( """ - function(x) { - list( - floor(x), ceiling(x), sign(x), - expm1(x), log1p(x), - cos(x), sin(x), tan(x), acos(x), asin(x), atan(x), - cosh(x), sinh(x), tanh(x), acosh(x), asinh(x), atanh(x), - lgamma(x), gamma(x), digamma(x), trigamma(x), - cospi(x), sinpi(x), tanpi(x) - ) - } + function(x) { + list( + floor(x), ceiling(x), sign(x), + expm1(x), log1p(x), + cos(x), sin(x), tan(x), acos(x), asin(x), atan(x), + cosh(x), sinh(x), tanh(x), acosh(x), asinh(x), atanh(x), + lgamma(x), gamma(x), digamma(x), trigamma(x), + cospi(x), sinpi(x), tanpi(x) + ) + } """); } @@ -260,80 +271,383 @@ public void inlineMath1() { public void inlineLogical() { assertBytecode( """ - function(x, y) { - list( - x == y, x != y, x < y, x <= y, x > y, x >= y, x & y, x | y, !x - ) - } - """); + function(x, y) { + list( + x == y, x != y, x < y, x <= y, x > y, x >= y, x & y, x | y, !x + ) + } + """); } @Test public void inlineDollar() { assertBytecode( """ - # xs <- list(a=1, b=list(c=2)) - function(xs) { - xs$a - xs$"a" - xs$b$c - xs$"b"$c - xs$"b"$"c" - } - """); + # xs <- list(a=1, b=list(c=2)) + function(xs) { + xs$a + xs$"a" + xs$b$c + xs$"b"$c + xs$"b"$"c" + } + """); } @Test public void inlineIsXYZ() { assertBytecode( """ - function(x) { - list( - is.character(x), - is.complex(x), - is.double(x), - is.integer(x), - is.logical(x), - is.name(x), - is.null(x), - is.object(x), - is.symbol(x) - ) - } - """); + function(x) { + list( + is.character(x), + is.complex(x), + is.double(x), + is.integer(x), + is.logical(x), + is.name(x), + is.null(x), + is.object(x), + is.symbol(x) + ) + } + """); } @Test public void inlineDotCall() { assertBytecode( """ - function(x) { - .Call("bar") - .Call("foo", x, 1, TRUE) - } - """); + function(x) { + .Call("bar") + .Call("foo", x, 1, TRUE) + } + """); } @Test public void inlineIntGeneratingSequences() { assertBytecode( """ - function(x, xs) { - list(x:100, seq_along(xs), seq_len(x)) - } - """); + function(x, xs) { + list(x:100, seq_along(xs), seq_len(x)) + } + """); } - // TODO: with / require - @Test public void multiColon() { assertBytecode( """ - function() { - list(compiler::cmpfun, compiler:::makeCenv) - } - """); + function() { + list(compiler::cmpfun, compiler:::makeCenv) + } + """); + } + + @Test + public void inlineSwitch() { + assertBytecode( + """ + function(x) { + if (switch(x, 1, 2, g(3))) { + if (y) 4 else 5 + } + } + """); + } + + @Test + public void inlineAssign1() { + assertBytecode(""" + function() { + x <- 1 + } + """); + + assertBytecode(""" + function() { + y <<- 2 + } + """); + + assertBytecode( + """ + function() { + a::b <- 1 + a:::b <- 3 + a:::b <<- 3 + } + """); + } + + @Test + public void inlineAssign2() { + assertBytecode(""" + function() { + f(x) <- 1 + } + """); + + assertBytecode(""" + function() { + pkg::f(x) <- 1 + } + """); + } + + @Test + public void inlineAssign3() { + assertBytecode(""" + function() { + f(g(h(x, k), j), i) <- v + } + """); + } + + @Test + public void inlineDollarAssign() { + assertBytecode( + """ + function() { + x$y <- 1 + x$"z" <- 2 + a::b$c <- 3 + } + """); + } + + @Test + public void inlineSquareAssign1() { + assertBytecode( + """ + function() { + x[y == 1] <- 1 + x[[y == 1]] <- 1 + } + """); + } + + @Test + public void inlineSquareAssign2() { + assertBytecode( + """ + function() { + x[y == 1, z == 2] <- 1 + x[[y == 1, z == 2]] <- 1 + } + """); + } + + @Test + public void inlineSquareAssign3() { + assertBytecode( + """ + function() { + x[y == 1, ] <- 1 + x[[y == 1, ]] <- 1 + } + """); + } + + @Test + public void inlineSquareAssign4() { + assertBytecode(""" + function() { + x$y[-c(1,2)] <- 1 + } + """); + } + + @Test + public void inlineSquareSubset1() { + assertBytecode(""" + function() { + x[y == 1] + x[[y == 1]] + } + """); + } + + @Test + public void inlineSquareSubset2() { + assertBytecode( + """ + function() { + x[y == 1, z == 2] + x[[y == 1, z == 2]] + } + """); + } + + @Test + public void inlineSquareSubset3() { + assertBytecode(""" + function() { + x[y == 1,] + x[[y == 1,]] + } + """); + } + + @Test + public void inlineSquareSubset4() { + assertBytecode(""" + function() { + x[a=1,] + x[[a=1,]] + } + """); + } + + @Test + public void inlineSlotAssign() { + assertBytecode( + """ + function() { + setClass("A", slots = list(x = "numeric")) + a <- new("A", x = 42) + a@x <- 43 + } + """); + } + + @Test + public void inlineIdentical() { + assertBytecode(""" + function(x) { + identical(unzip, "internal") + } + """); + } + + @Test + public void testMatchCall() { + var def = (CloSXP) R.eval("f <- function(a=1,b,c=100,d) {}"); + var call = (LangSXP) R.eval("quote(f(d=1,3,a=2))"); + + var matched = Compiler.matchCall(def, call); + + assertThat(matched).isEqualTo(R.eval("quote(f(d=1,a=2,b=3))")); + } + + @Test + public void constantFoldingC() { + // no constant folding - c is resolved from baseenv() + assertBytecode( + """ + function () { + c("%Y-%m-%d", "%d-%m-%Y", "%m-%d-%Y") + } + """); + + // constant folding - optlevel 3 + assertBytecode( + """ + function () { + c("%Y-%m-%d", "%d-%m-%Y", "%m-%d-%Y") + } + """, + 3); + } + + @Test + public void constantFoldNamedC() { + var code = """ + function(x) c(i = 1, d = 1, s = 1) + """; + var sexp = compile(code, 3); + var bc = ((BCodeSXP) sexp).bc(); + // FIXME: use some matchers + var i = (BcInstr.LdConst) bc.code().getFirst(); + var v = ((RealSXP) bc.consts().get(i.constant())); + assertEquals(3, v.size()); + assertEquals(List.of("i", "d", "s"), v.names()); + assertBytecode(code); + } + + @Test + public void constantFoldMul() { + assertBytecode(""" + function() { + 2 * 3 * 4 + } + """); + } + + @Test + public void constantFoldMul2() { + var code = """ + function(x) { + 2 * 3 * x + } + """; + var sexp = compile(code, 2); + var bc = ((BCodeSXP) sexp).bc(); + // FIXME: use some matchers + var i = (BcInstr.LdConst) bc.code().getFirst(); + var v = ((RealSXP) bc.consts().get(i.constant())); + assertEquals(1, v.size()); + assertEquals(6, v.get(0)); + assertBytecode(code); + } + + @Test + public void constantFoldAdd() { + var code = """ + function(x) 1 + 2 + """; + var sexp = compile(code, 3); + var bc = ((BCodeSXP) sexp).bc(); + + // FIXME: use some matchers + var i = (BcInstr.LdConst) bc.code().getFirst(); + var v = ((RealSXP) bc.consts().get(i.constant())); + assertEquals(1, v.size()); + assertEquals(3, v.get(0)); + } + + @Test + public void constantFoldAdd2() { + var code = """ + function(x) TRUE + c(10, 11) + """; + var sexp = compile(code, 3); + var bc = ((BCodeSXP) sexp).bc(); + // FIXME: use some matchers + var i = (BcInstr.LdConst) bc.code().getFirst(); + var v = ((RealSXP) bc.consts().get(i.constant())); + assertEquals(2, v.size()); + assertEquals(11, v.get(0)); + assertEquals(12, v.get(1)); + } + + @ParameterizedTest + @MethodSource("stdlibFunctionsList") + public void stdlibFunctions(String name) { + assertBytecode(name); + } + + private Stream stdlibFunctionsList() { + StrSXP base = + (StrSXP) + R.eval( + """ + list_functions <- function(name) { + namespace <- getNamespace(name) + p <- function(x) { + f <- get(x, envir=namespace) + is.function(f) && identical(environment(f), namespace) + } + Filter(p, ls(namespace, all.names = TRUE)) + } + + pkgs <- c("base", "tools", "utils", "graphics", "methods", "stats") + funs <- sapply(pkgs, function(x) paste0(x, ":::`", list_functions(x), "`")) + do.call(c, funs) + """); + + return Streams.stream(base.iterator()).map(Arguments::of); } private void assertBytecode(String code) { @@ -352,25 +666,50 @@ private void assertBytecode(String funCode, int optimizationLevel) { String code = "parse(file = '" + temp.getAbsolutePath() - + "', keep.source = TRUE)" // TODO: set conditionally + + "', keep.source = TRUE)" + " |> eval()" + " |> compiler::cmpfun(options = list(optimize=" + optimizationLevel + "))"; - var gnurfun = (CloSXP) R.eval(code); - var gnurbc = ((BCodeSXP) gnurfun.body()).bc(); - var astfun = - SEXPs.closure( - gnurfun.formals(), gnurbc.consts().getFirst(), gnurfun.env(), gnurfun.attributes()); + var gnurFun = (CloSXP) R.eval(code); + var astFun = + SEXPs.closure(gnurFun.formals(), gnurFun.bodyAST(), gnurFun.env(), gnurFun.attributes()); - var compiler = new Compiler(astfun, rsession); - compiler.setOptimizationLevel(optimizationLevel); - var ourbc = compiler.compile(); + // if a function calls browser() it won't be compiled into bytecode + if (gnurFun.body() instanceof BCodeSXP gnurBc) { + var ourBody = compile(astFun, optimizationLevel); - assertEquals( - printStructurally(gnurbc), - printStructurally(ourbc), - "`compile(read(ast)) == read(R.compile(ast))`"); + if (ourBody instanceof BCodeSXP ourBc) { + var eq = gnurBc.equals(ourBc); + + if (!eq) { + // bytecode can be large, so we only want to do it when it is different + assertEquals( + gnurBc.toString(), ourBc.toString(), "`compile(read(ast)) == read(R.compile(ast))`"); + fail("Produced bytecode is different, but the toString() representation is the same."); + } + } else { + assertEquals(gnurBc.toString(), ourBody.toString()); + } + } else { + var ourBody = compile(astFun, optimizationLevel); + + if (ourBody instanceof BCodeSXP ourBc) { + assertEquals(astFun.body().toString(), ourBc.toString()); + } else { + assertEquals(astFun.body(), ourBody); + } + } + } + + private SEXP compile(String fun, int optimizationLevel) { + return compile((CloSXP) R.eval(fun), optimizationLevel); + } + + private SEXP compile(CloSXP fun, int optimizationLevel) { + var compiler = new Compiler(fun, rsession); + compiler.setOptimizationLevel(optimizationLevel); + return compiler.compile().map(SEXPs::bcode).orElse(fun.body()); } } diff --git a/src/test/java/org/prlprg/bc/ConstantFoldingTest.java b/src/test/java/org/prlprg/bc/ConstantFoldingTest.java new file mode 100644 index 000000000..31b35dca0 --- /dev/null +++ b/src/test/java/org/prlprg/bc/ConstantFoldingTest.java @@ -0,0 +1,28 @@ +package org.prlprg.bc; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.prlprg.sexp.SEXPs; + +public class ConstantFoldingTest { + + @Test + public void testPlus() { + var a = SEXPs.complex(1, 2); + var b = SEXPs.integer(5); + var res = ConstantFolding.add(List.of(a, b)).get(); + var expected = SEXPs.complex(6, 2); + assertEquals(res, expected); + } + + @Test + public void textExp() { + var a = SEXPs.real(new double[] {2, 4}); + var b = SEXPs.real(new double[] {2, 3}); + var res = ConstantFolding.pow(List.of(a, b)).get(); + var expected = SEXPs.real(new double[] {4, 64}); + assertEquals(res, expected); + } +} diff --git a/src/test/java/org/prlprg/bc/ContextTest.java b/src/test/java/org/prlprg/bc/ContextTest.java index 1c1c3d3ab..fe0bec3b0 100644 --- a/src/test/java/org/prlprg/bc/ContextTest.java +++ b/src/test/java/org/prlprg/bc/ContextTest.java @@ -1,22 +1,18 @@ package org.prlprg.bc; import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.HashSet; import org.junit.jupiter.api.Test; -import org.prlprg.RSession; -import org.prlprg.rsession.TestRSession; import org.prlprg.sexp.CloSXP; +import org.prlprg.sexp.NamespaceEnvSXP; import org.prlprg.sexp.PromSXP; import org.prlprg.sexp.SEXPs; -import org.prlprg.util.GNUR; +import org.prlprg.util.AbstractGNURBasedTest; import org.prlprg.util.Pair; -public class ContextTest { - - private final RSession rsession = new TestRSession(); - private final GNUR R = new GNUR(rsession); - +public class ContextTest extends AbstractGNURBasedTest { @Test public void testFindLocals() { var fun = @@ -32,7 +28,7 @@ public void testFindLocals() { """); var ctx = Context.functionContext(fun); - assertThat(ctx.findLocals(fun.body())).containsExactly("y", "z", "zz"); + assertThat(ctx.findLocals(fun.bodyAST())).containsExactly("y", "z", "zz"); } @Test @@ -66,7 +62,7 @@ public void testFindLocalsWithShadowing() { var ctx = Context.functionContext(fun); var locals = new HashSet<>(); locals.addAll(ctx.findLocals(fun.formals())); - locals.addAll(ctx.findLocals(fun.body())); + locals.addAll(ctx.findLocals(fun.bodyAST())); assertThat(locals).containsExactly("local", "x"); } @@ -88,8 +84,6 @@ public void testBindingInNestedFunction() { """); var ctx = Context.functionContext(fun); - System.out.println(ctx); - var x = ctx.resolve("x"); assertThat(x).hasValue(new Pair<>(fun.env(), SEXPs.MISSING_ARG)); @@ -134,12 +128,15 @@ > f(42) """); var ctx = Context.functionContext(fun); - assertThat(ctx.findLocals(fun.body())).containsExactly("x"); + assertThat(ctx.findLocals(fun.bodyAST())).containsExactly("x"); } @Test public void testFrameTypes() { - var fun = (CloSXP) R.eval("tools:::Rcmd"); - System.out.println(fun); + var fun = (CloSXP) R.eval("utils::unzip"); + var ctx = Context.functionContext(fun); + // FIXME: ugly - can we have some matchers for this? + var identical = ctx.resolve("identical").get(); + assertTrue(identical.first() instanceof NamespaceEnvSXP ns && ns.getName().equals("base")); } } diff --git a/src/test/java/org/prlprg/primitive/ComplexTest.java b/src/test/java/org/prlprg/primitive/ComplexTest.java new file mode 100644 index 000000000..70505d7f9 --- /dev/null +++ b/src/test/java/org/prlprg/primitive/ComplexTest.java @@ -0,0 +1,19 @@ +package org.prlprg.primitive; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.prlprg.sexp.RealSXP; +import org.prlprg.util.Tests; + +public class ComplexTest implements Tests { + + @Test + void testPow() { + var a = new Complex(1, 2); + var b = new Complex(3, 4); + var res = a.pow(b); + assertEquals(res.real(), 0.1290095, RealSXP.DOUBLE_CMP_DELTA); + assertEquals(res.imag(), 0.0339241, RealSXP.DOUBLE_CMP_DELTA); + } +} diff --git a/src/test/java/org/prlprg/rds/RDSReaderTest.java b/src/test/java/org/prlprg/rds/RDSReaderTest.java index a9e1f1b22..331e228b2 100644 --- a/src/test/java/org/prlprg/rds/RDSReaderTest.java +++ b/src/test/java/org/prlprg/rds/RDSReaderTest.java @@ -1,25 +1,17 @@ package org.prlprg.rds; import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; +import static org.prlprg.sexp.Coercions.isNA; import java.util.Objects; import org.junit.jupiter.api.Test; -import org.prlprg.RSession; import org.prlprg.primitive.Constants; import org.prlprg.primitive.Logical; -import org.prlprg.rsession.TestRSession; import org.prlprg.sexp.*; -import org.prlprg.util.GNUR; -import org.prlprg.util.Tests; - -public class RDSReaderTest implements Tests { - private final RSession rsession = new TestRSession(); - private final GNUR R = new GNUR(rsession); - - // TODO: rewrite using GNUR +import org.prlprg.util.AbstractGNURBasedTest; +public class RDSReaderTest extends AbstractGNURBasedTest { @Test public void testInts() throws Exception { var sexp = R.eval("c(-.Machine$integer.max, -1L, 0L, NA, 1L, .Machine$integer.max)"); @@ -29,7 +21,7 @@ public void testInts() throws Exception { assertEquals(Constants.INT_MIN, ints.get(0)); assertEquals(-1, ints.get(1)); assertEquals(0, ints.get(2)); - assertEquals(Constants.NA_INT, ints.get(3)); + assertTrue(isNA(ints.get(3))); assertEquals(1, ints.get(4)); assertEquals(Integer.MAX_VALUE, ints.get(5)); } else { @@ -45,7 +37,7 @@ public void testLgls() throws Exception { assertEquals(3, logs.size()); assertEquals(Logical.TRUE, logs.get(0)); assertEquals(Logical.FALSE, logs.get(1)); - assertEquals(Logical.NA, logs.get(2)); + assertTrue(isNA(logs.get(2))); } else { fail("Expected LglSXP"); } @@ -61,7 +53,7 @@ public void testReals() throws Exception { // assertEquals(Double.MIN_VALUE, reals.get(0)); assertEquals(-1.0, reals.get(1)); assertEquals(.0, reals.get(2)); - assertEquals(Constants.NA_REAL, reals.get(3)); + assertTrue(isNA(reals.get(3))); assertEquals(1.0, reals.get(4)); assertEquals(Double.MAX_VALUE, reals.get(5)); } else { @@ -128,7 +120,7 @@ public void testClosure() throws Exception { assertEquals(new TaggedElem("x", SEXPs.MISSING_ARG), formals.get(0)); // TODO: this should really be a snapshot test - var body = sexp.body(); + var body = sexp.bodyAST(); assertThat(body).isInstanceOf(LangSXP.class); assertThat(body.toString()).isEqualTo("\"abc\" + x + length(y)"); } @@ -151,4 +143,43 @@ public void testExpression() throws Exception { var sexp = R.eval("parse(text='function() {}', keep.source = TRUE)"); assertThat(sexp).isInstanceOf(ExprSXP.class); } + + @Test + public void testNullInParams() throws Exception { + var sexp = R.eval("quote(match('AsIs', cl, 0L, NULL))"); + // FIXME: assert on the number of parameters + assertThat(sexp).isInstanceOf(LangSXP.class); + } + + @Test + public void testNullInParamsInBC() throws Exception { + var sexp = (BCodeSXP) R.eval("compiler::compile(quote(match('AsIs', cl, 0L, NULL)))"); + var ast = (LangSXP) sexp.bc().consts().getFirst(); + // here we want to make sure that the trailing NULL did not get lost + assertThat(ast.args()).hasSize(4); + } + + @Test + public void testFormatAsIs() throws Exception { + var sexp = R.eval("format.AsIs"); + assertThat(sexp).isInstanceOf(CloSXP.class); + } + + @Test + public void testComplex() throws Exception { + var sexp = R.eval("c(-1+1i, 0+0i, 1+1i)"); + assertThat(sexp).isInstanceOf(ComplexSXP.class); + } + + @Test + public void testRoundPOSIXt() throws Exception { + var sexp = R.eval("round.POSIXt"); + assertThat(sexp).isInstanceOf(CloSXP.class); + } + + @Test + public void testLocalFuncationBC() throws Exception { + var sexp = R.eval("compiler::cmpfun(function(x) local(x))"); + assertThat(sexp).isInstanceOf(CloSXP.class); + } } diff --git a/src/test/java/org/prlprg/rsession/TestRSession.java b/src/test/java/org/prlprg/rsession/TestRSession.java index 5011a0879..974eed741 100644 --- a/src/test/java/org/prlprg/rsession/TestRSession.java +++ b/src/test/java/org/prlprg/rsession/TestRSession.java @@ -3,6 +3,7 @@ import com.google.common.collect.ImmutableSet; import java.io.IOException; import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Set; import javax.annotation.Nullable; @@ -11,22 +12,25 @@ import org.prlprg.sexp.*; import org.prlprg.util.IO; +// http://adv-r.had.co.nz/Environments.html public class TestRSession implements RSession { - private static final String BASE_SYMBOLS_RDS_FILE = "base.RDS"; + private static final String BASE_SYMBOLS_RDS_FILE = "basevars.RDS"; + private static final String BASE_ENV_RDS_FILE = "baseenv.RDS"; private static final String BUILTINS_SYMBOLS_RDS_FILE = "builtins.RDS"; private static final String SPECIALS_SYMBOLS_RDS_FILE = "specials.RDS"; private static final String BUILTINS_INTERNAL_SYMBOLS_RDS_FILE = "builtins-internal.RDS"; private @Nullable BaseEnvSXP baseEnv = null; + private @Nullable NamespaceEnvSXP baseNamespace = null; private @Nullable GlobalEnvSXP globalEnv = null; private @Nullable Set builtins = null; private @Nullable Set specials = null; private @Nullable Set builtinsInternal = null; - private BaseEnvSXP loadBaseEnv() { + private void bootstrapBase() { try { - // this will work as long as the base.RDS does not need - // to load base or global environment itself + // 1. Load just the symbol names. This will work as long as loading the STRSXP does not need + // baseenv itself var names = (StrSXP) RDSReader.readStream( @@ -35,18 +39,47 @@ private BaseEnvSXP loadBaseEnv() { Objects.requireNonNull( TestRSession.class.getResourceAsStream(BASE_SYMBOLS_RDS_FILE)))); - var frame = new HashMap(names.size()); - names.forEach(x -> frame.put(x, SEXPs.UNBOUND_VALUE)); + // 2. Create empty bindings + var bindings = new HashMap(names.size()); + names.forEach(x -> bindings.put(x, SEXPs.UNBOUND_VALUE)); - return new BaseEnvSXP(frame); + // 3. Create a temporary baseenv and temporart base namespace + baseEnv = new BaseEnvSXP(bindings); + // the 4.3.2 should correspond to the R version that written the RDS files used in this class + baseNamespace = new NamespaceEnvSXP("base", "4.3.2", baseEnv, bindings); + + // 4. Load the values + var temp = + (EnvSXP) + RDSReader.readStream( + this, + IO.maybeDecompress( + Objects.requireNonNull( + TestRSession.class.getResourceAsStream(BASE_ENV_RDS_FILE)))); + + // 5. update them in the baseenv and base namespace + temp.bindings() + .forEach( + x -> { + baseEnv.set(x.getKey(), x.getValue()); + baseNamespace.set(x.getKey(), x.getValue()); + }); } catch (IOException e) { throw new RuntimeException("Failed to load the base environment", e); } } + @Override + public NamespaceEnvSXP baseNamespace() { + if (baseNamespace == null) { + bootstrapBase(); + } + return baseNamespace; + } + public synchronized BaseEnvSXP baseEnv() { if (baseEnv == null) { - baseEnv = loadBaseEnv(); + bootstrapBase(); } return baseEnv; } @@ -129,4 +162,13 @@ public boolean isSpecial(String name) { public boolean isBuiltinInternal(String name) { return builtinsInternal().contains(name); } + + @Override + public NamespaceEnvSXP getNamespace(String name, String version) { + if (name.equals("base")) { + return baseNamespace(); + } else { + return new NamespaceEnvSXP(name, version, baseNamespace(), Map.of()); + } + } } diff --git a/src/test/java/org/prlprg/sexp/CoercionsTest.java b/src/test/java/org/prlprg/sexp/CoercionsTest.java new file mode 100644 index 000000000..1b5bd692b --- /dev/null +++ b/src/test/java/org/prlprg/sexp/CoercionsTest.java @@ -0,0 +1,46 @@ +package org.prlprg.sexp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.prlprg.sexp.Coercions.isNA; + +import org.junit.jupiter.api.Test; +import org.prlprg.primitive.Constants; + +public class CoercionsTest { + + @Test + public void testComplexFromString() { + var c = Coercions.complexFromString("1+2i"); + assertEquals(1.0, c.real()); + assertEquals(2.0, c.imag()); + + c = Coercions.complexFromString("1-2i"); + assertEquals(1.0, c.real()); + assertEquals(-2.0, c.imag()); + + c = Coercions.complexFromString(" - 1-2i"); + assertEquals(-1.0, c.real()); + assertEquals(-2.0, c.imag()); + + c = Coercions.complexFromString("1"); + assertEquals(1.0, c.real()); + assertEquals(0.0, c.imag()); + + c = Coercions.complexFromString("-1"); + assertEquals(-1.0, c.real()); + assertEquals(0.0, c.imag()); + + c = Coercions.complexFromString("-1i"); + assertEquals(0.0, c.real()); + assertEquals(-1.0, c.imag()); + + c = Coercions.complexFromString("1i"); + assertEquals(0.0, c.real()); + assertEquals(1.0, c.imag()); + + c = Coercions.complexFromString(Constants.NA_STRING); + assertTrue(isNA(c.real())); + assertTrue(isNA(c.imag())); + } +} diff --git a/src/test/java/org/prlprg/util/AbstractGNURBasedTest.java b/src/test/java/org/prlprg/util/AbstractGNURBasedTest.java new file mode 100644 index 000000000..815b861ee --- /dev/null +++ b/src/test/java/org/prlprg/util/AbstractGNURBasedTest.java @@ -0,0 +1,31 @@ +package org.prlprg.util; + +import java.io.IOException; +import java.io.InputStream; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.TestInstance; +import org.prlprg.RSession; +import org.prlprg.rds.RDSReader; +import org.prlprg.rsession.TestRSession; +import org.prlprg.sexp.SEXP; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class AbstractGNURBasedTest { + protected static RSession rsession = new TestRSession(); + protected GNUR R; + + @BeforeAll + public void startR() { + R = GNUR.spawn(rsession); + } + + @AfterAll + public void stopR() { + R.close(); + } + + protected SEXP readRDS(InputStream input) throws IOException { + return RDSReader.readStream(rsession, IO.maybeDecompress(input)); + } +} diff --git a/src/test/java/org/prlprg/util/ArithmeticTest.java b/src/test/java/org/prlprg/util/ArithmeticTest.java new file mode 100644 index 000000000..4c1d1fc88 --- /dev/null +++ b/src/test/java/org/prlprg/util/ArithmeticTest.java @@ -0,0 +1,50 @@ +package org.prlprg.util; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.prlprg.bc.ConstantFolding; +import org.prlprg.primitive.Complex; +import org.prlprg.sexp.SEXPs; +import org.prlprg.sexp.VectorSXP; + +public class ArithmeticTest { + + public static Stream complexAdd() { + return Stream.of( + arguments( + new Complex[] {new Complex(1, 2), new Complex(3, 4)}, + new Complex[] {new Complex(5, 6), new Complex(7, 8)}, + new Complex[] {new Complex(6, 8), new Complex(10, 12)}), + arguments( + new Complex[] {new Complex(1, 2), new Complex(3, 4)}, + new Complex[] {new Complex(5, 6)}, + new Complex[] {new Complex(6, 8), new Complex(8, 10)}), + arguments( + new Complex[] {new Complex(1, 2)}, + new Complex[] {new Complex(5, 6), new Complex(3, 4)}, + new Complex[] {new Complex(6, 8), new Complex(4, 6)}), + arguments( + new Complex[] {new Complex(1, 2), new Complex(2, 3), new Complex(3, 4)}, + new Complex[] {new Complex(5, 6), new Complex(3, 4)}, + new Complex[] {new Complex(6, 8), new Complex(5, 7), new Complex(8, 10)}), + arguments( + new Complex[] {new Complex(1, 2), new Complex(2, 3), new Complex(3, 4)}, + new Complex[] {}, + new Complex[] {})); + } + + @ParameterizedTest + @MethodSource("complexAdd") + public void testComplexBinary(Complex[] a, Complex[] b, Complex[] expected) { + var res = ConstantFolding.add(List.of(SEXPs.complex(a), SEXPs.complex(b))); + assertTrue(res.isPresent()); + assertArrayEquals(((VectorSXP) res.get()).coerceToComplexes(), expected); + } +} diff --git a/src/test/java/org/prlprg/util/GNUR.java b/src/test/java/org/prlprg/util/GNUR.java index f10c9726c..b454466a2 100644 --- a/src/test/java/org/prlprg/util/GNUR.java +++ b/src/test/java/org/prlprg/util/GNUR.java @@ -1,21 +1,44 @@ package org.prlprg.util; -import java.io.File; -import java.io.PrintWriter; +import static java.lang.String.format; + +import java.io.*; +import java.util.UUID; +import javax.annotation.concurrent.NotThreadSafe; import org.prlprg.RSession; import org.prlprg.rds.RDSReader; import org.prlprg.sexp.SEXP; -public class GNUR { +@NotThreadSafe +public class GNUR implements AutoCloseable { public static final String R_BIN = "R"; private final RSession rsession; + private final Process rprocess; + private final PrintStream rin; + private final BufferedReader rout; - public GNUR(RSession rsession) { + public GNUR(RSession rsession, Process rprocess) { this.rsession = rsession; + this.rprocess = rprocess; + this.rin = new PrintStream(rprocess.getOutputStream()); + this.rout = new BufferedReader(new InputStreamReader(rprocess.getInputStream())); + } + + private void run(String code) { + var requestId = UUID.randomUUID().toString(); + + if (!rprocess.isAlive()) { + throw new RuntimeException("R is not running"); + } + + rin.println(code); + rin.printf("cat('%s\n')", requestId); + rin.println(); + rin.flush(); + waitForCommand(requestId); } - // TODO: keep a session open - do not start a new R every time public SEXP eval(String source) { try { var sourceFile = File.createTempFile("RCS-test", ".R"); @@ -26,10 +49,11 @@ public SEXP eval(String source) { } var code = - String.format( - "saveRDS(eval(parse(file=\"%s\")), \"%s\", compress=FALSE)", + format( + "saveRDS(eval(parse(file='%s'), envir=new.env(parent=baseenv())), '%s', version=2, compress=FALSE)", sourceFile.getAbsoluteFile(), targetFile.getAbsoluteFile()); - runCode(code); + + run(code); var sxp = RDSReader.readFile(rsession, targetFile); @@ -42,19 +66,51 @@ public SEXP eval(String source) { } } - private static void runCode(String code) { + private void waitForCommand(String requestId) { + var output = new StringBuilder(); try { - var pb = - new ProcessBuilder(R_BIN, "--slave", "--vanilla", "--no-save", "-e", code) - .redirectErrorStream(true); - var proc = pb.start(); - var output = new String(proc.getInputStream().readAllBytes()); - var exit = proc.waitFor(); - if (exit != 0) { - throw new RuntimeException("R exited with code " + exit + ":\n" + output); + while (true) { + if (!rprocess.isAlive()) { + throw new RuntimeException("R exited unexpectedly"); + } + + var line = rout.readLine(); + if (line == null) { + throw new RuntimeException("R exited unexpectedly"); + } + + if (line.equals(requestId)) { + return; + } + + output.append(line).append("\n"); } } catch (Exception e) { - throw new RuntimeException("Unable to run R code: " + code, e); + int exit; + try { + exit = rprocess.waitFor(); + + throw new RuntimeException( + "R REPL died (status: " + exit + ") Output so far:\n " + output, e); + + } catch (InterruptedException ex) { + throw new RuntimeException("Interrupted waiting for R process to finish dying", ex); + } + } + } + + @Override + public void close() { + rprocess.destroy(); + } + + public static GNUR spawn(RSession session) { + try { + var proc = + new ProcessBuilder(R_BIN, "--slave", "--vanilla").redirectErrorStream(true).start(); + return new GNUR(session, proc); + } catch (Exception e) { + throw new RuntimeException("Unable to start R", e); } } } diff --git a/src/test/java/org/prlprg/util/StructuralUtils.java b/src/test/java/org/prlprg/util/StructuralUtils.java deleted file mode 100644 index 3e58bff51..000000000 --- a/src/test/java/org/prlprg/util/StructuralUtils.java +++ /dev/null @@ -1,45 +0,0 @@ -package org.prlprg.util; - -import java.util.Set; -import java.util.regex.Pattern; - -public class StructuralUtils { - private static final String HASH_PATTERN = "-?[0-9a-fA-F]{1,16}"; - private static final Set> PATTERNS = - Set.of( - new Pair<>(Pattern.compile("@" + HASH_PATTERN), "#"), - new Pair<>(Pattern.compile(""), "")); - - /** - * Calls {@link Object#toString}, then replaces obvious references and hash-codes with - * deterministic values. This means that you can test that two objects are structurally equivalent - * by comparing their {@code printStructurally}. - * - *

It uses heuristics to find and references and hash-codes, so it can't be relied on for - * any data-structures we can't or aren't willing to change the {@link Object#toString} - * representation of to make them pass the tests. - */ - public static String printStructurally(Object object) { - var string = object.toString(); - - // Get hashes in order of occurrence - for (var pattern : PATTERNS) { - var hashMatcher = pattern.first().matcher(string); - var hashes = new java.util.LinkedHashSet(); - while (hashMatcher.find()) { - hashes.add(hashMatcher.group()); - } - - // Replace each hash with its index - var idx = 0; - for (var hash : hashes) { - var replacement = pattern.second().replace("#", Integer.toString(idx)); - string = string.replace(hash, replacement); - idx++; - } - } - return string; - } - - private StructuralUtils() {} -} diff --git a/src/test/resources/org/prlprg/rsession/base.RDS b/src/test/resources/org/prlprg/rsession/base.RDS deleted file mode 100644 index 11ac62eba..000000000 Binary files a/src/test/resources/org/prlprg/rsession/base.RDS and /dev/null differ diff --git a/src/test/resources/org/prlprg/rsession/baseenv.RDS b/src/test/resources/org/prlprg/rsession/baseenv.RDS new file mode 100644 index 000000000..e24347bb0 Binary files /dev/null and b/src/test/resources/org/prlprg/rsession/baseenv.RDS differ diff --git a/src/test/resources/org/prlprg/rsession/basevars.RDS b/src/test/resources/org/prlprg/rsession/basevars.RDS new file mode 100644 index 000000000..680345372 Binary files /dev/null and b/src/test/resources/org/prlprg/rsession/basevars.RDS differ diff --git a/src/test/resources/org/prlprg/rsession/bootstrap.R b/src/test/resources/org/prlprg/rsession/bootstrap.R index 720b1f71e..c3d0c607a 100644 --- a/src/test/resources/org/prlprg/rsession/bootstrap.R +++ b/src/test/resources/org/prlprg/rsession/bootstrap.R @@ -1,8 +1,44 @@ basevars <- ls("package:base", all.names = TRUE) types <- sapply(basevars, \(x) typeof(get(x))) -saveRDS(basevars[types == "special"], "specials.RDS") -saveRDS(basevars[types == "builtin"], "builtins.RDS") -saveRDS(basevars, "base.RDS") +specials <- basevars[types == "special"] +cat("saving ", length(specials), " specials\n") +saveRDS(specials, "specials.RDS", version = 2) -saveRDS(builtins(internal = TRUE), "builtins-internal.RDS") +builtins <- basevars[types == "builtin"] +cat("saving ", length(builtins), " builtins\n") +saveRDS(builtins, "builtins.RDS", version = 2) + +builtin_internals <- builtins(internal = TRUE) +simple_builtin_internals <- builtin_internals[sapply(builtin_internals, \(x) .Internal(is.builtin.internal(as.name(x))))] +cat("saving ", length(simple_builtin_internals), " builtin internals\n") +saveRDS(simple_builtin_internals, "builtins-internal.RDS", version = 2) + +base_env_funs <- basevars[sapply(basevars, \(x) is.function(get(x)))] +base_env <- sapply(base_env_funs, \(x) if (x %in% builtin_internals) get(x) else as.function(c(formals(get(x)), list(NULL)))) +base_env <- as.environment(base_env) +base_env$pi <- pi +base_env$T <- T +base_env$F <- F + +cat("saving ", length(basevars), " base variables\n") +saveRDS(basevars, "basevars.RDS", version = 2) + +cat("saving ", length(base_env), " baseenv symbols\n") +saveRDS(base_env, "baseenv.RDS", version = 2) + +# moved to CompilerTest.RDS +# list_functions <- function(name) { +# namespace <- getNamespace(name) +# p <- function(x) { +# f <- get(x, envir=namespace) +# is.function(f) && identical(environment(f), namespace) +# } +# Filter(p, ls(namespace, all.names = TRUE)) +# } +# +# pkgs <- c("base", "tools", "utils", "graphics", "methods", "stats") +# funs <- sapply(pkgs, \(x) paste0(x, ":::`", list_functions(x), "`")) +# funs <- do.call(c, funs) +# cat("saving", length(funs), "functions\n") +# saveRDS(funs, "functions.RDS", version = 2) diff --git a/src/test/resources/org/prlprg/rsession/builtins-internal.RDS b/src/test/resources/org/prlprg/rsession/builtins-internal.RDS index bc56e5a4f..4c86bfd36 100644 Binary files a/src/test/resources/org/prlprg/rsession/builtins-internal.RDS and b/src/test/resources/org/prlprg/rsession/builtins-internal.RDS differ diff --git a/src/test/resources/org/prlprg/rsession/builtins.RDS b/src/test/resources/org/prlprg/rsession/builtins.RDS index a2289b8e2..3217f3055 100644 Binary files a/src/test/resources/org/prlprg/rsession/builtins.RDS and b/src/test/resources/org/prlprg/rsession/builtins.RDS differ diff --git a/src/test/resources/org/prlprg/rsession/specials.RDS b/src/test/resources/org/prlprg/rsession/specials.RDS index 2146c810b..f5927f73c 100644 Binary files a/src/test/resources/org/prlprg/rsession/specials.RDS and b/src/test/resources/org/prlprg/rsession/specials.RDS differ diff --git a/src/test/snapshots/README.md b/src/test/snapshots/README.md deleted file mode 100644 index ddee9abde..000000000 --- a/src/test/snapshots/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Snapshots - -For regression testing, some also serve as cached computations (the `.rds` files created by invoking GNU-R). - -**These don't get bundled because** we have to write to them, so we access them at their actual locations (bundled are overridden on recompile). diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.ast.rds b/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.ast.rds deleted file mode 100644 index 33d533628..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.ast.rds and /dev/null differ diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.bc.rds b/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.bc.rds deleted file mode 100644 index 07daa7e56..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/block_with_one_expression.bc.rds and /dev/null differ diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.ast.rds b/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.ast.rds deleted file mode 100644 index caadcb598..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.ast.rds and /dev/null differ diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.bc.rds b/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.bc.rds deleted file mode 100644 index bc52c91ea..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/empty_block.bc.rds and /dev/null differ diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.ast.rds b/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.ast.rds deleted file mode 100644 index 89d9a466f..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.ast.rds and /dev/null differ diff --git a/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.bc.rds b/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.bc.rds deleted file mode 100644 index a08200284..000000000 Binary files a/src/test/snapshots/org/prlprg/bc/basics.R/nested_function.bc.rds and /dev/null differ