Skip to content

Commit

Permalink
Improve SecureRandom benchmarks (especially in multi-threaded settings)
Browse files Browse the repository at this point in the history
1. Add more benchmarked randomness generation algorithms/providers (SUN/NativePrng, java.util.random, and some baseline using copy instead of randomness generation). SUN/NativePrng is the default Java implementation of SecureRandom when ACCP is not installed.
2. Remove configuration line in build.gradle.kts forcing single thread in all benchmarks. This is because command line arguments override annotations, and we need to use multiple threads in some benchmarks.
3. Use average measurement (ns/op) in benchmarks of randomness generation instead of throughput (op/ns). This is because for multi-threaded benchmarks of randomness generation, we want to know the duration of a single randomness generation on a single thread (and we want to check if performance drops because of multi-threading).
4. Make the data variable (the target of randomness generation) thread-local to avoid L1 cache contention.
5. Add a few more benchmarks:
    a. multiThreadedLocal where the SecureRandom instance is thread local, instead of global: this is useful for users of ACCP to know whether they should have SecureRandom be thread local or not.
    b. singleThreadedNew creating a new instance of SecureRandom: this is useful for users of ACCP to know whether it is ok to instantiate a SecureRandom every time they need it, or to instead create a global instance.
    c. A variant of the benchmarks generating 4 bytes of random data, instead of 1,024 bytes. For example, generating a random int requires 4 bytes of random data.
  • Loading branch information
