diff --git a/dlc-manager/src/manager.rs b/dlc-manager/src/manager.rs index 1e916096..93d025a6 100644 --- a/dlc-manager/src/manager.rs +++ b/dlc-manager/src/manager.rs @@ -275,14 +275,28 @@ where &mut self, contract_input: &ContractInput, counter_party: PublicKey, + oracle_announcements: Option>>, ) -> Result { contract_input.validate()?; - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = match oracle_announcements { + Some(announcements) => { + // validate that the number of announcements is correct + if announcements.len() != contract_input.contract_infos.len() { + return Err(Error::InvalidParameters(format!( + "Expected {} oracle announcement vectors, got {}", + contract_input.contract_infos.len(), + announcements.len() + ))); + } + announcements + } + None => contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::, Error>>()?, + }; let (offered_contract, offer_msg) = crate::contract_updater::offer_contract( &self.secp, @@ -728,12 +742,26 @@ where &mut self, contract_input: &ContractInput, counter_party: PublicKey, + oracle_announcements: Option>>, ) -> Result { - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = match oracle_announcements { + Some(announcements) => { + // validate that the number of announcements is correct + if announcements.len() != contract_input.contract_infos.len() { + return Err(Error::InvalidParameters(format!( + "Expected {} oracle announcement vectors, got {}", + contract_input.contract_infos.len(), + announcements.len() + ))); + } + announcements + } + None => contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::, Error>>()?, + }; let (offered_channel, offered_contract) = crate::channel_updater::offer_channel( &self.secp, @@ -876,15 +904,29 @@ where channel_id: &ChannelId, counter_payout: u64, contract_input: &ContractInput, + oracle_announcements: Option>>, ) -> Result<(RenewOffer, PublicKey), Error> { let mut signed_channel = get_channel_in_state!(self, channel_id, Signed, None as Option)?; - let oracle_announcements = contract_input - .contract_infos - .iter() - .map(|x| self.get_oracle_announcements(&x.oracles)) - .collect::, Error>>()?; + let oracle_announcements = match oracle_announcements { + Some(announcements) => { + // validate that the number of announcements is correct + if announcements.len() != contract_input.contract_infos.len() { + return Err(Error::InvalidParameters(format!( + "Expected {} oracle announcement vectors, got {}", + contract_input.contract_infos.len(), + announcements.len() + ))); + } + announcements + } + None => contract_input + .contract_infos + .iter() + .map(|x| self.get_oracle_announcements(&x.oracles)) + .collect::, Error>>()?, + }; let (msg, offered_contract) = crate::channel_updater::renew_offer( &self.secp, diff --git a/dlc-manager/tests/channel_execution_tests.rs b/dlc-manager/tests/channel_execution_tests.rs index 0fe286c2..82a6e5eb 100644 --- a/dlc-manager/tests/channel_execution_tests.rs +++ b/dlc-manager/tests/channel_execution_tests.rs @@ -443,6 +443,7 @@ fn channel_execution_test(test_params: TestParams, path: TestPath) { "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), + None, ) .expect("Send offer error"); @@ -967,7 +968,7 @@ fn renew_channel( let (renew_offer, _) = first .lock() .unwrap() - .renew_offer(&channel_id, 100000000, contract_input) + .renew_offer(&channel_id, 100000000, contract_input, None) .expect("to be able to renew channel contract"); first_send @@ -1023,7 +1024,7 @@ fn renew_reject( let (renew_offer, _) = first .lock() .unwrap() - .renew_offer(&channel_id, 100000000, contract_input) + .renew_offer(&channel_id, 100000000, contract_input, None) .expect("to be able to renew channel contract"); first_send @@ -1064,13 +1065,13 @@ fn renew_race( let (renew_offer, _) = first .lock() .unwrap() - .renew_offer(&channel_id, 100000000, contract_input) + .renew_offer(&channel_id, 100000000, contract_input, None) .expect("to be able to renew channel contract"); let (renew_offer_2, _) = second .lock() .unwrap() - .renew_offer(&channel_id, 100000000, contract_input) + .renew_offer(&channel_id, 100000000, contract_input, None) .expect("to be able to renew channel contract"); first_send @@ -1148,7 +1149,7 @@ fn renew_timeout( let (renew_offer, _) = first .lock() .unwrap() - .renew_offer(&channel_id, 100000000, contract_input) + .renew_offer(&channel_id, 100000000, contract_input, None) .expect("to be able to offer a settlement of the contract."); first_send diff --git a/dlc-manager/tests/manager_execution_tests.rs b/dlc-manager/tests/manager_execution_tests.rs index f90e4f90..4bed7947 100644 --- a/dlc-manager/tests/manager_execution_tests.rs +++ b/dlc-manager/tests/manager_execution_tests.rs @@ -583,6 +583,7 @@ fn manager_execution_test(test_params: TestParams, path: TestPath) { "0218845781f631c48f1c9709e23092067d06837f30aa0cd0544ac887fe91ddd166" .parse() .unwrap(), + None, ) .expect("Send offer error"); diff --git a/sample/src/cli.rs b/sample/src/cli.rs index fc1484b7..4cef3b34 100644 --- a/sample/src/cli.rs +++ b/sample/src/cli.rs @@ -198,7 +198,7 @@ pub(crate) async fn poll_for_user_input( manager_clone .lock() .unwrap() - .send_offer(&contract_input, pubkey) + .send_offer(&contract_input, pubkey, None) .expect("Error sending offer"), ) } else { @@ -206,7 +206,7 @@ pub(crate) async fn poll_for_user_input( manager_clone .lock() .unwrap() - .offer_channel(&contract_input, pubkey) + .offer_channel(&contract_input, pubkey, None) .expect("Error sending offer channel"), ) } @@ -412,7 +412,7 @@ pub(crate) async fn poll_for_user_input( manager_clone .lock() .unwrap() - .renew_offer(&channel_id, counter_payout, &contract_input) + .renew_offer(&channel_id, counter_payout, &contract_input, None) .expect("Error sending offer") }) .await