Bug 1949394 - Part 2: Vendor mls-platform-api 5d88241b and mls-rs to b747d7ef (based on mls-rs 0.45.0). r=nss-reviewers,nkulatova

Differential Revision: https://phabricator.services.mozilla.com/D238924
This commit is contained in:
Benjamin Beurdouche
2025-02-21 11:20:34 +00:00
parent 1a446fa999
commit 5007442efa
140 changed files with 5263 additions and 2799 deletions

View File

@@ -1 +1 @@
{"files":{".github/workflows/test.yml":"d36004039f93ba4ca04a12cedc5fdb2e0acce7b415a9f0c46973e04dc199edd4","Cargo.lock":"329280f4d248432d82e3d86255ea615c1dedce49d35182ebed8e1bea608ce4e3","Cargo.toml":"b3225a1e31abbf72cfa2be631e55f054d4c68d20748cfd32faae370c6776ceb7","LICENSE-apache":"cfc7749b96f63bd31c3c42b5c471bf756814053e847c10f3eb003417bc523d30","LICENSE-mit":"929fb738c9a0bc0c816ec321b201c23b8980a9f161f88e85341d1fd0e5c075b3","README.md":"557b9747d251f151b40f2abd92bfc4c01ded4fa0b21f03e36bf54ccb0d044e1e","src/lib.rs":"6e22184c054f8900b505be477f548053c501004591d38ffe95aab3d4485ebb56","src/more.rs":"a3e95bf8fef7393ae58a763c9c6dbefe917c16712b51fee6baa9e7dbefb56017","src/state.rs":"35c5f02862a564eaad336d4cf893635592de5f96fefb65c78cedae01a69d5db1","tests/group_add.rs":"761b52534e899a660e84ea569c646d117cca683ed3fb01a70de20d3ff2596f2d","tests/group_close.rs":"43de437fad5150d134939fa457cd6842d9445ec2be5a729d82c06d469a4bd507","tests/group_create.rs":"ffa625df62b436d956a16d3a7cb07c59392cfa54f9e7ab5d33d213e4aa090766","tests/group_join.rs":"4d7fb2de369ac52bcea65a7c546d43cc91a988cb59dd06a5372f23b350ea040b","tests/group_propose_add.rs":"10da5b2b7911e134dd792f58efbdcd0bee0af3f8dd1e88b6d5f36ab5bcbaee04","tests/group_propose_remove.rs":"386763b93c1a36bcbf8274294a0781ac47b1cfb292643f591116bab695141de5","tests/group_remove.rs":"5cfd673f1e72b0f7b28151bfb2e453e543b5a2fe7037113695c363d4d5e1359f","tests/group_update.rs":"6af952fdedd7860770b4892b276c7dd385fab5e781d68a0527248a6c8a9adc7b","tests/main.rs":"8ffcddceac8524f4a40eca68fcf4489db0971d9f5c364e617fc4f1f1014c4f77","tests/send_receive.rs":"17597ba0e6c5a81547251f5635847d9699de129c86a5ae37343ec9108919f1a1"},"package":null} {"files":{".github/workflows/test.yml":"d36004039f93ba4ca04a12cedc5fdb2e0acce7b415a9f0c46973e04dc199edd4","Cargo.lock":"4b8d6356c7f41dbe9114b2f35b620a28d08b13cc57c5491acc0a946f3f3146bc","Cargo.toml":"259fca6aecc50645fd915271f26e23d7c9428628264c6560ace21eed33ba5442","LICENSE-apache":"cfc7749b96f63bd31c3c42b5c471bf756814053e847c10f3eb003417bc523d30","LICENSE-mit":"929fb738c9a0bc0c816ec321b201c23b8980a9f161f88e85341d1fd0e5c075b3","README.md":"557b9747d251f151b40f2abd92bfc4c01ded4fa0b21f03e36bf54ccb0d044e1e","src/lib.rs":"d10f601a77a374643f6427cd67f4bfc4c252a1d6966a4f864360db6e0d64fa50","src/more.rs":"a3e95bf8fef7393ae58a763c9c6dbefe917c16712b51fee6baa9e7dbefb56017","src/state.rs":"b6008a69016795b9fecd20f8110b26c5f844d5f3f56ddc9d840ac31cd4eaf27d","tests/clear_pending_commit.rs":"476e747e8ae576e833a7ec14925cb075a5d581811504a114762476fe9accb27c","tests/clear_pending_proposals.rs":"74df241422e5fab45c284213ef685f8948772e642d0b1ea8a3ef3983dcbcc870","tests/group_add.rs":"1fa4a21203a57c625e78d9397bb1e4cefb44f8e0a1745c66d780989590558fe0","tests/group_close.rs":"b4d5b065a24c6cf70f82b9574c4da78685d1f3bd275ece96f5d6cb91ac2c5323","tests/group_create.rs":"03d7c9f612e1425d1ee7f66e7997974bd456416e445a72e6ced6e227b6798640","tests/group_join.rs":"ddc984f88320bc89cc612cbbae5801fe107c97b1bfa485a8f105a11943ebf7f6","tests/group_propose_add.rs":"68b6b1bcba6e84047fff4ca14a090dcf7f47ae565dcd6d62055e62f403d42204","tests/group_propose_remove.rs":"0421bbc2f1f32039cae9621dd21bc92cb8afb6123213ff4f8e4e813df0b920a6","tests/group_remove.rs":"33a5c0d2e2feca86fda10e4efdb50b7331ac6a40c22f71d12cd72b5ff33cde82","tests/group_update.rs":"4eeb87e89be776c3e4efc28641cf286f6a2305c44ef8f08a75c674728bbdc332","tests/main.rs":"704ff188009a21657880dc409770e9d3b9112caf5c0d3238bf9820883d338ce8","tests/send_receive.rs":"da7b489f73ca1c269aa27f4bbcfbd9be41f0f42b5ff7ace3a44f81019616c8d9"},"package":null}

View File

@@ -489,8 +489,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs" name = "mls-rs"
version = "0.39.1" version = "0.45.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"cfg-if", "cfg-if",
@@ -514,8 +514,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-codec" name = "mls-rs-codec"
version = "0.5.3" version = "0.6.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"mls-rs-codec-derive", "mls-rs-codec-derive",
"thiserror", "thiserror",
@@ -524,8 +524,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-codec-derive" name = "mls-rs-codec-derive"
version = "0.1.1" version = "0.2.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"darling", "darling",
"proc-macro2", "proc-macro2",
@@ -535,8 +535,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-core" name = "mls-rs-core"
version = "0.18.0" version = "0.21.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"hex", "hex",
@@ -551,8 +551,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-crypto-hpke" name = "mls-rs-crypto-hpke"
version = "0.9.0" version = "0.14.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"cfg-if", "cfg-if",
@@ -566,7 +566,7 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-crypto-nss" name = "mls-rs-crypto-nss"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"getrandom", "getrandom",
"hex", "hex",
@@ -583,8 +583,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-crypto-traits" name = "mls-rs-crypto-traits"
version = "0.10.0" version = "0.15.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"maybe-async", "maybe-async",
@@ -593,8 +593,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-identity-x509" name = "mls-rs-identity-x509"
version = "0.11.0" version = "0.15.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"maybe-async", "maybe-async",
@@ -605,8 +605,8 @@ dependencies = [
[[package]] [[package]]
name = "mls-rs-provider-sqlite" name = "mls-rs-provider-sqlite"
version = "0.11.0" version = "0.15.0"
source = "git+https://github.com/beurdouche/mls-rs?rev=eedb37e50e3fca51863f460755afd632137da57c#eedb37e50e3fca51863f460755afd632137da57c" source = "git+https://github.com/beurdouche/mls-rs?rev=b747d7efb85a776b97ad8afa8d1b32893fa5efa3#b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"hex", "hex",

View File

@@ -26,6 +26,14 @@ license = "Apache-2.0 OR MIT"
name = "mls_platform_api" name = "mls_platform_api"
path = "src/lib.rs" path = "src/lib.rs"
[[test]]
name = "clear_pending_commit"
path = "tests/clear_pending_commit.rs"
[[test]]
name = "clear_pending_proposals"
path = "tests/clear_pending_proposals.rs"
[[test]] [[test]]
name = "group_add" name = "group_add"
path = "tests/group_add.rs" path = "tests/group_add.rs"
@@ -78,7 +86,7 @@ features = ["serde"]
[dependencies.mls-rs] [dependencies.mls-rs]
git = "https://github.com/beurdouche/mls-rs" git = "https://github.com/beurdouche/mls-rs"
rev = "eedb37e50e3fca51863f460755afd632137da57c" rev = "b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
features = [ features = [
"sqlcipher-bundled", "sqlcipher-bundled",
"serde", "serde",
@@ -86,11 +94,11 @@ features = [
[dependencies.mls-rs-crypto-nss] [dependencies.mls-rs-crypto-nss]
git = "https://github.com/beurdouche/mls-rs" git = "https://github.com/beurdouche/mls-rs"
rev = "eedb37e50e3fca51863f460755afd632137da57c" rev = "b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
[dependencies.mls-rs-provider-sqlite] [dependencies.mls-rs-provider-sqlite]
git = "https://github.com/beurdouche/mls-rs" git = "https://github.com/beurdouche/mls-rs"
rev = "eedb37e50e3fca51863f460755afd632137da57c" rev = "b747d7efb85a776b97ad8afa8d1b32893fa5efa3"
[dependencies.serde] [dependencies.serde]
version = "1.0" version = "1.0"

View File

@@ -5,7 +5,7 @@ mod state;
use mls_rs::error::{AnyError, IntoAnyError}; use mls_rs::error::{AnyError, IntoAnyError};
use mls_rs::group::proposal::{CustomProposal, ProposalType}; use mls_rs::group::proposal::{CustomProposal, ProposalType};
use mls_rs::group::{Capabilities, ExportedTree, ReceivedMessage}; use mls_rs::group::{Capabilities, CommitEffect, ExportedTree, ReceivedMessage};
use mls_rs::identity::SigningIdentity; use mls_rs::identity::SigningIdentity;
use mls_rs::mls_rs_codec::{MlsDecode, MlsEncode}; use mls_rs::mls_rs_codec::{MlsDecode, MlsEncode};
use mls_rs::{CipherSuiteProvider, CryptoProvider, Extension, ExtensionList}; use mls_rs::{CipherSuiteProvider, CryptoProvider, Extension, ExtensionList};
@@ -171,7 +171,7 @@ pub fn mls_generate_credential_basic(content: &[u8]) -> Result<MlsCredential, Pl
/// ///
/// Generate a Signature Keypair /// Generate a Signature Keypair
/// ///
pub fn mls_generate_signature_keypair( pub fn mls_generate_identity(
state: &PlatformState, state: &PlatformState,
cs: CipherSuite, cs: CipherSuite,
// _randomness: Option<Vec<u8>>, // _randomness: Option<Vec<u8>>,
@@ -218,7 +218,11 @@ pub fn mls_generate_key_package(
let client = state.client(myself, Some(decoded_cred), ProtocolVersion::MLS_10, config)?; let client = state.client(myself, Some(decoded_cred), ProtocolVersion::MLS_10, config)?;
// Generate a KeyPackage from that client_default // Generate a KeyPackage from that client_default
let key_package = client.generate_key_package_message()?; let key_package_extensions = config.key_package_extensions.clone().unwrap_or_default();
let leaf_node_extensions = config.leaf_node_extensions.clone().unwrap_or_default();
let key_package =
client.generate_key_package_message(key_package_extensions, leaf_node_extensions)?;
// Result // Result
Ok(key_package) Ok(key_package)
@@ -236,7 +240,7 @@ pub struct ClientIdentifiers {
} }
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct GroupMembers { pub struct GroupDetails {
pub group_id: MlsGroupId, pub group_id: MlsGroupId,
pub group_epoch: u64, pub group_epoch: u64,
pub group_members: Vec<ClientIdentifiers>, pub group_members: Vec<ClientIdentifiers>,
@@ -244,11 +248,11 @@ pub struct GroupMembers {
// Note: The identity is needed because it is allowed to have multiple // Note: The identity is needed because it is allowed to have multiple
// identities in a group. // identities in a group.
pub fn mls_group_members( pub fn mls_group_details(
state: &PlatformState, state: &PlatformState,
gid: MlsGroupIdArg, gid: MlsGroupIdArg,
myself: IdentityArg, myself: IdentityArg,
) -> Result<GroupMembers, PlatformError> { ) -> Result<GroupDetails, PlatformError> {
let crypto_provider = DefaultCryptoProvider::default(); let crypto_provider = DefaultCryptoProvider::default();
let group = state.client_default(myself)?.load_group(gid)?; let group = state.client_default(myself)?.load_group(gid)?;
@@ -272,13 +276,13 @@ pub fn mls_group_members(
}) })
.collect::<Result<Vec<_>, PlatformError>>()?; .collect::<Result<Vec<_>, PlatformError>>()?;
let members = GroupMembers { let group_details = GroupDetails {
group_id: gid.to_vec(), group_id: gid.to_vec(),
group_epoch: epoch, group_epoch: epoch,
group_members: members, group_members: members,
}; };
Ok(members) Ok(group_details)
} }
/// ///
@@ -305,8 +309,12 @@ pub fn mls_group_create(
Some(gid) => client.create_group_with_id( Some(gid) => client.create_group_with_id(
gid.to_vec(), gid.to_vec(),
group_context_extensions.unwrap_or_default().clone(), group_context_extensions.unwrap_or_default().clone(),
config.leaf_node_extensions.clone().unwrap_or_default(),
)?,
None => client.create_group(
group_context_extensions.unwrap_or_default().clone(),
config.leaf_node_extensions.clone().unwrap_or_default(),
)?, )?,
None => client.create_group(group_context_extensions.unwrap_or_default().clone())?,
}; };
// The state needs to be returned or stored somewhere // The state needs to be returned or stored somewhere
@@ -637,6 +645,10 @@ pub fn mls_group_update(
commit_builder = commit_builder.set_group_context_ext(group_context_extensions)?; commit_builder = commit_builder.set_group_context_ext(group_context_extensions)?;
} }
if let Some(leaf_node_extensions) = config.leaf_node_extensions.clone() {
commit_builder = commit_builder.set_leaf_node_extensions(leaf_node_extensions);
}
let identity = if let Some((key, cred)) = signature_key.zip(credential) { let identity = if let Some((key, cred)) = signature_key.zip(credential) {
let signature_secret_key = key.to_vec().into(); let signature_secret_key = key.to_vec().into();
let signature_public_key = cipher_suite_provider let signature_public_key = cipher_suite_provider
@@ -805,29 +817,32 @@ pub fn mls_receive(
} }
ReceivedMessage::Commit(commit) => { ReceivedMessage::Commit(commit) => {
// Check if the group is active or not after applying the commit // Check if the group is active or not after applying the commit
if !commit.state_update.is_active() { match commit.effect {
// Delete the group from the state of the client CommitEffect::Removed { .. } => {
pstate.delete_group(gid, myself)?; // Delete the group from the state of the client
pstate.delete_group(gid, myself)?;
// Return the group id and 0xFF..FF epoch to signal the group is closed // Return the group id and 0xFF..FF epoch to signal the group is closed
let group_epoch = GroupIdEpoch { let group_epoch = GroupIdEpoch {
group_id: group.group_id().to_vec(), group_id: group.group_id().to_vec(),
group_epoch: 0xFFFFFFFFFFFFFFFF, group_epoch: 0xFFFFFFFFFFFFFFFF,
}; };
Ok((gid.to_vec(), Received::GroupIdEpoch(group_epoch))) Ok((gid.to_vec(), Received::GroupIdEpoch(group_epoch)))
} else { }
// TODO: Receiving a group_close commit means the sender receiving _ => {
// is left alone in the group. We should be able delete group automatically. // TODO: Receiving a group_close commit means the sender receiving
// As of now, the user calling group_close has to delete group manually. // is left alone in the group. We should be able delete group automatically.
// As of now, the user calling group_close has to delete group manually.
// If this is a normal commit, return the affected group and new epoch // If this is a normal commit, return the affected group and new epoch
let group_epoch = GroupIdEpoch { let group_epoch = GroupIdEpoch {
group_id: group.group_id().to_vec(), group_id: group.group_id().to_vec(),
group_epoch: group.current_epoch(), group_epoch: group.current_epoch(),
}; };
Ok((gid.to_vec(), Received::GroupIdEpoch(group_epoch))) Ok((gid.to_vec(), Received::GroupIdEpoch(group_epoch)))
}
} }
} }
// TODO: We could make this more user friendly by allowing to // TODO: We could make this more user friendly by allowing to
@@ -841,6 +856,27 @@ pub fn mls_receive(
Ok(result) Ok(result)
} }
pub fn mls_has_pending_proposals(
pstate: &PlatformState,
gid: MlsGroupIdArg,
myself: IdentityArg,
) -> Result<bool, PlatformError> {
let group = pstate.client_default(myself)?.load_group(gid)?;
let result = group.commit_required();
Ok(result)
}
pub fn mls_clear_pending_proposals(
pstate: &PlatformState,
gid: MlsGroupIdArg,
myself: IdentityArg,
) -> Result<bool, PlatformError> {
let mut group = pstate.client_default(myself)?.load_group(gid)?;
group.clear_proposal_cache();
group.write_to_storage()?;
Ok(true)
}
pub fn mls_has_pending_commit( pub fn mls_has_pending_commit(
pstate: &PlatformState, pstate: &PlatformState,
gid: MlsGroupIdArg, gid: MlsGroupIdArg,
@@ -875,29 +911,32 @@ pub fn mls_apply_pending_commit(
let result = match received_message? { let result = match received_message? {
ReceivedMessage::Commit(commit) => { ReceivedMessage::Commit(commit) => {
// Check if the group is active or not after applying the commit // Check if the group is active or not after applying the commit
if !commit.state_update.is_active() { match commit.effect {
// Delete the group from the state of the client CommitEffect::Removed { .. } => {
pstate.delete_group(gid, myself)?; // Delete the group from the state of the client
pstate.delete_group(gid, myself)?;
// Return the group id and 0xFF..FF epoch to signal the group is closed // Return the group id and 0xFF..FF epoch to signal the group is closed
let group_epoch = GroupIdEpoch { let group_epoch = GroupIdEpoch {
group_id: group.group_id().to_vec(), group_id: group.group_id().to_vec(),
group_epoch: 0xFFFFFFFFFFFFFFFF, group_epoch: 0xFFFFFFFFFFFFFFFF,
}; };
Ok(Received::GroupIdEpoch(group_epoch)) Ok(Received::GroupIdEpoch(group_epoch))
} else { }
// TODO: Receiving a group_close commit means the sender receiving _ => {
// is left alone in the group. We should be able delete group automatically. // TODO: Receiving a group_close commit means the sender receiving
// As of now, the user calling group_close has to delete group manually. // is left alone in the group. We should be able delete group automatically.
// As of now, the user calling group_close has to delete group manually.
// If this is a normal commit, return the affected group and new epoch // If this is a normal commit, return the affected group and new epoch
let group_epoch = GroupIdEpoch { let group_epoch = GroupIdEpoch {
group_id: group.group_id().to_vec(), group_id: group.group_id().to_vec(),
group_epoch: group.current_epoch(), group_epoch: group.current_epoch(),
}; };
Ok(Received::GroupIdEpoch(group_epoch)) Ok(Received::GroupIdEpoch(group_epoch))
}
} }
} }
_ => Err(PlatformError::UnsupportedMessage), _ => Err(PlatformError::UnsupportedMessage),
@@ -1141,53 +1180,14 @@ pub fn mls_get_group_id(message_or_ack: &MessageOrAck) -> Result<Vec<u8>, Platfo
Ok(gid.to_vec()) Ok(gid.to_vec())
} }
use serde_json::{Error, Value}; pub fn mls_get_group_epoch(message_or_ack: &MessageOrAck) -> Result<u64, PlatformError> {
let group_epoch: Option<u64> = match &message_or_ack {
MessageOrAck::MlsMessage(message) => message.epoch(),
_ => None,
};
// This function takes a JSON string and converts byte arrays into hex strings. Ok(group_epoch.expect("Group epoch not found"))
fn convert_bytes_fields_to_hex(input_str: &str) -> Result<String, Error> {
// Parse the JSON string into a serde_json::Value
let mut value: Value = serde_json::from_str(input_str)?;
// Recursive function to process each element
fn process_element(element: &mut Value) {
match element {
Value::Array(ref mut vec) => {
if vec
.iter()
.all(|x| matches!(x, Value::Number(n) if n.is_u64()))
{
// Convert all elements to a Vec<u8> if they are numbers
let bytes: Vec<u8> = vec
.iter()
.filter_map(|x| x.as_u64().map(|n| n as u8))
.collect();
// Check if the conversion makes sense (the length matches)
if bytes.len() == vec.len() {
*element = Value::String(hex::encode(bytes));
} else {
vec.iter_mut().for_each(process_element);
}
} else {
vec.iter_mut().for_each(process_element);
}
}
Value::Object(ref mut map) => {
map.values_mut().for_each(process_element);
}
_ => {}
}
}
// Process the element and return the new Json string
process_element(&mut value);
serde_json::to_string(&value)
} }
// This function accepts bytes, converts them to a string, and then processes the string. // TODO:
pub fn utils_json_bytes_to_string_custom(input_bytes: &[u8]) -> Result<String, PlatformError> { // - Is key available for the message ?
// Convert input bytes to a string
let input_str =
std::str::from_utf8(input_bytes).map_err(|_| PlatformError::JsonConversionError)?;
// Call the original function with the decoded string
convert_bytes_fields_to_hex(input_str).map_err(|_| PlatformError::JsonConversionError)
}

View File

@@ -121,14 +121,6 @@ impl PlatformState {
) )
.protocol_version(version); .protocol_version(version);
if let Some(key_package_extensions) = &config.key_package_extensions {
builder = builder.key_package_extensions(key_package_extensions.clone());
};
if let Some(leaf_node_extensions) = &config.leaf_node_extensions {
builder = builder.leaf_node_extensions(leaf_node_extensions.clone());
}
if let Some(key_package_lifetime_s) = config.key_package_lifetime_s { if let Some(key_package_lifetime_s) = config.key_package_lifetime_s {
builder = builder.key_package_lifetime(key_package_lifetime_s); builder = builder.key_package_lifetime(key_package_lifetime_s);
} }
@@ -182,7 +174,7 @@ impl PlatformState {
.map_err(|e| PlatformError::StorageError(e.into_any_error()))?; .map_err(|e| PlatformError::StorageError(e.into_any_error()))?;
storage storage
.insert(hex::encode(identifier), data) .insert(&hex::encode(identifier), &data)
.map_err(|e| PlatformError::StorageError(e.into_any_error()))?; .map_err(|e| PlatformError::StorageError(e.into_any_error()))?;
Ok(()) Ok(())
} }

View File

@@ -0,0 +1,140 @@
// Copyright (c) 2024 Mozilla Corporation and contributors.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use mls_platform_api::MessageOrAck;
use mls_platform_api::PlatformError;
//
// Scenario
//
// * Alice, Bob and Charlie create signing identity (generate_signature_keypair)
// * Alice, Bob and Charlie create credentials (generate_credential_basic)
// * Bob and Charlie create key packages (generate_key_package)
// * Alice creates a group (group_create)
// * Alice adds Bob to the group (group_add)
// - Alice has a pending commit and cannot do another operation
// * Alice clears the pending commit
// * Alice can add Charlie to the group
#[test]
fn test_clear_pending_commit() -> Result<(), PlatformError> {
// Default group configuration
let group_config = mls_platform_api::GroupConfig::default();
// Storage states
let mut state_global = mls_platform_api::state_access("global.db", &[0u8; 32])?;
// Credentials
let alice_cred = mls_platform_api::mls_generate_credential_basic("alice".as_bytes())?;
let bob_cred = mls_platform_api::mls_generate_credential_basic("bob".as_bytes())?;
let charlie_cred = mls_platform_api::mls_generate_credential_basic("charlie".as_bytes())?;
println!("\nAlice credential: {}", hex::encode(&alice_cred));
println!("Bob credential: {}", hex::encode(&bob_cred));
println!("Charlie credential: {}", hex::encode(&charlie_cred));
// Create signature keypairs and store them in the state
let alice_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let charlie_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id));
println!("Charlie identifier: {}", hex::encode(&charlie_id));
// Create Key Package for Bob
let bob_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&bob_id,
&bob_cred,
&Default::default(),
)?;
// Create Key Package for Charlie
let charlie_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&charlie_id,
&charlie_cred,
&Default::default(),
)?;
// Create a group with Alice
let gide = mls_platform_api::mls_group_create(
&mut state_global,
&alice_id,
&alice_cred,
None,
None,
&Default::default(),
)?;
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}");
//
// Alice adds Bob to a group
//
println!("\nAlice adds Bob to the Group");
let _commit_output = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![bob_kp],
)?;
// Check if there's a pending commit
let pending =
mls_platform_api::mls_has_pending_commit(&state_global, &gide.group_id, &alice_id)?;
assert!(pending);
// Try to add Charlie while there's a pending commit - should fail
println!("\nAlice tries to add Charlie while there's a pending commit");
let result = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![charlie_kp.clone()],
);
// Verify we get an error
assert!(result.is_err());
println!("Got expected error when trying to add with pending commit: {result:?}");
// Discard the pending commit
println!("\nAlice discards the pending commit");
mls_platform_api::mls_clear_pending_commit(&mut state_global, &gide.group_id, &alice_id)?;
// Check if there's a pending commit again
let pending =
mls_platform_api::mls_has_pending_commit(&state_global, &gide.group_id, &alice_id)?;
assert!(!pending);
// Alice can now add Charlie to the group
println!("\nAlice adds Charlie to the Group");
let commit_output = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![charlie_kp.clone()],
)?;
// Alice process her own commit
println!("\nAlice process her commit to add Bob to the Group");
mls_platform_api::mls_receive(
&state_global,
&alice_id,
&MessageOrAck::MlsMessage(commit_output.commit),
)?;
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding charlie): {members:?}");
Ok(())
}

View File

@@ -0,0 +1,152 @@
// Copyright (c) 2024 Mozilla Corporation and contributors.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use mls_platform_api::MessageOrAck;
use mls_platform_api::PlatformError;
//
// Scenario
//
// * Alice, Bob and Charlie create signing identity (generate_signature_keypair)
// * Alice, Bob and Charlie create credentials (generate_credential_basic)
// * Bob and Charlie create key packages (generate_key_package)
// * Alice creates a group (group_create)
// * Alice proposes to add Bob to the group (group_propose_add)
// - Alice decides to clear the proposal
// * Alice can propose to add Charlie to the group
#[test]
fn test_clear_pending_proposals() -> Result<(), PlatformError> {
// Default group configuration
let group_config = mls_platform_api::GroupConfig::default();
// Storage states
let mut state_global = mls_platform_api::state_access("global.db", &[0u8; 32])?;
// Credentials
let alice_cred = mls_platform_api::mls_generate_credential_basic("alice".as_bytes())?;
let bob_cred = mls_platform_api::mls_generate_credential_basic("bob".as_bytes())?;
let charlie_cred = mls_platform_api::mls_generate_credential_basic("charlie".as_bytes())?;
println!("\nAlice credential: {}", hex::encode(&alice_cred));
println!("Bob credential: {}", hex::encode(&bob_cred));
println!("Charlie credential: {}", hex::encode(&charlie_cred));
// Create signature keypairs and store them in the state
let alice_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let charlie_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id));
println!("Charlie identifier: {}", hex::encode(&charlie_id));
// Create Key Package for Bob
let bob_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&bob_id,
&bob_cred,
&Default::default(),
)?;
// Create Key Package for Charlie
let charlie_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&charlie_id,
&charlie_cred,
&Default::default(),
)?;
// Create a group with Alice
let gide = mls_platform_api::mls_group_create(
&mut state_global,
&alice_id,
&alice_cred,
None,
None,
&Default::default(),
)?;
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}");
//
// Alice proposes to add Bob to the group
//
println!("\nAlice proposes to add Bob to the Group");
let _commit_output = mls_platform_api::mls_group_propose_add(
&mut state_global,
&gide.group_id,
&alice_id,
bob_kp,
)?;
// Check if there's a pending proposal
let pending =
mls_platform_api::mls_has_pending_proposals(&state_global, &gide.group_id, &alice_id)?;
assert!(pending);
println!("\nAlice proposes to add Charlie to the Group");
let _commit_output = mls_platform_api::mls_group_propose_add(
&mut state_global,
&gide.group_id,
&alice_id,
charlie_kp.clone(),
)?;
// Check if there's a pending proposal
let pending =
mls_platform_api::mls_has_pending_proposals(&state_global, &gide.group_id, &alice_id)?;
assert!(pending);
// Try to add Charlie while there's a pending proposal - should fail
println!("\nAlice tries to add Charlie while there's a pending proposal");
let result = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![charlie_kp.clone()],
);
// Verify we get an error
assert!(result.is_err());
println!("Got expected error when trying to add because of pending proposal: {result:?}");
// Discard the pending proposal
println!("\nAlice discards the pending proposal");
mls_platform_api::mls_clear_pending_proposals(&mut state_global, &gide.group_id, &alice_id)?;
// Check if there's a pending proposal again
let pending =
mls_platform_api::mls_has_pending_proposals(&state_global, &gide.group_id, &alice_id)?;
assert!(!pending);
// Alice can now add Charlie to the group
println!("\nAlice adds Charlie to the Group");
let commit_output = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![charlie_kp.clone()],
)?;
// Alice process her own commit
println!("\nAlice process her commit to add Bob to the Group");
mls_platform_api::mls_receive(
&state_global,
&alice_id,
&MessageOrAck::MlsMessage(commit_output.commit),
)?;
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding charlie): {members:?}");
Ok(())
}

View File

@@ -31,10 +31,9 @@ fn test_group_add() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -60,7 +59,7 @@ fn test_group_add() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -83,8 +82,95 @@ fn test_group_add() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
Ok(()) Ok(())
} }
#[test]
fn test_group_add_multiple() -> Result<(), PlatformError> {
// Default group configuration
let group_config = mls_platform_api::GroupConfig::default();
// Storage states
let mut state_global = mls_platform_api::state_access("global.db", &[0u8; 32])?;
// Credentials
let alice_cred = mls_platform_api::mls_generate_credential_basic("alice".as_bytes())?;
let bob_cred = mls_platform_api::mls_generate_credential_basic("bob".as_bytes())?;
let charlie_cred = mls_platform_api::mls_generate_credential_basic("charlie".as_bytes())?;
println!("\nAlice credential: {}", hex::encode(&alice_cred));
println!("Bob credential: {}", hex::encode(&bob_cred));
println!("Charlie credential: {}", hex::encode(&charlie_cred));
// Create signature keypairs and store them in the state
let alice_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let charlie_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id));
println!("Charlie identifier: {}", hex::encode(&charlie_id));
// Create Key Package for Bob
let bob_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&bob_id,
&bob_cred,
&Default::default(),
)?;
// Create Key Package for Charlie
let charlie_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&charlie_id,
&charlie_cred,
&Default::default(),
)?;
// Create a group with Alice
let gide = mls_platform_api::mls_group_create(
&mut state_global,
&alice_id,
&alice_cred,
None,
None,
&Default::default(),
)?;
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}");
//
// Alice adds Bob to a group
//
println!("\nAlice adds Bob and Charlie to the Group");
let commit_output = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![bob_kp, charlie_kp],
)?;
// Alice process her own commit
println!("\nAlice process her commit to add Bob and Charlie to the Group");
mls_platform_api::mls_receive(
&state_global,
&alice_id,
&MessageOrAck::MlsMessage(commit_output.commit),
)?;
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob and charlie): {members:?}");
Ok(())
}

View File

@@ -44,13 +44,12 @@ fn test_group_close() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let charlie_id = let charlie_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -85,7 +84,7 @@ fn test_group_close() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -114,7 +113,7 @@ fn test_group_close() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -122,7 +121,7 @@ fn test_group_close() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members:?}"); println!("Members (bob, after joining the group): {members:?}");
// //
@@ -174,7 +173,7 @@ fn test_group_close() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after adding charlie): {members:?}"); println!("Members (bob, after adding charlie): {members:?}");
// Alice receives the commit // Alice receives the commit
@@ -190,7 +189,7 @@ fn test_group_close() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?; mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members:?}"); println!("Members (charlie, after joining the group): {members:?}");
// //
@@ -226,7 +225,7 @@ fn test_group_close() -> Result<(), PlatformError> {
println!("\nCharlie processes the close commit"); println!("\nCharlie processes the close commit");
mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_6_msg)?; mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_6_msg)?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after processing their group_close commit): {members:?}"); println!("Members (charlie, after processing their group_close commit): {members:?}");
// Charlie deletes her state for the group // Charlie deletes her state for the group

View File

@@ -26,7 +26,7 @@ fn test_group_create() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
@@ -43,7 +43,7 @@ fn test_group_create() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, empty group): {members:?}"); println!("Members (alice, empty group): {members:?}");
Ok(()) Ok(())

View File

@@ -32,10 +32,9 @@ fn test_group_join() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -61,7 +60,7 @@ fn test_group_join() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -90,7 +89,7 @@ fn test_group_join() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -98,7 +97,7 @@ fn test_group_join() -> Result<(), PlatformError> {
let gide_2 = mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; let gide_2 = mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members_2 = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members_2 = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members_2:?}"); println!("Members (bob, after joining the group): {members_2:?}");
// Assert that the group identifier is the same for Alice and Bob // Assert that the group identifier is the same for Alice and Bob
@@ -108,3 +107,121 @@ fn test_group_join() -> Result<(), PlatformError> {
assert!(members == members_2); assert!(members == members_2);
Ok(()) Ok(())
} }
#[test]
fn test_group_join_multiple_adds() -> Result<(), PlatformError> {
// Default group configuration
let group_config = mls_platform_api::GroupConfig::default();
// Storage states
let mut state_global = mls_platform_api::state_access("global.db", &[0u8; 32])?;
// Credentials
let alice_cred = mls_platform_api::mls_generate_credential_basic("alice".as_bytes())?;
let bob_cred = mls_platform_api::mls_generate_credential_basic("bob".as_bytes())?;
let charlie_cred = mls_platform_api::mls_generate_credential_basic("charlie".as_bytes())?;
println!("\nAlice credential: {}", hex::encode(&alice_cred));
println!("Bob credential: {}", hex::encode(&bob_cred));
println!("Charlie credential: {}", hex::encode(&charlie_cred));
// Create signature keypairs and store them in the state
let alice_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let charlie_id =
mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id));
println!("Charlie identifier: {}", hex::encode(&charlie_id));
// Create Key Package for Bob
let bob_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&bob_id,
&bob_cred,
&Default::default(),
)?;
// Create Key Package for Charlie
let charlie_kp = mls_platform_api::mls_generate_key_package(
&state_global,
&charlie_id,
&charlie_cred,
&Default::default(),
)?;
// Create a group with Alice
let gide = mls_platform_api::mls_group_create(
&mut state_global,
&alice_id,
&alice_cred,
None,
None,
&Default::default(),
)?;
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob and charlie): {members:?}");
//
// Alice adds Bob and Charlie to a group
//
println!("\nAlice adds Bob and Charlie to the Group");
let commit_output = mls_platform_api::mls_group_add(
&mut state_global,
&gide.group_id,
&alice_id,
vec![bob_kp, charlie_kp],
)?;
let welcome = commit_output
.welcome
.first()
.expect("No welcome messages found")
.clone();
// Alice process her own commit
println!("\nAlice process her commit to add Bob and Charlie to the Group");
mls_platform_api::mls_receive(
&state_global,
&alice_id,
&MessageOrAck::MlsMessage(commit_output.commit.clone()),
)?;
// List the members of the group from Alice's perspective
let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob and charlie): {members:?}");
// Bob joins
println!("\nBob joins the group created by Alice");
let gide_2 = mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// Charlie joins
println!("\nCharlie joins the group created by Alice");
let gide_3 = mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome, None)?;
// List the members of the group from Bob's perspective
let members_2 = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members_2:?}");
// List the members of the group from Charlie's perspective
let members_3 =
mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members_3:?}");
// Assert that the group identifier is the same for Alice, Bob and Charlie
assert!(gide.group_id == gide_2.group_id);
assert!(gide.group_id == gide_3.group_id);
// Assert that the membership is the same for Alice, Bob and Charlie
assert!(members == members_2);
assert!(members == members_3);
Ok(())
}

View File

@@ -41,13 +41,12 @@ fn test_propose_add() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let charlie_id = let charlie_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -82,7 +81,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -111,7 +110,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -119,7 +118,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members:?}"); println!("Members (bob, after joining the group): {members:?}");
// //
@@ -154,7 +153,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members_bob = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members_bob = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after adding charlie): {members_bob:?}"); println!("Members (bob, after adding charlie): {members_bob:?}");
// Alice receives the commit // Alice receives the commit
@@ -167,7 +166,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
// List the members of the group // List the members of the group
let members_alice = let members_alice =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding charlie): {members_alice:?}"); println!("Members (alice, after adding charlie): {members_alice:?}");
// Extract the welcome from the commit output // Extract the welcome from the commit output
@@ -183,7 +182,7 @@ fn test_propose_add() -> Result<(), PlatformError> {
// List the members of the group // List the members of the group
let members_charlie = let members_charlie =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members_charlie:?}"); println!("Members (charlie, after joining the group): {members_charlie:?}");
// Test that Alice, Bob and Charlie are in the same group // Test that Alice, Bob and Charlie are in the same group

View File

@@ -47,13 +47,12 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let charlie_id = let charlie_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -88,7 +87,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -117,7 +116,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -125,7 +124,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members:?}"); println!("Members (bob, after joining the group): {members:?}");
// //
@@ -155,7 +154,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after adding charlie): {members:?}"); println!("Members (bob, after adding charlie): {members:?}");
// Alice receives the commit // Alice receives the commit
@@ -171,7 +170,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?; mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members:?}"); println!("Members (charlie, after joining the group): {members:?}");
// //
@@ -207,7 +206,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
mls_platform_api::mls_receive(&state_global, &alice_id, &commit_5_msg)?; mls_platform_api::mls_receive(&state_global, &alice_id, &commit_5_msg)?;
let members_alice = let members_alice =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after removing Bob): {members_alice:?}"); println!("Members (alice, after removing Bob): {members_alice:?}");
// Charlie processes the remove commit // Charlie processes the remove commit
@@ -215,7 +214,7 @@ fn test_group_propose_self_remove() -> Result<(), PlatformError> {
mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_5_msg)?; mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_5_msg)?;
let members_charlie = let members_charlie =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after removing bob): {members_charlie:?}"); println!("Members (charlie, after removing bob): {members_charlie:?}");
// Bob processes the remove commit // Bob processes the remove commit

View File

