use super::Verifier;
use parity_scale_codec::{Decode, Encode};
use scale_info::TypeInfo;
use serde::{Deserialize, Serialize};
use sp_core::{
sr25519::{Public, Signature},
H256,
};
use sp_std::{
collections::{btree_map::BTreeMap, btree_set::BTreeSet},
vec::Vec,
};
#[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq, Eq, Clone, TypeInfo)]
pub struct ThresholdMultiSignature {
pub threshold: u8,
pub signatories: Vec<H256>,
}
impl ThresholdMultiSignature {
pub fn new(threshold: u8, signatories: Vec<H256>) -> Self {
ThresholdMultiSignature {
threshold,
signatories,
}
}
pub fn has_duplicate_signatories(&self) -> bool {
let set: BTreeSet<_> = self.signatories.iter().collect();
set.len() < self.signatories.len()
}
}
#[derive(Serialize, Deserialize, Encode, Decode, Debug, PartialEq, Eq, Clone)]
pub struct SignatureAndIndex {
pub signature: Signature,
pub index: u8,
}
impl Verifier for ThresholdMultiSignature {
type Redeemer = Vec<SignatureAndIndex>;
fn verify(&self, simplified_tx: &[u8], _: u32, sigs: &Vec<SignatureAndIndex>) -> bool {
if self.has_duplicate_signatories() {
return false;
}
if sigs.len() < self.threshold.into() {
return false;
}
{
let index_out_of_bounds = sigs.iter().any(|sig| sig.index as usize >= sigs.len());
if index_out_of_bounds {
return false;
}
}
{
let set: BTreeMap<u8, Signature> = sigs
.iter()
.map(|sig_and_index| (sig_and_index.index, sig_and_index.signature))
.collect();
if set.len() < sigs.len() {
return false;
}
}
let valid_sigs: Vec<_> = sigs
.iter()
.map(|sig| {
sp_io::crypto::sr25519_verify(
&sig.signature,
simplified_tx,
&Public::from_h256(self.signatories[sig.index as usize]),
);
})
.collect();
valid_sigs.len() >= self.threshold.into()
}
fn new_unspendable() -> Option<Self> {
Some(Self {
threshold: 1,
signatories: Vec::new(),
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::verifier::test::generate_n_pairs;
use sp_core::crypto::Pair as _;
#[test]
fn threshold_multisig_with_enough_sigs_passes() {
let threshold = 2;
let pairs = generate_n_pairs(threshold);
let signatories: Vec<H256> = pairs.iter().map(|p| H256::from(p.public())).collect();
let simplified_tx = b"hello_world".as_slice();
let sigs: Vec<_> = pairs
.iter()
.enumerate()
.map(|(i, p)| SignatureAndIndex {
signature: p.sign(simplified_tx),
index: i.try_into().unwrap(),
})
.collect();
let threshold_multisig = ThresholdMultiSignature {
threshold,
signatories,
};
assert!(threshold_multisig.verify(simplified_tx, 0, &sigs));
}
#[test]
fn threshold_multisig_not_enough_sigs_fails() {
let threshold = 3;
let pairs = generate_n_pairs(threshold);
let signatories: Vec<H256> = pairs.iter().map(|p| H256::from(p.public())).collect();
let simplified_tx = b"hello_world".as_slice();
let sigs: Vec<_> = pairs
.iter()
.take(threshold as usize - 1)
.enumerate()
.map(|(i, p)| SignatureAndIndex {
signature: p.sign(simplified_tx),
index: i.try_into().unwrap(),
})
.collect();
let threshold_multisig = ThresholdMultiSignature {
threshold,
signatories,
};
assert!(!threshold_multisig.verify(simplified_tx, 0, &sigs));
}
#[test]
fn threshold_multisig_extra_sigs_still_passes() {
let threshold = 2;
let pairs = generate_n_pairs(threshold + 1);
let signatories: Vec<H256> = pairs.iter().map(|p| H256::from(p.public())).collect();
let simplified_tx = b"hello_world".as_slice();
let sigs: Vec<_> = pairs
.iter()
.enumerate()
.map(|(i, p)| SignatureAndIndex {
signature: p.sign(simplified_tx),
index: i.try_into().unwrap(),
})
.collect();
let threshold_multisig = ThresholdMultiSignature {
threshold,
signatories,
};
assert!(threshold_multisig.verify(simplified_tx, 0, &sigs));
}
#[test]
fn threshold_multisig_replay_sig_attack_fails() {
let threshold = 2;
let pairs = generate_n_pairs(threshold);
let signatories: Vec<H256> = pairs.iter().map(|p| H256::from(p.public())).collect();
let simplified_tx = b"hello_world".as_slice();
let sigs: Vec<SignatureAndIndex> = vec![
SignatureAndIndex {
signature: pairs[0].sign(simplified_tx),
index: 0.try_into().unwrap(),
},
SignatureAndIndex {
signature: pairs[0].sign(simplified_tx),
index: 0.try_into().unwrap(),
},
];
let threshold_multisig = ThresholdMultiSignature {
threshold,
signatories,
};
assert!(!threshold_multisig.verify(simplified_tx, 0, &sigs));
}
#[test]
fn threshold_multisig_has_duplicate_signatories_fails() {
let threshold = 2;
let pairs = generate_n_pairs(threshold);
let signatories: Vec<H256> =
vec![H256::from(pairs[0].public()), H256::from(pairs[0].public())];
let simplified_tx = b"hello_world".as_slice();
let sigs: Vec<_> = pairs
.iter()
.enumerate()
.map(|(i, p)| SignatureAndIndex {
signature: p.sign(simplified_tx),
index: i.try_into().unwrap(),
})
.collect();
let threshold_multisig = ThresholdMultiSignature {
threshold,
signatories,
};
assert!(!threshold_multisig.verify(simplified_tx, 0, &sigs));
}
}