Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-163: Modify layout to not use 'pinning' concept #28

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 104 additions & 78 deletions zirgen/dsl/passes/GenerateLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,88 +44,105 @@ using namespace zirgen::ZStruct;
* that if a vertex is visited twice (i.e. aliased) that it is assigned to the
* same columns.
*
* The second rule is handled by the AllocationTable, which keeps track of which
* columns are allocated in the "current scope." We push a new scope when
* Additionally, to allow depth first traversal to allocate via a 'bump allocator'
* type of approach, we precompute for each aliases component, it's innermost
* 'common' component, i.e. the deepest component which all aliased copies live
* inside of. When we reach this component, we allocate space for any aliased
* descendants prior to continuing our normal depth first traversal, knowing that
* the 'memo' will resolve the descendents when the traversal reaches them.
*
* The reuse of muxing is handled by tracking the largest column allocated in
* the "current scope." We push a new scope when
* visiting the arms of a mux, such that we can "pop" it and reuse those columns
* on the next arm. Afterwards, we mark any columns used by any mux arm as used,
* pursuant to the third rule. Thus, a mux typically needs as many columns as
* its largest arm (keep reading).
*
* There is an extra complication with layout aliases around muxes: when layouts
* are aliased between mux arms, those layouts must be placed in the exact same
* columns, regardless of "when" they are visited relative to other layouts in
* their respective mux arms. For this reason, it is necessary to reserve those
* columns across the arms of the mux where they are shared -- which is referred
* to here as "pinning."
*
* Note: Currently, only argument components and mux supers are aliased, so we
* manually pin them rather than using the LayoutDAGAnalysis to figure this out,
* and we make the simplifying but suboptimal decision to pin them all the way
* up to the root layout since it seems to work relatively well with our own
* circuits. This should be generalized when supporting manual layout aliasing.
* its largest arm.
*/