@@ -45,13 +45,12 @@ fn test_group_remove() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let charlie_id = let charlie_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -86,7 +85,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -115,7 +114,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -149,7 +148,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after adding charlie): {members:?}"); println!("Members (bob, after adding charlie): {members:?}");
// Alice receives the commit // Alice receives the commit
@@ -165,7 +164,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?; mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members:?}"); println!("Members (charlie, after joining the group): {members:?}");
// //
@@ -184,7 +183,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
&MessageOrAck::Ack(gide.group_id.to_vec()), &MessageOrAck::Ack(gide.group_id.to_vec()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after removing alice): {members:?}"); println!("Members (charlie, after removing alice): {members:?}");
// Alice receives the commit from Charlie // Alice receives the commit from Charlie
@@ -205,7 +204,7 @@ fn test_group_remove() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(commit_3.clone()), &MessageOrAck::MlsMessage(commit_3.clone()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after receiving alice's removal the group): {members:?}"); println!("Members (bob, after receiving alice's removal the group): {members:?}");
// Check if Alice is still in the members list // Check if Alice is still in the members list

View File

@@ -44,13 +44,12 @@ fn test_group_external_join() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let diana_id = let diana_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -77,7 +76,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -106,7 +105,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -114,7 +113,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members:?}"); println!("Members (bob, after joining the group): {members:?}");
// //
@@ -144,7 +143,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
)?; )?;
let members_alice = let members_alice =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!( println!(
"Members (alice, after receiving the commit allowing external join): {members_alice:?}" "Members (alice, after receiving the commit allowing external join): {members_alice:?}"
); );
@@ -156,7 +155,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(commit_4_output.commit.clone()), &MessageOrAck::MlsMessage(commit_4_output.commit.clone()),
)?; )?;
let members_bob = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members_bob = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after commit allowing external join): {members_bob:?}"); println!("Members (bob, after commit allowing external join): {members_bob:?}");
// //
@@ -177,7 +176,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
println!("Externally joined group {:?}", &external_commit_output.gid); println!("Externally joined group {:?}", &external_commit_output.gid);
let members_diana = let members_diana =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &diana_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &diana_id)?;
println!("Members (diane, after joining): {members_diana:?}"); println!("Members (diane, after joining): {members_diana:?}");
// Alice receives Diana's commit // Alice receives Diana's commit
@@ -189,7 +188,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
)?; )?;
let members_alice = let members_alice =
mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after receiving the commit from Diana): {members_alice:?}"); println!("Members (alice, after receiving the commit from Diana): {members_alice:?}");
// Bob receives Diana's commit // Bob receives Diana's commit
@@ -200,7 +199,7 @@ fn test_group_external_join() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(external_commit_output.external_commit.clone()), &MessageOrAck::MlsMessage(external_commit_output.external_commit.clone()),
)?; )?;
let members_bob = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members_bob = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after receiving the commit from Diana): {members_bob:?}"); println!("Members (bob, after receiving the commit from Diana): {members_bob:?}");
// Check if Diana is in the members list // Check if Diana is in the members list

View File

@@ -68,16 +68,15 @@ fn main() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
let charlie_id = let charlie_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let diana_id = let diana_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -113,7 +112,7 @@ fn main() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //
@@ -142,7 +141,7 @@ fn main() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, after adding bob): {members:?}"); println!("Members (alice, after adding bob): {members:?}");
// Bob joins // Bob joins
@@ -150,7 +149,7 @@ fn main() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?; mls_platform_api::mls_group_join(&state_global, &bob_id, &welcome, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after joining the group): {members:?}"); println!("Members (bob, after joining the group): {members:?}");
// //
@@ -201,7 +200,7 @@ fn main() -> Result<(), PlatformError> {
)?; )?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after adding charlie): {members:?}"); println!("Members (bob, after adding charlie): {members:?}");
// Alice receives the commit // Alice receives the commit
@@ -217,7 +216,7 @@ fn main() -> Result<(), PlatformError> {
mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?; mls_platform_api::mls_group_join(&state_global, &charlie_id, &welcome_2, None)?;
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after joining the group): {members:?}"); println!("Members (charlie, after joining the group): {members:?}");
// //
@@ -236,7 +235,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::Ack(gide.group_id.to_vec()), &MessageOrAck::Ack(gide.group_id.to_vec()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after removing alice): {members:?}"); println!("Members (charlie, after removing alice): {members:?}");
// Alice receives the commit from Charlie // Alice receives the commit from Charlie
@@ -257,7 +256,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(commit_3.clone()), &MessageOrAck::MlsMessage(commit_3.clone()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after receiving alice's removal the group): {members:?}"); println!("Members (bob, after receiving alice's removal the group): {members:?}");
// //
@@ -286,7 +285,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(commit_4_output.commit.clone()), &MessageOrAck::MlsMessage(commit_4_output.commit.clone()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
println!("Members (bob, after commit allowing external join): {members:?}"); println!("Members (bob, after commit allowing external join): {members:?}");
// Charlie receives Bob's commit with GroupInfo for External Join // Charlie receives Bob's commit with GroupInfo for External Join
@@ -296,7 +295,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(commit_4_output.commit), &MessageOrAck::MlsMessage(commit_4_output.commit),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after bob's commit allowing external join): {members:?}"); println!("Members (charlie, after bob's commit allowing external join): {members:?}");
// //
@@ -316,7 +315,7 @@ fn main() -> Result<(), PlatformError> {
println!("Externally joined group {:?}", &external_commit_output.gid); println!("Externally joined group {:?}", &external_commit_output.gid);
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &diana_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &diana_id)?;
println!("Members (diane, after joining): {members:?}"); println!("Members (diane, after joining): {members:?}");
// //
@@ -338,7 +337,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(external_commit_output.external_commit.clone()), &MessageOrAck::MlsMessage(external_commit_output.external_commit.clone()),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &diana_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &diana_id)?;
println!("Members (bob, after diane joined externally): {members:?}"); println!("Members (bob, after diane joined externally): {members:?}");
// Bob receives Diana's application message // Bob receives Diana's application message
@@ -365,7 +364,7 @@ fn main() -> Result<(), PlatformError> {
&MessageOrAck::MlsMessage(external_commit_output.external_commit), &MessageOrAck::MlsMessage(external_commit_output.external_commit),
)?; )?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after diane joined externally): {members:?}"); println!("Members (charlie, after diane joined externally): {members:?}");
// Charlie receives Diana's application message // Charlie receives Diana's application message
@@ -408,7 +407,7 @@ fn main() -> Result<(), PlatformError> {
println!("\nDiana processes the remove commit"); println!("\nDiana processes the remove commit");
mls_platform_api::mls_receive(&state_global, &diana_id, &commit_5_msg)?; mls_platform_api::mls_receive(&state_global, &diana_id, &commit_5_msg)?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &diana_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &diana_id)?;
println!("Members (diana, after removing bob): {members:?}"); println!("Members (diana, after removing bob): {members:?}");
// Charlie processes the remove commit // Charlie processes the remove commit
@@ -420,7 +419,7 @@ fn main() -> Result<(), PlatformError> {
)?; )?;
mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_5_msg)?; mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_5_msg)?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after removing bob): {members:?}"); println!("Members (charlie, after removing bob): {members:?}");
// Bob processes the remove commit // Bob processes the remove commit
@@ -435,7 +434,7 @@ fn main() -> Result<(), PlatformError> {
println!("Bob's state for the group has been removed"); println!("Bob's state for the group has been removed");
// Note: Bob cannot look at its own group state because it was already removed // Note: Bob cannot look at its own group state because it was already removed
// let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &bob_id)?; // let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &bob_id)?;
// let members_str = mls_platform_api::utils_json_bytes_to_string_custom(&members)?; // let members_str = mls_platform_api::utils_json_bytes_to_string_custom(&members)?;
// println!("Members (bob, after its removal): {members_str:?}"); // println!("Members (bob, after its removal): {members_str:?}");
@@ -464,7 +463,7 @@ fn main() -> Result<(), PlatformError> {
println!("\nCharlie processes the close commit"); println!("\nCharlie processes the close commit");
mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_6_msg)?; mls_platform_api::mls_receive(&state_global, &charlie_id, &commit_6_msg)?;
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &charlie_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &charlie_id)?;
println!("Members (charlie, after processing their group_close commit): {members:?}"); println!("Members (charlie, after processing their group_close commit): {members:?}");
// Charlie deletes her state for the group // Charlie deletes her state for the group

View File

@@ -35,10 +35,9 @@ fn test_send_receive() -> Result<(), PlatformError> {
// Create signature keypairs and store them in the state // Create signature keypairs and store them in the state
let alice_id = let alice_id =
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?; mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
let bob_id = let bob_id = mls_platform_api::mls_generate_identity(&state_global, group_config.ciphersuite)?;
mls_platform_api::mls_generate_signature_keypair(&state_global, group_config.ciphersuite)?;
println!("\nAlice identifier: {}", hex::encode(&alice_id)); println!("\nAlice identifier: {}", hex::encode(&alice_id));
println!("Bob identifier: {}", hex::encode(&bob_id)); println!("Bob identifier: {}", hex::encode(&bob_id));
@@ -64,7 +63,7 @@ fn test_send_receive() -> Result<(), PlatformError> {
println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id)); println!("\nGroup created by Alice: {}", hex::encode(&gide.group_id));
// List the members of the group // List the members of the group
let members = mls_platform_api::mls_group_members(&state_global, &gide.group_id, &alice_id)?; let members = mls_platform_api::mls_group_details(&state_global, &gide.group_id, &alice_id)?;
println!("Members (alice, before adding bob): {members:?}"); println!("Members (alice, before adding bob): {members:?}");
// //

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"47e853fb8eaac9e5edcb8d807e9914215137826800588e442ca0decc47054c7c","src/lib.rs":"e689947b850ea193ffe23d58a6f05d61e208c4f2d3dd5599824bc29b7daf97b7"},"package":null} {"files":{"Cargo.toml":"24c5a8bdedcd0069829fb5b4431ac7c52938f096b46f9c9a4e84feb0557887f7","src/lib.rs":"e689947b850ea193ffe23d58a6f05d61e208c4f2d3dd5599824bc29b7daf97b7"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-codec-derive" name = "mls-rs-codec-derive"
version = "0.1.1" version = "0.2.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"b452825ce75b75ee9e4466a7fd75596b0232ac8993eaddcfa7dd0954d12cb598","src/array.rs":"4fe2c298fb948e19456811e106434bea1e9b8db5e9a98d3985c914d2e5ed26ee","src/byte_vec.rs":"3b62de99b0c7ac368723980040d6795b5cf3f60469e3c39cd7119b6623fdcd62","src/cow.rs":"93fa5e53dbe5e07671071a8372268980ed27df06ffa1bcfe81f9ff2e9e49a34c","src/iter.rs":"f7224ebe151e09fec949977a7521d8fd9b6c4a467217d7762d0224667f25784e","src/lib.rs":"fd11155ec25a6798287b3dacad9ffbc5ba068e1e6ba818cd32a43d1fc4da4717","src/map.rs":"94a48e55db2b69702772006419fafd1f7e5db5e33a9ffa25b5e065d36864e38f","src/option.rs":"17737173897cfa6b7dcec5949bedc91f10216e7bb4b7e25552fa9a24d3b87e26","src/stdint.rs":"731e2ea9475a725f43cbff02b6e3deef692a9eb07efc5706c9e0d7ccabe46527","src/string.rs":"c272864e91e5d87563d330d771056f82d6be8c06f11384af779e40d436e4f9ad","src/tuple.rs":"f35165f5fe6e5d7611a36be2028e9d8f7805eef7b33f5dbe9fab5b543192cf85","src/varint.rs":"19f89bf5151a2fdaba89965b49edda31f1dce8745aed45fbc3da74be96625390","src/vec.rs":"d76b1d4782481deec47c5104a2f62c7f7f7125c230286521ebfa767f3807642e","src/writer.rs":"d374e6b497a1e5c30f435b9a92b80d00d989f2da36951a307e45255400c8309c","tests/macro_usage.rs":"9b6b1243d3783957f1064427da3ba84bbb333ae3d0d1724777f09fe3c7aef340"},"package":null} {"files":{"Cargo.toml":"cd169b49dfc90d7ed3be4b60dc07e53e4eb7b4339c423cae71177cd553e46037","src/array.rs":"4fe2c298fb948e19456811e106434bea1e9b8db5e9a98d3985c914d2e5ed26ee","src/bool.rs":"57ffa2118fa698527ecb6b67844edf7d0a748d8988deeb6fb40663def069673f","src/byte_vec.rs":"3b62de99b0c7ac368723980040d6795b5cf3f60469e3c39cd7119b6623fdcd62","src/cow.rs":"3c902f4568fe9713865990055b360ad94ffd812b9e2b8554082467ed553c4e8f","src/iter.rs":"f7224ebe151e09fec949977a7521d8fd9b6c4a467217d7762d0224667f25784e","src/lib.rs":"db5c0a03bf8bade9f00185efb2af86ca29df87001e618252cd82c6c081892e78","src/map.rs":"94a48e55db2b69702772006419fafd1f7e5db5e33a9ffa25b5e065d36864e38f","src/option.rs":"17737173897cfa6b7dcec5949bedc91f10216e7bb4b7e25552fa9a24d3b87e26","src/stdint.rs":"731e2ea9475a725f43cbff02b6e3deef692a9eb07efc5706c9e0d7ccabe46527","src/string.rs":"c272864e91e5d87563d330d771056f82d6be8c06f11384af779e40d436e4f9ad","src/tuple.rs":"f35165f5fe6e5d7611a36be2028e9d8f7805eef7b33f5dbe9fab5b543192cf85","src/varint.rs":"19f89bf5151a2fdaba89965b49edda31f1dce8745aed45fbc3da74be96625390","src/vec.rs":"d76b1d4782481deec47c5104a2f62c7f7f7125c230286521ebfa767f3807642e","src/writer.rs":"d374e6b497a1e5c30f435b9a92b80d00d989f2da36951a307e45255400c8309c","tests/macro_usage.rs":"9b6b1243d3783957f1064427da3ba84bbb333ae3d0d1724777f09fe3c7aef340"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-codec" name = "mls-rs-codec"
version = "0.5.3" version = "0.6.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false
@@ -38,7 +38,7 @@ name = "macro_usage"
path = "tests/macro_usage.rs" path = "tests/macro_usage.rs"
[dependencies.mls-rs-codec-derive] [dependencies.mls-rs-codec-derive]
version = "0.1.1" version = "0.2.0"
path = "../mls-rs-codec-derive" path = "../mls-rs-codec-derive"
[dependencies.thiserror] [dependencies.thiserror]
@@ -60,13 +60,9 @@ std = ["dep:thiserror"]
version = "0.2.79" version = "0.2.79"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test]
version = "0.3.26" version = "0.3"
default-features = false
[lints.rust.unexpected_cfgs] [lints.rust.unexpected_cfgs]
level = "warn" level = "warn"
priority = 0 priority = 0
check-cfg = [ check-cfg = ["cfg(mls_build_async)"]
"cfg(mls_build_async)",
"cfg(coverage_nightly)",
]

View File

@@ -0,0 +1,45 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::{MlsDecode, MlsEncode, MlsSize};
use alloc::vec::Vec;
impl MlsSize for bool {
fn mls_encoded_len(&self) -> usize {
1
}
}
impl MlsEncode for bool {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
writer.push(*self as u8);
Ok(())
}
}
impl MlsDecode for bool {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
MlsDecode::mls_decode(reader).map(|i: u8| i != 0)
}
}
#[cfg(test)]
mod tests {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
use crate::{MlsDecode, MlsEncode};
use alloc::vec;
#[test]
fn round_trip() {
assert_eq!(false.mls_encode_to_vec().unwrap(), vec![0]);
assert_eq!(true.mls_encode_to_vec().unwrap(), vec![1]);
let vec = vec![true, true, false];
let bytes = vec.mls_encode_to_vec().unwrap();
assert_eq!(vec, Vec::mls_decode(&mut &*bytes).unwrap())
}
}

View File

@@ -5,7 +5,7 @@ use alloc::{
use crate::{Error, MlsDecode, MlsEncode, MlsSize}; use crate::{Error, MlsDecode, MlsEncode, MlsSize};
impl<'a, T> MlsSize for Cow<'a, T> impl<T> MlsSize for Cow<'_, T>
where where
T: MlsSize + ToOwned, T: MlsSize + ToOwned,
{ {
@@ -14,7 +14,7 @@ where
} }
} }
impl<'a, T> MlsEncode for Cow<'a, T> impl<T> MlsEncode for Cow<'_, T>
where where
T: MlsEncode + ToOwned, T: MlsEncode + ToOwned,
{ {
@@ -24,7 +24,7 @@ where
} }
} }
impl<'a, T> MlsDecode for Cow<'a, T> impl<T> MlsDecode for Cow<'_, T>
where where
T: ToOwned, T: ToOwned,
<T as ToOwned>::Owned: MlsDecode, <T as ToOwned>::Owned: MlsDecode,

View File

@@ -18,6 +18,7 @@ pub mod byte_vec;
pub mod iter; pub mod iter;
mod bool;
mod cow; mod cow;
mod map; mod map;
mod option; mod option;
@@ -121,10 +122,7 @@ pub trait MlsDecode: Sized {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, Error>; fn mls_decode(reader: &mut &[u8]) -> Result<Self, Error>;
} }
impl<T> MlsDecode for Box<T> impl<T: MlsDecode> MlsDecode for Box<T> {
where
T: MlsDecode + ?Sized,
{
#[inline] #[inline]
fn mls_decode(reader: &mut &[u8]) -> Result<Self, Error> { fn mls_decode(reader: &mut &[u8]) -> Result<Self, Error> {
T::mls_decode(reader).map(Box::new) T::mls_decode(reader).map(Box::new)

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"c02d241248889a50524d9ab83670de089acdd13a086f64c3484b1e71d0b72287","src/crypto.rs":"a89d2272f3cb420a8332c9e086a5f8d2f30294c71125c4e1f48022dfd1b9f85e","src/crypto/cipher_suite.rs":"f9280f8e8ed3e62c826be459c1b8a541887c1c4967d10c471a981db32c51197b","src/crypto/test_suite.rs":"7296e457e41aa945b656b9e2f90fe3c563ea2882b814f339e7d7c5c1034cecd5","src/debug.rs":"19624870ad950983ada527bef80a29a452ed691c7719d6682ef74190b6e5cdc7","src/error.rs":"542f30724c055fbc6786d25d278e2e4a4f126f9e08a2a6bcc7bfe4ae334bc64c","src/extension.rs":"999539807f27d2cd5e806b9d6facccf3f46f866056ae9d9bb6c6e202051bb520","src/extension/list.rs":"1fe6808512a29827299a9e3a8cb541aaa796c350ed4406afe84885d85f07de39","src/group.rs":"9bba0509892862be12cd44599720be3445436445678637649cd6bf21c7b77b01","src/group/group_state.rs":"391c404f56bd915c54b38cf3e386adbb45e29d640403b7facc4eafb14f21e501","src/group/proposal_type.rs":"b59a4dc591ef96a74b7ac97475c052ab2d496ba5cf274ac7721eec14c687d219","src/group/roster.rs":"be60397a182acb2bede6e43ac3871a9ca5207a6bca4959971ddc5183b4d6f1b7","src/identity.rs":"69da494bb3a086976e90320bd375efc3d992182bcf999775f5248884105269fb","src/identity/basic.rs":"34ee955ef31b252ab269f78b0d6608a471e7b325f5fd0e6d56ca778446b35b68","src/identity/credential.rs":"e253007d12be0471200ae93368c637c5cb15612d8b1b109e15a4a6803f4e9d37","src/identity/provider.rs":"3f3d52bc63356e8d36eb0f515165ed8848ee9037055a80283962a68f0add3968","src/identity/signing_identity.rs":"ef6051a82fe86b0bb01f99e479f660f25168f81cbad9af7641f0597b22f3e343","src/identity/x509.rs":"439e42c348efe865322b6a99afb4b2986a3b85fc1d4a25b21788e7b4331e87b3","src/key_package.rs":"f647a60dc077b60b3afb74492ae57b253a6d9c014a67cef1710ca3a8822076fa","src/lib.rs":"d25aec48ec29456eac819e1be3501eea92645b1af4d520aba41737d24c3562fc","src/protocol_version.rs":"6ea04c3d4e30dfc060bafab23988471c2b5f08a8b9ade9bc0aeaccd389345a4f","src/psk.rs":"7e6c7906c4219e3405e17de8120b2a2506828cdf9c683ef2bca7e8f3ca7c6682","src/secret.rs":"cbb0895c13e51e1726c4dfcfe2706f77315c30d44a279512d85432b1d46e717f","src/time.rs":"f7f8a5a5ec6e99d3f5075cac4390a7a4b988d6a369cf3d7badfcc17c0f6d7a47"},"package":null} {"files":{"Cargo.toml":"e98ac432c2c50c37d37cefb93fdd368752e7688c6f33ca2bc371d9cfa3c92dad","src/crypto.rs":"68c121f31161755b113ef2165b3a816b0c9735d48c3f1e5e2ea22087f23d516a","src/crypto/cipher_suite.rs":"51030c3f63942f14132b5d021f1f4041a7e3930d47900171cb9092d2abc20638","src/crypto/test_suite.rs":"8cd8e19a9dd75ad7d11329ac62df0993c6bdbec73669e90eb4ac75318e466e47","src/debug.rs":"19624870ad950983ada527bef80a29a452ed691c7719d6682ef74190b6e5cdc7","src/error.rs":"542f30724c055fbc6786d25d278e2e4a4f126f9e08a2a6bcc7bfe4ae334bc64c","src/extension.rs":"c7da8ff888293d6792e0a4f845a59000bab7b8596288382a56f81de597d335e0","src/extension/list.rs":"1fe6808512a29827299a9e3a8cb541aaa796c350ed4406afe84885d85f07de39","src/group.rs":"0f8d58fdd577bb7fcd9fd90122a4432a572fb271190449f804d7ba8274608bad","src/group/context.rs":"e51cd0ea7e8f448b699fd1fd8ff39da2767e68c2ec99e7355ef8b51a56721627","src/group/group_state.rs":"391c404f56bd915c54b38cf3e386adbb45e29d640403b7facc4eafb14f21e501","src/group/proposal_type.rs":"b59a4dc591ef96a74b7ac97475c052ab2d496ba5cf274ac7721eec14c687d219","src/group/roster.rs":"be60397a182acb2bede6e43ac3871a9ca5207a6bca4959971ddc5183b4d6f1b7","src/identity.rs":"69da494bb3a086976e90320bd375efc3d992182bcf999775f5248884105269fb","src/identity/basic.rs":"34ee955ef31b252ab269f78b0d6608a471e7b325f5fd0e6d56ca778446b35b68","src/identity/credential.rs":"e253007d12be0471200ae93368c637c5cb15612d8b1b109e15a4a6803f4e9d37","src/identity/provider.rs":"b7b0e46d0d02f1ef84deb7964de256fa0914e6119e7e7d22c3dc6268ec74dd2b","src/identity/signing_identity.rs":"ef6051a82fe86b0bb01f99e479f660f25168f81cbad9af7641f0597b22f3e343","src/identity/x509.rs":"439e42c348efe865322b6a99afb4b2986a3b85fc1d4a25b21788e7b4331e87b3","src/key_package.rs":"f647a60dc077b60b3afb74492ae57b253a6d9c014a67cef1710ca3a8822076fa","src/lib.rs":"d25aec48ec29456eac819e1be3501eea92645b1af4d520aba41737d24c3562fc","src/protocol_version.rs":"6ea04c3d4e30dfc060bafab23988471c2b5f08a8b9ade9bc0aeaccd389345a4f","src/psk.rs":"7e6c7906c4219e3405e17de8120b2a2506828cdf9c683ef2bca7e8f3ca7c6682","src/secret.rs":"cbb0895c13e51e1726c4dfcfe2706f77315c30d44a279512d85432b1d46e717f","src/time.rs":"f7f8a5a5ec6e99d3f5075cac4390a7a4b988d6a369cf3d7badfcc17c0f6d7a47"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-core" name = "mls-rs-core"
version = "0.18.0" version = "0.21.0"
build = false build = false
exclude = ["test_data"] exclude = ["test_data"]
autolib = false autolib = false
@@ -52,7 +52,7 @@ version = "0.12"
optional = true optional = true
[dependencies.mls-rs-codec] [dependencies.mls-rs-codec]
version = "0.5.2" version = "0.6"
path = "../mls-rs-codec" path = "../mls-rs-codec"
default-features = false default-features = false
@@ -100,6 +100,8 @@ default = [
] ]
fast_serialize = ["mls-rs-codec/preallocate"] fast_serialize = ["mls-rs-codec/preallocate"]
ffi = [] ffi = []
last_resort_key_package_ext = []
post-quantum = []
rfc_compliant = ["x509"] rfc_compliant = ["x509"]
serde = [ serde = [
"dep:serde", "dep:serde",
@@ -127,8 +129,7 @@ async-trait = "^0.1"
version = "^0.2.79" version = "^0.2.79"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test]
version = "0.3.26" version = "0.3"
default-features = false
[lints.rust.unexpected_cfgs] [lints.rust.unexpected_cfgs]
level = "warn" level = "warn"

View File

@@ -21,6 +21,10 @@ pub mod test_suite;
#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque)
)]
/// Ciphertext produced by [`CipherSuiteProvider::hpke_seal`] /// Ciphertext produced by [`CipherSuiteProvider::hpke_seal`]
pub struct HpkeCiphertext { pub struct HpkeCiphertext {
#[mls_codec(with = "mls_rs_codec::byte_vec")] #[mls_codec(with = "mls_rs_codec::byte_vec")]
@@ -229,6 +233,12 @@ impl From<Vec<u8>> for SignaturePublicKey {
} }
} }
impl From<SignaturePublicKey> for Vec<u8> {
fn from(value: SignaturePublicKey) -> Self {
value.0
}
}
/// Byte representation of a signature key. /// Byte representation of a signature key.
// #[cfg_attr( // #[cfg_attr(
// all(feature = "ffi", not(test)), // all(feature = "ffi", not(test)),

View File

@@ -67,6 +67,18 @@ impl CipherSuite {
/// MLS_256_DHKEMP384_AES256GCM_SHA384_P384 /// MLS_256_DHKEMP384_AES256GCM_SHA384_P384
pub const P384_AES256: CipherSuite = CipherSuite(7); pub const P384_AES256: CipherSuite = CipherSuite(7);
/// So far, there are no official PQ cipher suites
#[cfg(feature = "post-quantum")]
pub const ML_KEM_512: CipherSuite = CipherSuite(65001);
#[cfg(feature = "post-quantum")]
pub const ML_KEM_768: CipherSuite = CipherSuite(65002);
#[cfg(feature = "post-quantum")]
pub const ML_KEM_1024: CipherSuite = CipherSuite(65003);
/// So far, there are no official PQ cipher suites
#[cfg(feature = "post-quantum")]
pub const ML_KEM_768_X25519: CipherSuite = CipherSuite(65100);
/// Ciphersuite from a raw value. /// Ciphersuite from a raw value.
pub const fn new(value: u16) -> CipherSuite { pub const fn new(value: u16) -> CipherSuite {
CipherSuite(value) CipherSuite(value)

View File

@@ -11,6 +11,7 @@ use super::{
CipherSuiteProvider, CryptoProvider, HpkeCiphertext, HpkeContextS, HpkePublicKey, HpkeSecretKey, CipherSuiteProvider, CryptoProvider, HpkeCiphertext, HpkeContextS, HpkePublicKey, HpkeSecretKey,
}; };
#[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
const PATH: &str = concat!( const PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"), env!("CARGO_MANIFEST_DIR"),
"/test_data/crypto_provider.json" "/test_data/crypto_provider.json"

View File

@@ -34,6 +34,9 @@ impl ExtensionType {
pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4); pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5); pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
#[cfg(feature = "last_resort_key_package_ext")]
pub const LAST_RESORT_KEY_PACKAGE: ExtensionType = ExtensionType(0x000A);
/// Default extension types defined /// Default extension types defined
/// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents) /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents)
pub const DEFAULT: &'static [ExtensionType] = &[ pub const DEFAULT: &'static [ExtensionType] = &[

View File

@@ -2,10 +2,12 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
mod context;
mod group_state; mod group_state;
mod proposal_type; mod proposal_type;
mod roster; mod roster;
pub use context::*;
pub use group_state::*; pub use group_state::*;
pub use proposal_type::*; pub use proposal_type::*;
pub use roster::*; pub use roster::*;

View File

@@ -2,16 +2,46 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use alloc::vec; use crate::{crypto::CipherSuite, extension::ExtensionList, protocol_version::ProtocolVersion};
use alloc::vec::Vec; use alloc::{vec, vec::Vec};
use core::fmt::{self, Debug}; use core::{
fmt::{self, Debug},
ops::Deref,
};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use crate::{cipher_suite::CipherSuite, protocol_version::ProtocolVersion, ExtensionList}; #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConfirmedTranscriptHash(
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
Vec<u8>,
);
use super::ConfirmedTranscriptHash; impl Debug for ConfirmedTranscriptHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
crate::debug::pretty_bytes(&self.0)
.named("ConfirmedTranscriptHash")
.fmt(f)
}
}
#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)] impl Deref for ConfirmedTranscriptHash {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Vec<u8>> for ConfirmedTranscriptHash {
fn from(value: Vec<u8>) -> Self {
Self(value)
}
}
#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
// #[cfg_attr( // #[cfg_attr(
// all(feature = "ffi", not(test)), // all(feature = "ffi", not(test)),
@@ -19,17 +49,17 @@ use super::ConfirmedTranscriptHash;
// )] // )]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GroupContext { pub struct GroupContext {
pub(crate) protocol_version: ProtocolVersion, pub protocol_version: ProtocolVersion,
pub(crate) cipher_suite: CipherSuite, pub cipher_suite: CipherSuite,
#[mls_codec(with = "mls_rs_codec::byte_vec")] #[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
pub(crate) group_id: Vec<u8>, pub group_id: Vec<u8>,
pub(crate) epoch: u64, pub epoch: u64,
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
#[mls_codec(with = "mls_rs_codec::byte_vec")] #[mls_codec(with = "mls_rs_codec::byte_vec")]
pub(crate) tree_hash: Vec<u8>, pub tree_hash: Vec<u8>,
pub(crate) confirmed_transcript_hash: ConfirmedTranscriptHash, pub confirmed_transcript_hash: ConfirmedTranscriptHash,
pub(crate) extensions: ExtensionList, pub extensions: ExtensionList,
} }
impl Debug for GroupContext { impl Debug for GroupContext {
@@ -37,15 +67,9 @@ impl Debug for GroupContext {
f.debug_struct("GroupContext") f.debug_struct("GroupContext")
.field("protocol_version", &self.protocol_version) .field("protocol_version", &self.protocol_version)
.field("cipher_suite", &self.cipher_suite) .field("cipher_suite", &self.cipher_suite)
.field( .field("group_id", &crate::debug::pretty_group_id(&self.group_id))
"group_id",
&mls_rs_core::debug::pretty_group_id(&self.group_id),
)
.field("epoch", &self.epoch) .field("epoch", &self.epoch)
.field( .field("tree_hash", &crate::debug::pretty_bytes(&self.tree_hash))
"tree_hash",
&mls_rs_core::debug::pretty_bytes(&self.tree_hash),
)
.field("confirmed_transcript_hash", &self.confirmed_transcript_hash) .field("confirmed_transcript_hash", &self.confirmed_transcript_hash)
.field("extensions", &self.extensions) .field("extensions", &self.extensions)
.finish() .finish()
@@ -54,20 +78,21 @@ impl Debug for GroupContext {
// #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] // #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
impl GroupContext { impl GroupContext {
pub(crate) fn new_group( /// Create a group context for a new MLS group.
pub fn new(
protocol_version: ProtocolVersion, protocol_version: ProtocolVersion,
cipher_suite: CipherSuite, cipher_suite: CipherSuite,
group_id: Vec<u8>, group_id: Vec<u8>,
tree_hash: Vec<u8>, tree_hash: Vec<u8>,
extensions: ExtensionList, extensions: ExtensionList,
) -> Self { ) -> GroupContext {
GroupContext { GroupContext {
protocol_version, protocol_version,
cipher_suite, cipher_suite,
group_id, group_id,
epoch: 0, epoch: 0,
tree_hash, tree_hash,
confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]), confirmed_transcript_hash: vec![].into(),
extensions, extensions,
} }
} }

View File

@@ -2,13 +2,37 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::{error::IntoAnyError, extension::ExtensionList, time::MlsTime}; use crate::{error::IntoAnyError, extension::ExtensionList, group::GroupContext, time::MlsTime};
#[cfg(mls_build_async)] #[cfg(mls_build_async)]
use alloc::boxed::Box; use alloc::boxed::Box;
use alloc::vec::Vec; use alloc::vec::Vec;
use super::{CredentialType, SigningIdentity}; use super::{CredentialType, SigningIdentity};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize,))]
#[non_exhaustive]
pub enum MemberValidationContext<'a> {
ForCommit {
current_context: &'a GroupContext,
new_extensions: &'a ExtensionList,
},
ForNewGroup {
current_context: &'a GroupContext,
},
None,
}
impl MemberValidationContext<'_> {
pub fn new_extensions(&self) -> Option<&ExtensionList> {
match self {
Self::ForCommit { new_extensions, .. } => Some(*new_extensions),
Self::ForNewGroup { current_context } => Some(&current_context.extensions),
Self::None => None,
}
}
}
/// Identity system that can be used to validate a /// Identity system that can be used to validate a
/// [`SigningIdentity`](mls-rs-core::identity::SigningIdentity) /// [`SigningIdentity`](mls-rs-core::identity::SigningIdentity)
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -26,7 +50,7 @@ pub trait IdentityProvider: Send + Sync {
&self, &self,
signing_identity: &SigningIdentity, signing_identity: &SigningIdentity,
timestamp: Option<MlsTime>, timestamp: Option<MlsTime>,
extensions: Option<&ExtensionList>, context: MemberValidationContext<'_>,
) -> Result<(), Self::Error>; ) -> Result<(), Self::Error>;
/// Determine if `signing_identity` is valid for an external sender in /// Determine if `signing_identity` is valid for an external sender in

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"b5402db1a96d1d70d4aaafc8a47745fbd2fafb170df76cfdab4e05763276b92a","src/context.rs":"e0bfb7c263739804f8bdfcabd773f063995a72f17816c7e4fd85b90acad097a9","src/dhkem.rs":"8707c390e23539dd13f5a2df427c7c02b78199ea147ab69596e5bb9f57bea9a6","src/hpke.rs":"a17ce2540c081d79bc0984511652a7577b2eb5781ec1a72094044b53a7639add","src/kdf.rs":"61d4bf34df2edcbb559f8b8c740fbdfd56f97445ffcc3b4a827b24bf50e306b6","src/lib.rs":"f10574d579bc55ca0f3893f229d50961d91d78afcc008f2fd80fd5d67ef88921","src/test_utils.rs":"4675d03c435558d02ec9ed7e23fe81842f3313a884de499888c60b63ecbdce2e"},"package":null} {"files":{"Cargo.toml":"36f17057356825750e59e111019dba4e4a5ae3f1efa3b53d1442858fdf9e905c","src/context.rs":"e0bfb7c263739804f8bdfcabd773f063995a72f17816c7e4fd85b90acad097a9","src/dhkem.rs":"ba2392d925c38f2df15aef7333d4c1b850fbbdce38c1f02e269a34e349094943","src/hpke.rs":"cb0c0c58c492f818a2328854481187e13a718d79e602c2e4f71274dbd17810a4","src/kdf.rs":"61d4bf34df2edcbb559f8b8c740fbdfd56f97445ffcc3b4a827b24bf50e306b6","src/kem_combiner.rs":"7d9a270a7028420b0d0bf49cd4f2a664ea386e673ef80cbf4f620971569b6efc","src/lib.rs":"544ac3ed080cb33f00893b55f2cbacec22e8c80d9779f1e3ddaa2f5e9acaf220","src/test_utils.rs":"4675d03c435558d02ec9ed7e23fe81842f3313a884de499888c60b63ecbdce2e"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-crypto-hpke" name = "mls-rs-crypto-hpke"
version = "0.9.0" version = "0.14.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false
@@ -43,12 +43,12 @@ cfg-if = "^1"
maybe-async = "0.2.10" maybe-async = "0.2.10"
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
default-features = false default-features = false
[dependencies.mls-rs-crypto-traits] [dependencies.mls-rs-crypto-traits]
version = "0.10.0" version = "0.15.0"
path = "../mls-rs-crypto-traits" path = "../mls-rs-crypto-traits"
default-features = false default-features = false
@@ -73,7 +73,7 @@ version = "^0.4.3"
features = ["serde"] features = ["serde"]
[dev-dependencies.mls-rs-crypto-traits] [dev-dependencies.mls-rs-crypto-traits]
version = "0.10.0" version = "0.15.0"
path = "../mls-rs-crypto-traits" path = "../mls-rs-crypto-traits"
features = ["mock"] features = ["mock"]
@@ -105,8 +105,7 @@ version = "0.2"
features = ["js"] features = ["js"]
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test]
version = "0.3.26" version = "0.3"
default-features = false
[lints.rust.unexpected_cfgs] [lints.rust.unexpected_cfgs]
level = "warn" level = "warn"

View File

@@ -2,7 +2,7 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use mls_rs_crypto_traits::{DhType, KdfType, KemResult, KemType}; use mls_rs_crypto_traits::{DhType, KdfType, KemResult, KemType, SamplingMethod};
use mls_rs_core::{ use mls_rs_core::{
crypto::{HpkePublicKey, HpkeSecretKey}, crypto::{HpkePublicKey, HpkeSecretKey},
@@ -75,24 +75,31 @@ impl<DH: DhType, KDF: KdfType> KemType for DhKem<DH, KDF> {
self.kem_id self.kem_id
} }
async fn derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> { async fn generate_deterministic(
let dkp_prk = self &self,
.kdf seed: &[u8],
.labeled_extract(&[], b"dkp_prk", ikm) ) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> {
.await match self.dh.bitmask_for_rejection_sampling() {
.map_err(|e| DhKemError::KdfError(e.into_any_error()))?; SamplingMethod::HpkeWithBitmask(bitmask) => {
self.derive_with_rejection_sampling(seed, bitmask).await
if let Some(bitmask) = self.dh.bitmask_for_rejection_sampling() { }
self.derive_with_rejection_sampling(&dkp_prk, bitmask).await SamplingMethod::HpkeWithoutBitmask => {
} else { self.derive_without_rejection_sampling(seed).await
self.derive_without_rejection_sampling(&dkp_prk).await }
SamplingMethod::Raw => self.derive_raw(seed.to_vec()).await,
} }
} }
async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> { async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> {
#[cfg(feature = "test_utils")] #[cfg(feature = "test_utils")]
if !self.test_key_data.is_empty() { if !self.test_key_data.is_empty() {
return self.derive(&self.test_key_data).await; let dkp_prk = self
.kdf
.labeled_extract(&[], b"dkp_prk", &self.test_key_data)
.await
.map_err(|e| DhKemError::KdfError(e.into_any_error()))?;
return self.generate_deterministic(&dkp_prk).await;
} }
self.dh self.dh
@@ -150,6 +157,17 @@ impl<DH: DhType, KDF: KdfType> KemType for DhKem<DH, KDF> {
.public_key_validate(key) .public_key_validate(key)
.map_err(|e| DhKemError::DhError(e.into_any_error())) .map_err(|e| DhKemError::DhError(e.into_any_error()))
} }
fn seed_length_for_derive(&self) -> usize {
self.n_secret
}
fn public_key_size(&self) -> usize {
self.dh.public_key_size()
}
fn secret_key_size(&self) -> usize {
self.dh.secret_key_size()
}
} }
impl<DH: DhType, KDF: KdfType> DhKem<DH, KDF> { impl<DH: DhType, KDF: KdfType> DhKem<DH, KDF> {
@@ -194,8 +212,17 @@ impl<DH: DhType, KDF: KdfType> DhKem<DH, KDF> {
.kdf .kdf
.labeled_expand(dkp_prk, b"sk", &[], self.dh.secret_key_size()) .labeled_expand(dkp_prk, b"sk", &[], self.dh.secret_key_size())
.await .await
.map_err(|e| DhKemError::KdfError(e.into_any_error()))? .map_err(|e| DhKemError::KdfError(e.into_any_error()))?;
.into();
self.derive_raw(sk).await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn derive_raw(
&self,
seed: Vec<u8>,
) -> Result<(HpkeSecretKey, HpkePublicKey), DhKemError> {
let sk = seed.into();
let pk = self let pk = self
.dh .dh

View File

@@ -70,6 +70,7 @@ pub struct Hpke<KEM: KemType, KDF: KdfType, AEAD: AeadType> {
pub(crate) kem: KEM, pub(crate) kem: KEM,
pub(crate) kdf: HpkeKdf<KDF>, pub(crate) kdf: HpkeKdf<KDF>,
pub(crate) aead: Option<AEAD>, pub(crate) aead: Option<AEAD>,
kem_kdf: HpkeKdf<KDF>,
} }
#[derive(Debug, Clone, Eq, PartialEq, Default)] #[derive(Debug, Clone, Eq, PartialEq, Default)]
@@ -104,8 +105,14 @@ where
] ]
.concat(); .concat();
let kdf = HpkeKdf::new(suite_id, kdf); let kem_suite_id = [b"KEM", &kem.kem_id().to_be_bytes() as &[u8]].concat();
Self { kem, kdf, aead }
Self {
kem,
aead,
kdf: HpkeKdf::new(suite_id, kdf.clone()),
kem_kdf: HpkeKdf::new(kem_suite_id, kdf),
}
} }
/// Based on RFC 9180 Single-Shot APIs. This function combines the action /// Based on RFC 9180 Single-Shot APIs. This function combines the action
@@ -210,8 +217,14 @@ where
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), HpkeError> { pub async fn derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), HpkeError> {
let dkp_prk = self
.kem_kdf
.labeled_extract(&[], b"dkp_prk", ikm)
.await
.map_err(|e| HpkeError::KdfError(e.into_any_error()))?;
self.kem self.kem
.derive(ikm) .generate_deterministic(&dkp_prk)
.await .await
.map_err(|e| HpkeError::KemError(e.into_any_error())) .map_err(|e| HpkeError::KemError(e.into_any_error()))
} }

