Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Scripts for producing the figures that explain the bucketing strategy (
Browse files Browse the repository at this point in the history
…#908)

* Scripts for producing the figures that explain the bucketing strategy

* move
  • Loading branch information
sxjscience authored Sep 3, 2019
1 parent a563293 commit e05288b
Showing 1 changed file with 261 additions and 0 deletions.
261 changes: 261 additions & 0 deletions tools/plot_bucketing_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['hatch.linewidth'] = 0.3
from matplotlib.patches import Rectangle

np.random.seed(215)
# COLORS = ["#247BA0",
# "#70C1B3",
# "#B2DBBF",
# "#F3FFBD",
# "#FF1654"]
# COLORS = ["#abd8ea",
# "#37aff7",
# "#2f97ef",
# "#178fd7",
# "#0077c0"]
COLORS = ["#deecfb",
"#bedaf7",
"#7ab3ef",
"#368ce7",
"#1666ba"]
# COLORS = ["#e6f0e1",
# "#d7ecd1",
# "#bbdfba",
# "#a4cca5",
# "#8bbf8c"]

PAD_COLOR = "#F0F0F0"
PAD_ALPHA = 1.0
PAD_LINEWIDTH = 1.5
ALPHA = 1.0
RECT_LINEIWDTH = 0.5
EDGECOLOR = 'k'
BAR_HEIGHT = 5
bucket_ranges = [(5, 10),
(10, 15),
(15, 20),
(20, 25),
(25, 30)]
bucket_label = ["[5, 10)",
"[10, 15)",
"[15, 20)",
"[20, 25)",
"[25, 30)"]
bucket_nums = [48, 32, 16, 8, 8]

MAX_LEN = 30
BASE_RATIO = 20.0 / float(sum(bucket_nums))
PAD_LEN = 0.005
BATCH_SIZE = 8


def plot_seq(seq_info, x_begin, y_begin, x_end, y_end, save_path=None, title=None):
fig, ax = plt.subplots(figsize=(len(seq_info) * BASE_RATIO, 5))
ax.set_axis_off()
x_len = (x_end - x_begin) / float(len(seq_info))
y_len = float(y_end - y_begin)
legend_objs = [None for _ in range(len(bucket_label))]
# Draw Rectangles
for i, (seq_len, bucket_id) in enumerate(seq_info):
rect = Rectangle((x_begin + x_len * i, y_begin), x_len,
y_len * float(seq_len) / MAX_LEN,
facecolor=COLORS[bucket_id],
linewidth=RECT_LINEIWDTH,
edgecolor=EDGECOLOR,
alpha=ALPHA)
legend_objs[bucket_id] = rect
ax.add_patch(rect)

fig.legend(legend_objs, bucket_label, loc="upper center", ncol=len(legend_objs), borderaxespad=0.05, fontsize=12)
title = r'Data Samples' if title is None else title
ax.text(0.5, - 0.04, title, horizontalalignment='center',
verticalalignment='center', fontsize=14)
fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, bbox_inches='tight', pad_inches=0)


def plot_bucket_seq(seq_info, x_begin, y_begin, x_end, y_end, bucket_sizes, save_path=None, title=None, sort_length=False):
fig, ax = plt.subplots(figsize=(len(seq_info) * BASE_RATIO, 5))
ax.set_axis_off()
bucket_num = len(bucket_sizes)
x_len = (x_end - x_begin - PAD_LEN * (bucket_num - 1)) / float(len(seq_info))
y_len = float(y_end - y_begin)
legend_objs = [None for _ in range(len(bucket_label))]
# Draw Rectangles
sample_id = 0
print(bucket_sizes)
for bucket_id, size in enumerate(bucket_sizes):
bucket_seq_info = seq_info[sample_id:(sample_id + size)]
if sort_length:
bucket_seq_info = sorted(bucket_seq_info, key=lambda ele: ele[0], reverse=True)
for i in range(size):
seq_len, color_id = bucket_seq_info[i]
rect = Rectangle((x_begin + x_len * sample_id + bucket_id * PAD_LEN, y_begin), x_len,
y_len * float(seq_len) / MAX_LEN,
facecolor=COLORS[color_id],
linewidth=RECT_LINEIWDTH,
edgecolor=EDGECOLOR,
alpha=ALPHA)
sample_id += 1
legend_objs[color_id] = rect
ax.add_patch(rect)
ax.annotate(r'$Bucket_{%d}$' %bucket_id,
xy=(x_begin + x_len * (sample_id - size / 2.0) + bucket_id * PAD_LEN, y_begin - 0.02),
xytext=(x_begin + x_len * (sample_id - size / 2.0) + bucket_id * PAD_LEN, y_begin - 0.1),
xycoords='axes fraction',
fontsize=12, ha='center', va='bottom',
arrowprops=dict(arrowstyle='-[, widthB=%g, lengthB=0.5' % (size * x_len * 55), lw=1.0))
fig.legend(legend_objs, bucket_label, loc="upper center", ncol=len(legend_objs), borderaxespad=0.05, fontsize=12)
title = r'Bucket Data Samples' if title is None else title
ax.text(0.5, - 0.04, title, horizontalalignment='center',
verticalalignment='center', fontsize=14)
fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, bbox_inches='tight', pad_inches=0)


