Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time wheel implementation #19

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public T get() {
return elem;
}

@Override
public void removeSelf() {
queue.queue.remove(this);
queue.remove(this);
}
}
113 changes: 104 additions & 9 deletions base/src/main/java/vproxy/base/util/time/impl/TimeQueueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,132 @@
import vproxy.base.util.time.TimeElem;
import vproxy.base.util.time.TimeQueue;

import java.util.PriorityQueue;
import java.util.*;

public class TimeQueueImpl<T> implements TimeQueue<T> {
PriorityQueue<TimeElemImpl<T>> queue = new PriorityQueue<>((a, b) -> (int) (a.triggerTime - b.triggerTime));
private static final int TIME_WHEEL_LEVEL = 4;
private static final int MAX_TIME_WHEEL_INTERVAL = 1 << (TIME_WHEEL_LEVEL * TimeWheel.WHEEL_SIZE_POWER);

private final PriorityQueue<TimeElemImpl<T>> queue = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));

private final ArrayList<TimeWheel<T>> timeWheels;

private long lastTickTimestamp;

public TimeQueueImpl() {
this(System.currentTimeMillis());
}

public TimeQueueImpl(long currentTimestamp) {
this.timeWheels = new ArrayList<>(TIME_WHEEL_LEVEL);
for (int i = 0; i < TIME_WHEEL_LEVEL; i++) {
this.timeWheels.add(new TimeWheel<>(1 << (i * TimeWheel.WHEEL_SIZE_POWER), currentTimestamp));
}
this.lastTickTimestamp = currentTimestamp;
}

@Override
public TimeElem<T> add(long currentTimestamp, int timeout, T elem) {
TimeElemImpl<T> event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this);
queue.add(event);
final TimeElemImpl<T> event = new TimeElemImpl<>(currentTimestamp + timeout, elem, this);
addTimeElem(event, currentTimestamp);
return event;
}

private void addTimeElem(TimeElemImpl<T> event, long currentTimestamp) {
long timeout = event.triggerTime - currentTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
// long timeout task put into queue
queue.add(event);
} else if (timeout <= 0) {
// already timeout task put into the lowest time wheel
this.timeWheels.get(0).add(event, currentTimestamp);
} else {
var index = findTimeWheelIndex(timeout);
this.timeWheels.get(index).add(event, currentTimestamp);
}
}

@Override
public T poll() {
TimeElemImpl<T> elem = queue.poll();
if (elem == null)
TimeElem<T> elem = timeWheels.get(0).poll();
if (elem == null) {
return null;
return elem.elem;
}
return elem.get();
}

@Override
public boolean isEmpty() {
return queue.isEmpty();
for (TimeWheel<T> timeWheel : timeWheels) {
if (!timeWheel.isEmpty()) {
return false;
}
}
return true;
}

@Override
public int nextTime(long currentTimestamp) {
tickTimeWheel(currentTimestamp);
for (TimeWheel<T> timeWheel : timeWheels) {
if (timeWheel.isEmpty()) {
continue;
}
return timeWheel.nextTime(currentTimestamp);
}

TimeElemImpl<T> elem = queue.peek();
if (elem == null)
if (elem == null) {
return Integer.MAX_VALUE;
}
long triggerTime = elem.triggerTime;
return Math.max((int) (triggerTime - currentTimestamp), 0);
}

private void tickTimeWheel(long currentTimestamp) {
if (currentTimestamp <= this.lastTickTimestamp) {
return;
}

for (int i = TIME_WHEEL_LEVEL - 1; i > 0; i--) {
final var wheel = timeWheels.get(i);
while (wheel.tryTick(currentTimestamp)) {
final Collection<TimeElemImpl<T>> events = wheel.tick(currentTimestamp);
for (TimeElemImpl<T> event : events) {
addTimeElem(event, currentTimestamp);
}
}
}

// move elements from queue to time wheels
while (!queue.isEmpty()) {
final TimeElemImpl<T> elem = queue.peek();
long timeout = elem.triggerTime - currentTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
break;
}

addTimeElem(elem, currentTimestamp);
queue.poll();
}

