diff --git a/src/main/java/li/cil/sedna/instruction/InstructionDefinitionLoader.java b/src/main/java/li/cil/sedna/instruction/InstructionDefinitionLoader.java index 18c10b01..2076ea42 100644 --- a/src/main/java/li/cil/sedna/instruction/InstructionDefinitionLoader.java +++ b/src/main/java/li/cil/sedna/instruction/InstructionDefinitionLoader.java @@ -27,7 +27,7 @@ public static HashMap load(final } final ClassReader cr = new ClassReader(stream); - cr.accept(new ClassVisitor(Opcodes.ASM7) { + cr.accept(new ClassVisitor(Opcodes.ASM8) { @Override public MethodVisitor visitMethod(final int access, final String name, final String descriptor, final String signature, final String[] exceptions) { final InstructionFunctionVisitor visitor = new InstructionFunctionVisitor(implementation, name, descriptor, exceptions); @@ -169,7 +169,7 @@ private static final class InstructionFunctionVisitor extends MethodVisitor { private boolean writesPC; public InstructionFunctionVisitor(final Class implementation, final String name, final String descriptor, final String[] exceptions) { - super(Opcodes.ASM7); + super(Opcodes.ASM8); this.implementation = implementation; this.name = name; this.descriptor = descriptor; @@ -180,7 +180,7 @@ public InstructionFunctionVisitor(final Class implementation, final String na @Override public AnnotationVisitor visitParameterAnnotation(final int parameter, final String descriptor, final boolean visible) { if (Objects.equals(descriptor, Type.getDescriptor(InstructionDefinition.Field.class))) { - return new AnnotationVisitor(Opcodes.ASM7) { + return new AnnotationVisitor(Opcodes.ASM8) { @Override public void visit(final String name, final Object value) { super.visit(name, value); @@ -208,7 +208,7 @@ public void visit(final String name, final Object value) { public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) { if (Objects.equals(descriptor, Type.getDescriptor(InstructionDefinition.Instruction.class))) { isImplementation = true; - return new AnnotationVisitor(Opcodes.ASM7) { + return new AnnotationVisitor(Opcodes.ASM8) { @Override public void visit(final String name, final Object value) { super.visit(name, value); @@ -345,12 +345,12 @@ private void resolveInvocations(final ArrayList known } final ClassReader reader = new ClassReader(stream); - reader.accept(new ClassVisitor(Opcodes.ASM7) { + reader.accept(new ClassVisitor(Opcodes.ASM8) { @Override public MethodVisitor visitMethod(final int access, final String methodName, final String methodDescriptor, final String signature, final String[] exceptions) { if (methodName.equals(NonStaticMethodInvocation.this.name) && methodDescriptor.equals(NonStaticMethodInvocation.this.descriptor)) { - return new MethodVisitor(Opcodes.ASM7) { + return new MethodVisitor(Opcodes.ASM8) { @Override public void visitMethodInsn(final int opcode, final String invokedMethodOwner, final String invokedMethodName, final String invokedMethodDescriptor, final boolean isInterface) { super.visitMethodInsn(opcode, invokedMethodOwner, invokedMethodName, invokedMethodDescriptor, isInterface); diff --git a/src/main/java/li/cil/sedna/instruction/decoder/DecoderGenerator.java b/src/main/java/li/cil/sedna/instruction/decoder/DecoderGenerator.java index a02f3d43..03744e02 100644 --- a/src/main/java/li/cil/sedna/instruction/decoder/DecoderGenerator.java +++ b/src/main/java/li/cil/sedna/instruction/decoder/DecoderGenerator.java @@ -91,7 +91,7 @@ public DecoderGenerator(final ClassVisitor cv, final Class illegalInstructionExceptionClass, final String decoderMethod, final String decoderHook) { - super(ASM7, cv); + super(ASM8, cv); this.decoderTree = decoderTree; this.definitionProvider = definitionProvider; this.decoderMethod = decoderMethod; @@ -193,7 +193,7 @@ private final class TemplateMethodVisitor extends MethodVisitor implements Opcod private final ClassVisitor classVisitor; public TemplateMethodVisitor(final MethodVisitor methodVisitor, final ClassVisitor classVisitor) { - super(Opcodes.ASM7, methodVisitor); + super(Opcodes.ASM8, methodVisitor); this.classVisitor = classVisitor; } diff --git a/src/main/java/li/cil/sedna/riscv/R5Board.java b/src/main/java/li/cil/sedna/riscv/R5Board.java index 147e1a90..9e1ff825 100644 --- a/src/main/java/li/cil/sedna/riscv/R5Board.java +++ b/src/main/java/li/cil/sedna/riscv/R5Board.java @@ -15,7 +15,7 @@ import li.cil.sedna.device.flash.FlashMemoryDevice; import li.cil.sedna.devicetree.DeviceTreeRegistry; import li.cil.sedna.devicetree.FlattenedDeviceTree; -import li.cil.sedna.gdbstub.GDBStub; +import li.cil.sedna.riscv.gdbstub.GDBStub; import li.cil.sedna.memory.SimpleMemoryMap; import li.cil.sedna.riscv.device.R5CoreLocalInterrupter; import li.cil.sedna.riscv.device.R5PlatformLevelInterruptController; diff --git a/src/main/java/li/cil/sedna/riscv/R5CPU.java b/src/main/java/li/cil/sedna/riscv/R5CPU.java index ac968258..b8525989 100644 --- a/src/main/java/li/cil/sedna/riscv/R5CPU.java +++ b/src/main/java/li/cil/sedna/riscv/R5CPU.java @@ -5,7 +5,7 @@ import li.cil.sedna.api.device.Steppable; import li.cil.sedna.api.device.rtc.RealTimeCounter; import li.cil.sedna.api.memory.MemoryMap; -import li.cil.sedna.gdbstub.CPUDebugInterface; +import li.cil.sedna.riscv.gdbstub.CPUDebugInterface; import javax.annotation.Nullable; diff --git a/src/main/java/li/cil/sedna/riscv/R5CPUTemplate.java b/src/main/java/li/cil/sedna/riscv/R5CPUTemplate.java index 126e586f..a18b2fff 100644 --- a/src/main/java/li/cil/sedna/riscv/R5CPUTemplate.java +++ b/src/main/java/li/cil/sedna/riscv/R5CPUTemplate.java @@ -11,7 +11,8 @@ import li.cil.sedna.api.memory.MappedMemoryRange; import li.cil.sedna.api.memory.MemoryAccessException; import li.cil.sedna.api.memory.MemoryMap; -import li.cil.sedna.gdbstub.CPUDebugInterface; +import li.cil.sedna.riscv.gdbstub.CPUDebugInterface; +import li.cil.sedna.riscv.gdbstub.Watchpoint; import li.cil.sedna.instruction.InstructionDefinition.Field; import li.cil.sedna.instruction.InstructionDefinition.Instruction; import li.cil.sedna.instruction.InstructionDefinition.InstructionSize; @@ -21,12 +22,16 @@ import li.cil.sedna.utils.BitUtils; import li.cil.sedna.utils.SoftDouble; import li.cil.sedna.utils.SoftFloat; +import li.cil.sedna.utils.Interval; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Iterator; +import java.util.List; import java.util.concurrent.atomic.AtomicLong; import java.util.function.LongConsumer; @@ -137,6 +142,8 @@ final class R5CPUTemplate implements R5CPU { private final transient RealTimeCounter rtc; private transient int cycleFrequency = 50_000_000; private final transient DebugInterface debugInterface = new DebugInterface(); + private transient boolean triggeredWatchpoint = false; + private transient long triggeredWatchpointAddress = 0; public R5CPUTemplate(final MemoryMap physicalMemory, @Nullable final RealTimeCounter rtc) { // This cast is necessary so that stack frame computation in ASM does not throw @@ -354,6 +361,10 @@ private void interpret(final boolean singleStep, final boolean ignoreBreakpoints } else { interpretTrace64(device, inst, pc, instOffset, singleStep ? 0 : instEnd, ignoreBreakpoints ? null : cache.breakpoints); } + if(triggeredWatchpoint) { + triggeredWatchpoint = false; + debugInterface.handleWatchpoint(triggeredWatchpointAddress); + } } catch (final R5MemoryAccessException e) { raiseException(e.getType(), e.getAddress()); } @@ -366,6 +377,12 @@ private void interpret(final boolean singleStep, final boolean ignoreBreakpoints private void interpretTrace32(final MemoryMappedDevice device, int inst, long pc, int instOffset, final int instEnd, final LongSet breakpoints) { try { // Catch any exceptions to patch PC field. for (; ; ) { // End of page check at the bottom since we enter with a valid inst. + // Since decode contains "continue" it's important to check at the beginning of the loop for a + // watchpoint trigger + if(triggeredWatchpoint) { + this.pc = pc; + return; + } if (breakpoints != null && breakpoints.contains(pc)) { this.pc = pc; debugInterface.handleBreakpoint(pc); @@ -402,6 +419,12 @@ private void interpretTrace32(final MemoryMappedDevice device, int inst, long pc private void interpretTrace64(final MemoryMappedDevice device, int inst, long pc, int instOffset, final int instEnd, final LongSet breakpoints) { try { // Catch any exceptions to patch PC field. for (; ; ) { // End of page check at the bottom since we enter with a valid inst. + // Since decode contains "continue" it's important to check at the beginning of the loop for a + // watchpoint trigger + if(triggeredWatchpoint) { + this.pc = pc; + return; + } if (breakpoints != null && breakpoints.contains(pc)) { this.pc = pc; debugInterface.handleBreakpoint(pc); @@ -476,16 +499,20 @@ private boolean csrrscx(final int rd, final int rs1, final int csr, final long m return false; } + private boolean isReadonlyCSR(final int csr) { + // csr[11:8] encodes access rights for CSR by convention. Of these, the top-most two bits, + // csr[11:10], encode read-only state, where 0b11: read-only, 0b00..0b10: read-write. + boolean readonly = ((csr & 0b1100_0000_0000) == 0b1100_0000_0000); + // There are also these special cases + readonly |= (csr >= 0xC00 && csr <= 0xC1F) || (csr >= 0xC80 && csr <= 0xC9F); + return readonly; + } private void checkCSR(final int csr, final boolean throwIfReadonly) throws R5IllegalInstructionException { - if (throwIfReadonly && ((csr >= 0xC00 && csr <= 0xC1F) || (csr >= 0xC80 && csr <= 0xC9F))) + if (throwIfReadonly && isReadonlyCSR(csr)) throw new R5IllegalInstructionException(); - // Topmost bits, i.e. csr[11:8], encode access rights for CSR by convention. Of these, the top-most two bits, - // csr[11:10], encode read-only state, where 0b11: read-only, 0b00..0b10: read-write. - if (throwIfReadonly && ((csr & 0b1100_0000_0000) == 0b1100_0000_0000)) - throw new R5IllegalInstructionException(); - // The two following bits, csr[9:8], encode the lowest privilege level that can access the CSR. + // csr[9:8] encodes the lowest privilege level that can access the CSR. if (priv < ((csr >>> 8) & 0b11)) throw new R5IllegalInstructionException(); } @@ -1173,7 +1200,12 @@ private long loadx(final long address, final int size, final int sizeLog2) throw final TLBEntry entry = loadTLB[index]; if (entry.hash == hash) { try { - return entry.device.load((int) (address + entry.toOffset), sizeLog2); + final long result = entry.device.load((int) (address + entry.toOffset), sizeLog2); + if(entry.hasWatchpoint(new Interval(address, 1L << sizeLog2))) { + triggeredWatchpoint = true; + triggeredWatchpointAddress = address; + } + return result; } catch (final MemoryAccessException e) { throw new R5MemoryAccessException(address, R5.EXCEPTION_FAULT_LOAD); } @@ -1191,6 +1223,10 @@ private void storex(final long address, final long value, final int size, final if (entry.hash == hash) { try { entry.device.store((int) (address + entry.toOffset), value, sizeLog2); + if(entry.hasWatchpoint(new Interval(address, 1L << sizeLog2))) { + triggeredWatchpoint = true; + triggeredWatchpointAddress = address; + } } catch (final MemoryAccessException e) { throw new R5MemoryAccessException(address, R5.EXCEPTION_FAULT_STORE); } @@ -1207,10 +1243,11 @@ private TLBEntry fetchPageSlow(final long address) throws R5MemoryAccessExceptio } final TLBEntry tlb = updateTLB(fetchTLB, address, physicalAddress, range); final var subset = debugInterface.breakpoints.subSet(address, address + (1 << R5.PAGE_ADDRESS_SHIFT)); - if (subset.isEmpty()) { + final int subsetSize = subset.size(); + if (subsetSize == 0) { tlb.breakpoints = null; } else { - tlb.breakpoints = new LongOpenHashSet(subset.size()); + tlb.breakpoints = new LongOpenHashSet(subsetSize); tlb.breakpoints.addAll(subset); } return tlb; @@ -1226,6 +1263,18 @@ private long loadSlow(final long address, final int sizeLog2) throws R5MemoryAcc try { if (range.device.supportsFetch()) { final TLBEntry entry = updateTLB(loadTLB, address, physicalAddress, range); + final Interval pageInterval = new Interval(address & ~R5.PAGE_ADDRESS_MASK, 1 << R5.PAGE_ADDRESS_SHIFT); + final Interval addressInterval = new Interval(address, 1L << sizeLog2); + entry.watchpoints = new ArrayList<>(); + for(Watchpoint w : debugInterface.readWatchpoints) { + if(w.range().intersects(pageInterval)) { + entry.watchpoints.add(w); + if(w.range().intersects(addressInterval)) { + triggeredWatchpoint = true; + triggeredWatchpointAddress = address; + } + } + } return entry.device.load((int) (address + entry.toOffset), sizeLog2); } else { return range.device.load((int) (physicalAddress - range.address()), sizeLog2); @@ -1245,6 +1294,18 @@ private void storeSlow(final long address, final long value, final int sizeLog2) try { if (range.device.supportsFetch()) { final TLBEntry entry = updateTLB(storeTLB, address, physicalAddress, range); + Interval pageInterval = new Interval(address & ~R5.PAGE_ADDRESS_MASK, 1 << R5.PAGE_ADDRESS_SHIFT); + final Interval addressInterval = new Interval(address, 1L << sizeLog2); + entry.watchpoints = new ArrayList<>(); + for(Watchpoint w : debugInterface.writeWatchpoints) { + if(w.range().intersects(pageInterval)) { + entry.watchpoints.add(w); + if(w.range().intersects(addressInterval)) { + triggeredWatchpoint = true; + triggeredWatchpointAddress = address; + } + } + } final int offset = (int) (address + entry.toOffset); entry.device.store(offset, value, sizeLog2); physicalMemory.setDirty(range, offset); @@ -3280,12 +3341,35 @@ private static final class TLBEntry { public MemoryMappedDevice device; //Subset of complete breakpoint set public LongSet breakpoints; + @Nonnull + public List watchpoints = new ArrayList<>(); + + private void addWatchpoint(Watchpoint watchpoint) { + watchpoints.add(watchpoint); + } + + private void removeWatchpoint(Watchpoint watchpoint) { + watchpoints.remove(watchpoint); + } + + private boolean hasWatchpoint(Interval interval) { + for(Watchpoint w : watchpoints) { + if(w.range().intersects(interval)) { + return true; + } + } + return false; + } } private final class DebugInterface implements CPUDebugInterface { private final Collection breakpointListeners = new ArrayList<>(); + private final Collection watchpointListeners = new ArrayList<>(); private final LongSortedSet breakpoints = new LongAVLTreeSet(); + private final List readWatchpoints = new ArrayList<>(); + private final List writeWatchpoints = new ArrayList<>(); + @Override public long getProgramCounter() { return pc; @@ -3306,6 +3390,31 @@ public long[] getGeneralRegisters() { return x; } + @Override + public long[] getFloatingRegisters() { + return f; + } + + @Override + public byte getPriv() { + return (byte) priv; + } + + @Override + public void setPriv(byte value) { + setPrivilege(value); + } + + @Override + public long getCSR(short csr) throws R5IllegalInstructionException { + return readCSR(csr); + } + + @Override + public void setCSR(short csr, long value) throws R5IllegalInstructionException { + writeCSR(csr, value); + } + @Override public byte[] loadDebug(final long address, final int size) throws R5MemoryAccessException { final byte[] mem = new byte[size]; @@ -3359,6 +3468,18 @@ public void removeBreakpointListener(final LongConsumer listener) { breakpointListeners.remove(listener); } + @Override + public void addWatchpointListener(final LongConsumer listener) { + if (!watchpointListeners.contains(listener)) { + watchpointListeners.add(listener); + } + } + + @Override + public void removeWatchpointListener(final LongConsumer listener) { + watchpointListeners.remove(listener); + } + @Override public void addBreakpoint(final long address) { breakpoints.add(address); @@ -3382,6 +3503,59 @@ public void removeBreakpoint(final long address) { } } + @Override + public void addWatchpoint(final Watchpoint watchpoint) { + final Interval interval = watchpoint.range(); + + if (watchpoint.read()) { + readWatchpoints.add(watchpoint); + tlbEntriesInInterval(interval, MemoryAccessType.LOAD).forEachRemaining((entry) -> entry.addWatchpoint(watchpoint)); + } + + if (watchpoint.write()) { + writeWatchpoints.add(watchpoint); + tlbEntriesInInterval(interval, MemoryAccessType.STORE).forEachRemaining((entry) -> entry.addWatchpoint(watchpoint)); + } + } + + @Override + public void removeWatchpoint(Watchpoint watchpoint) { + final Interval interval = watchpoint.range(); + if (watchpoint.read()) { + readWatchpoints.remove(watchpoint); + tlbEntriesInInterval(interval, MemoryAccessType.LOAD).forEachRemaining((entry) -> entry.removeWatchpoint(watchpoint)); + } + + if (watchpoint.write()) { + writeWatchpoints.remove(watchpoint); + tlbEntriesInInterval(interval, MemoryAccessType.STORE).forEachRemaining((entry) -> entry.removeWatchpoint(watchpoint)); + } + } + + private Iterator tlbEntriesInInterval(Interval interval, MemoryAccessType accessType) { + return new Iterator<>() { + private long pageAddr = interval.start() & ~R5.PAGE_ADDRESS_MASK; + private TLBEntry next; + + @Override + public boolean hasNext() { + while (pageAddr <= interval.end()) { + next = tryGetTLBEntry(pageAddr, accessType); + pageAddr += (1 << R5.PAGE_ADDRESS_SHIFT); + if (next != null) { + return true; + } + } + return false; + } + + @Override + public TLBEntry next() { + return next; + } + }; + } + /** * Used by the GDB stub for debugging. We have special requirements compared to normal memory access. * 1. Need to bypass access protection, particularly the R/W bits @@ -3425,5 +3599,11 @@ private void handleBreakpoint(final long pc) { listener.accept(pc); } } + + private void handleWatchpoint(final long address) { + for (final LongConsumer listener : watchpointListeners) { + listener.accept(address); + } + } } } diff --git a/src/main/java/li/cil/sedna/gdbstub/CPUDebugInterface.java b/src/main/java/li/cil/sedna/riscv/gdbstub/CPUDebugInterface.java similarity index 53% rename from src/main/java/li/cil/sedna/gdbstub/CPUDebugInterface.java rename to src/main/java/li/cil/sedna/riscv/gdbstub/CPUDebugInterface.java index a53b4280..58f75234 100644 --- a/src/main/java/li/cil/sedna/gdbstub/CPUDebugInterface.java +++ b/src/main/java/li/cil/sedna/riscv/gdbstub/CPUDebugInterface.java @@ -1,18 +1,21 @@ -package li.cil.sedna.gdbstub; +package li.cil.sedna.riscv.gdbstub; +import li.cil.sedna.riscv.exception.R5IllegalInstructionException; import li.cil.sedna.riscv.exception.R5MemoryAccessException; import java.util.function.LongConsumer; public interface CPUDebugInterface { + long[] getGeneralRegisters(); long getProgramCounter(); - void setProgramCounter(long value); + long[] getFloatingRegisters(); + byte getPriv(); + void setPriv(byte value); + long getCSR(short csr) throws R5IllegalInstructionException; + void setCSR(short csr, long value) throws R5IllegalInstructionException; void step(); - - long[] getGeneralRegisters(); - byte[] loadDebug(final long address, final int size) throws R5MemoryAccessException; int storeDebug(final long address, final byte[] data) throws R5MemoryAccessException; @@ -24,4 +27,12 @@ public interface CPUDebugInterface { void addBreakpoint(long address); void removeBreakpoint(long address); + + void addWatchpointListener(final LongConsumer listener); + + void removeWatchpointListener(final LongConsumer listener); + + void addWatchpoint(Watchpoint watchpoint); + + void removeWatchpoint(Watchpoint watchpoint); } diff --git a/src/main/java/li/cil/sedna/riscv/gdbstub/GDBBinaryOutputStream.java b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBBinaryOutputStream.java new file mode 100644 index 00000000..85d1bcf5 --- /dev/null +++ b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBBinaryOutputStream.java @@ -0,0 +1,27 @@ +package li.cil.sedna.riscv.gdbstub; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +/** + * A FilterOutputStream that escapes raw binary for gdb transport, as described in the + * GDB docs + */ +public final class GDBBinaryOutputStream extends FilterOutputStream { + public GDBBinaryOutputStream(OutputStream out) { + super(out); + } + + @Override + public void write(int i) throws IOException { + byte b = (byte) i; + switch (b) { + case '#', '$', '}', '*' -> { + out.write('}'); + out.write(b ^ 0x20); + } + default -> out.write(b); + } + } +} diff --git a/src/main/java/li/cil/sedna/gdbstub/GDBPacketOutputStream.java b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBPacketOutputStream.java similarity index 96% rename from src/main/java/li/cil/sedna/gdbstub/GDBPacketOutputStream.java rename to src/main/java/li/cil/sedna/riscv/gdbstub/GDBPacketOutputStream.java index a4e18a2e..eef4fa03 100644 --- a/src/main/java/li/cil/sedna/gdbstub/GDBPacketOutputStream.java +++ b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBPacketOutputStream.java @@ -1,4 +1,4 @@ -package li.cil.sedna.gdbstub; +package li.cil.sedna.riscv.gdbstub; import java.io.FilterOutputStream; import java.io.IOException; diff --git a/src/main/java/li/cil/sedna/gdbstub/GDBStub.java b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBStub.java similarity index 51% rename from src/main/java/li/cil/sedna/gdbstub/GDBStub.java rename to src/main/java/li/cil/sedna/riscv/gdbstub/GDBStub.java index 45551bd8..e5df1ef0 100644 --- a/src/main/java/li/cil/sedna/gdbstub/GDBStub.java +++ b/src/main/java/li/cil/sedna/riscv/gdbstub/GDBStub.java @@ -1,8 +1,10 @@ -package li.cil.sedna.gdbstub; +package li.cil.sedna.riscv.gdbstub; +import li.cil.sedna.riscv.exception.R5IllegalInstructionException; import li.cil.sedna.riscv.exception.R5MemoryAccessException; import li.cil.sedna.utils.ByteBufferUtils; import li.cil.sedna.utils.HexUtils; +import li.cil.sedna.utils.Interval; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -22,26 +24,35 @@ private enum GDBState { STOP_REPLY } - private enum StopReason { - MESSAGE, - BREAKPOINT - } - private static final Logger LOGGER = LogManager.getLogger(); - + private static final int MAX_PACKET_SIZE = 0x2000; + private final ServerSocketChannel listeningSock; + private final CPUDebugInterface cpu; + private final byte[] targetDescription; private GDBState state = GDBState.DISCONNECTED; private InputStream input; private OutputStream output; - - private final ServerSocketChannel listeningSock; private SocketChannel sock; - private final CPUDebugInterface cpu; - public GDBStub(final ServerSocketChannel socket, final CPUDebugInterface cpu) { this.listeningSock = socket; this.cpu = cpu; this.cpu.addBreakpointListener(this::handleBreakpointHit); + this.cpu.addWatchpointListener(this::handleWatchpointHit); + this.targetDescription = loadTargetDescription(); + } + + private static byte[] loadTargetDescription() { + try (final InputStream stream = GDBStub.class.getResourceAsStream("/gdb/target-riscv64.xml"); + ByteArrayOutputStream out = new ByteArrayOutputStream()) { + if (stream == null) { + throw new RuntimeException("Target description not found"); + } + stream.transferTo(out); + return out.toByteArray(); + } catch (IOException e) { + throw new RuntimeException(e); + } } public static GDBStub createDefault(final CPUDebugInterface cpu, final int port) throws IOException { @@ -53,11 +64,42 @@ public static GDBStub createDefault(final CPUDebugInterface cpu, final int port) public void run(final boolean waitForMessage) { if (isMessageAvailable() || waitForMessage) { - runLoop(StopReason.MESSAGE); + runLoop(new MessageStop()); + } + } + + private static abstract class StopReason { + public abstract void sendStopReply(Writer w) throws IOException; + } + + private static class MessageStop extends StopReason { + @Override + public void sendStopReply(Writer w) throws IOException { + w.write("S05"); + } + } + + private static class BreakpointStop extends StopReason { + @Override + public void sendStopReply(Writer w) throws IOException { + w.write("S05"); + } + } + + private static class WatchpointStop extends StopReason { + private final long address; + public WatchpointStop(long address) { + this.address = address; + } + @Override + public void sendStopReply(Writer w) throws IOException { + w.write("T05watch:"); + HexUtils.put64BE(w, address); + w.write(';'); } } - private void runLoop(final StopReason reason) { + private void runLoop(StopReason reason) { final ByteBuffer packetBuffer = ByteBuffer.allocate(8192); loop: while (true) { @@ -77,7 +119,7 @@ private void runLoop(final StopReason reason) { case STOP_REPLY -> { try (final var s = new GDBPacketOutputStream(output); final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { - w.write("S05"); + reason.sendStopReply(w); state = GDBState.WAITING_FOR_COMMAND; } catch (final IOException e) { disconnect(); @@ -96,28 +138,13 @@ private void runLoop(final StopReason reason) { final byte command = packetBuffer.get(); switch (command) { case '?' -> { - //TODO handle different reasons try (final var s = new GDBPacketOutputStream(output); final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { w.write("S05"); } } //General Query - case 'q' -> { - final byte[] Supported = "Supported:".getBytes(StandardCharsets.US_ASCII); - final byte[] Attached = "Attached".getBytes(StandardCharsets.US_ASCII); - if (ByteBufferUtils.startsWith(packetBuffer, ByteBuffer.wrap(Supported))) { - packetBuffer.position(packetBuffer.position() + Supported.length); - handleSupported(packetBuffer); - } else if (ByteBufferUtils.startsWith(packetBuffer, ByteBuffer.wrap(Attached))) { - try (final var s = new GDBPacketOutputStream(output); - final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { - w.write("1"); - } - } else { - unknownCommand(packetBuffer); - } - } + case 'q' -> handleQuery(packetBuffer); case 'g' -> readGeneralRegisters(); case 'G' -> writeGeneralRegisters(packetBuffer); case 'm' -> handleReadMemory(packetBuffer); @@ -126,6 +153,7 @@ private void runLoop(final StopReason reason) { final byte type = packetBuffer.get(); switch (type) { case '0', '1' -> handleBreakpointAdd(packetBuffer); + case '2', '3', '4' -> handleWatchpointAdd(type, packetBuffer); default -> unknownCommand(packetBuffer); } } @@ -133,6 +161,7 @@ private void runLoop(final StopReason reason) { final byte type = packetBuffer.get(); switch (type) { case '0', '1' -> handleBreakpointRemove(packetBuffer); + case '2', '3', '4' -> handleWatchpointRemove(type, packetBuffer); default -> unknownCommand(packetBuffer); } } @@ -147,7 +176,7 @@ private void runLoop(final StopReason reason) { unknownCommand(packetBuffer); return; } - handleStep(); + reason = handleStep(); } case 'D' -> { try (final var s = new GDBPacketOutputStream(output); @@ -157,6 +186,8 @@ private void runLoop(final StopReason reason) { disconnect(); break loop; } + case 'p' -> handleReadRegister(packetBuffer); + case 'P' -> handleWriteRegister(packetBuffer); default -> unknownCommand(packetBuffer); } } catch (final IOException e) { @@ -257,14 +288,20 @@ private boolean receivePacket(final ByteBuffer buffer) { } private void handleBreakpointHit(final long address) { - runLoop(StopReason.BREAKPOINT); + runLoop(new BreakpointStop()); + } + + // Since pc is incremented before this is called, the pc will be 1 step off + // additionally, for some reason GDB immediately performs a single step before returning to the user + // so the user will observe the pc to be 2 steps away from where the read/write happened + private void handleWatchpointHit(final long address) { + runLoop(new WatchpointStop(address)); } private void handleSupported(final ByteBuffer packet) throws IOException { try (final var s = new GDBPacketOutputStream(output); final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { - // Size in hex - w.write("PacketSize=2000"); + w.write("PacketSize=%x;qXfer:features:read+".formatted(MAX_PACKET_SIZE)); } } @@ -314,7 +351,7 @@ private void handleWriteMemory(final ByteBuffer buffer) throws IOException { private void handleBreakpointAdd(final ByteBuffer buffer) throws IOException { buffer.get(); final var chars = StandardCharsets.US_ASCII.decode(buffer); - final long address = HexUtils.getLong(chars); + final long address = HexUtils.getVarLengthInt(chars); try (final var s = new GDBPacketOutputStream(output); final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { cpu.addBreakpoint(address); @@ -325,7 +362,7 @@ private void handleBreakpointAdd(final ByteBuffer buffer) throws IOException { private void handleBreakpointRemove(final ByteBuffer buffer) throws IOException { buffer.get(); final var chars = StandardCharsets.US_ASCII.decode(buffer); - final long address = HexUtils.getLong(chars); + final long address = HexUtils.getVarLengthInt(chars); try (final var s = new GDBPacketOutputStream(output); final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { cpu.removeBreakpoint(address); @@ -333,9 +370,113 @@ private void handleBreakpointRemove(final ByteBuffer buffer) throws IOException } } - private void handleStep() { + private void handleWatchpointAdd(final byte type, final ByteBuffer buffer) throws IOException { + buffer.get(); + final String command = StandardCharsets.US_ASCII.decode(buffer).toString(); + final int addressCharsEnd = command.indexOf(','); + final long address = Long.parseUnsignedLong(command, 0, addressCharsEnd, 16); + final int length = Integer.parseInt(command, addressCharsEnd + 1, command.length(), 16); + final Interval interval = new Interval(address, length); + try (final var s = new GDBPacketOutputStream(output); + final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { + Watchpoint watchpoint = switch (type) { + case '2' -> new Watchpoint(interval, false, true); + case '3' -> new Watchpoint(interval, true, false); + case '4' -> new Watchpoint(interval, true, true); + default -> throw new IllegalStateException("Unexpected value: " + type); + }; + cpu.addWatchpoint(watchpoint); + w.write("OK"); + } + } + + private void handleWatchpointRemove(final byte type, final ByteBuffer buffer) throws IOException { + buffer.get(); + final String command = StandardCharsets.US_ASCII.decode(buffer).toString(); + final int addressCharsEnd = command.indexOf(','); + final long address = Long.parseUnsignedLong(command, 0, addressCharsEnd, 16); + final int length = Integer.parseInt(command, addressCharsEnd + 1, command.length(), 16); + final Interval interval = new Interval(address, length); + try (final var s = new GDBPacketOutputStream(output); + final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { + Watchpoint watchpoint = switch (type) { + case '2' -> new Watchpoint(interval, false, true); + case '3' -> new Watchpoint(interval, true, false); + case '4' -> new Watchpoint(interval, true, true); + default -> throw new IllegalStateException("Unexpected value: " + type); + }; + cpu.removeWatchpoint(watchpoint); + w.write("OK"); + } + } + + private MessageStop handleStep() { cpu.step(); state = GDBState.STOP_REPLY; + return new MessageStop(); + } + + private void handleQuery(ByteBuffer packetBuffer) throws IOException { + final byte[] Supported = "Supported:".getBytes(StandardCharsets.US_ASCII); + final byte[] Attached = "Attached".getBytes(StandardCharsets.US_ASCII); + final byte[] features = "Xfer:features:read:".getBytes(StandardCharsets.US_ASCII); + if (ByteBufferUtils.startsWith(packetBuffer, ByteBuffer.wrap(Supported))) { + packetBuffer.position(packetBuffer.position() + Supported.length); + handleSupported(packetBuffer); + } else if (ByteBufferUtils.startsWith(packetBuffer, ByteBuffer.wrap(Attached))) { + try (final var s = new GDBPacketOutputStream(output); + final var w = new OutputStreamWriter(s, StandardCharsets.US_ASCII)) { + w.write("1"); + } + } else if (ByteBufferUtils.startsWith(packetBuffer, ByteBuffer.wrap(features))) { + packetBuffer.position(packetBuffer.position() + features.length); + handleReadTargetDescription(packetBuffer); + } else { + unknownCommand(packetBuffer); + } + } + + private void handleReadTargetDescription(ByteBuffer buf) throws IOException { + try (final var s = new GDBPacketOutputStream(output)) { + try { + String annex = ByteBufferUtils.getStringToken(buf, (byte) ':'); + int offset = Integer.parseInt(ByteBufferUtils.getStringToken(buf, (byte) ','), 16); + int length = Integer.parseInt(ByteBufferUtils.tokenAsString(buf), 16); + handleReadTargetDescription(annex, offset, length, s); + } catch (ByteBufferUtils.TokenException e) { + LOGGER.error("Failed to parse qXfer features read packet", e); + s.write("E00".getBytes(StandardCharsets.US_ASCII)); + } + } + } + + private void handleReadTargetDescription(String annex, int offset, int length, OutputStream out) throws IOException { + if(!annex.equals("target.xml")) { + out.write("E00".getBytes(StandardCharsets.US_ASCII)); + return; + } + if(offset > targetDescription.length || offset < 0) { + out.write("E00".getBytes(StandardCharsets.US_ASCII)); + return; + } else if (offset == targetDescription.length) { + out.write('l'); + return; + } + // We need to make sure we don't exceed the max packet size + // Due to escaping each byte may take up to 2 bytes, hence the divide by 2. + // The 5 comes from 1 '$', 2 checksum bytes, 1 '#', and one 'l' for the qXfer read response + final int maxChunkLength = (MAX_PACKET_SIZE / 2) - 5; + final int maxLength = Math.min(targetDescription.length - offset, maxChunkLength); + length = Math.min(length, maxLength); + if(offset + length == targetDescription.length) { + out.write('l'); + } else { + out.write('m'); + } + + try(GDBBinaryOutputStream binOut = new GDBBinaryOutputStream(out)) { + binOut.write(targetDescription, offset, length); + } } private String asciiBytesToEscaped(final ByteBuffer bytes) { @@ -363,9 +504,9 @@ private void readGeneralRegisters() throws IOException { try (final var s = new GDBPacketOutputStream(output); final var w = new BufferedWriter(new OutputStreamWriter(s, StandardCharsets.US_ASCII))) { for (final long l : cpu.getGeneralRegisters()) { - HexUtils.putLong(w, l); + HexUtils.put64(w, l); } - HexUtils.putLong(w, cpu.getProgramCounter()); + HexUtils.put64(w, cpu.getProgramCounter()); } } @@ -382,4 +523,95 @@ private void writeGeneralRegisters(final ByteBuffer buf) throws IOException { w.write("OK"); } } + + // Must be kept in sync with target-riscv64.xml + private static final int regNumFirstX = 0; + private static final int regNumLastX = 31; + private static final int regNumPc = 32; + private static final int regNumFirstF = 33; + private static final int regNumLastF = 64; + private static final int regNumFflags = 65; + private static final int regNumFrm = 66; + private static final int regNumFcsr = 67; + private static final int regNumPriv = 68; + private static final int regNumFirstCSR = 0x1000; + private static final int regNumSwitch32 = 0x1bc0; + private static final int regNumLastCSR = 0x1fff; + + private void handleReadRegister(final ByteBuffer buffer) throws IOException { + final String regNumStr = StandardCharsets.US_ASCII.decode(buffer).toString(); + final int regNum = Integer.parseInt(regNumStr, 16); + try (final var s = new GDBPacketOutputStream(output); + final var w = new BufferedWriter(new OutputStreamWriter(s, StandardCharsets.US_ASCII))) { + try { + if (regNum >= regNumFirstX && regNum <= regNumLastX) + HexUtils.put64(w, cpu.getGeneralRegisters()[regNum - regNumFirstX]); + else if (regNum == regNumPc) HexUtils.put64(w, cpu.getProgramCounter()); + else if (regNum >= regNumFirstF && regNum <= regNumLastF) + HexUtils.put64(w, cpu.getFloatingRegisters()[regNum - regNumFirstF]); + else if (regNum == regNumFflags) HexUtils.put32(w, (int) cpu.getCSR((short) 1)); + else if (regNum == regNumFrm) HexUtils.put32(w, (int) cpu.getCSR((short) 2)); + else if (regNum == regNumFcsr) HexUtils.put32(w, (int) cpu.getCSR((short) 3)); + else if (regNum == regNumPriv) HexUtils.put64(w, cpu.getPriv()); + else if (regNum >= regNumFirstCSR && regNum <= regNumLastCSR) { + if (regNum == regNumSwitch32) { + // This is a write-only register, which GDB doesn't understand. We're + // special casing it so GDB (which always does a read before it writes) can + // write to it + HexUtils.put64(w, 0); + return; + } + try { + short csr = (short) (regNum - 0x1000); + HexUtils.put64(w, cpu.getCSR(csr)); + } catch (R5IllegalInstructionException e) { + w.write("E01"); + } + } else { + w.write("E01"); + } + } catch (R5IllegalInstructionException e) { + // Impossible + throw new RuntimeException(e); + } + } + } + + private void handleWriteRegister(final ByteBuffer buffer) throws IOException { + final String command = StandardCharsets.US_ASCII.decode(buffer).toString(); + final String[] commandArr = command.split("="); + final String regNumStr = commandArr[0]; + final int regNum = Integer.parseInt(regNumStr, 16); + final String regValStr = commandArr[1]; + final ByteBuffer regValRaw = ByteBuffer.wrap(HexFormat.of().parseHex(regValStr)).order(ByteOrder.LITTLE_ENDIAN); + try (final var s = new GDBPacketOutputStream(output); + final var w = new BufferedWriter(new OutputStreamWriter(s, StandardCharsets.US_ASCII))) { + try { + if (regNum >= regNumFirstX && regNum <= regNumLastX) + cpu.getGeneralRegisters()[regNum - regNumFirstX] = regValRaw.getLong(); + else if (regNum == regNumPc) cpu.setProgramCounter(regValRaw.getLong()); + else if (regNum >= regNumFirstF && regNum <= regNumLastF) + cpu.getFloatingRegisters()[regNum - regNumFirstF] = regValRaw.getLong(); + else if (regNum == regNumFflags) cpu.setCSR((short) 1, (byte) regValRaw.getInt()); + else if (regNum == regNumFrm) cpu.setCSR((short) 2, (byte) regValRaw.getInt()); + else if (regNum == regNumFcsr) cpu.setCSR((short) 3, regValRaw.getInt()); + else if (regNum == regNumPriv) cpu.setPriv((byte) regValRaw.getInt()); + else if (regNum >= regNumFirstCSR && regNum <= regNumLastCSR) { + try { + short csr = (short) (regNum - 0x1000); + cpu.setCSR(csr, regValRaw.getLong()); + } catch (R5IllegalInstructionException e) { + w.write("E01"); + return; + } + } else { + w.write("E01"); + return; + } + w.write("OK"); + } catch (R5IllegalInstructionException e) { + throw new RuntimeException(e); + } + } + } } diff --git a/src/main/java/li/cil/sedna/riscv/gdbstub/Watchpoint.java b/src/main/java/li/cil/sedna/riscv/gdbstub/Watchpoint.java new file mode 100644 index 00000000..3466a834 --- /dev/null +++ b/src/main/java/li/cil/sedna/riscv/gdbstub/Watchpoint.java @@ -0,0 +1,5 @@ +package li.cil.sedna.riscv.gdbstub; + +import li.cil.sedna.utils.Interval; + +public record Watchpoint(Interval range, boolean read, boolean write) {} diff --git a/src/main/java/li/cil/sedna/utils/ByteBufferUtils.java b/src/main/java/li/cil/sedna/utils/ByteBufferUtils.java index 11f93b37..8b8aeeb6 100644 --- a/src/main/java/li/cil/sedna/utils/ByteBufferUtils.java +++ b/src/main/java/li/cil/sedna/utils/ByteBufferUtils.java @@ -1,9 +1,46 @@ package li.cil.sedna.utils; +import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.StandardCharsets; public final class ByteBufferUtils { + public static final class TokenException extends Exception { + public TokenException(String message, Throwable cause){ + super(message, cause); + } + } + public static boolean startsWith(final ByteBuffer buffer, final ByteBuffer prefix) { return buffer.remaining() >= prefix.remaining() && buffer.slice(buffer.position(), prefix.remaining()).equals(prefix); } + + public static ByteBuffer getToken(final ByteBuffer buf, byte delimeter) throws TokenException { + ByteBuffer token = buf.slice(); + int len = 0; + try { + while (buf.get() != delimeter) len++; + token.limit(len); + return token; + } catch (BufferUnderflowException ex) { + throw new TokenException("Buffer missing delimeter '%c'".formatted((char)delimeter), ex); + } + } + + public static CharBuffer tokenAsChar(final ByteBuffer buf) { + return StandardCharsets.US_ASCII.decode(buf); + } + + public static CharBuffer getCharToken(final ByteBuffer buf, byte delimeter) throws TokenException { + return tokenAsChar(getToken(buf, delimeter)); + } + + public static String tokenAsString(final ByteBuffer buf) { + return tokenAsChar(buf).toString(); + } + + public static String getStringToken(final ByteBuffer buf, byte delimeter) throws TokenException { + return getCharToken(buf, delimeter).toString(); + } } diff --git a/src/main/java/li/cil/sedna/utils/HexUtils.java b/src/main/java/li/cil/sedna/utils/HexUtils.java index 6a50c4f9..cc10bf1e 100644 --- a/src/main/java/li/cil/sedna/utils/HexUtils.java +++ b/src/main/java/li/cil/sedna/utils/HexUtils.java @@ -3,21 +3,33 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.CharBuffer; -import java.nio.LongBuffer; import java.util.HexFormat; public final class HexUtils { //If we ever go multithreaded, make this a ThreadLocal - private static final ByteBuffer byteBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN); - private static final LongBuffer longBuf = byteBuf.asLongBuffer(); + private static final ByteBuffer byteBuf = ByteBuffer.allocate(8); + public static void put64(final Appendable out, final long l) { + byteBuf.order(ByteOrder.LITTLE_ENDIAN); + byteBuf.putLong(l); + HexFormat.of().formatHex(out, byteBuf.array(), 0, 8); + byteBuf.clear(); + } + + public static void put32(final Appendable out, final int i) { + byteBuf.order(ByteOrder.LITTLE_ENDIAN); + byteBuf.putInt(i); + HexFormat.of().formatHex(out, byteBuf.array(), 0, 4); + byteBuf.clear(); + } - public static void putLong(final Appendable out, final long l) { - longBuf.put(l); - HexFormat.of().formatHex(out, byteBuf.array()); - longBuf.clear(); + public static void put64BE(final Appendable out, final long l) { + byteBuf.order(ByteOrder.BIG_ENDIAN); + byteBuf.putLong(l); + HexFormat.of().formatHex(out, byteBuf.array(), 0, 8); + byteBuf.clear(); } - public static long getLong(final CharBuffer buf) { + public static long getVarLengthInt(final CharBuffer buf) { while (buf.hasRemaining() && HexFormat.isHexDigit(buf.get())) ; buf.flip(); buf.limit(buf.limit() - 1); diff --git a/src/main/java/li/cil/sedna/utils/Interval.java b/src/main/java/li/cil/sedna/utils/Interval.java new file mode 100644 index 00000000..64ed2c3a --- /dev/null +++ b/src/main/java/li/cil/sedna/utils/Interval.java @@ -0,0 +1,58 @@ +package li.cil.sedna.utils; + +import javax.annotation.Nonnull; +import java.util.Objects; + +public final class Interval implements Comparable { + private final long start; + private final long end; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Interval interval = (Interval) o; + return start == interval.start && end == interval.end; + } + + @Override + public int hashCode() { + return Objects.hash(start, end); + } + + public Interval(final long start, final long length) { + this.start = start; + this.end = start + (length - 1); + } + + @Nonnull + public static Interval fromEndpoint(final long start, final long end) { + if(end < start) throw new IllegalArgumentException("End must be >= start"); + return new Interval(start, end - start + 1); + } + + @Override + public int compareTo(Interval other) { + int cmp = Long.compareUnsigned(start, other.start); + if(cmp == 0) { + cmp = Long.compareUnsigned(end, other.end); + } + return cmp; + } + + public boolean intersects(Interval other) { + return Long.compareUnsigned(start, other.end) <= 0 && Long.compareUnsigned(end, other.start) >= 0; + } + + public long start() { + return start; + } + + public long length() { + return end - start + 1; + } + + public long end() { + return end; + } +} diff --git a/src/main/resources/gdb/target-riscv64.xml b/src/main/resources/gdb/target-riscv64.xml new file mode 100644 index 00000000..a3175da9 --- /dev/null +++ b/src/main/resources/gdb/target-riscv64.xml @@ -0,0 +1,139 @@ + + + + riscv:rv64 + none + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +