diff --git a/.idea/highlightedFiles.xml b/.idea/highlightedFiles.xml new file mode 100644 index 000000000..ee02cf95a --- /dev/null +++ b/.idea/highlightedFiles.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.settings/org.eclipse.jdt.core.prefs b/.settings/org.eclipse.jdt.core.prefs index 4adc96d52..747676ed2 100644 --- a/.settings/org.eclipse.jdt.core.prefs +++ b/.settings/org.eclipse.jdt.core.prefs @@ -4,8 +4,8 @@ org.eclipse.jdt.core.compiler.annotation.nonnull=org.eclipse.jdt.annotation.NonN org.eclipse.jdt.core.compiler.annotation.nonnullbydefault=org.eclipse.jdt.annotation.NonNullByDefault org.eclipse.jdt.core.compiler.annotation.nullable=org.eclipse.jdt.annotation.Nullable org.eclipse.jdt.core.compiler.annotation.nullanalysis=disabled -org.eclipse.jdt.core.compiler.codegen.targetPlatform=21 -org.eclipse.jdt.core.compiler.compliance=21 +org.eclipse.jdt.core.compiler.codegen.targetPlatform=22 +org.eclipse.jdt.core.compiler.compliance=22 org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning org.eclipse.jdt.core.compiler.problem.nullAnnotationInferenceConflict=error @@ -14,6 +14,6 @@ org.eclipse.jdt.core.compiler.problem.nullSpecViolation=error org.eclipse.jdt.core.compiler.problem.potentialNullReference=ignore org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore org.eclipse.jdt.core.compiler.problem.syntacticNullAnalysisForFields=disabled -org.eclipse.jdt.core.compiler.processAnnotations=disabled +org.eclipse.jdt.core.compiler.processAnnotations=enabled org.eclipse.jdt.core.compiler.release=disabled -org.eclipse.jdt.core.compiler.source=21 +org.eclipse.jdt.core.compiler.source=22 diff --git a/LICENSE b/LICENSE index b435ac4ec..649eb78e6 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/README.md b/README.md index a7227ed67..17fe12999 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,12 @@ import javax.annotation.ParametersAreNonnullByDefault; In IntelliJ you can simply copy `package-info.json` from any package into the new one and it will create this. +The project requires at least Java 22. + ## Commands - Run `make setup` to install Git Hooks. The commit hook formats, the pre-push hook runs tests and static analyses. - Build with `make` or `mvn package` - Test (no static analyses) with `make test` or `mvn test` - Test and static anaylses with `make check` or `mvn verify` -- Format with `make format` or `mvn spotless:apply` +- Format with `make format` or `mvn spotless:apply`. This requires to have `npm` installed. diff --git a/src/main/java/org/prlprg/RVersion.java b/src/main/java/org/prlprg/RVersion.java index d21ff66f1..c307ef3b8 100644 --- a/src/main/java/org/prlprg/RVersion.java +++ b/src/main/java/org/prlprg/RVersion.java @@ -50,6 +50,15 @@ public static RVersion parse(String textual) { this(major, minor, patch, null); } + /** + * Encode the version as an integer. It is used for the RDS serialization for instance. + * + * @return + */ + public int encode() { + return patch + 256 * minor + 65536 * major; + } + @Override public String toString() { return major + "." + minor + "." + patch + (suffix == null ? "" : "-" + suffix); diff --git a/src/main/java/org/prlprg/bc/ConstPool.java b/src/main/java/org/prlprg/bc/ConstPool.java index 78711487b..05d1a8f6b 100644 --- a/src/main/java/org/prlprg/bc/ConstPool.java +++ b/src/main/java/org/prlprg/bc/ConstPool.java @@ -3,12 +3,7 @@ import com.google.common.collect.ForwardingList; import com.google.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.Nullable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.function.Function; import javax.annotation.concurrent.Immutable; import org.prlprg.parseprint.ParseMethod; @@ -125,13 +120,13 @@ public static class Builder { private final List values; public Builder() { - this(Collections.emptyList()); + this.index = new HashMap<>(); + this.values = new ArrayList<>(); } - public Builder(List consts) { - index = new HashMap<>(consts.size()); - values = new ArrayList<>(consts.size()); - + public Builder(Collection consts) { + this.index = new HashMap<>(consts.size()); + this.values = new ArrayList<>(consts.size()); for (var e : consts) { add(e); } @@ -142,7 +137,7 @@ public Idx add(S c) { index.computeIfAbsent( c, (ignored) -> { - var x = index.size(); + var x = values.size(); values.add(c); return x; }); @@ -180,11 +175,21 @@ public Idx indexSym(int i) { return index(i, RegSymSXP.class); } - // FIXME: do we need this? + // FIXME: do we need these? --- public @Nullable Idx indexLangOrNilIfNegative(int i) { return i >= 0 ? orNil(i, LangSXP.class) : null; } + public @Nullable Idx indexIntOrNilIfNegative(int i) { + return i >= 0 ? orNil(i, IntSXP.class) : null; + } + + public @Nullable Idx indexStrOrNilIfNegative(int i) { + return i >= 0 ? orNil(i, StrSXP.class) : null; + } + + // -- FIXME + public @Nullable Idx indexStrOrSymOrNil(int i) { return orNil(i, StrOrRegSymSXP.class); } diff --git a/src/main/java/org/prlprg/primitive/Logical.java b/src/main/java/org/prlprg/primitive/Logical.java index acff49a2a..914f9cc0d 100644 --- a/src/main/java/org/prlprg/primitive/Logical.java +++ b/src/main/java/org/prlprg/primitive/Logical.java @@ -26,6 +26,11 @@ public static Logical valueOf(int i) { }; } + /** Convert to GNU-R representation. */ + public int toInt() { + return i; + } + Logical(int i) { this.i = i; } diff --git a/src/main/java/org/prlprg/rds/Flags.java b/src/main/java/org/prlprg/rds/Flags.java index c59707ce3..530e4b252 100644 --- a/src/main/java/org/prlprg/rds/Flags.java +++ b/src/main/java/org/prlprg/rds/Flags.java @@ -1,7 +1,26 @@ package org.prlprg.rds; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * The bitflags describing a {@code SEXP} in RDS. They are described as follows: + * + *
    + *
  • 0-7: describe the SEXP's RDSItemType + *
  • 8: enabled if the SEXP is an object + *
  • 9: enabled if the SEXP has attributes + *
  • 10: enabled if the SEXP has a tag (for the pairlist types) + *
  • 11: unused bit + *
  • 12-27: general purpose (gp) bits, as defined in the {@code SXPINFO} struct + *
+ * + * There are 16 GP bits, which are traditionally present on each SEXP. + */ final class Flags { - private static final int UTF8_MASK = 1 << 3; + private static final int OBJECT_MASK = 1 << 8; private static final int ATTR_MASK = 1 << 9; private static final int TAG_MASK = 1 << 10; @@ -16,32 +35,39 @@ public Flags(int flags) { } } + // Pack the flags of a regular item public Flags( - RDSItemType type, - int levels, - boolean isUTF8, - boolean hasAttributes, - boolean hasTag, - int refIndex) { + RDSItemType type, GPFlags levels, boolean isObject, boolean hasAttributes, boolean hasTag) { + if (type.i() == RDSItemType.Special.REFSXP.i()) + throw new IllegalArgumentException( + "Cannot write REFSXP with this constructor: ref index " + "needed"); this.flags = type.i() - | (levels << 12) - | (isUTF8 ? UTF8_MASK : 0) + | (levels.encode() << 12) + | (isObject ? OBJECT_MASK : 0) | (hasAttributes ? ATTR_MASK : 0) - | (hasTag ? TAG_MASK : 0) - | (refIndex << 8); + | (hasTag ? TAG_MASK : 0); + } + + // Pack the flags of a reference + public Flags(RDSItemType type, int refIndex) { + if (type.i() != RDSItemType.Special.REFSXP.i()) + throw new IllegalArgumentException( + "Cannot write REFSXP with this constructor: ref index " + "needed"); + this.flags = type.i() | (refIndex << 8); } public RDSItemType getType() { return RDSItemType.valueOf(flags & 255); } - public int decodeLevels() { - return flags >> 12; + // The levels contained in Flags are general-purpose bits. + public GPFlags getLevels() { + return new GPFlags(flags >> 12); } - public boolean isUTF8() { - return (decodeLevels() & UTF8_MASK) != 0; + public boolean isObject() { + return (flags & OBJECT_MASK) != 0; } public boolean hasAttributes() { @@ -56,15 +82,28 @@ public int unpackRefIndex() { return flags >> 8; } + /** + * Returns a new Flags identical to this one, but with the hasAttr bit set according to + * hasAttributes. + */ + public Flags withAttributes(boolean hasAttributes) { + return new Flags(this.flags & ~ATTR_MASK | (hasAttributes ? ATTR_MASK : 0)); + } + + /** Returns a new Flags identical to this one, but with the hasTag bit set according to hasTag. */ + public Flags withTag(boolean hasTag) { + return new Flags(this.flags & ~TAG_MASK | (hasTag ? TAG_MASK : 0)); + } + @Override public String toString() { return "Flags{" + "type=" + getType() + ", levels=" - + decodeLevels() - + ", isUTF8=" - + isUTF8() + + getLevels().encode() + + ", isObject=" + + isObject() + ", hasAttributes=" + hasAttributes() + ", hasTag=" @@ -73,4 +112,84 @@ public String toString() { + unpackRefIndex() + '}'; } + + public int encode() { + return flags; + } +} + +/** + * Flags corresponding with the general-purpose (gp) bits found on a SEXP. See R internals + * + *
    + *
  • Bits 14 and 15 are used for 'fancy bindings'. Bit 14 is used to lock a binding or + * environment, and bit 15 is used to indicate an active binding. Bit 15 is used for an + * environment to indicate if it participates in the global cache. + *
  • Bits 1, 2, 3, 5, and 6 are used for a {@code CHARSXP} (we use strings) to indicate + * its encoding. Relevant to us are bits 2, 3, and 6, which indicate Latin-1, UTF-8, and ASCII + * respectively. + *
  • Bit 4 is turned on to mark S4 objects + *
