Skip to content

Commit

Permalink
fix rebase + clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
thucpham committed Nov 15, 2023
1 parent a93aa75 commit 7e0d6dd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 51 deletions.
3 changes: 1 addition & 2 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ namespace ctranslate2 {
const Padder* values_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t step = 0,
int chunk_index = 0) const;
dim_t offset = 0) const;

bool has_positional_embeddings() const {
return _relative_position_keys || _relative_attention_bias || _rotary_embeddings || _alibi;
Expand Down
3 changes: 1 addition & 2 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ namespace ctranslate2 {
const Padder* memory_padder = nullptr,
bool return_normalized_attention = true,
StorageView* position_bias = nullptr,
dim_t step = 0,
int chunk_index = 0) const;
dim_t offset = 0) const;

DataType output_type() const override {
return _ff.output_type();
Expand Down
14 changes: 13 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def main():

# Cross-attention heads that are highly correlated to the word-level timing,
# i.e. the alignment between audio and text tokens.
# Obtained from https://github.com/openai/whisper/blob/v20230306/whisper/__init__.py#L31-L45
# Obtained from https://github.com/openai/whisper/blob/v20231106/whisper/__init__.py#L32-L47
_WHISPER_ALIGNMENT_HEADS = {
"openai/whisper-tiny.en": [
(1, 0),
Expand Down Expand Up @@ -2039,4 +2039,16 @@ def main():
(26, 12),
(27, 15),
],
"openai/whisper-large-v3": [
(7, 0),
(10, 17),
(12, 18),
(13, 12),
(16, 1),
(17, 14),
(19, 11),
(21, 4),
(24, 1),
(25, 6),
],
}
42 changes: 8 additions & 34 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,14 @@ namespace ctranslate2 {
const Padder* values_padder,
bool return_normalized_attention,
StorageView* position_bias,
dim_t step,
int chunk_index) const {
dim_t offset) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();
StorageView fused_proj(dtype, device);
StorageView queries_proj(dtype, device);
StorageView keys_proj(dtype, device);
StorageView values_proj(dtype, device);
StorageView tmp_values_lengths;
const StorageView* current_values_lengths = nullptr;

const StorageView* q = &queries;
if (_layer_norm && _pre_norm) {
Expand All @@ -453,7 +450,7 @@ namespace ctranslate2 {

dim_t beam_size = 1;

bool computing_chunking_input = (_sliding_window > 0 && values_lengths);
bool prefilling = (_sliding_window > 0 && values_lengths);

if (!_self_attention) {
queries_proj = std::move(fused_proj);
Expand Down Expand Up @@ -518,33 +515,15 @@ namespace ctranslate2 {
split_heads(queries_proj, _num_heads);
}

_rotary_embeddings->apply(queries_proj, _sliding_window * chunk_index + step);
_rotary_embeddings->apply(keys_proj, _sliding_window * chunk_index + step);
_rotary_embeddings->apply(queries_proj, offset);
_rotary_embeddings->apply(keys_proj, offset);

if (_merge_time_and_head_dims) {
combine_heads(queries_proj, _num_heads);
queries_proj.reshape({queries_proj.dim(0), -1, _d_head});
}
}

if (computing_chunking_input && cached_keys && !cached_keys->empty()) {
auto max_time = _sliding_window + queries_proj.dim(2);
std::unique_ptr<const StorageView> input_lengths = std::make_unique<StorageView>(Shape{queries_proj.dim(0)}, int32_t(max_time), device);
const StorageView* lengths = input_lengths.get();
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
_num_heads,
max_time,
/*mask_future=*/true,
multi_query());

StorageView tmp_init(lengths_mask.dtype(), lengths_mask.device());
tmp_values_lengths = std::move(tmp_init);
const ops::Slide slide_lengths_op(2, _sliding_window, queries_proj.dim(2));
slide_lengths_op(lengths_mask, tmp_values_lengths);
current_values_lengths = &tmp_values_lengths;
}

if (cached_keys != nullptr) {
if (cached_keys->empty()) {
*cached_keys = std::move(keys_proj);
Expand All @@ -557,7 +536,7 @@ namespace ctranslate2 {
tmp = std::move(*cached_values);
concat_op({&tmp, &values_proj}, *cached_values);

if (!computing_chunking_input && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
if (!prefilling && _sliding_window > 0 && cached_keys->shape()[2] > _sliding_window) {
// only for generation
const ops::Slide slide_op(2, 1, cached_keys->shape()[2] - 1);
slide_op(*cached_keys, tmp);
Expand All @@ -569,11 +548,6 @@ namespace ctranslate2 {
}
}

if (!current_values_lengths) {
current_values_lengths = values_lengths;
}


if (cached_keys) {
keys_proj.shallow_copy(*cached_keys);
values_proj.shallow_copy(*cached_values);
Expand All @@ -583,7 +557,7 @@ namespace ctranslate2 {
dot_product_attention(queries_proj,
keys_proj,
values_proj,
current_values_lengths,
values_lengths,
_relative_position_keys,
_relative_position_values,
_relative_attention_bias,
Expand All @@ -598,7 +572,7 @@ namespace ctranslate2 {
_alibi,
position_bias);

if (computing_chunking_input && cached_keys->shape()[2] > _sliding_window) {
if (prefilling && cached_keys->shape()[2] > _sliding_window) {
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
StorageView tmp(dtype, device);
Expand Down Expand Up @@ -675,7 +649,7 @@ namespace ctranslate2 {

if (!_sin || offset + max_time > _sin.dim(0)) {
const dim_t cur_num_positions = _sin ? _sin.dim(0) : 0;
const dim_t new_num_positions = cur_num_positions + _num_initial_positions;
const dim_t new_num_positions = std::max(offset + max_time, cur_num_positions + _num_initial_positions);
initialize(new_num_positions, dim, device, dtype);
}

Expand Down
37 changes: 25 additions & 12 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ namespace ctranslate2 {
const Padder* memory_padder,
bool return_normalized_attention,
StorageView* position_bias,
dim_t step,
int chunk_index) const {
dim_t offset) const {
PROFILE("TransformerDecoderLayer");

const DataType dtype = input.dtype();
Expand Down Expand Up @@ -152,8 +151,7 @@ namespace ctranslate2 {
input_padder,
true,
position_bias,
step,
chunk_index);
offset);

if (_post_attention_layer_norm)
(*_post_attention_layer_norm)(input, hidden);
Expand All @@ -177,8 +175,7 @@ namespace ctranslate2 {
input_padder,
true,
position_bias,
step,
chunk_index);
offset);

StorageView context(dtype, device);
if (_encoder_attention) {
Expand Down Expand Up @@ -475,11 +472,9 @@ namespace ctranslate2 {

const dim_t batch_size = layer_in.dim(0);
dim_t max_time;
bool use_sliding_window = false;

if (_sliding_window > 0 && layer_in.dim(1) > _sliding_window) {
max_time = _sliding_window;
use_sliding_window = true;
} else
max_time = layer_in.dim(1);

Expand All @@ -495,13 +490,15 @@ namespace ctranslate2 {
lengths = input_lengths.get();
}

bool multi_query;

if (lengths) {
if (allow_padding_removal) {
input_padder = std::make_unique<Padder>(*lengths, max_time);
input_padder->remove_padding(layer_in);
}

const bool multi_query = _layers.front()->get_self_attention().multi_query();
multi_query = _layers.front()->get_self_attention().multi_query();

StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
Expand Down Expand Up @@ -551,7 +548,7 @@ namespace ctranslate2 {

while (true) {
dim_t prompt_size = layer_in.dim(1);
if (!use_sliding_window || prompt_size <= _sliding_window) {
if (prompt_size <= _sliding_window) {
layer_ins.push_back(std::move(layer_in));
break;
}
Expand Down Expand Up @@ -586,6 +583,23 @@ namespace ctranslate2 {
if (attention && heads_to_select)
layer_attention = std::make_unique<StorageView>(dtype, device);

dim_t offset = _sliding_window * i + step;
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
tmp_lengths,
_num_heads,
max_tokens,
/*mask_future=*/true,
multi_query);

const ops::Slide slide_lengths_op(2, _sliding_window, layer_in_chunk.dim(1));
// reuse tmp_lengths
slide_lengths_op(lengths_mask, tmp_lengths);
input_lengths_mask = std::make_unique<StorageView>(std::move(tmp_lengths));
}

(*_layers[l])(layer_in_chunk,
input_lengths_mask.get(),
memory,
Expand All @@ -600,8 +614,7 @@ namespace ctranslate2 {
memory_padder.get(),
return_normalized_attention(),
&position_bias,
step,
i);
offset);
layer_in_chunk = std::move(layer_out);

if (layer_attention) {
Expand Down

0 comments on commit 7e0d6dd

Please sign in to comment.