View File

@@ -0,0 +1,543 @@
use alloc::vec::Vec;
use mls_rs_core::{
crypto::{HpkePublicKey, HpkeSecretKey},
error::{AnyError, IntoAnyError},
};
use mls_rs_crypto_traits::{Hash, KemResult, KemType, VariableLengthHash};
use zeroize::Zeroize;
#[derive(Debug)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
pub enum Error {
#[cfg_attr(feature = "std", error(transparent))]
KemError(AnyError),
#[cfg_attr(feature = "std", error(transparent))]
HashError(AnyError),
#[cfg_attr(feature = "std", error("invalid key data"))]
InvalidKeyData,
#[cfg_attr(feature = "std", error(transparent))]
MlsCodecError(mls_rs_core::mls_rs_codec::Error),
}
impl From<mls_rs_core::mls_rs_codec::Error> for Error {
#[inline]
fn from(e: mls_rs_core::mls_rs_codec::Error) -> Self {
Error::MlsCodecError(e)
}
}
impl IntoAnyError for Error {}
#[derive(Clone)]
pub struct CombinedKem<KEM1, KEM2, H, VH, F> {
kem1: KEM1,
kem2: KEM2,
hash: H,
variable_length_hash: VH,
shared_secret_hash_input: F,
}
impl<KEM1, KEM2, H, VH, F> CombinedKem<KEM1, KEM2, H, VH, F> {
pub fn new_custom(
kem1: KEM1,
kem2: KEM2,
hash: H,
variable_length_hash: VH,
shared_secret_hash_input: F,
) -> Self {
Self {
kem1,
kem2,
hash,
variable_length_hash,
shared_secret_hash_input,
}
}
}
pub trait SharedSecretHashInput: Send + Sync {
fn input<'a>(
&self,
ss_details1: SharedSecretDetails<'a>,
ss_details2: SharedSecretDetails<'a>,
) -> Vec<u8>;
}
#[derive(Debug, Clone, Copy)]
pub struct DefaultSharedSecretHashInput;
impl<KEM1, KEM2, H, VH> CombinedKem<KEM1, KEM2, H, VH, DefaultSharedSecretHashInput> {
pub fn new(kem1: KEM1, kem2: KEM2, hash: H, variable_length_hash: VH) -> Self {
Self {
kem1,
kem2,
hash,
variable_length_hash,
shared_secret_hash_input: DefaultSharedSecretHashInput,
}
}
}
/// Secure for any combiner KEMs.
impl SharedSecretHashInput for DefaultSharedSecretHashInput {
fn input<'a>(
&self,
ss_details1: SharedSecretDetails<'a>,
ss_details2: SharedSecretDetails<'a>,
) -> Vec<u8> {
[
ss_details1.enc,
ss_details1.shared_secret,
ss_details1.public_key,
ss_details2.enc,
ss_details2.shared_secret,
ss_details2.public_key,
]
.concat()
}
}
#[derive(Debug, Clone, Copy)]
pub struct XWingSharedSecretHashInput;
impl<KEM1, KEM2, H, VH> CombinedKem<KEM1, KEM2, H, VH, XWingSharedSecretHashInput> {
pub fn new_xwing(kem1: KEM1, kem2: KEM2, hash: H, variable_length_hash: VH) -> Self {
Self {
kem1,
kem2,
hash,
variable_length_hash,
shared_secret_hash_input: XWingSharedSecretHashInput,
}
}
}
/// Defined in https://www.ietf.org/archive/id/draft-connolly-cfrg-xwing-kem-01.html
///
/// IND-CCA secure for some KEMs (also, IND-RCCA secure for all KEMs)
impl SharedSecretHashInput for XWingSharedSecretHashInput {
fn input<'a>(
&self,
ss_details1: SharedSecretDetails<'a>,
ss_details2: SharedSecretDetails<'a>,
) -> Vec<u8> {
[
b"\\./\n/^\\",
ss_details1.shared_secret,
ss_details2.shared_secret,
ss_details2.enc,
ss_details2.public_key,
]
.concat()
}
}
pub struct SharedSecretDetails<'a> {
pub shared_secret: &'a [u8],
pub enc: &'a [u8],
pub public_key: &'a HpkePublicKey,
}
impl<'a> SharedSecretDetails<'a> {
pub fn new(shared_secret: &'a [u8], enc: &'a [u8], public_key: &'a HpkePublicKey) -> Self {
Self {
shared_secret,
enc,
public_key,
}
}
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
#[cfg_attr(
all(not(target_arch = "wasm32"), mls_build_async),
maybe_async::must_be_async
)]
impl<KEM1, KEM2, H, VH, F> KemType for CombinedKem<KEM1, KEM2, H, VH, F>
where
KEM1: KemType,
KEM2: KemType,
H: Hash,
VH: VariableLengthHash,
F: SharedSecretHashInput,
{
type Error = Error;
fn kem_id(&self) -> u16 {
// TODO not set by any RFC
15
}
async fn generate_deterministic(
&self,
seed: &[u8],
) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> {
self.generate_deterministic(seed).await
}
async fn encap(&self, remote_key: &HpkePublicKey) -> Result<KemResult, Self::Error> {
let (pk1, pk2) = self.parse_key(remote_key, self.kem1.public_key_size())?;
let pk1 = pk1.into();
let pk2 = pk2.into();
let ct1 = self
.kem1
.encap(&pk1)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let ct2 = self
.kem2
.encap(&pk2)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let enc = [&ct1.enc[..], &ct2.enc].concat();
let ss_details1 = SharedSecretDetails::new(&ct1.shared_secret, &ct1.enc, &pk1);
let ss_details2 = SharedSecretDetails::new(&ct2.shared_secret, &ct2.enc, &pk2);
let mut shared_secret_input = self
.shared_secret_hash_input
.input(ss_details1, ss_details2);
let shared_secret = self
.hash
.hash(&shared_secret_input)
.map_err(|e| Error::KemError(e.into_any_error()))?;
shared_secret_input.zeroize();
Ok(KemResult { shared_secret, enc })
}
async fn decap(
&self,
enc: &[u8],
secret_key: &HpkeSecretKey,
local_public: &HpkePublicKey,
) -> Result<Vec<u8>, Self::Error> {
let (pk1, pk2) = self.parse_key(local_public, self.kem1.public_key_size())?;
let (sk1, sk2) = self.parse_key(secret_key, self.kem1.secret_key_size())?;
let (enc1, enc2) = self.parse_key(enc, self.kem1.enc_size())?;
let pk1 = pk1.into();
let pk2 = pk2.into();
let sk1 = sk1.into();
let sk2 = sk2.into();
let shared_secret1 = self
.kem1
.decap(&enc1, &sk1, &pk1)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let shared_secret2 = self
.kem2
.decap(&enc2, &sk2, &pk2)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let ss_details1 = SharedSecretDetails::new(&shared_secret1, &enc1, &pk1);
let ss_details2 = SharedSecretDetails::new(&shared_secret2, &enc2, &pk2);
let mut shared_secret_input = self
.shared_secret_hash_input
.input(ss_details1, ss_details2);
let shared_secret = self
.hash
.hash(&shared_secret_input)
.map_err(|e| Error::KemError(e.into_any_error()))?;
shared_secret_input.zeroize();
Ok(shared_secret)
}
fn public_key_validate(&self, _key: &HpkePublicKey) -> Result<(), Self::Error> {
// TODO Not clear how to do this for Kyber or how useful it is.
Ok(())
}
async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> {
let (sk1, pk1) = self
.kem1
.generate()
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let (sk2, pk2) = self
.kem2
.generate()
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let sk = [sk1.as_ref(), &sk2].concat();
let pk = [pk1.as_ref(), &pk2].concat();
Ok((sk.into(), pk.into()))
}
fn seed_length_for_derive(&self) -> usize {
self.kem1.seed_length_for_derive() + self.kem2.seed_length_for_derive()
}
fn public_key_size(&self) -> usize {
self.kem1.public_key_size() + self.kem2.public_key_size()
}
fn secret_key_size(&self) -> usize {
self.kem1.secret_key_size() + self.kem2.secret_key_size()
}
fn enc_size(&self) -> usize {
self.kem1.enc_size() + self.kem2.enc_size()
}
}
impl<KEM1, KEM2, H, VH, F> CombinedKem<KEM1, KEM2, H, VH, F>
where
KEM1: KemType,
KEM2: KemType,
H: Hash,
VH: VariableLengthHash,
F: SharedSecretHashInput,
{
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn generate_deterministic(
&self,
ikm: &[u8],
) -> Result<(HpkeSecretKey, HpkePublicKey), Error> {
let ikm = self
.variable_length_hash
.hash(ikm, self.seed_length_for_derive())
.map_err(|e| Error::KemError(e.into_any_error()))?;
let (ikm1, ikm2) = ikm.split_at(self.kem1.seed_length_for_derive());
self.generate_key_pair_derand(ikm1, ikm2).await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn generate_key_pair_derand(
&self,
ikm1: &[u8],
ikm2: &[u8],
) -> Result<(HpkeSecretKey, HpkePublicKey), Error> {
let (sk1, pk1) = self
.kem1
.generate_deterministic(ikm1)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let (sk2, pk2) = self
.kem2
.generate_deterministic(ikm2)
.await
.map_err(|e| Error::KemError(e.into_any_error()))?;
let sk = [sk1.as_ref(), &sk2].concat();
let pk = [pk1.as_ref(), &pk2].concat();
Ok((sk.into(), pk.into()))
}
fn parse_key(&self, key: &[u8], size: usize) -> Result<(Vec<u8>, Vec<u8>), Error> {
(key.len() >= size)
.then_some(())
.ok_or(Error::InvalidKeyData)?;
let (key1, key2) = key.split_at(size);
Ok((key1.to_vec(), key2.to_vec()))
}
}
// Makes no sense to test this both in sync and async mode
#[cfg(all(test, not(mls_build_async)))]
mod tests {
use mls_rs_core::crypto::{HpkePublicKey, HpkeSecretKey};
use mls_rs_crypto_traits::{
mock::{MockHash, MockKemType, MockVariableLengthHash},
KemResult, KemType,
};
use super::{
CombinedKem, DefaultSharedSecretHashInput, SharedSecretHashInput,
XWingSharedSecretHashInput,
};
fn pk(i: u8) -> HpkePublicKey {
if i == 12 {
b"pk1pk2".to_vec().into()
} else {
format!("pk{i}").into_bytes().into()
}
}
fn sk(i: u8) -> HpkeSecretKey {
if i == 12 {
b"sk1sk2".to_vec().into()
} else {
format!("sk{i}").into_bytes().into()
}
}
fn enc(i: u8) -> Vec<u8> {
if i == 12 {
b"enc1enc2".to_vec()
} else {
format!("enc{i}").into_bytes()
}
}
fn ss(i: u8) -> Vec<u8> {
format!("ss{i}").into_bytes()
}
fn ikm(i: u8) -> Vec<u8> {
format!("ikm{i}").into_bytes()
}
#[test]
fn generate_deterministic() {
let mut kem1 = MockKemType::new();
let mut kem2 = MockKemType::new();
let hash = MockHash::new();
let mut variable_length_hash = MockVariableLengthHash::new();
variable_length_hash
.expect_hash()
.withf(|ikm, ikm1_len| ikm == b"test ikm" && *ikm1_len == 8)
.return_once(|_, _| Ok([ikm(1), ikm(2)].concat()));
kem1.expect_seed_length_for_derive().returning(|| 4);
kem2.expect_seed_length_for_derive().returning(|| 4);
kem1.expect_generate_deterministic()
.withf(|ikm1| ikm1 == ikm(1))
.return_once(|_| Ok((sk(1), pk(1))));
kem2.expect_generate_deterministic()
.withf(|ikm1| ikm1 == ikm(2))
.return_once(|_| Ok((sk(2), pk(2))));
let kem = CombinedKem::new(kem1, kem2, hash, variable_length_hash);
let keypair = kem.generate_deterministic(b"test ikm").unwrap();
assert_eq!(keypair.0, sk(12));
assert_eq!(keypair.1, pk(12));
}
#[test]
fn generate() {
let mut kem1 = MockKemType::new();
let mut kem2 = MockKemType::new();
let hash = MockHash::new();
let variable_length_hash = MockVariableLengthHash::new();
kem1.expect_generate().return_once(|| Ok((sk(1), pk(1))));
kem2.expect_generate().return_once(|| Ok((sk(2), pk(2))));
let kem = CombinedKem::new(kem1, kem2, hash, variable_length_hash);
let keypair = kem.generate().unwrap();
assert_eq!(keypair.0, sk(12));
assert_eq!(keypair.1, pk(12));
}
fn encap_test<F: SharedSecretHashInput>(hash_input_bytes: Vec<u8>, hash_input_fn: F) {
let mut kem1 = MockKemType::new();
let mut kem2 = MockKemType::new();
let mut hash = MockHash::new();
let variable_length_hash = MockVariableLengthHash::new();
kem1.expect_public_key_size().returning(|| pk(1).len());
kem1.expect_enc_size().returning(|| enc(1).len());
kem1.expect_encap()
.withf(|pk1| pk1 == &pk(1))
.return_once(|_| Ok(KemResult::new(ss(1), enc(1))));
kem2.expect_encap()
.withf(|pk2| pk2 == &pk(2))
.return_once(|_| Ok(KemResult::new(ss(2), enc(2))));
hash.expect_hash()
.withf(move |input| input == hash_input_bytes)
.return_once(|_| Ok(b"shared secret".to_vec()));
let kem = CombinedKem::new_custom(kem1, kem2, hash, variable_length_hash, hash_input_fn);
let encap_result = kem.encap(&pk(12)).unwrap();
assert_eq!(encap_result.enc, enc(12));
assert_eq!(encap_result.shared_secret, b"shared secret");
}
#[test]
fn encap() {
encap_test(
[&enc(1)[..], &ss(1), &pk(1), &enc(2), &ss(2), &pk(2)].concat(),
DefaultSharedSecretHashInput,
);
encap_test(
[b"\\./\n/^\\".as_slice(), &ss(1), &ss(2), &enc(2), &pk(2)].concat(),
XWingSharedSecretHashInput,
);
}
#[test]
fn decap() {
let mut kem1 = MockKemType::new();
let mut kem2 = MockKemType::new();
let mut hash = MockHash::new();
let variable_length_hash = MockVariableLengthHash::new();
kem1.expect_public_key_size().returning(|| pk(1).len());
kem1.expect_enc_size().returning(|| enc(1).len());
kem1.expect_secret_key_size().returning(|| sk(1).len());
kem1.expect_decap()
.withf(|enc1, sk1, pk1| enc1 == enc(1) && sk1 == &sk(1) && pk1 == &pk(1))
.return_once(|_, _, _| Ok(ss(1)));
kem2.expect_decap()
.withf(|enc2, sk2, pk2| enc2 == enc(2) && sk2 == &sk(2) && pk2 == &pk(2))
.return_once(|_, _, _| Ok(ss(2)));
hash.expect_hash()
.withf(|input| input == [&enc(1)[..], &ss(1), &pk(1), &enc(2), &ss(2), &pk(2)].concat())
.return_once(|_| Ok(b"shared secret".to_vec()));
let kem = CombinedKem::new(kem1, kem2, hash, variable_length_hash);
let decap_result = kem.decap(&enc(12), &sk(12), &pk(12)).unwrap();
assert_eq!(decap_result.as_slice(), b"shared secret");
}
#[test]
fn sizes() {
let mut kem1 = MockKemType::new();
let mut kem2 = MockKemType::new();
let hash = MockHash::new();
let variable_length_hash = MockVariableLengthHash::new();
kem1.expect_public_key_size().returning(|| 1);
kem1.expect_enc_size().returning(|| 10);
kem1.expect_secret_key_size().returning(|| 100);
kem2.expect_public_key_size().returning(|| 1000);
kem2.expect_enc_size().returning(|| 10000);
kem2.expect_secret_key_size().returning(|| 100000);
let kem = CombinedKem::new(kem1, kem2, hash, variable_length_hash);
assert_eq!(kem.public_key_size(), 1001);
assert_eq!(kem.secret_key_size(), 100100);
assert_eq!(kem.enc_size(), 10010);
}
}

View File

@@ -12,6 +12,7 @@ pub mod context;
pub mod dhkem; pub mod dhkem;
pub mod hpke; pub mod hpke;
pub mod kdf; pub mod kdf;
pub mod kem_combiner;
#[cfg(feature = "test_utils")] #[cfg(feature = "test_utils")]
mod test_utils; mod test_utils;

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"84c09a17a98c7672a7a81c2422882fd16d47885287de96a992f37c900d8ca3c8","src/aead.rs":"7aca3017a46ba936306e5e6939d50b4683f9b58b0cc7e70a61e8d811da726044","src/ec.rs":"96a4b0e3eb194a7e0440d6e27a125439271c3d8e332134ec94e32ce48d953eb5","src/ec_signer.rs":"cb756c5e5eab32d25ce559a85b5ec6aa183fccbc87a5a3d3f61977d49c79b1ac","src/ecdh.rs":"38bfe1d727fea9e3b9b27626c905fb36b381adefcc8ef15413b9d264e2808798","src/kdf.rs":"fe2b421e88c65566111b12c6c47cca4a213c4632d774d283cca43bb941bf8a86","src/lib.rs":"82d51e46925f3cc0da9365c8515410372297eb0a8821755f6dbbeedad017793f","src/mac.rs":"30b9d6e26e203bd25fa06a87a507b74fc2684c4fbe48649b6695939816bee6a2","test_data/test_ecdh.json":"2d38901d22a3e7f4d5ea080ab54b09c627bcbe826d8999183c7b27514a65c73c","test_data/test_ecdsa_eddsa.json":"10e41600049e7b9ec55dab3c6eb45bb2991639f5f17a9b88fdfd619a3620d80f","test_data/test_hash.json":"c5bde32059b20f5d65e0eb1a135edf060c2104aae520496c4df2384ecd811b0d","test_data/test_private_keys.json":"623425f7fc055a411955231a70dff15e4ee5c260d244da2e77f36e24de88dff4","test_data/test_public_keys.json":"b7cedee64377fcc7b2e825d93bed6bb013a9205b4b0f37e4793a70591c6b4851","test_data/x509/another_ca.der":"e3969dfe7880f1d965f081cd4e78c294430c3042cc706430d0ce57a5a2d6a8cf","test_data/x509/ca.der":"431c5c28d3189da986d788a1467f7a7b2e512f87f0072ca8ff0d96d6db886d23","test_data/x509/cert_ip.der":"9c4117c6858a18f03427be92da5d57a14b20d6f491411fd3cca15afaad10a21d","test_data/x509/github_intermediate.der":"f7a9a1b2fd964a3f2670bd668d561fb7c55d3aa9ab8391e7e169702db8a3dbcf","test_data/x509/github_leaf.der":"b7bc5510cc1c637b5e5fb785816a773dbb394b68337b1b117ca5ab43ccf778cf","test_data/x509/intermediate.der":"29d687905fca5f7c56d20b9fbf644b176b56e08aded62bf27dffe8d19fc50bca","test_data/x509/intermediate_ca/cert.der":"0e78c2bdf6252ebbd2371a5ed7990dc49be1f8a3fda4a7dcaff64fe5c4221571","test_data/x509/intermediate_ca/cert.pem":"5064d0025b31a9a6bcd198da4430f45886a8d98c0ecb953afa5627776d7dbffb","test_data/x509/intermediate_ca/csr.pem":"911bb5f7a2b1cb637d4fc79e76b08b2fe62f323879043979d61c8c224444126e","test_data/x509/intermediate_ca/key.pem":"695df243272468925c39aefe62d74aaab83e789d7824ad068ae4a1d05c67956b","test_data/x509/leaf.der":"225336abba7cb1fa74e5778a63cfd1a0ab7676b363aec06f5e6e5c4b10198764","test_data/x509/leaf/cert.der":"51992c356265f39e970dd86924371838bf0c45819c7cc02afac21897753af9e5","test_data/x509/leaf/cert.pem":"e39488ee664217fc85c5d9a47288d37b58029c194a4515f583062d0a0d22de7e","test_data/x509/leaf/csr.der":"609e9f076dfb795a60bb45e6227385fd274d1433ebbba9a7c2cb1938348599c6","test_data/x509/leaf/csr.pem":"27308af2703d97a111705e1b2c48727156170625dad0dfa010bf24d00076ba10","test_data/x509/leaf/key":"9c7a4e9828e4b38296eb5dc522d3df90826e4f10722baf2ee90d747abf6fde92","test_data/x509/leaf/key.pem":"6ab57b02c521fb16f1879ecd80ea90bdc5a4c114e578ec1578a6f2d065b3b6d9","test_data/x509/root_ca/cert.der":"5b7be0abfe6d4cfa1de47795472f2988d93dc95d43a213aad1051d715185beeb","test_data/x509/root_ca/cert.pem":"37f2239755d3ba1fbc60fdae9ab580f8d10c0a394bc223a6cbe5d62078c4d3b2","test_data/x509/root_ca/csr.der":"43f62ffead1969dd6f0872a14575536e6e93d7039e612fd9a03befd7be6c1019","test_data/x509/root_ca/csr.pem":"f241e27006816e953aa7cf4894fc129e7edae3c8dd2381cced53515c13d54cdc","test_data/x509/root_ca/key":"dd4e29082b01643ec501d61ccfc90a9c327513b2915d997299cf80cc0395b4cc","test_data/x509/root_ca/key.pem":"91ee6d662d5f363fbce5bba74f3866896a52f4d763cd061e3e2cf4c968647bc8"},"package":null} {"files":{"Cargo.toml":"8919f2e378364b2432eb567d5174652b6a026c2dc42c2a22094a83aa86355000","src/aead.rs":"7aca3017a46ba936306e5e6939d50b4683f9b58b0cc7e70a61e8d811da726044","src/ec.rs":"96a4b0e3eb194a7e0440d6e27a125439271c3d8e332134ec94e32ce48d953eb5","src/ec_signer.rs":"cb756c5e5eab32d25ce559a85b5ec6aa183fccbc87a5a3d3f61977d49c79b1ac","src/ecdh.rs":"daa9dbef41d19e7dc82a09ea20d4dfbf6d2296450393df337b925bdfa2c17489","src/kdf.rs":"fe2b421e88c65566111b12c6c47cca4a213c4632d774d283cca43bb941bf8a86","src/lib.rs":"82d51e46925f3cc0da9365c8515410372297eb0a8821755f6dbbeedad017793f","src/mac.rs":"30b9d6e26e203bd25fa06a87a507b74fc2684c4fbe48649b6695939816bee6a2","test_data/test_ecdh.json":"2d38901d22a3e7f4d5ea080ab54b09c627bcbe826d8999183c7b27514a65c73c","test_data/test_ecdsa_eddsa.json":"10e41600049e7b9ec55dab3c6eb45bb2991639f5f17a9b88fdfd619a3620d80f","test_data/test_hash.json":"c5bde32059b20f5d65e0eb1a135edf060c2104aae520496c4df2384ecd811b0d","test_data/test_private_keys.json":"623425f7fc055a411955231a70dff15e4ee5c260d244da2e77f36e24de88dff4","test_data/test_public_keys.json":"b7cedee64377fcc7b2e825d93bed6bb013a9205b4b0f37e4793a70591c6b4851","test_data/x509/another_ca.der":"e3969dfe7880f1d965f081cd4e78c294430c3042cc706430d0ce57a5a2d6a8cf","test_data/x509/ca.der":"431c5c28d3189da986d788a1467f7a7b2e512f87f0072ca8ff0d96d6db886d23","test_data/x509/cert_ip.der":"9c4117c6858a18f03427be92da5d57a14b20d6f491411fd3cca15afaad10a21d","test_data/x509/github_intermediate.der":"f7a9a1b2fd964a3f2670bd668d561fb7c55d3aa9ab8391e7e169702db8a3dbcf","test_data/x509/github_leaf.der":"b7bc5510cc1c637b5e5fb785816a773dbb394b68337b1b117ca5ab43ccf778cf","test_data/x509/intermediate.der":"29d687905fca5f7c56d20b9fbf644b176b56e08aded62bf27dffe8d19fc50bca","test_data/x509/intermediate_ca/cert.der":"0e78c2bdf6252ebbd2371a5ed7990dc49be1f8a3fda4a7dcaff64fe5c4221571","test_data/x509/intermediate_ca/cert.pem":"5064d0025b31a9a6bcd198da4430f45886a8d98c0ecb953afa5627776d7dbffb","test_data/x509/intermediate_ca/csr.pem":"911bb5f7a2b1cb637d4fc79e76b08b2fe62f323879043979d61c8c224444126e","test_data/x509/intermediate_ca/key.pem":"695df243272468925c39aefe62d74aaab83e789d7824ad068ae4a1d05c67956b","test_data/x509/leaf.der":"225336abba7cb1fa74e5778a63cfd1a0ab7676b363aec06f5e6e5c4b10198764","test_data/x509/leaf/cert.der":"51992c356265f39e970dd86924371838bf0c45819c7cc02afac21897753af9e5","test_data/x509/leaf/cert.pem":"e39488ee664217fc85c5d9a47288d37b58029c194a4515f583062d0a0d22de7e","test_data/x509/leaf/csr.der":"609e9f076dfb795a60bb45e6227385fd274d1433ebbba9a7c2cb1938348599c6","test_data/x509/leaf/csr.pem":"27308af2703d97a111705e1b2c48727156170625dad0dfa010bf24d00076ba10","test_data/x509/leaf/key":"9c7a4e9828e4b38296eb5dc522d3df90826e4f10722baf2ee90d747abf6fde92","test_data/x509/leaf/key.pem":"6ab57b02c521fb16f1879ecd80ea90bdc5a4c114e578ec1578a6f2d065b3b6d9","test_data/x509/root_ca/cert.der":"5b7be0abfe6d4cfa1de47795472f2988d93dc95d43a213aad1051d715185beeb","test_data/x509/root_ca/cert.pem":"37f2239755d3ba1fbc60fdae9ab580f8d10c0a394bc223a6cbe5d62078c4d3b2","test_data/x509/root_ca/csr.der":"43f62ffead1969dd6f0872a14575536e6e93d7039e612fd9a03befd7be6c1019","test_data/x509/root_ca/csr.pem":"f241e27006816e953aa7cf4894fc129e7edae3c8dd2381cced53515c13d54cdc","test_data/x509/root_ca/key":"dd4e29082b01643ec501d61ccfc90a9c327513b2915d997299cf80cc0395b4cc","test_data/x509/root_ca/key.pem":"91ee6d662d5f363fbce5bba74f3866896a52f4d763cd061e3e2cf4c968647bc8"},"package":null}

View File

@@ -46,17 +46,17 @@ version = "^0.4.3"
features = ["serde"] features = ["serde"]
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
default-features = false default-features = false
[dependencies.mls-rs-crypto-hpke] [dependencies.mls-rs-crypto-hpke]
version = "0.9.0" version = "0.14.0"
path = "../mls-rs-crypto-hpke" path = "../mls-rs-crypto-hpke"
default-features = false default-features = false
[dependencies.mls-rs-crypto-traits] [dependencies.mls-rs-crypto-traits]
version = "0.10.0" version = "0.15.0"
path = "../mls-rs-crypto-traits" path = "../mls-rs-crypto-traits"
default-features = false default-features = false
@@ -90,12 +90,12 @@ default-features = false
assert_matches = "1.5.0" assert_matches = "1.5.0"
[dev-dependencies.mls-rs-core] [dev-dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
features = ["test_suite"] features = ["test_suite"]
[dev-dependencies.mls-rs-crypto-hpke] [dev-dependencies.mls-rs-crypto-hpke]
version = "0.9.0" version = "0.14.0"
path = "../mls-rs-crypto-hpke" path = "../mls-rs-crypto-hpke"
features = ["test_utils"] features = ["test_utils"]
default-features = false default-features = false

View File

@@ -6,12 +6,11 @@ use core::ops::Deref;
use alloc::vec::Vec; use alloc::vec::Vec;
use mls_rs_crypto_traits::{Curve, DhType};
use mls_rs_core::{ use mls_rs_core::{
crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey}, crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey},
error::IntoAnyError, error::IntoAnyError,
}; };
use mls_rs_crypto_traits::{Curve, DhType, SamplingMethod};
use crate::ec::{ use crate::ec::{
generate_keypair, private_key_bytes_to_public, private_key_ecdh, private_key_from_bytes, generate_keypair, private_key_bytes_to_public, private_key_ecdh, private_key_from_bytes,
@@ -85,8 +84,8 @@ impl DhType for Ecdh {
Ok((key_pair.secret.into(), key_pair.public.into())) Ok((key_pair.secret.into(), key_pair.public.into()))
} }
fn bitmask_for_rejection_sampling(&self) -> Option<u8> { fn bitmask_for_rejection_sampling(&self) -> SamplingMethod {
self.curve_bitmask() self.0.hpke_sampling_method()
} }
fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error> { fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error> {
@@ -96,6 +95,10 @@ impl DhType for Ecdh {
fn secret_key_size(&self) -> usize { fn secret_key_size(&self) -> usize {
self.0.secret_key_size() self.0.secret_key_size()
} }
fn public_key_size(&self) -> usize {
self.0.public_key_size()
}
} }
impl Ecdh { impl Ecdh {

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"04f304120111e2b999f03a6087e510b8d656dd2e24db0eb33cbfadb9b87dac1a","src/aead.rs":"94d2d8aa578e6a37761c378f83de36173336de836e17702ab1d5ce56cf2059f6","src/dh.rs":"34e12aa126a50b36332152be83d90ae02a098ea6b1926f27cf22cf5a57142a02","src/ec.rs":"d4073bd3ceb6e78a25dff4626a92c2e12a91bc9d677fab3ea956e25700e4802d","src/kdf.rs":"185c0749ab69a513d2a3900cb3e15e81a1c6f24ddc6b1ca7c9e55f47026e3c77","src/kem.rs":"81b24968680b8955c9347b3e82d503e3b73e33a46e871f6ef712ebd54a002039","src/lib.rs":"285fe7de9edc81704c5b96546fe8049b1d84b4d68776cfd80f844045ba643683","src/mock.rs":"4146c616c60859f9a733e9cca888adf7d78429ded7b4c14d4ad4110534f1a876"},"package":null} {"files":{"Cargo.toml":"9390f7614441b9e1ff9c22a661596edb893ae695902d7fe8f7912ad35b80b48f","src/aead.rs":"94d2d8aa578e6a37761c378f83de36173336de836e17702ab1d5ce56cf2059f6","src/dh.rs":"b417fb405a3d3b4d77739d0caa11f2a2ed14f8df7e7016e70406be40dc473691","src/ec.rs":"ccda3dc3e4805d0efb5939a2de2ac9b3f3b59aa96a4c98814d3e4ab945b914b5","src/kdf.rs":"185c0749ab69a513d2a3900cb3e15e81a1c6f24ddc6b1ca7c9e55f47026e3c77","src/kem.rs":"aaa38f46c3a10c9092858f13c697bf84fa1e0d7e9c5a162eec8fbaf5f2413ca1","src/lib.rs":"72a995c052aa27d98876d5dc02120c6de7d30fe99a9a2f07b2e268666b7284f1","src/mock.rs":"84ab2de89878c13e294420d8c09b13d81a4916e78e4b5b354209a7b9e637f6aa"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-crypto-traits" name = "mls-rs-crypto-traits"
version = "0.10.0" version = "0.15.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false
@@ -37,7 +37,7 @@ path = "src/lib.rs"
maybe-async = "0.2.10" maybe-async = "0.2.10"
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
default-features = false default-features = false

View File

@@ -12,6 +12,13 @@ use alloc::vec::Vec;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
use mockall::automock; use mockall::automock;
#[derive(Clone, Debug, Copy)]
pub enum SamplingMethod {
HpkeWithBitmask(u8),
HpkeWithoutBitmask,
Raw,
}
/// A trait that provides the required DH functions, as in RFC 9180,Section 4.1 /// A trait that provides the required DH functions, as in RFC 9180,Section 4.1
#[cfg_attr(feature = "mock", automock(type Error = crate::mock::TestError;))] #[cfg_attr(feature = "mock", automock(type Error = crate::mock::TestError;))]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -48,9 +55,10 @@ pub trait DhType: Send + Sync {
/// significant one are filtered out), /// significant one are filtered out),
/// * `Some(0xFF)`for curves P-256 and P-384 (rejection sampling is needed but no /// * `Some(0xFF)`for curves P-256 and P-384 (rejection sampling is needed but no
/// bits need to be filtered). /// bits need to be filtered).
fn bitmask_for_rejection_sampling(&self) -> Option<u8>; fn bitmask_for_rejection_sampling(&self) -> SamplingMethod;
fn secret_key_size(&self) -> usize; fn secret_key_size(&self) -> usize;
fn public_key_size(&self) -> usize;
fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>; fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>;
} }

View File

@@ -4,6 +4,8 @@
use mls_rs_core::crypto::CipherSuite; use mls_rs_core::crypto::CipherSuite;
use crate::SamplingMethod;
/// Elliptic curve types /// Elliptic curve types
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(u8)] #[repr(u8)]
@@ -66,11 +68,11 @@ impl Curve {
} }
#[inline(always)] #[inline(always)]
pub fn curve_bitmask(&self) -> Option<u8> { pub fn hpke_sampling_method(&self) -> SamplingMethod {
match self { match self {
Curve::P256 | Curve::P384 => Some(0xFF), Curve::P256 | Curve::P384 => SamplingMethod::HpkeWithBitmask(0xFF),
Curve::P521 => Some(0x01), Curve::P521 => SamplingMethod::HpkeWithBitmask(0x01),
_ => None, _ => SamplingMethod::HpkeWithoutBitmask,
} }
} }
} }

View File

