Skip to content

Commit

Permalink
feat: add time wheel
Browse files Browse the repository at this point in the history
  • Loading branch information
nintha committed Jul 19, 2021
1 parent 69deca1 commit 5bd90d0
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public T get() {
return elem;
}

@Override
public void removeSelf() {
queue.queue.remove(this);
queue.remove(this);
}
}
113 changes: 104 additions & 9 deletions base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,132 @@
import vproxy.base.util.time.TimeElem;
import vproxy.base.util.time.TimeQueue;

import java.util.PriorityQueue;
import java.util.*;

public class TimeQueueImpl<T> implements TimeQueue<T> {
PriorityQueue<TimeElemImpl<T>> queue = new PriorityQueue<>((a, b) -> (int) (a.triggerTime - b.triggerTime));
private static final int TIME_WHEEL_LEVEL = 4;
private static final int MAX_TIME_WHEEL_INTERVAL = 1 << (TIME_WHEEL_LEVEL * TimeWheel.WHEEL_SIZE_POWER);

private final PriorityQueue<TimeElemImpl<T>> queue = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));

private final ArrayList<TimeWheel<T>> timeWheels;

private long lastTickTimestamp;

public TimeQueueImpl() {
this(System.currentTimeMillis());
}

public TimeQueueImpl(long currentTimestamp) {
this.timeWheels = new ArrayList<>(TIME_WHEEL_LEVEL);
for (int i = 0; i < TIME_WHEEL_LEVEL; i++) {
this.timeWheels.add(new TimeWheel<>(1 << (i * TimeWheel.WHEEL_SIZE_POWER), currentTimestamp));
}
this.lastTickTimestamp = currentTimestamp;
}

@Override
public TimeElem<T> add(long currentTimestamp, int timeout, T elem) {
TimeElemImpl<T> event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this);
queue.add(event);
final TimeElemImpl<T> event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this);
addTimeElem(event, currentTimestamp);
return event;
}

private void addTimeElem(TimeElemImpl<T> event, long currentTimestamp) {
long timeout = event.triggerTime - currentTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
// long timeout task put into queue
queue.add(event);
} else if (timeout <= 0) {
// already timeout task put into the lowest time wheel
this.timeWheels.get(0).add(event, currentTimestamp);
} else {
var index = findTimeWheelIndex(timeout);
this.timeWheels.get(index).add(event, currentTimestamp);
}
}

@Override
public T poll() {
TimeElemImpl<T> elem = queue.poll();
if (elem == null)
TimeElem<T> elem = timeWheels.get(0).poll();
if (elem == null) {
return null;
return elem.elem;
}
return elem.get();
}

@Override
public boolean isEmpty() {
return queue.isEmpty();
for (TimeWheel<T> timeWheel : timeWheels) {
if (!timeWheel.isEmpty()) {
return false;
}
}
return true;
}

@Override
public int nextTime(long currentTimestamp) {
tickTimeWheel(currentTimestamp);
for (TimeWheel<T> timeWheel : timeWheels) {
if (timeWheel.isEmpty()) {
continue;
}
return timeWheel.nextTime(currentTimestamp);
}

TimeElemImpl<T> elem = queue.peek();
if (elem == null)
if (elem == null) {
return Integer.MAX_VALUE;
}
long triggerTime = elem.triggerTime;
return Math.max((int) (triggerTime - currentTimestamp), 0);
}

private void tickTimeWheel(long currentTimestamp) {
if (currentTimestamp <= this.lastTickTimestamp) {
return;
}

for (int i = TIME_WHEEL_LEVEL - 1; i > 0; i--) {
final var wheel = timeWheels.get(i);
while (wheel.tryTick(currentTimestamp)) {
final Collection<TimeElemImpl<T>> events = wheel.tick(currentTimestamp);
for (TimeElemImpl<T> event : events) {
addTimeElem(event, currentTimestamp);
}
}
}

// move elements from queue to time wheels
while (!queue.isEmpty()) {
final TimeElemImpl<T> elem = queue.peek();
long timeout = elem.triggerTime - currentTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
break;
}

addTimeElem(elem, currentTimestamp);
queue.poll();
}

