Skip to content

Commit

Permalink
Avoid heap allocations when decrypting SSv2
Browse files Browse the repository at this point in the history
  • Loading branch information
akonradi-signal authored Jan 6, 2025
1 parent 132378b commit 9d6c2ac
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/protocol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ sha2 = { workspace = true }
subtle = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true }
zerocopy = { workspace = true, features = ["derive"] }

# WARNING: pqcrypto-kyber 0.8 and 0.7 don't actually coexist, they both depend on the same C symbols.
# We keep this here for if/when that gets cleared up.
Expand Down
60 changes: 37 additions & 23 deletions rust/protocol/benches/sealed_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: AGPL-3.0-only
//

use std::hint::black_box;
use std::time::SystemTime;

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
Expand Down Expand Up @@ -80,18 +81,27 @@ pub fn v1(c: &mut Criterion) {
.expect("valid");

let mut encrypt_it = || {
sealed_sender_encrypt_from_usmc(&bob_address, &usmc, &alice_store.identity_store, &mut rng)
black_box(
sealed_sender_encrypt_from_usmc(
&bob_address,
&usmc,
&alice_store.identity_store,
&mut rng,
)
.now_or_never()
.expect("sync")
.expect("valid")
.expect("valid"),
)
};
let encrypted = encrypt_it();

let mut decrypt_it = || {
sealed_sender_decrypt_to_usmc(&encrypted, &bob_store.identity_store)
.now_or_never()
.expect("sync")
.expect("valid")
black_box(
sealed_sender_decrypt_to_usmc(&encrypted, &bob_store.identity_store)
.now_or_never()
.expect("sync")
.expect("valid"),
)
};
assert_eq!(message, decrypt_it().contents().expect("valid"));

Expand Down Expand Up @@ -164,20 +174,22 @@ pub fn v2(c: &mut Criterion) {
.expect("valid");

let mut encrypt_it = || {
sealed_sender_multi_recipient_encrypt(
&[&bob_address],
&alice_store
.session_store
.load_existing_sessions(&[&bob_address])
.expect("present"),
[],
&usmc,
&alice_store.identity_store,
&mut rng,
black_box(
sealed_sender_multi_recipient_encrypt(
&[&bob_address],
&alice_store
.session_store
.load_existing_sessions(&[&bob_address])
.expect("present"),
[],
&usmc,
&alice_store.identity_store,
&mut rng,
)
.now_or_never()
.expect("sync")
.expect("valid"),
)
.now_or_never()
.expect("sync")
.expect("valid")
};
let outgoing = encrypt_it();

Expand All @@ -186,10 +198,12 @@ pub fn v2(c: &mut Criterion) {
assert_eq!(&incoming_recipient.service_id_string(), bob_address.name());

let mut decrypt_it = || {
sealed_sender_decrypt_to_usmc(&incoming_message, &bob_store.identity_store)
.now_or_never()
.expect("sync")
.expect("valid")
black_box(
sealed_sender_decrypt_to_usmc(&incoming_message, &bob_store.identity_store)
.now_or_never()
.expect("sync")
.expect("valid"),
)
};
assert_eq!(message, decrypt_it().contents().expect("valid"));

Expand Down
79 changes: 37 additions & 42 deletions rust/protocol/src/sealed_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use prost::Message;
use proto::sealed_sender::unidentified_sender_message::message::Type as ProtoMessageType;
use rand::{CryptoRng, Rng};
use subtle::ConstantTimeEq;
use zerocopy::{FromBytes, FromZeroes};

use crate::{
crypto, message_encrypt, proto, session_cipher, Aci, CiphertextMessageType, DeviceId,
Expand Down Expand Up @@ -514,17 +515,17 @@ impl UnidentifiedSenderMessageContent {
}
}

enum UnidentifiedSenderMessage {
enum UnidentifiedSenderMessage<'a> {
V1 {
ephemeral_public: PublicKey,
encrypted_static: Vec<u8>,
encrypted_message: Vec<u8>,
},
V2 {
ephemeral_public: PublicKey,
encrypted_message_key: Box<[u8]>,
authentication_tag: Box<[u8]>,
encrypted_message: Box<[u8]>,
encrypted_message_key: &'a [u8; sealed_sender_v2::MESSAGE_KEY_LEN],
authentication_tag: &'a [u8; sealed_sender_v2::AUTH_TAG_LEN],
encrypted_message: &'a [u8],
},
}

Expand All @@ -534,14 +535,12 @@ const SEALED_SENDER_V2_MAJOR_VERSION: u8 = 2;
const SEALED_SENDER_V2_UUID_FULL_VERSION: u8 = 0x22;
const SEALED_SENDER_V2_SERVICE_ID_FULL_VERSION: u8 = 0x23;

impl UnidentifiedSenderMessage {
fn deserialize(data: &[u8]) -> Result<Self> {
if data.is_empty() {
return Err(SignalProtocolError::InvalidSealedSenderMessage(
"Message was empty".to_owned(),
));
}
let version = data[0] >> 4;
impl<'a> UnidentifiedSenderMessage<'a> {
fn deserialize(data: &'a [u8]) -> Result<Self> {
let (version_byte, remaining) = data.split_first().ok_or_else(|| {
SignalProtocolError::InvalidSealedSenderMessage("Message was empty".to_owned())
})?;
let version = version_byte >> 4;
log::debug!(
"deserializing UnidentifiedSenderMessage with version {}",
version
Expand All @@ -550,7 +549,7 @@ impl UnidentifiedSenderMessage {
match version {
0 | SEALED_SENDER_V1_MAJOR_VERSION => {
// XXX should we really be accepted version == 0 here?
let pb = proto::sealed_sender::UnidentifiedSenderMessage::decode(&data[1..])
let pb = proto::sealed_sender::UnidentifiedSenderMessage::decode(remaining)
.map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;

let ephemeral_public = pb
Expand All @@ -572,27 +571,31 @@ impl UnidentifiedSenderMessage {
})
}
SEALED_SENDER_V2_MAJOR_VERSION => {
// Uses a flat representation: C || AT || E.pub || ciphertext
let remaining = &data[1..];
if remaining.len()
< sealed_sender_v2::MESSAGE_KEY_LEN
+ sealed_sender_v2::AUTH_TAG_LEN
+ sealed_sender_v2::PUBLIC_KEY_LEN
{
return Err(SignalProtocolError::InvalidProtobufEncoding);
/// Uses a flat representation: C || AT || E.pub || ciphertext
#[repr(packed)]
#[derive(FromBytes, FromZeroes)]
struct PrefixRepr {
encrypted_message_key: [u8; sealed_sender_v2::MESSAGE_KEY_LEN],
encrypted_authentication_tag: [u8; sealed_sender_v2::AUTH_TAG_LEN],
ephemeral_public: [u8; sealed_sender_v2::PUBLIC_KEY_LEN],
}
let (encrypted_message_key, remaining) =
remaining.split_at(sealed_sender_v2::MESSAGE_KEY_LEN);
let (encrypted_authentication_tag, remaining) =
remaining.split_at(sealed_sender_v2::AUTH_TAG_LEN);
let (ephemeral_public, encrypted_message) =
remaining.split_at(sealed_sender_v2::PUBLIC_KEY_LEN);
let (prefix, encrypted_message) =
zerocopy::Ref::<_, PrefixRepr>::new_from_prefix(remaining)
.ok_or(SignalProtocolError::InvalidProtobufEncoding)?;

let PrefixRepr {
encrypted_message_key,
encrypted_authentication_tag,
ephemeral_public,
} = prefix.into_ref();

Ok(Self::V2 {
ephemeral_public: PublicKey::from_djb_public_key_bytes(ephemeral_public)?,
encrypted_message_key: encrypted_message_key.into(),
authentication_tag: encrypted_authentication_tag.into(),
encrypted_message: encrypted_message.into(),
ephemeral_public: PublicKey::from_djb_public_key_bytes(
ephemeral_public.as_slice(),
)?,
encrypted_message_key,
authentication_tag: encrypted_authentication_tag,
encrypted_message,
})
}
_ => Err(SignalProtocolError::UnknownSealedSenderVersion(version)),
Expand Down Expand Up @@ -1817,19 +1820,11 @@ pub async fn sealed_sender_decrypt_to_usmc(
authentication_tag,
encrypted_message,
} => {
let encrypted_message_key: [u8; sealed_sender_v2::MESSAGE_KEY_LEN] =
encrypted_message_key.as_ref().try_into().map_err(|_| {
SignalProtocolError::InvalidSealedSenderMessage(format!(
"encrypted message key had incorrect length {} (should be {})",
encrypted_message_key.len(),
sealed_sender_v2::MESSAGE_KEY_LEN
))
})?;
let m = sealed_sender_v2::apply_agreement_xor(
&our_identity.into(),
&ephemeral_public,
Direction::Receiving,
&encrypted_message_key,
encrypted_message_key,
)?;

let keys = sealed_sender_v2::DerivedKeys::new(&m);
Expand All @@ -1839,7 +1834,7 @@ pub async fn sealed_sender_decrypt_to_usmc(
));
}

let mut message_bytes = encrypted_message.into_vec();
let mut message_bytes = Vec::from(encrypted_message);
Aes256GcmSiv::new(&keys.derive_k().into())
.decrypt_in_place(
// There's no nonce because the key is already one-use.
Expand All @@ -1862,7 +1857,7 @@ pub async fn sealed_sender_decrypt_to_usmc(
&usmc.sender()?.key()?.into(),
Direction::Receiving,
&ephemeral_public,
&encrypted_message_key,
encrypted_message_key,
)?;
if !bool::from(authentication_tag.ct_eq(&at)) {
return Err(SignalProtocolError::InvalidSealedSenderMessage(
Expand Down

0 comments on commit 9d6c2ac

Please sign in to comment.