def add_padded_batches(ax, all_batch_inds, seq_info, x_begin, y_begin, x_end, y_end):
batch_num = len(all_batch_inds)
cnt = 0
padding_val = 0.0
legend_objs = [None for _ in range(len(bucket_label) + 1)]
x_len = (x_end - x_begin - PAD_LEN * (batch_num - 1)) / float(len(seq_info))
y_len = float(y_end - y_begin)
for bid, batch_inds in enumerate(all_batch_inds):
max_seq_len = max(seq_info[j][0] for j in batch_inds)
rect = Rectangle((x_begin + x_len * cnt + bid * PAD_LEN, y_begin),
len(batch_inds) * x_len,
y_len * float(max_seq_len) / MAX_LEN,
facecolor=PAD_COLOR,
linewidth=PAD_LINEWIDTH,
linestyle='-',
edgecolor=EDGECOLOR,
hatch='/',
alpha=PAD_ALPHA,
zorder=1)
ax.add_patch(rect)
border_rect = Rectangle((x_begin + x_len * cnt + bid * PAD_LEN, y_begin),
len(batch_inds) * x_len,
y_len * float(max_seq_len) / MAX_LEN,
linewidth=PAD_LINEWIDTH,
linestyle='-',
edgecolor=EDGECOLOR,
fill=False,
zorder=20)
ax.add_patch(border_rect)
ax.text(x_begin + x_len * cnt + bid * PAD_LEN + len(batch_inds) * x_len / 2.0,
y_begin - 0.03,
r'$B_{%d}$' % bid, horizontalalignment='center', verticalalignment='center',
fontsize=12)
legend_objs[0] = rect

for ind in batch_inds:
seq_len, color_id = seq_info[ind]
rect = Rectangle((x_begin + x_len * cnt + bid * PAD_LEN, y_begin), x_len,
y_len * float(seq_len) / MAX_LEN,
facecolor=COLORS[color_id],
linewidth=RECT_LINEIWDTH,
edgecolor=EDGECOLOR,
alpha=ALPHA,
label=bucket_label[bucket_id],
zorder=10)
cnt += 1
legend_objs[color_id + 1] = rect
ax.add_patch(rect)
padding_val += max_seq_len - seq_len
avg_padding = padding_val / float(len(seq_info))
return avg_padding, legend_objs

def plot_batches(seq_info, all_batch_inds, x_begin, y_begin, x_end, y_end, save_path=None,
title="Bucketing Strategy. "):
fig, ax = plt.subplots(figsize=(len(seq_info) * BASE_RATIO, 5))
ax.set_axis_off()
avg_pad, legend_objs = add_padded_batches(ax, all_batch_inds, seq_info, x_begin, y_begin, x_end, y_end)
fig.legend(legend_objs, ['padding'] + bucket_label, loc="upper center", ncol=len(legend_objs),
borderaxespad=0.04, fontsize=12)
ax.text(0.5, - 0.04,
r'%sAvg Pad = %.1f' % (title, avg_pad),
horizontalalignment='center', verticalalignment='center', fontsize=14)
fig.tight_layout()
if save_path is not None:
fig.savefig(save_path, bbox_inches='tight', pad_inches=0)
print('%savg padding=%g' %(title, avg_pad))
return avg_pad


def get_no_bucket_inds(seq_info):
batch_inds = []
for begin in range(0, len(seq_info), BATCH_SIZE):
end = min(begin + BATCH_SIZE, len(seq_info))
batch_inds.append(list(range(begin, end)))
return batch_inds


