diff --git a/packages/contracts/contracts/libs/FeistelShuffle.sol b/packages/contracts/contracts/libs/FeistelShuffle.sol new file mode 100644 index 000000000..1450fd27a --- /dev/null +++ b/packages/contracts/contracts/libs/FeistelShuffle.sol @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.10; + +/// @title FeistelShuffle (Reference) +/// @author kevincharm +/// @notice Lazy shuffling using generalised Feistel ciphers. +library FeistelShuffle { + /// @notice Integer sqrt (rounding down), adapted from uniswap/v2-core + /// @param s integer to sqrt + /// @return z sqrt(s), rounding to zero + function sqrt(uint256 s) private pure returns (uint256 z) { + if (s > 3) { + z = s; + uint256 x = s / 2 + 1; + while (x < z) { + z = x; + x = (s / x + x) / 2; + } + } else if (s != 0) { + z = 1; + } + } + + /// @notice Feistel round function + /// @param x index of element in the list + /// @param i hash iteration index + /// @param seed random seed + /// @param modulus cardinality of list + /// @return hashed hash of x (mod `modulus`) + function f( + uint256 x, + uint256 i, + uint256 seed, + uint256 modulus + ) private pure returns (uint256 hashed) { + return uint256(keccak256(abi.encodePacked(x, i, seed, modulus))); + } + + /// @notice Next perfect square + /// @param n Number to get next perfect square of, unless it's already a + /// perfect square. + function nextPerfectSquare(uint256 n) private pure returns (uint256) { + uint256 sqrtN = sqrt(n); + if (sqrtN ** 2 == n) { + return n; + } + return (sqrtN + 1) ** 2; + } + + /// @notice Compute a Feistel shuffle mapping for index `x` + /// @param x index of element in the list + /// @param domain Number of elements in the list + /// @param seed Random seed; determines the permutation + /// @param rounds Number of Feistel rounds to perform + /// @return resulting shuffled index + function shuffle( + uint256 x, + uint256 domain, + uint256 seed, + uint256 rounds + ) internal pure returns (uint256) { + require(domain != 0, "modulus must be > 0"); + require(x < domain, "x too large"); + require((rounds & 1) == 0, "rounds must be even"); + + uint256 h = sqrt(nextPerfectSquare(domain)); + do { + uint256 L = x % h; + uint256 R = x / h; + for (uint256 i = 0; i < rounds; ++i) { + uint256 nextR = (L + f(R, i, seed, domain)) % h; + L = R; + R = nextR; + } + x = h * R + L; + } while (x >= domain); + return x; + } + + /// @notice Compute the inverse Feistel shuffle mapping for the shuffled + /// index `xPrime` + /// @param xPrime shuffled index of element in the list + /// @param domain Number of elements in the list + /// @param seed Random seed; determines the permutation + /// @param rounds Number of Feistel rounds that was performed in the + /// original shuffle. + /// @return resulting shuffled index + function deshuffle( + uint256 xPrime, + uint256 domain, + uint256 seed, + uint256 rounds + ) internal pure returns (uint256) { + require(domain != 0, "modulus must be > 0"); + require(xPrime < domain, "x too large"); + require((rounds & 1) == 0, "rounds must be even"); + + uint256 h = sqrt(nextPerfectSquare(domain)); + do { + uint256 L = xPrime % h; + uint256 R = xPrime / h; + for (uint256 i = 0; i < rounds; ++i) { + uint256 nextL = (R + + h - + (f(L, rounds - i - 1, seed, domain) % h)) % h; + R = L; + L = nextL; + } + xPrime = h * R + L; + } while (xPrime >= domain); + return xPrime; + } +} diff --git a/packages/contracts/contracts/test/FeistelShuffleConsumer.sol b/packages/contracts/contracts/test/FeistelShuffleConsumer.sol new file mode 100644 index 000000000..488445ae9 --- /dev/null +++ b/packages/contracts/contracts/test/FeistelShuffleConsumer.sol @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.10; + +import {FeistelShuffle} from "../libs/FeistelShuffle.sol"; +import {FeistelShuffleOptimised} from "../libs/FeistelShuffleOptimised.sol"; + +contract FeistelShuffleConsumer { + function shuffle( + uint256 x, + uint256 domain, + uint256 seed, + uint256 rounds + ) public pure returns (uint256) { + return FeistelShuffle.shuffle(x, domain, seed, rounds); + } + + function deshuffle( + uint256 xPrime, + uint256 domain, + uint256 seed, + uint256 rounds + ) public pure returns (uint256) { + return FeistelShuffle.deshuffle(xPrime, domain, seed, rounds); + } + + function shuffle__OPT( + uint256 x, + uint256 domain, + uint256 seed, + uint256 rounds + ) public pure returns (uint256) { + return FeistelShuffleOptimised.shuffle(x, domain, seed, rounds); + } + + function deshuffle__OPT( + uint256 xPrime, + uint256 domain, + uint256 seed, + uint256 rounds + ) public pure returns (uint256) { + return FeistelShuffleOptimised.deshuffle(xPrime, domain, seed, rounds); + } +} diff --git a/packages/contracts/package.json b/packages/contracts/package.json index 4fea51b1c..e98a57d04 100644 --- a/packages/contracts/package.json +++ b/packages/contracts/package.json @@ -41,6 +41,7 @@ }, "devDependencies": { "@ethereum-waffle/chai": "^3.4.3", + "@kevincharm/gfc-fpe": "^1.1.0", "@nomicfoundation/hardhat-verify": "^2.0.5", "@nomiclabs/hardhat-ethers": "^2.0.2", "@nomiclabs/hardhat-waffle": "^2.0.1", diff --git a/packages/contracts/test/contracts/FeistelShuffle.test.ts b/packages/contracts/test/contracts/FeistelShuffle.test.ts new file mode 100644 index 000000000..b2e620232 --- /dev/null +++ b/packages/contracts/test/contracts/FeistelShuffle.test.ts @@ -0,0 +1,175 @@ +import { expect } from 'chai' +import { ethers } from 'hardhat' +import { FeistelShuffleConsumer, FeistelShuffleConsumer__factory } from '../../build' +import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers' +import { BigNumber, BigNumberish } from 'ethers' +import { randomBytes } from 'crypto' +import * as tsFeistel from '@kevincharm/gfc-fpe' +import { solidityKeccak256 } from 'ethers/lib/utils' + +const f = (R: bigint, i: bigint, seed: bigint, domain: bigint) => + BigNumber.from(solidityKeccak256(['uint256', 'uint256', 'uint256', 'uint256'], [R, i, seed, domain])).toBigInt() + +describe('FeistelShuffle', () => { + let deployer: SignerWithAddress + let feistelShuffle: FeistelShuffleConsumer + let indices: number[] + let seed: string + before(async () => { + const signers = await ethers.getSigners() + deployer = signers[0] + feistelShuffle = await new FeistelShuffleConsumer__factory(deployer).deploy() + indices = Array(100) + .fill(0) + .map((_, i) => i) + seed = ethers.utils.defaultAbiCoder.encode(['bytes32'], ['0x' + randomBytes(32).toString('hex')]) + }) + + function assertSetEquality(left: number[], right: number[]) { + const set = new Set() + for (const l of left) { + set.add(l) + } + expect(set.size).to.equal(left.length) + for (const r of right) { + expect(set.delete(r)).to.equal(true, `${r} exists in left`) + } + expect(set.size).to.equal(0) + } + + /** + * Same as calling `feistelShuffle.shuffle(...)`, but additionally + * checks the return value against the reference implementation and asserts + * they're equal. + * + * @param x + * @param domain + * @param seed + * @param rounds + * @returns + */ + async function checkedShuffle(x: BigNumberish, domain: BigNumberish, seed: BigNumberish, rounds: number) { + const contractRefAnswer = await feistelShuffle.shuffle(x, domain, seed, rounds) + const refAnswer = await tsFeistel.encrypt( + BigNumber.from(x).toBigInt(), + BigNumber.from(domain).toBigInt(), + BigNumber.from(seed).toBigInt(), + BigNumber.from(rounds).toBigInt(), + f + ) + expect(contractRefAnswer).to.equal(refAnswer) + // Compute x from x' using the inverse function + expect(await feistelShuffle.deshuffle(contractRefAnswer, domain, seed, rounds)).to.eq(x) + expect(await feistelShuffle.deshuffle__OPT(contractRefAnswer, domain, seed, rounds)).to.eq(x) + return contractRefAnswer + } + + it('should create permutation with FeistelShuffle', async () => { + const rounds = 4 + const shuffled: BigNumber[] = [] + for (let i = 0; i < indices.length; i++) { + const s = await feistelShuffle.shuffle__OPT(i, indices.length, seed, rounds) + shuffled.push(s) + } + assertSetEquality( + indices, + shuffled.map((s) => s.toNumber()) + ) + }) + + it('should match reference implementation', async () => { + const rounds = 4 + const shuffled: number[] = [] + for (const i of indices) { + // Test both unoptimised & optimised versions + const s = await checkedShuffle(i, indices.length, seed, rounds) + // Test that optimised Yul version spits out the same output + const sOpt = await feistelShuffle.shuffle__OPT(i, indices.length, seed, rounds) + expect(s).to.equal(sOpt) + shuffled.push(sOpt.toNumber()) + } + + const specOutput: number[] = [] + for (const index of indices) { + const xPrime = await tsFeistel.encrypt( + BigInt(index), + BigInt(indices.length), + BigNumber.from(seed).toBigInt(), + BigNumber.from(rounds).toBigInt(), + f + ) + specOutput.push(Number(xPrime)) + } + + expect(shuffled).to.deep.equal(specOutput) + }) + + it('should revert if x >= modulus', async () => { + const rounds = 4 + // on boundary + await expect(feistelShuffle.shuffle(100, 100, seed, rounds)).to.be.revertedWith('x too large') + await expect(feistelShuffle.shuffle__OPT(100, 100, seed, rounds)).to.be.reverted + // past boundary + await expect(feistelShuffle.shuffle(101, 100, seed, rounds)).to.be.revertedWith('x too large') + await expect(feistelShuffle.shuffle__OPT(101, 100, seed, rounds)).to.be.reverted + }) + + it('should revert if modulus == 0', async () => { + const rounds = 4 + await expect(feistelShuffle.shuffle(0, 0, seed, rounds)).to.be.revertedWith('modulus must be > 0') + await expect(feistelShuffle.shuffle__OPT(0, 0, seed, rounds)).to.be.reverted + }) + + it('should handle small modulus', async () => { + // This is mainly to ensure the sqrt / nextPerfectSquare functions are correct + const rounds = 4 + + // list size of 1 + let modulus = 1 + const permutedOneRef = await checkedShuffle(0, modulus, seed, rounds) + expect(permutedOneRef).to.equal(0) + expect(permutedOneRef).to.equal(await feistelShuffle.shuffle__OPT(0, modulus, seed, rounds)) + + // list size of 2 + modulus = 2 + const shuffledTwo = new Set() + for (let i = 0; i < modulus; i++) { + shuffledTwo.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber()) + } + // |shuffledSet| = modulus + expect(shuffledTwo.size).to.equal(modulus) + // set equality with optimised version + for (let i = 0; i < modulus; i++) { + shuffledTwo.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber()) + } + expect(shuffledTwo.size).to.equal(0) + + // list size of 3 + modulus = 3 + const shuffledThree = new Set() + for (let i = 0; i < modulus; i++) { + shuffledThree.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber()) + } + // |shuffledSet| = modulus + expect(shuffledThree.size).to.equal(modulus) + // set equality with optimised version + for (let i = 0; i < modulus; i++) { + shuffledThree.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber()) + } + expect(shuffledThree.size).to.equal(0) + + // list size of 4 (past boundary) + modulus = 4 + const shuffledFour = new Set() + for (let i = 0; i < modulus; i++) { + shuffledFour.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber()) + } + // |shuffledSet| = modulus + expect(shuffledFour.size).to.equal(modulus) + // set equality with optimised version + for (let i = 0; i < modulus; i++) { + shuffledFour.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber()) + } + expect(shuffledFour.size).to.equal(0) + }) +}) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9ffce267c..5e56353ba 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -17,6 +17,9 @@ importers: '@ethereum-waffle/chai': specifier: ^3.4.3 version: 3.4.4 + '@kevincharm/gfc-fpe': + specifier: ^1.1.0 + version: 1.1.0 '@nomicfoundation/hardhat-verify': specifier: ^2.0.5 version: 2.0.5(hardhat@2.19.4) @@ -1058,6 +1061,10 @@ packages: wrap-ansi-cjs: /wrap-ansi@7.0.0 dev: true + /@kevincharm/gfc-fpe@1.1.0: + resolution: {integrity: sha512-Kg9wOa3T1ttOVxqeomid3zqWcqXVWZA7l1nXhOLrn/NTKmaTIM873rComzuG/b3Xv1CogLTZgK/+HeVf3Fn0fg==} + dev: true + /@lit-labs/ssr-dom-shim@1.2.0: resolution: {integrity: sha512-yWJKmpGE6lUURKAaIltoPIE/wrbY3TEkqQt+X0m+7fQNnAv0keydnYvbiJFP1PnMhizmIWRWOG5KLhYyc/xl+g==} dev: false