this.lastTickTimestamp = currentTimestamp;
}

public void remove(TimeElemImpl<T> elem) {
long timeout = elem.triggerTime - this.lastTickTimestamp;
if (timeout >= MAX_TIME_WHEEL_INTERVAL) {
queue.remove(elem);
} else {
timeWheels.get(findTimeWheelIndex(timeout)).remove(elem);
}
}

private static int findTimeWheelIndex(long timeout) {
if (timeout <= 0) {
return 0;
}
int bits = 63 - Long.numberOfLeadingZeros(timeout);
return bits / TimeWheel.WHEEL_SIZE_POWER;
}
}
120 changes: 120 additions & 0 deletions base/src/main/java/vproxy/base/util/time/impl/TimeWheel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package vproxy.base.util.time.impl;

import vproxy.base.util.time.TimeElem;

import java.util.*;

public class TimeWheel<T> {
public static final int WHEEL_SIZE_POWER = 5;
public static final int WHEEL_SIZE = 1 << WHEEL_SIZE_POWER;

private final PriorityQueue<TimeElemImpl<T>>[] slots = new PriorityQueue[WHEEL_SIZE];
/**
* min time unit in this time wheel
*/
public final long tickDuration;
/**
* the time wheel max time interval. interval = tickDuration * WHEEL_SIZE
*/
public final long interval;
public final long startTimestamp;
private int tickIndex;
private long elemNum;
private long currentTime;

public TimeWheel(long tickDuration, long timestamp) {
this.tickDuration = tickDuration;
this.interval = this.tickDuration * WHEEL_SIZE;
this.startTimestamp = timestamp;
this.currentTime = timestamp;
this.tickIndex = findSlotIndex(timestamp);
this.elemNum = 0;

for (int i = 0; i < slots.length; i++) {
slots[i] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));
}
}

public void add(TimeElemImpl<T> elem, long timestamp) {
if (elem.triggerTime <= timestamp) {
slots[tickIndex].add(elem);
} else {
slots[findSlotIndex(elem.triggerTime)].add(elem);
}
elemNum++;
}

private int findSlotIndex(long timestamp) {
long timeout = timestamp - startTimestamp;
return (int) ((timeout & (interval - 1)) / tickDuration);
}

/**
* return true if it can move.
*/
public boolean tryTick(long timestamp) {
return timestamp - currentTime >= tickDuration;
}

/**
* move the tick index to point the next slot.
*/
public Collection<TimeElemImpl<T>> tick(long timestamp) {
if (!tryTick(timestamp)) {
return Collections.emptyList();
}

int oldIndex = tickIndex;
int nextIndex = (oldIndex + 1) & (WHEEL_SIZE - 1);
if (!slots[oldIndex].isEmpty()) {
slots[nextIndex].addAll(slots[oldIndex]);
slots[oldIndex].clear();
}
this.tickIndex = nextIndex;
final PriorityQueue<TimeElemImpl<T>> queue = slots[tickIndex];
slots[tickIndex] = new PriorityQueue<>(Comparator.comparingLong(x -> x.triggerTime));

elemNum -= queue.size();
currentTime += tickDuration;
return queue;
}

public TimeElem<T> poll() {
var elem = slots[tickIndex].poll();
if (elem != null) {
elemNum--;
}
return elem;
}

public boolean isEmpty() {
return elemNum == 0;
}

public long size() {
return elemNum;
}

