diff --git a/base/src/main/java/io/vproxy/base/util/Utils.java b/base/src/main/java/io/vproxy/base/util/Utils.java index 1813bc18..752bb619 100644 --- a/base/src/main/java/io/vproxy/base/util/Utils.java +++ b/base/src/main/java/io/vproxy/base/util/Utils.java @@ -470,6 +470,17 @@ public static boolean allZerosAfter(ByteArray bytes, int index) { return true; } + public static int minPow2GreaterThan(int n) { + n -= 1; + n |= n >>> 1; + n |= n >>> 2; + n |= n >>> 4; + n |= n >>> 8; + n |= n >>> 16; + n += 1; + return n; + } + public static boolean assertOn() { return assertOn; } diff --git a/base/src/main/java/io/vproxy/base/util/lock/ReadWriteSpinLock.java b/base/src/main/java/io/vproxy/base/util/lock/ReadWriteSpinLock.java new file mode 100644 index 00000000..5957d5e6 --- /dev/null +++ b/base/src/main/java/io/vproxy/base/util/lock/ReadWriteSpinLock.java @@ -0,0 +1,55 @@ +package io.vproxy.base.util.lock; + +import java.util.concurrent.atomic.AtomicInteger; + +public class ReadWriteSpinLock { + private static final int WRITE_LOCKED = 0x80_00_00_00; + // 32 31 ------ 0 + // W RRRR...RRRR + private final AtomicInteger lock = new AtomicInteger(0); + private final AtomicInteger wLockPending = new AtomicInteger(0); + private final int spinTimes; + + public ReadWriteSpinLock() { + this(20); + } + + public ReadWriteSpinLock(int spinTimes) { + this.spinTimes = spinTimes; + } + + public void readLock() { + while (true) { + if (wLockPending.get() != 0) { + spinWait(); + continue; + } + if (lock.incrementAndGet() < 0) { + continue; + } + break; + } + } + + public void readUnlock() { + lock.decrementAndGet(); + } + + public void writeLock() { + wLockPending.incrementAndGet(); + while (!lock.compareAndSet(0, WRITE_LOCKED)) { + spinWait(); + } + } + + public void writeUnlock() { + lock.set(0); + wLockPending.decrementAndGet(); + } + + private void spinWait() { + for (int i = 0; i < spinTimes; ++i) { + Thread.onSpinWait(); + } + } +} diff --git a/base/src/main/java/io/vproxy/base/util/objectpool/ConcurrentObjectPool.java b/base/src/main/java/io/vproxy/base/util/objectpool/ConcurrentObjectPool.java index 1ee1f65c..08a6c72d 100644 --- a/base/src/main/java/io/vproxy/base/util/objectpool/ConcurrentObjectPool.java +++ b/base/src/main/java/io/vproxy/base/util/objectpool/ConcurrentObjectPool.java @@ -1,15 +1,17 @@ package io.vproxy.base.util.objectpool; +import io.vproxy.base.util.Utils; +import io.vproxy.base.util.lock.ReadWriteSpinLock; +import io.vproxy.base.util.thread.VProxyThread; + import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReferenceArray; /** * The pool is split into a few partitions, each partition has a read array and a write array. * When adding, elements will be added into the write array. * When polling, elements will be polled from the read array. - * If read array is empty and write array is full, and when running polling, the two arrays will be swapped - * (they will not be swapped when adding). + * If read array is empty and write array is full, and when running polling or adding, the two arrays will be swapped * The arrays will not be operated when they are being swapped. * When concurrency occurs, the operations will retry for maximum 10 times. * @@ -17,37 +19,48 @@ */ public class ConcurrentObjectPool { private final int partitionCount; + private final int partitionCountMinusOne; private final Partition[] partitions; + private final int maxTraversal; public ConcurrentObjectPool(int capacityHint) { - this(capacityHint, 16, 4); + this(capacityHint, 16, 0); } - public ConcurrentObjectPool(int capacityHint, int partitionCountHint, int minPartitionCapHint) { - capacityHint -= 1; - capacityHint |= capacityHint >>> 1; - capacityHint |= capacityHint >>> 2; - capacityHint |= capacityHint >>> 4; - capacityHint |= capacityHint >>> 8; - capacityHint |= capacityHint >>> 16; - capacityHint += 1; + public ConcurrentObjectPool(int capacityHint, int partitionCountHint, int maxTraversalHint) { + capacityHint = Utils.minPow2GreaterThan(capacityHint) / 2; + partitionCountHint = Utils.minPow2GreaterThan(partitionCountHint); - if (capacityHint / minPartitionCapHint == 0) { + if (capacityHint / partitionCountHint == 0) { partitionCount = 1; } else { - partitionCount = Math.min(capacityHint / minPartitionCapHint, partitionCountHint); + partitionCount = partitionCountHint; } + partitionCountMinusOne = partitionCount - 1; //noinspection unchecked this.partitions = new Partition[partitionCount]; for (int i = 0; i < partitionCount; ++i) { partitions[i] = new Partition<>(capacityHint / partitionCount); } + + if (maxTraversalHint <= 0 || maxTraversalHint >= partitionCount) { + maxTraversal = partitionCount; + } else { + maxTraversal = maxTraversalHint; + } + } + + private int hashForPartition() { + var tid = VProxyThread.current().threadId; + return (int) (tid & partitionCountMinusOne); } public boolean add(E e) { - for (int i = 0; i < partitionCount; ++i) { - if (partitions[i].add(e)) { + int m = maxTraversal; + int hash = hashForPartition(); + for (int i = hash; m > 0; ++i, --m) { + if (partitions[i & partitionCountMinusOne].add(e)) { return true; } } @@ -55,8 +68,10 @@ public boolean add(E e) { } public E poll() { - for (int i = 0; i < partitionCount; ++i) { - E e = partitions[i].poll(); + int m = maxTraversal; + int hash = hashForPartition(); + for (int i = hash; m > 0; ++i, --m) { + E e = partitions[i & partitionCountMinusOne].poll(); if (e != null) { return e; } @@ -73,27 +88,37 @@ public int size() { } private static class Partition { - private final AtomicReference> read; - private volatile StorageArray write; - private final StorageArray _1; - private final StorageArray _2; + private final ReadWriteSpinLock lock = new ReadWriteSpinLock(); + private volatile ArrayQueue read; + private volatile ArrayQueue write; + private final ArrayQueue _1; + private final ArrayQueue _2; public Partition(int capacity) { - _1 = new StorageArray<>(capacity); - _2 = new StorageArray<>(capacity); - read = new AtomicReference<>(_1); + _1 = new ArrayQueue<>(capacity, lock); + _2 = new ArrayQueue<>(capacity, lock); + read = _1; write = _2; } public boolean add(E e) { - StorageArray write = this.write; + return add(1, e); + } + + private boolean add(int retry, E e) { + if (retry > 10) { // max retry for 10 times + return false; // too many retries + } - // adding is always safe - //noinspection RedundantIfStatement + var write = this.write; if (write.add(e)) { return true; } - // $write is full, storing fails + + // the $write is full now + if (swap(read, write, false)) { + return add(retry + 1, e); + } return false; } @@ -106,38 +131,59 @@ private E poll(int retry) { return null; // too many retries } - StorageArray read = this.read.get(); - StorageArray write = this.write; + var read = this.read; + var write = this.write; - // polling is always safe - E ret = read.poll(); + var ret = read.poll(); if (ret != null) { return ret; } // no elements in the $read now - // check whether we can swap (whether $write is full) + if (swap(read, write, true)) { + return poll(retry + 1); + } + return null; + } - int writeEnd = write.end.get(); - if (writeEnd < write.capacity) { - return null; // capacity not reached, do not swap and return nothing - // no retry here because the write array will not change until something written into it + // return true -> need retry + // return false -> failed and should not retry + private boolean swap(ArrayQueue read, ArrayQueue write, boolean isPolling) { + // check whether we can swap + if (read == write) { + // is being swapped by another thread + return true; } - // also we should check whether there are no elements being stored - if (write.storing.get() != 0) { // element is being stored into the array - return poll(retry + 1); // try again + + if (isPolling) { // $read is empty + int writeEnd = write.end.get(); + if (writeEnd < write.capacity) { + return false; // capacity not reached, do not swap and return nothing + // no retry here because the write array will not change until something written into it + } + } else { // $write is full + int readStart = read.start.get(); + if (readStart < read.end.get()) { + return false; // still have objects to fetch, do not swap + // no retry here because the read array will not change until something polling from it + } } - // now we can know that writing operations will not happen in this partition - // we can safely swap the two arrays now - if (!this.read.compareAndSet(read, write)) { - return poll(retry + 1); // concurrency detected: another thread is swapping + lock.writeLock(); + if (this.read != read) { + // already swapped by another thread + lock.writeUnlock(); + return true; } + // we can safely swap the two arrays now + this.read = write; // the $read is expected to be empty assert read.size() == 0; read.reset(); // reset the cursors, so further operations can store data into this array this.write = read; // swapping is done - return poll(retry + 1); // poll again + lock.writeUnlock(); + + return true; } public int size() { @@ -145,64 +191,74 @@ public int size() { } } - private static class StorageArray { + private static class ArrayQueue { private final int capacity; + private final ReadWriteSpinLock lock; private final AtomicReferenceArray array; - private final AtomicInteger start = new AtomicInteger(-1); - private final AtomicInteger end = new AtomicInteger(-1); - private final AtomicInteger storing = new AtomicInteger(0); + private final AtomicInteger start = new AtomicInteger(0); + private final AtomicInteger end = new AtomicInteger(0); - private StorageArray(int capacity) { + private ArrayQueue(int capacity, ReadWriteSpinLock lock) { this.capacity = capacity; + this.lock = lock; this.array = new AtomicReferenceArray<>(capacity); } boolean add(E e) { - storing.incrementAndGet(); + lock.readLock(); if (end.get() >= capacity) { - storing.decrementAndGet(); + lock.readUnlock(); return false; // exceeds capacity } - int index = end.incrementAndGet(); + int index = end.getAndIncrement(); if (index < capacity) { // storing should succeed array.set(index, e); - storing.decrementAndGet(); + lock.readUnlock(); return true; } else { // storing failed - storing.decrementAndGet(); + lock.readUnlock(); return false; } } E poll() { - if (start.get() + 1 >= end.get() || start.get() + 1 >= capacity) { + lock.readLock(); + + if (start.get() >= end.get() || start.get() >= capacity) { + lock.readUnlock(); return null; } - int idx = start.incrementAndGet(); + int idx = start.getAndIncrement(); if (idx >= end.get() || idx >= capacity) { + lock.readUnlock(); return null; // concurrent polling } - return array.get(idx); + var e = array.get(idx); + lock.readUnlock(); + return e; } int size() { - int start = this.start.get() + 1; + int start = this.start.get(); if (start >= capacity) { return 0; } - int cap = end.get() + 1; + int cap = end.get(); if (cap > capacity) { cap = capacity; } + if (start > cap) { + return 0; + } return cap - start; } void reset() { - end.set(-1); - start.set(-1); + end.set(0); + start.set(0); } } } diff --git a/base/src/main/java/io/vproxy/base/util/thread/VProxyThread.java b/base/src/main/java/io/vproxy/base/util/thread/VProxyThread.java index 301364a7..a8e82bdc 100644 --- a/base/src/main/java/io/vproxy/base/util/thread/VProxyThread.java +++ b/base/src/main/java/io/vproxy/base/util/thread/VProxyThread.java @@ -12,6 +12,7 @@ import vjson.util.StringDictionary; import java.util.UUID; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.BooleanSupplier; public interface VProxyThread { @@ -51,7 +52,14 @@ default void interrupt() { thread().interrupt(); } + default void join() throws InterruptedException { + thread().join(); + } + class VProxyThreadVariable { + private static final AtomicLong threadIdGenerator = new AtomicLong(); + + public final long threadId = threadIdGenerator.getAndIncrement(); public SelectorEventLoop loop; public ArrayParser threadLocalArrayParser; diff --git a/base/src/main/java/module-info.java b/base/src/main/java/module-info.java index b23558f1..0aac2b16 100644 --- a/base/src/main/java/module-info.java +++ b/base/src/main/java/module-info.java @@ -66,6 +66,7 @@ exports io.vproxy.base.util.functional; exports io.vproxy.base.util.io; exports io.vproxy.base.util.kt; + exports io.vproxy.base.util.lock; exports io.vproxy.base.util.log; exports io.vproxy.base.util.misc; exports io.vproxy.base.util.net; diff --git a/build.gradle b/build.gradle index 5293a603..66678215 100644 --- a/build.gradle +++ b/build.gradle @@ -513,7 +513,7 @@ project(':test') { } def testCase = c def m = System.getProperty("method") - if (m != null) { + if (m != null && !m.isEmpty()) { testCase += '.' + m } diff --git a/test/src/test/java/io/vproxy/test/cases/TestUtils.java b/test/src/test/java/io/vproxy/test/cases/TestUtils.java index ab62b40e..b40f08cd 100644 --- a/test/src/test/java/io/vproxy/test/cases/TestUtils.java +++ b/test/src/test/java/io/vproxy/test/cases/TestUtils.java @@ -5,8 +5,10 @@ import io.vproxy.base.util.bytearray.RandomAccessFileByteArray; import io.vproxy.base.util.file.MappedByteBufferLogger; import io.vproxy.base.util.nio.ByteArrayChannel; +import io.vproxy.base.util.objectpool.ConcurrentObjectPool; import io.vproxy.base.util.ringbuffer.SimpleRingBuffer; import io.vproxy.base.util.ringbuffer.SimpleRingBufferReaderCommitter; +import io.vproxy.base.util.thread.VProxyThread; import io.vproxy.commons.util.IOUtils; import org.junit.After; import org.junit.Test; @@ -20,6 +22,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.*; +import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import static org.junit.Assert.*; @@ -493,4 +496,125 @@ public void randomAccessFileByteArray() throws Exception { assertEquals("abcdefghijk", str); } } + + @Test + public void concurrentObjectPoolPoll() { + var pool = new ConcurrentObjectPool(65536, 16, 0); + for (int i = 0; i < 65536; ++i) { + assertTrue("add(" + i + ")", pool.add(i)); + } + // pool is full, adding should fail now + for (int i = 0; i < 10; ++i) { + assertFalse(pool.add(1234)); + } + var results = new ArrayList(); + var pollCount = new ArrayList(); + var threads = new ArrayList(); + for (int i = 0; i < 16; ++i) { + threads.add(new Thread(() -> testConcurrentObjectPoolPoll(pool, results, pollCount))); + } + for (var t : threads) { + t.start(); + } + for (var t : threads) { + try { + t.join(); + } catch (InterruptedException ignore) { + } + } + assertEquals(65536, results.size()); + assertEquals(65536, new HashSet<>(results).size()); + results.sort(Integer::compareTo); + assertEquals(0, results.getFirst().intValue()); + assertEquals(65535, results.getLast().intValue()); + int sum = 0; + for (var n : pollCount) { + sum += n; + } + assertEquals(65536, sum); + for (int i = 0; i < 65536; ++i) { + assertTrue("check:add(" + i + ")", pool.add(i)); + } + } + + private void testConcurrentObjectPoolPoll(ConcurrentObjectPool pool, List results, List pollCount) { + var resultsOfThisThread = new ArrayList(); + while (true) { + var e = pool.poll(); + if (e == null) { + break; + } + resultsOfThisThread.add(e); + Thread.yield(); + } + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (results) { + results.addAll(resultsOfThisThread); + } + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (pollCount) { + pollCount.add(resultsOfThisThread.size()); + } + } + + @Test + public void concurrentObjectPoolAdd() { + var pool = new ConcurrentObjectPool(65536, 4, 0); + + var added = new ArrayList(); + var addCount = new ArrayList(); + var threads = new ArrayList(); + for (int i = 0; i < 64; ++i) { + threads.add(VProxyThread.create(() -> testConcurrentObjectPoolAdd(pool, added, addCount), "add:" + i)); + } + for (var t : threads) { + t.start(); + } + for (var t : threads) { + try { + t.join(); + } catch (InterruptedException ignore) { + } + } + assertEquals(65536, added.size()); + added.sort(Integer::compareTo); + var results = new ArrayList(); + for (int i = 0; i < 65536; ++i) { + var e = pool.poll(); + assertNotNull("check:poll():" + i, e); + results.add(e); + } + // the pool is empty so nothing would be polled + for (int i = 0; i < 10; ++i) { + assertNull(pool.poll()); + } + results.sort(Integer::compareTo); + assertEquals(added, results); + + int cnt = 0; + for (var n : addCount) { + cnt += n; + } + assertEquals(65536, cnt); + } + + private void testConcurrentObjectPoolAdd(ConcurrentObjectPool pool, List added, List addCount) { + var addedOfThisThread = new ArrayList(); + while (true) { + var e = ThreadLocalRandom.current().nextInt(); + if (!pool.add(e)) { + break; + } + addedOfThisThread.add(e); + Thread.yield(); + } + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (added) { + added.addAll(addedOfThisThread); + } + //noinspection SynchronizationOnLocalVariableOrMethodParameter + synchronized (addCount) { + addCount.add(addedOfThisThread.size()); + } + } }