this.lastTickTimestamp = currentTimestamp;
}

public void remove(TimeElemImpl<T> elem) {
long timeout = elem.triggerTime - this.lastTickTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
queue.remove(elem);
} else {
timeWheels.get(findTimeWheelIndex(timeout)).remove(elem);
}
}

private static int findTimeWheelIndex(long timeout) {
if (timeout <= 0) {
return 0;
}
int bits = 63 - Long.numberOfLeadingZeros(timeout);
return bits / TimeWheel.WHEEL_SIZE_POWER;
}
}
120 changes: 120 additions & 0 deletions base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package vproxy.base.util.time.impl;

import vproxy.base.util.time.TimeElem;

import java.util.*;

public class TimeWheel<T> {
public static final int WHEEL_SIZE_POWER = 5;
public static final int WHEEL_SIZE = 1 << WHEEL_SIZE_POWER;

private final PriorityQueue<TimeElemImpl<T>>[] slots = new PriorityQueue[WHEEL_SIZE];
/**
* min time unit in this time wheel
*/
public final long tickDuration;
/**
* the time wheel max time interval. interval = tickDuration * WHEEL_SIZE
*/
public final long interval;
public final long startTimestamp;
private int tickIndex;
private long elemNum;
private long currentTime;

public TimeWheel(long tickDuration, long timestamp) {
this.tickDuration = tickDuration;
this.interval = this.tickDuration * WHEEL_SIZE;
this.startTimestamp = timestamp;
this.currentTime = timestamp;
this.tickIndex = findSlotIndex(timestamp);
this.elemNum = 0;

for (int i = 0; i < slots.length; i++) {
slots[i] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));
}
}

public void add(TimeElemImpl<T> elem, long timestamp) {
if (elem.triggerTime <= timestamp) {
slots[tickIndex].add(elem);
} else {
slots[findSlotIndex(elem.triggerTime)].add(elem);
}
elemNum++;
}

private int findSlotIndex(long timestamp) {
long timeout = timestamp - startTimestamp;
return (int) ((timeout & (interval - 1)) / tickDuration);
}

/**
* return true if it can move.
*/
public boolean tryTick(long timestamp) {
return timestamp - currentTime >= tickDuration;
}

/**
* move the tick index to point the next slot.
*/
public Collection<TimeElemImpl<T>> tick(long timestamp) {
if (!tryTick(timestamp)) {
return Collections.emptyList();
}

int oldIndex = tickIndex;
int nextIndex = (oldIndex + 1) & (WHEEL_SIZE - 1);
if (!slots[oldIndex].isEmpty()) {
slots[nextIndex].addAll(slots[oldIndex]);
slots[oldIndex].clear();
}
this.tickIndex = nextIndex;
final PriorityQueue<TimeElemImpl<T>> queue = slots[tickIndex];
slots[tickIndex] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));

elemNum -= queue.size();
currentTime += tickDuration;
return queue;
}

public TimeElem<T> poll() {
var elem = slots[tickIndex].poll();
if (elem != null) {
elemNum--;
}
return elem;
}

public boolean isEmpty() {
return elemNum == 0;
}

public long size() {
return elemNum;
}

public int nextTime(long timestamp) {
for (int i = tickIndex; i < tickIndex + WHEEL_SIZE; i++) {
final int index = i & (WHEEL_SIZE - 1);
if (slots[index].isEmpty()) {
continue;
}

long triggerTime = slots[index].peek().triggerTime;
int nextTime = Math.max((int) (triggerTime - timestamp), 0);
if (nextTime == 0 && index != tickIndex){
slots[tickIndex].add(slots[index].poll());
}
return nextTime;
}
return Integer.MAX_VALUE;
}

