From 66219ce37200526c6f4f2e629537c3a9ff9f0199 Mon Sep 17 00:00:00 2001 From: ninthakeey Date: Fri, 19 Feb 2021 00:19:01 +0800 Subject: [PATCH] feat: add time wheel --- .../base/util/time/impl/TimeElemImpl.java | 3 +- .../base/util/time/impl/TimeQueueImpl.java | 113 +++++++++++++++-- .../vproxy/base/util/time/impl/TimeWheel.java | 120 ++++++++++++++++++ .../java/vproxy/test/cases/TestTimeQueue.java | 109 ++++++++++++++++ 4 files changed, 335 insertions(+), 10 deletions(-) create mode 100644 base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java create mode 100644 test/src/test/java/vproxy/test/cases/TestTimeQueue.java diff --git a/base/src/main/java/vproxy/base/util/time/impl/TimeElemImpl.java b/base/src/main/java/vproxy/base/util/time/impl/TimeElemImpl.java index acb5b3fe8..1fa64add1 100644 --- a/base/src/main/java/vproxy/base/util/time/impl/TimeElemImpl.java +++ b/base/src/main/java/vproxy/base/util/time/impl/TimeElemImpl.java @@ -18,7 +18,8 @@ public T get() { return elem; } + @Override public void removeSelf() { - queue.queue.remove(this); + queue.remove(this); } } diff --git a/base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java b/base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java index 8ca67246c..f69a1d993 100644 --- a/base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java +++ b/base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java @@ -3,37 +3,132 @@ import vproxy.base.util.time.TimeElem; import vproxy.base.util.time.TimeQueue; -import java.util.PriorityQueue; +import java.util.*; public class TimeQueueImpl implements TimeQueue { - PriorityQueue> queue = new PriorityQueue<>((a, b) -> (int) (a.triggerTime - b.triggerTime)); + private static final int TIME_WHEEL_LEVEL = 4; + private static final int MAX_TIME_WHEEL_INTERVAL = 1 << (TIME_WHEEL_LEVEL * TimeWheel.WHEEL_SIZE_POWER); + + private final PriorityQueue> queue = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime)); + + private final ArrayList> timeWheels; + + private long lastTickTimestamp; + + public TimeQueueImpl() { + this(System.currentTimeMillis()); + } + + public TimeQueueImpl(long currentTimestamp) { + this.timeWheels = new ArrayList<>(TIME_WHEEL_LEVEL); + for (int i = 0; i < TIME_WHEEL_LEVEL; i++) { + this.timeWheels.add(new TimeWheel<>(1 << (i * TimeWheel.WHEEL_SIZE_POWER), currentTimestamp)); + } + this.lastTickTimestamp = currentTimestamp; + } @Override public TimeElem add(long currentTimestamp, int timeout, T elem) { - TimeElemImpl event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this); - queue.add(event); + final TimeElemImpl event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this); + addTimeElem(event, currentTimestamp); return event; } + private void addTimeElem(TimeElemImpl event, long currentTimestamp) { + long timeout = event.triggerTime - currentTimestamp; + if (timeout >= MAX_TIME_WHEEL_INTERVAL) { + // long timeout task put into queue + queue.add(event); + } else if (timeout <= 0) { + // already timeout task put into the lowest time wheel + this.timeWheels.get(0).add(event, currentTimestamp); + } else { + var index = findTimeWheelIndex(timeout); + this.timeWheels.get(index).add(event, currentTimestamp); + } + } + @Override public T poll() { - TimeElemImpl elem = queue.poll(); - if (elem == null) + TimeElem elem = timeWheels.get(0).poll(); + if (elem == null) { return null; - return elem.elem; + } + return elem.get(); } @Override public boolean isEmpty() { - return queue.isEmpty(); + for (TimeWheel timeWheel : timeWheels) { + if (!timeWheel.isEmpty()) { + return false; + } + } + return true; } @Override public int nextTime(long currentTimestamp) { + tickTimeWheel(currentTimestamp); + for (TimeWheel timeWheel : timeWheels) { + if (timeWheel.isEmpty()) { + continue; + } + return timeWheel.nextTime(currentTimestamp); + } + TimeElemImpl elem = queue.peek(); - if (elem == null) + if (elem == null) { return Integer.MAX_VALUE; + } long triggerTime = elem.triggerTime; return Math.max((int) (triggerTime - currentTimestamp), 0); } + + private void tickTimeWheel(long currentTimestamp) { + if (currentTimestamp <= this.lastTickTimestamp) { + return; + } + + for (int i = TIME_WHEEL_LEVEL - 1; i > 0; i--) { + final var wheel = timeWheels.get(i); + while (wheel.tryTick(currentTimestamp)) { + final Collection> events = wheel.tick(currentTimestamp); + for (TimeElemImpl event : events) { + addTimeElem(event, currentTimestamp); + } + } + } + + // move elements from queue to time wheels + while (!queue.isEmpty()) { + final TimeElemImpl elem = queue.peek(); + long timeout = elem.triggerTime - currentTimestamp; + if (timeout >= MAX_TIME_WHEEL_INTERVAL) { + break; + } + + addTimeElem(elem, currentTimestamp); + queue.poll(); + } + + this.lastTickTimestamp = currentTimestamp; + } + + public void remove(TimeElemImpl elem) { + long timeout = elem.triggerTime - this.lastTickTimestamp; + if (timeout >= MAX_TIME_WHEEL_INTERVAL) { + queue.remove(elem); + } else { + timeWheels.get(findTimeWheelIndex(timeout)).remove(elem); + } + } + + private static int findTimeWheelIndex(long timeout) { + if (timeout <= 0) { + return 0; + } + int bits = 63 - Long.numberOfLeadingZeros(timeout); + return bits / TimeWheel.WHEEL_SIZE_POWER; + } } diff --git a/base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java b/base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java new file mode 100644 index 000000000..4e6be05ea --- /dev/null +++ b/base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java @@ -0,0 +1,120 @@ +package vproxy.base.util.time.impl; + +import vproxy.base.util.time.TimeElem; + +import java.util.*; + +public class TimeWheel { + public static final int WHEEL_SIZE_POWER = 5; + public static final int WHEEL_SIZE = 1 << WHEEL_SIZE_POWER; + + private final PriorityQueue>[] slots = new PriorityQueue[WHEEL_SIZE]; + /** + * min time unit in this time wheel + */ + public final long tickDuration; + /** + * the time wheel max time interval. interval = tickDuration * WHEEL_SIZE + */ + public final long interval; + public final long startTimestamp; + private int tickIndex; + private long elemNum; + private long currentTime; + + public TimeWheel(long tickDuration, long timestamp) { + this.tickDuration = tickDuration; + this.interval = this.tickDuration * WHEEL_SIZE; + this.startTimestamp = timestamp; + this.currentTime = timestamp; + this.tickIndex = findSlotIndex(timestamp); + this.elemNum = 0; + + for (int i = 0; i < slots.length; i++) { + slots[i] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime)); + } + } + + public void add(TimeElemImpl elem, long timestamp) { + if (elem.triggerTime <= timestamp) { + slots[tickIndex].add(elem); + } else { + slots[findSlotIndex(elem.triggerTime)].add(elem); + } + elemNum++; + } + + private int findSlotIndex(long timestamp) { + long timeout = timestamp - startTimestamp; + return (int) ((timeout & (interval - 1)) / tickDuration); + } + + /** + * return true if it can move. + */ + public boolean tryTick(long timestamp) { + return timestamp - currentTime >= tickDuration; + } + + /** + * move the tick index to point the next slot. + */ + public Collection> tick(long timestamp) { + if (!tryTick(timestamp)) { + return Collections.emptyList(); + } + + int oldIndex = tickIndex; + int nextIndex = (oldIndex + 1) & (WHEEL_SIZE - 1); + if (!slots[oldIndex].isEmpty()) { + slots[nextIndex].addAll(slots[oldIndex]); + slots[oldIndex].clear(); + } + this.tickIndex = nextIndex; + final PriorityQueue> queue = slots[tickIndex]; + slots[tickIndex] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime)); + + elemNum -= queue.size(); + currentTime += tickDuration; + return queue; + } + + public TimeElem poll() { + var elem = slots[tickIndex].poll(); + if (elem != null) { + elemNum--; + } + return elem; + } + + public boolean isEmpty() { + return elemNum == 0; + } + + public long size() { + return elemNum; + } + + public int nextTime(long timestamp) { + for (int i = tickIndex; i < tickIndex + WHEEL_SIZE; i++) { + final int index = i & (WHEEL_SIZE - 1); + if (slots[index].isEmpty()) { + continue; + } + + long triggerTime = slots[index].peek().triggerTime; + int nextTime = Math.max((int) (triggerTime - timestamp), 0); + if (nextTime == 0 && index != tickIndex){ + slots[tickIndex].add(slots[index].poll()); + } + return nextTime; + } + return Integer.MAX_VALUE; + } + + public void remove(TimeElemImpl elem) { + if (slots[findSlotIndex(elem.triggerTime)].remove(elem)) { + elemNum--; + } + } +} diff --git a/test/src/test/java/vproxy/test/cases/TestTimeQueue.java b/test/src/test/java/vproxy/test/cases/TestTimeQueue.java new file mode 100644 index 000000000..372671924 --- /dev/null +++ b/test/src/test/java/vproxy/test/cases/TestTimeQueue.java @@ -0,0 +1,109 @@ +package vproxy.test.cases; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import vproxy.base.util.Tuple; +import vproxy.base.util.time.TimeQueue; +import vproxy.base.util.time.impl.TimeQueueImpl; +import vproxy.base.util.time.impl.TimeWheel; + +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; + +public class TestTimeQueue { + private TimeQueue queue; + private long current = 0L; + + @Before + public void setUp() throws Exception { + queue = new TimeQueueImpl<>(current); + } + + @After + public void tearDown() throws Exception { + current = 0L; + } + + private int sleepForQueue(long duration) { + current += duration; + return queue.nextTime(current); + } + + private Tuple pushRandomTimeTask(int origin, int bound) { + int timeout = ThreadLocalRandom.current().nextInt(origin, bound); + String elem = timeout + "#" + UUID.randomUUID(); + queue.add(current, timeout, elem); + return new Tuple<>(timeout, elem); + } + + public void buildRandomTest(int origin, int bound, int taskNum) { + final TreeMap> taskMap = new TreeMap<>(); + for (int i = 0; i < taskNum; i++) { + Tuple tuple = pushRandomTimeTask(origin, bound); + taskMap.computeIfAbsent(tuple.getKey(), k -> new ArrayList<>()).add(tuple.getValue()); + } + + for (Map.Entry> entry : taskMap.entrySet()) { + long duration = queue.nextTime(current); + sleepForQueue(duration); + + List strings = entry.getValue(); + for (String ignored : strings) { + Assert.assertEquals(0, queue.nextTime(current)); + String poll = queue.poll(); + Assert.assertTrue(String.format("timestamp=%d, poll=%s", entry.getKey(), poll), strings.contains(poll)); + } + } + Assert.assertTrue(queue.isEmpty()); + } + + @Test + public void firstWheel() { + buildRandomTest(1, TimeWheel.WHEEL_SIZE, 1000); + } + + @Test + public void highWheel() { + buildRandomTest(TimeWheel.WHEEL_SIZE, (int) Math.pow(TimeWheel.WHEEL_SIZE, 4), 1000); + } + + @Test + public void outOfWheel() { + buildRandomTest((int) Math.pow(TimeWheel.WHEEL_SIZE, 4), (int) Math.pow(TimeWheel.WHEEL_SIZE, 5), 1000); + } + + @Test + public void level1() { + String elem = UUID.randomUUID().toString(); + queue.add(current, 10, elem); + Assert.assertNull(queue.poll()); + + sleepForQueue(9); + Assert.assertNull(queue.poll()); + + sleepForQueue(1); + Assert.assertEquals(elem, queue.poll()); + Assert.assertNull(queue.poll()); + } + + @Test + public void sameTime() { + String elem = UUID.randomUUID().toString(); + queue.add(current, 10, elem); + String elem2 = UUID.randomUUID().toString(); + queue.add(current, 10, elem2); + Assert.assertNull(queue.poll()); + + sleepForQueue(9); + Assert.assertNull(queue.poll()); + + Assert.assertEquals(0, sleepForQueue(1)); + Assert.assertEquals(elem, queue.poll()); + + Assert.assertEquals(0, queue.nextTime(current)); + Assert.assertEquals(elem2, queue.poll()); + Assert.assertNull(queue.poll()); + } +}