jf_merkle_tree/
universal_merkle_tree.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Implementation of a typical Sparse Merkle Tree.
8use super::{
9    internal::{MerkleNode, MerkleTreeIntoIter, MerkleTreeIter, MerkleTreeProof},
10    DigestAlgorithm, Element, ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme,
11    Index, LookupResult, MerkleProof, MerkleTreeScheme, NodeValue,
12    PersistentUniversalMerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme,
13};
14use crate::{
15    errors::MerkleTreeError, impl_forgetable_merkle_tree_scheme, impl_merkle_tree_scheme,
16    VerificationResult,
17};
18use alloc::sync::Arc;
19use ark_std::{borrow::Borrow, fmt::Debug, marker::PhantomData, string::ToString, vec, vec::Vec};
20use num_bigint::BigUint;
21use num_traits::pow::pow;
22use serde::{Deserialize, Serialize};
23
24// A standard Universal Merkle tree implementation
25impl_merkle_tree_scheme!(UniversalMerkleTree);
26impl_forgetable_merkle_tree_scheme!(UniversalMerkleTree);
27
28impl<E, H, I, const ARITY: usize, T> UniversalMerkleTree<E, H, I, ARITY, T>
29where
30    E: Element,
31    H: DigestAlgorithm<E, I, T>,
32    I: Index + ToTraversalPath<ARITY>,
33    T: NodeValue,
34{
35    /// Initialize an empty Merkle tree.
36    pub fn new(height: usize) -> Self {
37        Self {
38            root: Arc::new(MerkleNode::<E, I, T>::Empty),
39            height,
40            num_leaves: 0,
41            _phantom: PhantomData,
42        }
43    }
44
45    /// Build a universal merkle tree from a key-value set.
46    /// * `height` - height of the merkle tree
47    /// * `data` - an iterator of key-value pairs. Could be a hashmap or simply
48    ///   an array or a slice of (key, value) pairs
49    pub fn from_kv_set<BI, BE>(
50        height: usize,
51        data: impl IntoIterator<Item = impl Borrow<(BI, BE)>>,
52    ) -> Result<Self, MerkleTreeError>
53    where
54        BI: Borrow<I>,
55        BE: Borrow<E>,
56    {
57        let mut mt = Self::new(height);
58        for tuple in data.into_iter() {
59            let (key, value) = tuple.borrow();
60            UniversalMerkleTreeScheme::update(&mut mt, key.borrow(), value.borrow())?;
61        }
62        Ok(mt)
63    }
64}
65impl<E, H, I, const ARITY: usize, T> UniversalMerkleTreeScheme
66    for UniversalMerkleTree<E, H, I, ARITY, T>
67where
68    E: Element,
69    H: DigestAlgorithm<E, I, T>,
70    I: Index + ToTraversalPath<ARITY>,
71    T: NodeValue,
72{
73    type NonMembershipProof = MerkleTreeProof<T>;
74    type BatchNonMembershipProof = ();
75
76    fn update_with<F>(
77        &mut self,
78        pos: impl Borrow<Self::Index>,
79        f: F,
80    ) -> Result<LookupResult<E, (), ()>, MerkleTreeError>
81    where
82        F: FnOnce(Option<&Self::Element>) -> Option<Self::Element>,
83    {
84        let pos = pos.borrow();
85        let traversal_path = pos.to_traversal_path(self.height);
86        let (new_root, delta, result) =
87            self.root
88                .update_with_internal::<H, ARITY, F>(self.height, pos, &traversal_path, f)?;
89        self.root = new_root;
90        self.num_leaves = (delta + self.num_leaves as i64) as u64;
91        Ok(result)
92    }
93
94    fn non_membership_verify(
95        commitment: impl Borrow<Self::Commitment>,
96        pos: impl Borrow<Self::Index>,
97        proof: impl Borrow<Self::NonMembershipProof>,
98    ) -> Result<VerificationResult, MerkleTreeError> {
99        crate::internal::verify_merkle_proof::<E, H, I, ARITY, T>(
100            commitment.borrow(),
101            pos.borrow(),
102            None,
103            proof.borrow().path_values(),
104        )
105    }
106
107    fn universal_lookup(
108        &self,
109        pos: impl Borrow<Self::Index>,
110    ) -> LookupResult<&Self::Element, Self::MembershipProof, Self::NonMembershipProof> {
111        let pos = pos.borrow();
112        let traversal_path = pos.to_traversal_path(self.height);
113        self.root.lookup_internal(self.height, &traversal_path)
114    }
115}
116
117impl<E, H, I, const ARITY: usize, T> PersistentUniversalMerkleTreeScheme
118    for UniversalMerkleTree<E, H, I, ARITY, T>
119where
120    E: Element,
121    H: DigestAlgorithm<E, I, T>,
122    I: Index + ToTraversalPath<ARITY>,
123    T: NodeValue,
124{
125    fn persistent_update_with<F>(
126        &self,
127        pos: impl Borrow<Self::Index>,
128        f: F,
129    ) -> Result<Self, MerkleTreeError>
130    where
131        F: FnOnce(Option<&Self::Element>) -> Option<Self::Element>,
132    {
133        let pos = pos.borrow();
134        let traversal_path = pos.to_traversal_path(self.height);
135        let (root, delta, _) =
136            self.root
137                .update_with_internal::<H, ARITY, F>(self.height, pos, &traversal_path, f)?;
138        let num_leaves = (delta + self.num_leaves as i64) as u64;
139        Ok(Self {
140            root,
141            height: self.height,
142            num_leaves,
143            _phantom: PhantomData,
144        })
145    }
146}
147
148impl<E, H, I, const ARITY: usize, T> ForgetableUniversalMerkleTreeScheme
149    for UniversalMerkleTree<E, H, I, ARITY, T>
150where
151    E: Element,
152    H: DigestAlgorithm<E, I, T>,
153    I: Index + ToTraversalPath<ARITY>,
154    T: NodeValue,
155{
156    /// WARN(#495): this method breaks non-membership proofs.
157    fn universal_forget(
158        &mut self,
159        pos: Self::Index,
160    ) -> LookupResult<Self::Element, Self::MembershipProof, Self::NonMembershipProof> {
161        let traversal_path = pos.to_traversal_path(self.height);
162        let (root, result) = self.root.forget_internal(self.height, &traversal_path);
163        self.root = root;
164        result
165    }
166
167    fn non_membership_remember(
168        &mut self,
169        pos: Self::Index,
170        proof: impl Borrow<Self::NonMembershipProof>,
171    ) -> Result<(), MerkleTreeError> {
172        let pos = pos.borrow();
173        let proof = proof.borrow();
174        if Self::non_membership_verify(&self.commitment(), pos, proof)?.is_err() {
175            Err(MerkleTreeError::InconsistentStructureError(
176                "Wrong proof".to_string(),
177            ))
178        } else {
179            let traversal_path = pos.to_traversal_path(self.height);
180            self.root = self.root.remember_internal::<H, ARITY>(
181                self.height,
182                &traversal_path,
183                pos,
184                None,
185                proof.path_values(),
186            )?;
187            Ok(())
188        }
189    }
190}
191
192#[cfg(test)]
193mod mt_tests {
194    use crate::{
195        internal::{MerkleNode, MerkleTreeProof},
196        prelude::{RescueHash, RescueSparseMerkleTree},
197        DigestAlgorithm, ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme, Index,
198        LookupResult, MerkleProof, MerkleTreeScheme, PersistentUniversalMerkleTreeScheme,
199        ToTraversalPath, UniversalMerkleTreeScheme,
200    };
201    use ark_bls12_377::Fr as Fr377;
202    use ark_bls12_381::Fr as Fr381;
203    use ark_bn254::Fr as Fr254;
204    use hashbrown::HashMap;
205    use jf_rescue::RescueParameter;
206    use num_bigint::BigUint;
207
208    #[test]
209    fn test_universal_mt_builder() {
210        test_universal_mt_builder_helper::<Fr254>();
211        test_universal_mt_builder_helper::<Fr377>();
212        test_universal_mt_builder_helper::<Fr381>();
213    }
214
215    fn test_universal_mt_builder_helper<F: RescueParameter>() {
216        let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(
217            1,
218            [(BigUint::from(1u64), F::from(1u64))],
219        )
220        .unwrap();
221        assert_eq!(mt.num_leaves(), 1);
222
223        let mut hashmap = HashMap::new();
224        hashmap.insert(BigUint::from(1u64), F::from(2u64));
225        hashmap.insert(BigUint::from(2u64), F::from(2u64));
226        hashmap.insert(BigUint::from(1u64), F::from(3u64));
227        let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(10, &hashmap).unwrap();
228        assert_eq!(mt.num_leaves(), hashmap.len() as u64);
229    }
230
231    #[test]
232    fn test_non_membership_lookup_and_verify() {
233        test_non_membership_lookup_and_verify_helper::<Fr254>();
234        test_non_membership_lookup_and_verify_helper::<Fr377>();
235        test_non_membership_lookup_and_verify_helper::<Fr381>();
236    }
237
238    fn test_non_membership_lookup_and_verify_helper<F: RescueParameter>() {
239        let mut hashmap = HashMap::new();
240        hashmap.insert(BigUint::from(1u64), F::from(2u64));
241        hashmap.insert(BigUint::from(2u64), F::from(2u64));
242        hashmap.insert(BigUint::from(1u64), F::from(3u64));
243        let mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(10, &hashmap).unwrap();
244        assert_eq!(mt.num_leaves(), hashmap.len() as u64);
245
246        let commitment = mt.commitment();
247
248        let mut proof = mt
249            .universal_lookup(BigUint::from(3u64))
250            .expect_not_found()
251            .unwrap();
252
253        let verify_result = RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
254            &commitment,
255            BigUint::from(3u64),
256            &proof,
257        )
258        .unwrap();
259        assert!(verify_result.is_ok());
260
261        let verify_result = RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
262            &commitment,
263            BigUint::from(1u64),
264            &proof,
265        )
266        .unwrap();
267        assert!(verify_result.is_err());
268    }
269
270    #[test]
271    fn test_update_and_lookup() {
272        test_update_and_lookup_helper::<BigUint, Fr254>();
273        test_update_and_lookup_helper::<BigUint, Fr377>();
274        test_update_and_lookup_helper::<BigUint, Fr381>();
275
276        test_update_and_lookup_helper::<Fr254, Fr254>();
277        test_update_and_lookup_helper::<Fr377, Fr377>();
278        test_update_and_lookup_helper::<Fr381, Fr381>();
279    }
280
281    fn test_update_and_lookup_helper<I, F>()
282    where
283        I: Index + ToTraversalPath<3>,
284        F: RescueParameter + ToTraversalPath<3>,
285        RescueHash<F>: DigestAlgorithm<F, I, F>,
286    {
287        let mut mt = RescueSparseMerkleTree::<F, F>::new(10);
288        for i in 0..2 {
289            mt.update(F::from(i as u64), F::from(i as u64)).unwrap();
290        }
291        let commitment = mt.commitment();
292        for i in 0..2 {
293            let (val, proof) = mt.universal_lookup(F::from(i as u64)).expect_ok().unwrap();
294            assert_eq!(val, &F::from(i as u64));
295            assert!(RescueSparseMerkleTree::<F, F>::verify(
296                &commitment,
297                F::from(i as u64),
298                val,
299                &proof
300            )
301            .unwrap()
302            .is_ok());
303        }
304        for i in 0..10 {
305            mt.update_with(F::from(i as u64), |elem| match elem {
306                Some(elem) => Some(*elem),
307                None => Some(F::from(i as u64)),
308            })
309            .unwrap();
310        }
311        assert_eq!(mt.num_leaves(), 10);
312        let commitment = mt.commitment();
313        // test lookup at index 7
314        let (val, proof) = mt.universal_lookup(F::from(7u64)).expect_ok().unwrap();
315        assert_eq!(val, &F::from(7u64));
316        assert!(
317            RescueSparseMerkleTree::<F, F>::verify(&commitment, F::from(7u64), val, &proof)
318                .unwrap()
319                .is_ok()
320        );
321
322        // Remove index 8
323        mt.update_with(F::from(8u64), |_| None).unwrap();
324        assert!(mt
325            .universal_lookup(F::from(8u64))
326            .expect_not_found()
327            .is_ok());
328        assert_eq!(mt.num_leaves(), 9);
329    }
330
331    #[test]
332    fn test_universal_mt_forget_remember() {
333        test_universal_mt_forget_remember_helper::<Fr254>();
334        test_universal_mt_forget_remember_helper::<Fr377>();
335        test_universal_mt_forget_remember_helper::<Fr381>();
336    }
337
338    fn test_universal_mt_forget_remember_helper<F: RescueParameter>() {
339        let mut mt = RescueSparseMerkleTree::<BigUint, F>::from_kv_set(
340            10,
341            [
342                (BigUint::from(0u64), F::from(1u64)),
343                (BigUint::from(2u64), F::from(3u64)),
344            ],
345        )
346        .unwrap();
347        let commitment = mt.commitment();
348
349        // Look up and forget an element that is in the tree.
350        let (lookup_elem, lookup_mem_proof) = mt
351            .universal_lookup(BigUint::from(0u64))
352            .expect_ok()
353            .unwrap();
354        let lookup_elem = *lookup_elem;
355        let (elem, mem_proof) = mt.universal_forget(0u64.into()).expect_ok().unwrap();
356        assert_eq!(lookup_elem, elem);
357        assert_eq!(lookup_mem_proof, mem_proof);
358        assert_eq!(elem, 1u64.into());
359        assert_eq!(mem_proof.height(), 10);
360        assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
361            &commitment,
362            BigUint::from(0u64),
363            &elem,
364            &lookup_mem_proof
365        )
366        .unwrap()
367        .is_ok());
368        assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
369            &commitment,
370            BigUint::from(0u64),
371            &elem,
372            &mem_proof
373        )
374        .unwrap()
375        .is_ok());
376
377        // Forgetting or looking up an element that is already forgotten should fail.
378        assert!(matches!(
379            mt.universal_forget(0u64.into()),
380            LookupResult::NotInMemory
381        ));
382        assert!(matches!(
383            mt.universal_lookup(BigUint::from(0u64)),
384            LookupResult::NotInMemory
385        ));
386
387        // We should still be able to look up an element that is not forgotten.
388        let (elem, proof) = mt
389            .universal_lookup(BigUint::from(2u64))
390            .expect_ok()
391            .unwrap();
392        assert_eq!(elem, &3u64.into());
393        assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
394            &commitment,
395            BigUint::from(2u64),
396            elem,
397            &proof
398        )
399        .unwrap()
400        .is_ok());
401
402        // Look up and forget an empty sub-tree.
403        let lookup_non_mem_proof = match mt.universal_lookup(BigUint::from(1u64)) {
404            LookupResult::NotFound(proof) => proof,
405            res => panic!("expected NotFound, got {:?}", res),
406        };
407        let non_mem_proof = match mt.universal_forget(BigUint::from(1u64)) {
408            LookupResult::NotFound(proof) => proof,
409            res => panic!("expected NotFound, got {:?}", res),
410        };
411        assert_eq!(lookup_non_mem_proof, non_mem_proof);
412        assert_eq!(non_mem_proof.height(), 10);
413        assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
414            &commitment,
415            BigUint::from(1u64),
416            &lookup_non_mem_proof
417        )
418        .unwrap()
419        .is_ok());
420        assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
421            &commitment,
422            BigUint::from(1u64),
423            &non_mem_proof
424        )
425        .unwrap()
426        .is_ok());
427
428        // Forgetting an empty sub-tree will never actually cause any new entries to be
429        // forgotten, since empty sub-trees are _already_ treated as if they
430        // were forgotten when deciding whether to forget their parent. So even
431        // though we "forgot" it, the empty sub-tree is still in memory.
432        match mt.universal_lookup(BigUint::from(1u64)) {
433            LookupResult::NotFound(proof) => {
434                assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
435                    &commitment,
436                    BigUint::from(1u64),
437                    &proof
438                )
439                .unwrap()
440                .is_ok());
441            },
442            res => {
443                panic!("expected NotFound, got {:?}", res);
444            },
445        }
446
447        // We should still be able to look up an element that is not forgotten.
448        let (elem, proof) = mt
449            .universal_lookup(BigUint::from(2u64))
450            .expect_ok()
451            .unwrap();
452        assert_eq!(elem, &3u64.into());
453        assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
454            &commitment,
455            BigUint::from(2u64),
456            elem,
457            &proof
458        )
459        .unwrap()
460        .is_ok());
461
462        // Now if we forget the last entry, which is the only thing keeping the root
463        // branch in memory, every entry will be forgotten.
464        mt.universal_forget(BigUint::from(2u64))
465            .expect_ok()
466            .unwrap();
467        assert!(matches!(
468            mt.universal_lookup(BigUint::from(0u64)),
469            LookupResult::NotInMemory
470        ));
471        assert!(matches!(
472            mt.universal_lookup(BigUint::from(1u64)),
473            LookupResult::NotInMemory
474        ));
475        assert!(matches!(
476            mt.universal_lookup(BigUint::from(2u64)),
477            LookupResult::NotInMemory
478        ));
479
480        // Remember should fail if the proof is invalid.
481        mt.remember(BigUint::from(0u64), F::from(2u64), &mem_proof)
482            .unwrap_err();
483        mt.remember(BigUint::from(1u64), F::from(1u64), &mem_proof)
484            .unwrap_err();
485        let mut bad_mem_proof = mem_proof.clone();
486        bad_mem_proof.0[0][0] = F::one();
487        mt.remember(BigUint::from(0u64), F::from(1u64), &bad_mem_proof)
488            .unwrap_err();
489
490        mt.non_membership_remember(0u64.into(), &non_mem_proof)
491            .unwrap_err();
492        let mut bad_non_mem_proof = non_mem_proof.clone();
493        bad_non_mem_proof.0[0][0] = F::one();
494        mt.non_membership_remember(1u64.into(), &bad_non_mem_proof)
495            .unwrap_err();
496
497        // Remember an occupied and an empty  sub-tree.
498        mt.remember(BigUint::from(0u64), F::from(1u64), &mem_proof)
499            .unwrap();
500        mt.non_membership_remember(1u64.into(), &non_mem_proof)
501            .unwrap();
502
503        // We should be able to look up everything we remembered.
504        let (elem, proof) = mt
505            .universal_lookup(BigUint::from(0u64))
506            .expect_ok()
507            .unwrap();
508        assert_eq!(elem, &1u64.into());
509        assert!(RescueSparseMerkleTree::<BigUint, F>::verify(
510            &commitment,
511            BigUint::from(0u64),
512            elem,
513            &proof
514        )
515        .unwrap()
516        .is_ok());
517
518        match mt.universal_lookup(BigUint::from(1u64)) {
519            LookupResult::NotFound(proof) => {
520                assert!(RescueSparseMerkleTree::<BigUint, F>::non_membership_verify(
521                    &commitment,
522                    BigUint::from(1u64),
523                    &proof
524                )
525                .unwrap()
526                .is_ok());
527            },
528            res => {
529                panic!("expected NotFound, got {:?}", res);
530            },
531        }
532    }
533
534    #[test]
535    fn test_persistent_update() {
536        test_persistent_update_helper::<BigUint, Fr254>();
537        test_persistent_update_helper::<BigUint, Fr377>();
538        test_persistent_update_helper::<BigUint, Fr381>();
539
540        test_persistent_update_helper::<Fr254, Fr254>();
541        test_persistent_update_helper::<Fr377, Fr377>();
542        test_persistent_update_helper::<Fr381, Fr381>();
543    }
544
545    fn test_persistent_update_helper<I, F>()
546    where
547        I: Index + ToTraversalPath<3>,
548        F: RescueParameter + ToTraversalPath<3>,
549        RescueHash<F>: DigestAlgorithm<F, I, F>,
550    {
551        let mt = RescueSparseMerkleTree::<F, F>::new(10);
552        let mut mts = ark_std::vec![mt];
553        for i in 1..10u64 {
554            mts.push(
555                mts.last()
556                    .unwrap()
557                    .persistent_update(F::from(i), F::from(i))
558                    .unwrap(),
559            );
560            assert_eq!(mts.last().unwrap().num_leaves(), i);
561        }
562        for i in 1..10u64 {
563            mts.iter().enumerate().for_each(|(j, mt)| {
564                if j as u64 >= i {
565                    assert!(mt.lookup(F::from(i)).expect_ok().is_ok());
566                } else {
567                    assert!(mt.lookup(F::from(i)).expect_not_found().is_ok());
568                }
569            });
570        }
571
572        assert_eq!(mts[5].num_leaves(), 5);
573        let mt = mts[5].persistent_remove(F::from(1u64)).unwrap();
574        assert_eq!(mt.num_leaves(), 4);
575    }
576
577    #[test]
578    fn test_universal_mt_serde() {
579        test_universal_mt_serde_helper::<Fr254>();
580        test_universal_mt_serde_helper::<Fr377>();
581        test_universal_mt_serde_helper::<Fr381>();
582    }
583
584    fn test_universal_mt_serde_helper<F: RescueParameter + ToTraversalPath<3>>() {
585        let mut hashmap = HashMap::new();
586        hashmap.insert(F::from(1u64), F::from(2u64));
587        hashmap.insert(F::from(10u64), F::from(3u64));
588        let mt = RescueSparseMerkleTree::<F, F>::from_kv_set(3, &hashmap).unwrap();
589        let (_, mem_proof) = mt.lookup(F::from(10u64)).expect_ok().unwrap();
590        // let node = (F::from(10u64), elem.clone());
591        let non_mem_proof = match mt.universal_lookup(F::from(9u64)) {
592            LookupResult::NotFound(proof) => proof,
593            res => panic!("expected NotFound, got {:?}", res),
594        };
595
596        assert_eq!(
597            mt,
598            bincode::deserialize(&bincode::serialize(&mt).unwrap()).unwrap()
599        );
600        assert_eq!(
601            mem_proof,
602            bincode::deserialize(&bincode::serialize(&mem_proof).unwrap()).unwrap()
603        );
604        assert_eq!(
605            non_mem_proof,
606            bincode::deserialize(&bincode::serialize(&non_mem_proof).unwrap()).unwrap()
607        );
608    }
609}