1use super::{
9 internal::{MerkleNode, MerkleTreeIntoIter, MerkleTreeIter, MerkleTreeProof},
10 DigestAlgorithm, Element, ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme,
11 Index, LookupResult, MerkleProof, MerkleTreeScheme, NodeValue,
12 PersistentUniversalMerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme,
13};
14use crate::{
15 errors::MerkleTreeError, impl_forgetable_merkle_tree_scheme, impl_merkle_tree_scheme,
16 VerificationResult,
17};
18use alloc::sync::Arc;
19use ark_std::{borrow::Borrow, fmt::Debug, marker::PhantomData, string::ToString, vec, vec::Vec};
20use num_bigint::BigUint;
21use num_traits::pow::pow;
22use serde::{Deserialize, Serialize};
23
24impl_merkle_tree_scheme!(UniversalMerkleTree);
26impl_forgetable_merkle_tree_scheme!(UniversalMerkleTree);
27
28impl<E, H, I, const ARITY: usize, T> UniversalMerkleTree<E, H, I, ARITY, T>
29where
30 E: Element,
31 H: DigestAlgorithm<E, I, T>,
32 I: Index + ToTraversalPath<ARITY>,
33 T: NodeValue,
34{
35 pub fn new(height: usize) -> Self {
37 Self {
38 root: Arc::new(MerkleNode::<E, I, T>::Empty),
39 height,
40 num_leaves: 0,
41 _phantom: PhantomData,
42 }
43 }
44
45 pub fn from_kv_set<BI, BE>(
50 height: usize,
51 data: impl IntoIterator<Item = impl Borrow<(BI, BE)>>,
52 ) -> Result<Self, MerkleTreeError>
53 where
54 BI: Borrow<I>,
55 BE: Borrow<E>,
56 {
57 let mut mt = Self::new(height);
58 for tuple in data.into_iter() {
59 let (key, value) = tuple.borrow();
60 UniversalMerkleTreeScheme::update(&mut mt, key.borrow(), value.borrow())?;
61 }
62 Ok(mt)
63 }
64}
65impl<E, H, I, const ARITY: usize, T> UniversalMerkleTreeScheme
66 for UniversalMerkleTree<E, H, I, ARITY, T>
67where
68 E: Element,
69 H: DigestAlgorithm<E, I, T>,
70 I: Index + ToTraversalPath<ARITY>,
71 T: NodeValue,
72{
73 type NonMembershipProof = MerkleTreeProof<T>;
74 type BatchNonMembershipProof = ();
75
76 fn update_with<F>(
77 &mut self,
78 pos: impl Borrow<Self::Index>,
79 f: F,
80 ) -> Result<LookupResult<E, (), ()>, MerkleTreeError>
81 where
82 F: FnOnce(Option<&Self::Element>) -> Option<Self::Element>,
83 {
84 let pos = pos.borrow();
85 let traversal_path = pos.to_traversal_path(self.height);
86 let (new_root, delta, result) =
87 self.root
88 .update_with_internal::<H, ARITY, F>(self.height, pos, &traversal_path, f)?;
89 self.root = new_root;
90 self.num_leaves = (delta + self.num_leaves as i64) as u64;
91 Ok(result)
92 }
93
94 fn non_membership_verify(
95 commitment: impl Borrow<Self::Commitment>,
96 pos: impl Borrow<Self::Index>,
97 proof: impl Borrow<Self::NonMembershipProof>,
98 ) -> Result<VerificationResult, MerkleTreeError> {
99 crate::internal::verify_merkle_proof::<E, H, I, ARITY, T>(
100 commitment.borrow(),
101 pos.borrow(),
102 None,
103 proof.borrow().path_values(),
104 )
105 }
106
107 fn universal_lookup(
108 &self,
109 pos: impl Borrow<Self::Index>,
110 ) -> LookupResult<&Self::Element, Self::MembershipProof, Self::NonMembershipProof> {
111 let pos = pos.borrow();
112 let traversal_path = pos.to_traversal_path(self.height);
113 self.root.lookup_internal(self.height, &traversal_path)
114 }
115}
116
117impl<E, H, I, const ARITY: usize, T> PersistentUniversalMerkleTreeScheme
118 for UniversalMerkleTree<E, H, I, ARITY, T>
119where
120 E: Element,
121 H: DigestAlgorithm<E, I, T>,
122 I: Index + ToTraversalPath<ARITY>,
123 T: NodeValue,
124{
125 fn persistent_update_with<F>(
126 &self,
127 pos: impl Borrow<Self::Index>,
128 f: F,
129 ) -> Result<Self, MerkleTreeError>
130 where
131 F: FnOnce(Option<&Self::Element>) -> Option<Self::Element>,
132 {
133 let pos = pos.borrow();
134 let traversal_path = pos.to_traversal_path(self.height);
135 let (root, delta, _) =
136 self.root
137 .update_with_internal::<H, ARITY, F>(self.height, pos, &traversal_path, f)?;
138 let num_leaves = (delta + self.num_leaves as i64) as u64;
139 Ok(Self {
140 root,
141 height: self.height,
142 num_leaves,
143 _phantom: PhantomData,
144 })
145 }
146}
147
148impl<E, H, I, const ARITY: usize, T> ForgetableUniversalMerkleTreeScheme
149 for UniversalMerkleTree<E, H, I, ARITY, T>
150where
151 E: Element,
152 H: DigestAlgorithm<E, I, T>,
153 I: Index + ToTraversalPath<ARITY>,
154 T: NodeValue,
155{
156 fn universal_forget(
158 &mut self,
159 pos: Self::Index,
160 ) -> LookupResult<Self::Element, Self::MembershipProof, Self::NonMembershipProof> {
161 let traversal_path = pos.to_traversal_path(self.height);
162 let (root, result) = self.root.forget_internal(self.height, &traversal_path);
163 self.root = root;
164 result
165 }
166
167 fn non_membership_remember(
168 &mut self,
169 pos: Self::Index,
170 proof: impl Borrow<Self::NonMembershipProof>,
171 ) -> Result<(), MerkleTreeError> {
172 let pos = pos.borrow();
173 let proof = proof.borrow();
174 if Self::non_membership_verify(&self.commitment(), pos, proof)?.is_err() {
175 Err(MerkleTreeError::InconsistentStructureError(
176 "Wrong proof".to_string(),
177 ))
178 } else {
179 let traversal_path = pos.to_traversal_path(self.height);
180 self.root = self.root.remember_internal::<H, ARITY>(
181 self.height,
182 &traversal_path,
183 pos,
184 None,
185 proof.path_values(),
186 )?;
187 Ok(())
188 }
189 }
190}
191
192#[cfg(test)]
193mod mt_tests {
194 use crate::{
195 internal::{MerkleNode, MerkleTreeProof},
196 prelude::{RescueHash, RescueSparseMerkleTree},
197 DigestAlgorithm, ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme, Index,
198 LookupResult, MerkleProof, MerkleTreeScheme, PersistentUniversalMerkleTreeScheme,
199 ToTraversalPath, UniversalMerkleTreeScheme,
200 };
201 use ark_bls12_377::Fr as Fr377;
202 use ark_bls12_381::Fr as Fr381;
203 use ark_bn254::Fr as Fr254;
204 use hashbrown::HashMap;
205 use jf_rescue::RescueParameter;
206 use num_bigint::BigUint;
207
208 #[test]
209 fn test_universal_mt_builder() {
210 test_universal_mt_builder_helper::<Fr254>();
211 test_universal_mt_builder_helper::<Fr377>();
212 test_universal_mt_builder_helper::<Fr381>();
213 }
214
215 fn test_universal_mt_builder_helper<F: RescueParameter>() {
216 let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(
217 1,
218 [(BigUint::from(1u64), F::from(1u64))],
219 )
220 .unwrap();
221 assert_eq!(mt.num_leaves(), 1);
222
223 let mut hashmap = HashMap::new();
224 hashmap.insert(BigUint::from(1u64), F::from(2u64));
225 hashmap.insert(BigUint::from(2u64), F::from(2u64));
226 hashmap.insert(BigUint::from(1u64), F::from(3u64));
227 let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(10, &hashmap).unwrap();
228 assert_eq!(mt.num_leaves(), hashmap.len() as u64);
229 }
230
231 #[test]
232 fn test_non_membership_lookup_and_verify() {
233 test_non_membership_lookup_and_verify_helper::<Fr254>();
234 test_non_membership_lookup_and_verify_helper::<Fr377>();
235 test_non_membership_lookup_and_verify_helper::<Fr381>();
236 }
237
238 fn test_non_membership_lookup_and_verify_helper<F: RescueParameter>() {
239 let mut hashmap = HashMap::new();
240 hashmap.insert(BigUint::from(1u64), F::from(2u64));
241 hashmap.insert(BigUint::from(2u64), F::from(2u64));
242 hashmap.insert(BigUint::from(1u64), F::from(3u64));
243 let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(10, &hashmap).unwrap();
244 assert_eq!(mt.num_leaves(), hashmap.len() as u64);
245
246 let commitment = mt.commitment();
247
248 let mut proof = mt
249 .universal_lookup(BigUint::from(3u64))
250 .expect_not_found()
251 .unwrap();
252
253 let verify_result = RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
254 &commitment,
255 BigUint::from(3u64),
256 &proof,
257 )
258 .unwrap();
259 assert!(verify_result.is_ok());
260
261 let verify_result = RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
262 &commitment,
263 BigUint::from(1u64),
264 &proof,
265 )
266 .unwrap();
267 assert!(verify_result.is_err());
268 }
269
270 #[test]
271 fn test_update_and_lookup() {
272 test_update_and_lookup_helper::<BigUint, Fr254>();
273 test_update_and_lookup_helper::<BigUint, Fr377>();
274 test_update_and_lookup_helper::<BigUint, Fr381>();
275
276 test_update_and_lookup_helper::<Fr254, Fr254>();
277 test_update_and_lookup_helper::<Fr377, Fr377>();
278 test_update_and_lookup_helper::<Fr381, Fr381>();
279 }
280
281 fn test_update_and_lookup_helper<I, F>()
282 where
283 I: Index + ToTraversalPath<3>,
284 F: RescueParameter + ToTraversalPath<3>,
285 RescueHash<F>: DigestAlgorithm<F, I, F>,
286 {
287 let mut mt = RescueSparseMerkleTree::<F, F>::new(10);
288 for i in 0..2 {
289 mt.update(F::from(i as u64), F::from(i as u64)).unwrap();
290 }
291 let commitment = mt.commitment();
292 for i in 0..2 {
293 let (val, proof) = mt.universal_lookup(F::from(i as u64)).expect_ok().unwrap();
294 assert_eq!(val, &F::from(i as u64));
295 assert!(RescueSparseMerkleTree::<F, F>::verify(
296 &commitment,
297 F::from(i as u64),
298 val,
299 &proof
300 )
301 .unwrap()
302 .is_ok());
303 }
304 for i in 0..10 {
305 mt.update_with(F::from(i as u64), |elem| match elem {
306 Some(elem) => Some(*elem),
307 None => Some(F::from(i as u64)),
308 })
309 .unwrap();
310 }
311 assert_eq!(mt.num_leaves(), 10);
312 let commitment = mt.commitment();
313 let (val, proof) = mt.universal_lookup(F::from(7u64)).expect_ok().unwrap();
315 assert_eq!(val, &F::from(7u64));
316 assert!(
317 RescueSparseMerkleTree::<F, F>::verify(&commitment, F::from(7u64), val, &proof)
318 .unwrap()
319 .is_ok()
320 );
321
322 mt.update_with(F::from(8u64), |_| None).unwrap();
324 assert!(mt
325 .universal_lookup(F::from(8u64))
326 .expect_not_found()
327 .is_ok());
328 assert_eq!(mt.num_leaves(), 9);
329 }
330
331 #[test]
332 fn test_universal_mt_forget_remember() {
333 test_universal_mt_forget_remember_helper::<Fr254>();
334 test_universal_mt_forget_remember_helper::<Fr377>();
335 test_universal_mt_forget_remember_helper::<Fr381>();
336 }
337
338 fn test_universal_mt_forget_remember_helper<F: RescueParameter>() {
339 let mut mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(
340 10,
341 [
342 (BigUint::from(0u64), F::from(1u64)),
343 (BigUint::from(2u64), F::from(3u64)),
344 ],
345 )
346 .unwrap();
347 let commitment = mt.commitment();
348
349 let (lookup_elem, lookup_mem_proof) = mt
351 .universal_lookup(BigUint::from(0u64))
352 .expect_ok()
353 .unwrap();
354 let lookup_elem = *lookup_elem;
355 let (elem, mem_proof) = mt.universal_forget(0u64.into()).expect_ok().unwrap();
356 assert_eq!(lookup_elem, elem);
357 assert_eq!(lookup_mem_proof, mem_proof);
358 assert_eq!(elem, 1u64.into());
359 assert_eq!(mem_proof.height(), 10);
360 assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
361 &commitment,
362 BigUint::from(0u64),
363 &elem,
364 &lookup_mem_proof
365 )
366 .unwrap()
367 .is_ok());
368 assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
369 &commitment,
370 BigUint::from(0u64),
371 &elem,
372 &mem_proof
373 )
374 .unwrap()
375 .is_ok());
376
377 assert!(matches!(
379 mt.universal_forget(0u64.into()),
380 LookupResult::NotInMemory
381 ));
382 assert!(matches!(
383 mt.universal_lookup(BigUint::from(0u64)),
384 LookupResult::NotInMemory
385 ));
386
387 let (elem, proof) = mt
389 .universal_lookup(BigUint::from(2u64))
390 .expect_ok()
391 .unwrap();
392 assert_eq!(elem, &3u64.into());
393 assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
394 &commitment,
395 BigUint::from(2u64),
396 elem,
397 &proof
398 )
399 .unwrap()
400 .is_ok());
401
402 let lookup_non_mem_proof = match mt.universal_lookup(BigUint::from(1u64)) {
404 LookupResult::NotFound(proof) => proof,
405 res => panic!("expected NotFound, got {:?}", res),
406 };
407 let non_mem_proof = match mt.universal_forget(BigUint::from(1u64)) {
408 LookupResult::NotFound(proof) => proof,
409 res => panic!("expected NotFound, got {:?}", res),
410 };
411 assert_eq!(lookup_non_mem_proof, non_mem_proof);
412 assert_eq!(non_mem_proof.height(), 10);
413 assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
414 &commitment,
415 BigUint::from(1u64),
416 &lookup_non_mem_proof
417 )
418 .unwrap()
419 .is_ok());
420 assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
421 &commitment,
422 BigUint::from(1u64),
423 &non_mem_proof
424 )
425 .unwrap()
426 .is_ok());
427
428 match mt.universal_lookup(BigUint::from(1u64)) {
433 LookupResult::NotFound(proof) => {
434 assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
435 &commitment,
436 BigUint::from(1u64),
437 &proof
438 )
439 .unwrap()
440 .is_ok());
441 },
442 res => {
443 panic!("expected NotFound, got {:?}", res);
444 },
445 }
446
447 let (elem, proof) = mt
449 .universal_lookup(BigUint::from(2u64))
450 .expect_ok()
451 .unwrap();
452 assert_eq!(elem, &3u64.into());
453 assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
454 &commitment,
455 BigUint::from(2u64),
456 elem,
457 &proof
458 )
459 .unwrap()
460 .is_ok());
461
462 mt.universal_forget(BigUint::from(2u64))
465 .expect_ok()
466 .unwrap();
467 assert!(matches!(
468 mt.universal_lookup(BigUint::from(0u64)),
469 LookupResult::NotInMemory
470 ));
471 assert!(matches!(
472 mt.universal_lookup(BigUint::from(1u64)),
473 LookupResult::NotInMemory
474 ));
475 assert!(matches!(
476 mt.universal_lookup(BigUint::from(2u64)),
477 LookupResult::NotInMemory
478 ));
479
480 mt.remember(BigUint::from(0u64), F::from(2u64), &mem_proof)
482 .unwrap_err();
483 mt.remember(BigUint::from(1u64), F::from(1u64), &mem_proof)
484 .unwrap_err();
485 let mut bad_mem_proof = mem_proof.clone();
486 bad_mem_proof.0[0][0] = F::one();
487 mt.remember(BigUint::from(0u64), F::from(1u64), &bad_mem_proof)
488 .unwrap_err();
489
490 mt.non_membership_remember(0u64.into(), &non_mem_proof)
491 .unwrap_err();
492 let mut bad_non_mem_proof = non_mem_proof.clone();
493 bad_non_mem_proof.0[0][0] = F::one();
494 mt.non_membership_remember(1u64.into(), &bad_non_mem_proof)
495 .unwrap_err();
496
497 mt.remember(BigUint::from(0u64), F::from(1u64), &mem_proof)
499 .unwrap();
500 mt.non_membership_remember(1u64.into(), &non_mem_proof)
501 .unwrap();
502
503 let (elem, proof) = mt
505 .universal_lookup(BigUint::from(0u64))
506 .expect_ok()
507 .unwrap();
508 assert_eq!(elem, &1u64.into());
509 assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
510 &commitment,
511 BigUint::from(0u64),
512 elem,
513 &proof
514 )
515 .unwrap()
516 .is_ok());
517
518 match mt.universal_lookup(BigUint::from(1u64)) {
519 LookupResult::NotFound(proof) => {
520 assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
521 &commitment,
522 BigUint::from(1u64),
523 &proof
524 )
525 .unwrap()
526 .is_ok());
527 },
528 res => {
529 panic!("expected NotFound, got {:?}", res);
530 },
531 }
532 }
533
534 #[test]
535 fn test_persistent_update() {
536 test_persistent_update_helper::<BigUint, Fr254>();
537 test_persistent_update_helper::<BigUint, Fr377>();
538 test_persistent_update_helper::<BigUint, Fr381>();
539
540 test_persistent_update_helper::<Fr254, Fr254>();
541 test_persistent_update_helper::<Fr377, Fr377>();
542 test_persistent_update_helper::<Fr381, Fr381>();
543 }
544
545 fn test_persistent_update_helper<I, F>()
546 where
547 I: Index + ToTraversalPath<3>,
548 F: RescueParameter + ToTraversalPath<3>,
549 RescueHash<F>: DigestAlgorithm<F, I, F>,
550 {
551 let mt = RescueSparseMerkleTree::<F, F>::new(10);
552 let mut mts = ark_std::vec![mt];
553 for i in 1..10u64 {
554 mts.push(
555 mts.last()
556 .unwrap()
557 .persistent_update(F::from(i), F::from(i))
558 .unwrap(),
559 );
560 assert_eq!(mts.last().unwrap().num_leaves(), i);
561 }
562 for i in 1..10u64 {
563 mts.iter().enumerate().for_each(|(j, mt)| {
564 if j as u64 >= i {
565 assert!(mt.lookup(F::from(i)).expect_ok().is_ok());
566 } else {
567 assert!(mt.lookup(F::from(i)).expect_not_found().is_ok());
568 }
569 });
570 }
571
572 assert_eq!(mts[5].num_leaves(), 5);
573 let mt = mts[5].persistent_remove(F::from(1u64)).unwrap();
574 assert_eq!(mt.num_leaves(), 4);
575 }
576
577 #[test]
578 fn test_universal_mt_serde() {
579 test_universal_mt_serde_helper::<Fr254>();
580 test_universal_mt_serde_helper::<Fr377>();
581 test_universal_mt_serde_helper::<Fr381>();
582 }
583
584 fn test_universal_mt_serde_helper<F: RescueParameter + ToTraversalPath<3>>() {
585 let mut hashmap = HashMap::new();
586 hashmap.insert(F::from(1u64), F::from(2u64));
587 hashmap.insert(F::from(10u64), F::from(3u64));
588 let mt = RescueSparseMerkleTree::<F, F>::from_kv_set(3, &hashmap).unwrap();
589 let (_, mem_proof) = mt.lookup(F::from(10u64)).expect_ok().unwrap();
590 let non_mem_proof = match mt.universal_lookup(F::from(9u64)) {
592 LookupResult::NotFound(proof) => proof,
593 res => panic!("expected NotFound, got {:?}", res),
594 };
595
596 assert_eq!(
597 mt,
598 bincode::deserialize(&bincode::serialize(&mt).unwrap()).unwrap()
599 );
600 assert_eq!(
601 mem_proof,
602 bincode::deserialize(&bincode::serialize(&mem_proof).unwrap()).unwrap()
603 );
604 assert_eq!(
605 non_mem_proof,
606 bincode::deserialize(&bincode::serialize(&non_mem_proof).unwrap()).unwrap()
607 );
608 }
609}