From 0abf0f93dc5443bce544ab8f7679f551bd4ed02d Mon Sep 17 00:00:00 2001
From: Kevin Charm <kevin@kevincharm.com>
Date: Mon, 29 Apr 2024 15:10:37 +0200
Subject: [PATCH] add FeistelShuffle reference contract, add
 FeistelShuffle/FeistelShuffleOptimised test coverage

---
 .../contracts/libs/FeistelShuffle.sol         | 113 +++++++++++
 .../contracts/test/FeistelShuffleConsumer.sol |  43 +++++
 packages/contracts/package.json               |   1 +
 .../test/contracts/FeistelShuffle.test.ts     | 175 ++++++++++++++++++
 pnpm-lock.yaml                                |   7 +
 5 files changed, 339 insertions(+)
 create mode 100644 packages/contracts/contracts/libs/FeistelShuffle.sol
 create mode 100644 packages/contracts/contracts/test/FeistelShuffleConsumer.sol
 create mode 100644 packages/contracts/test/contracts/FeistelShuffle.test.ts

diff --git a/packages/contracts/contracts/libs/FeistelShuffle.sol b/packages/contracts/contracts/libs/FeistelShuffle.sol
new file mode 100644
index 000000000..1450fd27a
--- /dev/null
+++ b/packages/contracts/contracts/libs/FeistelShuffle.sol
@@ -0,0 +1,113 @@
+// SPDX-License-Identifier: MIT
+pragma solidity >=0.8.10;
+
+/// @title FeistelShuffle (Reference)
+/// @author kevincharm
+/// @notice Lazy shuffling using generalised Feistel ciphers.
+library FeistelShuffle {
+    /// @notice Integer sqrt (rounding down), adapted from uniswap/v2-core
+    /// @param s integer to sqrt
+    /// @return z sqrt(s), rounding to zero
+    function sqrt(uint256 s) private pure returns (uint256 z) {
+        if (s > 3) {
+            z = s;
+            uint256 x = s / 2 + 1;
+            while (x < z) {
+                z = x;
+                x = (s / x + x) / 2;
+            }
+        } else if (s != 0) {
+            z = 1;
+        }
+    }
+
+    /// @notice Feistel round function
+    /// @param x index of element in the list
+    /// @param i hash iteration index
+    /// @param seed random seed
+    /// @param modulus cardinality of list
+    /// @return hashed hash of x (mod `modulus`)
+    function f(
+        uint256 x,
+        uint256 i,
+        uint256 seed,
+        uint256 modulus
+    ) private pure returns (uint256 hashed) {
+        return uint256(keccak256(abi.encodePacked(x, i, seed, modulus)));
+    }
+
+    /// @notice Next perfect square
+    /// @param n Number to get next perfect square of, unless it's already a
+    ///     perfect square.
+    function nextPerfectSquare(uint256 n) private pure returns (uint256) {
+        uint256 sqrtN = sqrt(n);
+        if (sqrtN ** 2 == n) {
+            return n;
+        }
+        return (sqrtN + 1) ** 2;
+    }
+
+    /// @notice Compute a Feistel shuffle mapping for index `x`
+    /// @param x index of element in the list
+    /// @param domain Number of elements in the list
+    /// @param seed Random seed; determines the permutation
+    /// @param rounds Number of Feistel rounds to perform
+    /// @return resulting shuffled index
+    function shuffle(
+        uint256 x,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) internal pure returns (uint256) {
+        require(domain != 0, "modulus must be > 0");
+        require(x < domain, "x too large");
+        require((rounds & 1) == 0, "rounds must be even");
+
+        uint256 h = sqrt(nextPerfectSquare(domain));
+        do {
+            uint256 L = x % h;
+            uint256 R = x / h;
+            for (uint256 i = 0; i < rounds; ++i) {
+                uint256 nextR = (L + f(R, i, seed, domain)) % h;
+                L = R;
+                R = nextR;
+            }
+            x = h * R + L;
+        } while (x >= domain);
+        return x;
+    }
+
+    /// @notice Compute the inverse Feistel shuffle mapping for the shuffled
+    ///     index `xPrime`
+    /// @param xPrime shuffled index of element in the list
+    /// @param domain Number of elements in the list
+    /// @param seed Random seed; determines the permutation
+    /// @param rounds Number of Feistel rounds that was performed in the
+    ///     original shuffle.
+    /// @return resulting shuffled index
+    function deshuffle(
+        uint256 xPrime,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) internal pure returns (uint256) {
+        require(domain != 0, "modulus must be > 0");
+        require(xPrime < domain, "x too large");
+        require((rounds & 1) == 0, "rounds must be even");
+
+        uint256 h = sqrt(nextPerfectSquare(domain));
+        do {
+            uint256 L = xPrime % h;
+            uint256 R = xPrime / h;
+            for (uint256 i = 0; i < rounds; ++i) {
+                uint256 nextL = (R +
+                    h -
+                    (f(L, rounds - i - 1, seed, domain) % h)) % h;
+                R = L;
+                L = nextL;
+            }
+            xPrime = h * R + L;
+        } while (xPrime >= domain);
+        return xPrime;
+    }
+}
diff --git a/packages/contracts/contracts/test/FeistelShuffleConsumer.sol b/packages/contracts/contracts/test/FeistelShuffleConsumer.sol
new file mode 100644
index 000000000..488445ae9
--- /dev/null
+++ b/packages/contracts/contracts/test/FeistelShuffleConsumer.sol
@@ -0,0 +1,43 @@
+// SPDX-License-Identifier: MIT
+pragma solidity >=0.8.10;
+
+import {FeistelShuffle} from "../libs/FeistelShuffle.sol";
+import {FeistelShuffleOptimised} from "../libs/FeistelShuffleOptimised.sol";
+
+contract FeistelShuffleConsumer {
+    function shuffle(
+        uint256 x,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) public pure returns (uint256) {
+        return FeistelShuffle.shuffle(x, domain, seed, rounds);
+    }
+
+    function deshuffle(
+        uint256 xPrime,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) public pure returns (uint256) {
+        return FeistelShuffle.deshuffle(xPrime, domain, seed, rounds);
+    }
+
+    function shuffle__OPT(
+        uint256 x,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) public pure returns (uint256) {
+        return FeistelShuffleOptimised.shuffle(x, domain, seed, rounds);
+    }
+
+    function deshuffle__OPT(
+        uint256 xPrime,
+        uint256 domain,
+        uint256 seed,
+        uint256 rounds
+    ) public pure returns (uint256) {
+        return FeistelShuffleOptimised.deshuffle(xPrime, domain, seed, rounds);
+    }
+}
diff --git a/packages/contracts/package.json b/packages/contracts/package.json
index 4fea51b1c..e98a57d04 100644
--- a/packages/contracts/package.json
+++ b/packages/contracts/package.json
@@ -41,6 +41,7 @@
   },
   "devDependencies": {
     "@ethereum-waffle/chai": "^3.4.3",
+    "@kevincharm/gfc-fpe": "^1.1.0",
     "@nomicfoundation/hardhat-verify": "^2.0.5",
     "@nomiclabs/hardhat-ethers": "^2.0.2",
     "@nomiclabs/hardhat-waffle": "^2.0.1",
diff --git a/packages/contracts/test/contracts/FeistelShuffle.test.ts b/packages/contracts/test/contracts/FeistelShuffle.test.ts
new file mode 100644
index 000000000..b2e620232
--- /dev/null
+++ b/packages/contracts/test/contracts/FeistelShuffle.test.ts
@@ -0,0 +1,175 @@
+import { expect } from 'chai'
+import { ethers } from 'hardhat'
+import { FeistelShuffleConsumer, FeistelShuffleConsumer__factory } from '../../build'
+import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers'
+import { BigNumber, BigNumberish } from 'ethers'
+import { randomBytes } from 'crypto'
+import * as tsFeistel from '@kevincharm/gfc-fpe'
+import { solidityKeccak256 } from 'ethers/lib/utils'
+
+const f = (R: bigint, i: bigint, seed: bigint, domain: bigint) =>
+  BigNumber.from(solidityKeccak256(['uint256', 'uint256', 'uint256', 'uint256'], [R, i, seed, domain])).toBigInt()
+
+describe('FeistelShuffle', () => {
+  let deployer: SignerWithAddress
+  let feistelShuffle: FeistelShuffleConsumer
+  let indices: number[]
+  let seed: string
+  before(async () => {
+    const signers = await ethers.getSigners()
+    deployer = signers[0]
+    feistelShuffle = await new FeistelShuffleConsumer__factory(deployer).deploy()
+    indices = Array(100)
+      .fill(0)
+      .map((_, i) => i)
+    seed = ethers.utils.defaultAbiCoder.encode(['bytes32'], ['0x' + randomBytes(32).toString('hex')])
+  })
+
+  function assertSetEquality(left: number[], right: number[]) {
+    const set = new Set<number>()
+    for (const l of left) {
+      set.add(l)
+    }
+    expect(set.size).to.equal(left.length)
+    for (const r of right) {
+      expect(set.delete(r)).to.equal(true, `${r} exists in left`)
+    }
+    expect(set.size).to.equal(0)
+  }
+
+  /**
+   * Same as calling `feistelShuffle.shuffle(...)`, but additionally
+   * checks the return value against the reference implementation and asserts
+   * they're equal.
+   *
+   * @param x
+   * @param domain
+   * @param seed
+   * @param rounds
+   * @returns
+   */
+  async function checkedShuffle(x: BigNumberish, domain: BigNumberish, seed: BigNumberish, rounds: number) {
+    const contractRefAnswer = await feistelShuffle.shuffle(x, domain, seed, rounds)
+    const refAnswer = await tsFeistel.encrypt(
+      BigNumber.from(x).toBigInt(),
+      BigNumber.from(domain).toBigInt(),
+      BigNumber.from(seed).toBigInt(),
+      BigNumber.from(rounds).toBigInt(),
+      f
+    )
+    expect(contractRefAnswer).to.equal(refAnswer)
+    // Compute x from x' using the inverse function
+    expect(await feistelShuffle.deshuffle(contractRefAnswer, domain, seed, rounds)).to.eq(x)
+    expect(await feistelShuffle.deshuffle__OPT(contractRefAnswer, domain, seed, rounds)).to.eq(x)
+    return contractRefAnswer
+  }
+
+  it('should create permutation with FeistelShuffle', async () => {
+    const rounds = 4
+    const shuffled: BigNumber[] = []
+    for (let i = 0; i < indices.length; i++) {
+      const s = await feistelShuffle.shuffle__OPT(i, indices.length, seed, rounds)
+      shuffled.push(s)
+    }
+    assertSetEquality(
+      indices,
+      shuffled.map((s) => s.toNumber())
+    )
+  })
+
+  it('should match reference implementation', async () => {
+    const rounds = 4
+    const shuffled: number[] = []
+    for (const i of indices) {
+      // Test both unoptimised & optimised versions
+      const s = await checkedShuffle(i, indices.length, seed, rounds)
+      // Test that optimised Yul version spits out the same output
+      const sOpt = await feistelShuffle.shuffle__OPT(i, indices.length, seed, rounds)
+      expect(s).to.equal(sOpt)
+      shuffled.push(sOpt.toNumber())
+    }
+
+    const specOutput: number[] = []
+    for (const index of indices) {
+      const xPrime = await tsFeistel.encrypt(
+        BigInt(index),
+        BigInt(indices.length),
+        BigNumber.from(seed).toBigInt(),
+        BigNumber.from(rounds).toBigInt(),
+        f
+      )
+      specOutput.push(Number(xPrime))
+    }
+
+    expect(shuffled).to.deep.equal(specOutput)
+  })
+
+  it('should revert if x >= modulus', async () => {
+    const rounds = 4
+    // on boundary
+    await expect(feistelShuffle.shuffle(100, 100, seed, rounds)).to.be.revertedWith('x too large')
+    await expect(feistelShuffle.shuffle__OPT(100, 100, seed, rounds)).to.be.reverted
+    // past boundary
+    await expect(feistelShuffle.shuffle(101, 100, seed, rounds)).to.be.revertedWith('x too large')
+    await expect(feistelShuffle.shuffle__OPT(101, 100, seed, rounds)).to.be.reverted
+  })
+
+  it('should revert if modulus == 0', async () => {
+    const rounds = 4
+    await expect(feistelShuffle.shuffle(0, 0, seed, rounds)).to.be.revertedWith('modulus must be > 0')
+    await expect(feistelShuffle.shuffle__OPT(0, 0, seed, rounds)).to.be.reverted
+  })
+
+  it('should handle small modulus', async () => {
+    // This is mainly to ensure the sqrt / nextPerfectSquare functions are correct
+    const rounds = 4
+
+    // list size of 1
+    let modulus = 1
+    const permutedOneRef = await checkedShuffle(0, modulus, seed, rounds)
+    expect(permutedOneRef).to.equal(0)
+    expect(permutedOneRef).to.equal(await feistelShuffle.shuffle__OPT(0, modulus, seed, rounds))
+
+    // list size of 2
+    modulus = 2
+    const shuffledTwo = new Set<number>()
+    for (let i = 0; i < modulus; i++) {
+      shuffledTwo.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber())
+    }
+    // |shuffledSet| = modulus
+    expect(shuffledTwo.size).to.equal(modulus)
+    // set equality with optimised version
+    for (let i = 0; i < modulus; i++) {
+      shuffledTwo.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber())
+    }
+    expect(shuffledTwo.size).to.equal(0)
+
+    // list size of 3
+    modulus = 3
+    const shuffledThree = new Set<number>()
+    for (let i = 0; i < modulus; i++) {
+      shuffledThree.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber())
+    }
+    // |shuffledSet| = modulus
+    expect(shuffledThree.size).to.equal(modulus)
+    // set equality with optimised version
+    for (let i = 0; i < modulus; i++) {
+      shuffledThree.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber())
+    }
+    expect(shuffledThree.size).to.equal(0)
+
+    // list size of 4 (past boundary)
+    modulus = 4
+    const shuffledFour = new Set<number>()
+    for (let i = 0; i < modulus; i++) {
+      shuffledFour.add((await checkedShuffle(i, modulus, seed, rounds)).toNumber())
+    }
+    // |shuffledSet| = modulus
+    expect(shuffledFour.size).to.equal(modulus)
+    // set equality with optimised version
+    for (let i = 0; i < modulus; i++) {
+      shuffledFour.delete((await feistelShuffle.shuffle__OPT(i, modulus, seed, rounds)).toNumber())
+    }
+    expect(shuffledFour.size).to.equal(0)
+  })
+})
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 9ffce267c..5e56353ba 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -17,6 +17,9 @@ importers:
       '@ethereum-waffle/chai':
         specifier: ^3.4.3
         version: 3.4.4
+      '@kevincharm/gfc-fpe':
+        specifier: ^1.1.0
+        version: 1.1.0
       '@nomicfoundation/hardhat-verify':
         specifier: ^2.0.5
         version: 2.0.5(hardhat@2.19.4)
@@ -1058,6 +1061,10 @@ packages:
       wrap-ansi-cjs: /wrap-ansi@7.0.0
     dev: true
 
+  /@kevincharm/gfc-fpe@1.1.0:
+    resolution: {integrity: sha512-Kg9wOa3T1ttOVxqeomid3zqWcqXVWZA7l1nXhOLrn/NTKmaTIM873rComzuG/b3Xv1CogLTZgK/+HeVf3Fn0fg==}
+    dev: true
+
   /@lit-labs/ssr-dom-shim@1.2.0:
     resolution: {integrity: sha512-yWJKmpGE6lUURKAaIltoPIE/wrbY3TEkqQt+X0m+7fQNnAv0keydnYvbiJFP1PnMhizmIWRWOG5KLhYyc/xl+g==}
     dev: false