Skip to content

Commit

Permalink
fix forward batch (#1572)
Browse files Browse the repository at this point in the history
Co-authored-by: thucpham <[email protected]>
  • Loading branch information
minhthuc2502 and thucpham authored Dec 4, 2023
1 parent c6f7f3b commit 01cf79d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ namespace ctranslate2 {
_alibi,
position_bias);

if (prefilling && cached_keys->shape()[2] > _sliding_window) {
if (prefilling && cached_keys && cached_keys->shape()[2] > _sliding_window) {
// set only last sliding_window tokens to cached_keys and cached_values after computing attention
const ops::Slide slide_op(2, cached_keys->shape()[2] - _sliding_window, _sliding_window);
StorageView tmp(dtype, device);
Expand Down
1 change: 1 addition & 0 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ namespace ctranslate2 {
layer_attention = std::make_unique<StorageView>(dtype, device);

dim_t offset = _sliding_window * i + step;
offset = offset < 0 ? 0 : offset;
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk.dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk.dim(0)}, int32_t(max_tokens), device);
Expand Down

0 comments on commit 01cf79d

Please sign in to comment.