Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-307: Fix host read #148

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions zirgen/circuit/rv32im/v2/dsl/arr.zir
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// This file contains utilities that work with bits and twits.
// RUN: zirgen --test %s

// Vector / List functions
Expand Down Expand Up @@ -37,11 +36,10 @@ component EqArr<SIZE: Val>(a: Array<Val, SIZE>, b: Array<Val, SIZE>) {
// Tests....

test ShiftAndRotate {
// TODO: Now that these support non-bit values, maybe make new tests
// Remember: array entry 0 is the low bit, so there seem backwards
EqArr<8>(ShiftRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [0, 0, 1, 1, 1, 0, 1, 0]);
EqArr<8>(RotateRight<8>([1, 1, 1, 0, 1, 0, 0, 0], 2), [1, 0, 1, 0, 0, 0, 1, 1]);
EqArr<8>(RotateLeft<8>([1, 1, 1, 0, 1, 0, 0, 1], 2), [0, 1, 1, 1, 1, 0, 1, 0]);
EqArr<8>(ShiftRight<8>([3, 1, 5, 0, 2, 0, 0, 0], 2), [5, 0, 2, 0, 0, 0, 0, 0]);
EqArr<8>(ShiftLeft<8>([1, 4, 2, 0, 6, 0, 0, 0], 2), [0, 0, 1, 4, 2, 0, 6, 0]);
EqArr<8>(RotateRight<8>([7, 6, 1, 0, 2, 0, 0, 0], 2), [1, 0, 2, 0, 0, 0, 7, 6]);
EqArr<8>(RotateLeft<8>([4, 5, 1, 0, 1, 0, 0, 3], 2), [0, 3, 4, 5, 1, 0, 1, 0]);
}

73 changes: 58 additions & 15 deletions zirgen/circuit/rv32im/v2/dsl/inst_ecall.zir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inst;
import consts;


// Prepare to read a certain length, maybe return a smaller one
extern HostReadPrepare(fd: Val, len: Val): Val;

Expand Down Expand Up @@ -87,8 +86,8 @@ component ECallHostReadSetup(cycle: Reg, input: InstInput) {
lenDecomp := DecomposeLow2(newLen);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check if things are 'uneven'
uneven := Reg(len123 * ptrDecomp.low2Nonzero);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + ptrDecomp.low2Nonzero - len123 * ptrDecomp.low2Nonzero);
// Now pick the next cycle
nextCycle :=
// If length == 0, go back to decoding
Expand Down Expand Up @@ -118,36 +117,80 @@ component ECallHostWrite(cycle: Reg, input: InstInput) {
ECallOutput(StateDecode(), 0, 0, 0)
}

component ECallHostReadBytes(cycle: Reg, input: InstInput) {
// TODO
component ECallHostReadBytes(cycle: Reg, input: InstInput, ptrWord: Val, ptrLow2: Val, len: Val) {
input.state = StateHostReadBytes();
0 = 1;
ECallOutput(16, 0, 0, 0)
// Decompose next len
lenDecomp := DecomposeLow2(len - 1);
// Check if length is exactly 1, 2, or 3
len123 := Reg(lenDecomp.highZero * lenDecomp.low2Nonzero);
// Check is next pointer is even (this can only happen if Low2 == 3)
nextPtrEven := IsZero(ptrLow2 - 3);
nextPtrUneven := 1 - nextPtrEven;
nextPtrWord := nextPtrEven * (ptrWord + 1) + nextPtrUneven * ptrWord;
nextPtrLow2 := nextPtrUneven * (ptrLow2 + 1);
// Check if things are 'uneven' (this is an 'or')
uneven := Reg(len123 + nextPtrUneven - len123 * nextPtrUneven);
// Check is length is exactly zero
lenZero := IsZero(len - 1);
// Split low bits into parts
low0 := NondetBitReg(ptrLow2 & 1);
low1 := BitReg((ptrLow2 - low0) / 2);
// Load the original word
origWord := MemoryRead(cycle, ptrWord);
// Write the answer
io := MemoryWriteUnconstrained(cycle, ptrWord).io;
// Make the non-specified half matches
if (low1) {
origWord.low = io.newTxn.dataLow;
} else {
origWord.high = io.newTxn.dataHigh;
};
// Get the half that changed
oldHalf := low1 * origWord.high + (1 - low1) * origWord.low;
newHalf := low1 * io.newTxn.dataHigh + (1 - low1) * io.newTxn.dataLow;
// Split both into bytes
oldBytes := SplitWord(oldHalf);
newBytes := SplitWord(newHalf);
// Make sure the non-specified bytes matchs
if (low0) {
oldBytes.byte0 = newBytes.byte0;
} else {
oldBytes.byte1 = newBytes.byte1;
};
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * uneven * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - uneven) * StateHostReadWords();
ECallOutput(nextCycle, nextPtrWord, nextPtrLow2, len - 1)
}

component ECallHostReadWords(cycle: Reg, input: InstInput, ptrWord: Val, len: Val) {
input.state = StateHostReadWords();
lenDecomp := DecomposeLow2(len);
wordsDecomp := DecomposeLow2(lenDecomp.high);
doWord := [
wordsDecomp.low2Hot[1] * wordsDecomp.highZero,
wordsDecomp.low2Hot[2] * wordsDecomp.highZero,
wordsDecomp.low2Hot[3] * wordsDecomp.highZero,
(wordsDecomp.low2Hot[1] + wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[2] + wordsDecomp.low2Hot[3])* wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(wordsDecomp.low2Hot[3]) * wordsDecomp.highZero + (1 - wordsDecomp.highZero),
(1 - wordsDecomp.highZero)
];
count := reduce doWord init 0 with Add;
for i : 0..4 {
addr := Reg(doWord[i] * (ptrWord + i) + (1 - doWord[i]) * SafeWriteWord());
MemoryWriteUnconstrained(cycle, addr);
};
lenZero := IsZero(len - 4 * count);
newLenHighZero := IsZero(lenDecomp.high - count);
lenZero := Reg(newLenHighZero * (1 - lenDecomp.low2Nonzero));
nextCycle :=
// If length == 0, go back to decoding
lenZero * StateDecode() +
// If length != 0 and uneven, do bytes
(1 - lenZero) * (lenDecomp.low2Nonzero) * StateHostReadBytes() +
// If lengtj != 0 and even, more words
(1 - lenZero) * (1 - lenDecomp.low2Nonzero) * StateHostReadWords();
(1 - lenZero) * newLenHighZero * StateHostReadBytes() +
// If length != 0 and even, more words
(1 - lenZero) * (1 - newLenHighZero) * StateHostReadWords();
ECallOutput(nextCycle, ptrWord + count, 0, len - count * 4)
}

Expand All @@ -163,7 +206,7 @@ component ECall0(cycle: Reg, inst_input: InstInput) {
ECallTerminate(cycle, inst_input),
ECallHostReadSetup(cycle, inst_input),
ECallHostWrite(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input),
ECallHostReadBytes(cycle, inst_input, s0@1, s1@1, s2@1),
ECallHostReadWords(cycle, inst_input, s0@1, s2@1),
IllegalECall(),
IllegalECall()
Expand Down
2 changes: 1 addition & 1 deletion zirgen/circuit/rv32im/v2/dsl/mem.zir
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ component MemoryWrite(cycle: Reg, addr: Val, data: ValU32) {

// Let the host write anythings (used in host read words)
component MemoryWriteUnconstrained(cycle: Reg, addr: Val) {
io := MemoryIO(2*cycle + 1, addr);
public io := MemoryIO(2*cycle + 1, addr);
IsForward(io);
}

Expand Down
4 changes: 0 additions & 4 deletions zirgen/circuit/rv32im/v2/dsl/top.zir
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
// RUN: true

// TODO: Now that the v2 circuit uses an extern to compute major/minor it no
// longer makes sense to do rv32im conformance testing here. Make sure
// integration tests are covering this.

import inst_div;
import inst_misc;
import inst_mul;
Expand Down
1 change: 1 addition & 0 deletions zirgen/circuit/rv32im/v2/emu/preflight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct PreflightContext {
}
size_t rlen = segment.readRecord[curRead].size();
memcpy(data, segment.readRecord[curRead].data(), rlen);
curRead++;
return rlen;
}

Expand Down
22 changes: 15 additions & 7 deletions zirgen/circuit/rv32im/v2/emu/r0vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,22 @@ template <typename Context> struct R0Context {
std::vector<uint8_t> bytes(len);
rlen = context.read(fd, bytes.data(), len);
storeReg(REG_A0, rlen);
uint32_t i = 0;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
uint32_t i = 0;
while (rlen > 0 && ptr % 4 != 0) {
writeByte(ptr, bytes[i]);
// context.hostReadBytes(ptr);
ptr++;
i++;
rlen--;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
while (rlen >= 4) {
uint32_t words = std::min(rlen / 4, uint32_t(4));
Expand All @@ -195,25 +199,29 @@ template <typename Context> struct R0Context {
word |= bytes[i + k] << (8 * k);
}
storeMem(ptr / 4, word);
ptr += 4;
i += 4;
rlen -= 4;
} else {
storeMem(SAFE_WRITE_WORD, 0);
}
ptr += words;
i += words;
rlen -= words;
}
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
while (rlen > 0 && ptr % 4 != 0) {
while (rlen > 0) {
writeByte(ptr, bytes[i]);
// context.hostReadBytes(ptr);
ptr++;
i++;
rlen--;
if (rlen == 0) {
context.pc += 4;
}
context.ecallCycle(curState, nextState(ptr, rlen), ptr / 4, ptr % 4, rlen);
curState = nextState(ptr, rlen);
}
return false;
}
Expand Down
23 changes: 23 additions & 0 deletions zirgen/circuit/rv32im/v2/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,27 @@ cc_binary(
],
)

risc0_cc_kernel_binary(
name = "test_io_kernel",
srcs = [
"entry.s",
"test_io_kernel.cpp",
],
deps = ["//zirgen/circuit/rv32im/v2/platform:core"],
)

cc_test(
name = "test_io",
srcs = [
"test_io.cpp",
],
data = [
":test_io_kernel",
],
deps = [
"//risc0/core",
"//zirgen/circuit/rv32im/v2/run",
]
)

riscv_test_suite()
50 changes: 50 additions & 0 deletions zirgen/circuit/rv32im/v2/test/test_io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>

#include "zirgen/circuit/rv32im/v2/platform/constants.h"
#include "zirgen/circuit/rv32im/v2/run/run.h"

using namespace zirgen::rv32im_v2;

const std::string kernelName = "zirgen/circuit/rv32im/v2/test/test_io_kernel";

// Allows reads of any size, fill with a pattern to check in kernel
struct RandomReadSizeHandler : public HostIoHandler {
uint32_t write(uint32_t fd, const uint8_t* data, uint32_t len) override { return len; }
uint32_t read(uint32_t fd, uint8_t* data, uint32_t len) override {
std::cout << "DOING READ OF SIZE " << len << "\n";
for (size_t i = 0; i < len; i++) {
data[i] = i;
}
return len;
}
};

int main() {
size_t cycles = 100000;
RandomReadSizeHandler io;

// Load image
auto image = MemoryImage::fromRawElf(kernelName);
// Do executions
auto segments = execute(image, io, cycles, cycles);
// Do 'run' (preflight + expansion)
for (const auto& segment : segments) {
std::cout << "HEY, doing a segment!\n";
runSegment(segment, cycles + 1000);
}
std::cout << "What a fine day\n";
}
85 changes: 85 additions & 0 deletions zirgen/circuit/rv32im/v2/test/test_io_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdint.h>
#include <sys/errno.h>

#include "zirgen/circuit/rv32im/v2/platform/constants.h"

using namespace zirgen::rv32im_v2;

inline void die() {
asm("fence\n");
}

// Implement machine mode ECALLS

inline void terminate(uint32_t val) {
register uintptr_t a0 asm("a0") = val;
register uintptr_t a7 asm("a7") = 0;
asm volatile("ecall\n"
: // no outputs
: "r"(a0), "r"(a7) // inputs
: // no clobbers
);
}

inline uint32_t host_read(uint32_t fd, uint32_t buf, uint32_t len) {
register uintptr_t a0 asm("a0") = fd;
register uintptr_t a1 asm("a1") = buf;
register uintptr_t a2 asm("a2") = len;
register uintptr_t a7 asm("a7") = 1;
asm volatile("ecall\n"
: "+r"(a0) // outputs
: "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs
: // no clobbers
);
return a0;
}

inline uint32_t host_write(uint32_t fd, uint32_t buf, uint32_t len) {
register uintptr_t a0 asm("a0") = fd;
register uintptr_t a1 asm("a1") = buf;
register uintptr_t a2 asm("a2") = len;
register uintptr_t a7 asm("a7") = 2;
asm volatile("ecall\n"
: "+r"(a0) // outputs
: "r"(a0), "r"(a1), "r"(a2), "r"(a7) // inputs
: // no clobbers
);
return a0;
}

constexpr uint32_t sizes[11] = {0, 1, 2, 3, 4, 5, 7, 13, 19, 40, 101};

void test_multi_read() {
uint8_t buf[200];
// Try all 4 alignments
for (size_t i = 0; i < 4; i++) {
// Try a variety of size
for (size_t j = 0; j < 11; j++) {
host_read(0, (uint32_t)(buf + i), sizes[j]);
for (size_t k = 0; k < sizes[j]; k++) {
if (buf[i + k] != k) {
die();
}
}
}
}
}

extern "C" void start() {
test_multi_read();
terminate(0);
}
Loading