Skip to content

Commit

Permalink
relax hash-kind checks and invert roots
Browse files Browse the repository at this point in the history
  • Loading branch information
rachel-bousfield committed Jul 20, 2022
1 parent 00fc5a7 commit 54c7daa
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 73 deletions.
2 changes: 1 addition & 1 deletion arbstate/das_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (cert *DataAvailabilityCertificate) RecoverKeyset(
if err != nil {
return nil, err
}
if dastree.Hash(keysetBytes) != cert.KeysetHash {
if !dastree.ValidHash(cert.KeysetHash, keysetBytes) {
return nil, errors.New("keyset hash does not match cert")
}
return DeserializeKeyset(bytes.NewReader(keysetBytes))
Expand Down
2 changes: 1 addition & 1 deletion cmd/datool/datool.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func startRPCClientGetByHash(args []string) error {
}

ctx := context.Background()
message, err := client.GetByHash(ctx, common.BytesToHash((decodedHash)))
message, err := client.GetByHash(ctx, common.BytesToHash(decodedHash))
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion contracts/src/bridge/SequencerInbox.sol
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ contract SequencerInbox is DelegateCallAware, GasRefundEnabled, ISequencerInbox
* @param keysetBytes bytes of the serialized keyset
*/
function setValidKeyset(bytes calldata keysetBytes) external override onlyRollupOwner {
bytes32 ksHash = keccak256(bytes.concat(keccak256(keysetBytes)));
uint256 ksWord = uint256(keccak256(bytes.concat(keccak256(keysetBytes))));
bytes32 ksHash = bytes32(ksWord ^ (1 << 255));

if (dasKeySetInfo[ksHash].isValidKeyset) revert AlreadyValidDASKeyset(ksHash);
dasKeySetInfo[ksHash] = DasKeySetInfo({
Expand Down
2 changes: 1 addition & 1 deletion das/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (a *Aggregator) GetByHash(ctx context.Context, hash common.Hash) ([]byte, e
errorChan <- err
return
}
if dastree.Hash(blob) == hash {
if dastree.ValidHash(hash, blob) {
blobChan <- blob
} else {
errorChan <- fmt.Errorf("DAS (mask %X) returned data that doesn't match requested hash!", d.signersMask)
Expand Down
4 changes: 2 additions & 2 deletions das/chain_fetch_das.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func chainFetchGetByHash(

// try to fetch from the inner DAS
innerRes, err := daReader.GetByHash(ctx, hash)
if err == nil && hash == dastree.Hash(innerRes) {
if err == nil && dastree.ValidHash(hash, innerRes) {
return innerRes, nil
}

Expand All @@ -137,7 +137,7 @@ func chainFetchGetByHash(
return nil, err
}
for iter.Next() {
if hash == dastree.Hash(iter.Event.KeysetBytes) {
if dastree.ValidHash(hash, iter.Event.KeysetBytes) {
cache.put(hash, iter.Event.KeysetBytes)
return iter.Event.KeysetBytes, nil
}
Expand Down
2 changes: 1 addition & 1 deletion das/dasrpc/dasRpcClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (c *DASRPCClient) GetByHash(ctx context.Context, hash common.Hash) ([]byte,
if err := c.clnt.CallContext(ctx, &ret, "das_getByHash", hexutil.Bytes(hash[:])); err != nil {
return nil, err
}
if hash != dastree.Hash(ret) { // check hash because RPC server might be untrusted
if !dastree.ValidHash(hash, ret) { // check hash because RPC server might be untrusted
return nil, arbstate.ErrHashMismatch
}
return ret, nil
Expand Down
88 changes: 45 additions & 43 deletions das/dastree/dastree.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/offchainlabs/nitro/util/arbmath"
"github.com/offchainlabs/nitro/util/colors"
)

const BinSize = 64 * 1024 // 64 kB
const LeafByte = 0xfe
const NodeByte = 0xff

type bytes32 = common.Hash
Expand All @@ -38,25 +36,30 @@ func RecordHash(record func(bytes32, []byte), preimage ...[]byte) bytes32 {
// Intermediate hashes like '*' from above may be recorded via the `record` closure
//

keccord := func(value []byte) bytes32 {
hash := crypto.Keccak256Hash(value)
record(hash, value)
return hash
}

unrolled := []byte{}
for _, slice := range preimage {
unrolled = append(unrolled, slice...)
}
if len(unrolled) == 0 {
single := []byte{LeafByte}
keccak := crypto.Keccak256Hash(single)
record(keccak, single)
return keccak
innerKeccak := keccord([]byte{})
outerKeccak := keccord(innerKeccak.Bytes())
return arbmath.FlipBit(outerKeccak, 0)
}

length := uint32(len(unrolled))
leaves := []node{}
for bin := uint32(0); bin < length; bin += BinSize {
end := arbmath.MinUint32(bin+BinSize, length)
single := append([]byte{LeafByte}, unrolled[bin:end]...)
keccak := crypto.Keccak256Hash(single)
record(keccak, single)
leaves = append(leaves, node{keccak, end - bin})
content := unrolled[bin:end]
innerKeccak := keccord(content)
outerKeccak := keccord(innerKeccak.Bytes())
leaves = append(leaves, node{outerKeccak, end - bin})
}

layer := leaves
Expand All @@ -68,23 +71,21 @@ func RecordHash(record func(bytes32, []byte), preimage ...[]byte) bytes32 {
firstHash := layer[i].hash.Bytes()
otherHash := layer[i+1].hash.Bytes()
sizeUnder := layer[i].size + layer[i+1].size
dataUnder := append([]byte{NodeByte}, firstHash...)
dataUnder := firstHash
dataUnder = append(dataUnder, otherHash...)
dataUnder = append(dataUnder, arbmath.Uint32ToBytes(sizeUnder)...)
parent := node{
crypto.Keccak256Hash(dataUnder),
keccord(dataUnder),
sizeUnder,
}
record(parent.hash, dataUnder)
paired[i/2] = parent
}
if prior%2 == 1 {
paired[after-1] = layer[prior-1]
}
layer = paired
}

return layer[0].hash
return arbmath.FlipBit(layer[0].hash, 0)
}

func Hash(preimage ...[]byte) bytes32 {
Expand All @@ -99,7 +100,7 @@ func HashBytes(preimage ...[]byte) []byte {
func FlatHashToTreeHash(flat bytes32) bytes32 {
// Forms a degenerate dastree that's just a single leaf
// note: the inner preimage may be larger than the 64 kB standard
return crypto.Keccak256Hash(flat[:])
return arbmath.FlipBit(crypto.Keccak256Hash(flat[:]), 0)
}

func ValidHash(hash bytes32, preimage []byte) bool {
Expand All @@ -117,55 +118,47 @@ func Content(root bytes32, oracle func(bytes32) []byte) ([]byte, error) {
// 3. Only the committee can produce trees unwrapped by this function
//

start := arbmath.FlipBit(root, 0)
total := uint32(0)
upper := oracle(root)
switch {
case len(upper) > 0 && upper[0] == LeafByte:
return upper[1:], nil
case len(upper) == 69 && upper[0] == NodeByte:
total = binary.BigEndian.Uint32(upper[65:])
upper := oracle(start)
switch len(upper) {
case 32:
return oracle(common.BytesToHash(upper)), nil
case 68:
total = binary.BigEndian.Uint32(upper[64:])
default:
return nil, fmt.Errorf("invalid root with preimage of size %v: %v %v", len(upper), root, upper)
}

stack := []node{{hash: root, size: total}}
preimage := []byte{}
leaves := []node{}
stack := []node{{hash: start, size: total}}

for len(stack) > 0 {
place := stack[len(stack)-1]
stack = stack[:len(stack)-1]

colors.PrintYellow("here ", place.hash, place.size)

under := oracle(place.hash)

if len(under) == 0 || (under[0] == NodeByte && len(under) != 69) {
return nil, fmt.Errorf("invalid node for hash %v: %v", place.hash, under)
}

kind := under[0]
content := under[1:]

switch kind {
case LeafByte:
if len(content) != int(place.size) {
return nil, fmt.Errorf("leaf has a badly sized bin: %v vs %v", len(under), place.size)
switch len(under) {
case 32:
leaf := node{
hash: common.BytesToHash(under),
size: place.size,
}
preimage = append(preimage, content...)
case NodeByte:
count := binary.BigEndian.Uint32(content[64:])
leaves = append(leaves, leaf)
case 68:
count := binary.BigEndian.Uint32(under[64:])
power := uint32(arbmath.NextOrCurrentPowerOf2(uint64(count)))

if place.size != count {
return nil, fmt.Errorf("invalid size data: %v vs %v for %v", count, place.size, under)
}

prior := node{
hash: common.BytesToHash(content[:32]),
hash: common.BytesToHash(under[:32]),
size: power / 2,
}
after := node{
hash: common.BytesToHash(content[32:64]),
hash: common.BytesToHash(under[32:64]),
size: count - power/2,
}

Expand All @@ -176,6 +169,15 @@ func Content(root bytes32, oracle func(bytes32) []byte) ([]byte, error) {
}
}

preimage := []byte{}
for i, leaf := range leaves {
bin := oracle(leaf.hash)
if len(bin) != int(leaf.size) {
return nil, fmt.Errorf("leaf %v has an incorrectly sized bin: %v vs %v", i, len(bin), leaf.size)
}
preimage = append(preimage, bin...)
}

// Check the hash matches. Given the size data this should never fail but we'll check anyway
if Hash(preimage) != root {
return nil, fmt.Errorf("preimage not canonically hashed")
Expand Down
2 changes: 1 addition & 1 deletion das/fallback_storage_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (f *FallbackStorageService) GetByHash(ctx context.Context, key common.Hash)
if err != nil {
return nil, err
}
if key == dastree.Hash(data) {
if dastree.ValidHash(key, data) {
putErr := f.StorageService.Put(
ctx, data, arbmath.SaturatingUAdd(uint64(time.Now().Unix()), f.backupRetentionSeconds),
)
Expand Down
2 changes: 1 addition & 1 deletion das/restful_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (c *RestfulDasClient) GetByHash(ctx context.Context, hash common.Hash) ([]b
if err != nil {
return nil, err
}
if hash != dastree.Hash(decodedBytes) {
if !dastree.ValidHash(hash, decodedBytes) {
return nil, arbstate.ErrHashMismatch
}

Expand Down
2 changes: 1 addition & 1 deletion das/simple_das_reader_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (a *SimpleDASReaderAggregator) tryGetByHash(
start := time.Now()
result, err := reader.GetByHash(ctx, hash)
if err == nil {
if dastree.Hash(result) == hash {
if dastree.ValidHash(hash, result) {
stat.success = true
} else {
err = fmt.Errorf("SimpleDASReaderAggregator got result from reader(%v) not matching hash", reader)
Expand Down
37 changes: 37 additions & 0 deletions util/arbmath/bits.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2021-2022, Offchain Labs, Inc.
// For license information, see https://github.com/nitro/blob/master/LICENSE

package arbmath

import (
"encoding/binary"

"github.com/ethereum/go-ethereum/common"
)

type bytes32 = common.Hash

// flips the nth bit in an ethereum word, starting from the left
func FlipBit(data bytes32, bit byte) bytes32 {
data[bit/8] ^= 1 << (7 - bit%8)
return data
}

// the number of eth-words needed to store n bytes
func WordsForBytes(nbytes uint64) uint64 {
return (nbytes + 31) / 32
}

// casts a uint64 to its big-endian representation
func UintToBytes(value uint64) []byte {
result := make([]byte, 8)
binary.BigEndian.PutUint64(result, value)
return result
}

// casts a uint32 to its big-endian representation
func Uint32ToBytes(value uint32) []byte {
result := make([]byte, 4)
binary.BigEndian.PutUint32(result, value)
return result
}
20 changes: 0 additions & 20 deletions util/arbmath/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package arbmath

import (
"encoding/binary"
"math"
"math/big"
"math/bits"
Expand Down Expand Up @@ -296,25 +295,6 @@ func SaturatingCastToUint(value *big.Int) uint64 {
return value.Uint64()
}

// the number of eth-words needed to store n bytes
func WordsForBytes(nbytes uint64) uint64 {
return (nbytes + 31) / 32
}

// casts a uint64 to its big-endian representation
func UintToBytes(value uint64) []byte {
result := make([]byte, 8)
binary.BigEndian.PutUint64(result, value)
return result
}

// casts a uint32 to its big-endian representation
func Uint32ToBytes(value uint32) []byte {
result := make([]byte, 4)
binary.BigEndian.PutUint32(result, value)
return result
}

// Return the Maclaurin series approximation of e^x, where x is denominated in basis points.
// This quartic polynomial will underestimate e^x by about 5% as x approaches 20000 bips.
func ApproxExpBasisPoints(value Bips) Bips {
Expand Down
8 changes: 8 additions & 0 deletions util/pretty/pretty_printing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ func FirstFewBytes(b []byte) string {
}
}

func PrettyBytes(b []byte) string {
hex := common.Bytes2Hex(b)
if len(hex) > 24 {
return fmt.Sprintf("%v...", hex[:24])
}
return hex
}

func PrettyHash(hash common.Hash) string {
return FirstFewBytes(hash.Bytes())
}
Expand Down

0 comments on commit 54c7daa

Please sign in to comment.