public int nextTime(long timestamp) {
for (int i = tickIndex; i < tickIndex + WHEEL_SIZE; i++) {
final int index = i & (WHEEL_SIZE - 1);
if (slots[index].isEmpty()) {
continue;
}

long triggerTime = slots[index].peek().triggerTime;
int nextTime = Math.max((int) (triggerTime - timestamp), 0);
if (nextTime == 0 && index != tickIndex){
slots[tickIndex].add(slots[index].poll());
}
return nextTime;
}
return Integer.MAX_VALUE;
}

public void remove(TimeElemImpl<T> elem) {
if (slots[findSlotIndex(elem.triggerTime)].remove(elem)) {
elemNum--;
}
}
}
109 changes: 109 additions & 0 deletions test/src/test/java/vproxy/test/cases/TestTimeQueue.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package vproxy.test.cases;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import vproxy.base.util.Tuple;
import vproxy.base.util.time.TimeQueue;
import vproxy.base.util.time.impl.TimeQueueImpl;
import vproxy.base.util.time.impl.TimeWheel;

import java.util.*;
import java.util.concurrent.ThreadLocalRandom;

public class TestTimeQueue {
private TimeQueue<String> queue;
private long current = 0L;

@Before
public void setUp() throws Exception {
queue = new TimeQueueImpl<>(current);
}

@After
public void tearDown() throws Exception {
current = 0L;
}

private int sleepForQueue(long duration) {
current += duration;
return queue.nextTime(current);
}

private Tuple<Integer, String> pushRandomTimeTask(int origin, int bound) {
int timeout = ThreadLocalRandom.current().nextInt(origin, bound);
String elem = timeout + "#" + UUID.randomUUID();
queue.add(current, timeout, elem);
return new Tuple<>(timeout, elem);
}

public void buildRandomTest(int origin, int bound, int taskNum) {
final TreeMap<Integer, List<String>> taskMap = new TreeMap<>();
for (int i = 0; i < taskNum; i++) {
Tuple<Integer, String> tuple = pushRandomTimeTask(origin, bound);
taskMap.computeIfAbsent(tuple.getKey(), k -> new ArrayList<>()).add(tuple.getValue());
}

for (Map.Entry<Integer, List<String>> entry : taskMap.entrySet()) {
long duration = queue.nextTime(current);
sleepForQueue(duration);

List<String> strings = entry.getValue();
for (String ignored : strings) {
Assert.assertEquals(0, queue.nextTime(current));
String poll = queue.poll();
Assert.assertTrue(String.format("timestamp=%d, poll=%s", entry.getKey(), poll), strings.contains(poll));
}
}
Assert.assertTrue(queue.isEmpty());
}

@Test
public void firstWheel() {
buildRandomTest(1, TimeWheel.WHEEL_SIZE, 1000);
}

@Test
public void highWheel() {
buildRandomTest(TimeWheel.WHEEL_SIZE, (int) Math.pow(TimeWheel.WHEEL_SIZE, 4), 1000);
}

@Test
public void outOfWheel() {
buildRandomTest((int) Math.pow(TimeWheel.WHEEL_SIZE, 4), (int) Math.pow(TimeWheel.WHEEL_SIZE, 5), 1000);
}

@Test
public void level1() {
String elem = UUID.randomUUID().toString();
queue.add(current, 10, elem);
Assert.assertNull(queue.poll());

sleepForQueue(9);
Assert.assertNull(queue.poll());

sleepForQueue(1);
Assert.assertEquals(elem, queue.poll());
Assert.assertNull(queue.poll());
}

@Test
public void sameTime() {
String elem = UUID.randomUUID().toString();
queue.add(current, 10, elem);
String elem2 = UUID.randomUUID().toString();
queue.add(current, 10, elem2);
Assert.assertNull(queue.poll());

sleepForQueue(9);
Assert.assertNull(queue.poll());

Assert.assertEquals(0, sleepForQueue(1));
Assert.assertEquals(elem, queue.poll());

Assert.assertEquals(0, queue.nextTime(current));
Assert.assertEquals(elem2, queue.poll());
Assert.assertNull(queue.poll());
}
}