1pub use crate::{
10 append_only::MerkleTree,
11 impl_to_traversal_path_biguint, impl_to_traversal_path_primitives,
12 internal::{MerkleNode, MerkleTreeProof},
13 universal_merkle_tree::UniversalMerkleTree,
14 AppendableMerkleTreeScheme, DigestAlgorithm, Element, ForgetableMerkleTreeScheme,
15 ForgetableUniversalMerkleTreeScheme, Index, LookupResult, MerkleTreeScheme, NodeValue,
16 ToTraversalPath, UniversalMerkleTreeScheme,
17};
18
19use super::light_weight::LightWeightMerkleTree;
20use crate::errors::MerkleTreeError;
21use ark_ff::PrimeField;
22use ark_serialize::{
23 CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate,
24 Write,
25};
26use ark_std::{fmt, marker::PhantomData, string::ToString, vec, vec::Vec};
27use jf_crhf::CRHF;
28use jf_poseidon2::{
29 crhf::FixedLenPoseidon2Hash, sponge::Poseidon2Sponge, Poseidon2, Poseidon2Params,
30};
31use jf_rescue::{crhf::RescueCRHF, RescueParameter};
32use nimue::hash::sponge::Sponge;
33use sha3::{Digest, Keccak256, Sha3_256};
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq)]
37pub struct RescueHash<F: RescueParameter> {
38 phantom_f: PhantomData<F>,
39}
40
41fn leaf_hash_dom_sep<F: PrimeField>() -> F {
43 F::one()
44}
45
46fn internal_hash_dom_sep<F: PrimeField>() -> F {
48 F::zero()
49}
50
51pub(crate) const LEAF_HASH_DOM_SEP: &'static [u8; 1] = b"1";
53pub(crate) const INTERNAL_HASH_DOM_SEP: &'static [u8; 1] = b"0";
55
56impl<I: Index, F: RescueParameter + From<I>> DigestAlgorithm<F, I, F> for RescueHash<F> {
57 fn digest(data: &[F]) -> Result<F, MerkleTreeError> {
58 let mut input = vec![internal_hash_dom_sep()];
59 input.extend(data.iter());
60 Ok(RescueCRHF::<F>::sponge_with_zero_padding(&input, 1)[0])
61 }
62
63 fn digest_leaf(pos: &I, elem: &F) -> Result<F, MerkleTreeError> {
64 let data = [leaf_hash_dom_sep(), F::from(pos.clone()), *elem];
65 Ok(RescueCRHF::<F>::sponge_with_zero_padding(&data, 1)[0])
66 }
67}
68
69pub type RescueMerkleTree<F> = MerkleTree<F, RescueHash<F>, u64, 3, F>;
71
72pub type RescueLightWeightMerkleTree<F> = LightWeightMerkleTree<F, RescueHash<F>, u64, 3, F>;
74
75pub type RescueSparseMerkleTree<I, F> = UniversalMerkleTree<F, RescueHash<F>, I, 3, F>;
77
78impl<I, F, S, const INPUT_SIZE: usize> DigestAlgorithm<F, I, F>
82 for FixedLenPoseidon2Hash<F, S, INPUT_SIZE, 1>
83where
84 I: Index,
85 F: PrimeField + From<I> + nimue::Unit,
86 S: Sponge<U = F> + Poseidon2Sponge,
87{
88 fn digest(data: &[F]) -> Result<F, MerkleTreeError> {
89 let mut input = vec![internal_hash_dom_sep()];
90 input.extend(data.iter());
91 Ok(FixedLenPoseidon2Hash::<F, S, INPUT_SIZE, 1>::evaluate(input)?[0])
92 }
93
94 fn digest_leaf(pos: &I, elem: &F) -> Result<F, MerkleTreeError> {
95 if INPUT_SIZE < 3 {
96 return Err(MerkleTreeError::ParametersError(ark_std::format!(
97 "INPUT_SIZE {} too short",
98 INPUT_SIZE
99 )));
100 }
101 let mut input = vec![F::zero(); INPUT_SIZE];
102 input[0] = leaf_hash_dom_sep();
103 input[INPUT_SIZE - 2] = F::from(pos.clone());
104 input[INPUT_SIZE - 1] = *elem;
105 Ok(FixedLenPoseidon2Hash::<F, S, INPUT_SIZE, 1>::evaluate(input)?[0])
106 }
107}
108
109macro_rules! impl_mt_hash_256 {
118 ($hasher:ident, $node_name:ident, $digest_name:ident) => {
119 #[derive(Default, Eq, PartialEq, Clone, Copy, Ord, PartialOrd, Hash)]
121 pub struct $node_name(pub(crate) [u8; 32]);
122
123 impl fmt::Debug for $node_name {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 f.debug_tuple(&stringify!($node_name))
126 .field(&hex::encode(self.0))
127 .finish()
128 }
129 }
130
131 impl AsRef<[u8]> for $node_name {
132 fn as_ref(&self) -> &[u8] {
133 &self.0
134 }
135 }
136 impl CanonicalSerialize for $node_name {
137 fn serialize_with_mode<W: Write>(
138 &self,
139 mut writer: W,
140 _compress: Compress,
141 ) -> Result<(), SerializationError> {
142 writer.write_all(&self.0)?;
143 Ok(())
144 }
145 fn serialized_size(&self, _compress: Compress) -> usize {
146 32
147 }
148 }
149 impl CanonicalDeserialize for $node_name {
150 fn deserialize_with_mode<R: Read>(
151 mut reader: R,
152 _compress: Compress,
153 _validate: Validate,
154 ) -> Result<Self, SerializationError> {
155 let mut ret = [0u8; 32];
156 reader.read_exact(&mut ret)?;
157 Ok(Self(ret))
158 }
159 }
160 impl Valid for $node_name {
161 fn check(&self) -> Result<(), SerializationError> {
162 Ok(())
163 }
164 }
165
166 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
168 pub struct $digest_name;
169 impl<E: Element + CanonicalSerialize, I: Index> DigestAlgorithm<E, I, $node_name>
170 for $digest_name
171 {
172 fn digest(data: &[$node_name]) -> Result<$node_name, MerkleTreeError> {
173 let mut h = $hasher::new();
174 h.update(INTERNAL_HASH_DOM_SEP);
175 for value in data {
176 h.update(value);
177 }
178 Ok($node_name(h.finalize().into()))
179 }
180
181 fn digest_leaf(_pos: &I, elem: &E) -> Result<$node_name, MerkleTreeError> {
182 let mut writer = Vec::new();
183 elem.serialize_compressed(&mut writer).unwrap();
184 let mut h = $hasher::new();
185 h.update(LEAF_HASH_DOM_SEP);
186 h.update(writer);
187 Ok($node_name(h.finalize().into()))
188 }
189 }
190 };
191}
192
193impl_mt_hash_256!(Sha3_256, Sha3Node, Sha3Digest);
194impl_mt_hash_256!(Keccak256, Keccak256Node, Keccak256Digest);
195
196pub type SHA3MerkleTree<E> = MerkleTree<E, Sha3Digest, u64, 3, Sha3Node>;
198pub type LightWeightSHA3MerkleTree<E> = LightWeightMerkleTree<E, Sha3Digest, u64, 3, Sha3Node>;
200
201pub type Keccak256MerkleTree<E> = MerkleTree<E, Keccak256Node, u64, 3, Keccak256Digest>;
203pub type LightWeightKeccak256MerkleTree<E> =
205 LightWeightMerkleTree<E, Keccak256Digest, u64, 3, Keccak256Node>;
206
207#[cfg(test)]
208mod tests {
209 use super::{MerkleTreeScheme, RescueMerkleTree};
210 use ark_bls12_377::Fr as Fr377;
211 use ark_bls12_381::Fr as Fr381;
212 use ark_bn254::Fr as Fr254;
213 use jf_rescue::{crhf::RescueCRHF, RescueParameter};
214
215 #[test]
216 fn test_extension_attack() {
217 test_extension_attack_helper::<Fr254>();
218 test_extension_attack_helper::<Fr377>();
219 test_extension_attack_helper::<Fr381>();
220 }
221
222 fn test_extension_attack_helper<F: RescueParameter>() {
223 let forged_val = F::from(42u64);
224 let attack_pos = 5u64;
225 let forged_pos = attack_pos * 3 + 2;
226 let data = [F::zero(), F::from(forged_pos), forged_val];
227 let val = RescueCRHF::<F>::sponge_no_padding(&data, 1).unwrap()[0];
228 let elems = ark_std::vec![val; attack_pos as usize + 1];
229 let mt = RescueMerkleTree::<F>::from_elems(None, elems).unwrap();
230 let commit = mt.commitment();
231 let (elem, mut proof) = mt.lookup(attack_pos).expect_ok().unwrap();
232 assert!(
233 RescueMerkleTree::<F>::verify(&commit, attack_pos, elem, &proof)
234 .unwrap()
235 .is_ok()
236 );
237 proof
238 .0
239 .insert(0, ark_std::vec![F::zero(), F::from(attack_pos)]);
240 assert!(
241 RescueMerkleTree::<F>::verify(&commit, forged_pos, forged_val, &proof)
242 .unwrap()
243 .is_err()
244 );
245 }
246}