namespace zirgen {
namespace dsl {
namespace {

class AllocationTable {
public:
AllocationTable() : parent(nullptr), storage(256, /*set=*/false) {}
AllocationTable(AllocationTable* parent) : parent(parent), storage(parent->storage) {}
// For components that are aliased into multiple locations, finds the most
// specific parent element and records it. Then, when generating concrete
// layout, we can 'preallocate' any such decendents. Note, we need to give
// every component a unique ID (by order they appear in depth first traversal)
// rather than only use maps, since otherwise pointers cause non-deterministic
// ordering.
class FindPreallocs {
using Ptr = std::shared_ptr<LayoutDAG>;
using Vec = std::vector<uint32_t>;
using MapToVec = std::map<uint32_t, Vec>;
// Result: for each abstract type, things I need to 'prealloc'
using Result = std::map<Ptr, std::vector<Ptr>>;

// Return the index of the first k consecutive unallocated columns, and mark
// them as allocated. If pinned, also mark them as allocated in the parent.
size_t allocate(size_t k, bool pinned) {
int n = 0;
while (!canAllocateContiguously(n, k)) {
n = nextIndex(n);
}
storage.set(n, n + k);
AllocationTable* ancestor = parent;
while (pinned && ancestor) {
ancestor->storage.set(n, n + k);
ancestor = ancestor->parent;
public:
Result run(const std::shared_ptr<LayoutDAG>& abstract) {
Vec prefix;
computeSharedPrefixes(abstract, prefix);
Result ret;
for (const auto& kvp : prefixes) {
// Check if final value matches key, in which case there is a unique parent
if (kvp.first == kvp.second.back()) {
continue;
}
// Other, add ret
ret[idToPtr[kvp.second.back()]].push_back(idToPtr[kvp.first]);
}

return n;
}

// If a column is allocated in either this or other, mark it as allocated
AllocationTable& operator|=(const AllocationTable& other) {
storage |= other.storage;
return *this;
return ret;
}

private:
// True iff k columns starting at n are all unallocated
bool canAllocateContiguously(int n, size_t k) {
// BitVector::find_first_in returns the index of the first set bit in a
// range, or -1 if they're all unset. If they're all unset, all k of them
// are unallocated.
return storage.find_first_in(n, n + k, /*set=*/true) == -1;
}

// Return the index of the next unallocated column, resizing storage if necessary
int nextIndex(int n) {
int next = storage.find_next_unset(n);
size_t capacity = storage.getBitCapacity();
assert(next >= -1);
if (next == -1 || (size_t)next >= capacity) {
storage.resize(2 * capacity);
// Give ID's to each entry
std::vector<Ptr> idToPtr;
std::map<Ptr, uint32_t> ptrToId;
// For an entry, what is the longest unique prefix
MapToVec prefixes;

AllocationTable* ancestor = parent;
while (ancestor) {
ancestor->storage.resize(2 * capacity);
ancestor = ancestor->parent;
static Vec sharedPrefix(const Vec& lhs, const Vec& rhs) {
if (lhs.size() == 0) { // Handle initialization case
return rhs;
}
assert(lhs.size() > 0 && rhs.size() > 0);
assert(lhs[0] == rhs[0]);
size_t i = 0;
while (i < std::min(lhs.size(), rhs.size())) {
if (lhs[i] != rhs[i]) {
break;
}
next = storage.find_next_unset(n);
i++;
}
return next;
return Vec(lhs.begin(), lhs.begin() + i);
}

AllocationTable* parent;
BitVector storage;
void computeSharedPrefixes(const std::shared_ptr<LayoutDAG>& abstract, Vec& prefix) {
// Compute ID for element
uint32_t id;
if (ptrToId.count(abstract)) {
id = ptrToId[abstract];
} else {
ptrToId[abstract] = idToPtr.size();
id = idToPtr.size();
idToPtr.push_back(abstract);
}
// Add to current prefix path compute shared prefix
prefix.push_back(id);
prefixes[id] = sharedPrefix(prefixes[id], prefix);
// Descend for recursive types
if (const auto* arr = std::get_if<AbstractArray>(abstract.get())) {
for (auto element : arr->elements) {
computeSharedPrefixes(element, prefix);
}
} else if (const auto* str = std::get_if<AbstractStructure>(abstract.get())) {
for (auto field : str->fields) {
computeSharedPrefixes(field.second, prefix);
}
} else if (const auto* ref = std::get_if<std::shared_ptr<LayoutDAG>>(abstract.get())) {
computeSharedPrefixes(*ref, prefix);
}
// Pop from path
prefix.pop_back();
}
};

struct LayoutGenerator {
Expand All @@ -137,59 +154,68 @@ struct LayoutGenerator {
if (!component.getLayout())
return Attribute();

// llvm::errs() << component << "\n";
Memo memo;
AllocationTable allocator;
auto layout = solver.lookupState<LayoutDAGAnalysis::Element>(component.getLayout());
return materialize(layout->getValue().get(), memo, allocator);
FindPreallocs findPreallocs;
Preallocs preallocs = findPreallocs.run(layout->getValue().get());
size_t allocator = 0;
return materialize(layout->getValue().get(), memo, allocator, preallocs);
}

private:
// A memo of previously generated abstract layouts
using Preallocs = std::map<std::shared_ptr<LayoutDAG>, std::vector<std::shared_ptr<LayoutDAG>>>;
using Memo = DenseMap<LayoutDAG*, Attribute>;

// Materialize a concrete layout attribute from an abstract layout
Attribute materialize(const std::shared_ptr<LayoutDAG>& abstract,
Memo& memo,
AllocationTable& allocator,
bool pinned = false) {
size_t& allocator,
const Preallocs& preallocs) {
if (memo.contains(abstract.get())) {
return memo.at(abstract.get());
}
auto it = preallocs.find(abstract);
if (it != preallocs.end()) {
for (auto ptr : it->second) {
// Allocate + memoize 'preallocs'
materialize(ptr, memo, allocator, preallocs);
}
}

Attribute attr;
if (const auto* reg = std::get_if<AbstractRegister>(abstract.get())) {
// Allocate multiple columns for extension field elements
size_t size = reg->type.getElement().getFieldK();
size_t index = allocator.allocate(size, pinned);
size_t index = allocator;
allocator += reg->type.getElement().getFieldK();
attr = RefAttr::get(reg->type.getContext(), index, reg->type);
} else if (const auto* arr = std::get_if<AbstractArray>(abstract.get())) {
SmallVector<Attribute, 4> elements;
for (auto element : arr->elements) {
elements.push_back(materialize(element, memo, allocator, pinned));
elements.push_back(materialize(element, memo, allocator, preallocs));
}
attr = ArrayAttr::get(arr->type.getContext(), elements);
} else if (const auto* str = std::get_if<AbstractStructure>(abstract.get())) {
SmallVector<NamedAttribute> fields;
if (str->type.getKind() == LayoutKind::Mux) {
AllocationTable finalAllocator = allocator;
size_t finalAllocator = allocator;
for (auto field : str->fields) {
AllocationTable armAllocator(&allocator);
bool armPinned = pinned || field.first == "@super";
size_t armAllocator = allocator;
fields.emplace_back(field.first,
materialize(field.second, memo, armAllocator, armPinned));
finalAllocator |= armAllocator;
materialize(field.second, memo, armAllocator, preallocs));
finalAllocator = std::max(finalAllocator, armAllocator);
}
allocator = finalAllocator;
} else {
bool strPinned = pinned || (str->type.getKind() == LayoutKind::Argument);
for (auto field : str->fields) {
fields.emplace_back(field.first, materialize(field.second, memo, allocator, strPinned));
fields.emplace_back(field.first, materialize(field.second, memo, allocator, preallocs));
}
}
auto members = DictionaryAttr::get(str->type.getContext(), fields);
attr = StructAttr::get(str->type.getContext(), members, str->type);
} else if (const auto* ref = std::get_if<std::shared_ptr<LayoutDAG>>(abstract.get())) {
attr = materialize(*ref, memo, allocator, pinned);
attr = materialize(*ref, memo, allocator, preallocs);
} else {
llvm_unreachable("bad variant");
}
Expand Down
Loading