Skip to content

Commit

Permalink
Add some tests for the generated mask.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 23, 2024
1 parent 4244d59 commit 13a541e
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions candle-nn/tests/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,36 @@ fn rotating_kv_cache() -> Result<()> {
assert_eq!(cache.current_seq_len(), 13);
assert_eq!(cache.offset(), 1);

let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
);
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0);

let mask = cache.attn_mask(1, &Device::Cpu)?;
assert!(mask.is_none());
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let t = Tensor::new(&[42.], &Device::Cpu)?;

let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 23);
Expand Down

0 comments on commit 13a541e

Please sign in to comment.