1use super::{
10 internal::{
11 build_tree_internal, MerkleNode, MerkleTreeIntoIter, MerkleTreeIter, MerkleTreeProof,
12 },
13 AppendableMerkleTreeScheme, DigestAlgorithm, Element, ForgetableMerkleTreeScheme, Index,
14 LookupResult, MerkleProof, MerkleTreeScheme, NodeValue, ToTraversalPath,
15};
16use crate::{
17 errors::MerkleTreeError, impl_forgetable_merkle_tree_scheme, impl_merkle_tree_scheme,
18 VerificationResult,
19};
20use alloc::sync::Arc;
21use ark_std::{borrow::Borrow, fmt::Debug, marker::PhantomData, string::ToString, vec, vec::Vec};
22use num_bigint::BigUint;
23use num_traits::pow::pow;
24use serde::{Deserialize, Serialize};
25
26impl_merkle_tree_scheme!(MerkleTree);
27impl_forgetable_merkle_tree_scheme!(MerkleTree);
28
29impl<E, H, I, const ARITY: usize, T> MerkleTree<E, H, I, ARITY, T>
30where
31 E: Element,
32 H: DigestAlgorithm<E, I, T>,
33 I: Index,
34 T: NodeValue,
35{
36 pub fn new(height: usize) -> Self {
38 Self {
39 root: Arc::new(MerkleNode::<E, I, T>::Empty),
40 height,
41 num_leaves: 0,
42 _phantom: PhantomData,
43 }
44 }
45}
46
47impl<E, H, const ARITY: usize, T> MerkleTree<E, H, u64, ARITY, T>
48where
49 E: Element,
50 H: DigestAlgorithm<E, u64, T>,
51 T: NodeValue,
52{
53 pub fn from_elems(
59 height: Option<usize>,
60 elems: impl IntoIterator<Item = impl Borrow<E>>,
61 ) -> Result<Self, MerkleTreeError> {
62 let (root, height, num_leaves) = build_tree_internal::<E, H, ARITY, T>(height, elems)?;
63 Ok(Self {
64 root,
65 height,
66 num_leaves,
67 _phantom: PhantomData,
68 })
69 }
70}
71
72impl<E, H, const ARITY: usize, T> AppendableMerkleTreeScheme for MerkleTree<E, H, u64, ARITY, T>
73where
74 E: Element,
75 H: DigestAlgorithm<E, u64, T>,
76 T: NodeValue,
77{
78 fn push(&mut self, elem: impl Borrow<Self::Element>) -> Result<(), MerkleTreeError> {
79 <Self as AppendableMerkleTreeScheme>::extend(self, [elem])
80 }
81
82 fn extend(
83 &mut self,
84 elems: impl IntoIterator<Item = impl Borrow<Self::Element>>,
85 ) -> Result<(), MerkleTreeError> {
86 let mut iter = elems.into_iter().peekable();
87
88 let traversal_path =
89 ToTraversalPath::<ARITY>::to_traversal_path(&self.num_leaves, self.height);
90 let (root, num_inserted) = self.root.extend_internal::<H, ARITY>(
91 self.height,
92 &self.num_leaves,
93 &traversal_path,
94 true,
95 &mut iter,
96 )?;
97 self.root = root;
98 self.num_leaves += num_inserted;
99 if iter.peek().is_some() {
100 return Err(MerkleTreeError::ExceedCapacity);
101 }
102 Ok(())
103 }
104}
105
106#[cfg(test)]
107mod mt_tests {
108 use crate::{
109 internal::{MerkleNode, MerkleTreeProof},
110 prelude::{RescueMerkleTree, RescueSparseMerkleTree},
111 *,
112 };
113 use ark_bls12_377::Fr as Fr377;
114 use ark_bls12_381::Fr as Fr381;
115 use ark_bn254::Fr as Fr254;
116 use jf_rescue::RescueParameter;
117
118 #[test]
119 fn test_mt_builder() {
120 test_mt_builder_helper::<Fr254>();
121 test_mt_builder_helper::<Fr377>();
122 test_mt_builder_helper::<Fr381>();
123 }
124
125 fn test_mt_builder_helper<F: RescueParameter>() {
126 assert!(RescueMerkleTree::<F>::from_elems(None, [F::from(0u64); 3]).is_ok());
127 assert!(RescueMerkleTree::<F>::from_elems(Some(1), [F::from(0u64); 4]).is_err());
128 }
129
130 #[test]
131 fn test_mt_insertion() {
132 test_mt_insertion_helper::<Fr254>();
133 test_mt_insertion_helper::<Fr377>();
134 test_mt_insertion_helper::<Fr381>();
135 }
136
137 fn test_mt_insertion_helper<F: RescueParameter>() {
138 let mut mt = RescueMerkleTree::<F>::new(2);
139 assert_eq!(mt.capacity(), BigUint::from(9u64));
140 assert!(mt.push(F::from(2u64)).is_ok());
141 assert!(mt.push(F::from(3u64)).is_ok());
142 assert!(mt.extend(&[F::from(0u64); 9]).is_err()); assert_eq!(mt.num_leaves(), 9u64); assert!(mt.push(F::from(0u64)).is_err());
147 assert!(mt.extend(&[]).is_ok());
148 assert!(mt.extend(&[F::from(1u64)]).is_err());
149 }
150
151 #[test]
152 fn test_mt_lookup() {
153 test_mt_lookup_helper::<Fr254>();
154 test_mt_lookup_helper::<Fr377>();
155 test_mt_lookup_helper::<Fr381>();
156 }
157
158 fn test_mt_lookup_helper<F: RescueParameter>() {
159 let mt = RescueMerkleTree::<F>::from_elems(None, [F::from(0u64)]).unwrap();
161 let (elem, _) = mt.lookup(0).expect_ok().unwrap();
162 assert_eq!(elem, &F::from(0u64));
163
164 let mt =
165 RescueMerkleTree::<F>::from_elems(Some(2), [F::from(3u64), F::from(1u64)]).unwrap();
166 let commitment = mt.commitment();
167 let (elem, proof) = mt.lookup(0).expect_ok().unwrap();
168 assert_eq!(elem, &F::from(3u64));
169 assert_eq!(proof.height(), 2);
170 assert!(
171 RescueMerkleTree::<F>::verify(&commitment, 0u64, elem, &proof)
172 .unwrap()
173 .is_ok()
174 );
175
176 assert!(
178 RescueMerkleTree::<F>::verify(&commitment, 0, F::from(14u64), &proof)
179 .unwrap()
180 .is_err()
181 );
182
183 assert!(RescueMerkleTree::<F>::verify(&commitment, 1, elem, &proof)
185 .unwrap()
186 .is_err());
187
188 let mut bad_proof = proof.clone();
189 bad_proof.0[0][0] = F::one();
190
191 assert!(
192 RescueMerkleTree::<F>::verify(&commitment, 0, elem, &bad_proof)
193 .unwrap()
194 .is_err()
195 );
196 }
197
198 #[test]
199 fn test_mt_forget_remember() {
200 test_mt_forget_remember_helper::<Fr254>();
201 test_mt_forget_remember_helper::<Fr377>();
202 test_mt_forget_remember_helper::<Fr381>();
203 }
204
205 fn test_mt_forget_remember_helper<F: RescueParameter>() {
206 let mut mt = RescueMerkleTree::<F>::from_elems(
207 Some(2),
208 [F::from(3u64), F::from(1u64), F::from(2u64), F::from(5u64)],
209 )
210 .unwrap();
211 let commitment = mt.commitment();
212 let (&lookup_elem, mut lookup_proof) = mt.lookup(3).expect_ok().unwrap();
213 let (elem, proof) = mt.forget(3).expect_ok().unwrap();
214 assert_eq!(lookup_elem, elem);
215 assert_eq!(lookup_proof, proof);
216 assert_eq!(elem, F::from(5u64));
217 assert_eq!(proof.height(), 2);
218 assert!(
219 RescueMerkleTree::<F>::verify(&commitment, 3, elem, &lookup_proof)
220 .unwrap()
221 .is_ok()
222 );
223 assert!(RescueMerkleTree::<F>::verify(&commitment, 3, elem, &proof)
224 .unwrap()
225 .is_ok());
226
227 assert!(mt.forget(3).expect_ok().is_err());
228 assert!(matches!(mt.lookup(3), LookupResult::NotInMemory));
229
230 assert!(mt.remember(3, F::from(19u64), &proof).is_err());
232 assert!(mt.remember(1, elem, &proof).is_err());
234 lookup_proof.0[0][0] = F::one();
236 assert!(mt.remember(3, elem, &lookup_proof).is_err());
237
238 assert!(mt.remember(3, elem, &proof).is_ok());
239 assert!(mt.lookup(3).expect_ok().is_ok());
240
241 let (&lookup_elem, mut lookup_proof) = mt.lookup(0).expect_ok().unwrap();
243 let (elem, proof) = mt.forget(0).expect_ok().unwrap();
244 assert_eq!(lookup_elem, elem);
245 assert_eq!(lookup_proof, proof);
246 assert_eq!(elem, F::from(3u64));
247 assert_eq!(proof.height(), 2);
248 assert!(
249 RescueMerkleTree::<F>::verify(&commitment, 0, elem, &lookup_proof)
250 .unwrap()
251 .is_ok()
252 );
253 assert!(RescueMerkleTree::<F>::verify(&commitment, 0, elem, &proof)
254 .unwrap()
255 .is_ok());
256
257 assert!(mt.forget(0).expect_ok().is_err());
258 assert!(matches!(mt.lookup(0), LookupResult::NotInMemory));
259
260 assert!(mt.remember(0, F::from(19u64), &proof).is_err());
262 assert!(mt.remember(1, elem, &proof).is_err());
264 lookup_proof.0[0][0] = F::one();
266 assert!(mt.remember(0, elem, &lookup_proof).is_err());
267
268 assert!(mt.remember(0, elem, &proof).is_ok());
269 assert!(mt.lookup(0).expect_ok().is_ok());
270 }
271
272 #[test]
273 fn test_mt_serde() {
274 test_mt_serde_helper::<Fr254>();
275 test_mt_serde_helper::<Fr377>();
276 test_mt_serde_helper::<Fr381>();
277 }
278
279 fn test_mt_serde_helper<F: RescueParameter>() {
280 let mt =
281 RescueMerkleTree::<F>::from_elems(Some(2), [F::from(3u64), F::from(1u64)]).unwrap();
282 let (_, proof) = mt.lookup(0).expect_ok().unwrap();
283
284 assert_eq!(
285 mt,
286 bincode::deserialize(&bincode::serialize(&mt).unwrap()).unwrap()
287 );
288 assert_eq!(
289 proof,
290 bincode::deserialize(&bincode::serialize(&proof).unwrap()).unwrap()
291 );
292 }
293
294 #[test]
295 fn test_mt_iter() {
296 test_mt_iter_helper::<Fr254>();
297 test_mt_iter_helper::<Fr377>();
298 test_mt_iter_helper::<Fr381>();
299 }
300
301 fn test_mt_iter_helper<F: RescueParameter>() {
302 let mut mt = RescueMerkleTree::<F>::from_elems(
303 Some(2),
304 [F::from(0u64), F::from(1u64), F::from(2u64)],
305 )
306 .unwrap();
307 assert!(mt.iter().all(|(index, elem)| { elem == &F::from(*index) }));
308
309 assert!(mt.forget(1).expect_ok().is_ok());
311 assert_eq!(mt.num_leaves(), 3);
313 let leaves = mt.into_iter().collect::<Vec<_>>();
315 assert_eq!(leaves, [(0, F::from(0u64)), (2, F::from(2u64))]);
316
317 let kv_set = [
318 (BigUint::from(64u64), F::from(32u64)),
319 (BigUint::from(123u64), F::from(234u64)),
320 ];
321 let mut mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(10, &kv_set).unwrap();
322 let kv_refs = kv_set
323 .iter()
324 .map(|tuple| (&tuple.0, &tuple.1))
325 .collect::<Vec<_>>();
326 assert_eq!(mt.iter().collect::<Vec<_>>(), kv_refs);
327 mt.update(BigUint::from(32u64), F::from(16u64)).unwrap();
329 mt.forget(BigUint::from(123u64)).expect_ok().unwrap();
331 assert_eq!(
333 mt.into_iter().collect::<Vec<_>>(),
334 [
335 (BigUint::from(32u64), F::from(16u64)),
336 (BigUint::from(64u64), F::from(32u64)),
337 ]
338 );
339 }
340}