public void remove(TimeElemImpl<T> elem) {
if (slots[findSlotIndex(elem.triggerTime)].remove(elem)) {
elemNum--;
}
}
}
121 changes: 121 additions & 0 deletions test/src/test/java/vproxy/test/cases/TestTimeQueue.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package vproxy.test.cases;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import vproxy.base.util.Tuple;
import vproxy.base.util.time.TimeQueue;
import vproxy.base.util.time.impl.TimeQueueImpl;
import vproxy.base.util.time.impl.TimeWheel;

import java.util.*;
import java.util.concurrent.ThreadLocalRandom;

public class TestTimeQueue {
private TimeQueue<String> queue;
private long current = 0L;

@Before
public void setUp() throws Exception {
queue = new TimeQueueImpl<>(current);
}

@After
public void tearDown() throws Exception {
current = 0L;
}

private int sleepForQueue(long duration) {
current += duration;
return queue.nextTime(current);
}

private Tuple<Integer, String> pushRandomTimeTask(int origin, int bound) {
int timeout = ThreadLocalRandom.current().nextInt(origin, bound);
String elem = timeout + "#" + UUID.randomUUID();
queue.add(current, timeout, elem);
return new Tuple<>(timeout, elem);
}

public void buildRandomTest(int origin, int bound, int taskNum) {
final TreeMap<Integer, List<String>> taskMap = new TreeMap<>();
for (int i = 0; i < taskNum; i++) {
Tuple<Integer, String> tuple = pushRandomTimeTask(origin, bound);
taskMap.computeIfAbsent(tuple.getKey(), k -> new ArrayList<>()).add(tuple.getValue());
}

for (Map.Entry<Integer, List<String>> entry : taskMap.entrySet()) {
long duration = queue.nextTime(current);
sleepForQueue(duration);

List<String> strings = entry.getValue();
for (String ignored : strings) {
Assert.assertEquals(0, queue.nextTime(current));
String poll = queue.poll();
Assert.assertTrue(String.format("timestamp=%d, poll=%s", entry.getKey(), poll), strings.contains(poll));
}
}
Assert.assertTrue(queue.isEmpty());
}

/**
* 测试第一层时间轮
*/
@Test
public void firstWheel() {
buildRandomTest(1, TimeWheel.WHEEL_SIZE, 1000);
}

/**
* 测试高层时间轮
*/
@Test
public void highWheel() {
buildRandomTest(TimeWheel.WHEEL_SIZE, (int) Math.pow(TimeWheel.WHEEL_SIZE, 4), 1000);
}

/**
* 测试时间轮溢出
*/
@Test
public void outOfWheel() {
buildRandomTest((int) Math.pow(TimeWheel.WHEEL_SIZE, 4), (int) Math.pow(TimeWheel.WHEEL_SIZE, 5), 1000);
}

/**
* 第一层时间轮边界
*/
@Test
public void level1() {
String elem = UUID.randomUUID().toString();
queue.add(current, 10, elem);
Assert.assertNull(queue.poll());

sleepForQueue(9);
Assert.assertNull(queue.poll());

sleepForQueue(1);
Assert.assertEquals(elem, queue.poll());
Assert.assertNull(queue.poll());
}

@Test
public void sameTime() {
String elem = UUID.randomUUID().toString();
queue.add(current, 10, elem);
String elem2 = UUID.randomUUID().toString();
queue.add(current, 10, elem2);
Assert.assertNull(queue.poll());

sleepForQueue(9);
Assert.assertNull(queue.poll());

Assert.assertEquals(0, sleepForQueue(1));
Assert.assertEquals(elem, queue.poll());

Assert.assertEquals(0, queue.nextTime(current));
Assert.assertEquals(elem2, queue.poll());
Assert.assertNull(queue.poll());
}
}

0 comments on commit 5bd90d0

Please sign in to comment.