Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-215: abstract bibc file format reader & writer for use on memory buffers #54

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions zirgen/Dialect/BigInt/Bytecode/decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ mlir::func::FuncOp decode(mlir::ModuleOp module, const Program& prog) {
}
}

// Add terminator op, for the sake of propriety.
builder.create<mlir::func::ReturnOp>(loc);
return out;
}

Expand Down
155 changes: 120 additions & 35 deletions zirgen/Dialect/BigInt/Bytecode/file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "file.h"
#include <array>
#include <cstring>
#include <stdexcept>
#include <string>

Expand Down Expand Up @@ -78,25 +79,27 @@ IOException::IOException(const char* file, const char* func, int line, const cha

namespace {

void writeU16(uint16_t value, FILE* stream) {
struct Writer {
virtual void write(const void*, size_t) = 0;
};

void writeU16(uint16_t value, Writer& stream) {
std::array<uint8_t, 2> buf;
buf[0] = (value >> 0x00) & 0xFFU;
buf[1] = (value >> 0x08) & 0xFFU;
size_t writ = fwrite(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !writ);
stream.write(buf.data(), buf.size());
}

void writeU32(uint32_t value, FILE* stream) {
void writeU32(uint32_t value, Writer& stream) {
std::array<uint8_t, 4> buf;
buf[0] = (value >> 0x00) & 0xFFU;
buf[1] = (value >> 0x08) & 0xFFU;
buf[2] = (value >> 0x10) & 0xFFU;
buf[3] = (value >> 0x18) & 0xFFU;
size_t writ = fwrite(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !writ);
stream.write(buf.data(), buf.size());
}

void writeU64(uint64_t value, FILE* stream) {
void writeU64(uint64_t value, Writer& stream) {
std::array<uint8_t, 8> buf;
buf[0] = (value >> 0x00) & 0xFFU;
buf[1] = (value >> 0x08) & 0xFFU;
Expand All @@ -106,11 +109,10 @@ void writeU64(uint64_t value, FILE* stream) {
buf[5] = (value >> 0x28) & 0xFFU;
buf[6] = (value >> 0x30) & 0xFFU;
buf[7] = (value >> 0x38) & 0xFFU;
size_t writ = fwrite(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !writ);
stream.write(buf.data(), buf.size());
}

void writeHeader(const Program& p, FILE* stream) {
void writeHeader(const Program& p, Writer& stream) {
writeU32(MAGIC, stream);
writeU32(1, stream);
writeU32(p.inputs.size(), stream);
Expand All @@ -119,21 +121,21 @@ void writeHeader(const Program& p, FILE* stream) {
writeU32(p.ops.size(), stream);
}

void writeInput(const Input& i, FILE* stream) {
void writeInput(const Input& i, Writer& stream) {
writeU64(i.label, stream);
writeU32(i.bitWidth, stream);
writeU16(i.minBits, stream);
writeU16(i.isPublic ? 1 : 0, stream);
}

void writeType(const Type& t, FILE* stream) {
void writeType(const Type& t, Writer& stream) {
writeU64(t.coeffs, stream);
writeU64(t.maxPos, stream);
writeU64(t.maxNeg, stream);
writeU64(t.minBits, stream);
}

void writeOp(const Op& o, FILE* stream) {
void writeOp(const Op& o, Writer& stream) {
// Pack operation struct fields into a single 64-bit word
uint64_t w = 0;
check(static_cast<uint8_t>(o.code) >= 0x10);
Expand All @@ -147,9 +149,7 @@ void writeOp(const Op& o, FILE* stream) {
writeU64(w, stream);
}

} // namespace

void write(const Program& p, FILE* stream) {
void writeProgram(const Program& p, Writer& stream) {
writeHeader(p, stream);
// inputs referenced through 24-bit operand
check(p.inputs.size() > 0x00FFFFFF);
Expand All @@ -173,37 +173,86 @@ void write(const Program& p, FILE* stream) {
}
}

struct Teller : public Writer {
void write(const void*, size_t len) override {
// Count the bytes that would be written, if we were going to write.
total += len;
}
size_t total = 0;
};

struct FileWriter : public Writer {
FileWriter(FILE* stream) : stream(stream) { check(!stream); }
void write(const void* buf, size_t len) override {
size_t writ = fwrite(buf, len, 1, stream);
check(ferror(stream) || !writ);
}

private:
FILE* stream = nullptr;
};

struct BufWriter : public Writer {
BufWriter(void* buf, size_t len) : buf(buf), remaining(len) { check(!buf); }
void write(const void* src, size_t len) override {
check(remaining < len);
std::memmove(buf, src, len);
remaining -= len;
buf = &static_cast<uint8_t*>(buf)[len];
}

private:
void* buf = nullptr;
size_t remaining = 0;
};

} // namespace

size_t tell(const Program& p) {
Teller dest;
writeProgram(p, dest);
return dest.total;
}

void write(const Program& p, FILE* stream) {
FileWriter dest(stream);
writeProgram(p, dest);
}

void write(const Program& p, void* buf, size_t len) {
BufWriter dest(buf, len);
writeProgram(p, dest);
}

namespace {

uint32_t readU16(FILE* stream) {
check(feof(stream));
struct Reader {
virtual void read(void*, size_t) = 0;
};

uint32_t readU16(Reader& stream) {
std::array<uint8_t, 2> buf;
size_t got = fread(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !got);
stream.read(buf.data(), buf.size());
return (static_cast<uint16_t>(buf[0]) << 0x00) | (static_cast<uint16_t>(buf[1]) << 0x08);
}

uint32_t readU32(FILE* stream) {
check(feof(stream));
uint32_t readU32(Reader& stream) {
std::array<uint8_t, 4> buf;
size_t got = fread(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !got);
stream.read(buf.data(), buf.size());
return (static_cast<uint32_t>(buf[0]) << 0x00) | (static_cast<uint32_t>(buf[1]) << 0x08) |
(static_cast<uint32_t>(buf[2]) << 0x10) | (static_cast<uint32_t>(buf[3]) << 0x18);
}

uint64_t readU64(FILE* stream) {
check(feof(stream));
uint64_t readU64(Reader& stream) {
std::array<uint8_t, 8> buf;
size_t got = fread(buf.data(), buf.size(), 1, stream);
check(ferror(stream) || !got);
stream.read(buf.data(), buf.size());
return (static_cast<uint64_t>(buf[0]) << 0x00) | (static_cast<uint64_t>(buf[1]) << 0x08) |
(static_cast<uint64_t>(buf[2]) << 0x10) | (static_cast<uint64_t>(buf[3]) << 0x18) |
(static_cast<uint64_t>(buf[4]) << 0x20) | (static_cast<uint64_t>(buf[5]) << 0x28) |
(static_cast<uint64_t>(buf[6]) << 0x30) | (static_cast<uint64_t>(buf[7]) << 0x38);
}

void readHeader(Program& p, FILE* stream) {
void readHeader(Program& p, Reader& stream) {
check(MAGIC != readU32(stream));
check(1 != readU32(stream));
p.inputs.resize(readU32(stream));
Expand All @@ -212,31 +261,29 @@ void readHeader(Program& p, FILE* stream) {
p.ops.resize(readU32(stream));
}

void readInput(Input& wire, FILE* stream) {
void readInput(Input& wire, Reader& stream) {
wire.label = readU64(stream);
wire.bitWidth = readU32(stream);
wire.minBits = readU16(stream);
wire.isPublic = readU16(stream) != 0;
}

void readType(Type& t, FILE* stream) {
void readType(Type& t, Reader& stream) {
t.coeffs = readU64(stream);
t.maxPos = readU64(stream);
t.maxNeg = readU64(stream);
t.minBits = readU64(stream);
}

void readOp(Op& o, FILE* stream) {
void readOp(Op& o, Reader& stream) {
uint64_t bits = readU64(stream);
o.code = (bits >> 0) & 0x0F;
o.type = (bits >> 4) & 0x0FFF;
o.operandA = (bits >> 16) & 0x00FFFFFF;
o.operandB = (bits >> 40) & 0x00FFFFFF;
}

} // namespace

void read(Program& p, FILE* stream) {
void readProgram(Program& p, Reader& stream) {
p.clear();
readHeader(p, stream);
for (size_t i = 0; i < p.inputs.size(); ++i) {
Expand All @@ -253,4 +300,42 @@ void read(Program& p, FILE* stream) {
}
}

struct FileReader : public Reader {
FileReader(FILE* stream) : stream(stream) { check(!stream); }
void read(void* buf, size_t len) override {
check(feof(stream));
size_t got = fread(buf, len, 1, stream);
check(ferror(stream) || !got);
}

private:
FILE* stream = nullptr;
};

struct BufReader : public Reader {
BufReader(const void* buf, size_t len) : buf(buf), remaining(len) { check(!buf); }
void read(void* dest, size_t len) override {
check(remaining < len);
std::memmove(dest, buf, len);
remaining -= len;
buf = &static_cast<const uint8_t*>(buf)[len];
}

private:
const void* buf = nullptr;
size_t remaining = 0;
};

} // namespace

void read(Program& p, FILE* stream) {
FileReader reader(stream);
readProgram(p, reader);
}

void read(Program& p, const void* buf, size_t len) {
BufReader reader(buf, len);
readProgram(p, reader);
}

} // namespace zirgen::BigInt::Bytecode
3 changes: 3 additions & 0 deletions zirgen/Dialect/BigInt/Bytecode/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

namespace zirgen::BigInt::Bytecode {

size_t tell(const Program&);
void write(const Program&, FILE*);
void write(const Program&, void* buf, size_t len);
void read(Program&, FILE*);
void read(Program&, const void* buf, size_t len);

struct IOException : public std::runtime_error {
IOException(const char*, const char*, int, const char*);
Expand Down
21 changes: 21 additions & 0 deletions zirgen/Dialect/BigInt/Bytecode/test/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(
default_visibility = ["//visibility:public"],
)

load("//bazel/rules/lit:defs.bzl", "glob_lit_tests")

glob_lit_tests()

cc_test(
name = "test",
srcs = [
"test.cpp",
],
deps = [
"//risc0/core/test:gtest_main",
"//zirgen/Dialect/BigInt/Bytecode",
"//zirgen/Dialect/BigInt/IR",
"//zirgen/Dialect/BigInt/Transforms",
"//zirgen/circuit/bigint:lib",
],
)
Loading
Loading