+ * + * Currently, only the character encoding flag is used. + */ +final class GPFlags { + // HASHASH_MASK: 1; + // private static final int BYTES_MASK = 1 << 1; + private static final int LATIN1_MASK = 1 << 2; + private static final int UTF8_MASK = 1 << 3; + // S4_OBJECT_MASK: 1 << 4 + // CACHED_MASK = 1 << 5; + private static final int ASCII_MASK = 1 << 6; + private static final int LOCKED_MASK = 1 << 14; + + private final int flags; + + GPFlags(@Nullable Charset charset, boolean locked) { + // NOTE: CACHED_MASK and HASHASH_MASK should be off when packing RDS flags for SEXPType.CHAR, + // but since we don't have external input for levels I think they should be off anyways + this.flags = + (locked ? LOCKED_MASK : 0) + | (charset == StandardCharsets.UTF_8 ? UTF8_MASK : 0) + | (charset == StandardCharsets.US_ASCII ? ASCII_MASK : 0) + | (charset == StandardCharsets.ISO_8859_1 ? LATIN1_MASK : 0); + } + + GPFlags(int levels) { + this.flags = levels; + } + + GPFlags() { + this.flags = 0; + } + + public int encode() { + return flags; + } + + public @Nullable Charset encoding() { + if ((flags & LATIN1_MASK) != 0) { + return StandardCharsets.ISO_8859_1; // ISO_8859_1 is LATIN1 + } else if ((flags & UTF8_MASK) != 0) { + return StandardCharsets.UTF_8; + } else if ((flags & ASCII_MASK) != 0) { + return StandardCharsets.US_ASCII; + } else { + return null; + } + } + + public boolean isLocked() { + return (this.flags & LOCKED_MASK) != 0; + } + + public String toString() { + return "GPFlags{" + + "encoding=" + + (this.encoding() == null ? "null" : Objects.requireNonNull(this.encoding()).name()) + + ", locked=" + + this.isLocked() + + '}'; + } } diff --git a/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java b/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java index 85f45af27..7a58bbf8f 100644 --- a/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java +++ b/src/main/java/org/prlprg/rds/GNURByteCodeDecoderFactory.java @@ -19,7 +19,7 @@ class GNURByteCodeDecoderFactory { cpb = new ConstPool.Builder(consts); cbb = new BcCode.Builder(); - labelMapping = LabelMapping.from(byteCode); + labelMapping = LabelMapping.fromGNUR(byteCode); curr = 1; } @@ -213,9 +213,9 @@ BcInstr decode() { case DUP2ND -> new BcInstr.Dup2nd(); case SWITCH -> { var ast = cpb.indexLang(byteCode.get(curr++)); - var names = cpb.indexStrOrNil(byteCode.get(curr++)); - var chrLabelsIdx = cpb.indexIntOrNil(byteCode.get(curr++)); - var numlabelsIdx = cpb.indexIntOrNil(byteCode.get(curr++)); + var names = cpb.indexStrOrNilIfNegative(byteCode.get(curr++)); + var chrLabelsIdx = cpb.indexIntOrNilIfNegative(byteCode.get(curr++)); + var numLabelsIdx = cpb.indexIntOrNilIfNegative(byteCode.get(curr++)); // in the case switch does not have any named labels this will be null, if (chrLabelsIdx != null) { @@ -226,11 +226,11 @@ BcInstr decode() { // case of empty switch? // in some cases, the number labels can be the same as the chrLabels // and we do not want to remap twice - if (numlabelsIdx != null && !numlabelsIdx.equals(chrLabelsIdx)) { - cpb.reset(numlabelsIdx, this::remapLabels); + if (numLabelsIdx != null && !numLabelsIdx.equals(chrLabelsIdx)) { + cpb.reset(numLabelsIdx, this::remapLabels); } - yield new BcInstr.Switch(ast, names, chrLabelsIdx, numlabelsIdx); + yield new BcInstr.Switch(ast, names, chrLabelsIdx, numLabelsIdx); } case RETURNJMP -> new BcInstr.ReturnJmp(); case STARTSUBSET_N -> @@ -295,97 +295,3 @@ private IntSXP remapLabels(IntSXP oldLabels) { return SEXPs.integer(remapped); } } - -/** - * Create labels from GNU-R labels. - * - * @implNote This contains a map of positions in GNU-R bytecode to positions in our bytecode. We - * need this because every index in our bytecode maps to an instruction, while indexes in - * GNU-R's bytecode also map to the bytecode version and instruction metadata. - */ -class LabelMapping { - private final ImmutableIntArray posMap; - - private LabelMapping(ImmutableIntArray posMap) { - this.posMap = posMap; - } - - /** Create a label from a GNU-R label. */ - BcLabel make(int gnurLabel) { - return new BcLabel(getTarget(gnurLabel)); - } - - int getTarget(int gnurLabel) { - if (gnurLabel == 0) { - throw new IllegalArgumentException("GNU-R label 0 is reserved for the version number"); - } - - var target = posMap.get(gnurLabel); - if (target == -1) { - var gnurEarlier = gnurLabel - 1; - int earlier = posMap.get(gnurEarlier); - var gnurLater = gnurLabel + 1; - int later = posMap.get(gnurLater); - throw new IllegalArgumentException( - "GNU-R position maps to the middle of one of our instructions: " - + gnurLabel - + " between " - + earlier - + " and " - + later); - } - - return target; - } - - static LabelMapping from(ImmutableIntArray gnurBC) { - var builder = new Builder(); - // skip the BC version number - int i = 1; - while (i < gnurBC.length()) { - try { - var op = BcOp.valueOf(gnurBC.get(i)); - var size = 1 + op.nArgs(); - builder.step(size, 1); - i += size; - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException( - "malformed bytecode at " + i + "\nBytecode up to this point: " + builder.build(), e); - } - } - return builder.build(); - } - - // FIXME: inline - static class Builder { - private final ImmutableIntArray.Builder map = ImmutableIntArray.builder(); - private int targetPc = 0; - - Builder() { - // Add initial mapping of 1 -> 0 (version # is 0) - map.add(-1); - map.add(0); - } - - /** Step m times in the source bytecode and n times in the target bytecode */ - void step(int sourceOffset, @SuppressWarnings("SameParameterValue") int targetOffset) { - if (sourceOffset < 0 || targetOffset < 0) { - throw new IllegalArgumentException("offsets must be nonnegative"); - } - - targetPc += targetOffset; - // Offsets before sourceOffset map to the middle of the previous instruction - for (int i = 0; i < sourceOffset - 1; i++) { - map.add(-1); - } - // Add target position - if (sourceOffset > 0) { - map.add(targetPc); - } - } - - LabelMapping build() { - return new LabelMapping(map.build()); - } - } -} diff --git a/src/main/java/org/prlprg/rds/GNURByteCodeEncoderFactory.java b/src/main/java/org/prlprg/rds/GNURByteCodeEncoderFactory.java new file mode 100644 index 000000000..c8ae5d16f --- /dev/null +++ b/src/main/java/org/prlprg/rds/GNURByteCodeEncoderFactory.java @@ -0,0 +1,197 @@ +package org.prlprg.rds; + +import com.google.common.primitives.ImmutableIntArray; +import org.jetbrains.annotations.NotNull; +import org.prlprg.bc.Bc; +import org.prlprg.bc.BcCode; +import org.prlprg.bc.BcInstr; +import org.prlprg.bc.ConstPool; +import org.prlprg.sexp.IntSXP; +import org.prlprg.sexp.SEXPs; + +public class GNURByteCodeEncoderFactory { + private final BcCode bc; + private final ImmutableIntArray.Builder builder; + private final LabelMapping labelMapping; + private final ConstPool.Builder cpb; + + GNURByteCodeEncoderFactory(Bc bc) { + this.bc = bc.code(); + this.builder = ImmutableIntArray.builder(); + this.labelMapping = LabelMapping.toGNUR(this.bc); + this.cpb = new ConstPool.Builder(bc.consts()); + } + + public static class GNURByteCode { + private final ImmutableIntArray instructions; + private final ConstPool consts; + + private GNURByteCode(ImmutableIntArray instructions, ConstPool consts) { + this.instructions = instructions; + this.consts = consts; + } + + public ImmutableIntArray getInstructions() { + return instructions; + } + + public ConstPool getConsts() { + return consts; + } + } + + public GNURByteCode buildRaw() { + // Write the bytecode version first + builder.add(Bc.R_BC_VERSION); + // Write the serialized instruction, containing the opcode and the arguments + for (var instr : bc) { + // Add the opcode + builder.add(instr.op().value()); + // Add the arguments + var args = args(instr, cpb); + if (args.length != instr.op().nArgs()) + throw new AssertionError( + "Sanity check failed: number of arguments " + + "serialized for " + + instr.op().name() + + " is not equal to instr.op().nArgs()"); + builder.addAll(args); + } + return new GNURByteCode(builder.build(), cpb.build()); + } + + /** Converts the arguments of the provided BcInstr to a "raw" format; i.e. an array of integers */ + public int[] args(@NotNull BcInstr instr, ConstPool.Builder cpb) { + return switch (instr) { + case BcInstr.Goto i -> new int[] {labelMapping.extract(i.label())}; + case BcInstr.BrIfNot i -> new int[] {i.ast().idx(), labelMapping.extract(i.label())}; + case BcInstr.StartLoopCntxt i -> + new int[] {i.isForLoop() ? 1 : 0, labelMapping.extract(i.end())}; + case BcInstr.EndLoopCntxt i -> + new int[] { + i.isForLoop() ? 1 : 0, + }; + case BcInstr.StartFor i -> + new int[] {i.ast().idx(), i.elemName().idx(), labelMapping.extract(i.step())}; + case BcInstr.StepFor i -> new int[] {labelMapping.extract(i.body())}; + case BcInstr.LdConst i -> new int[] {i.constant().idx()}; + case BcInstr.GetVar i -> new int[] {i.name().idx()}; + case BcInstr.DdVal i -> new int[] {i.name().idx()}; + case BcInstr.SetVar i -> new int[] {i.name().idx()}; + case BcInstr.GetFun i -> new int[] {i.name().idx()}; + case BcInstr.GetGlobFun i -> new int[] {i.name().idx()}; + case BcInstr.GetSymFun i -> new int[] {i.name().idx()}; + case BcInstr.GetBuiltin i -> new int[] {i.name().idx()}; + case BcInstr.GetIntlBuiltin i -> new int[] {i.name().idx()}; + case BcInstr.MakeProm i -> new int[] {i.code().idx()}; + case BcInstr.SetTag i -> new int[] {i.tag() == null ? -1 : i.tag().idx()}; + case BcInstr.PushConstArg i -> new int[] {i.constant().idx()}; + case BcInstr.Call i -> new int[] {i.ast().idx()}; + case BcInstr.CallBuiltin i -> new int[] {i.ast().idx()}; + case BcInstr.CallSpecial i -> new int[] {i.ast().idx()}; + case BcInstr.MakeClosure i -> new int[] {i.arg().idx()}; + case BcInstr.UMinus i -> new int[] {i.ast().idx()}; + case BcInstr.UPlus i -> new int[] {i.ast().idx()}; + case BcInstr.Add i -> new int[] {i.ast().idx()}; + case BcInstr.Sub i -> new int[] {i.ast().idx()}; + case BcInstr.Mul i -> new int[] {i.ast().idx()}; + case BcInstr.Div i -> new int[] {i.ast().idx()}; + case BcInstr.Expt i -> new int[] {i.ast().idx()}; + case BcInstr.Sqrt i -> new int[] {i.ast().idx()}; + case BcInstr.Exp i -> new int[] {i.ast().idx()}; + case BcInstr.Eq i -> new int[] {i.ast().idx()}; + case BcInstr.Ne i -> new int[] {i.ast().idx()}; + case BcInstr.Lt i -> new int[] {i.ast().idx()}; + case BcInstr.Le i -> new int[] {i.ast().idx()}; + case BcInstr.Ge i -> new int[] {i.ast().idx()}; + case BcInstr.Gt i -> new int[] {i.ast().idx()}; + case BcInstr.And i -> new int[] {i.ast().idx()}; + case BcInstr.Or i -> new int[] {i.ast().idx()}; + case BcInstr.Not i -> new int[] {i.ast().idx()}; + case BcInstr.StartAssign i -> new int[] {i.name().idx()}; + case BcInstr.EndAssign i -> new int[] {i.name().idx()}; + case BcInstr.StartSubset i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.StartSubassign i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.StartC i -> new int[] {i.ast().idx()}; + case BcInstr.StartSubset2 i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.StartSubassign2 i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.Dollar i -> new int[] {i.ast().idx(), i.member().idx()}; + case BcInstr.DollarGets i -> new int[] {i.ast().idx(), i.member().idx()}; + case BcInstr.VecSubset i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.MatSubset i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.VecSubassign i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.MatSubassign i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.And1st i -> new int[] {i.ast().idx(), labelMapping.extract(i.shortCircuit())}; + case BcInstr.And2nd i -> new int[] {i.ast().idx()}; + case BcInstr.Or1st i -> new int[] {i.ast().idx(), labelMapping.extract(i.shortCircuit())}; + case BcInstr.Or2nd i -> new int[] {i.ast().idx()}; + case BcInstr.GetVarMissOk i -> new int[] {i.name().idx()}; + case BcInstr.DdValMissOk i -> new int[] {i.name().idx()}; + case BcInstr.SetVar2 i -> new int[] {i.name().idx()}; + case BcInstr.StartAssign2 i -> new int[] {i.name().idx()}; + case BcInstr.EndAssign2 i -> new int[] {i.name().idx()}; + case BcInstr.SetterCall i -> new int[] {i.ast().idx(), i.valueExpr().idx()}; + case BcInstr.GetterCall i -> new int[] {i.ast().idx()}; + case BcInstr.Switch i -> { + var chrLabelsIdx = i.chrLabelsIdx(); + var numLabelsIdx = i.numLabelsIdx(); + + // Map the contents of the IntSXP referenced at i.chrLabelsIndex to the updated label + // positions + if (chrLabelsIdx != null) { + cpb.reset(chrLabelsIdx, this::remapLabels); + } + // Map the contents of the IntSXP referenced at i.numLabelsIndex to the updated label + // positions + if (numLabelsIdx != null && !numLabelsIdx.equals(chrLabelsIdx)) { + cpb.reset(numLabelsIdx, this::remapLabels); + } + yield new int[] { + i.ast().idx(), + i.names() == null ? -1 : i.names().idx(), + i.chrLabelsIdx() == null ? -1 : i.chrLabelsIdx().idx(), + i.numLabelsIdx() == null ? -1 : i.numLabelsIdx().idx(), + }; + } + + case BcInstr.StartSubsetN i -> + new int[] { + i.ast().idx(), labelMapping.extract(i.after()), + }; + case BcInstr.StartSubassignN i -> + new int[] { + i.ast().idx(), labelMapping.extract(i.after()), + }; + case BcInstr.VecSubset2 i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.MatSubset2 i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.VecSubassign2 i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.MatSubassign2 i -> new int[] {i.ast() == null ? -1 : i.ast().idx()}; + case BcInstr.StartSubset2N i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.StartSubassign2N i -> new int[] {i.ast().idx(), labelMapping.extract(i.after())}; + case BcInstr.SubsetN i -> new int[] {i.ast() == null ? -1 : i.ast().idx(), i.n()}; + case BcInstr.Subset2N i -> new int[] {i.ast() == null ? -1 : i.ast().idx(), i.n()}; + case BcInstr.SubassignN i -> new int[] {i.ast() == null ? -1 : i.ast().idx(), i.n()}; + case BcInstr.Subassign2N i -> new int[] {i.ast() == null ? -1 : i.ast().idx(), i.n()}; + case BcInstr.Log i -> new int[] {i.ast().idx()}; + case BcInstr.LogBase i -> new int[] {i.ast().idx()}; + case BcInstr.Math1 i -> + new int[] { + i.ast().idx(), i.funId(), + }; + case BcInstr.DotCall i -> new int[] {i.ast().idx(), i.numArgs()}; + case BcInstr.Colon i -> new int[] {i.ast().idx()}; + case BcInstr.SeqAlong i -> new int[] {i.ast().idx()}; + case BcInstr.SeqLen i -> new int[] {i.ast().idx()}; + case BcInstr.BaseGuard i -> new int[] {i.expr().idx(), labelMapping.extract(i.ifFail())}; + case BcInstr.DeclnkN i -> new int[] {i.n()}; + + // Otherwise, there are no arguments we need to serialize + default -> new int[0]; + }; + } + + private IntSXP remapLabels(IntSXP oldLabels) { + var remapped = oldLabels.data().stream().map(labelMapping::getTarget).toArray(); + return SEXPs.integer(remapped); + } +} diff --git a/src/main/java/org/prlprg/rds/LabelMapping.java b/src/main/java/org/prlprg/rds/LabelMapping.java new file mode 100644 index 000000000..c40c93759 --- /dev/null +++ b/src/main/java/org/prlprg/rds/LabelMapping.java @@ -0,0 +1,126 @@ +package org.prlprg.rds; + +import com.google.common.primitives.ImmutableIntArray; +import org.prlprg.bc.BcCode; +import org.prlprg.bc.BcInstr; +import org.prlprg.bc.BcLabel; +import org.prlprg.bc.BcOp; + +/** + * Create labels from GNU-R labels to our labels, or vice versa + * + *

This contains a map of positions in GNU-R bytecode to positions in our bytecode. We need this + * because every index in our bytecode maps to an instruction, while indexes in GNU-R's bytecode + * also map to the bytecode version and instruction metadata. + */ +public class LabelMapping { + private final ImmutableIntArray posMap; + + private LabelMapping(ImmutableIntArray posMap) { + this.posMap = posMap; + } + + /** Make a {@link BcLabel} referencing the target from an int referencing the source */ + BcLabel make(int sourceLabel) { + return new BcLabel(getTarget(sourceLabel)); + } + + /** Extract an int referencing the target from a {@link BcLabel} referencing the source */ + public int extract(BcLabel sourceLabel) { + return getTarget(sourceLabel.target()); + } + + int getTarget(int sourceLabel) { + var target = posMap.get(sourceLabel); + if (target == -1) { + if (sourceLabel == 0) { + throw new IllegalArgumentException( + "Could not get target for source label 0. Note that if" + + "the source is GNU-R bytecode, GNU-R label 0 is reserved for the version number"); + } else { + var prev = posMap.get(sourceLabel - 1); + var next = posMap.get(sourceLabel + 1); + throw new IllegalArgumentException( + "Source position " + + sourceLabel + + " maps to the middle of target instructions: " + + " between " + + prev + + " and " + + next); + } + } + return target; + } + + /** Creates a mapping from GNU-R labels to our labels */ + public static LabelMapping fromGNUR(ImmutableIntArray gnurBC) { + // add source offset of 1 (position 0 is the version in GNU-R) + var builder = new LabelMapping.Builder(1, 0); + for (int i = 1; i < gnurBC.length(); ) { + try { + var op = BcOp.valueOf(gnurBC.get(i)); + var size = 1 + op.nArgs(); + builder.step(size, 1); + i += size; + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "malformed bytecode at " + i + "\nBytecode up to this point: " + builder.build(), e); + } + } + return builder.build(); + } + + /** Creates a mapping from our labels to GNUR labels */ + public static LabelMapping toGNUR(BcCode bc) { + // add target offset of 1 (position 0 is the version in GNU-R) + var builder = new LabelMapping.Builder(0, 1); + for (BcInstr instr : bc) { + var op = instr.op(); + var size = 1 + op.nArgs(); + builder.step(1, size); + } + return builder.build(); + } + + // FIXME: inline + static class Builder { + private final ImmutableIntArray.Builder map; + private int targetPc; + + Builder(int initialSourceOffset, int initialTargetOffset) { + map = ImmutableIntArray.builder(); + + // targetPc should start at initialTargetOffset, and map should be padded by -1 repeated + // initialSourceOffset times. + targetPc = initialTargetOffset; + for (int i = 0; i < initialSourceOffset; i++) { + map.add(-1); + } + } + + /** Step m times in the source bytecode and n times in the target bytecode */ + void step(int sourceOffset, @SuppressWarnings("SameParameterValue") int targetOffset) { + if (sourceOffset < 0 || targetOffset < 0) { + throw new IllegalArgumentException("offsets must be nonnegative"); + } + + // Add target position + if (sourceOffset > 0) { + map.add(targetPc); + } + + // "allocate" positions for the arguments afterward + for (int i = 0; i < sourceOffset - 1; i++) { + map.add(-1); + } + targetPc += targetOffset; + } + + LabelMapping build() { + // Add the final offset + map.add(targetPc); + return new LabelMapping(map.build()); + } + } +} diff --git a/src/main/java/org/prlprg/rds/RDSInputStream.java b/src/main/java/org/prlprg/rds/RDSInputStream.java index ea7db7c83..eafc2c395 100644 --- a/src/main/java/org/prlprg/rds/RDSInputStream.java +++ b/src/main/java/org/prlprg/rds/RDSInputStream.java @@ -49,18 +49,18 @@ public String readString(int natEncSize, Charset charset) throws IOException { public int[] readInts(int length) throws IOException { int[] ints = new int[length]; for (int i = 0; i < length; i++) { - var n = readInt(); + var n = in.readInt(); ints[i] = n; } return ints; } public double[] readDoubles(int length) throws IOException { - double[] ints = new double[length]; + double[] doubles = new double[length]; for (int i = 0; i < length; i++) { - var n = readDouble(); - ints[i] = n; + var n = in.readDouble(); + doubles[i] = n; } - return ints; + return doubles; } } diff --git a/src/main/java/org/prlprg/rds/RDSOutputStream.java b/src/main/java/org/prlprg/rds/RDSOutputStream.java new file mode 100644 index 000000000..dccd27c1e --- /dev/null +++ b/src/main/java/org/prlprg/rds/RDSOutputStream.java @@ -0,0 +1,54 @@ +package org.prlprg.rds; + +import java.io.Closeable; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class RDSOutputStream implements Closeable { + private final DataOutputStream out; + + RDSOutputStream(OutputStream out) { + this.out = new DataOutputStream(out); + } + + @Override + public void close() throws IOException { + out.close(); + } + + public void writeByte(byte v) throws IOException { + out.writeByte(v); + } + + public void writeInt(int v) throws IOException { + out.writeInt(v); + } + + public void writeDouble(double v) throws IOException { + out.writeDouble(v); + } + + /** + * Writes a series of bytes to the output stream. + * + *

Note: This replaces the writeString method. This is done since the representation of + * "length" when reading a String is not the actual length of the string in characters, but + * the length of the String in bytes. + */ + public void writeBytes(byte[] v) throws IOException { + out.write(v); + } + + public void writeInts(int[] v) throws IOException { + for (int e : v) { + out.writeInt(e); + } + } + + public void writeDoubles(double[] v) throws IOException { + for (double e : v) { + out.writeDouble(e); + } + } +} diff --git a/src/main/java/org/prlprg/rds/RDSReader.java b/src/main/java/org/prlprg/rds/RDSReader.java index 108753145..1e61421db 100644 --- a/src/main/java/org/prlprg/rds/RDSReader.java +++ b/src/main/java/org/prlprg/rds/RDSReader.java @@ -1,5 +1,7 @@ package org.prlprg.rds; +import static java.util.Objects.requireNonNull; + import com.google.common.collect.ImmutableList; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.Closeable; @@ -8,7 +10,6 @@ import java.io.IOException; import java.io.InputStream; import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; import java.util.*; import javax.annotation.Nullable; import org.prlprg.RSession; @@ -40,24 +41,42 @@ import org.prlprg.util.IO; public class RDSReader implements Closeable { + private final RSession rsession; private final RDSInputStream in; private final List refTable = new ArrayList<>(128); // FIXME: this should include the logic from platform.c - private Charset nativeEncoding = Charset.defaultCharset(); + // or should we individually read the charset property of each SEXP? this will require + // verifying which SEXPs we need to write the charset for--just CHARSXP, or also + // builtin/special? + private final Charset nativeEncoding = Charset.defaultCharset(); private RDSReader(RSession session, InputStream in) { this.rsession = session; this.in = new RDSInputStream(in); } + /** + * Reads a SEXP from the provided file. + * + * @param session The current R session, used to supply special constructs such as the base + * environment and namespace + * @param file The file to read from + */ public static SEXP readFile(RSession session, File file) throws IOException { try (var input = new FileInputStream(file)) { return readStream(session, IO.maybeDecompress(input)); } } + /** + * Reads a SEXP from the provided {@code InputStream}. + * + * @param session The current R session, used to supply special constructs such as the base + * environment and namespace + * @param input The stream to read from + */ public static SEXP readStream(RSession session, InputStream input) throws IOException { try (var reader = new RDSReader(session, input)) { return reader.read(); @@ -65,12 +84,11 @@ public static SEXP readStream(RSession session, InputStream input) throws IOExce } private void readHeader() throws IOException { - var type = in.readByte(); - if (type != 'X') { + + if (in.readByte() != 'X') { throw new RDSException("Unsupported type (possibly compressed)"); } - var nl = in.readByte(); - assert nl == '\n'; + assert in.readByte() == '\n'; // versions var formatVersion = in.readInt(); @@ -78,7 +96,6 @@ private void readHeader() throws IOException { // we do not support RDS version 3 because it uses ALTREP throw new RDSException("Unsupported RDS version: " + formatVersion); } - // writer version in.readInt(); // minimal reader version @@ -88,6 +105,7 @@ private void readHeader() throws IOException { public SEXP read() throws IOException { readHeader(); var sexp = readItem(); + if (in.readRaw() != -1) { throw new RDSException("Expected end of file"); } @@ -114,8 +132,7 @@ private SEXP readItem() throws IOException { case BCODE -> readByteCode(); case EXPR -> readExpr(flags); case PROM -> readPromise(flags); - case BUILTIN -> readBuiltin(false); - case SPECIAL -> readBuiltin(true); + case BUILTIN, SPECIAL -> readBuiltinOrSpecial(); case CPLX -> readComplex(flags); default -> throw new RDSException("Unsupported SEXP type: " + s.sexp()); }; @@ -148,25 +165,37 @@ private SEXP readComplex(Flags flags) throws IOException { cplx.add(new Complex(real, im)); } var attributes = readAttributes(flags); + return SEXPs.complex(cplx.build(), attributes); } - private SEXP readBuiltin(boolean special) throws IOException { + private SEXP readBuiltinOrSpecial() throws IOException { var length = in.readInt(); var name = in.readString(length, nativeEncoding); - return special ? SEXPs.special(name) : SEXPs.builtin(name); + + // For now, we throw an exception upon reading any SpecialSXP or BuiltinSXP. This is because + // RDS serializes builtins via their name, but we do not have any (fully implemented) construct + // representing the name of a builtin (instead, they are represented with indices) + throw new UnsupportedOperationException("Unable to read builtin: " + name); + + // Spec for future implementation: + // - return SEXPs.builtin() or SEXPs.special() depending on the boolean passed to the method } private SEXP readPromise(Flags flags) throws IOException { + // FIXME: do something with the attributes here? readAttributes(flags); var tag = flags.hasTag() ? readItem() : SEXPs.NULL; var val = readItem(); var expr = readItem(); if (tag instanceof NilSXP) { + // If the tag is nil, the promise is evaluated return new PromSXP(expr, val, SEXPs.EMPTY_ENV); } else if (tag instanceof EnvSXP env) { - return new PromSXP(expr, val, env); + // Otherwise, the promise is lazy. We represent lazy promises as having a val of + // SEXPs.UNBOUND_VALUE, so we set it here accordingly + return new PromSXP(expr, SEXPs.UNBOUND_VALUE, env); } else { throw new RDSException("Expected promise ENV to be environment"); } @@ -184,6 +213,7 @@ private SEXP readNamespace() throws IOException { return namespace; } + // Note that this method is not used to StringSXPs. private StrSXP readStringVec() throws IOException { if (in.readInt() != 0) { // cf. InStringVec @@ -192,7 +222,6 @@ private StrSXP readStringVec() throws IOException { var length = in.readInt(); var strings = new ArrayList(length); - for (int i = 0; i < length; i++) { strings.add(readChars()); } @@ -208,12 +237,14 @@ private ExprSXP readExpr(Flags flags) throws IOException { } Attributes attributes = readAttributes(flags); + return SEXPs.expr(sexps, attributes); } private BCodeSXP readByteCode() throws IOException { var length = in.readInt(); var reps = new SEXP[length]; + return readByteCode1(reps); } @@ -231,15 +262,18 @@ private BCodeSXP readByteCode1(SEXP[] reps) throws IOException { var consts = readByteCodeConsts(reps); var factory = new GNURByteCodeDecoderFactory(code.data(), consts); + return SEXPs.bcode(factory.create()); } private List readByteCodeConsts(SEXP[] reps) throws IOException { var length = in.readInt(); + var consts = new ArrayList(length); for (int i = 0; i < length; i++) { - var type = RDSItemType.valueOf(in.readInt()); - switch (type) { + var type = in.readInt(); + + switch (RDSItemType.valueOf(type)) { case RDSItemType.Sexp s -> { switch (s.sexp()) { case BCODE -> consts.add(readByteCode1(reps)); @@ -256,21 +290,46 @@ private List readByteCodeConsts(SEXP[] reps) throws IOException { } } } + return consts; } - private SEXP readByteCodeLang(RDSItemType type, SEXP[] reps) throws IOException { - return switch (type) { - case RDSItemType.Sexp s -> - switch (s.sexp()) { - case LANG, LIST -> readByteCodeLang1(type, reps); - default -> readItem(); - }; + // `type` will not necessarily correspond with a valid RDSItemType. If the next value in the + // stream is not a LangSXP, ListSXP, or bytecode reference / definition, then the stream will + // include a padding int so the next SEXP can be read in full. This is why the function accepts + // an integer instead of an RDSItemType. + private SEXP readByteCodeLang(int type, SEXP[] reps) throws IOException { + // If the type is 0, we encountered a padding bit, meaning we jump back to "regular" SEXP + // processing. + if (type == 0) { + return readItem(); + } + + // Otherwise, we continue with bytecode processing + var rdsType = RDSItemType.valueOf(type); + + return switch (rdsType) { + case RDSItemType.Sexp s -> { + if (s.sexp() == SEXPType.LANG || s.sexp() == SEXPType.LIST) { + yield readByteCodeLang1(rdsType, reps); + } else { + throw new UnsupportedOperationException( + "RDS reader error when reading SEXP: expected a padding bit, lang or list SXP, got: " + + rdsType); + } + } case RDSItemType.Special s -> switch (s) { - case BCREPREF -> reps[in.readInt()]; - case BCREPDEF, ATTRLISTSXP, ATTRLANGSXP -> readByteCodeLang1(type, reps); - default -> readItem(); + case BCREPREF -> { + int pos = in.readInt(); + yield reps[pos]; + } + case BCREPDEF, ATTRLISTSXP, ATTRLANGSXP -> readByteCodeLang1(rdsType, reps); + default -> + throw new UnsupportedOperationException( + "RDS reader error when reading special: expected a padding bit, BCREPDEF, " + + "BCREPREF, ATTRLISTSXP, or ATTRLANGSXP, got: " + + rdsType); }; }; } @@ -279,7 +338,8 @@ private SEXP readByteCodeLang1(RDSItemType type, SEXP[] reps) throws IOException var pos = -1; if (type == RDSItemType.Special.BCREPDEF) { pos = in.readInt(); - type = RDSItemType.valueOf(in.readInt()); + var type_i = in.readInt(); + type = RDSItemType.valueOf(type_i); } var attributes = (type == RDSItemType.Special.ATTRLANGSXP || type == RDSItemType.Special.ATTRLISTSXP) @@ -306,8 +366,11 @@ private SEXP readByteCodeLang1(RDSItemType type, SEXP[] reps) throws IOException throw new RDSException("Expected regular symbol or nil"); } - var head = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); - var tail = readByteCodeLang(RDSItemType.valueOf(in.readInt()), reps); + var headType = in.readInt(); + var head = readByteCodeLang(headType, reps); + + var tailType = in.readInt(); + var tail = readByteCodeLang(tailType, reps); ListSXP tailList; @@ -349,15 +412,23 @@ private SEXP readByteCodeLang1(RDSItemType type, SEXP[] reps) throws IOException private SEXP readRef(Flags flags) throws IOException { var index = flags.unpackRefIndex(); + // if index is 0, it was too large to be packed with the flags and was therefore written + // afterward if (index == 0) { index = in.readInt(); } + + // since index is 1-based return refTable.get(index - 1); } private LangSXP readLang(Flags flags) throws IOException { var attributes = readAttributes(flags); - // FIXME: not sure what it is good for + + // We do not support tags for LangSXPs. It is technically possible to have a tag on a LangSXP + // by adding it manually, but this is very rare. It is more common for the arguments, which + // are members of a ListSXP, to have names associated with them. As such, we read and discard + // any tag that might be present. readTag(flags); if (!(readItem() instanceof SymOrLangSXP fun)) { @@ -374,23 +445,30 @@ private String readChars() throws IOException { if (!flags.getType().isSexp(SEXPType.CHAR)) { throw new RDSException("Expected CHAR"); } + var encoding = flags.getLevels().encoding(); + var length = in.readInt(); + + String out; if (length == -1) { - return Constants.NA_STRING; + out = Constants.NA_STRING; } else { - return in.readString(length, nativeEncoding); + assert encoding != null; + out = in.readString(length, encoding); } + + return out; } private StrSXP readStrs(Flags flags) throws IOException { var length = in.readInt(); var strings = ImmutableList.builderWithExpectedSize(length); - for (int i = 0; i < length; i++) { strings.add(readChars()); } var attributes = readAttributes(flags); + return SEXPs.string(strings.build(), attributes); } @@ -416,7 +494,7 @@ private UserEnvSXP readEnv() throws IOException { case NilSXP _ -> {} case ListSXP frame -> { for (var elem : frame) { - item.set(Objects.requireNonNull(elem.tag()), elem.value()); + item.set(requireNonNull(elem.tag()), elem.value()); } } default -> throw new RDSException("Expected list (FRAME)"); @@ -431,7 +509,7 @@ private UserEnvSXP readEnv() throws IOException { case NilSXP _ -> {} case ListSXP list -> { for (var e : list) { - item.set(Objects.requireNonNull(e.tag()), e.value()); + item.set(requireNonNull(e.tag()), e.value()); } } default -> throw new RDSException("Expected list for the hashtab entries"); @@ -451,6 +529,7 @@ private VecSXP readVec(Flags flags) throws IOException { data.add(readItem()); } var attributes = readAttributes(flags); + return SEXPs.vec(data.build(), attributes); } @@ -458,6 +537,7 @@ private LglSXP readLogicals(Flags flags) throws IOException { var length = in.readInt(); var data = in.readInts(length); var attributes = readAttributes(flags); + return SEXPs.logical( Arrays.stream(data).mapToObj(Logical::valueOf).collect(ImmutableList.toImmutableList()), attributes); @@ -467,6 +547,7 @@ private RealSXP readReals(Flags flags) throws IOException { var length = in.readInt(); var data = in.readDoubles(length); var attributes = readAttributes(flags); + return SEXPs.real(data, attributes); } @@ -474,6 +555,7 @@ private IntSXP readInts(Flags flags) throws IOException { var length = in.readInt(); var data = in.readInts(length); var attributes = readAttributes(flags); + return SEXPs.integer(data, attributes); } @@ -501,6 +583,7 @@ private ListSXP readList(Flags flags) throws IOException { flags = readFlags(); } + // TODO: add the attributes here? return SEXPs.list(data.build()); } @@ -557,7 +640,7 @@ private Attributes readAttributes() throws IOException { private RegSymSXP readSymbol() throws IOException { var flags = readFlags(); - var s = readString(flags); + var s = readChars(flags); var item = SEXPs.symbol(s); refTable.add(item); @@ -565,13 +648,10 @@ private RegSymSXP readSymbol() throws IOException { return item; } - private String readString(Flags flags) throws IOException { + private String readChars(Flags flags) throws IOException { var len = in.readInt(); - var charset = StandardCharsets.US_ASCII; - - if (flags.isUTF8()) { - charset = StandardCharsets.UTF_8; - } + // charset should never be null for strings + var charset = requireNonNull(flags.getLevels().encoding()); return in.readString(len, charset); } diff --git a/src/main/java/org/prlprg/rds/RDSWriter.java b/src/main/java/org/prlprg/rds/RDSWriter.java new file mode 100644 index 000000000..17808eff7 --- /dev/null +++ b/src/main/java/org/prlprg/rds/RDSWriter.java @@ -0,0 +1,599 @@ +package org.prlprg.rds; + +import java.io.*; +import java.nio.charset.Charset; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.DoubleStream; +import java.util.stream.StreamSupport; +import org.prlprg.RVersion; +import org.prlprg.primitive.Logical; +import org.prlprg.sexp.*; +import org.prlprg.util.UnreachableError; + +public class RDSWriter implements Closeable { + + private final RDSOutputStream out; + // refIndex is 1-based, so the first ref will have index 1 + private int refIndex = 1; + private final HashMap refTable = new HashMap<>(128); + + protected RDSWriter(OutputStream out) { + this.out = new RDSOutputStream(out); + } + + /** + * Writes a SEXP to the provided output stream. + * + * @param output The stream to write to + * @param sexp the SEXP to write + */ + public static void writeStream(OutputStream output, SEXP sexp) throws IOException { + try (var writer = new RDSWriter(output)) { + writer.write(sexp); + } + } + + /** + * Writes a SEXP to the provided file. + * + * @param file The file to write to + * @param sexp the SEXP to write + */ + public static void writeFile(File file, SEXP sexp) throws IOException { + try (var output = new FileOutputStream(file)) { + writeStream(output, sexp); + } + } + + public void writeHeader() throws IOException { + // Could also be "B" (binary) and "A" (ASCII) but we only support XDR. + // XDR just means big endian and DataInputStream/DataOutputStream from Java use BigEndian + out.writeByte((byte) 'X'); + + out.writeByte((byte) '\n'); + + // Write version 2 of the encoding, since we want the writer to align with the reader and + // the reader does not support ALTREP + out.writeInt(2); + + // Version of R for the writer + out.writeInt(RVersion.LATEST_AWARE.encode()); + + // Minimal version of R required to read back + out.writeInt(new RVersion(2, 3, 0, null).encode()); + } + + public void write(SEXP sexp) throws IOException { + writeHeader(); + writeItem(sexp); + } + + // See + // https://github.com/wch/r-source/blob/65892cc124ac20a44950e6e432f9860b1d6e9bf4/src/main/serialize.c#L1021 + public void writeItem(SEXP s) throws IOException { + // Write the flags for this SEXP. This will vary depending on whether the SEXP is a special + // RDS type or not + var flags = flags(s); + out.writeInt(flags.encode()); + + switch (flags.getType()) { + // Special types not handled by Save Special hooks + case RDSItemType.Special special -> { + switch (special) { + case RDSItemType.Special.NAMESPACESXP -> { + // add to the ref table + refAdd(s); + // write details about the namespace + var namespace = (NamespaceEnvSXP) s; + writeStringVec(SEXPs.string(namespace.name(), namespace.version())); + } + case RDSItemType.Special.REFSXP -> { + // If the flags encoded a reference, then we may need to write the ref index (only if + // it was too large to be packed in the flags) + if (flags.unpackRefIndex() == 0) { + out.writeInt(refIndex); + } + } + default -> { + /* nothing to write */ + } + } + } + case RDSItemType.Sexp _ -> { + // Otherwise, write the sexp as normal + switch (s) { + case SymSXP sym -> writeSymbol(sym); + case EnvSXP env -> writeEnv(env); + case ListSXP list -> writeListSXP(list); + case LangSXP lang -> writeLangSXP(lang); + case PromSXP prom -> writePromSXP(prom); + case CloSXP clo -> writeCloSXP(clo); + case BuiltinOrSpecialSXP bos -> writeBuiltinOrSpecialSXP(bos); + case VectorSXP vec -> writeVectorSXP(vec); + case BCodeSXP bc -> writeByteCode(bc); + } + } + } + } + + // UTILITY (not standard SEXPs) ----------------------------------------------------------------- + + /** Adds s to the ref table at the next available index */ + private void refAdd(SEXP s) { + refTable.put(s, refIndex++); + } + + /** + * Determines if the hasTag bit should be set based on the SEXP. + * + *

The meaning of "tag" varies depending on the SEXP. In C, it is defined based on the position + * of the field (namely, the last field of the struct). So, it corresponds with the fields as + * follows: + * + *

    + *
  • {@link CloSXP}: the closure's environment + *
  • {@link PromSXP}: the promise's environment (null if the promise has been evaluated) + *
  • {@link ListSXP}: the name assigned to an element + *
  • {@link LangSXP}: a name assigned to the function (we do not support this since it's such + * a rare case) + *
+ */ + private boolean hasTag(SEXP s) { + return switch (s) { + // CloSXP should always be marked as having a tag + case CloSXP _ -> true; + // The tag for a LangSXP is manually assigned to the function (this is very rare). We + // don't support them. + case LangSXP _ -> false; + // In GNUR, the tag of a promise is its environment. The environment is set to null once + // the promise is evaluated. So, hasTag should return true if and only if the promise is + // unevaluated (lazy) + case PromSXP prom -> prom.isLazy(); + // hasTag is based on the first element + case ListSXP list -> !list.isEmpty() && list.get(0).hasTag(); + default -> false; + }; + } + + /** Determines the RDS item type associated with the provided SEXP. */ + private RDSItemType rdsType(SEXP s) { + return switch (s) { + // "Save Special hooks" from serialize.c + case NilSXP _ -> RDSItemType.Special.NILVALUE_SXP; + case ListSXP l when l.isEmpty() -> RDSItemType.Special.NILVALUE_SXP; + case EmptyEnvSXP _ -> RDSItemType.Special.EMPTYENV_SXP; + case BaseEnvSXP _ -> RDSItemType.Special.BASEENV_SXP; + case GlobalEnvSXP _ -> RDSItemType.Special.GLOBALENV_SXP; + case SpecialSymSXP sexp when sexp == SEXPs.UNBOUND_VALUE -> + RDSItemType.Special.UNBOUNDVALUE_SXP; + case SpecialSymSXP sexp when sexp == SEXPs.MISSING_ARG -> RDSItemType.Special.MISSINGARG_SXP; + // Non-"Save Special" cases + case NamespaceEnvSXP ns -> RDSItemType.Special.NAMESPACESXP; + default -> new RDSItemType.Sexp(s.type()); + }; + } + + /** + * Returns the flags associated with the provided SEXP. If the SEXP is already present in the ref + * table, will return flags associated with its reference index. + */ + final int MAX_PACKED_INDEX = Integer.MAX_VALUE >> 8; + + private Flags flags(SEXP s) { + // If refIndex is greater than 0, the object has already been written, so we can + // write a reference (since the index is 1-based) + int sexpRefIndex = refTable.getOrDefault(s, 0); + if (sexpRefIndex > 0) { + if (sexpRefIndex > MAX_PACKED_INDEX) { + // If the reference index can't be packed in the flags, it will be written afterward + return new Flags(RDSItemType.Special.REFSXP, 0); + } else { + // Otherwise, pack the reference index in the flags + return new Flags(RDSItemType.Special.REFSXP, sexpRefIndex); + } + } + + // Otherwise, write flags based on the RDSType of s + // FIXME: we should actually get the proper "locked" flag, but we currently don't have a + // representation of this in our environments + return new Flags(rdsType(s), new GPFlags(), s.isObject(), s.hasAttributes(), hasTag(s)); + } + + /** + * Writes an {@link Attributes} to the output stream, throwing an exception if it is empty. As + * such, it is essential to check if an object has attributes before invoking this method. + */ + private void writeAttributes(Attributes attrs) throws IOException { + if (attrs.isEmpty()) + throw new IllegalArgumentException("Cannot write an empty set of attributes"); + // convert to ListSXP + var l = attrs.entrySet().stream().map(e -> new TaggedElem(e.getKey(), e.getValue())).toList(); + // Write it + writeItem(SEXPs.list(l)); + } + + private void writeAttributesIfPresent(SEXP s) throws IOException { + if (s.hasAttributes()) writeAttributes(Objects.requireNonNull(s.attributes())); + } + + /** Writes the tag of the provided TaggedElem, if one exists. If none exists, does nothing. */ + private void writeTagIfPresent(TaggedElem elem) throws IOException { + if (elem.hasTag()) { + // Convert the tag to a symbol, since we need to add it to the ref table + writeItem(Objects.requireNonNull(elem.tagAsSymbol())); + } + } + + /** + * Writes a String to the output in the format expected by RDS. Since R represents Strings as a + * CHARSXP, we need to write metadata like flags. + */ + private void writeChars(String s) throws IOException { + var flags = + new Flags( + RDSItemType.valueOf(SEXPType.CHAR.i), + new GPFlags(Charset.defaultCharset(), false), + false, + false, + false); + out.writeInt(flags.encode()); + + // If the string is NA, we write -1 for the length and exit + if (Coercions.isNA(s)) { + out.writeInt(-1); + return; + } + + // Otherwise, do a standard string write (length in bytes, then bytes) + var bytes = s.getBytes(Charset.defaultCharset()); + out.writeInt(bytes.length); + out.writeBytes(bytes); + } + + /** + * Writes a StrSXP with an unused placeholder "name" int before the length. + * + * @apiNote this is NOT used to write regular StrSXPs. It is currently only used to write + * namespace and package environment spec. + */ + private void writeStringVec(StrSXP s) throws IOException { + out.writeInt(0); + out.writeInt(s.size()); + + for (String str : s) { + writeChars(str); + } + } + + // STANDARD SEXPs ------------------------------------------------------------------------------- + + private void writeEnv(EnvSXP env) throws IOException { + // Add to the ref table + refAdd(env); + + if (env instanceof UserEnvSXP userEnv) { + // Write 1 if the environment is locked, or 0 if it is not + // FIXME: implement locked environments, as is this will always be false + out.writeInt(new GPFlags().isLocked() ? 1 : 0); + // Enclosure + writeItem(userEnv.parent()); + // Frame + writeItem(userEnv.frame()); + // Hashtab (NULL or VECSXP) + writeItem(SEXPs.NULL); // simple version here. + // Otherwise, we would have to actually do the hashing as it is done in R + + // Attributes + // R always write something here, as it does not write a hastag bit in the flags + // (it actually has no flags; it just writes the type ENV) + if (env.hasAttributes()) { + writeAttributes(env.attributes()); + } else { + writeItem(SEXPs.NULL); + } + } else { + throw new UnreachableError("Implemented as special RDS type: " + env.type()); + } + } + + private void writeSymbol(SymSXP s) throws IOException { + switch (s) { + case RegSymSXP rs -> { + // Add to the ref table + refAdd(rs); + // Write the symbol + writeChars(rs.name()); + } + case SpecialSymSXP specialSymSXP when specialSymSXP.isEllipsis() -> { + writeChars("..."); // Really? + } + default -> + throw new UnsupportedOperationException("Unreachable: implemented in special sexps."); + } + } + + private void writeBuiltinOrSpecialSXP(BuiltinOrSpecialSXP bos) throws IOException { + // For now, we throw an exception upon writing any SpecialSXP or BuiltinSXP. This is because + // RDS serializes builtins via their name, but we do not have any (fully implemented) construct + // representing the name of a builtin (instead, they are represented with indices) + throw new UnsupportedOperationException("Unable to write builtin: " + bos); + + // Spec for future implementation: + // - write an int representing the length of the BuiltinOrSpecialSXP's name + // - write the name as a String (not a CHARSXP, in that no additional flags are written) + } + + private void writeListSXP(ListSXP lsxp) throws IOException { + Flags listFlags = flags(lsxp); + + // Write the first element. This case is separate because: + // - the first element may have attributes + // - the first element's tag has already been written + writeAttributesIfPresent(lsxp); + + var first = lsxp.get(0); + writeTagIfPresent(first); + writeItem(first.value()); + + // Write the rest of the list + for (var el : lsxp.subList(1)) { + // Write flags + var itemFlags = listFlags.withTag(el.hasTag()).withAttributes(false); + out.writeInt(itemFlags.encode()); + // Write tag + writeTagIfPresent(el); + // Write item + writeItem(el.value()); + } + + // Write a NilSXP to end the list + writeItem(SEXPs.NULL); + } + + private void writeLangSXP(LangSXP lang) throws IOException { + writeAttributesIfPresent(lang); + // LangSXPs can have tags, but we don't support them, so no tag is written here + writeItem(lang.fun()); + writeItem(lang.args()); + } + + private void writePromSXP(PromSXP prom) throws IOException { + writeAttributesIfPresent(prom); + + // TODO: test that this is the correct order of arguments + + // Only write the + if (prom.isLazy()) { + writeItem(prom.env()); + } + + writeItem(prom.val()); + writeItem(prom.expr()); + } + + private void writeCloSXP(CloSXP clo) throws IOException { + writeAttributesIfPresent(clo); + // a closure has the environment, formals, and then body + writeItem(clo.env()); + writeItem(clo.parameters()); + writeItem(clo.body()); + } + + private void writeVectorSXP(VectorSXP s) throws IOException { + var length = s.size(); + out.writeInt(length); + + switch (s) { + case VecSXP vec -> { + // Write all the elements of the vec as individual items + for (var val : vec) { + writeItem(val); + } + } + case ExprSXP exprs -> { + // Write all the exprs to the stream as individual items + for (var val : exprs) { + writeItem(val); + } + } + case IntSXP ints -> { + // Write all the ints to the stream + var vec = StreamSupport.stream(ints.spliterator(), false).mapToInt(i -> i).toArray(); + out.writeInts(vec); + } + case LglSXP lgls -> { + // Write all the logicals to the stream as ints + var vec = + StreamSupport.stream(lgls.spliterator(), false).mapToInt(Logical::toInt).toArray(); + out.writeInts(vec); + } + case RealSXP reals -> { + // Write all the reals to the stream as doubles + var vec = StreamSupport.stream(reals.spliterator(), false).mapToDouble(d -> d).toArray(); + out.writeDoubles(vec); + } + case ComplexSXP cplxs -> { + // For each complex number in the vector, add two doubles representing the real and + // imaginary components via a flat map + var doubles = + StreamSupport.stream(cplxs.spliterator(), false) + .flatMapToDouble(c -> DoubleStream.builder().add(c.real()).add(c.imag()).build()) + .toArray(); + out.writeDoubles(doubles); + } + case StrSXP strs -> { + // For each string in the vector, we write its chars because R represents each string as a + // CHARSXP + for (String str : strs) { + writeChars(str); + } + } + + default -> throw new RuntimeException("Unreachable: implemented in another branch."); + } + + writeAttributesIfPresent(s); + } + + // BYTECODE ------------------------------------------------------------------------------------- + + private void scanForCircles(SEXP sexp, HashMap reps, HashSet seen) { + switch (sexp) { + case LangOrListSXP lol -> { + if (seen.contains(lol)) { + // Add to reps if the cell has already been seen + // We put -1 for the time being so that we can update reps in the correct order later + reps.put(lol, -1); + return; + } + // Otherwise, add to seen and scan recursively + seen.add(lol); + + switch (lol) { + case LangSXP lang -> { + // For LangSXP, we want to scan both the function and the arg values + scanForCircles(lang.fun(), reps, seen); + lang.args().values().forEach((el) -> scanForCircles(el, reps, seen)); + } + case ListSXP list -> { + // For ListSXP, we scan the values + list.values().forEach((el) -> scanForCircles(el, reps, seen)); + } + } + } + case BCodeSXP bc -> { + // For bytecode, we scan the constant pool + bc.bc().consts().forEach((el) -> scanForCircles(el, reps, seen)); + } + default -> { + // do nothing + } + } + } + + private void writeByteCodeLang(SEXP s, HashMap reps, AtomicInteger nextRepIndex) + throws IOException { + if (s instanceof LangOrListSXP lol && lol.type() != SEXPType.NIL) { + var assignedRepIndex = reps.get(lol); + if (assignedRepIndex != null) { + if (assignedRepIndex == -1) { + // If the rep is present in the map but is -1, this is our first time seeing it, so we + // emit a BCREPDEF and update the counter + int newIndex = nextRepIndex.getAndIncrement(); + reps.put(lol, newIndex); + + out.writeInt(RDSItemType.Special.BCREPDEF.i()); + out.writeInt(newIndex); + } else { + // If the rep is present with an index other than -1, we have already seen it, so we + // emit a BCREPREF with the reference index. + out.writeInt(RDSItemType.Special.BCREPREF.i()); + out.writeInt(assignedRepIndex); + // We also return, since the child nodes have already been written, and we don't want + // to write them again + return; + } + } + + var type = RDSItemType.valueOf(lol.type().i); + + // if the item has attributes, we use the special types ATTRLANGSXP and ATTRLISTSXP instead + // of LangSXP and ListSXP. This is done to preserve information on expressions in the + // constant pool of byte code objects. + if (lol.hasAttributes()) { + type = + switch (lol) { + case LangSXP _lang -> RDSItemType.Special.ATTRLANGSXP; + case ListSXP _list -> RDSItemType.Special.ATTRLISTSXP; + }; + } + out.writeInt(type.i()); + writeAttributesIfPresent(lol); + + switch (lol) { + // For a LangSXP, recursively write the function and args + case LangSXP lang -> { + // The tag of a LangSXP is an argument name, but it does not seem that we support them. + writeItem(SEXPs.NULL); + // write head + writeByteCodeLang(lang.fun(), reps, nextRepIndex); + // write tail + writeByteCodeLang(lang.args(), reps, nextRepIndex); + } + // For a ListSXP, recursively write the elements + case ListSXP list -> { + // there will always be a first element because we take a different path when the list + // is empty + var first = list.stream().findFirst().orElseThrow(); + SEXP tag = first.tag() == null ? SEXPs.NULL : SEXPs.symbol(first.tag()); + + // write tag + writeItem(tag); + // write head + writeByteCodeLang(list.value(0), reps, nextRepIndex); + // write tail + writeByteCodeLang(list.subList(1), reps, nextRepIndex); + } + } + } else { // Print a zero as padding and write the item normally + out.writeInt(0); + writeItem(s); + } + } + + private void writeByteCode1(BCodeSXP s, HashMap reps, AtomicInteger nextRepIndex) + throws IOException { + // Decode the bytecode (we will get a vector of integers) + // write the vector of integers + var encoder = new GNURByteCodeEncoderFactory(s.bc()); + + var code_bytes = encoder.buildRaw(); + writeItem(SEXPs.integer(code_bytes.getInstructions())); + writeByteCodeConsts(code_bytes.getConsts(), reps, nextRepIndex); + } + + private void writeByteCodeConsts( + List consts, HashMap reps, AtomicInteger nextRepIndex) + throws IOException { + // write the number of consts in the bytecode + // iterate the consts: if it s bytecode, write the type and recurse + // if it is langsxp or listsxp, write them , using the BCREDPEF, ATTRALANGSXP and ATTRLISTSXP + // else write the type and the value + out.writeInt(consts.size()); + + // Iterate the constant pool and write the values + for (var c : consts) { + switch (c) { + case BCodeSXP bc -> { + out.writeInt(c.type().i); + writeByteCode1(bc, reps, nextRepIndex); + } + case LangOrListSXP l -> { + // writeBCLang writes the type i + writeByteCodeLang(l, reps, nextRepIndex); + } + default -> { + out.writeInt(c.type().i); + writeItem(c); + } + } + } + } + + private void writeByteCode(BCodeSXP s) throws IOException { + // Scan for circles + var reps = new HashMap(); + var seen = new HashSet(); + scanForCircles(s, reps, seen); + out.writeInt(reps.size() + 1); + + var nextRepIndex = new AtomicInteger(0); + writeByteCode1(s, reps, nextRepIndex); + } + + @Override + public void close() throws IOException { + out.close(); + } +} diff --git a/src/main/java/org/prlprg/sexp/BaseEnvSXP.java b/src/main/java/org/prlprg/sexp/BaseEnvSXP.java index d1a12a75b..643ee1c74 100644 --- a/src/main/java/org/prlprg/sexp/BaseEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/BaseEnvSXP.java @@ -17,6 +17,11 @@ public EmptyEnvSXP parent() { return (EmptyEnvSXP) super.parent(); } + @Override + public int size() { + return bindings.size(); + } + @Override public void setParent(StaticEnvSXP parent) { if (parent instanceof EmptyEnvSXP e) { diff --git a/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java b/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java index 783c5d8f5..ce564279d 100644 --- a/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/EmptyEnvSXP.java @@ -4,7 +4,6 @@ import java.util.Optional; import java.util.Set; import org.jetbrains.annotations.UnmodifiableView; -import org.prlprg.parseprint.Printer; import org.prlprg.util.Pair; public final class EmptyEnvSXP implements StaticEnvSXP { @@ -47,6 +46,11 @@ public Optional getLocal(String name) { return Optional.empty(); } + @Override + public int size() { + return 0; + } + @Override public Optional> find(String name) { return Optional.empty(); @@ -57,11 +61,6 @@ public Optional> find(String name) { return Set.of(); } - @Override - public int size() { - return 0; - } - @Override public EnvType envType() { return EnvType.EMPTY; @@ -69,6 +68,6 @@ public EnvType envType() { @Override public String toString() { - return Printer.toString(this); + return ""; } } diff --git a/src/main/java/org/prlprg/sexp/EnvSXP.java b/src/main/java/org/prlprg/sexp/EnvSXP.java index 89d544257..84bfdfb75 100644 --- a/src/main/java/org/prlprg/sexp/EnvSXP.java +++ b/src/main/java/org/prlprg/sexp/EnvSXP.java @@ -49,6 +49,13 @@ public sealed interface EnvSXP extends SEXP permits StaticEnvSXP, UserEnvSXP { */ Optional getLocal(String name); + /** + * Get the number of symbols in the environment (locally) + * + * @return the number of symbols in the environment + */ + int size(); + @Override default SEXPType type() { return SEXPType.ENV; @@ -75,13 +82,6 @@ default Iterable bindingsAsTaggedElems() { return streamBindingsAsTaggedElems()::iterator; } - /** - * Get the number of symbols in the environment. - * - * @return the number of symbols in the environment - */ - int size(); - /** Whether this is a user, global, namespace, base, or empty environment. */ EnvType envType(); diff --git a/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java b/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java index adc95461e..b41948842 100644 --- a/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/GlobalEnvSXP.java @@ -12,6 +12,11 @@ public GlobalEnvSXP(StaticEnvSXP parent, Map bindings) { this.bindings.putAll(bindings); } + @Override + public int size() { + return bindings.size(); + } + @Override public EnvType envType() { return EnvType.GLOBAL; diff --git a/src/main/java/org/prlprg/sexp/LangOrListSXP.java b/src/main/java/org/prlprg/sexp/LangOrListSXP.java new file mode 100644 index 000000000..108f57779 --- /dev/null +++ b/src/main/java/org/prlprg/sexp/LangOrListSXP.java @@ -0,0 +1,7 @@ +package org.prlprg.sexp; + +import javax.annotation.concurrent.Immutable; + +/** Either {@link ListSXP} (AST identifier) or {@link LangSXP} (AST call). */ +@Immutable +public sealed interface LangOrListSXP extends SEXP permits ListSXP, LangSXP {} diff --git a/src/main/java/org/prlprg/sexp/LangSXP.java b/src/main/java/org/prlprg/sexp/LangSXP.java index fa6e28d1a..c6db70e60 100644 --- a/src/main/java/org/prlprg/sexp/LangSXP.java +++ b/src/main/java/org/prlprg/sexp/LangSXP.java @@ -11,7 +11,7 @@ /** AST function call ("language object") SEXP. */ @Immutable -public sealed interface LangSXP extends SymOrLangSXP { +public sealed interface LangSXP extends SymOrLangSXP, LangOrListSXP { /** The function being called. */ SymOrLangSXP fun(); diff --git a/src/main/java/org/prlprg/sexp/ListSXP.java b/src/main/java/org/prlprg/sexp/ListSXP.java index a158d6d62..c2993c16f 100644 --- a/src/main/java/org/prlprg/sexp/ListSXP.java +++ b/src/main/java/org/prlprg/sexp/ListSXP.java @@ -4,7 +4,6 @@ import java.util.Iterator; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.stream.Stream; import javax.annotation.Nullable; import org.jetbrains.annotations.Unmodifiable; @@ -19,7 +18,25 @@ *

Implementation note: in GNU-R this is represented as a linked list, but we internally * use an array-list because it's more efficient. */ -public sealed interface ListSXP extends ListOrVectorSXP permits NilSXP, ListSXPImpl { +public sealed interface ListSXP extends ListOrVectorSXP, LangOrListSXP + permits NilSXP, ListSXPImpl { + /** + * Flatten {@code src} while adding its elements to {@code target}. Ex: + * + *

+   *   b = []; flatten([1, [2, 3], 4], b) ==> b = [1, 2, 3, 4]
+   * 
+ */ + static void flatten(ListSXP src, ImmutableList.Builder target) { + for (var i : src) { + if (i.value() instanceof ListSXP lst) { + flatten(lst, target); + } else { + target.add(i); + } + } + } + @Override ListSXP withAttributes(Attributes attributes); @@ -55,8 +72,6 @@ default boolean hasTags() { Stream stream(); - Optional get(String name); - ListSXP prepend(TaggedElem elem); @Override @@ -149,11 +164,6 @@ public Stream stream() { return data.stream(); } - @Override - public Optional get(String name) { - return Optional.empty(); - } - @Override public ListSXP prepend(TaggedElem elem) { return new ListSXPImpl( diff --git a/src/main/java/org/prlprg/sexp/NilSXP.java b/src/main/java/org/prlprg/sexp/NilSXP.java index c6eb86495..3e3548111 100644 --- a/src/main/java/org/prlprg/sexp/NilSXP.java +++ b/src/main/java/org/prlprg/sexp/NilSXP.java @@ -3,7 +3,6 @@ import com.google.common.collect.ImmutableList; import java.util.Collections; import java.util.List; -import java.util.Optional; import java.util.stream.Stream; import javax.annotation.Nullable; import javax.annotation.concurrent.Immutable; @@ -125,11 +124,6 @@ public Stream stream() { return Stream.of(); } - @Override - public Optional get(String name) { - return Optional.empty(); - } - @Override public ListSXP prepend(TaggedElem elem) { return SEXPs.list(List.of(elem)); diff --git a/src/main/java/org/prlprg/sexp/SEXP.java b/src/main/java/org/prlprg/sexp/SEXP.java index c7a2e4330..0c74e6f13 100644 --- a/src/main/java/org/prlprg/sexp/SEXP.java +++ b/src/main/java/org/prlprg/sexp/SEXP.java @@ -29,6 +29,7 @@ public sealed interface SEXP permits StrOrRegSymSXP, SymOrLangSXP, ListOrVectorSXP, + LangOrListSXP, CloSXP, EnvSXP, BCodeSXP, @@ -176,4 +177,8 @@ private void print(Printer p) { // `toString` is overridden in every subclass to call `Printer.toString(this)`. // endregion serialization and deserialization + + default boolean isObject() { + return attributes() != null && Objects.requireNonNull(attributes()).containsKey("class"); + } } diff --git a/src/main/java/org/prlprg/sexp/UserEnvSXP.java b/src/main/java/org/prlprg/sexp/UserEnvSXP.java index f063cde84..243ed4571 100644 --- a/src/main/java/org/prlprg/sexp/UserEnvSXP.java +++ b/src/main/java/org/prlprg/sexp/UserEnvSXP.java @@ -1,10 +1,13 @@ package org.prlprg.sexp; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterators; +import java.util.Iterator; import java.util.Map; import javax.annotation.Nonnull; /** An environment inside a closure or explicitly defined by the user. */ -public final class UserEnvSXP extends AbstractEnvSXP implements EnvSXP { +public final class UserEnvSXP extends AbstractEnvSXP implements EnvSXP, Iterable { private Attributes attributes = Attributes.NONE; public UserEnvSXP() { @@ -25,6 +28,21 @@ public void setParent(EnvSXP parent) { this.parent = parent; } + public Iterator iterator() { + // We need to transform the entries to TaggedElem to avoid exposing the internal map. + // This is not very efficient though. + return Iterators.transform( + bindings.entrySet().iterator(), e -> new TaggedElem(e.getKey(), e.getValue())); + } + + public ListSXP frame() { + return new ListSXPImpl( + bindings.entrySet().stream() + .map(e -> new TaggedElem(e.getKey(), e.getValue())) + .collect(ImmutableList.toImmutableList()), + Attributes.NONE); + } + @Override public EnvType envType() { return EnvType.USER; @@ -40,4 +58,8 @@ public UserEnvSXP withAttributes(Attributes attributes) { public @Nonnull Attributes attributes() { return attributes; } + + public void setAttributes(Attributes attributes) { + this.attributes = attributes; + } } diff --git a/src/test/java/org/prlprg/RClosureTests.java b/src/test/java/org/prlprg/RClosureTests.java index fa3fcc1db..11899abe6 100644 --- a/src/test/java/org/prlprg/RClosureTests.java +++ b/src/test/java/org/prlprg/RClosureTests.java @@ -19,7 +19,7 @@ */ public abstract class RClosureTests extends AbstractGNURBasedTest { @Test - public void testEmptyList() { + public void testEmptyList() throws Exception { testClosure( """ function () @@ -30,7 +30,7 @@ public void testEmptyList() { } @Test - public void testEmptyBlock() { + public void testEmptyBlock() throws Exception { testClosure( """ function() {} @@ -38,7 +38,7 @@ public void testEmptyBlock() { } @Test - public void testSingleExpressionBlock() { + public void testSingleExpressionBlock() throws Exception { testClosure( """ function() { 1 } @@ -46,7 +46,7 @@ public void testSingleExpressionBlock() { } @Test - public void testMultipleExpressionBlock() { + public void testMultipleExpressionBlock() throws Exception { testClosure( """ function() { 1; 2 } @@ -54,7 +54,7 @@ public void testMultipleExpressionBlock() { } @Test - public void testIf() { + public void testIf() throws Exception { testClosure( """ function(x) if (x) 1 @@ -62,7 +62,7 @@ public void testIf() { } @Test - public void testIfElse() { + public void testIfElse() throws Exception { testClosure( """ function(x) if (x) 1 else 2 @@ -70,7 +70,7 @@ public void testIfElse() { } @Test - public void testFunctionInlining() { + public void testFunctionInlining() throws Exception { testClosure( """ function(x) function(y) 1 @@ -78,7 +78,7 @@ public void testFunctionInlining() { } @Test - public void testFunctionLeftParenInlining() { + public void testFunctionLeftParenInlining() throws Exception { testClosure( """ function(x) (x) @@ -91,7 +91,7 @@ public void testFunctionLeftParenInlining() { } @Test - public void builtinsInlining() { + public void builtinsInlining() throws Exception { // expecting a guard testClosure( """ @@ -115,7 +115,7 @@ public void builtinsInlining() { } @Test - public void specialsInlining() { + public void specialsInlining() throws Exception { testClosure( """ function() rep(1, 10) @@ -123,7 +123,7 @@ public void specialsInlining() { } @Test - public void inlineLocal() { + public void inlineLocal() throws Exception { testClosure( """ function(x) local(x) @@ -131,7 +131,7 @@ public void inlineLocal() { } @Test - public void inlineReturn() { + public void inlineReturn() throws Exception { testClosure( """ function(x) return(x) @@ -139,7 +139,7 @@ public void inlineReturn() { } @Test - public void inlineBuiltinsInternal() { + public void inlineBuiltinsInternal() throws Exception { testClosure( """ function(x) .Internal(inspect(x)) @@ -152,7 +152,7 @@ public void inlineBuiltinsInternal() { } @Test - public void inlineLogicalAnd() { + public void inlineLogicalAnd() throws Exception { testClosure( """ function(x, y) x && y @@ -170,7 +170,7 @@ public void inlineLogicalAnd() { } @Test - public void inlineLogicalOr() { + public void inlineLogicalOr() throws Exception { testClosure( """ function(x, y) x || y @@ -188,7 +188,7 @@ public void inlineLogicalOr() { } @Test - public void inlineLogicalAndOr() { + public void inlineLogicalAndOr() throws Exception { testClosure( """ function(x, y) x && y || y @@ -201,7 +201,7 @@ public void inlineLogicalAndOr() { } @Test - public void inlineRepeat() { + public void inlineRepeat() throws Exception { testClosure( """ function(x) repeat(x) @@ -224,7 +224,7 @@ public void inlineRepeat() { } @Test - public void inlineWhile() { + public void inlineWhile() throws Exception { testClosure( """ function(x) while(x) 1 @@ -242,7 +242,7 @@ public void inlineWhile() { } @Test - public void inlineFor() { + public void inlineFor() throws Exception { testClosure( """ function(x) for (i in x) 1 @@ -255,7 +255,7 @@ public void inlineFor() { } @Test - public void inlineArithmetics() { + public void inlineArithmetics() throws Exception { testClosure( """ function(x, y) x + y @@ -289,7 +289,7 @@ public void inlineArithmetics() { } @Test - public void inlineMath1() { + public void inlineMath1() throws Exception { testClosure( """ function(x) { @@ -306,7 +306,7 @@ public void inlineMath1() { } @Test - public void inlineLogical() { + public void inlineLogical() throws Exception { testClosure( """ function(x, y) { @@ -318,7 +318,7 @@ public void inlineLogical() { } @Test - public void inlineDollar() { + public void inlineDollar() throws Exception { testClosure( """ # xs <- list(a=1, b=list(c=2)) @@ -333,7 +333,7 @@ public void inlineDollar() { } @Test - public void inlineIsXYZ() { + public void inlineIsXYZ() throws Exception { testClosure( """ function(x) { @@ -353,7 +353,7 @@ public void inlineIsXYZ() { } @Test - public void inlineDotCall() { + public void inlineDotCall() throws Exception { testClosure( """ function(x) { @@ -364,7 +364,7 @@ public void inlineDotCall() { } @Test - public void inlineIntGeneratingSequences() { + public void inlineIntGeneratingSequences() throws Exception { testClosure( """ function(x, xs) { @@ -374,7 +374,7 @@ public void inlineIntGeneratingSequences() { } @Test - public void multiColon() { + public void multiColon() throws Exception { testClosure( """ function() { @@ -384,7 +384,7 @@ public void multiColon() { } @Test - public void inlineSwitch() { + public void inlineSwitch() throws Exception { testClosure( """ function(x) { @@ -396,7 +396,7 @@ public void inlineSwitch() { } @Test - public void inlineAssign1() { + public void inlineAssign1() throws Exception { testClosure( """ function() { @@ -422,7 +422,7 @@ public void inlineAssign1() { } @Test - public void inlineAssign2() { + public void inlineAssign2() throws Exception { testClosure( """ function() { @@ -439,7 +439,7 @@ public void inlineAssign2() { } @Test - public void inlineAssign3() { + public void inlineAssign3() throws Exception { testClosure( """ function() { @@ -449,7 +449,7 @@ public void inlineAssign3() { } @Test - public void inlineDollarAssign() { + public void inlineDollarAssign() throws Exception { testClosure( """ function() { @@ -461,7 +461,7 @@ public void inlineDollarAssign() { } @Test - public void inlineSquareAssign1() { + public void inlineSquareAssign1() throws Exception { testClosure( """ function() { @@ -472,7 +472,7 @@ public void inlineSquareAssign1() { } @Test - public void inlineSquareAssign2() { + public void inlineSquareAssign2() throws Exception { testClosure( """ function() { @@ -483,7 +483,7 @@ public void inlineSquareAssign2() { } @Test - public void inlineSquareAssign3() { + public void inlineSquareAssign3() throws Exception { testClosure( """ function() { @@ -494,7 +494,7 @@ public void inlineSquareAssign3() { } @Test - public void inlineSquareAssign4() { + public void inlineSquareAssign4() throws Exception { testClosure( """ function() { @@ -504,7 +504,7 @@ public void inlineSquareAssign4() { } @Test - public void inlineSquareSubset1() { + public void inlineSquareSubset1() throws Exception { testClosure( """ function() { @@ -515,7 +515,7 @@ public void inlineSquareSubset1() { } @Test - public void inlineSquareSubset2() { + public void inlineSquareSubset2() throws Exception { testClosure( """ function() { @@ -526,7 +526,7 @@ public void inlineSquareSubset2() { } @Test - public void inlineSquareSubset3() { + public void inlineSquareSubset3() throws Exception { testClosure( """ function() { @@ -537,7 +537,7 @@ public void inlineSquareSubset3() { } @Test - public void inlineSquareSubset4() { + public void inlineSquareSubset4() throws Exception { testClosure( """ function() { @@ -548,7 +548,7 @@ public void inlineSquareSubset4() { } @Test - public void inlineSlotAssign() { + public void inlineSlotAssign() throws Exception { testClosure( """ function() { @@ -560,7 +560,7 @@ public void inlineSlotAssign() { } @Test - public void inlineIdentical() { + public void inlineIdentical() throws Exception { testClosure( """ function(x) { @@ -570,7 +570,7 @@ public void inlineIdentical() { } @Test - public void constantFoldingC() { + public void constantFoldingC() throws Exception { // no constant folding - c is resolved from baseenv() testClosure( """ @@ -590,7 +590,7 @@ public void constantFoldingC() { } @Test - public void constantFoldMul() { + public void constantFoldMul() throws Exception { testClosure( """ function() { @@ -601,7 +601,7 @@ public void constantFoldMul() { @ParameterizedTest @MethodSource("stdlibFunctionsList") - public void stdlibFunctions(String name) { + public void stdlibFunctions(String name) throws Exception { testClosure(name); } @@ -640,9 +640,9 @@ protected double stdlibTestsRatio() { return 1; } - protected void testClosure(String closure) { + protected void testClosure(String closure) throws Exception { testClosure(closure, Compiler.DEFAULT_OPTIMIZATION_LEVEL); } - protected abstract void testClosure(String closure, int optimizationLevel); + protected abstract void testClosure(String closure, int optimizationLevel) throws Exception; } diff --git a/src/test/java/org/prlprg/bc2ir/IRCompilerTests.java b/src/test/java/org/prlprg/bc2ir/IRCompilerTests.java index 4085e4169..3d23fa7d9 100644 --- a/src/test/java/org/prlprg/bc2ir/IRCompilerTests.java +++ b/src/test/java/org/prlprg/bc2ir/IRCompilerTests.java @@ -8,7 +8,7 @@ /** Test our {@linkplain ClosureCompiler IR closure compiler} specifically. */ public class IRCompilerTests extends RClosureTestsUsingIRCompiler { @Test - public void inlineForReturn() { + public void inlineForReturn() throws Exception { testClosure( """ function(x) for (i in x) if (i) return() else 1 diff --git a/src/test/java/org/prlprg/rds/RDSReaderTest.java b/src/test/java/org/prlprg/rds/RDSReaderTest.java index 51dd489ae..24c99df0e 100644 --- a/src/test/java/org/prlprg/rds/RDSReaderTest.java +++ b/src/test/java/org/prlprg/rds/RDSReaderTest.java @@ -126,10 +126,12 @@ public void testNamedList() { @Test public void testClosure() { var sexp = (CloSXP) R.eval("function(x, y=1) 'abc' + x + length(y)"); + System.out.println(sexp); var formals = sexp.parameters(); assertEquals(2, formals.size()); assertEquals(new TaggedElem("x", SEXPs.MISSING_ARG), formals.get(0)); + assertEquals(new TaggedElem("y", SEXPs.real(1.0)), formals.get(1)); // TODO: this should really be a snapshot test var body = sexp.bodyAST(); @@ -147,6 +149,7 @@ public void testClosureWithBC() { // TODO: this should really be a snapshot test var body = sexp.body(); + System.out.println(body); assertThat(body).isInstanceOf(BCodeSXP.class); } diff --git a/src/test/java/org/prlprg/rds/RDSRoundtripTest.java b/src/test/java/org/prlprg/rds/RDSRoundtripTest.java new file mode 100644 index 000000000..83b91b166 --- /dev/null +++ b/src/test/java/org/prlprg/rds/RDSRoundtripTest.java @@ -0,0 +1,33 @@ +package org.prlprg.rds; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import org.prlprg.RClosureTests; +import org.prlprg.sexp.CloSXP; +import org.prlprg.sexp.SEXP; + +public class RDSRoundtripTest extends RClosureTests { + @Override + protected void testClosure(String closure, int optimizationLevel) throws Exception { + + // Load the closure into Java using eval - this will be our starting point + var clo = R.eval(closure); + + // Write the closure using the RDS writer + var output = new ByteArrayOutputStream(); + RDSWriter.writeStream(output, clo); + + // Read from the stream using the RDS reader + var input = new ByteArrayInputStream(output.toByteArray()); + SEXP res = RDSReader.readStream(rsession, input); + + if (clo instanceof CloSXP c && res instanceof CloSXP r) { + assertEquals(c.body(), r.body()); + assertEquals(c.parameters(), r.parameters()); + } else { + throw new AssertionError("Expected deserialized SEXP to be a closure"); + } + } +} diff --git a/src/test/java/org/prlprg/rds/RDSWriterTest.java b/src/test/java/org/prlprg/rds/RDSWriterTest.java new file mode 100644 index 000000000..cece7fcb2 --- /dev/null +++ b/src/test/java/org/prlprg/rds/RDSWriterTest.java @@ -0,0 +1,382 @@ +package org.prlprg.rds; + +import static java.lang.Double.NaN; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.prlprg.sexp.SEXPs.*; + +import java.io.*; +import java.util.HashMap; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.prlprg.AbstractGNURBasedTest; +import org.prlprg.bc.Compiler; +import org.prlprg.primitive.Complex; +import org.prlprg.primitive.Constants; +import org.prlprg.primitive.Logical; +import org.prlprg.sexp.*; + +public class RDSWriterTest extends AbstractGNURBasedTest { + @Test + public void testInts() throws Exception { + var ints = integer(5, 4, 3, 2, 1); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, ints); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof IntSXP read_ints) { + assertEquals(5, read_ints.size()); + assertEquals(5, read_ints.get(0)); + assertEquals(4, read_ints.get(1)); + assertEquals(3, read_ints.get(2)); + assertEquals(2, read_ints.get(3)); + assertEquals(1, read_ints.get(4)); + } else { + fail("Expected IntSXP"); + } + } + + @Test + public void testInts_withR() throws Exception { + var ints = integer(5, 4, 3, 2, 1); + var output = + R.eval("typeof(input) == 'integer' && identical(input, c(5L, 4L, 3L, 2L, 1L))", ints); + + if (output instanceof LglSXP read_lgls) { + assertEquals(1, read_lgls.size()); + assertEquals(Logical.TRUE, read_lgls.get(0)); + } else { + fail("Expected LglSXP"); + } + } + + @Test + public void testComplex() throws Exception { + var complexes = complex(new Complex(0, 0), new Complex(1, 2), new Complex(-2, -1)); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, complexes); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof ComplexSXP read_complexes) { + assertEquals(3, read_complexes.size()); + assertEquals(new Complex(0, 0), read_complexes.get(0)); + assertEquals(new Complex(1, 2), read_complexes.get(1)); + assertEquals(new Complex(-2, -1), read_complexes.get(2)); + } + } + + @Test + public void testComplex_withR() throws Exception { + var complexes = complex(new Complex(0, 0), new Complex(1, 2), new Complex(-2, -1)); + var output = + R.eval("typeof(input) == 'complex' && identical(input, c(0+0i, 1+2i, -2-1i))", complexes); + + if (output instanceof LglSXP read_lgls) { + assertEquals(1, read_lgls.size()); + assertEquals(Logical.TRUE, read_lgls.get(0)); + } else { + fail("Expected LglSXP"); + } + } + + @Test + public void testLang() throws Exception { + var lang = + lang( + symbol("func"), + list(List.of(new TaggedElem("arg", integer(1)), new TaggedElem(integer(2))))); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, lang); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof LangSXP read_lang) { + var name = read_lang.funName(); + var arg1 = read_lang.args().get(0); + var arg2 = read_lang.args().get(1); + + assert name.isPresent(); + + assertEquals(name.get(), "func"); + assertEquals(arg1, new TaggedElem("arg", integer(1))); + assertEquals(arg2, new TaggedElem(integer(2))); + } + } + + @Test + public void testVecAttributes() throws Exception { + var attrs = + new Attributes.Builder().put("a", integer(1)).put("b", logical(Logical.TRUE)).build(); + var ints = integer(1, attrs); + + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, ints); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof IntSXP read_ints) { + assertEquals(1, read_ints.size()); + assertEquals(1, read_ints.get(0)); + assertEquals(attrs, read_ints.attributes()); + } else { + fail("Expected IntSXP"); + } + } + + @Test + public void testLgls() throws Exception { + var lgls = logical(Logical.TRUE, Logical.FALSE, Logical.NA); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, lgls); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof LglSXP read_lgls) { + assertEquals(3, read_lgls.size()); + assertEquals(Logical.TRUE, read_lgls.get(0)); + assertEquals(Logical.FALSE, read_lgls.get(1)); + assertEquals(Logical.NA, read_lgls.get(2)); + } else { + fail("Expected LglSXP"); + } + } + + @Test + public void testReals() throws Exception { + var reals = real(5.2, 4.0, Constants.NA_REAL, 2.0, NaN, 1.0); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, reals); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof RealSXP read_reals) { + assertEquals(6, read_reals.size()); + assertEquals(5.2, read_reals.get(0)); + assertEquals(4.0, read_reals.get(1)); + assertEquals(Constants.NA_REAL, read_reals.get(2)); + assertEquals(2.0, read_reals.get(3)); + assertEquals(NaN, read_reals.get(4)); + assertEquals(1.0, read_reals.get(5)); + } else { + fail("Expected RealSXP"); + } + } + + @Test + public void testNull() throws Exception { + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, NULL); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + assertEquals(NULL, sexp); + } + + @Test + public void testVec() throws Exception { + var vec = vec(integer(1, 2, 3), logical(Logical.TRUE, Logical.FALSE, Logical.NA)); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, vec); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof VecSXP read_vec) { + assertEquals(2, read_vec.size()); + if (read_vec.get(0) instanceof IntSXP read_ints) { + assertEquals(3, read_ints.size()); + assertEquals(1, read_ints.get(0)); + assertEquals(2, read_ints.get(1)); + assertEquals(3, read_ints.get(2)); + } else { + fail("Expected IntSXP for the 1st element of the VecSXP"); + } + if (read_vec.get(1) instanceof LglSXP read_lgls) { + assertEquals(3, read_lgls.size()); + assertEquals(Logical.TRUE, read_lgls.get(0)); + assertEquals(Logical.FALSE, read_lgls.get(1)); + assertEquals(Logical.NA, read_lgls.get(2)); + } else { + fail("Expected LglSXP for the 2nd element of the VecSXP"); + } + } else { + fail("Expected VecSXP"); + } + } + + @Test + public void testList() throws Exception { + var elems = + new TaggedElem[] { + new TaggedElem("a", integer(1)), + new TaggedElem("b", logical(Logical.TRUE)), + new TaggedElem("c", real(3.14, 2.71)) + }; + var list = list(elems, Attributes.NONE); + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, list); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof ListSXP l) { + assertEquals(3, l.size()); + assertEquals("a", l.get(0).tag()); + if (l.get(0).value() instanceof IntSXP i) { + assertEquals(1, i.get(0)); + } else { + fail("Expected IntSXP for the 1st element of the ListSXP"); + } + assertEquals("b", l.get(1).tag()); + if (l.get(1).value() instanceof LglSXP lgl) { + assertEquals(Logical.TRUE, lgl.get(0)); + } else { + fail("Expected LglSXP for the 2nd element of the ListSXP"); + } + assertEquals("c", l.get(2).tag()); + if (l.get(2).value() instanceof RealSXP r) { + assertEquals(3.14, r.get(0)); + assertEquals(2.71, r.get(1)); + } else { + fail("Expected RealSXP for the 3rd element of the ListSXP"); + } + } else { + fail("Expected ListSXP"); + } + } + + @Test + public void testEnv() throws Exception { + var env = new UserEnvSXP(); + env.set("a", integer(1)); + env.set("b", logical(Logical.TRUE)); + env.set("c", real(3.14, 2.71)); + env.set("d", string("foo", "bar")); + env.setAttributes(new Attributes.Builder().put("test", logical(Logical.TRUE)).build()); + + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, env); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + if (sexp instanceof UserEnvSXP read_env) { + assertEquals(4, read_env.size()); + assertEquals(integer(1), read_env.get("a").orElseThrow()); + assertEquals(logical(Logical.TRUE), read_env.get("b").orElseThrow()); + assertEquals(real(3.14, 2.71), read_env.get("c").orElseThrow()); + assertEquals(string("foo", "bar"), read_env.get("d").orElseThrow()); + assertEquals( + new Attributes.Builder().put("test", logical(Logical.TRUE)).build(), + read_env.attributes()); + } else { + fail("Expected UserEnvSXP"); + } + } + + @Test + public void testEnv_withR() throws Exception { + var env = new UserEnvSXP(); + env.set("a", integer(1)); + env.set("b", logical(Logical.TRUE)); + env.set("c", real(3.14, 2.71)); + env.set("d", string("foo", "bar")); + + var output = R.eval("typeof(input) == 'environment'", env); + + if (output instanceof LglSXP read_lgls) { + assertEquals(1, read_lgls.size()); + assertEquals(Logical.TRUE, read_lgls.get(0)); + } else { + fail("Expected LglSXP"); + } + } + + @Test + public void testClosureEval() throws Exception { + // function(x, y=1) length(x) + x + y + // test by loading the closure into R and evaluating + var clo = + closure( + list(List.of(new TaggedElem("x", MISSING_ARG), new TaggedElem("y", real(3)))), + lang( + symbol("+"), + list( + lang( + symbol("+"), + list(lang(symbol("length"), list(symbol("x"))), symbol("x"))), + symbol("y"))), + new BaseEnvSXP(new HashMap<>())) + .withAttributes(new Attributes.Builder().put("a", integer(1)).build()); + ; + + var output = R.eval("input(x=c(1, 2))", clo); + + assertEquals(output, real(6, 7)); + } + + @Test + public void testClosureWithBC() throws Exception { + // Same closure as `testClosure`, just compiled to bytecode + // Test by serializing and deserializing + var clo = + closure( + list(List.of(new TaggedElem("x", MISSING_ARG), new TaggedElem("y", real(3)))), + lang( + symbol("+"), + list( + lang(symbol("+"), list(lang(symbol("length"), list(symbol("x"))), symbol("x"))), + symbol("y"))), + new BaseEnvSXP(new HashMap<>())); + var bc = new Compiler(clo, rsession).compile().orElseThrow(); + + var output = new ByteArrayOutputStream(); + + RDSWriter.writeStream(output, bcode(bc)); + + var input = new ByteArrayInputStream(output.toByteArray()); + var sexp = RDSReader.readStream(rsession, input); + + assertEquals(sexp, bcode(bc)); + } + + @Test + public void testClosureWithBCEval() { + // Same closure as `testClosure`, just compiled to bytecode + // Test by loading into R and evaluating + var clo = + closure( + list(List.of(new TaggedElem("x", MISSING_ARG), new TaggedElem("y", real(3)))), + lang( + symbol("+"), + list( + lang(symbol("+"), list(lang(symbol("length"), list(symbol("x"))), symbol("x"))), + symbol("y"))), + new BaseEnvSXP(new HashMap<>())); + var bc = new Compiler(clo, rsession).compile().orElseThrow(); + var compiled_clo = closure(clo.parameters(), bcode(bc), clo.env()); + + var output = R.eval("input(x=c(1, 2))", compiled_clo); + + assertEquals(output, real(6, 7)); + } +} diff --git a/src/test/java/org/prlprg/util/GNUR.java b/src/test/java/org/prlprg/util/GNUR.java index 77c520841..8dee07ba7 100644 --- a/src/test/java/org/prlprg/util/GNUR.java +++ b/src/test/java/org/prlprg/util/GNUR.java @@ -9,6 +9,7 @@ import javax.annotation.concurrent.NotThreadSafe; import org.prlprg.RSession; import org.prlprg.rds.RDSReader; +import org.prlprg.rds.RDSWriter; import org.prlprg.sexp.SEXP; @NotThreadSafe @@ -25,7 +26,7 @@ public GNUR(RSession rsession, Process rprocess) { this.rout = new BufferedReader(new InputStreamReader(rprocess.getInputStream())); } - private void run(String code) { + public void run(String code) { var requestId = UUID.randomUUID().toString(); if (!rprocess.isAlive()) { @@ -66,6 +67,25 @@ public SEXP eval(String source) { } } + /** + * Evaluate R source with input SEXP. The SEXP is passed from Java to the R world using RDS. + * + * @param source + * @param input + * @return + */ + public SEXP eval(String source, SEXP input) { + try { + var inputFile = File.createTempFile("RCS-input", ".rds"); + RDSWriter.writeFile(inputFile, input); + String full_source = "input <- readRDS('" + inputFile.getAbsolutePath() + "')\n" + source; + + return eval(full_source); + } catch (Exception e) { + throw new RuntimeException("Unable to eval R source", e); + } + } + private void waitForCommand(String requestId) { var output = new StringBuilder(); try { diff --git a/src/test/resources/org/prlprg/bc/serialize-closures.R b/src/test/resources/org/prlprg/bc/serialize-closures.R index 09795db48..4d408dbd0 100644 --- a/src/test/resources/org/prlprg/bc/serialize-closures.R +++ b/src/test/resources/org/prlprg/bc/serialize-closures.R @@ -12,4 +12,4 @@ for (.name in ls()) { saveRDS(.func, compress = FALSE, file = file.path(.target, paste0(.name, ".ast.rds"))) saveRDS(compiler::cmpfun(.func), compress = FALSE, file = file.path(.target, paste0(.name, ".bc.rds"))) } -} \ No newline at end of file +}