def get_sorted_bucket_inds(seq_info, mult=4):
batch_inds = []
for bucket_begin in range(0, len(seq_info), BATCH_SIZE * mult):
bucket_end = min(bucket_begin + BATCH_SIZE * mult, len(seq_info))
bucket_sample_ids = sorted(range(bucket_begin, bucket_end),
key=lambda ele: seq_info[ele][0],
reverse=True)
for begin in range(0, bucket_end - bucket_begin, BATCH_SIZE):
end = min(begin + BATCH_SIZE, bucket_end - bucket_begin)
batch_inds.append(bucket_sample_ids[begin:end])
return batch_inds


def get_fixed_bucket_inds(seq_info, ratio=0.0):
bucket_sample_ids = [[] for _ in range(5)]
batch_inds = []
for i, (seq_len, color_id) in enumerate(seq_info):
bucket_sample_ids[color_id].append(i)
bucket_seq_len = [ele[1] - 1 for ele in bucket_ranges]

bucket_batch_sizes = [max(int(BATCH_SIZE * ratio * max(bucket_seq_len) / float(ele_len)), BATCH_SIZE)
for ele_len in bucket_seq_len]
bucket_sample_ids, bucket_batch_sizes = bucket_sample_ids[::-1], bucket_batch_sizes[::-1]
for i, (sample_ids, batch_size) in enumerate(zip(bucket_sample_ids, bucket_batch_sizes)):
for begin in range(0, len(sample_ids), batch_size):
end = min(begin + batch_size, len(sample_ids))
batch_inds.append(sample_ids[begin:end])
return bucket_sample_ids, batch_inds


seq_info = []
for bucket_id, (brange, bnum) in enumerate(zip(bucket_ranges, bucket_nums)):
for _ in range(bnum):
seq_info.append((np.random.randint(brange[0], brange[1]), bucket_id))
np.random.shuffle(seq_info)
plot_seq(seq_info, 0.0, 0.0, 0.99, 0.97, save_path="data_samples.png")
batch_inds = get_no_bucket_inds(seq_info)
plot_batches(seq_info, batch_inds, 0.0, 0.05, 0.99, 0.97,
save_path="no_bucket_strategy.png", title="No Bucketing Strategy. ")

sorted_bucket_size = [BATCH_SIZE * 4 for _ in range(0, len(seq_info), BATCH_SIZE * 4)]
sorted_bucket_size[-1] -= BATCH_SIZE * 4 * len(sorted_bucket_size) - len(seq_info)
plot_bucket_seq(seq_info, 0.0, 0.1, 0.99, 0.97, bucket_sizes=sorted_bucket_size, save_path="sorted_bucket_data_samples.png", title=r"Data Samples. Bucket Size = %d" %(BATCH_SIZE * 4))
plot_bucket_seq(seq_info, 0.0, 0.1, 0.99, 0.97, bucket_sizes=sorted_bucket_size, save_path="sorted_bucket_data_samples_after_sort.png", title=r"Sorted Data Samples. Bucket Size = %d" %(BATCH_SIZE * 4), sort_length=True)
# plot_seq_batch_size_padded(seq_info, 0.0, 0.05, 0.99, 0.97, batch_size=BATCH_SIZE, save_path="no_bucket_strategy.png")

batch_inds = get_sorted_bucket_inds(seq_info, mult=4)
plot_batches(seq_info, batch_inds, 0.0, 0.05, 0.99, 0.97,
save_path="sorted_bucket_strategy.png", title=r"Sorted Bucketing Strategy. Bucket Size = %d, " % (BATCH_SIZE * 4))


bucket_sample_ids, batch_inds = get_fixed_bucket_inds(seq_info, 0.0)
plot_batches(seq_info, batch_inds, 0.0, 0.05, 0.99, 0.97,
save_path="fixed_bucket_strategy_ratio0.0.png", title="Fixed Bucketing Strategy. Ratio = 0.0, ")

plot_bucket_seq([seq_info[i] for i in sum(bucket_sample_ids, [])], 0.0, 0.1, 0.99, 0.97,
bucket_sizes=[len(ele) for ele in bucket_sample_ids],
save_path="fixed_bucket_data_samples.png", title=r"Reorganized Data Samples")

bucket_sample_ids, batch_inds = get_fixed_bucket_inds(seq_info, 0.7)
plot_batches(seq_info, batch_inds, 0.0, 0.05, 0.99, 0.97,
save_path="fixed_bucket_strategy_ratio0.7.png", title="Fixed Bucketing Strategy. Ratio = 0.7, ")


plt.show()

0 comments on commit e05288b

Please sign in to comment.