Fabrice Benhamouda committed Apr 10, 2024
1 parent 03dcaef commit 0fdb6e4
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 43 deletions.
6 changes: 3 additions & 3 deletions benchmarks/lib/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ jmh {
includes.add(includeBenchmark)
}
fork.set(1)
benchmarkMode.add("thrpt")
threads.set(1)
timeUnit.set("s")
// Do not specify benchmarkMode nor timeUnit to allow each benchmark to use their own
// Do not set threads.set(1) as it prevents multi-threaded benchmarks
// Classes without any annotation will use a single thread and ops/s by default
iterations.set(5)
timeOnIteration.set("3s")
warmup.set("1s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,196 @@
// SPDX-License-Identifier: Apache-2.0
package com.amazon.corretto.crypto.provider.benchmarks;

import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SecureRandom;
import java.util.concurrent.TimeUnit;

import com.amazon.corretto.crypto.provider.AmazonCorrettoCryptoProvider;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Threads;

@State(Scope.Benchmark)
public class Drbg {
@Param({"1024"})
public int size;

@Param({AmazonCorrettoCryptoProvider.PROVIDER_NAME, "BC", "SUN"})
public String provider;

private byte[] data;
private SecureRandom random;

@Setup
public void setup() throws Exception {
BenchmarkUtils.setupProvider(provider);
data = new byte[size];
final String algorithm;
switch (provider) {
case AmazonCorrettoCryptoProvider.PROVIDER_NAME:
case "BC":
algorithm = "DEFAULT";
break;
case "SUN":
algorithm = "DRBG";
break;
default:
throw new RuntimeException("Unknown algorithm for provider " + provider);
/**
* Benchmark Random/SecureRandom implementations
*
* <p>Use average time in ns/op to measure the time per thread instead of the default throughput
* mode (ops/s), because the throughput mode sums the number of operations over all the threads
*/
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public class Random {
@State(Scope.Thread)
public static class ThreadState {
@Param({"4", "1024"})
public int size;

// data is a thread-local variable to prevent L1 cache contention
// benchmarks generating randomness store it in data
private byte[] data;
// sourceData is a random array used as source data for baseline benchmarks
// where sourceData is just copied to data
private byte[] sourceData;

@Setup
public void setup() {
data = new byte[size];
sourceData = new byte[size];
// generating sourceData at random to prevent potential optimizations
(new java.util.Random()).nextBytes(sourceData);
}
}

@State(Scope.Benchmark)
public static class Shared {
// !!! WARNING: java.util.random is not a secure randomness generator
// !!! WARNING: we add it here just for comparison
@Param({
AmazonCorrettoCryptoProvider.PROVIDER_NAME + "/DEFAULT",
"BC/DEFAULT",
"SUN/NativePrng",
"SUN/DRBG",
"java.util.Random"
})
public String provider_algorithm;

private String provider;
private String algorithm;

// random is shared amongst all threads
private java.util.Random random;

// localRandom is thread local
private ThreadLocal<java.util.Random> localRandom;

@Setup
public synchronized void setup() throws Exception {
if ("java.util.Random".equals(provider_algorithm)) {
// java.util.random is a special case as it's not a SecureRandom
random = new java.util.Random();
// !!! WARNING: This is just for benchmarking and should not be used as is.
// Use ThreadLocalRandom.current() for a thread-local non-cryptographic randomness
// generator
localRandom = new ThreadLocal<java.util.Random>();
} else {
final String[] parts = provider_algorithm.split("/", 2);
provider = parts[0];
algorithm = parts[1];

BenchmarkUtils.setupProvider(provider);

random = SecureRandom.getInstance(algorithm, provider);
localRandom =
ThreadLocal.withInitial(
() -> {
try {
return SecureRandom.getInstance(algorithm, provider);
} catch (NoSuchAlgorithmException | NoSuchProviderException e) {
throw new RuntimeException(e);
}
});
}
}
}

@Benchmark
@Threads(1)
public byte[] singleThreaded(Shared shared, ThreadState threadState) {
shared.random.nextBytes(threadState.data);
return threadState.data;
}

/** Benchmark the time needed to get a new instance */
@Benchmark
@Threads(1)
public Object singleThreadedNew(Shared shared)
throws NoSuchAlgorithmException, NoSuchProviderException {
if ("java.util.Random".equals(shared.provider_algorithm)) {
return new Random();
} else {
return SecureRandom.getInstance(shared.algorithm, shared.provider);
}
random = SecureRandom.getInstance(algorithm, provider);
}

/**
* Benchmark of SecureRandom with number of threads = number of hardware threads where
* SecureRandom is shared between all threads
*/
@Benchmark
@Threads(Threads.MAX)
public byte[] multiThreaded(Shared shared, ThreadState threadState) {
shared.random.nextBytes(threadState.data);
return threadState.data;
}

/**
* Benchmark of SecureRandom with number of threads = number of hardware threads where
* SecureRandom is local to each thread
*/
@Benchmark
@Threads(Threads.MAX)
public byte[] multiThreadedLocal(Shared shared, ThreadState threadState) {
shared.localRandom.get().nextBytes(threadState.data);
return threadState.data;
}

/**
* Benchmark of creating a new thread and generate randomness This benchmark is only to find
* potential regressions due to per-thread initialization done by SecureRandom
*
* <p>Note that there will be L1 cache contention because `data` is shared between all the threads
*/
@Benchmark
@Threads(1)
public byte[] newThreadPerRequest(Shared shared, ThreadState threadState)
throws InterruptedException {
Thread t = new Thread(() -> shared.random.nextBytes(threadState.data));
t.start();
t.join();
return threadState.data;
}

/**
* Baseline version of the {@link #singleThreaded singleThreaded} benchmark where instead of
* generating randomness, data is copied from sourceData to data
*/
@Benchmark
@Threads(1)
public byte[] singleThreaded() {
random.nextBytes(data);
return data;
public byte[] singleThreadedBaseline(ThreadState threadState) {
System.arraycopy(threadState.sourceData, 0, threadState.data, 0, threadState.size);
return threadState.data;
}

/**
* Baseline version of the {@link #multiThreaded multiThreaded} benchmark where instead of
* generating randomness, data is copied from sourceData to data
*/
@Benchmark
@Threads(Threads.MAX)
public byte[] multiThreaded() {
random.nextBytes(data);
return data;
public byte[] multiThreadedBaseline(ThreadState threadState) {
System.arraycopy(threadState.sourceData, 0, threadState.data, 0, threadState.size);
return threadState.data;
}

/**
* Baseline version of the {@link #newThreadPerRequest} newThreadPerRequest} benchmark where
* instead of generating randomness, data is copied from sourceData to data
*/
@Benchmark
@Threads(1)
public byte[] newThreadPerRequest() throws InterruptedException {
public byte[] newThreadPerRequestBaseline(ThreadState threadState) throws InterruptedException {
Thread t =
new Thread() {
public void run() {
random.nextBytes(data);
}
};
new Thread(
() ->
System.arraycopy(threadState.sourceData, 0, threadState.data, 0, threadState.size));
t.start();
t.join();
return data;
return threadState.data;
}
}

0 comments on commit 0fdb6e4

Please sign in to comment.