@@ -5,6 +5,7 @@
use mls_rs_core::{ use mls_rs_core::{
crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey}, crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey},
error::IntoAnyError, error::IntoAnyError,
mls_rs_codec::{self, MlsDecode, MlsEncode, MlsSize},
}; };
use alloc::vec::Vec; use alloc::vec::Vec;
@@ -20,13 +21,17 @@ use mockall::automock;
maybe_async::must_be_async maybe_async::must_be_async
)] )]
#[cfg_attr(feature = "mock", automock(type Error = crate::mock::TestError;))] #[cfg_attr(feature = "mock", automock(type Error = crate::mock::TestError;))]
pub trait KemType: Send + Sync { pub trait KemType: Send + Sync + Sized {
type Error: IntoAnyError + Send + Sync; type Error: IntoAnyError + Send + Sync;
/// KEM Id, as specified in RFC 9180, Section 5.1 and Table 2. /// KEM Id, as specified in RFC 9180, Section 5.1 and Table 2.
fn kem_id(&self) -> u16; fn kem_id(&self) -> u16;
async fn derive(&self, ikm: &[u8]) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>; async fn generate_deterministic(
&self,
seed: &[u8],
) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>;
async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>; async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error>;
fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>; fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error>;
@@ -38,9 +43,18 @@ pub trait KemType: Send + Sync {
secret_key: &HpkeSecretKey, secret_key: &HpkeSecretKey,
local_public: &HpkePublicKey, local_public: &HpkePublicKey,
) -> Result<Vec<u8>, Self::Error>; ) -> Result<Vec<u8>, Self::Error>;
fn seed_length_for_derive(&self) -> usize;
fn public_key_size(&self) -> usize;
fn secret_key_size(&self) -> usize;
fn enc_size(&self) -> usize {
self.public_key_size()
}
} }
/// Struct to represent the output of the kem [encap](KemType::encap) function /// Struct to represent the output of the kem [encap](KemType::encap) function
#[derive(Clone, Debug, MlsDecode, MlsEncode, MlsSize)]
pub struct KemResult { pub struct KemResult {
pub shared_secret: Vec<u8>, pub shared_secret: Vec<u8>,
pub enc: Vec<u8>, pub enc: Vec<u8>,

View File

@@ -12,10 +12,27 @@ mod kdf;
mod kem; mod kem;
pub use aead::{AeadId, AeadType, AEAD_ID_EXPORT_ONLY, AES_TAG_LEN}; pub use aead::{AeadId, AeadType, AEAD_ID_EXPORT_ONLY, AES_TAG_LEN};
pub use dh::DhType; pub use dh::{DhType, SamplingMethod};
pub use ec::Curve; pub use ec::Curve;
pub use kdf::{KdfId, KdfType}; pub use kdf::{KdfId, KdfType};
pub use kem::{KemId, KemResult, KemType}; pub use kem::{KemId, KemResult, KemType};
use mls_rs_core::error::IntoAnyError;
#[cfg(feature = "mock")] #[cfg(feature = "mock")]
pub mod mock; pub mod mock;
use alloc::vec::Vec;
#[cfg_attr(feature = "mock", mockall::automock(type Error = crate::mock::TestError;))]
pub trait Hash: Send + Sync {
type Error: IntoAnyError + Send + Sync;
fn hash(&self, input: &[u8]) -> Result<Vec<u8>, Self::Error>;
}
#[cfg_attr(feature = "mock", mockall::automock(type Error = crate::mock::TestError;))]
pub trait VariableLengthHash: Send + Sync {
type Error: IntoAnyError + Send + Sync;
fn hash(&self, input: &[u8], out_len: usize) -> Result<Vec<u8>, Self::Error>;
}

View File

@@ -2,7 +2,10 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
pub use crate::{aead::MockAeadType, dh::MockDhType, kdf::MockKdfType, kem::MockKemType}; pub use crate::{
aead::MockAeadType, dh::MockDhType, kdf::MockKdfType, kem::MockKemType, MockHash,
MockVariableLengthHash,
};
#[derive(Debug)] #[derive(Debug)]
pub struct TestError {} pub struct TestError {}

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"95fcf38a08016ed1876d902a1d60b1f6c1034e26a21acdea0ee93dff0e239cb4","src/error.rs":"f7301b842e465168c460767c5c0f87ebb37cad5d88d2b9637ee2e73f66a46c7b","src/identity_extractor.rs":"12477ac92fee3154ee119b7298bfb23dd23fb8cd4930499e1a2686d71c48d9c8","src/lib.rs":"b4b8eaa1c37835175e0ae658c329c894d91345404615c5059b87f7ccc2230f65","src/provider.rs":"ac9acec23541eed0b03d89aaca606f4f034b76b9166c4c84365276a721cb556a","src/traits.rs":"b2b89e3a57c888e5f3e9df5010001a19d1cac645817e6a4680f521c3ba700ec8","src/util.rs":"bc004021a55a56cea3abd12c14715a65f8646c9ca699047c51311bc3b38601b8"},"package":null} {"files":{"Cargo.toml":"350e7b0fbfbc2896e9c5ce6b7b31382f18845df5ef9de24a42ba662fb00f3720","src/error.rs":"f7301b842e465168c460767c5c0f87ebb37cad5d88d2b9637ee2e73f66a46c7b","src/identity_extractor.rs":"12477ac92fee3154ee119b7298bfb23dd23fb8cd4930499e1a2686d71c48d9c8","src/lib.rs":"483e734debb67a626ad14684b445dc5edb898745fe850a99b63d88af5e03db26","src/provider.rs":"89dfea0590c3e06abc84e39840f364b59c0abb0b72b98eef5508c662197affeb","src/traits.rs":"b2b89e3a57c888e5f3e9df5010001a19d1cac645817e6a4680f521c3ba700ec8","src/util.rs":"bc004021a55a56cea3abd12c14715a65f8646c9ca699047c51311bc3b38601b8"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-identity-x509" name = "mls-rs-identity-x509"
version = "0.11.0" version = "0.15.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false
@@ -37,7 +37,7 @@ path = "src/lib.rs"
maybe-async = "0.2.10" maybe-async = "0.2.10"
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
features = ["x509"] features = ["x509"]
default-features = false default-features = false
@@ -71,8 +71,7 @@ version = "0.2"
features = ["js"] features = ["js"]
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test]
version = "0.3.26" version = "0.3"
default-features = false
[lints.rust.unexpected_cfgs] [lints.rust.unexpected_cfgs]
level = "warn" level = "warn"

View File

@@ -21,8 +21,10 @@ pub use traits::*;
pub use mls_rs_core::identity::{CertificateChain, DerCertificate}; pub use mls_rs_core::identity::{CertificateChain, DerCertificate};
#[cfg(all(test, target_arch = "wasm32"))]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[derive(Clone, PartialEq, Eq)] #[derive(Clone, PartialEq, Eq)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
/// X.509 certificate request in DER format. /// X.509 certificate request in DER format.
pub struct DerCertificateRequest(Vec<u8>); pub struct DerCertificateRequest(Vec<u8>);
@@ -34,7 +36,6 @@ impl Debug for DerCertificateRequest {
} }
} }
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
impl DerCertificateRequest { impl DerCertificateRequest {
/// Create a DER certificate request from raw bytes. /// Create a DER certificate request from raw bytes.
pub fn new(data: Vec<u8>) -> DerCertificateRequest { pub fn new(data: Vec<u8>) -> DerCertificateRequest {

View File

@@ -9,7 +9,7 @@ use mls_rs_core::{
crypto::SignaturePublicKey, crypto::SignaturePublicKey,
error::IntoAnyError, error::IntoAnyError,
extension::ExtensionList, extension::ExtensionList,
identity::{CredentialType, IdentityProvider}, identity::{CredentialType, IdentityProvider, MemberValidationContext, SigningIdentity},
time::MlsTime, time::MlsTime,
}; };
@@ -80,10 +80,10 @@ where
/// Determine if a certificate is valid based on the behavior of the /// Determine if a certificate is valid based on the behavior of the
/// underlying validator provided. /// underlying validator provided.
pub fn validate( fn validate(
&self, &self,
signing_identity: &mls_rs_core::identity::SigningIdentity, signing_identity: &SigningIdentity,
timestamp: Option<mls_rs_core::time::MlsTime>, timestamp: Option<MlsTime>,
) -> Result<(), X509IdentityError> { ) -> Result<(), X509IdentityError> {
let chain = credential_to_chain(&signing_identity.credential)?; let chain = credential_to_chain(&signing_identity.credential)?;
@@ -98,40 +98,6 @@ where
Ok(()) Ok(())
} }
/// Produce a unique identity value to represent the entity controlling a
/// certificate credential within an MLS group.
pub fn identity(
&self,
signing_id: &mls_rs_core::identity::SigningIdentity,
) -> Result<Vec<u8>, X509IdentityError> {
self.identity_extractor
.identity(&credential_to_chain(&signing_id.credential)?)
.map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
}
/// Determine if `successor` is controlled by the same entity as
/// `predecessor` based on the behavior of the underlying identity
/// extractor provided.
pub fn valid_successor(
&self,
predecessor: &mls_rs_core::identity::SigningIdentity,
successor: &mls_rs_core::identity::SigningIdentity,
) -> Result<bool, X509IdentityError> {
self.identity_extractor
.valid_successor(
&credential_to_chain(&predecessor.credential)?,
&credential_to_chain(&successor.credential)?,
)
.map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
}
/// Supported credential types.
///
/// Only [`CredentialType::X509`] is supported.
pub fn supported_types(&self) -> Vec<mls_rs_core::identity::CredentialType> {
vec![CredentialType::X509]
}
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -143,48 +109,66 @@ where
{ {
type Error = X509IdentityError; type Error = X509IdentityError;
/// Determine if a certificate is valid based on the behavior of the
/// underlying validator provided.
async fn validate_member( async fn validate_member(
&self, &self,
signing_identity: &mls_rs_core::identity::SigningIdentity, signing_identity: &SigningIdentity,
timestamp: Option<MlsTime>, timestamp: Option<MlsTime>,
_extensions: Option<&ExtensionList>, _context: MemberValidationContext<'_>,
) -> Result<(), Self::Error> { ) -> Result<(), X509IdentityError> {
self.validate(signing_identity, timestamp) self.validate(signing_identity, timestamp)
} }
/// Produce a unique identity value to represent the entity controlling a
/// certificate credential within an MLS group.
async fn identity(
&self,
signing_id: &SigningIdentity,
_extensions: &ExtensionList,
) -> Result<Vec<u8>, X509IdentityError> {
self.identity_extractor
.identity(&credential_to_chain(&signing_id.credential)?)
.map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
}
/// Determine if `successor` is controlled by the same entity as
/// `predecessor` based on the behavior of the underlying identity
/// extractor provided.
async fn valid_successor(
&self,
predecessor: &SigningIdentity,
successor: &SigningIdentity,
_extensions: &ExtensionList,
) -> Result<bool, X509IdentityError> {
self.identity_extractor
.valid_successor(
&credential_to_chain(&predecessor.credential)?,
&credential_to_chain(&successor.credential)?,
)
.map_err(|e| X509IdentityError::IdentityExtractorError(e.into_any_error()))
}
async fn validate_external_sender( async fn validate_external_sender(
&self, &self,
signing_identity: &mls_rs_core::identity::SigningIdentity, signing_identity: &SigningIdentity,
timestamp: Option<MlsTime>, timestamp: Option<MlsTime>,
_extensions: Option<&ExtensionList>, _extensions: Option<&ExtensionList>,
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
self.validate(signing_identity, timestamp) self.validate(signing_identity, timestamp)
} }
async fn identity( /// Supported credential types.
&self, ///
signing_id: &mls_rs_core::identity::SigningIdentity, /// Only [`CredentialType::X509`] is supported.
_extensions: &ExtensionList,
) -> Result<Vec<u8>, Self::Error> {
self.identity(signing_id)
}
async fn valid_successor(
&self,
predecessor: &mls_rs_core::identity::SigningIdentity,
successor: &mls_rs_core::identity::SigningIdentity,
_extensions: &ExtensionList,
) -> Result<bool, Self::Error> {
self.valid_successor(predecessor, successor)
}
fn supported_types(&self) -> Vec<CredentialType> { fn supported_types(&self) -> Vec<CredentialType> {
self.supported_types() vec![CredentialType::X509]
} }
} }
#[cfg(all(test, feature = "std"))] #[cfg(all(test, feature = "std"))]
mod tests { mod tests {
use super::*;
use mls_rs_core::{crypto::SignaturePublicKey, identity::CredentialType, time::MlsTime}; use mls_rs_core::{crypto::SignaturePublicKey, identity::CredentialType, time::MlsTime};
use crate::{ use crate::{

View File

@@ -1 +1 @@
{"files":{"Cargo.toml":"5272c1fffa9dde0f7910abff8978f97d50d9a87f01103e1a3dc938e556dcdabb","src/application.rs":"3bbb37ba1f06391f2bae0a6a7d5caf1c5ea425115c746061efba64d389fd4268","src/cipher.rs":"1f500c0127565354086eac00c385bbaa560b7f457b97168ebaa17e23bbccebe3","src/connection_strategy.rs":"969e260c18baaf6df2a7e8489a43baed2aba319231f83c0fea77dd8c8358c180","src/group_state.rs":"17a012ada3e8d85157363ae2cea55199e2d9bbc9723e49532ad1f54e13246317","src/key_package.rs":"67607d4566d236423ce1219365b56975aff039137404ee4d43ea86323900dac6","src/lib.rs":"4a07b898cb6f2fe27a7b43cdb59d94c20175ec5ccb01a1ebeee4001cccfec9c9","src/psk.rs":"90d2255549b625f7f130272639c975b8845dd3ee261cc2629fc70b7554e3a209","src/test_utils.rs":"547d254e73606dcfcef9f7c89ca9ad7fdb51dbe16c49dc2332d24bb587939ee6"},"package":null} {"files":{"Cargo.toml":"32728aaa0a94bdbe90670fd0bb0324cc1042ff2956b1e736376588b890e478a7","src/application.rs":"cbc60559eb8e3ec2c1f9b70b9108049dbc9abd076318632d4b79638d5717bc6b","src/cipher.rs":"1f500c0127565354086eac00c385bbaa560b7f457b97168ebaa17e23bbccebe3","src/connection_strategy.rs":"969e260c18baaf6df2a7e8489a43baed2aba319231f83c0fea77dd8c8358c180","src/group_state.rs":"17a012ada3e8d85157363ae2cea55199e2d9bbc9723e49532ad1f54e13246317","src/key_package.rs":"34cf12aa7f8fee48ebb36409c863b2814afed871942746b986c0a89782d29ebc","src/lib.rs":"a5ab3024e302cb1b218fae1543714f7ebd06dffaa19de5961128d0152d1cb4f0","src/psk.rs":"3c0201c67e980cb44b84d3787f2b04b9254839a2975545a907ded400ac854caf","src/test_utils.rs":"547d254e73606dcfcef9f7c89ca9ad7fdb51dbe16c49dc2332d24bb587939ee6"},"package":null}

View File

@@ -12,7 +12,7 @@
[package] [package]
edition = "2021" edition = "2021"
name = "mls-rs-provider-sqlite" name = "mls-rs-provider-sqlite"
version = "0.11.0" version = "0.15.0"
build = false build = false
autolib = false autolib = false
autobins = false autobins = false
@@ -43,7 +43,7 @@ thiserror = "1.0.40"
version = "0.4" version = "0.4"
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
[dependencies.rusqlite] [dependencies.rusqlite]
@@ -73,7 +73,7 @@ sqlcipher-bundled = [
"sqlite", "sqlite",
"rusqlite/bundled-sqlcipher", "rusqlite/bundled-sqlcipher",
] ]
sqlite = ["rusqlite/modern_sqlite"] sqlite = []
sqlite-bundled = [ sqlite-bundled = [
"sqlite", "sqlite",
"rusqlite/bundled", "rusqlite/bundled",

View File

@@ -30,7 +30,7 @@ impl SqLiteApplicationStorage {
/// Insert `value` into storage indexed by `key`. /// Insert `value` into storage indexed by `key`.
/// ///
/// If a value already exists for `key` it will be overwritten. /// If a value already exists for `key` it will be overwritten.
pub fn insert(&self, key: String, value: Vec<u8>) -> Result<(), SqLiteDataStorageError> { pub fn insert(&self, key: &str, value: &[u8]) -> Result<(), SqLiteDataStorageError> {
let connection = self.connection.lock().unwrap(); let connection = self.connection.lock().unwrap();
// Upsert into the database // Upsert into the database
@@ -41,13 +41,13 @@ impl SqLiteApplicationStorage {
} }
/// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction. /// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction.
pub fn transact_insert(&self, items: Vec<Item>) -> Result<(), SqLiteDataStorageError> { pub fn transact_insert(&self, items: &[Item]) -> Result<(), SqLiteDataStorageError> {
let mut connection = self.connection.lock().unwrap(); let mut connection = self.connection.lock().unwrap();
// Upsert into the database // Upsert into the database
let tx = connection.transaction().map_err(sql_engine_error)?; let tx = connection.transaction().map_err(sql_engine_error)?;
items.into_iter().try_for_each(|item| { items.iter().try_for_each(|item| {
tx.execute(INSERT_SQL, params![item.key, item.value]) tx.execute(INSERT_SQL, params![item.key, item.value])
.map_err(sql_engine_error) .map_err(sql_engine_error)
.map(|_| ()) .map(|_| ())
@@ -179,7 +179,7 @@ mod tests {
let (key, value) = test_kv(); let (key, value) = test_kv();
let storage = test_storage(); let storage = test_storage();
storage.insert(key.clone(), value.clone()).unwrap(); storage.insert(&key, &value).unwrap();
let from_storage = storage.get(&key).unwrap().unwrap(); let from_storage = storage.get(&key).unwrap().unwrap();
assert_eq!(from_storage, value); assert_eq!(from_storage, value);
@@ -192,8 +192,8 @@ mod tests {
let storage = test_storage(); let storage = test_storage();
storage.insert(key.clone(), value).unwrap(); storage.insert(&key, &value).unwrap();
storage.insert(key.clone(), new_value.clone()).unwrap(); storage.insert(&key, &new_value).unwrap();
let from_storage = storage.get(&key).unwrap().unwrap(); let from_storage = storage.get(&key).unwrap().unwrap();
assert_eq!(from_storage, new_value); assert_eq!(from_storage, new_value);
@@ -204,7 +204,7 @@ mod tests {
let (key, value) = test_kv(); let (key, value) = test_kv();
let storage = test_storage(); let storage = test_storage();
storage.insert(key.clone(), value).unwrap(); storage.insert(&key, &value).unwrap();
storage.delete(&key).unwrap(); storage.delete(&key).unwrap();
assert!(storage.get(&key).unwrap().is_none()); assert!(storage.get(&key).unwrap().is_none());
@@ -217,8 +217,7 @@ mod tests {
let storage = test_storage(); let storage = test_storage();
keys.iter() keys.iter().for_each(|k| storage.insert(k, &value).unwrap());
.for_each(|k| storage.insert(k.clone(), value.clone()).unwrap());
let mut expected = vec![ let mut expected = vec![
Item::new(keys[0].clone(), value.clone()), Item::new(keys[0].clone(), value.clone()),
@@ -249,15 +248,9 @@ mod tests {
fn test_special_characters() { fn test_special_characters() {
let storage = test_storage(); let storage = test_storage();
storage storage.insert("%$_ƕ❤_$%", &gen_rand_bytes(5)).unwrap();
.insert("%$_ƕ❤_$%".to_string(), gen_rand_bytes(5)) storage.insert("%$_ƕ❤a$%", &gen_rand_bytes(5)).unwrap();
.unwrap(); storage.insert("%$_ƕ❤Ḉ$%", &gen_rand_bytes(5)).unwrap();
storage
.insert("%$_ƕ❤a$%".to_string(), gen_rand_bytes(5))
.unwrap();
storage
.insert("%$_ƕ❤Ḉ$%".to_string(), gen_rand_bytes(5))
.unwrap();
let items = storage.get_by_prefix("%$_ƕ❤_").unwrap(); let items = storage.get_by_prefix("%$_ƕ❤_").unwrap();
let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>(); let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>();
@@ -269,7 +262,7 @@ mod tests {
let storage = test_storage(); let storage = test_storage();
let items = vec![test_item(), test_item(), test_item()]; let items = vec![test_item(), test_item(), test_item()];
storage.transact_insert(items.clone()).unwrap(); storage.transact_insert(&items).unwrap();
for item in items { for item in items {
assert_eq!(storage.get(&item.key).unwrap(), Some(item.value)); assert_eq!(storage.get(&item.key).unwrap(), Some(item.value));

View File

@@ -75,10 +75,13 @@ impl SqLiteKeyPackageStorage {
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
} }
/// Delete key packages that are expired based on the current system clock time.
pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> { pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> {
self.delete_expired_by_time(MlsTime::now().seconds_since_epoch()) self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
} }
/// Delete key packages that are expired based on an application provided time in seconds since
/// unix epoch.
pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> { pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> {
let connection = self.connection.lock().unwrap(); let connection = self.connection.lock().unwrap();
@@ -90,6 +93,32 @@ impl SqLiteKeyPackageStorage {
.map(|_| ()) .map(|_| ())
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
} }
/// Total number of key packages held in storage.
pub fn count(&self) -> Result<usize, SqLiteDataStorageError> {
let connection = self.connection.lock().unwrap();
connection
.query_row("SELECT count(*) FROM key_package", params![], |row| {
row.get(0)
})
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
}
/// Total number of key packages that will still remain in storage at a specific application provided
/// time in seconds since unix epoch. This assumes that the application would also be calling
/// [SqLiteKeyPackageStorage::delete_expired] at a reasonable cadence to be accurate.
pub fn count_at_time(&self, time: u64) -> Result<usize, SqLiteDataStorageError> {
let connection = self.connection.lock().unwrap();
connection
.query_row(
"SELECT count(*) FROM key_package where expiration >= ?",
params![time],
|row| row.get(0),
)
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
}
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -215,4 +244,37 @@ mod tests {
assert!(storage.get(&data[2].0).unwrap().is_none()); assert!(storage.get(&data[2].0).unwrap().is_none());
assert!(storage.get(&data[3].0).unwrap().is_none()); assert!(storage.get(&data[3].0).unwrap().is_none());
} }
#[test]
fn key_count() {
let mut storage = test_storage();
let test_packages = (0..10).map(|_| test_key_package()).collect::<Vec<_>>();
test_packages
.into_iter()
.for_each(|(key_package_id, key_package)| {
storage.insert(&key_package_id, key_package).unwrap();
});
assert_eq!(storage.count().unwrap(), 10);
}
#[test]
fn key_count_at_time() {
let mut storage = test_storage();
let mut kp_1 = test_key_package();
kp_1.1.expiration = 1;
storage.insert(&kp_1.0, kp_1.1).unwrap();
let mut kp_2 = test_key_package();
kp_2.1.expiration = 2;
storage.insert(&kp_2.0, kp_2.1).unwrap();
assert_eq!(storage.count_at_time(3).unwrap(), 0);
assert_eq!(storage.count_at_time(2).unwrap(), 1);
assert_eq!(storage.count_at_time(1).unwrap(), 2);
assert_eq!(storage.count_at_time(0).unwrap(), 2);
}
} }

View File

@@ -54,6 +54,31 @@ impl mls_rs_core::error::IntoAnyError for SqLiteDataStorageError {
} }
} }
#[derive(Clone, Debug)]
pub enum JournalMode {
Delete,
Truncate,
Persist,
Memory,
Wal,
Off,
}
/// Note: for in-memory dbs (such as what the tests use), the only available options are MEMORY or OFF
/// Invalid modes do not error, only no-op
impl JournalMode {
fn as_str(&self) -> &'static str {
match self {
JournalMode::Delete => "DELETE",
JournalMode::Truncate => "TRUNCATE",
JournalMode::Persist => "PERSIST",
JournalMode::Memory => "MEMORY",
JournalMode::Wal => "WAL",
JournalMode::Off => "OFF",
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// SQLite data storage engine. /// SQLite data storage engine.
pub struct SqLiteDataStorageEngine<CS> pub struct SqLiteDataStorageEngine<CS>
@@ -62,6 +87,7 @@ where
{ {
connection_strategy: CS, connection_strategy: CS,
group_state_context: Option<Vec<u8>>, group_state_context: Option<Vec<u8>>,
journal_mode: Option<JournalMode>,
} }
impl<CS> SqLiteDataStorageEngine<CS> impl<CS> SqLiteDataStorageEngine<CS>
@@ -74,6 +100,7 @@ where
Ok(SqLiteDataStorageEngine { Ok(SqLiteDataStorageEngine {
connection_strategy, connection_strategy,
group_state_context: None, group_state_context: None,
journal_mode: None,
}) })
} }
@@ -84,6 +111,14 @@ where
} }
} }
/// A `journal_mode` of `None` means the SQLite default is used.
pub fn with_journal_mode(self, journal_mode: Option<JournalMode>) -> Self {
Self {
journal_mode,
..self
}
}
fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> { fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> {
let connection = self.connection_strategy.make_connection()?; let connection = self.connection_strategy.make_connection()?;
@@ -92,6 +127,12 @@ where
.pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0)) .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?; .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
if let Some(journal_mode) = &self.journal_mode {
connection
.pragma_update(None, "journal_mode", journal_mode.as_str())
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
}
if current_schema != 1 { if current_schema != 1 {
create_tables_v1(&connection)?; create_tables_v1(&connection)?;
} }
@@ -164,7 +205,12 @@ fn create_tables_v1(connection: &Connection) -> Result<(), SqLiteDataStorageErro
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{connection_strategy::MemoryStrategy, SqLiteDataStorageEngine}; use tempfile::tempdir;
use crate::{
connection_strategy::{FileConnectionStrategy, MemoryStrategy},
SqLiteDataStorageEngine,
};
#[test] #[test]
pub fn user_version_test() { pub fn user_version_test() {
@@ -182,4 +228,26 @@ mod tests {
assert_eq!(current_schema, 1); assert_eq!(current_schema, 1);
} }
#[test]
pub fn journal_mode_test() {
let temp = tempdir().unwrap();
// Connect with journal_mode other than the default of MEMORY
let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
&temp.path().join("test_db.sqlite"),
))
.unwrap();
let connection = database
.with_journal_mode(Some(crate::JournalMode::Truncate))
.create_connection()
.unwrap();
let journal_mode = connection
.pragma_query_value(None, "journal_mode", |rows| rows.get::<_, String>(0))
.unwrap();
assert_eq!(journal_mode, "truncate");
}
} }

View File

@@ -24,7 +24,7 @@ impl SqLitePreSharedKeyStorage {
} }
/// Insert a pre-shared key into storage. /// Insert a pre-shared key into storage.
pub fn insert(&self, psk_id: Vec<u8>, psk: PreSharedKey) -> Result<(), SqLiteDataStorageError> { pub fn insert(&self, psk_id: &[u8], psk: &PreSharedKey) -> Result<(), SqLiteDataStorageError> {
let connection = self.connection.lock().unwrap(); let connection = self.connection.lock().unwrap();
// Upsert into the database // Upsert into the database
@@ -103,7 +103,7 @@ mod tests {
let (psk_id, psk) = test_psk(); let (psk_id, psk) = test_psk();
let storage = test_storage(); let storage = test_storage();
storage.insert(psk_id.clone(), psk.clone()).unwrap(); storage.insert(&psk_id, &psk).unwrap();
let from_storage = storage.get(&psk_id).unwrap().unwrap(); let from_storage = storage.get(&psk_id).unwrap().unwrap();
assert_eq!(from_storage, psk); assert_eq!(from_storage, psk);
@@ -116,8 +116,8 @@ mod tests {
let storage = test_storage(); let storage = test_storage();
storage.insert(psk_id.clone(), psk).unwrap(); storage.insert(&psk_id, &psk).unwrap();
storage.insert(psk_id.clone(), new_psk.clone()).unwrap(); storage.insert(&psk_id, &new_psk).unwrap();
let from_storage = storage.get(&psk_id).unwrap().unwrap(); let from_storage = storage.get(&psk_id).unwrap().unwrap();
assert_eq!(from_storage, new_psk); assert_eq!(from_storage, new_psk);
@@ -128,7 +128,7 @@ mod tests {
let (psk_id, psk) = test_psk(); let (psk_id, psk) = test_psk();
let storage = test_storage(); let storage = test_storage();
storage.insert(psk_id.clone(), psk).unwrap(); storage.insert(&psk_id, &psk).unwrap();
storage.delete(&psk_id).unwrap(); storage.delete(&psk_id).unwrap();
assert!(storage.get(&psk_id).unwrap().is_none()); assert!(storage.get(&psk_id).unwrap().is_none());

File diff suppressed because one or more lines are too long

View File

@@ -13,7 +13,7 @@
edition = "2021" edition = "2021"
rust-version = "1.68.2" rust-version = "1.68.2"
name = "mls-rs" name = "mls-rs"
version = "0.39.1" version = "0.45.0"
build = false build = false
exclude = ["test_data"] exclude = ["test_data"]
autolib = false autolib = false
@@ -142,28 +142,33 @@ default-features = false
version = "0.2.10" version = "0.2.10"
[dependencies.mls-rs-codec] [dependencies.mls-rs-codec]
version = "0.5.2" version = "0.6"
path = "../mls-rs-codec" path = "../mls-rs-codec"
default-features = false default-features = false
[dependencies.mls-rs-core] [dependencies.mls-rs-core]
version = "0.18.0" version = "0.21.0"
path = "../mls-rs-core" path = "../mls-rs-core"
default-features = false default-features = false
[dependencies.mls-rs-crypto-awslc]
version = "0.15"
path = "../mls-rs-crypto-awslc"
optional = true
[dependencies.mls-rs-crypto-openssl] [dependencies.mls-rs-crypto-openssl]
version = "0.9.0" version = "0.14.0"
path = "../mls-rs-crypto-openssl" path = "../mls-rs-crypto-openssl"
optional = true optional = true
[dependencies.mls-rs-identity-x509] [dependencies.mls-rs-identity-x509]
version = "0.11.0" version = "0.15.0"
path = "../mls-rs-identity-x509" path = "../mls-rs-identity-x509"
optional = true optional = true
default-features = false default-features = false
[dependencies.mls-rs-provider-sqlite] [dependencies.mls-rs-provider-sqlite]
version = "0.11.0" version = "0.15.0"
path = "../mls-rs-provider-sqlite" path = "../mls-rs-provider-sqlite"
optional = true optional = true
default-features = false default-features = false
@@ -232,6 +237,7 @@ arbitrary = [
"dep:arbitrary", "dep:arbitrary",
"mls-rs-core/arbitrary", "mls-rs-core/arbitrary",
] ]
benchmark_pq_crypto = ["mls-rs-crypto-awslc/post-quantum"]
benchmark_util = [ benchmark_util = [
"test_util", "test_util",
"default", "default",
@@ -256,6 +262,8 @@ fuzz_util = [
"dep:mls-rs-crypto-openssl", "dep:mls-rs-crypto-openssl",
] ]
grease = ["std"] grease = ["std"]
last_resort_key_package_ext = ["mls-rs-core/last_resort_key_package_ext"]
non_domain_separated_hpke_encrypt_decrypt = []
out_of_order = ["private_message"] out_of_order = ["private_message"]
prior_epoch = [] prior_epoch = []
private_message = [] private_message = []
@@ -265,7 +273,6 @@ rayon = [
"dep:rayon", "dep:rayon",
] ]
rfc_compliant = [ rfc_compliant = [
"state_update",
"private_message", "private_message",
"custom_proposal", "custom_proposal",
"out_of_order", "out_of_order",
@@ -298,7 +305,6 @@ sqlite-bundled = [
"sqlite", "sqlite",
"mls-rs-provider-sqlite/sqlite-bundled", "mls-rs-provider-sqlite/sqlite-bundled",
] ]
state_update = []
std = [ std = [
"mls-rs-core/std", "mls-rs-core/std",
"mls-rs-codec/std", "mls-rs-codec/std",
@@ -337,7 +343,7 @@ features = [
] ]
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.mls-rs-crypto-openssl] [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.mls-rs-crypto-openssl]
version = "0.9.0" version = "0.14.0"
path = "../mls-rs-crypto-openssl" path = "../mls-rs-crypto-openssl"
[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] [target.'cfg(target_arch = "wasm32")'.dependencies.getrandom]
@@ -354,7 +360,7 @@ features = ["alloc"]
default-features = false default-features = false
[target.'cfg(target_arch = "wasm32")'.dependencies.wasm-bindgen] [target.'cfg(target_arch = "wasm32")'.dependencies.wasm-bindgen]
version = "^0.2.79" version = "0.2"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion]
version = "0.5.1" version = "0.5.1"
@@ -367,12 +373,11 @@ features = [
default-features = false default-features = false
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.mls-rs-crypto-webcrypto] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.mls-rs-crypto-webcrypto]
version = "0.4.0" version = "0.8.0"
path = "../mls-rs-crypto-webcrypto" path = "../mls-rs-crypto-webcrypto"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test] [target.'cfg(target_arch = "wasm32")'.dev-dependencies.wasm-bindgen-test]
version = "0.3.26" version = "0.3"
default-features = false
[lints.rust.unexpected_cfgs] [lints.rust.unexpected_cfgs]
level = "warn" level = "warn"

View File

@@ -60,6 +60,7 @@ For cipher suite descriptions see the RFC documentation [here](https://www.rfc-e
| Rust Crypto | 1,2,3 | ⚠️ Experimental | | Rust Crypto | 1,2,3 | ⚠️ Experimental |
| Web Crypto | ⚠️ Experimental 2,5,7 | Unsupported | | Web Crypto | ⚠️ Experimental 2,5,7 | Unsupported |
| CryptoKit | 1,2,3,5,7 | Unsupported | | CryptoKit | 1,2,3,5,7 | Unsupported |
| NSS | 1,2,3 | Unsupported |
## Security Notice ## Security Notice

View File

@@ -10,13 +10,13 @@ use mls_rs::{
SigningIdentity, SigningIdentity,
}, },
mls_rules::{CommitOptions, DefaultMlsRules}, mls_rules::{CommitOptions, DefaultMlsRules},
CipherSuite, CipherSuiteProvider, Client, CryptoProvider, test_utils::benchmarks::{MlsCryptoProvider, BENCH_CIPHER_SUITE},
CipherSuiteProvider, Client, CryptoProvider,
}; };
use mls_rs_crypto_openssl::OpensslCryptoProvider;
fn bench(c: &mut Criterion) { fn bench(c: &mut Criterion) {
let alice = make_client("alice") let alice = make_client("alice")
.create_group(Default::default()) .create_group(Default::default(), Default::default())
.unwrap(); .unwrap();
const MAX_ADD_COUNT: usize = 1000; const MAX_ADD_COUNT: usize = 1000;
@@ -24,7 +24,7 @@ fn bench(c: &mut Criterion) {
let key_packages = (0..MAX_ADD_COUNT) let key_packages = (0..MAX_ADD_COUNT)
.map(|i| { .map(|i| {
make_client(&format!("bob-{i}")) make_client(&format!("bob-{i}"))
.generate_key_package_message() .generate_key_package_message(Default::default(), Default::default())
.unwrap() .unwrap()
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -59,8 +59,8 @@ criterion::criterion_group!(benches, bench);
criterion::criterion_main!(benches); criterion::criterion_main!(benches);
fn make_client(name: &str) -> Client<impl MlsConfig> { fn make_client(name: &str) -> Client<impl MlsConfig> {
let crypto_provider = OpensslCryptoProvider::new(); let crypto_provider = MlsCryptoProvider::new();
let cipher_suite = CipherSuite::CURVE25519_AES128; let cipher_suite = BENCH_CIPHER_SUITE;
let (secret_key, public_key) = crypto_provider let (secret_key, public_key) = crypto_provider
.cipher_suite_provider(cipher_suite) .cipher_suite_provider(cipher_suite)

View File

@@ -3,13 +3,11 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use criterion::{BatchSize, BenchmarkId, Criterion, Throughput}; use criterion::{BatchSize, BenchmarkId, Criterion, Throughput};
use mls_rs::test_utils::benchmarks::load_group_states; use mls_rs::test_utils::benchmarks::{load_group_states, BENCH_CIPHER_SUITE};
use mls_rs::CipherSuite;
use rand::RngCore; use rand::RngCore;
fn bench(c: &mut Criterion) { fn bench(c: &mut Criterion) {
let cipher_suite = CipherSuite::CURVE25519_AES128; let group_states = load_group_states().pop().unwrap();
let group_states = load_group_states(cipher_suite).pop().unwrap();
let mut bytes = vec![0; 1000000]; let mut bytes = vec![0; 1000000];
rand::thread_rng().fill_bytes(&mut bytes); rand::thread_rng().fill_bytes(&mut bytes);
@@ -21,7 +19,7 @@ fn bench(c: &mut Criterion) {
while n <= 1000000 { while n <= 1000000 {
bench_group.throughput(Throughput::Bytes(n as u64)); bench_group.throughput(Throughput::Bytes(n as u64));
bench_group.bench_with_input( bench_group.bench_with_input(
BenchmarkId::new(format!("{cipher_suite:?}"), n), BenchmarkId::new(format!("{BENCH_CIPHER_SUITE:?}"), n),
&n, &n,
|b, _| { |b, _| {
b.iter_batched_ref( b.iter_batched_ref(

View File

@@ -3,16 +3,16 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use criterion::{BatchSize, BenchmarkId, Criterion}; use criterion::{BatchSize, BenchmarkId, Criterion};
use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite}; use mls_rs::test_utils::benchmarks::{load_group_states, BENCH_CIPHER_SUITE};
fn bench(c: &mut Criterion) { fn bench(c: &mut Criterion) {
let cipher_suite = CipherSuite::CURVE25519_AES128; let group_states = load_group_states();
let group_states = load_group_states(cipher_suite);
let mut bench_group = c.benchmark_group("group_commit"); let mut bench_group = c.benchmark_group("group_commit");
for (i, group_states) in group_states.into_iter().enumerate() { for (i, group_states) in group_states.into_iter().enumerate() {
bench_group.bench_with_input( bench_group.bench_with_input(
BenchmarkId::new(format!("{cipher_suite:?}"), i), BenchmarkId::new(format!("{BENCH_CIPHER_SUITE:?}"), i),
&i, &i,
|b, _| { |b, _| {
b.iter_batched_ref( b.iter_batched_ref(

View File

@@ -3,16 +3,15 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use criterion::{BatchSize, BenchmarkId, Criterion}; use criterion::{BatchSize, BenchmarkId, Criterion};
use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite}; use mls_rs::test_utils::benchmarks::{load_group_states, BENCH_CIPHER_SUITE};
fn bench(c: &mut Criterion) { fn bench(c: &mut Criterion) {
let cipher_suite = CipherSuite::CURVE25519_AES128; let group_states = load_group_states();
let group_states = load_group_states(cipher_suite);
let mut bench_group = c.benchmark_group("group_receive_commit"); let mut bench_group = c.benchmark_group("group_receive_commit");
for (i, mut group_states) in group_states.into_iter().enumerate() { for (i, mut group_states) in group_states.into_iter().enumerate() {
bench_group.bench_with_input( bench_group.bench_with_input(
BenchmarkId::new(format!("{cipher_suite:?}"), i), BenchmarkId::new(format!("{BENCH_CIPHER_SUITE:?}"), i),
&i, &i,
|b, _| { |b, _| {
b.iter_batched_ref( b.iter_batched_ref(

View File

@@ -2,25 +2,28 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite}; use mls_rs::test_utils::benchmarks::{load_group_states, BENCH_CIPHER_SUITE};
use criterion::{BenchmarkId, Criterion}; use criterion::{BenchmarkId, Criterion};
fn bench_serialize(c: &mut Criterion) { fn bench_serialize(c: &mut Criterion) {
use criterion::BatchSize; use criterion::BatchSize;
let cs = CipherSuite::CURVE25519_AES128; let group_states = load_group_states();
let group_states = load_group_states(cs);
let mut bench_group = c.benchmark_group("group_serialize"); let mut bench_group = c.benchmark_group("group_serialize");
for (i, group_states) in group_states.into_iter().enumerate() { for (i, group_states) in group_states.into_iter().enumerate() {
bench_group.bench_with_input(BenchmarkId::new(format!("{cs:?}"), i), &i, |b, _| { bench_group.bench_with_input(
b.iter_batched_ref( BenchmarkId::new(format!("{BENCH_CIPHER_SUITE:?}"), i),
|| group_states.sender.clone(), &i,
move |sender| sender.write_to_storage().unwrap(), |b, _| {
BatchSize::SmallInput, b.iter_batched_ref(
) || group_states.sender.clone(),
}); move |sender| sender.write_to_storage().unwrap(),
BatchSize::SmallInput,
)
},
);
} }
bench_group.finish(); bench_group.finish();

View File

@@ -140,8 +140,9 @@ fn main() -> Result<(), MlsError> {
let bob = make_client("bob")?; let bob = make_client("bob")?;
// Alice creates a group with bob // Alice creates a group with bob
let mut alice_group = alice.create_group(ExtensionList::default())?; let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
let bob_key_package = bob.generate_key_package_message()?; let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;
let welcome = &alice_group let welcome = &alice_group
.commit_builder() .commit_builder()

View File

@@ -44,10 +44,11 @@ fn main() -> Result<(), MlsError> {
let bob = make_client(crypto_provider.clone(), "bob")?; let bob = make_client(crypto_provider.clone(), "bob")?;
// Alice creates a new group. // Alice creates a new group.
let mut alice_group = alice.create_group(ExtensionList::default())?; let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
// Bob generates a key package that Alice needs to add Bob to the group. // Bob generates a key package that Alice needs to add Bob to the group.
let bob_key_package = bob.generate_key_package_message()?; let bob_key_package =
bob.generate_key_package_message(Default::default(), Default::default())?;
// Alice issues a commit that adds Bob to the group. // Alice issues a commit that adds Bob to the group.
let alice_commit = alice_group let alice_commit = alice_group

View File

@@ -29,7 +29,7 @@ use mls_rs::{
error::MlsError, error::MlsError,
group::{ group::{
proposal::{MlsCustomProposal, Proposal}, proposal::{MlsCustomProposal, Proposal},
Roster, Sender, GroupContext, Roster, Sender,
}, },
mls_rules::{ mls_rules::{
CommitDirection, CommitOptions, CommitSource, EncryptionOptions, ProposalBundle, CommitDirection, CommitOptions, CommitSource, EncryptionOptions, ProposalBundle,
@@ -44,7 +44,10 @@ use mls_rs_core::{
error::IntoAnyError, error::IntoAnyError,
extension::{ExtensionError, ExtensionType, MlsCodecExtension}, extension::{ExtensionError, ExtensionType, MlsCodecExtension},
group::ProposalType, group::ProposalType,
identity::{Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity}, identity::{
Credential, CredentialType, CustomCredential, MemberValidationContext, MlsCredential,
SigningIdentity,
},
time::MlsTime, time::MlsTime,
}; };
@@ -128,12 +131,16 @@ impl MlsRules for CustomMlsRules {
_: CommitDirection, _: CommitDirection,
_: CommitSource, _: CommitSource,
_: &Roster, _: &Roster,
extension_list: &ExtensionList, context: &GroupContext,
mut proposals: ProposalBundle, mut proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
// Find our extension // Find our extension
let mut roster: RosterExtension = let mut roster: RosterExtension = context
extension_list.get_as().ok().flatten().ok_or(CustomError)?; .extensions
.get_as()
.ok()
.flatten()
.ok_or(CustomError)?;
// Find AddUser proposals // Find AddUser proposals
let add_user_proposals = proposals let add_user_proposals = proposals
@@ -149,7 +156,7 @@ impl MlsRules for CustomMlsRules {
} }
// Issue GroupContextExtensions proposal to modify our roster (eventually we don't have to do this if there were no AddUser proposals) // Issue GroupContextExtensions proposal to modify our roster (eventually we don't have to do this if there were no AddUser proposals)
let mut new_extensions = extension_list.clone(); let mut new_extensions = context.extensions.clone();
new_extensions.set_from(roster)?; new_extensions.set_from(roster)?;
let gce_proposal = Proposal::GroupContextExtensions(new_extensions); let gce_proposal = Proposal::GroupContextExtensions(new_extensions);
proposals.add(gce_proposal, Sender::Member(0), ProposalSource::Local); proposals.add(gce_proposal, Sender::Member(0), ProposalSource::Local);
@@ -160,7 +167,7 @@ impl MlsRules for CustomMlsRules {
fn commit_options( fn commit_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: &ProposalBundle, _: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
Ok(CommitOptions::new()) Ok(CommitOptions::new())
@@ -169,7 +176,7 @@ impl MlsRules for CustomMlsRules {
fn encryption_options( fn encryption_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
Ok(EncryptionOptions::new(false, PaddingMode::None)) Ok(EncryptionOptions::new(false, PaddingMode::None))
} }
@@ -202,9 +209,9 @@ impl IdentityProvider for CustomIdentityProvider {
&self, &self,
signing_identity: &SigningIdentity, signing_identity: &SigningIdentity,
_: Option<MlsTime>, _: Option<MlsTime>,
extensions: Option<&ExtensionList>, context: MemberValidationContext<'_>,
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
let Some(extensions) = extensions else { let Some(extensions) = context.new_extensions() else {
return Ok(()); return Ok(());
}; };
@@ -369,11 +376,13 @@ fn main() -> Result<(), CustomError> {
let roster = vec![alice.credential]; let roster = vec![alice.credential];
context_extensions.set_from(RosterExtension { roster })?; context_extensions.set_from(RosterExtension { roster })?;
let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?; let mut alice_tablet_group =
make_client(alice_tablet)?.create_group(context_extensions, Default::default())?;
// Alice can add her other device // Alice can add her other device
let alice_pc_client = make_client(alice_pc)?; let alice_pc_client = make_client(alice_pc)?;
let key_package = alice_pc_client.generate_key_package_message()?; let key_package =
alice_pc_client.generate_key_package_message(Default::default(), Default::default())?;
let welcome = alice_tablet_group let welcome = alice_tablet_group
.commit_builder() .commit_builder()
@@ -387,7 +396,8 @@ fn main() -> Result<(), CustomError> {
// Alice cannot add bob's devices yet // Alice cannot add bob's devices yet
let bob_tablet_client = make_client(bob_tablet)?; let bob_tablet_client = make_client(bob_tablet)?;
let key_package = bob_tablet_client.generate_key_package_message()?; let key_package =
bob_tablet_client.generate_key_package_message(Default::default(), Default::default())?;
let res = alice_tablet_group let res = alice_tablet_group
.commit_builder() .commit_builder()

View File

@@ -58,7 +58,7 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> { ) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let bob_client = make_client(crypto_provider.clone(), &make_name(0))?; let bob_client = make_client(crypto_provider.clone(), &make_name(0))?;
let bob_group = bob_client.create_group(Default::default())?; let bob_group = bob_client.create_group(Default::default(), Default::default())?;
let mut groups = vec![bob_group]; let mut groups = vec![bob_group];
@@ -66,7 +66,8 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?; let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?;
// The new client generates a key package. // The new client generates a key package.
let bob_kpkg = bob_client.generate_key_package_message()?; let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;
// Last group sends a commit adding the new client to the group. // Last group sends a commit adding the new client to the group.
let commit = groups let commit = groups
@@ -100,7 +101,7 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> { ) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
let alice_client = make_client(crypto_provider.clone(), &make_name(0))?; let alice_client = make_client(crypto_provider.clone(), &make_name(0))?;
let mut alice_group = alice_client.create_group(Default::default())?; let mut alice_group = alice_client.create_group(Default::default(), Default::default())?;
let bob_clients = (0..(num_groups - 1)) let bob_clients = (0..(num_groups - 1))
.map(|i| make_client(crypto_provider.clone(), &make_name(i + 1))) .map(|i| make_client(crypto_provider.clone(), &make_name(i + 1)))
@@ -110,7 +111,8 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
let mut commit_builder = alice_group.commit_builder(); let mut commit_builder = alice_group.commit_builder();
for bob_client in &bob_clients { for bob_client in &bob_clients {
let bob_kpkg = bob_client.generate_key_package_message()?; let bob_kpkg =
bob_client.generate_key_package_message(Default::default(), Default::default())?;
commit_builder = commit_builder.add_member(bob_kpkg)?; commit_builder = commit_builder.add_member(bob_kpkg)?;
} }

View File

@@ -31,7 +31,9 @@ fn main() {
.signing_identity(signing_identity, secret_key, CIPHERSUITE) .signing_identity(signing_identity, secret_key, CIPHERSUITE)
.build(); .build();
let mut alice_group = alice_client.create_group(Default::default()).unwrap(); let mut alice_group = alice_client
.create_group(Default::default(), Default::default())
.unwrap();
alice_group.commit(Vec::new()).unwrap(); alice_group.commit(Vec::new()).unwrap();
alice_group.apply_pending_commit().unwrap(); alice_group.apply_pending_commit().unwrap();

View File

@@ -7,13 +7,16 @@ use crate::client_builder::{recreate_config, BaseConfig, ClientBuilder, MakeConf
use crate::client_config::ClientConfig; use crate::client_config::ClientConfig;
use crate::group::framing::MlsMessage; use crate::group::framing::MlsMessage;
use crate::group::{cipher_suite_provider, validate_group_info_joiner, GroupInfo};
use crate::group::{
framing::MlsMessagePayload, snapshot::Snapshot, ExportedTree, Group, NewMemberInfo,
};
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use crate::group::{ use crate::group::{
framing::{Content, MlsMessagePayload, PublicMessage, Sender, WireFormat}, framing::{Content, PublicMessage, Sender, WireFormat},
message_signature::AuthenticatedContent, message_signature::AuthenticatedContent,
proposal::{AddProposal, Proposal}, proposal::{AddProposal, Proposal},
}; };
use crate::group::{snapshot::Snapshot, ExportedTree, Group, NewMemberInfo};
use crate::identity::SigningIdentity; use crate::identity::SigningIdentity;
use crate::key_package::{KeyPackageGeneration, KeyPackageGenerator}; use crate::key_package::{KeyPackageGeneration, KeyPackageGenerator};
use crate::protocol_version::ProtocolVersion; use crate::protocol_version::ProtocolVersion;
@@ -24,7 +27,7 @@ use mls_rs_core::crypto::{CryptoProvider, SignatureSecretKey};
use mls_rs_core::error::{AnyError, IntoAnyError}; use mls_rs_core::error::{AnyError, IntoAnyError};
use mls_rs_core::extension::{ExtensionError, ExtensionList, ExtensionType}; use mls_rs_core::extension::{ExtensionError, ExtensionList, ExtensionType};
use mls_rs_core::group::{GroupStateStorage, ProposalType}; use mls_rs_core::group::{GroupStateStorage, ProposalType};
use mls_rs_core::identity::CredentialType; use mls_rs_core::identity::{CredentialType, IdentityProvider, MemberValidationContext};
use mls_rs_core::key_package::KeyPackageStorage; use mls_rs_core::key_package::KeyPackageStorage;
use crate::group::external_commit::ExternalCommitBuilder; use crate::group::external_commit::ExternalCommitBuilder;
@@ -335,6 +338,8 @@ pub enum MlsError {
InvalidGroupInfo, InvalidGroupInfo,
#[cfg_attr(feature = "std", error("Invalid welcome message"))] #[cfg_attr(feature = "std", error("Invalid welcome message"))]
InvalidWelcomeMessage, InvalidWelcomeMessage,
#[cfg_attr(feature = "std", error("Exporter deleted"))]
ExporterDeleted,
} }
impl IntoAnyError for MlsError { impl IntoAnyError for MlsError {
@@ -426,12 +431,23 @@ where
/// ///
/// A key package message may only be used once. /// A key package message may only be used once.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> { pub async fn generate_key_package_message(
Ok(self.generate_key_package().await?.key_package_message()) &self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> {
Ok(self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package_message())
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> { async fn generate_key_package(
&self,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<KeyPackageGeneration, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?; let (signing_identity, cipher_suite) = self.signing_identity()?;
let cipher_suite_provider = self let cipher_suite_provider = self
@@ -445,15 +461,14 @@ where
cipher_suite_provider: &cipher_suite_provider, cipher_suite_provider: &cipher_suite_provider,
signing_key: self.signer()?, signing_key: self.signer()?,
signing_identity, signing_identity,
identity_provider: &self.config.identity_provider(),
}; };
let key_pkg_gen = key_package_generator let key_pkg_gen = key_package_generator
.generate( .generate(
self.config.lifetime(), self.config.lifetime(),
self.config.capabilities(), self.config.capabilities(),
self.config.key_package_extensions(), key_package_extensions,
self.config.leaf_node_extensions(), leaf_node_extensions,
) )
.await?; .await?;
@@ -484,6 +499,7 @@ where
&self, &self,
group_id: Vec<u8>, group_id: Vec<u8>,
group_context_extensions: ExtensionList, group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> { ) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?; let (signing_identity, cipher_suite) = self.signing_identity()?;
@@ -494,6 +510,7 @@ where
self.version, self.version,
signing_identity.clone(), signing_identity.clone(),
group_context_extensions, group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(), self.signer()?.clone(),
) )
.await .await
@@ -508,6 +525,7 @@ where
pub async fn create_group( pub async fn create_group(
&self, &self,
group_context_extensions: ExtensionList, group_context_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<Group<C>, MlsError> { ) -> Result<Group<C>, MlsError> {
let (signing_identity, cipher_suite) = self.signing_identity()?; let (signing_identity, cipher_suite) = self.signing_identity()?;
@@ -518,6 +536,7 @@ where
self.version, self.version,
signing_identity.clone(), signing_identity.clone(),
group_context_extensions, group_context_extensions,
leaf_node_extensions,
self.signer()?.clone(), self.signer()?.clone(),
) )
.await .await
@@ -547,6 +566,50 @@ where
.await .await
} }
/// Decrypt GroupInfo encrypted in the Welcome message without actually joining
/// the group. The ratchet tree is not needed.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn examine_welcome_message(
&self,
welcome_message: &MlsMessage,
) -> Result<GroupInfo, MlsError> {
Group::decrypt_group_info(welcome_message, &self.config).await
}
/// Validate GroupInfo message. This does NOT validate the ratchet tree in case
/// it is provided in the extension. It validates the signature, identity of the
/// signer, identities of external senders and cipher suite.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn validate_group_info(
&self,
group_info_message: &MlsMessage,
signer: &SigningIdentity,
) -> Result<(), MlsError> {
let MlsMessagePayload::GroupInfo(group_info) = &group_info_message.payload else {
return Err(MlsError::UnexpectedMessageType);
};
let cs = cipher_suite_provider(
self.config.crypto_provider(),
group_info.group_context.cipher_suite,
)?;
let id = self.config.identity_provider();
validate_group_info_joiner(group_info_message.version, group_info, signer, &id, &cs)
.await?;
let context = MemberValidationContext::ForNewGroup {
current_context: &group_info.group_context,
};
id.validate_member(signer, None, context)
.await
.map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
Ok(())
}
/// 0-RTT add to an existing [group](crate::group::Group) /// 0-RTT add to an existing [group](crate::group::Group)
/// ///
/// External commits allow for immediate entry into a /// External commits allow for immediate entry into a
@@ -620,6 +683,31 @@ where
Group::from_snapshot(self.config.clone(), snapshot).await Group::from_snapshot(self.config.clone(), snapshot).await
} }
/// Load an existing group state into this client using the
/// [GroupStateStorage](crate::GroupStateStorage) that
/// this client was configured to use. The tree is taken from
/// `tree_data` instead of the stored state.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[inline(never)]
pub async fn load_group_with_ratchet_tree(
&self,
group_id: &[u8],
tree_data: ExportedTree<'_>,
) -> Result<Group<C>, MlsError> {
let snapshot = self
.config
.group_state_storage()
.state(group_id)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
.ok_or(MlsError::GroupNotFound)?;
let mut snapshot = Snapshot::mls_decode(&mut &*snapshot)?;
snapshot.state.public_tree.nodes = tree_data.0.into_owned();
Group::from_snapshot(self.config.clone(), snapshot).await
}
/// Request to join an existing [group](crate::group::Group). /// Request to join an existing [group](crate::group::Group).
/// ///
/// An existing group member will need to perform a /// An existing group member will need to perform a
@@ -632,6 +720,8 @@ where
group_info: &MlsMessage, group_info: &MlsMessage,
tree_data: Option<crate::group::ExportedTree<'_>>, tree_data: Option<crate::group::ExportedTree<'_>>,
authenticated_data: Vec<u8>, authenticated_data: Vec<u8>,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
) -> Result<MlsMessage, MlsError> { ) -> Result<MlsMessage, MlsError> {
let protocol_version = group_info.version; let protocol_version = group_info.version;
@@ -651,7 +741,7 @@ where
.cipher_suite_provider(cipher_suite) .cipher_suite_provider(cipher_suite)
.ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?; .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
crate::group::validate_group_info_joiner( crate::group::validate_tree_and_info_joiner(
protocol_version, protocol_version,
group_info, group_info,
tree_data, tree_data,
@@ -660,7 +750,10 @@ where
) )
.await?; .await?;
let key_package = self.generate_key_package().await?.key_package; let key_package = self
.generate_key_package(key_package_extensions, leaf_node_extensions)
.await?
.key_package;
(key_package.cipher_suite == cipher_suite) (key_package.cipher_suite == cipher_suite)
.then_some(()) .then_some(())
@@ -703,11 +796,6 @@ where
.ok_or(MlsError::SignerNotFound) .ok_or(MlsError::SignerNotFound)
} }
/// Returns key package extensions used by this client
pub fn key_package_extensions(&self) -> ExtensionList {
self.config.key_package_extensions()
}
/// The [KeyPackageStorage] that this client was configured to use. /// The [KeyPackageStorage] that this client was configured to use.
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)] #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository { pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
@@ -726,6 +814,12 @@ where
pub fn group_state_storage(&self) -> <C as ClientConfig>::GroupStateStorage { pub fn group_state_storage(&self) -> <C as ClientConfig>::GroupStateStorage {
self.config.group_state_storage() self.config.group_state_storage()
} }
/// The [IdentityProvider](crate::IdentityProvider) that this client was configured to use.
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
pub fn identity_provider(&self) -> <C as ClientConfig>::IdentityProvider {
self.config.identity_provider()
}
} }
#[cfg(test)] #[cfg(test)]
@@ -745,7 +839,15 @@ pub(crate) mod test_utils {
cipher_suite: CipherSuite, cipher_suite: CipherSuite,
identity: &str, identity: &str,
) -> (Client<TestClientConfig>, MlsMessage) { ) -> (Client<TestClientConfig>, MlsMessage) {
test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await test_client_with_key_pkg_custom(
protocol_version,
cipher_suite,
identity,
Default::default(),
Default::default(),
|_| {},
)
.await
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -753,6 +855,8 @@ pub(crate) mod test_utils {
protocol_version: ProtocolVersion, protocol_version: ProtocolVersion,
cipher_suite: CipherSuite, cipher_suite: CipherSuite,
identity: &str, identity: &str,
key_package_extensions: ExtensionList,
leaf_node_extensions: ExtensionList,
mut config: F, mut config: F,
) -> (Client<TestClientConfig>, MlsMessage) ) -> (Client<TestClientConfig>, MlsMessage)
where where
@@ -768,7 +872,10 @@ pub(crate) mod test_utils {
config(&mut client.config); config(&mut client.config);
let key_package = client.generate_key_package_message().await.unwrap(); let key_package = client
.generate_key_package_message(key_package_extensions, leaf_node_extensions)
.await
.unwrap();
(client, key_package) (client, key_package)
} }
@@ -786,16 +893,17 @@ mod tests {
}; };
use assert_matches::assert_matches; use assert_matches::assert_matches;
use crate::{ #[cfg(feature = "by_ref_proposal")]
group::{ use crate::group::message_processor::ProposalMessageDescription;
message_processor::ProposalMessageDescription, #[cfg(feature = "by_ref_proposal")]
proposal::Proposal, use crate::group::proposal::Proposal;
test_utils::{test_group, test_group_custom_config}, use crate::group::test_utils::test_group;
ReceivedMessage, #[cfg(feature = "psk")]
}, use crate::group::test_utils::test_group_custom_config;
psk::{ExternalPskId, PreSharedKey}, #[cfg(feature = "by_ref_proposal")]
}; use crate::group::ReceivedMessage;
#[cfg(feature = "psk")]
use crate::psk::{ExternalPskId, PreSharedKey};
use alloc::vec; use alloc::vec;
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
@@ -814,7 +922,10 @@ mod tests {
.build(); .build();
// TODO: Tests around extensions // TODO: Tests around extensions
let key_package = client.generate_key_package_message().await.unwrap(); let key_package = client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
assert_eq!(key_package.version, protocol_version); assert_eq!(key_package.version, protocol_version);
@@ -850,15 +961,16 @@ mod tests {
let proposal = bob let proposal = bob
.external_add_proposal( .external_add_proposal(
&alice_group.group.group_info_message(true).await.unwrap(), &alice_group.group_info_message(true).await.unwrap(),
None, None,
vec![], vec![],
Default::default(),
Default::default(),
) )
.await .await
.unwrap(); .unwrap();
let message = alice_group let message = alice_group
.group
.process_incoming_message(proposal) .process_incoming_message(proposal)
.await .await
.unwrap(); .unwrap();
@@ -870,12 +982,11 @@ mod tests {
) if p.key_package.leaf_node.signing_identity == bob_identity ) if p.key_package.leaf_node.signing_identity == bob_identity
); );
alice_group.group.commit(vec![]).await.unwrap(); alice_group.commit(vec![]).await.unwrap();
alice_group.group.apply_pending_commit().await.unwrap(); alice_group.apply_pending_commit().await.unwrap();
// Check that the new member is in the group // Check that the new member is in the group
assert!(alice_group assert!(alice_group
.group
.roster() .roster()
.members_iter() .members_iter()
.any(|member| member.signing_identity == bob_identity)) .any(|member| member.signing_identity == bob_identity))
@@ -888,6 +999,8 @@ mod tests {
// interim_transcript_hash to be computed from the confirmed_transcript_hash and // interim_transcript_hash to be computed from the confirmed_transcript_hash and
// confirmation_tag, which is not the case for the initial interim_transcript_hash. // confirmation_tag, which is not the case for the initial interim_transcript_hash.
use crate::group::{message_processor::CommitEffect, CommitMessageDescription};
let psk = PreSharedKey::from(b"psk".to_vec()); let psk = PreSharedKey::from(b"psk".to_vec());
let psk_id = ExternalPskId::new(b"psk id".to_vec()); let psk_id = ExternalPskId::new(b"psk id".to_vec());
@@ -905,7 +1018,6 @@ mod tests {
.unwrap(); .unwrap();
let group_info_msg = alice_group let group_info_msg = alice_group
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(); .unwrap();
@@ -937,37 +1049,40 @@ mod tests {
assert_eq!(new_group.roster().members_iter().count(), num_members); assert_eq!(new_group.roster().members_iter().count(), num_members);
let _ = alice_group let _ = alice_group
.group
.process_incoming_message(external_commit.clone()) .process_incoming_message(external_commit.clone())
.await .await
.unwrap(); .unwrap();
let bob_current_epoch = bob_group.group.current_epoch(); let bob_current_epoch = bob_group.current_epoch();
let message = bob_group let message = bob_group
.group
.process_incoming_message(external_commit) .process_incoming_message(external_commit)
.await .await
.unwrap(); .unwrap();
assert!(alice_group.group.roster().members_iter().count() == num_members); assert!(alice_group.roster().members_iter().count() == num_members);
if !do_remove { if !do_remove {
assert!(bob_group.group.roster().members_iter().count() == num_members); assert!(bob_group.roster().members_iter().count() == num_members);
} else { } else {
// Bob was removed so his epoch must stay the same // Bob was removed so his epoch must stay the same
assert_eq!(bob_group.group.current_epoch(), bob_current_epoch); assert_eq!(bob_group.current_epoch(), bob_current_epoch);
#[cfg(feature = "state_update")] assert_matches!(
assert_matches!(message, ReceivedMessage::Commit(desc) if !desc.state_update.active); message,
ReceivedMessage::Commit(CommitMessageDescription {
#[cfg(not(feature = "state_update"))] effect: CommitEffect::Removed {
assert_matches!(message, ReceivedMessage::Commit(_)); new_epoch: _,
remover: _
},
..
})
);
} }
// Comparing epoch authenticators is sufficient to check that members are in sync. // Comparing epoch authenticators is sufficient to check that members are in sync.
assert_eq!( assert_eq!(
alice_group.group.epoch_authenticator().unwrap(), alice_group.epoch_authenticator().unwrap(),
new_group.epoch_authenticator().unwrap() new_group.epoch_authenticator().unwrap()
); );
@@ -996,7 +1111,10 @@ mod tests {
.signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE) .signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
.build(); .build();
let msg = alice.generate_key_package_message().await.unwrap(); let msg = alice
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
let res = alice.commit_external(msg).await.map(|_| ()); let res = alice.commit_external(msg).await.map(|_| ());
assert_matches!(res, Err(MlsError::UnexpectedMessageType)); assert_matches!(res, Err(MlsError::UnexpectedMessageType));
@@ -1007,11 +1125,10 @@ mod tests {
let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let mut bob_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let mut bob_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
bob_group.group.commit(vec![]).await.unwrap(); bob_group.commit(vec![]).await.unwrap();
bob_group.group.apply_pending_commit().await.unwrap(); bob_group.apply_pending_commit().await.unwrap();
let group_info_msg = bob_group let group_info_msg = bob_group
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(); .unwrap();
@@ -1031,10 +1148,7 @@ mod tests {
.unwrap(); .unwrap();
// If Carol tries to join Alice's group using the group info from Bob's group, that fails. // If Carol tries to join Alice's group using the group info from Bob's group, that fails.
let res = alice_group let res = alice_group.process_incoming_message(external_commit).await;
.group
.process_incoming_message(external_commit)
.await;
assert_matches!(res, Err(_)); assert_matches!(res, Err(_));
} }
@@ -1046,4 +1160,70 @@ mod tests {
let bob = alice.to_builder().extension_type(34.into()).build(); let bob = alice.to_builder().extension_type(34.into()).build();
assert_eq!(bob.config.supported_extensions(), [33, 34].map(Into::into)); assert_eq!(bob.config.supported_extensions(), [33, 34].map(Into::into));
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn examine_welcome_message() {
let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
.await
.group;
let (bob, kp) =
test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
let commit = alice
.commit_builder()
.add_member(kp)
.unwrap()
.build()
.await
.unwrap();
alice.apply_pending_commit().await.unwrap();
let mut group_info = bob
.examine_welcome_message(&commit.welcome_messages[0])
.await
.unwrap();
// signature is random so we won't compare it
group_info.signature = vec![];
group_info.ungrease();
let mut expected_group_info = alice
.group_info_message(commit.ratchet_tree.is_none())
.await
.unwrap()
.into_group_info()
.unwrap();
expected_group_info.signature = vec![];
expected_group_info.ungrease();
assert_eq!(expected_group_info, group_info);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn validate_group_info() {
let alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
.await
.group;
let bob = test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob")
.await
.0;
let group_info = alice.group_info_message(false).await.unwrap();
let alice_signer = alice.current_member_signing_identity().unwrap().clone();
bob.validate_group_info(&group_info, &alice_signer)
.await
.unwrap();
let other_signer = get_test_signing_identity(TEST_CIPHER_SUITE, b"alice")
.await
.0;
let res = bob.validate_group_info(&group_info, &other_signer).await;
assert_matches!(res, Err(MlsError::InvalidSignature));
}
} }

View File

@@ -10,7 +10,7 @@ use crate::{
cipher_suite::CipherSuite, cipher_suite::CipherSuite,
client::Client, client::Client,
client_config::ClientConfig, client_config::ClientConfig,
extension::{ExtensionType, MlsExtension}, extension::ExtensionType,
group::{ group::{
mls_rules::{DefaultMlsRules, MlsRules}, mls_rules::{DefaultMlsRules, MlsRules},
proposal::ProposalType, proposal::ProposalType,
@@ -297,56 +297,6 @@ impl<C: IntoConfig> ClientBuilder<C> {
ClientBuilder(c) ClientBuilder(c)
} }
/// Add a key package extension to the list of key package extensions supported by the client.
pub fn key_package_extension<T>(
self,
extension: T,
) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
where
T: MlsExtension,
Self: Sized,
{
let mut c = self.0.into_config();
c.0.settings.key_package_extensions.set_from(extension)?;
Ok(ClientBuilder(c))
}
/// Add multiple key package extensions to the list of key package extensions supported by the
/// client.
pub fn key_package_extensions(
self,
extensions: ExtensionList,
) -> ClientBuilder<IntoConfigOutput<C>> {
let mut c = self.0.into_config();
c.0.settings.key_package_extensions.append(extensions);
ClientBuilder(c)
}
/// Add a leaf node extension to the list of leaf node extensions supported by the client.
pub fn leaf_node_extension<T>(
self,
extension: T,
) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
where
T: MlsExtension,
Self: Sized,
{
let mut c = self.0.into_config();
c.0.settings.leaf_node_extensions.set_from(extension)?;
Ok(ClientBuilder(c))
}
/// Add multiple leaf node extensions to the list of leaf node extensions supported by the
/// client.
pub fn leaf_node_extensions(
self,
extensions: ExtensionList,
) -> ClientBuilder<IntoConfigOutput<C>> {
let mut c = self.0.into_config();
c.0.settings.leaf_node_extensions.append(extensions);
ClientBuilder(c)
}
/// Set the lifetime duration in seconds of key packages generated by the client. /// Set the lifetime duration in seconds of key packages generated by the client.
pub fn key_package_lifetime(self, duration_in_s: u64) -> ClientBuilder<IntoConfigOutput<C>> { pub fn key_package_lifetime(self, duration_in_s: u64) -> ClientBuilder<IntoConfigOutput<C>> {
let mut c = self.0.into_config(); let mut c = self.0.into_config();
@@ -733,14 +683,6 @@ where
self.crypto_provider.clone() self.crypto_provider.clone()
} }
fn key_package_extensions(&self) -> ExtensionList {
self.settings.key_package_extensions.clone()
}
fn leaf_node_extensions(&self) -> ExtensionList {
self.settings.leaf_node_extensions.clone()
}
fn lifetime(&self) -> Lifetime { fn lifetime(&self) -> Lifetime {
#[cfg(feature = "std")] #[cfg(feature = "std")]
let now_timestamp = MlsTime::now().seconds_since_epoch(); let now_timestamp = MlsTime::now().seconds_since_epoch();
@@ -840,14 +782,6 @@ impl<T: MlsConfig> ClientConfig for T {
self.get().crypto_provider() self.get().crypto_provider()
} }
fn key_package_extensions(&self) -> ExtensionList {
self.get().key_package_extensions()
}
fn leaf_node_extensions(&self) -> ExtensionList {
self.get().leaf_node_extensions()
}
fn lifetime(&self) -> Lifetime { fn lifetime(&self) -> Lifetime {
self.get().lifetime() self.get().lifetime()
} }
@@ -870,8 +804,6 @@ pub(crate) struct Settings {
pub(crate) extension_types: Vec<ExtensionType>, pub(crate) extension_types: Vec<ExtensionType>,
pub(crate) protocol_versions: Vec<ProtocolVersion>, pub(crate) protocol_versions: Vec<ProtocolVersion>,
pub(crate) custom_proposal_types: Vec<ProposalType>, pub(crate) custom_proposal_types: Vec<ProposalType>,
pub(crate) key_package_extensions: ExtensionList,
pub(crate) leaf_node_extensions: ExtensionList,
pub(crate) lifetime_in_s: u64, pub(crate) lifetime_in_s: u64,
#[cfg(any(test, feature = "test_util"))] #[cfg(any(test, feature = "test_util"))]
pub(crate) key_package_not_before: Option<u64>, pub(crate) key_package_not_before: Option<u64>,
@@ -882,8 +814,6 @@ impl Default for Settings {
Self { Self {
extension_types: Default::default(), extension_types: Default::default(),
protocol_versions: Default::default(), protocol_versions: Default::default(),
key_package_extensions: Default::default(),
leaf_node_extensions: Default::default(),
lifetime_in_s: 365 * 24 * 3600, lifetime_in_s: 365 * 24 * 3600,
custom_proposal_types: Default::default(), custom_proposal_types: Default::default(),
#[cfg(any(test, feature = "test_util"))] #[cfg(any(test, feature = "test_util"))]
@@ -903,8 +833,6 @@ pub(crate) fn recreate_config<T: ClientConfig>(
extension_types: c.supported_extensions(), extension_types: c.supported_extensions(),
protocol_versions: c.supported_protocol_versions(), protocol_versions: c.supported_protocol_versions(),
custom_proposal_types: c.supported_custom_proposals(), custom_proposal_types: c.supported_custom_proposals(),
key_package_extensions: c.key_package_extensions(),
leaf_node_extensions: c.leaf_node_extensions(),
lifetime_in_s: { lifetime_in_s: {
let l = c.lifetime(); let l = c.lifetime();
l.not_after - l.not_before l.not_after - l.not_before
@@ -979,7 +907,6 @@ mod private {
use mls_rs_core::{ use mls_rs_core::{
crypto::{CryptoProvider, SignatureSecretKey}, crypto::{CryptoProvider, SignatureSecretKey},
extension::{ExtensionError, ExtensionList},
group::GroupStateStorage, group::GroupStateStorage,
identity::IdentityProvider, identity::IdentityProvider,
key_package::KeyPackageStorage, key_package::KeyPackageStorage,

View File

@@ -37,8 +37,6 @@ pub trait ClientConfig: Send + Sync + Clone {
fn identity_provider(&self) -> Self::IdentityProvider; fn identity_provider(&self) -> Self::IdentityProvider;
fn crypto_provider(&self) -> Self::CryptoProvider; fn crypto_provider(&self) -> Self::CryptoProvider;
fn key_package_extensions(&self) -> ExtensionList;
fn leaf_node_extensions(&self) -> ExtensionList;
fn lifetime(&self) -> Lifetime; fn lifetime(&self) -> Lifetime;
fn capabilities(&self) -> Capabilities { fn capabilities(&self) -> Capabilities {
@@ -59,10 +57,10 @@ pub trait ClientConfig: Send + Sync + Clone {
self.identity_provider().supported_types() self.identity_provider().supported_types()
} }
fn leaf_properties(&self) -> ConfigProperties { fn leaf_properties(&self, leaf_node_extensions: ExtensionList) -> ConfigProperties {
ConfigProperties { ConfigProperties {
capabilities: self.capabilities(), capabilities: self.capabilities(),
extensions: self.leaf_node_extensions(), extensions: leaf_node_extensions,
} }
} }
} }

View File

@@ -5,10 +5,16 @@
pub use mls_rs_core::extension::{ExtensionType, MlsCodecExtension, MlsExtension}; pub use mls_rs_core::extension::{ExtensionType, MlsCodecExtension, MlsExtension};
pub(crate) use built_in::*; pub(crate) use built_in::*;
#[cfg(feature = "last_resort_key_package_ext")]
pub(crate) use recommended::*;
/// Default extension types required by the MLS RFC. /// Default extension types required by the MLS RFC.
pub mod built_in; pub mod built_in;
/// Extension types which are not mandatory, but still recommended.
#[cfg(feature = "last_resort_key_package_ext")]
pub mod recommended;
#[cfg(test)] #[cfg(test)]
pub(crate) mod test_utils { pub(crate) mod test_utils {
use alloc::vec::Vec; use alloc::vec::Vec;

View File

@@ -0,0 +1,29 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//! Recommended MLS extensions.
//!
//! Optional, but recommended extensions from [The Messaging Layer
//! Security (MLS) Extensions][1].
//!
//! [1]: https://datatracker.ietf.org/doc/html/draft-ietf-mls-extensions-04
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
/// Last resort key packages.
///
/// The extension allows clients that pre-publish key packages to
/// signal to the Delivery Service which key packages are meant to be
/// used as last resort key packages.
#[cfg(feature = "last_resort_key_package_ext")]
#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
pub struct LastResortKeyPackageExt;
#[cfg(feature = "last_resort_key_package_ext")]
impl MlsCodecExtension for LastResortKeyPackageExt {
fn extension_type() -> ExtensionType {
ExtensionType::LAST_RESORT_KEY_PACKAGE
}
}

View File

@@ -97,6 +97,21 @@ where
ExternalGroup::from_snapshot(self.config.clone(), snapshot).await ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
} }
/// Load an existing observed group by loading a snapshot that was
/// generated by
/// [ExternalGroup::snapshot](self::ExternalGroup::snapshot). The tree
/// is taken from `tree_data` instead of the stored state.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn load_group_with_ratchet_tree(
&self,
mut snapshot: ExternalSnapshot,
tree_data: ExportedTree<'_>,
) -> Result<ExternalGroup<C>, MlsError> {
snapshot.state.public_tree.nodes = tree_data.0.into_owned();
ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn validate_key_package( pub async fn validate_key_package(
&self, &self,
@@ -120,6 +135,11 @@ where
Ok(key_package) Ok(key_package)
} }
/// The [IdentityProvider](crate::IdentityProvider) that this client was configured to use.
pub fn identity_provider(&self) -> <C as ExternalClientConfig>::IdentityProvider {
self.config.identity_provider()
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -582,7 +582,7 @@ pub(crate) mod test_utils {
impl TestExternalClientBuilder { impl TestExternalClientBuilder {
pub fn new_for_test() -> Self { pub fn new_for_test() -> Self {
ExternalClientBuilder::new() ExternalClientBuilder::new()
.crypto_provider(TestCryptoProvider::default()) .crypto_provider(TestCryptoProvider::new())
.identity_provider(BasicIdentityProvider::new()) .identity_provider(BasicIdentityProvider::new())
} }

View File

@@ -21,10 +21,12 @@ use crate::{
ApplicationMessageDescription, CommitMessageDescription, EventOrContent, ApplicationMessageDescription, CommitMessageDescription, EventOrContent,
MessageProcessor, ProposalMessageDescription, ProvisionalState, MessageProcessor, ProposalMessageDescription, ProvisionalState,
}, },
proposal::RemoveProposal,
proposal_filter::ProposalInfo,
snapshot::RawGroupState, snapshot::RawGroupState,
state::GroupState, state::GroupState,
transcript_hash::InterimTranscriptHash, transcript_hash::InterimTranscriptHash,
validate_group_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster, validate_tree_and_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster,
Welcome, Welcome,
}, },
identity::SigningIdentity, identity::SigningIdentity,
@@ -56,7 +58,7 @@ use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use crate::{ use crate::{
extension::ExternalSendersExt, extension::ExternalSendersExt,
group::proposal::{AddProposal, ReInitProposal, RemoveProposal}, group::proposal::{AddProposal, ReInitProposal},
}; };
#[cfg(all(feature = "by_ref_proposal", feature = "psk"))] #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
@@ -127,7 +129,7 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
group_info.group_context.cipher_suite, group_info.group_context.cipher_suite,
)?; )?;
let public_tree = validate_group_info_joiner( let public_tree = validate_tree_and_info_joiner(
protocol_version, protocol_version,
&group_info, &group_info,
tree_data, tree_data,
@@ -414,9 +416,14 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
.await .await
} }
/// Issue an external proposal.
///
/// This function is useful for reissuing external proposals that
/// are returned in [crate::group::NewEpoch::unused_proposals]
/// after a commit is processed.
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn propose( pub async fn propose(
&mut self, &mut self,
proposal: Proposal, proposal: Proposal,
authenticated_data: Vec<u8>, authenticated_data: Vec<u8>,
@@ -450,11 +457,8 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
) )
.await?; .await?;
self.state.proposals.insert( let proposal_ref =
ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?, ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
proposal,
sender,
);
let plaintext = PublicMessage { let plaintext = PublicMessage {
content: auth_content.content, content: auth_content.content,
@@ -462,10 +466,14 @@ impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
membership_tag: None, membership_tag: None,
}; };
Ok(MlsMessage::new( let message = MlsMessage::new(
self.group_context().version(), self.group_context().version(),
MlsMessagePayload::Plain(plaintext), MlsMessagePayload::Plain(plaintext),
)) );
self.state.proposals.insert(proposal_ref, proposal, sender);
Ok(message)
} }
/// Delete all sent and received proposals cached for commit. /// Delete all sent and received proposals cached for commit.
@@ -582,7 +590,6 @@ where
&self.cipher_suite_provider, &self.cipher_suite_provider,
message, message,
None, None,
None,
&self.state, &self.state,
) )
.await?; .await?;
@@ -633,8 +640,11 @@ where
&mut self.state &mut self.state
} }
fn can_continue_processing(&self, _provisional_state: &ProvisionalState) -> bool { fn removal_proposal(
true &self,
_provisional_state: &ProvisionalState,
) -> Option<ProposalInfo<RemoveProposal>> {
None
} }
#[cfg(feature = "private_message")] #[cfg(feature = "private_message")]
@@ -653,7 +663,7 @@ where
#[derive(Debug, MlsEncode, MlsSize, MlsDecode, PartialEq, Clone)] #[derive(Debug, MlsEncode, MlsSize, MlsDecode, PartialEq, Clone)]
pub struct ExternalSnapshot { pub struct ExternalSnapshot {
version: u16, version: u16,
state: RawGroupState, pub(crate) state: RawGroupState,
signing_data: Option<(SignatureSecretKey, SigningIdentity)>, signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
} }
@@ -667,6 +677,11 @@ impl ExternalSnapshot {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> { pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
Ok(Self::mls_decode(&mut &*bytes)?) Ok(Self::mls_decode(&mut &*bytes)?)
} }
/// Group context encoded in the snapshot
pub fn context(&self) -> &GroupContext {
&self.state.context
}
} }
impl<C> ExternalGroup<C> impl<C> ExternalGroup<C>
@@ -682,6 +697,23 @@ where
} }
} }
/// Create a snapshot of this group's current internal state.
/// The tree is not included in the state and can be stored
/// separately by calling [`Group::export_tree`].
pub fn snapshot_without_ratchet_tree(&mut self) -> ExternalSnapshot {
let tree = std::mem::take(&mut self.state.public_tree.nodes);
let snapshot = ExternalSnapshot {
state: RawGroupState::export(&self.state),
version: 1,
signing_data: self.signing_data.clone(),
};
self.state.public_tree.nodes = tree;
snapshot
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn from_snapshot( pub(crate) async fn from_snapshot(
config: C, config: C,
@@ -777,7 +809,6 @@ pub(crate) mod test_utils {
config, config,
None, None,
group group
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(), .unwrap(),
@@ -802,14 +833,15 @@ mod tests {
external_client::{ external_client::{
group::test_utils::make_external_group_with_config, group::test_utils::make_external_group_with_config,
tests_utils::{TestExternalClientBuilder, TestExternalClientConfig}, tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
ExternalGroup, ExternalReceivedMessage, ExternalSnapshot, ExternalClient, ExternalGroup, ExternalReceivedMessage, ExternalSnapshot,
}, },
group::{ group::{
framing::{Content, MlsMessagePayload}, framing::{Content, MlsMessagePayload},
message_processor::CommitEffect,
proposal::{AddProposal, Proposal, ProposalOrRef}, proposal::{AddProposal, Proposal, ProposalOrRef},
proposal_ref::ProposalRef, proposal_ref::ProposalRef,
test_utils::{test_group, TestGroup}, test_utils::{test_group, TestGroup},
ProposalMessageDescription, CommitMessageDescription, ExportedTree, ProposalMessageDescription,
}, },
identity::{test_utils::get_test_signing_identity, SigningIdentity}, identity::{test_utils::get_test_signing_identity, SigningIdentity},
key_package::test_utils::{test_key_package, test_key_package_message}, key_package::test_utils::{test_key_package, test_key_package_message},
@@ -822,7 +854,7 @@ mod tests {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn test_group_with_one_commit(v: ProtocolVersion, cs: CipherSuite) -> TestGroup { async fn test_group_with_one_commit(v: ProtocolVersion, cs: CipherSuite) -> TestGroup {
let mut group = test_group(v, cs).await; let mut group = test_group(v, cs).await;
group.group.commit(Vec::new()).await.unwrap(); group.commit(Vec::new()).await.unwrap();
group.process_pending_commit().await.unwrap(); group.process_pending_commit().await.unwrap();
group group
} }
@@ -837,11 +869,7 @@ mod tests {
let bob_key_package = test_key_package_message(v, cs, "bob").await; let bob_key_package = test_key_package_message(v, cs, "bob").await;
let mut commit_builder = group let mut commit_builder = group.commit_builder().add_member(bob_key_package).unwrap();
.group
.commit_builder()
.add_member(bob_key_package)
.unwrap();
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
if let Some(ext_signer) = ext_identity { if let Some(ext_signer) = ext_identity {
@@ -877,15 +905,15 @@ mod tests {
async fn external_group_can_process_commit() { async fn external_group_can_process_commit() {
let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let mut server = make_external_group(&alice).await; let mut server = make_external_group(&alice).await;
let commit_output = alice.group.commit(Vec::new()).await.unwrap(); let commit_output = alice.commit(Vec::new()).await.unwrap();
alice.group.apply_pending_commit().await.unwrap(); alice.apply_pending_commit().await.unwrap();
server server
.process_incoming_message(commit_output.commit_message) .process_incoming_message(commit_output.commit_message)
.await .await
.unwrap(); .unwrap();
assert_eq!(alice.group.state, server.state); assert_eq!(alice.state, server.state);
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
@@ -909,25 +937,29 @@ mod tests {
ExternalReceivedMessage::Proposal(ProposalMessageDescription { ref proposal, ..}) if proposal == &add_proposal ExternalReceivedMessage::Proposal(ProposalMessageDescription { ref proposal, ..}) if proposal == &add_proposal
); );
let commit_output = alice.group.commit(vec![]).await.unwrap(); let commit_output = alice.commit(vec![]).await.unwrap();
alice.group.apply_pending_commit().await.unwrap(); alice.apply_pending_commit().await.unwrap();
let commit_result = server let new_epoch = match server
.process_incoming_message(commit_output.commit_message) .process_incoming_message(commit_output.commit_message)
.await .await
.unwrap(); .unwrap()
{
ExternalReceivedMessage::Commit(CommitMessageDescription {
effect: CommitEffect::NewEpoch(new_epoch),
..
}) => new_epoch,
_ => panic!("Expected processed commit"),
};
#[cfg(feature = "state_update")] assert_eq!(new_epoch.applied_proposals.len(), 1);
assert_matches!(
commit_result,
ExternalReceivedMessage::Commit(commit_description)
if commit_description.state_update.roster_update.added().iter().any(|added| added.index == 1)
);
#[cfg(not(feature = "state_update"))] assert!(new_epoch
assert_matches!(commit_result, ExternalReceivedMessage::Commit(_)); .applied_proposals
.into_iter()
.any(|p| p.proposal == add_proposal));
assert_eq!(alice.group.state, server.state); assert_eq!(alice.state, server.state);
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
@@ -936,17 +968,28 @@ mod tests {
let mut server = make_external_group(&alice).await; let mut server = make_external_group(&alice).await;
let (_, commit) = alice.join("bob").await; let (_, commit) = alice.join("bob").await;
let update = match server.process_incoming_message(commit).await.unwrap() { let new_epoch = match server.process_incoming_message(commit).await.unwrap() {
ExternalReceivedMessage::Commit(update) => update.state_update, ExternalReceivedMessage::Commit(CommitMessageDescription {
effect: CommitEffect::NewEpoch(new_epoch),
..
}) => new_epoch,
_ => panic!("Expected processed commit"), _ => panic!("Expected processed commit"),
}; };
#[cfg(feature = "state_update")] assert_eq!(new_epoch.applied_proposals.len(), 1);
assert_eq!(update.roster_update.added().len(), 1);
assert_eq!(
new_epoch
.applied_proposals
.into_iter()
.filter(|p| matches!(p.proposal, Proposal::Add(_)))
.count(),
1
);
assert_eq!(server.state.public_tree.get_leaf_nodes().len(), 2); assert_eq!(server.state.public_tree.get_leaf_nodes().len(), 2);
assert_eq!(alice.group.state, server.state); assert_eq!(alice.state, server.state);
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
@@ -954,7 +997,7 @@ mod tests {
let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let mut server = make_external_group(&alice).await; let mut server = make_external_group(&alice).await;
let mut commit_output = alice.group.commit(vec![]).await.unwrap(); let mut commit_output = alice.commit(vec![]).await.unwrap();
match commit_output.commit_message.payload { match commit_output.commit_message.payload {
MlsMessagePayload::Plain(ref mut plain) => plain.content.epoch = 0, MlsMessagePayload::Plain(ref mut plain) => plain.content.epoch = 0,
@@ -978,7 +1021,7 @@ mod tests {
) )
.await; .await;
let mut commit_output = alice.group.commit(Vec::new()).await.unwrap(); let mut commit_output = alice.commit(Vec::new()).await.unwrap();
match commit_output.commit_message.payload { match commit_output.commit_message.payload {
MlsMessagePayload::Plain(ref mut plain) => plain.auth.signature = Vec::new().into(), MlsMessagePayload::Plain(ref mut plain) => plain.auth.signature = Vec::new().into(),
@@ -1018,7 +1061,6 @@ mod tests {
config, config,
None, None,
alice alice
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(), .unwrap(),
@@ -1040,7 +1082,6 @@ mod tests {
let config = TestExternalClientBuilder::new_for_test().build_config(); let config = TestExternalClientBuilder::new_for_test().build_config();
let mut group_info = alice let mut group_info = alice
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(); .unwrap();
@@ -1093,7 +1134,7 @@ mod tests {
alice.process_message(external_proposal).await.unwrap(); alice.process_message(external_proposal).await.unwrap();
// Alice commits the proposal // Alice commits the proposal
let commit_output = alice.group.commit(vec![]).await.unwrap(); let commit_output = alice.commit(vec![]).await.unwrap();
let commit = match commit_output let commit = match commit_output
.commit_message .commit_message
@@ -1119,7 +1160,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
assert_eq!(alice.group.state, server.state); assert_eq!(alice.state, server.state);
} }
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
@@ -1211,12 +1252,11 @@ mod tests {
.await; .await;
let old_application_msg = alice let old_application_msg = alice
.group
.encrypt_application_message(&[], vec![]) .encrypt_application_message(&[], vec![])
.await .await
.unwrap(); .unwrap();
let commit_output = alice.group.commit(vec![]).await.unwrap(); let commit_output = alice.commit(vec![]).await.unwrap();
server server
.process_incoming_message(commit_output.commit_message) .process_incoming_message(commit_output.commit_message)
@@ -1240,9 +1280,9 @@ mod tests {
) )
.await; .await;
let proposal = alice.group.propose_update(vec![]).await.unwrap(); let proposal = alice.propose_update(vec![]).await.unwrap();
let commit_output = alice.group.commit(vec![]).await.unwrap(); let commit_output = alice.commit(vec![]).await.unwrap();
server server
.process_incoming_message(proposal.clone()) .process_incoming_message(proposal.clone())
@@ -1262,7 +1302,6 @@ mod tests {
let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let info = alice let info = alice
.group
.group_info_message_allowing_ext_commit(true) .group_info_message_allowing_ext_commit(true)
.await .await
.unwrap(); .unwrap();
@@ -1271,7 +1310,7 @@ mod tests {
let mut server = ExternalGroup::join(config, None, info, None).await.unwrap(); let mut server = ExternalGroup::join(config, None, info, None).await.unwrap();
for _ in 0..2 { for _ in 0..2 {
let commit = alice.group.commit(vec![]).await.unwrap().commit_message; let commit = alice.commit(vec![]).await.unwrap().commit_message;
alice.process_pending_commit().await.unwrap(); alice.process_pending_commit().await.unwrap();
server.process_incoming_message(commit).await.unwrap(); server.process_incoming_message(commit).await.unwrap();
} }
@@ -1299,7 +1338,6 @@ mod tests {
let mut server = make_external_group(&alice).await; let mut server = make_external_group(&alice).await;
let info = alice let info = alice
.group
.group_info_message_allowing_ext_commit(false) .group_info_message_allowing_ext_commit(false)
.await .await
.unwrap(); .unwrap();
@@ -1329,7 +1367,6 @@ mod tests {
let mut server = make_external_group(&alice).await; let mut server = make_external_group(&alice).await;
let [welcome] = alice let [welcome] = alice
.group
.commit_builder() .commit_builder()
.add_member( .add_member(
test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await, test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await,
@@ -1346,4 +1383,37 @@ mod tests {
assert_matches!(update, ExternalReceivedMessage::Welcome); assert_matches!(update, ExternalReceivedMessage::Welcome);
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn external_group_can_be_stored_without_tree() {
let mut server =
make_external_group(&test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await).await;
let snapshot_with_tree = server.snapshot().mls_encode_to_vec().unwrap();
let snapshot_without_tree = server
.snapshot_without_ratchet_tree()
.mls_encode_to_vec()
.unwrap();
let tree = server.state.public_tree.nodes.mls_encode_to_vec().unwrap();
let empty_tree = Vec::<u8>::new().mls_encode_to_vec().unwrap();
assert_eq!(
snapshot_with_tree.len() - snapshot_without_tree.len(),
tree.len() - empty_tree.len()
);
let exported_tree = server.export_tree().unwrap();
let restored = ExternalClient::new(server.config.clone(), None)
.load_group_with_ratchet_tree(
ExternalSnapshot::from_bytes(&snapshot_without_tree).unwrap(),
ExportedTree::from_bytes(&exported_tree).unwrap(),
)
.await
.unwrap();
assert_eq!(restored.group_state(), server.group_state());
}
} }

View File

@@ -55,6 +55,10 @@ impl GroupInfo {
pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> { pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ()) grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
} }
pub fn ungrease(&mut self) {
grease_functions::ungrease_extensions(&mut self.extensions)
}
} }
impl NewMemberInfo { impl NewMemberInfo {

View File

@@ -305,10 +305,10 @@ mod test {
let content = AuthenticatedContent::new_signed( let content = AuthenticatedContent::new_signed(
&provider, &provider,
group.group.context(), group.context(),
Sender::Member(0), Sender::Member(0),
Content::Application(ApplicationData::from(b"test".to_vec())), Content::Application(ApplicationData::from(b"test".to_vec())),
&group.group.signer, &group.signer,
WireFormat::PrivateMessage, WireFormat::PrivateMessage,
vec![], vec![],
) )
@@ -331,7 +331,7 @@ mod test {
.await .await
.unwrap(); .unwrap();
receiver_group.group.private_tree.self_index = LeafIndex::new(1); receiver_group.private_tree.self_index = LeafIndex::new(1);
let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite); let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);
@@ -401,7 +401,7 @@ mod test {
.unwrap(); .unwrap();
ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len()); ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
receiver_group.group.private_tree.self_index = LeafIndex::new(1); receiver_group.private_tree.self_index = LeafIndex::new(1);
let res = ciphertext_processor.open(&ciphertext).await; let res = ciphertext_processor.open(&ciphertext).await;

View File

@@ -2,14 +2,12 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use alloc::boxed::Box;
use alloc::vec; use alloc::vec;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::fmt::{self, Debug}; use core::fmt::Debug;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::{ use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
crypto::{CipherSuiteProvider, SignatureSecretKey},
error::IntoAnyError,
};
use crate::{ use crate::{
cipher_suite::CipherSuite, cipher_suite::CipherSuite,
@@ -43,12 +41,14 @@ use super::{
confirmation_tag::ConfirmationTag, confirmation_tag::ConfirmationTag,
framing::{Content, MlsMessage, MlsMessagePayload, Sender}, framing::{Content, MlsMessage, MlsMessagePayload, Sender},
key_schedule::{KeySchedule, WelcomeSecret}, key_schedule::{KeySchedule, WelcomeSecret},
message_hash::MessageHash,
message_processor::{path_update_required, MessageProcessor}, message_processor::{path_update_required, MessageProcessor},
message_signature::AuthenticatedContent, message_signature::AuthenticatedContent,
mls_rules::CommitDirection, mls_rules::CommitDirection,
proposal::{Proposal, ProposalOrRef}, proposal::{Proposal, ProposalOrRef},
ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo, CommitEffect, CommitMessageDescription, EncryptedGroupSecrets, EpochSecrets, ExportedTree,
Welcome, Group, GroupContext, GroupInfo, GroupState, InterimTranscriptHash, NewEpoch,
PendingCommitSnapshot, Welcome,
}; };
#[cfg(not(feature = "by_ref_proposal"))] #[cfg(not(feature = "by_ref_proposal"))]
@@ -66,40 +66,34 @@ pub(crate) struct Commit {
} }
#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)] #[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct PendingCommit {
pub(super) struct CommitGeneration { pub(crate) state: GroupState,
pub content: AuthenticatedContent, pub(crate) epoch_secrets: EpochSecrets,
pub pending_private_tree: TreeKemPrivate, pub(crate) private_tree: TreeKemPrivate,
pub pending_commit_secret: PathSecret, pub(crate) key_schedule: KeySchedule,
pub commit_message_hash: CommitHash, pub(crate) signer: SignatureSecretKey,
pub(crate) output: CommitMessageDescription,
pub(crate) commit_message_hash: MessageHash,
} }
#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)] #[cfg_attr(
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] all(feature = "ffi", not(test)),
pub(crate) struct CommitHash( safer_ffi_gen::ffi_type(clone, opaque)
#[mls_codec(with = "mls_rs_codec::byte_vec")] )]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] #[derive(Clone)]
Vec<u8>, pub struct CommitSecrets(pub(crate) PendingCommitSnapshot);
);
impl Debug for CommitHash { impl CommitSecrets {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { /// Deserialize the commit secrets from bytes
mls_rs_core::debug::pretty_bytes(&self.0) pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
.named("CommitHash") Ok(MlsDecode::mls_decode(&mut &*bytes).map(Self)?)
.fmt(f)
} }
}
impl CommitHash { /// Serialize the commit secrets to bytes
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
pub(crate) async fn compute<CS: CipherSuiteProvider>( Ok(self.0.mls_encode_to_vec()?)
cs: &CS,
commit: &MlsMessage,
) -> Result<Self, MlsError> {
cs.hash(&commit.mls_encode_to_vec()?)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
.map(Self)
} }
} }
@@ -133,6 +127,8 @@ pub struct CommitOutput {
/// Proposals that were received in the prior epoch but not included in the following commit. /// Proposals that were received in the prior epoch but not included in the following commit.
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>, pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
/// Indicator that the commit contains a path update
pub contains_update_path: bool,
} }
#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
@@ -189,6 +185,7 @@ where
group_info_extensions: ExtensionList, group_info_extensions: ExtensionList,
new_signer: Option<SignatureSecretKey>, new_signer: Option<SignatureSecretKey>,
new_signing_identity: Option<SigningIdentity>, new_signing_identity: Option<SigningIdentity>,
new_leaf_node_extensions: Option<ExtensionList>,
} }
impl<'a, C> CommitBuilder<'a, C> impl<'a, C> CommitBuilder<'a, C>
@@ -292,7 +289,7 @@ where
/// Insert a proposal that was previously constructed such as when a /// Insert a proposal that was previously constructed such as when a
/// proposal is returned from /// proposal is returned from
/// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals). /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
pub fn raw_proposal(mut self, proposal: Proposal) -> Self { pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
self.proposals.push(proposal); self.proposals.push(proposal);
self self
@@ -300,7 +297,7 @@ where
/// Insert proposals that were previously constructed such as when a /// Insert proposals that were previously constructed such as when a
/// proposal is returned from /// proposal is returned from
/// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals). /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self { pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
self.proposals.append(&mut proposals); self.proposals.append(&mut proposals);
self self
@@ -337,6 +334,14 @@ where
} }
} }
/// Change the committer's leaf node extensions as part of making this commit.
pub fn set_leaf_node_extensions(self, new_leaf_node_extensions: ExtensionList) -> Self {
Self {
new_leaf_node_extensions: Some(new_leaf_node_extensions),
..self
}
}
/// Finalize the commit to send. /// Finalize the commit to send.
/// ///
/// # Errors /// # Errors
@@ -347,7 +352,8 @@ where
/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules). /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn build(self) -> Result<CommitOutput, MlsError> { pub async fn build(self) -> Result<CommitOutput, MlsError> {
self.group let (output, pending_commit) = self
.group
.commit_internal( .commit_internal(
self.proposals, self.proposals,
None, None,
@@ -355,8 +361,40 @@ where
self.group_info_extensions, self.group_info_extensions,
self.new_signer, self.new_signer,
self.new_signing_identity, self.new_signing_identity,
self.new_leaf_node_extensions,
) )
.await .await?;
self.group.pending_commit = pending_commit.try_into()?;
Ok(output)
}
/// The same function as `GroupBuilder::build` except the secrets generated
/// for the commit are outputted instead of being cached internally.
///
/// A detached commit can be applied using `Group::apply_detached_commit`.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn build_detached(self) -> Result<(CommitOutput, CommitSecrets), MlsError> {
let (output, pending_commit) = self
.group
.commit_internal(
self.proposals,
None,
self.authenticated_data,
self.group_info_extensions,
self.new_signer,
self.new_signing_identity,
self.new_leaf_node_extensions,
)
.await?;
Ok((
output,
CommitSecrets(PendingCommitSnapshot::PendingCommit(
pending_commit.mls_encode_to_vec()?,
)),
))
} }
} }
@@ -406,15 +444,25 @@ where
/// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit. /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> { pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
self.commit_internal( self.commit_builder()
vec![], .authenticated_data(authenticated_data)
None, .build()
authenticated_data, .await
Default::default(), }
None,
None, /// The same function as `Group::commit` except the secrets generated
) /// for the commit are outputted instead of being cached internally.
.await ///
/// A detached commit can be applied using `Group::apply_detached_commit`.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn commit_detached(
&mut self,
authenticated_data: Vec<u8>,
) -> Result<(CommitOutput, CommitSecrets), MlsError> {
self.commit_builder()
.authenticated_data(authenticated_data)
.build_detached()
.await
} }
/// Create a new commit builder that can include proposals /// Create a new commit builder that can include proposals
@@ -427,6 +475,7 @@ where
group_info_extensions: Default::default(), group_info_extensions: Default::default(),
new_signer: Default::default(), new_signer: Default::default(),
new_signing_identity: Default::default(), new_signing_identity: Default::default(),
new_leaf_node_extensions: Default::default(),
} }
} }
@@ -442,8 +491,9 @@ where
mut welcome_group_info_extensions: ExtensionList, mut welcome_group_info_extensions: ExtensionList,
new_signer: Option<SignatureSecretKey>, new_signer: Option<SignatureSecretKey>,
new_signing_identity: Option<SigningIdentity>, new_signing_identity: Option<SigningIdentity>,
) -> Result<CommitOutput, MlsError> { new_leaf_node_extensions: Option<ExtensionList>,
if self.pending_commit.is_some() { ) -> Result<(CommitOutput, PendingCommit), MlsError> {
if !self.pending_commit.is_none() {
return Err(MlsError::ExistingPendingCommit); return Err(MlsError::ExistingPendingCommit);
} }
@@ -464,7 +514,7 @@ where
Sender::Member(*self.private_tree.self_index) Sender::Member(*self.private_tree.self_index)
}; };
let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer); let new_signer = new_signer.unwrap_or_else(|| self.signer.clone());
let old_signer = &self.signer; let old_signer = &self.signer;
#[cfg(feature = "std")] #[cfg(feature = "std")]
@@ -505,15 +555,13 @@ where
self.private_tree.self_index = provisional_private_tree.self_index; self.private_tree.self_index = provisional_private_tree.self_index;
} }
let mut provisional_group_context = provisional_state.group_context;
// Decide whether to populate the path field: If the path field is required based on the // Decide whether to populate the path field: If the path field is required based on the
// proposals that are in the commit (see above), then it MUST be populated. Otherwise, the // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
// sender MAY omit the path field at its discretion. // sender MAY omit the path field at its discretion.
let commit_options = mls_rules let commit_options = mls_rules
.commit_options( .commit_options(
&provisional_state.public_tree.roster(), &provisional_state.public_tree.roster(),
&provisional_group_context.extensions, &provisional_state.group_context,
&provisional_state.applied_proposals, &provisional_state.applied_proposals,
) )
.map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?; .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
@@ -528,15 +576,25 @@ where
// group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
// GroupContext object. The leaf_key_package for this UpdatePath must have a // GroupContext object. The leaf_key_package for this UpdatePath must have a
// parent_hash extension. // parent_hash extension.
let new_leaf_node_extensions =
new_leaf_node_extensions.or(external_leaf.map(|ln| ln.ungreased_extensions()));
let new_leaf_node_extensions = match new_leaf_node_extensions {
Some(extensions) => extensions,
// If we are not setting new extensions and this is not an external leaf then the current node MUST exist.
None => self.current_user_leaf_node()?.ungreased_extensions(),
};
let encap_gen = TreeKem::new( let encap_gen = TreeKem::new(
&mut provisional_state.public_tree, &mut provisional_state.public_tree,
&mut provisional_private_tree, &mut provisional_private_tree,
) )
.encap( .encap(
&mut provisional_group_context, &mut provisional_state.group_context,
&provisional_state.indexes_of_added_kpkgs, &provisional_state.indexes_of_added_kpkgs,
new_signer_ref, &new_signer,
self.config.leaf_properties(), Some(self.config.leaf_properties(new_leaf_node_extensions)),
new_signing_identity, new_signing_identity,
&self.cipher_suite_provider, &self.cipher_suite_provider,
#[cfg(test)] #[cfg(test)]
@@ -559,7 +617,7 @@ where
) )
.await?; .await?;
provisional_group_context.tree_hash = provisional_state provisional_state.group_context.tree_hash = provisional_state
.public_tree .public_tree
.tree_hash(&self.cipher_suite_provider) .tree_hash(&self.cipher_suite_provider)
.await?; .await?;
@@ -583,7 +641,7 @@ where
.collect(); .collect();
let commit = Commit { let commit = Commit {
proposals: provisional_state.applied_proposals.into_proposals_or_refs(), proposals: provisional_state.applied_proposals.proposals_or_refs(),
path: update_path, path: update_path,
}; };
@@ -591,7 +649,7 @@ where
&self.cipher_suite_provider, &self.cipher_suite_provider,
self.context(), self.context(),
sender, sender,
Content::Commit(alloc::boxed::Box::new(commit)), Content::Commit(Box::new(commit)),
old_signer, old_signer,
#[cfg(feature = "private_message")] #[cfg(feature = "private_message")]
self.encryption_options()?.control_wire_format(sender), self.encryption_options()?.control_wire_format(sender),
@@ -603,21 +661,21 @@ where
// Use the signature, the commit_secret and the psk_secret to advance the key schedule and // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
// compute the confirmation_tag value in the MlsPlaintext. // compute the confirmation_tag value in the MlsPlaintext.
let confirmed_transcript_hash = ConfirmedTranscriptHash::create( let confirmed_transcript_hash = super::transcript_hash::create(
self.cipher_suite_provider(), self.cipher_suite_provider(),
&self.state.interim_transcript_hash, &self.state.interim_transcript_hash,
&auth_content, &auth_content,
) )
.await?; .await?;
provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash; provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
let key_schedule_result = KeySchedule::from_key_schedule( let key_schedule_result = KeySchedule::from_key_schedule(
&self.key_schedule, &self.key_schedule,
&commit_secret, &commit_secret,
&provisional_group_context, &provisional_state.group_context,
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))] #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
self.state.public_tree.total_leaf_count(), provisional_state.public_tree.total_leaf_count(),
&psk_secret, &psk_secret,
&self.cipher_suite_provider, &self.cipher_suite_provider,
) )
@@ -625,11 +683,18 @@ where
let confirmation_tag = ConfirmationTag::create( let confirmation_tag = ConfirmationTag::create(
&key_schedule_result.confirmation_key, &key_schedule_result.confirmation_key,
&provisional_group_context.confirmed_transcript_hash, &provisional_state.group_context.confirmed_transcript_hash,
&self.cipher_suite_provider, &self.cipher_suite_provider,
) )
.await?; .await?;
let interim_transcript_hash = InterimTranscriptHash::create(
self.cipher_suite_provider(),
&provisional_state.group_context.confirmed_transcript_hash,
&confirmation_tag,
)
.await?;
auth_content.auth.confirmation_tag = Some(confirmation_tag.clone()); auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
let ratchet_tree_ext = commit_options let ratchet_tree_ext = commit_options
@@ -656,10 +721,10 @@ where
let info = self let info = self
.make_group_info( .make_group_info(
&provisional_group_context, &provisional_state.group_context,
extensions, extensions,
&confirmation_tag, &confirmation_tag,
new_signer_ref, &new_signer,
) )
.await?; .await?;
@@ -679,10 +744,10 @@ where
let welcome_group_info = self let welcome_group_info = self
.make_group_info( .make_group_info(
&provisional_group_context, &provisional_state.group_context,
welcome_group_info_extensions, welcome_group_info_extensions,
&confirmation_tag, &confirmation_tag,
new_signer_ref, &new_signer,
) )
.await?; .await?;
@@ -705,11 +770,11 @@ where
#[cfg(not(any(mls_build_async, not(feature = "rayon"))))] #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
let encrypted_path_secrets: Vec<_> = added_key_pkgs let encrypted_path_secrets: Vec<_> = added_key_pkgs
.into_par_iter() .into_par_iter()
.zip(provisional_state.indexes_of_added_kpkgs) .zip(&provisional_state.indexes_of_added_kpkgs)
.map(|(key_package, leaf_index)| { .map(|(key_package, leaf_index)| {
self.encrypt_group_secrets( self.encrypt_group_secrets(
&key_package, &key_package,
leaf_index, *leaf_index,
&key_schedule_result.joiner_secret, &key_schedule_result.joiner_secret,
path_secrets, path_secrets,
#[cfg(feature = "psk")] #[cfg(feature = "psk")]
@@ -725,12 +790,12 @@ where
for (key_package, leaf_index) in added_key_pkgs for (key_package, leaf_index) in added_key_pkgs
.into_iter() .into_iter()
.zip(provisional_state.indexes_of_added_kpkgs) .zip(&provisional_state.indexes_of_added_kpkgs)
{ {
secrets.push( secrets.push(
self.encrypt_group_secrets( self.encrypt_group_secrets(
&key_package, &key_package,
leaf_index, *leaf_index,
&key_schedule_result.joiner_secret, &key_schedule_result.joiner_secret,
path_secrets, path_secrets,
#[cfg(feature = "psk")] #[cfg(feature = "psk")]
@@ -756,31 +821,61 @@ where
let commit_message = self.format_for_wire(auth_content.clone()).await?; let commit_message = self.format_for_wire(auth_content.clone()).await?;
let pending_commit = CommitGeneration { // TODO is it necessary to clone the tree here? or can we just output serialized bytes?
content: auth_content, let ratchet_tree = (!commit_options.ratchet_tree_extension)
pending_private_tree: provisional_private_tree, .then(|| ExportedTree::new(provisional_state.public_tree.nodes.clone()));
pending_commit_secret: commit_secret,
commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message) let pending_reinit = provisional_state
.applied_proposals
.reinitializations
.first();
let pending_commit = PendingCommit {
output: CommitMessageDescription {
is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
authenticated_data: auth_content.content.authenticated_data,
committer: *provisional_private_tree.self_index,
effect: match pending_reinit {
Some(r) => CommitEffect::ReInit(r.clone()),
None => CommitEffect::NewEpoch(
NewEpoch::new(self.state.clone(), &provisional_state).into(),
),
},
},
state: GroupState {
#[cfg(feature = "by_ref_proposal")]
proposals: crate::group::ProposalCache::new(
self.protocol_version(),
self.group_id().to_vec(),
),
context: provisional_state.group_context,
public_tree: provisional_state.public_tree,
interim_transcript_hash,
pending_reinit: pending_reinit.map(|r| r.proposal.clone()),
confirmation_tag,
},
commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
.await?, .await?,
signer: new_signer,
epoch_secrets: key_schedule_result.epoch_secrets,
key_schedule: key_schedule_result.key_schedule,
private_tree: provisional_private_tree,
}; };
self.pending_commit = Some(pending_commit); let output = CommitOutput {
let ratchet_tree = (!commit_options.ratchet_tree_extension)
.then(|| ExportedTree::new(provisional_state.public_tree.nodes));
if let Some(signer) = new_signer {
self.signer = signer;
}
Ok(CommitOutput {
commit_message, commit_message,
welcome_messages, welcome_messages,
ratchet_tree, ratchet_tree,
external_commit_group_info, external_commit_group_info,
contains_update_path: perform_path_update,
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
unused_proposals: provisional_state.unused_proposals, unused_proposals: provisional_state.unused_proposals,
}) };
Ok((output, pending_commit))
} }
// Construct a GroupInfo reflecting the new state // Construct a GroupInfo reflecting the new state
@@ -856,25 +951,14 @@ pub(crate) mod test_utils {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use alloc::boxed::Box;
use mls_rs_core::{ use mls_rs_core::{
error::IntoAnyError, error::IntoAnyError,
extension::ExtensionType, extension::ExtensionType,
identity::{CredentialType, IdentityProvider}, identity::{CredentialType, IdentityProvider, MemberValidationContext},
time::MlsTime, time::MlsTime,
}; };
use crate::{ use crate::extension::RequiredCapabilitiesExt;
crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
mls_rules::CommitOptions,
Client,
};
#[cfg(feature = "by_ref_proposal")]
use crate::extension::ExternalSendersExt;
use crate::{ use crate::{
client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
client_builder::{ client_builder::{
@@ -882,7 +966,9 @@ mod tests {
WithIdentityProvider, WithIdentityProvider,
}, },
client_config::ClientConfig, client_config::ClientConfig,
crypto::test_utils::TestCryptoProvider,
extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE}, extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
group::test_utils::{test_group, test_group_custom},
group::{ group::{
proposal::ProposalType, proposal::ProposalType,
test_utils::{test_group_custom_config, test_n_member_group}, test_utils::{test_group_custom_config, test_n_member_group},
@@ -890,9 +976,16 @@ mod tests {
identity::test_utils::get_test_signing_identity, identity::test_utils::get_test_signing_identity,
identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential}, identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
key_package::test_utils::test_key_package_message, key_package::test_utils::test_key_package_message,
mls_rules::CommitOptions,
Client,
}; };
use crate::extension::RequiredCapabilitiesExt; #[cfg(feature = "by_ref_proposal")]
use crate::crypto::test_utils::test_cipher_suite_provider;
#[cfg(feature = "by_ref_proposal")]
use crate::extension::ExternalSendersExt;
#[cfg(feature = "by_ref_proposal")]
use crate::group::mls_rules::DefaultMlsRules;
#[cfg(feature = "psk")] #[cfg(feature = "psk")]
use crate::{ use crate::{
@@ -1219,19 +1312,11 @@ mod tests {
let (bob, bob_kp) = let (bob, bob_kp) =
test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await; test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
group group.propose_add(alice_kp.clone(), vec![]).await.unwrap();
.group
.propose_add(alice_kp.clone(), vec![])
.await
.unwrap();
group group.propose_add(bob_kp.clone(), vec![]).await.unwrap();
.group
.propose_add(bob_kp.clone(), vec![])
.await
.unwrap();
let output = group.group.commit(Vec::new()).await.unwrap(); let output = group.commit(Vec::new()).await.unwrap();
let welcomes = output.welcome_messages; let welcomes = output.welcome_messages;
let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
@@ -1257,7 +1342,6 @@ mod tests {
let (identity, secret_key) = get_test_signing_identity(cs, b"member").await; let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
let commit_output = groups[0] let commit_output = groups[0]
.group
.commit_builder() .commit_builder()
.set_new_signing_identity(secret_key, identity.clone()) .set_new_signing_identity(secret_key, identity.clone())
.build() .build()
@@ -1266,7 +1350,7 @@ mod tests {
// Check that the credential was updated by in the committer's state. // Check that the credential was updated by in the committer's state.
groups[0].process_pending_commit().await.unwrap(); groups[0].process_pending_commit().await.unwrap();
let new_member = groups[0].group.roster().member_with_index(0).unwrap(); let new_member = groups[0].roster().member_with_index(0).unwrap();
assert_eq!( assert_eq!(
new_member.signing_identity.credential, new_member.signing_identity.credential,
@@ -1284,7 +1368,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let new_member = groups[1].group.roster().member_with_index(0).unwrap(); let new_member = groups[1].roster().member_with_index(0).unwrap();
assert_eq!( assert_eq!(
new_member.signing_identity.credential, new_member.signing_identity.credential,
@@ -1306,8 +1390,7 @@ mod tests {
None, None,
Some(CommitOptions::new().with_ratchet_tree_extension(false)), Some(CommitOptions::new().with_ratchet_tree_extension(false)),
) )
.await .await;
.group;
let commit = group.commit(vec![]).await.unwrap(); let commit = group.commit(vec![]).await.unwrap();
@@ -1327,8 +1410,7 @@ mod tests {
None, None,
Some(CommitOptions::new().with_ratchet_tree_extension(true)), Some(CommitOptions::new().with_ratchet_tree_extension(true)),
) )
.await .await;
.group;
let commit = group.commit(vec![]).await.unwrap(); let commit = group.commit(vec![]).await.unwrap();
@@ -1348,8 +1430,7 @@ mod tests {
.with_ratchet_tree_extension(false), .with_ratchet_tree_extension(false),
), ),
) )
.await .await;
.group;
let commit = group.commit(vec![]).await.unwrap(); let commit = group.commit(vec![]).await.unwrap();
@@ -1376,8 +1457,7 @@ mod tests {
.with_ratchet_tree_extension(true), .with_ratchet_tree_extension(true),
), ),
) )
.await .await;
.group;
let commit = group.commit(vec![]).await.unwrap(); let commit = group.commit(vec![]).await.unwrap();
@@ -1400,8 +1480,7 @@ mod tests {
None, None,
Some(CommitOptions::new().with_allow_external_commit(false)), Some(CommitOptions::new().with_allow_external_commit(false)),
) )
.await .await;
.group;
let commit = group.commit(vec![]).await.unwrap(); let commit = group.commit(vec![]).await.unwrap();
@@ -1411,10 +1490,16 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn member_identity_is_validated_against_new_extensions() { async fn member_identity_is_validated_against_new_extensions() {
let alice = client_with_test_extension(b"alice").await; let alice = client_with_test_extension(b"alice").await;
let mut alice = alice.create_group(ExtensionList::new()).await.unwrap(); let mut alice = alice
.create_group(ExtensionList::new(), Default::default())
.await
.unwrap();
let bob = client_with_test_extension(b"bob").await; let bob = client_with_test_extension(b"bob").await;
let bob_kp = bob.generate_key_package_message().await.unwrap(); let bob_kp = bob
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
let mut extension_list = ExtensionList::new(); let mut extension_list = ExtensionList::new();
let extension = TestExtension { foo: b'a' }; let extension = TestExtension { foo: b'a' };
@@ -1435,7 +1520,11 @@ mod tests {
alice alice
.commit_builder() .commit_builder()
.add_member(alex.generate_key_package_message().await.unwrap()) .add_member(
alex.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap(),
)
.unwrap() .unwrap()
.set_group_context_ext(extension_list.clone()) .set_group_context_ext(extension_list.clone())
.unwrap() .unwrap()
@@ -1448,7 +1537,10 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn server_identity_is_validated_against_new_extensions() { async fn server_identity_is_validated_against_new_extensions() {
let alice = client_with_test_extension(b"alice").await; let alice = client_with_test_extension(b"alice").await;
let mut alice = alice.create_group(ExtensionList::new()).await.unwrap(); let mut alice = alice
.create_group(ExtensionList::new(), Default::default())
.await
.unwrap();
let mut extension_list = ExtensionList::new(); let mut extension_list = ExtensionList::new();
let extension = TestExtension { foo: b'a' }; let extension = TestExtension { foo: b'a' };
@@ -1538,9 +1630,9 @@ mod tests {
&self, &self,
identity: &SigningIdentity, identity: &SigningIdentity,
timestamp: Option<MlsTime>, timestamp: Option<MlsTime>,
extensions: Option<&ExtensionList>, context: MemberValidationContext<'_>,
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
self.starts_with_foo(identity, timestamp, extensions) self.starts_with_foo(identity, timestamp, context.new_extensions())
.await .await
.then_some(()) .then_some(())
.ok_or(IdentityProviderWithExtensionError {}) .ok_or(IdentityProviderWithExtensionError {})
@@ -1598,4 +1690,14 @@ mod tests {
.signing_identity(identity, secret_key, TEST_CIPHER_SUITE) .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
.build() .build()
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn detached_commit() {
let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let (_commit, secrets) = group.commit_builder().build_detached().await.unwrap();
assert!(group.pending_commit.is_none());
group.apply_detached_commit(secrets).await.unwrap();
assert_eq!(group.context().epoch, 1);
}
} }

View File

@@ -0,0 +1,26 @@
use crate::client::MlsError;
use alloc::vec::Vec;
use mls_rs_codec::{MlsEncode, MlsSize};
pub type ComponentID = u32;
#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
pub struct ComponentOperationLabel<'a> {
label: &'static [u8],
component_id: ComponentID,
context: &'a [u8],
}
impl<'a> ComponentOperationLabel<'a> {
pub fn new(component_id: u32, context: &'a [u8]) -> Self {
Self {
label: b"MLS 1.0 Application",
component_id,
context,
}
}
pub fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
self.mls_encode_to_vec().map_err(Into::into)
}
}

View File

@@ -3,7 +3,7 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use crate::CipherSuiteProvider; use crate::CipherSuiteProvider;
use crate::{client::MlsError, group::transcript_hash::ConfirmedTranscriptHash}; use crate::{client::MlsError, group::ConfirmedTranscriptHash};
use alloc::vec::Vec; use alloc::vec::Vec;
use core::{ use core::{
fmt::{self, Debug}, fmt::{self, Debug},
@@ -12,7 +12,7 @@ use core::{
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::error::IntoAnyError; use mls_rs_core::error::IntoAnyError;
#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode, Default)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConfirmationTag( pub struct ConfirmationTag(

View File

@@ -7,6 +7,8 @@ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use crate::{client::MlsError, tree_kem::node::NodeVec}; use crate::{client::MlsError, tree_kem::node::NodeVec};
use super::Roster;
#[cfg_attr( #[cfg_attr(
all(feature = "ffi", not(test)), all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque) safer_ffi_gen::ffi_type(clone, opaque)
@@ -35,6 +37,12 @@ impl<'a> ExportedTree<'a> {
pub fn into_owned(self) -> ExportedTree<'static> { pub fn into_owned(self) -> ExportedTree<'static> {
ExportedTree(Cow::Owned(self.0.into_owned())) ExportedTree(Cow::Owned(self.0.into_owned()))
} }
pub fn roster(&'a self) -> Roster<'a> {
Roster {
public_tree: &self.0,
}
}
} }
#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]

View File

@@ -2,7 +2,9 @@
// Copyright by contributors to this project. // Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT) // SPDX-License-Identifier: (Apache-2.0 OR MIT)
use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity}; use mls_rs_core::{
crypto::SignatureSecretKey, extension::ExtensionList, identity::SigningIdentity,
};
use crate::{ use crate::{
client_config::ClientConfig, client_config::ClientConfig,
@@ -39,13 +41,14 @@ use crate::group::{
PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID}, PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID},
}; };
use super::{validate_group_info_joiner, ExportedTree}; use super::{validate_tree_and_info_joiner, ExportedTree};
/// A builder that aids with the construction of an external commit. /// A builder that aids with the construction of an external commit.
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))] #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
pub struct ExternalCommitBuilder<C: ClientConfig> { pub struct ExternalCommitBuilder<C: ClientConfig> {
signer: SignatureSecretKey, signer: SignatureSecretKey,
signing_identity: SigningIdentity, signing_identity: SigningIdentity,
leaf_node_extensions: ExtensionList,
config: C, config: C,
tree_data: Option<ExportedTree<'static>>, tree_data: Option<ExportedTree<'static>>,
to_remove: Option<u32>, to_remove: Option<u32>,
@@ -70,6 +73,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
authenticated_data: Vec::new(), authenticated_data: Vec::new(),
signer, signer,
signing_identity, signing_identity,
leaf_node_extensions: Default::default(),
config, config,
#[cfg(feature = "psk")] #[cfg(feature = "psk")]
external_psks: Vec::new(), external_psks: Vec::new(),
@@ -140,6 +144,14 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
self self
} }
/// Change the committer's leaf node extensions as part of making this commit.
pub fn with_leaf_node_extensions(self, leaf_node_extensions: ExtensionList) -> Self {
Self {
leaf_node_extensions,
..self
}
}
/// Build the external commit using a GroupInfo message provided by an existing group member. /// Build the external commit using a GroupInfo message provided by an existing group member.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn build(self, group_info: MlsMessage) -> Result<(Group<C>, MlsMessage), MlsError> { pub async fn build(self, group_info: MlsMessage) -> Result<(Group<C>, MlsMessage), MlsError> {
@@ -163,7 +175,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
.get_as::<ExternalPubExt>()? .get_as::<ExternalPubExt>()?
.ok_or(MlsError::MissingExternalPubExtension)?; .ok_or(MlsError::MissingExternalPubExtension)?;
let public_tree = validate_group_info_joiner( let public_tree = validate_tree_and_info_joiner(
protocol_version, protocol_version,
&group_info, &group_info,
self.tree_data, self.tree_data,
@@ -174,7 +186,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
let (leaf_node, _) = LeafNode::generate( let (leaf_node, _) = LeafNode::generate(
&cipher_suite, &cipher_suite,
self.config.leaf_properties(), self.config.leaf_properties(self.leaf_node_extensions),
self.signing_identity, self.signing_identity,
&self.signer, &self.signer,
self.config.lifetime(), self.config.lifetime(),
@@ -233,9 +245,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
}; };
let auth_content = AuthenticatedContent::from(plaintext.clone()); let auth_content = AuthenticatedContent::from(plaintext.clone());
verify_plaintext_authentication(&cipher_suite, plaintext, None, &group.state).await?;
verify_plaintext_authentication(&cipher_suite, plaintext, None, None, &group.state)
.await?;
group group
.process_event_or_content(EventOrContent::Content(auth_content), true, None) .process_event_or_content(EventOrContent::Content(auth_content), true, None)
@@ -248,7 +258,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
})); }));
} }
let commit_output = group let (commit_output, pending_commit) = group
.commit_internal( .commit_internal(
proposals, proposals,
Some(&leaf_node), Some(&leaf_node),
@@ -256,9 +266,11 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
Default::default(), Default::default(),
None, None,
None, None,
None,
) )
.await?; .await?;
group.pending_commit = pending_commit.try_into()?;
group.apply_pending_commit().await?; group.apply_pending_commit().await?;
Ok((group, commit_output.commit_message)) Ok((group, commit_output.commit_message))

View File

@@ -410,6 +410,13 @@ impl MlsMessage {
} }
} }
pub fn as_key_package(&self) -> Option<&KeyPackage> {
match &self.payload {
MlsMessagePayload::KeyPackage(kp) => Some(kp),
_ => None,
}
}
/// The wire format value describing the contents of this message. /// The wire format value describing the contents of this message.
pub fn wire_format(&self) -> WireFormat { pub fn wire_format(&self) -> WireFormat {
match self.payload { match self.payload {
@@ -505,7 +512,7 @@ impl MlsMessage {
} }
/// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with
/// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals). /// [`NewEpoch::unused_proposals`](super::NewEpoch::unused_proposals).
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn into_proposal_reference<C: CipherSuiteProvider>( pub async fn into_proposal_reference<C: CipherSuiteProvider>(

View File

@@ -41,7 +41,6 @@ impl Debug for GroupInfo {
} }
} }
// #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
impl GroupInfo { impl GroupInfo {
/// Group context. /// Group context.
pub fn group_context(&self) -> &GroupContext { pub fn group_context(&self) -> &GroupContext {
@@ -68,7 +67,7 @@ struct SignableGroupInfo<'a> {
signer: LeafIndex, signer: LeafIndex,
} }
impl<'a> Signable<'a> for GroupInfo { impl Signable<'_> for GroupInfo {
const SIGN_LABEL: &'static str = "GroupInfoTBS"; const SIGN_LABEL: &'static str = "GroupInfoTBS";
type SigningContext = (); type SigningContext = ();

View File

@@ -250,7 +250,10 @@ async fn invite_passive_client<P: CipherSuiteProvider>(
.signing_identity(identity.clone(), secret_key.clone(), cs.cipher_suite()) .signing_identity(identity.clone(), secret_key.clone(), cs.cipher_suite())
.build(); .build();
let key_pckg = client.generate_key_package_message().await.unwrap(); let key_pckg = client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
let (_, key_pckg_secrets) = key_package_repo.key_packages()[0].clone(); let (_, key_pckg_secrets) = key_package_repo.key_packages()[0].clone();
@@ -489,7 +492,10 @@ async fn create_key_package(cs: CipherSuite) -> MlsMessage {
) )
.await; .await;
client.generate_key_package_message().await.unwrap() client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap()
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
@@ -554,7 +560,10 @@ pub async fn generate_passive_client_random_tests() -> Vec<TestCase> {
generate_basic_client(cs, VERSION, 0, None, false, &crypto, Some(ETERNAL_LIFETIME)) generate_basic_client(cs, VERSION, 0, None, false, &crypto, Some(ETERNAL_LIFETIME))
.await; .await;
let creator_group = creator.create_group(Default::default()).await.unwrap(); let creator_group = creator
.create_group(Default::default(), Default::default())
.await
.unwrap();
let mut groups = vec![creator_group]; let mut groups = vec![creator_group];
@@ -646,7 +655,10 @@ pub async fn add_random_members<C: MlsConfig>(
let mut key_packages = Vec::new(); let mut key_packages = Vec::new();
for client in &clients { for client in &clients {
let key_package = client.generate_key_package_message().await.unwrap(); let key_package = client
.generate_key_package_message(Default::default(), Default::default())
.await
.unwrap();
key_packages.push(key_package); key_packages.push(key_package);
} }

View File

@@ -166,7 +166,7 @@ async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
&test_cipher_suite_provider(TEST_CIPHER_SUITE), &test_cipher_suite_provider(TEST_CIPHER_SUITE),
TEST_GROUP, TEST_GROUP,
i, i,
default_properties(), Some(default_properties()),
None, None,
signer, signer,
) )

View File

@@ -83,6 +83,10 @@ impl KeySchedule {
} }
} }
pub fn delete_exporter(&mut self) {
self.exporter_secret = Default::default();
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn derive_for_external<P: CipherSuiteProvider>( pub async fn derive_for_external<P: CipherSuiteProvider>(
&self, &self,
@@ -234,6 +238,10 @@ impl KeySchedule {
len: usize, len: usize,
cipher_suite: &P, cipher_suite: &P,
) -> Result<Zeroizing<Vec<u8>>, MlsError> { ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
if self.exporter_secret.is_empty() {
return Err(MlsError::ExporterDeleted);
}
let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?; let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
let context_hash = cipher_suite let context_hash = cipher_suite

View File

@@ -0,0 +1,37 @@
use alloc::vec::Vec;
use core::fmt;
use core::fmt::Debug;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::crypto::CipherSuiteProvider;
use crate::{client::MlsError, error::IntoAnyError, MlsMessage};
#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct MessageHash(
#[mls_codec(with = "mls_rs_codec::byte_vec")]
#[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
Vec<u8>,
);
impl Debug for MessageHash {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
mls_rs_core::debug::pretty_bytes(&self.0)
.named("MessageHash")
.fmt(f)
}
}
impl MessageHash {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn compute<CS: CipherSuiteProvider>(
cs: &CS,
message: &MlsMessage,
) -> Result<Self, MlsError> {
cs.hash(&message.mls_encode_to_vec()?)
.await
.map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
.map(Self)
}
}

View File

@@ -13,7 +13,8 @@ use super::{
proposal_filter::ProposalBundle, proposal_filter::ProposalBundle,
state::GroupState, state::GroupState,
transcript_hash::InterimTranscriptHash, transcript_hash::InterimTranscriptHash,
transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome, transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, ReInitProposal,
RemoveProposal, Welcome,
}; };
use crate::{ use crate::{
client::MlsError, client::MlsError,
@@ -27,12 +28,16 @@ use crate::{
}, },
CipherSuiteProvider, KeyPackage, CipherSuiteProvider, KeyPackage,
}; };
#[cfg(mls_build_async)] use itertools::Itertools;
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use alloc::boxed::Box; use alloc::boxed::Box;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::fmt::{self, Debug}; use core::fmt::{self, Debug};
use mls_rs_core::{ use mls_rs_core::{
identity::IdentityProvider, protocol_version::ProtocolVersion, psk::PreSharedKeyStorage, identity::{IdentityProvider, MemberValidationContext},
protocol_version::ProtocolVersion,
psk::PreSharedKeyStorage,
}; };
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
@@ -41,36 +46,12 @@ use super::proposal_ref::ProposalRef;
#[cfg(not(feature = "by_ref_proposal"))] #[cfg(not(feature = "by_ref_proposal"))]
use crate::group::proposal_cache::resolve_for_commit; use crate::group::proposal_cache::resolve_for_commit;
#[cfg(feature = "by_ref_proposal")]
use super::proposal::Proposal; use super::proposal::Proposal;
#[cfg(feature = "custom_proposal")]
use super::proposal_filter::ProposalInfo; use super::proposal_filter::ProposalInfo;
#[cfg(feature = "state_update")]
use mls_rs_core::{
crypto::CipherSuite,
group::{MemberUpdate, RosterUpdate},
};
#[cfg(all(feature = "state_update", feature = "psk"))]
use mls_rs_core::psk::ExternalPskId;
#[cfg(feature = "state_update")]
use crate::tree_kem::UpdatePath;
#[cfg(feature = "state_update")]
use super::{member_from_key_package, member_from_leaf_node};
#[cfg(all(feature = "state_update", feature = "custom_proposal"))]
use super::proposal::CustomProposal;
#[cfg(feature = "private_message")] #[cfg(feature = "private_message")]
use crate::group::framing::PrivateMessage; use crate::group::framing::PrivateMessage;
#[cfg(feature = "by_ref_proposal")]
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ProvisionalState { pub(crate) struct ProvisionalState {
pub(crate) public_tree: TreeKemPublic, pub(crate) public_tree: TreeKemPublic,
@@ -78,8 +59,7 @@ pub(crate) struct ProvisionalState {
pub(crate) group_context: GroupContext, pub(crate) group_context: GroupContext,
pub(crate) external_init_index: Option<LeafIndex>, pub(crate) external_init_index: Option<LeafIndex>,
pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>, pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
#[cfg(feature = "by_ref_proposal")] pub(crate) unused_proposals: Vec<ProposalInfo<Proposal>>,
pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
} }
//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if //By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
@@ -93,7 +73,7 @@ pub(crate) struct ProvisionalState {
// psk // psk
// reinit // reinit
pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool { pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
let res = proposals.external_init_proposals().first().is_some(); let res = !proposals.external_init_proposals().is_empty();
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
let res = res || !proposals.update_proposals().is_empty(); let res = res || !proposals.update_proposals().is_empty();
@@ -103,72 +83,114 @@ pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
|| !proposals.remove_proposals().is_empty() || !proposals.remove_proposals().is_empty()
} }
/// Representation of changes made by a [commit](crate::Group::commit). #[cfg_attr(
#[cfg(feature = "state_update")] all(feature = "ffi", not(test)),
#[derive(Clone, Debug, PartialEq)] safer_ffi_gen::ffi_type(clone, opaque)
pub struct StateUpdate { )]
pub(crate) roster_update: RosterUpdate, #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg(feature = "psk")] #[non_exhaustive]
pub(crate) added_psks: Vec<ExternalPskId>, pub struct NewEpoch {
pub(crate) pending_reinit: Option<CipherSuite>, pub epoch: u64,
pub(crate) active: bool, pub prior_state: GroupState,
pub(crate) epoch: u64, pub applied_proposals: Vec<ProposalInfo<Proposal>>,
#[cfg(feature = "custom_proposal")] pub unused_proposals: Vec<ProposalInfo<Proposal>>,
pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
#[cfg(feature = "by_ref_proposal")]
pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
} }
#[cfg(not(feature = "state_update"))] impl NewEpoch {
#[non_exhaustive] pub(crate) fn new(prior_state: GroupState, provisional_state: &ProvisionalState) -> NewEpoch {
#[derive(Clone, Debug, PartialEq)] NewEpoch {
pub struct StateUpdate {} epoch: provisional_state.group_context.epoch,
prior_state,
#[cfg(feature = "state_update")] unused_proposals: provisional_state.unused_proposals.clone(),
impl StateUpdate { applied_proposals: provisional_state
/// Changes to the roster as a result of proposals. .applied_proposals
pub fn roster_update(&self) -> &RosterUpdate { .clone()
&self.roster_update .into_proposals()
.collect_vec(),
}
} }
}
#[cfg(feature = "psk")] #[cfg(all(feature = "ffi", not(test)))]
/// Pre-shared keys that have been added to the group. #[safer_ffi_gen::safer_ffi_gen]
pub fn added_psks(&self) -> &[ExternalPskId] { impl NewEpoch {
&self.added_psks pub fn epoch(&self) -> u64 {
}
/// Flag to indicate if the group is now pending reinitialization due to
/// receiving a [`ReInit`](crate::group::proposal::Proposal::ReInit)
/// proposal.
pub fn is_pending_reinit(&self) -> bool {
self.pending_reinit.is_some()
}
/// Flag to indicate the group is still active. This will be false if the
/// member processing the commit has been removed from the group.
pub fn is_active(&self) -> bool {
self.active
}
/// The new epoch of the group state.
pub fn new_epoch(&self) -> u64 {
self.epoch self.epoch
} }
/// Custom proposals that were committed to. pub fn prior_state(&self) -> &GroupState {
#[cfg(feature = "custom_proposal")] &self.prior_state
pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
&self.custom_proposals
} }
/// Proposals that were received in the prior epoch but not committed to. pub fn applied_proposals(&self) -> &[ProposalInfo<Proposal>] {
#[cfg(feature = "by_ref_proposal")] &self.applied_proposals
pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] { }
pub fn unused_proposals(&self) -> &[ProposalInfo<Proposal>] {
&self.unused_proposals &self.unused_proposals
} }
}
pub fn pending_reinit_ciphersuite(&self) -> Option<CipherSuite> { #[cfg_attr(
self.pending_reinit all(feature = "ffi", not(test)),
safer_ffi_gen::ffi_type(clone, opaque)
)]
#[derive(Clone, Debug, PartialEq)]
pub enum CommitEffect {
NewEpoch(Box<NewEpoch>),
Removed {
new_epoch: Box<NewEpoch>,
remover: Sender,
},
ReInit(ProposalInfo<ReInitProposal>),
}
impl MlsSize for CommitEffect {
fn mls_encoded_len(&self) -> usize {
0u8.mls_encoded_len()
+ match self {
Self::NewEpoch(e) => e.mls_encoded_len(),
Self::Removed { new_epoch, remover } => {
new_epoch.mls_encoded_len() + remover.mls_encoded_len()
}
Self::ReInit(r) => r.mls_encoded_len(),
}
}
}
impl MlsEncode for CommitEffect {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
match self {
Self::NewEpoch(e) => {
1u8.mls_encode(writer)?;
e.mls_encode(writer)?;
}
Self::Removed { new_epoch, remover } => {
2u8.mls_encode(writer)?;
new_epoch.mls_encode(writer)?;
remover.mls_encode(writer)?;
}
Self::ReInit(r) => {
3u8.mls_encode(writer)?;
r.mls_encode(writer)?;
}
}
Ok(())
}
}
impl MlsDecode for CommitEffect {
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
match u8::mls_decode(reader)? {
1u8 => Ok(Self::NewEpoch(NewEpoch::mls_decode(reader)?.into())),
2u8 => Ok(Self::Removed {
new_epoch: NewEpoch::mls_decode(reader)?.into(),
remover: Sender::mls_decode(reader)?,
}),
3u8 => Ok(Self::ReInit(ProposalInfo::mls_decode(reader)?)),
_ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
}
} }
} }
@@ -272,7 +294,7 @@ impl ApplicationMessageDescription {
// all(feature = "ffi", not(test)), // all(feature = "ffi", not(test)),
// safer_ffi_gen::ffi_type(clone, opaque) // safer_ffi_gen::ffi_type(clone, opaque)
// )] // )]
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[non_exhaustive] #[non_exhaustive]
/// Description of a processed MLS commit message. /// Description of a processed MLS commit message.
pub struct CommitMessageDescription { pub struct CommitMessageDescription {
@@ -281,8 +303,9 @@ pub struct CommitMessageDescription {
/// The index in the group state of the member who performed this commit. /// The index in the group state of the member who performed this commit.
pub committer: u32, pub committer: u32,
/// A full description of group state changes as a result of this commit. /// A full description of group state changes as a result of this commit.
pub state_update: StateUpdate, pub effect: CommitEffect,
/// Plaintext authenticated data in the received MLS packet. /// Plaintext authenticated data in the received MLS packet.
#[mls_codec(with = "mls_rs_codec::byte_vec")]
pub authenticated_data: Vec<u8>, pub authenticated_data: Vec<u8>,
} }
@@ -291,7 +314,7 @@ impl Debug for CommitMessageDescription {
f.debug_struct("CommitMessageDescription") f.debug_struct("CommitMessageDescription")
.field("is_external", &self.is_external) .field("is_external", &self.is_external)
.field("committer", &self.committer) .field("committer", &self.committer)
.field("state_update", &self.state_update) .field("effect", &self.effect)
.field( .field(
"authenticated_data", "authenticated_data",
&mls_rs_core::debug::pretty_bytes(&self.authenticated_data), &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
@@ -300,16 +323,18 @@ impl Debug for CommitMessageDescription {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
/// Proposal sender type. /// Proposal sender type.
pub enum ProposalSender { pub enum ProposalSender {
/// A current member of the group by index in the group state. /// A current member of the group by index in the group state.
Member(u32), Member(u32) = 1u8,
/// An external entity by index within an /// An external entity by index within an
/// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt). /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
External(u32), External(u32) = 2u8,
/// A new member proposing their addition to the group. /// A new member proposing their addition to the group.
NewMember, NewMember = 3u8,
} }
impl TryFrom<Sender> for ProposalSender { impl TryFrom<Sender> for ProposalSender {
@@ -332,7 +357,8 @@ impl TryFrom<Sender> for ProposalSender {
// all(feature = "ffi", not(test)), // all(feature = "ffi", not(test)),
// safer_ffi_gen::ffi_type(clone, opaque) // safer_ffi_gen::ffi_type(clone, opaque)
// )] // )]
#[derive(Clone)] #[derive(Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive] #[non_exhaustive]
/// Description of a processed MLS proposal message. /// Description of a processed MLS proposal message.
pub struct ProposalMessageDescription { pub struct ProposalMessageDescription {
@@ -401,6 +427,20 @@ impl ProposalMessageDescription {
pub fn proposal_ref(&self) -> Vec<u8> { pub fn proposal_ref(&self) -> Vec<u8> {
self.proposal_ref.to_vec() self.proposal_ref.to_vec()
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn new<C: CipherSuiteProvider>(
cs: &C,
content: &AuthenticatedContent,
proposal: Proposal,
) -> Result<Self, MlsError> {
Ok(ProposalMessageDescription {
authenticated_data: content.content.authenticated_data.clone(),
proposal,
sender: content.content.sender.try_into()?,
proposal_ref: ProposalRef::from_content(cs, content).await?,
})
}
} }
#[cfg(not(feature = "by_ref_proposal"))] #[cfg(not(feature = "by_ref_proposal"))]
@@ -587,127 +627,24 @@ pub(crate) trait MessageProcessor: Send + Sync {
proposal: &Proposal, proposal: &Proposal,
cache_proposal: bool, cache_proposal: bool,
) -> Result<ProposalMessageDescription, MlsError> { ) -> Result<ProposalMessageDescription, MlsError> {
let proposal_ref = let proposal = ProposalMessageDescription::new(
ProposalRef::from_content(self.cipher_suite_provider(), auth_content).await?; self.cipher_suite_provider(),
auth_content,
proposal.clone(),
)
.await?;
let group_state = self.group_state_mut(); let group_state = self.group_state_mut();
if cache_proposal { if cache_proposal {
let proposal_ref = proposal_ref.clone();
group_state.proposals.insert( group_state.proposals.insert(
proposal_ref.clone(), proposal.proposal_ref.clone(),
proposal.clone(), proposal.proposal.clone(),
auth_content.content.sender, auth_content.content.sender,
); );
} }
Ok(ProposalMessageDescription { Ok(proposal)
authenticated_data: auth_content.content.authenticated_data.clone(),
proposal: proposal.clone(),
sender: auth_content.content.sender.try_into()?,
proposal_ref,
})
}
#[cfg(feature = "state_update")]
async fn make_state_update(
&self,
provisional: &ProvisionalState,
path: Option<&UpdatePath>,
sender: LeafIndex,
) -> Result<StateUpdate, MlsError> {
let added = provisional
.applied_proposals
.additions
.iter()
.zip(provisional.indexes_of_added_kpkgs.iter())
.map(|(p, index)| member_from_key_package(&p.proposal.key_package, *index))
.collect::<Vec<_>>();
let mut added = added;
let old_tree = &self.group_state().public_tree;
let removed = provisional
.applied_proposals
.removals
.iter()
.map(|p| {
let index = p.proposal.to_remove;
let node = old_tree.nodes.borrow_as_leaf(index)?;
Ok(member_from_leaf_node(node, index))
})
.collect::<Result<_, MlsError>>()?;
#[cfg(feature = "by_ref_proposal")]
let mut updated = provisional
.applied_proposals
.update_senders
.iter()
.map(|index| {
let prior = old_tree
.get_leaf_node(*index)
.map(|n| member_from_leaf_node(n, *index))?;
let new = provisional
.public_tree
.get_leaf_node(*index)
.map(|n| member_from_leaf_node(n, *index))?;
Ok::<_, MlsError>(MemberUpdate::new(prior, new))
})
.collect::<Result<Vec<_>, _>>()?;
#[cfg(not(feature = "by_ref_proposal"))]
let mut updated = Vec::new();
if let Some(path) = path {
if !provisional
.applied_proposals
.external_initializations
.is_empty()
{
added.push(member_from_leaf_node(&path.leaf_node, sender))
} else {
let prior = old_tree
.get_leaf_node(sender)
.map(|n| member_from_leaf_node(n, sender))?;
let new = member_from_leaf_node(&path.leaf_node, sender);
updated.push(MemberUpdate::new(prior, new))
}
}
#[cfg(feature = "psk")]
let psks = provisional
.applied_proposals
.psks
.iter()
.filter_map(|psk| psk.proposal.external_psk_id().cloned())
.collect::<Vec<_>>();
let roster_update = RosterUpdate::new(added, removed, updated);
let update = StateUpdate {
roster_update,
#[cfg(feature = "psk")]
added_psks: psks,
pending_reinit: provisional
.applied_proposals
.reinitializations
.first()
.map(|ri| ri.proposal.new_cipher_suite()),
active: true,
epoch: provisional.group_context.epoch,
#[cfg(feature = "custom_proposal")]
custom_proposals: provisional.applied_proposals.custom_proposals.clone(),
#[cfg(feature = "by_ref_proposal")]
unused_proposals: provisional.unused_proposals.clone(),
};
Ok(update)
} }
async fn process_commit( async fn process_commit(
@@ -763,33 +700,14 @@ pub(crate) trait MessageProcessor: Send + Sync {
let sender = commit_sender(&auth_content.content.sender, &provisional_state)?; let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
#[cfg(feature = "state_update")]
let mut state_update = self
.make_state_update(&provisional_state, commit.path.as_ref(), sender)
.await?;
#[cfg(not(feature = "state_update"))]
let state_update = StateUpdate {};
//Verify that the path value is populated if the proposals vector contains any Update //Verify that the path value is populated if the proposals vector contains any Update
// or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted. // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() { if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
return Err(MlsError::CommitMissingPath); return Err(MlsError::CommitMissingPath);
} }
if !self.can_continue_processing(&provisional_state) { let self_removed = self.removal_proposal(&provisional_state);
#[cfg(feature = "state_update")] let is_self_removed = self_removed.is_some();
{
state_update.active = false;
}
return Ok(CommitMessageDescription {
is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
authenticated_data: auth_content.content.authenticated_data,
committer: *sender,
state_update,
});
}
let update_path = match commit.path { let update_path = match commit.path {
Some(update_path) => Some( Some(update_path) => Some(
@@ -800,18 +718,36 @@ pub(crate) trait MessageProcessor: Send + Sync {
&provisional_state, &provisional_state,
sender, sender,
time_sent, time_sent,
&group_state.context,
) )
.await?, .await?,
), ),
None => None, None => None,
}; };
let commit_effect =
if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
self.group_state_mut().pending_reinit = Some(reinit.proposal.clone());
CommitEffect::ReInit(reinit)
} else if let Some(remove_proposal) = self_removed {
let new_epoch = NewEpoch::new(self.group_state().clone(), &provisional_state);
CommitEffect::Removed {
remover: remove_proposal.sender,
new_epoch: Box::new(new_epoch),
}
} else {
CommitEffect::NewEpoch(Box::new(NewEpoch::new(
self.group_state().clone(),
&provisional_state,
)))
};
let new_secrets = match update_path { let new_secrets = match update_path {
Some(update_path) => { Some(update_path) if !is_self_removed => {
self.apply_update_path(sender, &update_path, &mut provisional_state) self.apply_update_path(sender, &update_path, &mut provisional_state)
.await .await
} }
None => Ok(None), _ => Ok(None),
}?; }?;
// Update the transcript hash to get the new context. // Update the transcript hash to get the new context.
@@ -829,30 +765,23 @@ pub(crate) trait MessageProcessor: Send + Sync {
.tree_hash(self.cipher_suite_provider()) .tree_hash(self.cipher_suite_provider())
.await?; .await?;
if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
self.group_state_mut().pending_reinit = Some(reinit.proposal);
#[cfg(feature = "state_update")]
{
state_update.active = false;
}
}
if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag { if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
// Update the key schedule to calculate new private keys if !is_self_removed {
self.update_key_schedule( // Update the key schedule to calculate new private keys
new_secrets, self.update_key_schedule(
interim_transcript_hash, new_secrets,
confirmation_tag, interim_transcript_hash,
provisional_state, confirmation_tag,
) provisional_state,
.await?; )
.await?;
}
Ok(CommitMessageDescription { Ok(CommitMessageDescription {
is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit), is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
authenticated_data: auth_content.content.authenticated_data, authenticated_data: auth_content.content.authenticated_data,
committer: *sender, committer: *sender,
state_update, effect: commit_effect,
}) })
} else { } else {
Err(MlsError::InvalidConfirmationTag) Err(MlsError::InvalidConfirmationTag)
@@ -865,7 +794,11 @@ pub(crate) trait MessageProcessor: Send + Sync {
fn identity_provider(&self) -> Self::IdentityProvider; fn identity_provider(&self) -> Self::IdentityProvider;
fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider; fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
fn psk_storage(&self) -> Self::PreSharedKeyStorage; fn psk_storage(&self) -> Self::PreSharedKeyStorage;
fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool;
fn removal_proposal(
&self,
provisional_state: &ProvisionalState,
) -> Option<ProposalInfo<RemoveProposal>>;
#[cfg(feature = "private_message")] #[cfg(feature = "private_message")]
fn min_epoch_available(&self) -> Option<u64>; fn min_epoch_available(&self) -> Option<u64>;
@@ -1017,7 +950,7 @@ pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProv
cs: &C, cs: &C,
id: &I, id: &I,
) -> Result<(), MlsError> { ) -> Result<(), MlsError> {
let validator = LeafNodeValidator::new(cs, id, None); let validator = LeafNodeValidator::new(cs, id, MemberValidationContext::None);
#[cfg(feature = "std")] #[cfg(feature = "std")]
let context = Some(MlsTime::now()); let context = Some(MlsTime::now());
@@ -1035,3 +968,49 @@ pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProv
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use alloc::{vec, vec::Vec};
use mls_rs_codec::{MlsDecode, MlsEncode};
use crate::{
client::test_utils::TEST_PROTOCOL_VERSION,
group::{test_utils::get_test_group_context, GroupState, Sender},
};
use super::{CommitEffect, NewEpoch};
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn commit_effect_codec() {
let epoch = NewEpoch {
epoch: 7,
prior_state: GroupState {
#[cfg(feature = "by_ref_proposal")]
proposals: crate::group::ProposalCache::new(TEST_PROTOCOL_VERSION, vec![]),
context: get_test_group_context(7, 7.into()).await,
public_tree: Default::default(),
interim_transcript_hash: vec![].into(),
pending_reinit: None,
confirmation_tag: Default::default(),
},
applied_proposals: vec![],
unused_proposals: vec![],
};
let effects = vec![
CommitEffect::NewEpoch(epoch.clone().into()),
CommitEffect::Removed {
new_epoch: epoch.into(),
remover: Sender::Member(0),
},
];
let bytes = effects.mls_encode_to_vec().unwrap();
assert_eq!(
effects,
Vec::<CommitEffect>::mls_decode(&mut &*bytes).unwrap()
);
}
}

View File

@@ -159,7 +159,7 @@ pub(crate) struct AuthenticatedContentTBS<'a> {
pub(crate) context: Option<&'a GroupContext>, pub(crate) context: Option<&'a GroupContext>,
} }
impl<'a> MlsSize for AuthenticatedContentTBS<'a> { impl MlsSize for AuthenticatedContentTBS<'_> {
fn mls_encoded_len(&self) -> usize { fn mls_encoded_len(&self) -> usize {
self.protocol_version.mls_encoded_len() self.protocol_version.mls_encoded_len()
+ self.wire_format.mls_encoded_len() + self.wire_format.mls_encoded_len()
@@ -168,7 +168,7 @@ impl<'a> MlsSize for AuthenticatedContentTBS<'a> {
} }
} }
impl<'a> MlsEncode for AuthenticatedContentTBS<'a> { impl MlsEncode for AuthenticatedContentTBS<'_> {
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
self.protocol_version.mls_encode(writer)?; self.protocol_version.mls_encode(writer)?;
self.wire_format.mls_encode(writer)?; self.wire_format.mls_encode(writer)?;

View File

@@ -38,7 +38,6 @@ pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
cipher_suite_provider: &P, cipher_suite_provider: &P,
plaintext: PublicMessage, plaintext: PublicMessage,
key_schedule: Option<&KeySchedule>, key_schedule: Option<&KeySchedule>,
self_index: Option<LeafIndex>,
state: &GroupState, state: &GroupState,
) -> Result<AuthenticatedContent, MlsError> { ) -> Result<AuthenticatedContent, MlsError> {
let tag = plaintext.membership_tag.clone(); let tag = plaintext.membership_tag.clone();
@@ -52,7 +51,7 @@ pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
// Verify the membership tag if needed // Verify the membership tag if needed
match &auth_content.content.sender { match &auth_content.content.sender {
Sender::Member(index) => { Sender::Member(_) => {
if let Some(key_schedule) = key_schedule { if let Some(key_schedule) = key_schedule {
let expected_tag = &key_schedule let expected_tag = &key_schedule
.get_membership_tag(&auth_content, context, cipher_suite_provider) .get_membership_tag(&auth_content, context, cipher_suite_provider)
@@ -64,10 +63,6 @@ pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
return Err(MlsError::InvalidMembershipTag); return Err(MlsError::InvalidMembershipTag);
} }
} }
if self_index == Some(LeafIndex(*index)) {
return Err(MlsError::CantProcessMessageFromSelf);
}
} }
_ => { _ => {
tag.is_none() tag.is_none()
@@ -218,6 +213,7 @@ fn signing_identity_for_new_member_proposal(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use crate::{ use crate::{
client::{ client::{
test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
@@ -231,7 +227,6 @@ mod tests {
test_utils::{test_group_custom, TestGroup}, test_utils::{test_group_custom, TestGroup},
Group, PublicMessage, Group, PublicMessage,
}, },
tree_kem::node::LeafIndex,
}; };
use alloc::vec; use alloc::vec;
use assert_matches::assert_matches; use assert_matches::assert_matches;
@@ -255,6 +250,7 @@ mod tests {
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use alloc::boxed::Box; use alloc::boxed::Box;
#[cfg(feature = "by_ref_proposal")]
use crate::group::{ use crate::group::{
test_utils::{test_group, test_member}, test_utils::{test_group, test_member},
Sender, Sender,
@@ -263,8 +259,6 @@ mod tests {
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use crate::identity::test_utils::get_test_signing_identity; use crate::identity::test_utils::get_test_signing_identity;
use super::{verify_auth_content_signature, verify_plaintext_authentication};
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn make_signed_plaintext(group: &mut Group<TestClientConfig>) -> PublicMessage { async fn make_signed_plaintext(group: &mut Group<TestClientConfig>) -> PublicMessage {
group group
@@ -297,7 +291,6 @@ mod tests {
test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await; test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
let commit_output = alice let commit_output = alice
.group
.commit_builder() .commit_builder()
.add_member(bob_key_pkg) .add_member(bob_key_pkg)
.unwrap() .unwrap()
@@ -305,7 +298,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
alice.group.apply_pending_commit().await.unwrap(); alice.apply_pending_commit().await.unwrap();
let (bob, _) = Group::join( let (bob, _) = Group::join(
&commit_output.welcome_messages[0], &commit_output.welcome_messages[0],
@@ -327,14 +320,13 @@ mod tests {
async fn valid_plaintext_is_verified() { async fn valid_plaintext_is_verified() {
let mut env = TestEnv::new().await; let mut env = TestEnv::new().await;
let message = make_signed_plaintext(&mut env.alice.group).await; let message = make_signed_plaintext(&mut env.alice).await;
verify_plaintext_authentication( verify_plaintext_authentication(
&env.bob.group.cipher_suite_provider, &env.bob.cipher_suite_provider,
message, message,
Some(&env.bob.group.key_schedule), Some(&env.bob.key_schedule),
None, &env.bob.state,
&env.bob.group.state,
) )
.await .await
.unwrap(); .unwrap();
@@ -344,12 +336,12 @@ mod tests {
async fn valid_auth_content_is_verified() { async fn valid_auth_content_is_verified() {
let mut env = TestEnv::new().await; let mut env = TestEnv::new().await;
let message = AuthenticatedContent::from(make_signed_plaintext(&mut env.alice.group).await); let message = AuthenticatedContent::from(make_signed_plaintext(&mut env.alice).await);
verify_auth_content_signature( verify_auth_content_signature(
&env.bob.group.cipher_suite_provider, &env.bob.cipher_suite_provider,
super::SignaturePublicKeysContainer::RatchetTree(&env.bob.group.state.public_tree), super::SignaturePublicKeysContainer::RatchetTree(&env.bob.state.public_tree),
env.bob.group.context(), env.bob.context(),
&message, &message,
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
&[], &[],
@@ -361,28 +353,26 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn invalid_plaintext_is_not_verified() { async fn invalid_plaintext_is_not_verified() {
let mut env = TestEnv::new().await; let mut env = TestEnv::new().await;
let mut message = make_signed_plaintext(&mut env.alice.group).await; let mut message = make_signed_plaintext(&mut env.alice).await;
message.auth.signature = MessageSignature::from(b"test".to_vec()); message.auth.signature = MessageSignature::from(b"test".to_vec());
message.membership_tag = env message.membership_tag = env
.alice .alice
.group
.key_schedule .key_schedule
.get_membership_tag( .get_membership_tag(
&AuthenticatedContent::from(message.clone()), &AuthenticatedContent::from(message.clone()),
env.alice.group.context(), env.alice.context(),
&test_cipher_suite_provider(env.alice.group.cipher_suite()), &test_cipher_suite_provider(env.alice.cipher_suite()),
) )
.await .await
.unwrap() .unwrap()
.into(); .into();
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&env.bob.group.cipher_suite_provider, &env.bob.cipher_suite_provider,
message, message,
Some(&env.bob.group.key_schedule), Some(&env.bob.key_schedule),
None, &env.bob.state,
&env.bob.group.state,
) )
.await; .await;
@@ -392,15 +382,14 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn plaintext_from_member_requires_membership_tag() { async fn plaintext_from_member_requires_membership_tag() {
let mut env = TestEnv::new().await; let mut env = TestEnv::new().await;
let mut message = make_signed_plaintext(&mut env.alice.group).await; let mut message = make_signed_plaintext(&mut env.alice).await;
message.membership_tag = None; message.membership_tag = None;
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&env.bob.group.cipher_suite_provider, &env.bob.cipher_suite_provider,
message, message,
Some(&env.bob.group.key_schedule), Some(&env.bob.key_schedule),
None, &env.bob.state,
&env.bob.group.state,
) )
.await; .await;
@@ -410,15 +399,14 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn plaintext_fails_with_invalid_membership_tag() { async fn plaintext_fails_with_invalid_membership_tag() {
let mut env = TestEnv::new().await; let mut env = TestEnv::new().await;
let mut message = make_signed_plaintext(&mut env.alice.group).await; let mut message = make_signed_plaintext(&mut env.alice).await;
message.membership_tag = Some(MembershipTag::from(b"test".to_vec())); message.membership_tag = Some(MembershipTag::from(b"test".to_vec()));
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&env.bob.group.cipher_suite_provider, &env.bob.cipher_suite_provider,
message, message,
Some(&env.bob.group.key_schedule), Some(&env.bob.key_schedule),
None, &env.bob.state,
&env.bob.group.state,
) )
.await; .await;
@@ -437,8 +425,8 @@ mod tests {
F: FnMut(&mut AuthenticatedContent), F: FnMut(&mut AuthenticatedContent),
{ {
let mut content = AuthenticatedContent::new_signed( let mut content = AuthenticatedContent::new_signed(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
test_group.group.context(), test_group.context(),
Sender::NewMemberProposal, Sender::NewMemberProposal,
Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal { Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
key_package: key_pkg_gen.key_package, key_package: key_pkg_gen.key_package,
@@ -453,16 +441,12 @@ mod tests {
edit(&mut content); edit(&mut content);
let signing_context = MessageSigningContext { let signing_context = MessageSigningContext {
group_context: Some(test_group.group.context()), group_context: Some(test_group.context()),
protocol_version: test_group.group.protocol_version(), protocol_version: test_group.protocol_version(),
}; };
content content
.sign( .sign(&test_group.cipher_suite_provider, signer, &signing_context)
&test_group.group.cipher_suite_provider,
signer,
&signing_context,
)
.await .await
.unwrap(); .unwrap();
@@ -482,11 +466,10 @@ mod tests {
let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await; let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
verify_plaintext_authentication( verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await .await
.unwrap(); .unwrap();
@@ -503,11 +486,10 @@ mod tests {
message.membership_tag = Some(MembershipTag::from(vec![])); message.membership_tag = Some(MembershipTag::from(vec![]));
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await; .await;
@@ -529,11 +511,10 @@ mod tests {
.await; .await;
let res: Result<AuthenticatedContent, MlsError> = verify_plaintext_authentication( let res: Result<AuthenticatedContent, MlsError> = verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await; .await;
@@ -553,11 +534,10 @@ mod tests {
.await; .await;
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await; .await;
@@ -582,7 +562,6 @@ mod tests {
.unwrap(); .unwrap();
test_group test_group
.group
.commit_builder() .commit_builder()
.set_group_context_ext(extensions) .set_group_context_ext(extensions)
.unwrap() .unwrap()
@@ -590,7 +569,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
test_group.group.apply_pending_commit().await.unwrap(); test_group.apply_pending_commit().await.unwrap();
let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| { let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
msg.content.sender = Sender::External(0) msg.content.sender = Sender::External(0)
@@ -598,11 +577,10 @@ mod tests {
.await; .await;
verify_plaintext_authentication( verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await .await
.unwrap(); .unwrap();
@@ -622,11 +600,10 @@ mod tests {
.await; .await;
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await; .await;
@@ -649,32 +626,13 @@ mod tests {
message.membership_tag = Some(MembershipTag::from(vec![])); message.membership_tag = Some(MembershipTag::from(vec![]));
let res = verify_plaintext_authentication( let res = verify_plaintext_authentication(
&test_group.group.cipher_suite_provider, &test_group.cipher_suite_provider,
message, message,
Some(&test_group.group.key_schedule), Some(&test_group.key_schedule),
None, &test_group.state,
&test_group.group.state,
) )
.await; .await;
assert_matches!(res, Err(MlsError::MembershipTagForNonMember)); assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn plaintext_from_self_fails_verification() {
let mut env = TestEnv::new().await;
let message = make_signed_plaintext(&mut env.alice.group).await;
let res = verify_plaintext_authentication(
&env.alice.group.cipher_suite_provider,
message,
Some(&env.alice.group.key_schedule),
Some(LeafIndex::new(env.alice.group.current_member_index())),
&env.alice.group.state,
)
.await;
assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
}
} }

View File

@@ -12,9 +12,9 @@ use crate::{
use alloc::boxed::Box; use alloc::boxed::Box;
use core::convert::Infallible; use core::convert::Infallible;
use mls_rs_core::{ use mls_rs_core::{error::IntoAnyError, group::Member, identity::SigningIdentity};
error::IntoAnyError, extension::ExtensionList, group::Member, identity::SigningIdentity,
}; use super::GroupContext;
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CommitDirection { pub enum CommitDirection {
@@ -143,7 +143,7 @@ pub trait MlsRules: Send + Sync {
direction: CommitDirection, direction: CommitDirection,
source: CommitSource, source: CommitSource,
current_roster: &Roster, current_roster: &Roster,
extension_list: &ExtensionList, current_context: &GroupContext,
proposals: ProposalBundle, proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error>; ) -> Result<ProposalBundle, Self::Error>;
@@ -156,7 +156,7 @@ pub trait MlsRules: Send + Sync {
fn commit_options( fn commit_options(
&self, &self,
new_roster: &Roster, new_roster: &Roster,
new_extension_list: &ExtensionList, new_context: &GroupContext,
proposals: &ProposalBundle, proposals: &ProposalBundle,
) -> Result<CommitOptions, Self::Error>; ) -> Result<CommitOptions, Self::Error>;
@@ -168,7 +168,7 @@ pub trait MlsRules: Send + Sync {
fn encryption_options( fn encryption_options(
&self, &self,
current_roster: &Roster, current_roster: &Roster,
current_extension_list: &ExtensionList, current_context: &GroupContext,
) -> Result<EncryptionOptions, Self::Error>; ) -> Result<EncryptionOptions, Self::Error>;
} }
@@ -185,29 +185,29 @@ macro_rules! delegate_mls_rules {
direction: CommitDirection, direction: CommitDirection,
source: CommitSource, source: CommitSource,
current_roster: &Roster, current_roster: &Roster,
extension_list: &ExtensionList, context: &GroupContext,
proposals: ProposalBundle, proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
(**self) (**self)
.filter_proposals(direction, source, current_roster, extension_list, proposals) .filter_proposals(direction, source, current_roster, context, proposals)
.await .await
} }
fn commit_options( fn commit_options(
&self, &self,
roster: &Roster, roster: &Roster,
extension_list: &ExtensionList, context: &GroupContext,
proposals: &ProposalBundle, proposals: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
(**self).commit_options(roster, extension_list, proposals) (**self).commit_options(roster, context, proposals)
} }
fn encryption_options( fn encryption_options(
&self, &self,
roster: &Roster, roster: &Roster,
extension_list: &ExtensionList, context: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
(**self).encryption_options(roster, extension_list) (**self).encryption_options(roster, context)
} }
} }
}; };
@@ -258,7 +258,7 @@ impl MlsRules for DefaultMlsRules {
_direction: CommitDirection, _direction: CommitDirection,
_source: CommitSource, _source: CommitSource,
_current_roster: &Roster, _current_roster: &Roster,
_extension_list: &ExtensionList, _: &GroupContext,
proposals: ProposalBundle, proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
Ok(proposals) Ok(proposals)
@@ -267,7 +267,7 @@ impl MlsRules for DefaultMlsRules {
fn commit_options( fn commit_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: &ProposalBundle, _: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
Ok(self.commit_options) Ok(self.commit_options)
@@ -276,7 +276,7 @@ impl MlsRules for DefaultMlsRules {
fn encryption_options( fn encryption_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
Ok(self.encryption_options) Ok(self.encryption_options)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -33,6 +33,12 @@ pub struct AddProposal {
} }
impl AddProposal { impl AddProposal {
/// The [`KeyPackage`] used by this proposal to add
/// a [`Member`](mls_rs_core::group::Member) to the group.
pub fn key_package(&self) -> &KeyPackage {
&self.key_package
}
/// The [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member) /// The [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
/// that will be added by this proposal. /// that will be added by this proposal.
pub fn signing_identity(&self) -> &SigningIdentity { pub fn signing_identity(&self) -> &SigningIdentity {

View File

@@ -7,6 +7,7 @@ use alloc::vec::Vec;
use super::{ use super::{
message_processor::ProvisionalState, message_processor::ProvisionalState,
mls_rules::{CommitDirection, CommitSource, MlsRules}, mls_rules::{CommitDirection, CommitSource, MlsRules},
proposal_filter::prepare_proposals_for_mls_rules,
GroupState, ProposalOrRef, GroupState, ProposalOrRef,
}; };
use crate::{ use crate::{
@@ -19,13 +20,13 @@ use crate::{
}; };
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion}; use crate::{
group::{message_hash::MessageHash, ProposalMessageDescription, ProposalRef, ProtocolVersion},
MlsMessage,
};
use crate::tree_kem::leaf_node::LeafNode; use crate::tree_kem::leaf_node::LeafNode;
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
@@ -46,14 +47,21 @@ pub struct CachedProposal {
} }
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
#[derive(Clone, PartialEq)] #[derive(Clone, MlsSize, MlsEncode, MlsDecode)]
pub(crate) struct ProposalCache { pub(crate) struct ProposalCache {
protocol_version: ProtocolVersion, protocol_version: ProtocolVersion,
group_id: Vec<u8>, group_id: Vec<u8>,
#[cfg(feature = "std")] pub(crate) proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
pub(crate) proposals: HashMap<ProposalRef, CachedProposal>, pub(crate) own_proposals: crate::map::SmallMap<MessageHash, ProposalMessageDescription>,
#[cfg(not(feature = "std"))] }
pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
#[cfg(feature = "by_ref_proposal")]
impl PartialEq for ProposalCache {
fn eq(&self, other: &Self) -> bool {
self.protocol_version == other.protocol_version
&& self.group_id == other.group_id
&& self.proposals == other.proposals
}
} }
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
@@ -77,28 +85,30 @@ impl ProposalCache {
protocol_version, protocol_version,
group_id, group_id,
proposals: Default::default(), proposals: Default::default(),
own_proposals: Default::default(),
} }
} }
pub fn import( pub fn import(
protocol_version: ProtocolVersion, protocol_version: ProtocolVersion,
group_id: Vec<u8>, group_id: Vec<u8>,
#[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>, proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
#[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>, own_proposals: crate::map::SmallMap<MessageHash, ProposalMessageDescription>,
) -> Self { ) -> Self {
Self { Self {
protocol_version, protocol_version,
group_id, group_id,
proposals, proposals,
own_proposals,
} }
} }
#[inline]
pub fn clear(&mut self) { pub fn clear(&mut self) {
self.proposals.clear(); self.proposals.clear();
self.own_proposals.clear();
} }
#[cfg(feature = "private_message")] #[cfg(feature = "by_ref_proposal")]
#[inline] #[inline]
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.proposals.is_empty() self.proposals.is_empty()
@@ -115,6 +125,26 @@ impl ProposalCache {
self.proposals.push((proposal_ref, cached_proposal)); self.proposals.push((proposal_ref, cached_proposal));
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn insert_own<CS: CipherSuiteProvider>(
&mut self,
proposal: ProposalMessageDescription,
message: &MlsMessage,
sender: Sender,
cs: &CS,
) -> Result<(), MlsError> {
self.insert(
proposal.proposal_ref.clone(),
proposal.proposal.clone(),
sender,
);
let message_hash = MessageHash::compute(cs, message).await?;
self.own_proposals.insert(message_hash, proposal);
Ok(())
}
pub fn prepare_commit( pub fn prepare_commit(
&self, &self,
sender: Sender, sender: Sender,
@@ -169,6 +199,17 @@ impl ProposalCache {
Ok(proposals) Ok(proposals)
} }
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn get_own<CS: CipherSuiteProvider>(
&self,
cs: &CS,
message: &MlsMessage,
) -> Result<Option<ProposalMessageDescription>, MlsError> {
let message_hash = MessageHash::compute(cs, message).await?;
Ok(self.own_proposals.get(&message_hash).cloned())
}
} }
#[cfg(not(feature = "by_ref_proposal"))] #[cfg(not(feature = "by_ref_proposal"))]
@@ -223,7 +264,6 @@ impl GroupState {
CSP: CipherSuiteProvider, CSP: CipherSuiteProvider,
{ {
let roster = self.public_tree.roster(); let roster = self.public_tree.roster();
let group_extensions = &self.context.extensions;
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
let all_proposals = proposals.clone(); let all_proposals = proposals.clone();
@@ -243,36 +283,26 @@ impl GroupState {
)), )),
}?; }?;
prepare_proposals_for_mls_rules(&mut proposals, direction, &self.public_tree)?;
proposals = user_rules proposals = user_rules
.filter_proposals(direction, origin, &roster, group_extensions, proposals) .filter_proposals(direction, origin, &roster, &self.context, proposals)
.await .await
.map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?; .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
let applier = ProposalApplier::new( let applier = ProposalApplier::new(
&self.public_tree, &self.public_tree,
self.context.protocol_version,
cipher_suite_provider, cipher_suite_provider,
group_extensions, &self.context,
external_leaf, external_leaf,
identity_provider, identity_provider,
psk_storage, psk_storage,
#[cfg(feature = "by_ref_proposal")]
&self.context.group_id,
); );
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
let applier_output = match direction { let applier_output = applier
CommitDirection::Send => { .apply_proposals(direction.into(), &sender, proposals, commit_time)
applier .await?;
.apply_proposals(FilterStrategy::IgnoreByRef, &sender, proposals, commit_time)
.await?
}
CommitDirection::Receive => {
applier
.apply_proposals(FilterStrategy::IgnoreNone, &sender, proposals, commit_time)
.await?
}
};
#[cfg(not(feature = "by_ref_proposal"))] #[cfg(not(feature = "by_ref_proposal"))]
let applier_output = applier let applier_output = applier
@@ -288,6 +318,9 @@ impl GroupState {
&applier_output.applied_proposals, &applier_output.applied_proposals,
); );
#[cfg(not(feature = "by_ref_proposal"))]
let unused_proposals = alloc::vec::Vec::default();
let mut group_context = self.context.clone(); let mut group_context = self.context.clone();
group_context.epoch += 1; group_context.epoch += 1;
@@ -304,7 +337,6 @@ impl GroupState {
applied_proposals: proposals, applied_proposals: proposals,
external_init_index: applier_output.external_init_index, external_init_index: applier_output.external_init_index,
indexes_of_added_kpkgs: applier_output.indexes_of_added_kpkgs, indexes_of_added_kpkgs: applier_output.indexes_of_added_kpkgs,
#[cfg(feature = "by_ref_proposal")]
unused_proposals, unused_proposals,
}) })
} }
@@ -629,8 +661,8 @@ mod tests {
use crate::group::proposal_ref::test_utils::auth_content_from_proposal; use crate::group::proposal_ref::test_utils::auth_content_from_proposal;
use crate::group::proposal_ref::ProposalRef; use crate::group::proposal_ref::ProposalRef;
use crate::group::{ use crate::group::{
AddProposal, AuthenticatedContent, Content, ExternalInit, Proposal, ProposalOrRef, AddProposal, AuthenticatedContent, Content, ExternalInit, GroupContext, Proposal,
ReInitProposal, RemoveProposal, Roster, Sender, UpdateProposal, ProposalOrRef, ReInitProposal, RemoveProposal, Roster, Sender, UpdateProposal,
}; };
use crate::key_package::test_utils::test_key_package_with_signer; use crate::key_package::test_utils::test_key_package_with_signer;
use crate::signer::Signable; use crate::signer::Signable;
@@ -751,7 +783,7 @@ mod tests {
&test_cipher_suite_provider(TEST_CIPHER_SUITE), &test_cipher_suite_provider(TEST_CIPHER_SUITE),
TEST_GROUP, TEST_GROUP,
leaf_index, leaf_index,
default_properties(), Some(default_properties()),
None, None,
&signer, &signer,
) )
@@ -854,7 +886,6 @@ mod tests {
group_context: get_test_group_context(1, cipher_suite).await, group_context: get_test_group_context(1, cipher_suite).await,
external_init_index: None, external_init_index: None,
indexes_of_added_kpkgs: vec![LeafIndex(1)], indexes_of_added_kpkgs: vec![LeafIndex(1)],
#[cfg(feature = "state_update")]
unused_proposals: vec![], unused_proposals: vec![],
applied_proposals: bundle, applied_proposals: bundle,
}; };
@@ -923,8 +954,8 @@ mod tests {
} }
fn assert_matches(mut expected_state: ProvisionalState, state: ProvisionalState) { fn assert_matches(mut expected_state: ProvisionalState, state: ProvisionalState) {
let expected_proposals = expected_state.applied_proposals.into_proposals_or_refs(); let expected_proposals = expected_state.applied_proposals.proposals_or_refs();
let proposals = state.applied_proposals.into_proposals_or_refs(); let proposals = state.applied_proposals.proposals_or_refs();
assert_eq!(proposals.len(), expected_proposals.len()); assert_eq!(proposals.len(), expected_proposals.len());
@@ -955,7 +986,6 @@ mod tests {
assert_eq!(expected_state.public_tree, state.public_tree); assert_eq!(expected_state.public_tree, state.public_tree);
#[cfg(feature = "state_update")]
assert_eq!(expected_state.unused_proposals, state.unused_proposals); assert_eq!(expected_state.unused_proposals, state.unused_proposals);
} }
@@ -1119,7 +1149,7 @@ mod tests {
assert!(!provisional_state assert!(!provisional_state
.applied_proposals .applied_proposals
.into_proposals_or_refs() .proposals_or_refs()
.contains(&ProposalOrRef::Reference(update_proposal_ref))) .contains(&ProposalOrRef::Reference(update_proposal_ref)))
} }
@@ -1254,7 +1284,7 @@ mod tests {
let proposals = expected_effects let proposals = expected_effects
.applied_proposals .applied_proposals
.clone() .clone()
.into_proposals_or_refs(); .proposals_or_refs();
let resolution = cache let resolution = cache
.resolve_for_commit_default( .resolve_for_commit_default(
@@ -1314,7 +1344,7 @@ mod tests {
&test_cipher_suite_provider(TEST_CIPHER_SUITE), &test_cipher_suite_provider(TEST_CIPHER_SUITE),
TEST_GROUP, TEST_GROUP,
0, 0,
default_properties(), Some(default_properties()),
None, None,
&signer, &signer,
) )
@@ -1331,7 +1361,7 @@ mod tests {
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let public_tree = &group.group.state.public_tree; let public_tree = &group.state.public_tree;
let res = cache let res = cache
.resolve_for_commit_default( .resolve_for_commit_default(
@@ -1340,7 +1370,7 @@ mod tests {
ExternalInit { kem_output }, ExternalInit { kem_output },
)))], )))],
None, None,
&group.group.context().extensions, &group.context().extensions,
&BasicIdentityProvider, &BasicIdentityProvider,
&cipher_suite_provider, &cipher_suite_provider,
public_tree, public_tree,
@@ -1371,14 +1401,14 @@ mod tests {
); );
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let public_tree = &group.group.state.public_tree; let public_tree = &group.state.public_tree;
let res = cache let res = cache
.resolve_for_commit_default( .resolve_for_commit_default(
Sender::NewMemberCommit, Sender::NewMemberCommit,
vec![ProposalOrRef::Reference(proposal_ref)], vec![ProposalOrRef::Reference(proposal_ref)],
Some(&test_node().await), Some(&test_node().await),
&group.group.context().extensions, &group.context().extensions,
&BasicIdentityProvider, &BasicIdentityProvider,
&cipher_suite_provider, &cipher_suite_provider,
public_tree, public_tree,
@@ -1396,7 +1426,7 @@ mod tests {
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let public_tree = &group.group.state.public_tree; let public_tree = &group.state.public_tree;
let res = cache let res = cache
.resolve_for_commit_default( .resolve_for_commit_default(
@@ -1411,7 +1441,7 @@ mod tests {
.map(|p| ProposalOrRef::Proposal(Box::new(p))) .map(|p| ProposalOrRef::Proposal(Box::new(p)))
.collect(), .collect(),
Some(&test_node().await), Some(&test_node().await),
&group.group.context().extensions, &group.context().extensions,
&BasicIdentityProvider, &BasicIdentityProvider,
&cipher_suite_provider, &cipher_suite_provider,
public_tree, public_tree,
@@ -1432,7 +1462,7 @@ mod tests {
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let public_tree = &group.group.state.public_tree; let public_tree = &group.state.public_tree;
cache cache
.resolve_for_commit_default( .resolve_for_commit_default(
@@ -1445,7 +1475,7 @@ mod tests {
.map(|p| ProposalOrRef::Proposal(Box::new(p))) .map(|p| ProposalOrRef::Proposal(Box::new(p)))
.collect(), .collect(),
Some(&test_node().await), Some(&test_node().await),
&group.group.context().extensions, &group.context().extensions,
&BasicIdentityProvider, &BasicIdentityProvider,
&cipher_suite_provider, &cipher_suite_provider,
public_tree, public_tree,
@@ -1476,7 +1506,7 @@ mod tests {
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let group_extensions = group.group.context().extensions.clone(); let group_extensions = group.context().extensions.clone();
let mut public_tree = group.group.state.public_tree; let mut public_tree = group.group.state.public_tree;
let foo = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; let foo = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
@@ -1530,7 +1560,7 @@ mod tests {
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let group_extensions = group.group.context().extensions.clone(); let group_extensions = group.context().extensions.clone();
let mut public_tree = group.group.state.public_tree; let mut public_tree = group.group.state.public_tree;
let node = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await; let node = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
@@ -1579,7 +1609,7 @@ mod tests {
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()]; let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let group_extensions = group.group.context().extensions.clone(); let group_extensions = group.context().extensions.clone();
let mut public_tree = group.group.state.public_tree; let mut public_tree = group.group.state.public_tree;
let node = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; let node = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
@@ -1674,14 +1704,14 @@ mod tests {
let cache = make_proposal_cache(); let cache = make_proposal_cache();
let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
let public_tree = &group.group.state.public_tree; let public_tree = &group.state.public_tree;
let res = cache let res = cache
.resolve_for_commit_default( .resolve_for_commit_default(
Sender::NewMemberCommit, Sender::NewMemberCommit,
Vec::new(), Vec::new(),
Some(&test_node().await), Some(&test_node().await),
&group.group.context().extensions, &group.context().extensions,
&BasicIdentityProvider, &BasicIdentityProvider,
&cipher_suite_provider, &cipher_suite_provider,
public_tree, public_tree,
@@ -1993,7 +2023,7 @@ mod tests {
) )
.await?; .await?;
let proposals = state.applied_proposals.clone().into_proposals_or_refs(); let proposals = state.applied_proposals.clone().proposals_or_refs();
Ok((proposals, state)) Ok((proposals, state))
} }
@@ -2002,7 +2032,7 @@ mod tests {
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn key_package_with_invalid_signature() -> KeyPackage { async fn key_package_with_invalid_signature() -> KeyPackage {
let mut kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "mallory").await; let mut kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "mallory").await;
kp.signature.clear(); kp.signature = vec![1, 2, 3];
kp kp
} }
@@ -2036,6 +2066,7 @@ mod tests {
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn receiving_add_with_invalid_key_package_fails() { async fn receiving_add_with_invalid_key_package_fails() {
let (alice, tree) = new_tree("alice").await; let (alice, tree) = new_tree("alice").await;
let kp = key_package_with_invalid_signature().await;
let res = CommitReceiver::new( let res = CommitReceiver::new(
&tree, &tree,
@@ -2043,9 +2074,7 @@ mod tests {
alice, alice,
test_cipher_suite_provider(TEST_CIPHER_SUITE), test_cipher_suite_provider(TEST_CIPHER_SUITE),
) )
.receive([Proposal::Add(Box::new(AddProposal { .receive([Proposal::Add(Box::new(AddProposal { key_package: kp }))])
key_package: key_package_with_invalid_signature().await,
}))])
.await; .await;
assert_matches!(res, Err(MlsError::InvalidSignature)); assert_matches!(res, Err(MlsError::InvalidSignature));
@@ -2088,7 +2117,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2135,7 +2163,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2183,7 +2210,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2242,7 +2268,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2324,7 +2349,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2393,7 +2417,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2492,7 +2515,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2546,7 +2568,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2597,7 +2618,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2651,7 +2671,6 @@ mod tests {
assert_eq!(processed_proposals.0, vec![remove_ref.into()]); assert_eq!(processed_proposals.0, vec![remove_ref.into()]);
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
} }
@@ -2721,7 +2740,6 @@ mod tests {
let add_refs = [add_ref_one, add_ref_two]; let add_refs = [add_ref_one, add_ref_two];
assert!(add_refs.contains(committed_add_ref)); assert!(add_refs.contains(committed_add_ref));
#[cfg(feature = "state_update")]
assert_matches!( assert_matches!(
&*processed_proposals.1.unused_proposals, &*processed_proposals.1.unused_proposals,
[rejected_add_info] if committed_add_ref != rejected_add_info.proposal_ref().unwrap() && add_refs.contains(rejected_add_info.proposal_ref().unwrap()) [rejected_add_info] if committed_add_ref != rejected_add_info.proposal_ref().unwrap() && add_refs.contains(rejected_add_info.proposal_ref().unwrap())
@@ -2768,7 +2786,6 @@ mod tests {
// Bob proposed the update, so it is not listed as rejected when Alice commits it because // Bob proposed the update, so it is not listed as rejected when Alice commits it because
// she didn't propose it. // she didn't propose it.
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
} }
@@ -2859,7 +2876,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -2938,7 +2954,6 @@ mod tests {
assert!(proposal_info.contains(&committed_info)); assert!(proposal_info.contains(&committed_info));
#[cfg(feature = "state_update")]
match &*processed_proposals.1.unused_proposals { match &*processed_proposals.1.unused_proposals {
[r] => { [r] => {
assert_ne!(*r, committed_info); assert_ne!(*r, committed_info);
@@ -2991,8 +3006,8 @@ mod tests {
); );
} }
fn make_extension_list(foo: u8) -> ExtensionList { fn make_extension_list(something: u8) -> ExtensionList {
vec![TestExtension { foo }.into_extension().unwrap()].into() vec![TestExtension { foo: something }.into_extension().unwrap()].into()
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
@@ -3069,7 +3084,6 @@ mod tests {
assert!(gce_info.contains(&committed_gce_info)); assert!(gce_info.contains(&committed_gce_info));
#[cfg(feature = "state_update")]
assert_matches!( assert_matches!(
&*processed_proposals.1.unused_proposals, &*processed_proposals.1.unused_proposals,
[rejected_gce_info] if committed_gce_info != *rejected_gce_info && gce_info.contains(rejected_gce_info) [rejected_gce_info] if committed_gce_info != *rejected_gce_info && gce_info.contains(rejected_gce_info)
@@ -3148,7 +3162,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -3208,7 +3221,6 @@ mod tests {
assert_eq!(processed_proposals.0, vec![add_ref.into()]); assert_eq!(processed_proposals.0, vec![add_ref.into()]);
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![reinit_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![reinit_info]);
} }
@@ -3272,7 +3284,6 @@ mod tests {
assert!(*processed_ref == reinit_ref || *processed_ref == other_reinit_ref); assert!(*processed_ref == reinit_ref || *processed_ref == other_reinit_ref);
#[cfg(feature = "state_update")]
{ {
let (rejected_ref, unused_proposal) = match &*processed_proposals.1.unused_proposals { let (rejected_ref, unused_proposal) = match &*processed_proposals.1.unused_proposals {
[r] => (r.proposal_ref().unwrap().clone(), r.proposal.clone()), [r] => (r.proposal_ref().unwrap().clone(), r.proposal.clone()),
@@ -3338,7 +3349,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!( assert_eq!(
processed_proposals.1.unused_proposals, processed_proposals.1.unused_proposals,
vec![external_init_info] vec![external_init_info]
@@ -3410,7 +3420,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -3587,7 +3596,6 @@ mod tests {
cipher_suite_provider: &test_cipher_suite_provider(TEST_CIPHER_SUITE), cipher_suite_provider: &test_cipher_suite_provider(TEST_CIPHER_SUITE),
signing_identity: &signing_identity, signing_identity: &signing_identity,
signing_key: &secret_key, signing_key: &secret_key,
identity_provider: &BasicWithCustomProvider::new(BasicIdentityProvider::new()),
}; };
generator generator
@@ -3656,7 +3664,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![add_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![add_info]);
} }
@@ -3702,7 +3709,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![custom_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![custom_info]);
} }
@@ -3786,7 +3792,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -3858,7 +3863,6 @@ mod tests {
assert_eq!(processed_proposals.0, Vec::new()); assert_eq!(processed_proposals.0, Vec::new());
#[cfg(feature = "state_update")]
assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]); assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
} }
@@ -3876,7 +3880,7 @@ mod tests {
_: CommitDirection, _: CommitDirection,
_: CommitSource, _: CommitSource,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
mut proposals: ProposalBundle, mut proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
proposals.group_context_extensions.clear(); proposals.group_context_extensions.clear();
@@ -3887,7 +3891,7 @@ mod tests {
fn commit_options( fn commit_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: &ProposalBundle, _: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
@@ -3897,7 +3901,7 @@ mod tests {
fn encryption_options( fn encryption_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
} }
@@ -3928,7 +3932,7 @@ mod tests {
_: CommitDirection, _: CommitDirection,
_: CommitSource, _: CommitSource,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: ProposalBundle, _: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
Err(MlsError::InvalidSignature) Err(MlsError::InvalidSignature)
@@ -3938,7 +3942,7 @@ mod tests {
fn commit_options( fn commit_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: &ProposalBundle, _: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
@@ -3948,14 +3952,14 @@ mod tests {
fn encryption_options( fn encryption_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
} }
} }
struct InjectMlsRules { struct InjectMlsRules {
to_inject: Proposal, to_inject: Vec<Proposal>,
source: ProposalSource, source: ProposalSource,
} }
@@ -3969,14 +3973,13 @@ mod tests {
_: CommitDirection, _: CommitDirection,
_: CommitSource, _: CommitSource,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
mut proposals: ProposalBundle, mut proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> { ) -> Result<ProposalBundle, Self::Error> {
proposals.add( for proposal in self.to_inject.iter().cloned() {
self.to_inject.clone(), proposals.add(proposal, Sender::Member(0), self.source.clone());
Sender::Member(0), }
self.source.clone(),
);
Ok(proposals) Ok(proposals)
} }
@@ -3984,7 +3987,7 @@ mod tests {
fn commit_options( fn commit_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
_: &ProposalBundle, _: &ProposalBundle,
) -> Result<CommitOptions, Self::Error> { ) -> Result<CommitOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
@@ -3994,7 +3997,7 @@ mod tests {
fn encryption_options( fn encryption_options(
&self, &self,
_: &Roster, _: &Roster,
_: &ExtensionList, _: &GroupContext,
) -> Result<EncryptionOptions, Self::Error> { ) -> Result<EncryptionOptions, Self::Error> {
Ok(Default::default()) Ok(Default::default())
} }
@@ -4009,7 +4012,7 @@ mod tests {
let (committed, _) = let (committed, _) =
CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules { .with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(), to_inject: vec![test_proposal.clone()],
source: ProposalSource::ByValue, source: ProposalSource::ByValue,
}) })
.send() .send()
@@ -4031,7 +4034,7 @@ mod tests {
let (committed, _) = let (committed, _) =
CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules { .with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(), to_inject: vec![test_proposal.clone()],
source: ProposalSource::Local, source: ProposalSource::Local,
}) })
.send() .send()
@@ -4051,7 +4054,7 @@ mod tests {
let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE)) let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules { .with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(), to_inject: vec![test_proposal.clone()],
source: ProposalSource::ByValue, source: ProposalSource::ByValue,
}) })
.send() .send()
@@ -4060,6 +4063,25 @@ mod tests {
assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. })) assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. }))
} }
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn sending_invalid_local_proposal_fails() {
let (alice, tree) = new_tree("alice").await;
let gce_proposal = Proposal::GroupContextExtensions(Default::default());
let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules {
to_inject: vec![gce_proposal.clone(), gce_proposal],
source: ProposalSource::Local,
})
.send()
.await;
assert_matches!(
res,
Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
);
}
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn user_defined_filter_can_refuse_to_send_commit() { async fn user_defined_filter_can_refuse_to_send_commit() {
let (alice, tree) = new_tree("alice").await; let (alice, tree) = new_tree("alice").await;
@@ -4146,17 +4168,27 @@ mod tests {
#[cfg(feature = "by_ref_proposal")] #[cfg(feature = "by_ref_proposal")]
let receiver = receiver.with_extensions(extensions); let receiver = receiver.with_extensions(extensions);
let (receiver, proposals, proposer) = if by_ref { let (receiver, proposals, proposer, source) = if by_ref {
let proposal_ref = make_proposal_ref(proposal, proposer).await; let proposal_ref = make_proposal_ref(proposal, proposer).await;
let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer); let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer);
(receiver, vec![ProposalOrRef::from(proposal_ref)], proposer) (
receiver,
vec![ProposalOrRef::from(proposal_ref.clone())],
proposer,
ProposalSource::ByReference(proposal_ref),
)
} else { } else {
(receiver, vec![proposal.clone().into()], committer) (
receiver,
vec![proposal.clone().into()],
committer,
ProposalSource::Local,
)
}; };
let res = receiver.receive(proposals).await; let res = receiver.receive(proposals).await;
if proposer_can_propose(proposer, proposal.proposal_type(), by_ref).is_err() { if proposer_can_propose(proposer, proposal.proposal_type(), &source).is_err() {
assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender)); assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
} else { } else {
let is_self_update = proposal.proposal_type() == ProposalType::UPDATE let is_self_update = proposal.proposal_type() == ProposalType::UPDATE

Some files were not shown because too many files have changed in this diff Show More