Skip to content

Commit

Permalink
Fix pointer of layer_in
Browse files Browse the repository at this point in the history
  • Loading branch information
thucpham committed Oct 31, 2023
1 parent 1ec236f commit be1275f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace ctranslate2 {
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
const int _slide_window;
const dim_t _slide_window;
};

enum class RotaryScalingType {
Expand Down
16 changes: 8 additions & 8 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,12 @@ namespace ctranslate2 {

int offset = 0;

StorageView* layer_in_slided;
StorageView layer_in_slided;
StorageView tmp(dtype, device);
if (!use_slide_window)
layer_in_slided = &layer_in;
layer_in_slided = std::move(layer_in);
else {
layer_in_slided = &tmp;
layer_in_slided = std::move(tmp);
}

for (size_t l = 0; l < _layers.size(); ++l) {
Expand All @@ -626,12 +626,12 @@ namespace ctranslate2 {
}
if (_slide_window * (offset + 1) >= max_time) {
const ops::Slide slide_op(1, max_time - _slide_window, _slide_window);
slide_op(layer_in, *layer_in_slided);
slide_op(layer_in, layer_in_slided);
use_slide_window = false;
}
else {
const ops::Slide slide_op(1, _slide_window * offset, _slide_window);
slide_op(layer_in, *layer_in_slided);
slide_op(layer_in, layer_in_slided);
++offset;
}
}
Expand All @@ -643,7 +643,7 @@ namespace ctranslate2 {

bool compute_attn_ws = _slide_window > 0 && !cached_self_attn_keys->empty()
&& cached_self_attn_keys->dim(2) >= _slide_window;
(*_layers[l])(*layer_in_slided,
(*_layers[l])(layer_in_slided,
input_lengths_mask.get(),
memory,
memory_lengths_mask.get(),
Expand All @@ -659,14 +659,14 @@ namespace ctranslate2 {
&position_bias,
compute_attn_ws);
if (!use_slide_window)
layer_in_slided = &layer_out;
layer_in_slided = std::move(layer_out);

if (layer_attention) {
alignment_heads.emplace_back(dtype, device);
ops::Gather(1, 1)(*layer_attention, *heads_to_select, alignment_heads.back());
}
}
layer_in = *layer_in_slided;
layer_in = std::move(layer_in_slided);

if (step == 0) {
// The memory is no longer needed as its projections were cached in the first step.
Expand Down

0 comments on commit be1275f

Please sign in to comment.