jf_merkle_tree/
internal.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7use super::{DigestAlgorithm, Element, Index, LookupResult, NodeValue, ToTraversalPath};
8use crate::{errors::MerkleTreeError, prelude::MerkleTree, VerificationResult, FAIL, SUCCESS};
9use alloc::sync::Arc;
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use ark_std::{borrow::Borrow, format, iter::Peekable, string::ToString, vec, vec::Vec};
12use derivative::Derivative;
13use itertools::Itertools;
14use jf_utils::canonical;
15use num_bigint::BigUint;
16use serde::{Deserialize, Serialize};
17use tagged_base64::tagged;
18
19/// An internal Merkle node.
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21#[serde(bound = "E: CanonicalSerialize + CanonicalDeserialize,
22                 I: CanonicalSerialize + CanonicalDeserialize,")]
23pub enum MerkleNode<E: Element, I: Index, T: NodeValue> {
24    /// An empty subtree.
25    Empty,
26    /// An internal branching node
27    Branch {
28        /// Merkle hash value of this subtree
29        #[serde(with = "canonical")]
30        value: T,
31        /// All it's children
32        children: Vec<Arc<MerkleNode<E, I, T>>>,
33    },
34    /// A leaf node
35    Leaf {
36        /// Merkle hash value of this leaf
37        #[serde(with = "canonical")]
38        value: T,
39        /// Index of this leaf
40        #[serde(with = "canonical")]
41        pos: I,
42        /// Associated element of this leaf
43        #[serde(with = "canonical")]
44        elem: E,
45    },
46    /// The subtree is forgotten from the memory
47    ForgottenSubtree {
48        /// Merkle hash value of this forgotten subtree
49        #[serde(with = "canonical")]
50        value: T,
51    },
52}
53
54impl<E, I, T> MerkleNode<E, I, T>
55where
56    E: Element,
57    I: Index,
58    T: NodeValue,
59{
60    /// Return the value of this [`MerkleNode`].
61    #[inline]
62    pub(crate) fn value(&self) -> T {
63        match self {
64            Self::Empty => T::default(),
65            Self::Leaf {
66                value,
67                pos: _,
68                elem: _,
69            } => *value,
70            Self::Branch { value, children: _ } => *value,
71            Self::ForgottenSubtree { value } => *value,
72        }
73    }
74
75    #[inline]
76    pub(crate) fn is_forgotten(&self) -> bool {
77        matches!(self, Self::ForgottenSubtree { .. })
78    }
79}
80
81/// A (non)membership Merkle proof consists of all values of siblings of a
82/// Merkle path.
83#[derive(
84    Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, CanonicalSerialize, CanonicalDeserialize,
85)]
86#[tagged("MERKLE_PROOF")]
87pub struct MerkleTreeProof<T: NodeValue>(pub Vec<Vec<T>>);
88
89impl<T: NodeValue> super::MerkleProof<T> for MerkleTreeProof<T> {
90    /// Expected height of the Merkle tree.
91    fn height(&self) -> usize {
92        self.0.len()
93    }
94
95    /// Return all values of siblings of this Merkle path
96    fn path_values(&self) -> &[Vec<T>] {
97        &self.0
98    }
99}
100
101/// Verify a merkle proof
102/// * `commitment` - a merkle tree commitment
103/// * `pos` - zero-based index of the leaf in the tree
104/// * `element` - the leaf value, None if verifying a non-membership proof
105/// * `proof` - a membership proof for `element` at given `pos`
106/// * `returns` - Ok(true) if the proof is accepted, Ok(false) if not. Err() if
107///   the proof is not well structured, E.g. not for this merkle tree.
108pub(crate) fn verify_merkle_proof<E, H, I, const ARITY: usize, T>(
109    commitment: &T,
110    pos: &I,
111    element: Option<&E>,
112    proof: &[Vec<T>],
113) -> Result<VerificationResult, MerkleTreeError>
114where
115    E: Element,
116    I: Index + ToTraversalPath<ARITY>,
117    T: NodeValue,
118    H: DigestAlgorithm<E, I, T>,
119{
120    let init = if let Some(elem) = element {
121        H::digest_leaf(pos, elem)?
122    } else {
123        T::default()
124    };
125    match element {
126        // only strictly checking this during membership proof
127        // in non-membership proof, empty leaf/branch node can have empty vector in
128        // their merkle proof
129        Some(_) => {
130            if proof.iter().any(|v| v.len() != ARITY - 1) {
131                return Err(MerkleTreeError::InconsistentStructureError(
132                    "Malformed proof".to_string(),
133                ));
134            }
135        },
136        None => {
137            if !proof.iter().all(|v| v.len() == ARITY - 1 || v.is_empty()) {
138                return Err(MerkleTreeError::InconsistentStructureError(
139                    "Malformed proof".to_string(),
140                ));
141            }
142        },
143    };
144    if element.is_some() && proof.iter().any(|v| v.len() != ARITY - 1) {
145        // only strictly checking this during membership proof
146        // in non-membership proof, empty leaf/branch node can have empty vector in
147        // their merkle proof
148        return Err(MerkleTreeError::InconsistentStructureError(
149            "Malformed proof".to_string(),
150        ));
151    }
152    let mut data = [T::default(); ARITY];
153    let computed_root = pos
154        .to_traversal_path(proof.len())
155        .iter()
156        .zip(proof.iter())
157        .try_fold(
158            init,
159            |val, (branch, values)| -> Result<T, MerkleTreeError> {
160                if values.len() == 0 {
161                    Ok(T::default())
162                } else {
163                    data[..*branch].copy_from_slice(&values[..*branch]);
164                    data[*branch] = val;
165                    data[*branch + 1..].copy_from_slice(&values[*branch..]);
166                    H::digest(&data)
167                }
168            },
169        )?;
170    if computed_root == *commitment {
171        Ok(SUCCESS)
172    } else {
173        Ok(FAIL)
174    }
175}
176
177#[allow(clippy::type_complexity)]
178pub(crate) fn build_tree_internal<E, H, const ARITY: usize, T>(
179    height: Option<usize>,
180    elems: impl IntoIterator<Item = impl Borrow<E>>,
181) -> Result<(Arc<MerkleNode<E, u64, T>>, usize, u64), MerkleTreeError>
182where
183    E: Element,
184    H: DigestAlgorithm<E, u64, T>,
185    T: NodeValue,
186{
187    let leaves: Vec<_> = elems.into_iter().collect();
188    let num_leaves = leaves.len() as u64;
189    let height = height.unwrap_or_else(|| {
190        let mut height = 0usize;
191        let mut capacity = 1;
192        while capacity < num_leaves {
193            height += 1;
194            capacity *= ARITY as u64;
195        }
196        height
197    });
198    let capacity = BigUint::from(ARITY as u64).pow(height as u32);
199
200    if BigUint::from(num_leaves) > capacity {
201        Err(MerkleTreeError::ExceedCapacity)
202    } else if num_leaves == 0 {
203        Ok((Arc::new(MerkleNode::<E, u64, T>::Empty), height, 0))
204    } else if height == 0usize {
205        Ok((
206            Arc::new(MerkleNode::Leaf {
207                value: H::digest_leaf(&0, leaves[0].borrow())?,
208                pos: 0,
209                elem: leaves[0].borrow().clone(),
210            }),
211            height,
212            1,
213        ))
214    } else {
215        let mut cur_nodes = leaves
216            .into_iter()
217            .enumerate()
218            .chunks(ARITY)
219            .into_iter()
220            .map(|chunk| {
221                let children = chunk
222                    .map(|(pos, elem)| {
223                        let pos = pos as u64;
224                        Ok(Arc::new(MerkleNode::Leaf {
225                            value: H::digest_leaf(&pos, elem.borrow())?,
226                            pos,
227                            elem: elem.borrow().clone(),
228                        }))
229                    })
230                    .pad_using(ARITY, |_| Ok(Arc::new(MerkleNode::Empty)))
231                    .collect::<Result<Vec<_>, MerkleTreeError>>()?;
232                Ok(Arc::new(MerkleNode::<E, u64, T>::Branch {
233                    value: digest_branch::<E, H, u64, T>(&children)?,
234                    children,
235                }))
236            })
237            .collect::<Result<Vec<_>, MerkleTreeError>>()?;
238        for _ in 1..height {
239            cur_nodes = cur_nodes
240                .into_iter()
241                .chunks(ARITY)
242                .into_iter()
243                .map(|chunk| {
244                    let children: Vec<_> = chunk
245                        .pad_using(ARITY, |_| Arc::new(MerkleNode::<E, u64, T>::Empty))
246                        .collect();
247                    Ok(Arc::new(MerkleNode::<E, u64, T>::Branch {
248                        value: digest_branch::<E, H, u64, T>(&children)?,
249                        children,
250                    }))
251                })
252                .collect::<Result<Vec<_>, MerkleTreeError>>()?;
253        }
254        Ok((cur_nodes[0].clone(), height, num_leaves))
255    }
256}
257
258#[allow(clippy::type_complexity)]
259pub(crate) fn build_light_weight_tree_internal<E, H, const ARITY: usize, T>(
260    height: Option<usize>,
261    elems: impl IntoIterator<Item = impl Borrow<E>>,
262) -> Result<(Arc<MerkleNode<E, u64, T>>, usize, u64), MerkleTreeError>
263where
264    E: Element,
265    H: DigestAlgorithm<E, u64, T>,
266    T: NodeValue,
267{
268    let leaves: Vec<_> = elems.into_iter().collect();
269    let num_leaves = leaves.len() as u64;
270    let height = height.unwrap_or_else(|| {
271        let mut height = 0usize;
272        let mut capacity = 1;
273        while capacity < num_leaves {
274            height += 1;
275            capacity *= ARITY as u64;
276        }
277        height
278    });
279    let capacity = num_traits::checked_pow(ARITY as u64, height).ok_or_else(|| {
280        MerkleTreeError::ParametersError("Merkle tree size too large.".to_string())
281    })?;
282
283    if num_leaves > capacity {
284        Err(MerkleTreeError::ExceedCapacity)
285    } else if num_leaves == 0 {
286        Ok((Arc::new(MerkleNode::<E, u64, T>::Empty), height, 0))
287    } else if height == 0usize {
288        Ok((
289            Arc::new(MerkleNode::Leaf {
290                value: H::digest_leaf(&0, leaves[0].borrow())?,
291                pos: 0,
292                elem: leaves[0].borrow().clone(),
293            }),
294            height,
295            1,
296        ))
297    } else {
298        let mut cur_nodes = leaves
299            .into_iter()
300            .enumerate()
301            .chunks(ARITY)
302            .into_iter()
303            .map(|chunk| {
304                let children = chunk
305                    .map(|(pos, elem)| {
306                        let pos = pos as u64;
307                        Ok(if pos < num_leaves - 1 {
308                            Arc::new(MerkleNode::ForgottenSubtree {
309                                value: H::digest_leaf(&pos, elem.borrow())?,
310                            })
311                        } else {
312                            Arc::new(MerkleNode::Leaf {
313                                value: H::digest_leaf(&pos, elem.borrow())?,
314                                pos,
315                                elem: elem.borrow().clone(),
316                            })
317                        })
318                    })
319                    .pad_using(ARITY, |_| Ok(Arc::new(MerkleNode::Empty)))
320                    .collect::<Result<Vec<_>, MerkleTreeError>>()?;
321                Ok(Arc::new(MerkleNode::<E, u64, T>::Branch {
322                    value: digest_branch::<E, H, u64, T>(&children)?,
323                    children,
324                }))
325            })
326            .collect::<Result<Vec<_>, MerkleTreeError>>()?;
327        for i in 1..cur_nodes.len() - 1 {
328            cur_nodes[i] = Arc::new(MerkleNode::ForgottenSubtree {
329                value: cur_nodes[i].value(),
330            })
331        }
332        for _ in 1..height {
333            cur_nodes = cur_nodes
334                .into_iter()
335                .chunks(ARITY)
336                .into_iter()
337                .map(|chunk| {
338                    let children = chunk
339                        .pad_using(ARITY, |_| Arc::new(MerkleNode::<E, u64, T>::Empty))
340                        .collect::<Vec<_>>();
341                    Ok(Arc::new(MerkleNode::<E, u64, T>::Branch {
342                        value: digest_branch::<E, H, u64, T>(&children)?,
343                        children,
344                    }))
345                })
346                .collect::<Result<Vec<_>, MerkleTreeError>>()?;
347            for i in 1..cur_nodes.len() - 1 {
348                cur_nodes[i] = Arc::new(MerkleNode::ForgottenSubtree {
349                    value: cur_nodes[i].value(),
350                })
351            }
352        }
353        Ok((cur_nodes[0].clone(), height, num_leaves))
354    }
355}
356
357pub(crate) fn digest_branch<E, H, I, T>(
358    data: &[Arc<MerkleNode<E, I, T>>],
359) -> Result<T, MerkleTreeError>
360where
361    E: Element,
362    H: DigestAlgorithm<E, I, T>,
363    I: Index,
364    T: NodeValue,
365{
366    // Question(Chengyu): any more efficient implementation?
367    let data = data.iter().map(|node| node.value()).collect::<Vec<_>>();
368    H::digest(&data)
369}
370
371impl<E, I, T> MerkleNode<E, I, T>
372where
373    E: Element,
374    I: Index,
375    T: NodeValue,
376{
377    /// Forget a leaf from the merkle tree. Internal branch merkle node will
378    /// also be forgotten if all its leaves are forgotten.
379    /// WARN(#495): this method breaks non-membership proofs.
380    #[allow(clippy::type_complexity)]
381    pub(crate) fn forget_internal(
382        &self,
383        height: usize,
384        traversal_path: &[usize],
385    ) -> (
386        Arc<Self>,
387        LookupResult<E, MerkleTreeProof<T>, MerkleTreeProof<T>>,
388    ) {
389        match self {
390            MerkleNode::Empty => (
391                Arc::new(self.clone()),
392                LookupResult::NotFound(MerkleTreeProof(vec![])),
393            ),
394            MerkleNode::Branch { value, children } => {
395                let mut children = children.clone();
396                let (new_child, result) = children[traversal_path[height - 1]]
397                    .forget_internal(height - 1, traversal_path);
398                match result {
399                    LookupResult::Ok(elem, mut membership_proof) => {
400                        membership_proof.0.push(
401                            children
402                                .iter()
403                                .enumerate()
404                                .filter(|(id, _)| *id != traversal_path[height - 1])
405                                .map(|(_, child)| child.value())
406                                .collect::<Vec<_>>(),
407                        );
408                        children[traversal_path[height - 1]] = new_child;
409                        if children.iter().all(|child| {
410                            matches!(
411                                **child,
412                                MerkleNode::Empty | MerkleNode::ForgottenSubtree { .. }
413                            )
414                        }) {
415                            (
416                                Arc::new(MerkleNode::ForgottenSubtree { value: *value }),
417                                LookupResult::Ok(elem, membership_proof),
418                            )
419                        } else {
420                            (
421                                Arc::new(MerkleNode::Branch {
422                                    value: *value,
423                                    children,
424                                }),
425                                LookupResult::Ok(elem, membership_proof),
426                            )
427                        }
428                    },
429                    LookupResult::NotInMemory => {
430                        (Arc::new(self.clone()), LookupResult::NotInMemory)
431                    },
432                    LookupResult::NotFound(mut non_membership_proof) => {
433                        non_membership_proof.0.push(
434                            children
435                                .iter()
436                                .enumerate()
437                                .filter(|(id, _)| *id != traversal_path[height - 1])
438                                .map(|(_, child)| child.value())
439                                .collect::<Vec<_>>(),
440                        );
441                        (
442                            Arc::new(self.clone()),
443                            LookupResult::NotFound(non_membership_proof),
444                        )
445                    },
446                }
447            },
448            MerkleNode::Leaf { value, pos, elem } => (
449                Arc::new(MerkleNode::ForgottenSubtree { value: *value }),
450                LookupResult::Ok(elem.clone(), MerkleTreeProof(vec![])),
451            ),
452            _ => (Arc::new(self.clone()), LookupResult::NotInMemory),
453        }
454    }
455
456    /// Re-insert a forgotten leaf to the Merkle tree.
457    /// It also fails if the Merkle proof is invalid.
458    pub(crate) fn remember_internal<H, const ARITY: usize>(
459        &self,
460        height: usize,
461        traversal_path: &[usize],
462        pos: &I,
463        element: Option<&E>,
464        proof: &[Vec<T>],
465    ) -> Result<Arc<Self>, MerkleTreeError>
466    where
467        H: DigestAlgorithm<E, I, T>,
468    {
469        match self {
470            MerkleNode::Empty => Ok(Arc::new(self.clone())),
471            MerkleNode::Leaf {
472                value,
473                pos: leaf_pos,
474                elem,
475            } => {
476                if height != 0 {
477                    // Reach a leaf before it should
478                    Err(MerkleTreeError::InconsistentStructureError(
479                        "Malformed Merkle tree or proof".to_string(),
480                    ))
481                } else {
482                    Ok(Arc::new(self.clone()))
483                }
484            },
485            MerkleNode::Branch { value, children } => {
486                if height == 0 {
487                    // Reach a branch when there should be a leaf
488                    Err(MerkleTreeError::InconsistentStructureError(
489                        "Malformed merkle tree".to_string(),
490                    ))
491                } else {
492                    let branch = traversal_path[height - 1];
493                    let mut children = children.clone();
494                    children[branch] = children[branch].remember_internal::<H, ARITY>(
495                        height - 1,
496                        traversal_path,
497                        pos,
498                        element,
499                        proof,
500                    )?;
501                    Ok(Arc::new(MerkleNode::Branch {
502                        value: *value,
503                        children,
504                    }))
505                }
506            },
507            MerkleNode::ForgottenSubtree { value } => Ok(Arc::new(if height == 0 {
508                if let Some(element) = element {
509                    MerkleNode::Leaf {
510                        value: H::digest_leaf(pos, element)?,
511                        pos: pos.clone(),
512                        elem: element.clone(),
513                    }
514                } else {
515                    MerkleNode::Empty
516                }
517            } else {
518                let branch = traversal_path[height - 1];
519                let mut values = proof[height - 1].clone();
520                values.insert(branch, *value);
521                let mut children = values
522                    .iter()
523                    .map(|&value| Arc::new(MerkleNode::ForgottenSubtree { value }))
524                    .collect::<Vec<_>>();
525                children[branch] = children[branch].remember_internal::<H, ARITY>(
526                    height - 1,
527                    traversal_path,
528                    pos,
529                    element,
530                    proof,
531                )?;
532                values[branch] = children[branch].value();
533                MerkleNode::Branch {
534                    value: H::digest(&values)?,
535                    children,
536                }
537            })),
538        }
539    }
540
541    /// Query the given index at the current Merkle node. Return the element
542    /// with a membership proof if presence, otherwise return a non-membership
543    /// proof.
544    #[allow(clippy::type_complexity)]
545    pub(crate) fn lookup_internal(
546        &self,
547        height: usize,
548        traversal_path: &[usize],
549    ) -> LookupResult<&E, MerkleTreeProof<T>, MerkleTreeProof<T>> {
550        match self {
551            MerkleNode::Empty => LookupResult::NotFound(MerkleTreeProof(vec![vec![]; height])),
552            MerkleNode::Branch { value: _, children } => {
553                match children[traversal_path[height - 1]]
554                    .lookup_internal(height - 1, traversal_path)
555                {
556                    LookupResult::Ok(elem, mut membership_proof) => {
557                        membership_proof.0.push(
558                            children
559                                .iter()
560                                .enumerate()
561                                .filter(|(id, _)| *id != traversal_path[height - 1])
562                                .map(|(_, child)| child.value())
563                                .collect::<Vec<_>>(),
564                        );
565                        LookupResult::Ok(elem, membership_proof)
566                    },
567                    LookupResult::NotInMemory => LookupResult::NotInMemory,
568                    LookupResult::NotFound(mut non_membership_proof) => {
569                        non_membership_proof.0.push(
570                            children
571                                .iter()
572                                .enumerate()
573                                .filter(|(id, _)| *id != traversal_path[height - 1])
574                                .map(|(_, child)| child.value())
575                                .collect::<Vec<_>>(),
576                        );
577                        LookupResult::NotFound(non_membership_proof)
578                    },
579                }
580            },
581            MerkleNode::Leaf {
582                elem,
583                value: _,
584                pos: _,
585            } => LookupResult::Ok(elem, MerkleTreeProof(vec![])),
586            _ => LookupResult::NotInMemory,
587        }
588    }
589
590    /// Update the element at the given index.
591    /// * `returns` - `Err()` if any error happens internally. `Ok(delta,
592    ///   result)`, `delta` represents the changes to the overall number of
593    ///   leaves of the tree, `result` contains the original lookup information
594    ///   at the given location.
595    #[allow(clippy::type_complexity)]
596    pub(crate) fn update_with_internal<H, const ARITY: usize, F>(
597        &self,
598        height: usize,
599        pos: impl Borrow<I>,
600        traversal_path: &[usize],
601        f: F,
602    ) -> Result<(Arc<Self>, i64, LookupResult<E, (), ()>), MerkleTreeError>
603    where
604        H: DigestAlgorithm<E, I, T>,
605        F: FnOnce(Option<&E>) -> Option<E>,
606    {
607        let pos = pos.borrow();
608        match self {
609            MerkleNode::Leaf {
610                elem: node_elem,
611                value: _,
612                pos,
613            } => {
614                let result = LookupResult::Ok(node_elem.clone(), ());
615                match f(Some(node_elem)) {
616                    Some(elem) => Ok((
617                        Arc::new(MerkleNode::Leaf {
618                            value: H::digest_leaf(pos, &elem)?,
619                            pos: pos.clone(),
620                            elem,
621                        }),
622                        0i64,
623                        result,
624                    )),
625                    None => Ok((Arc::new(MerkleNode::Empty), -1i64, result)),
626                }
627            },
628            MerkleNode::Branch { value, children } => {
629                let branch = traversal_path[height - 1];
630                let result = children[branch].update_with_internal::<H, ARITY, _>(
631                    height - 1,
632                    pos,
633                    traversal_path,
634                    f,
635                )?;
636                let mut children = children.clone();
637                children[branch] = result.0;
638                if matches!(*children[branch], MerkleNode::ForgottenSubtree { .. }) {
639                    // If the branch containing the update was forgotten by
640                    // user, the update failed and nothing was changed, so we
641                    // can short-circuit without recomputing this node's value.
642                    Ok((
643                        Arc::new(MerkleNode::Branch {
644                            value: *value,
645                            children,
646                        }),
647                        result.1,
648                        result.2,
649                    ))
650                } else if children
651                    .iter()
652                    .all(|child| matches!(**child, MerkleNode::Empty))
653                {
654                    Ok((Arc::new(MerkleNode::Empty), result.1, result.2))
655                } else {
656                    // Otherwise, an entry has been updated and the value of one of our children has
657                    // changed, so we must recompute our own value.
658                    // *value = digest_branch::<E, H, I, T>(&children)?;
659                    Ok((
660                        Arc::new(MerkleNode::Branch {
661                            value: digest_branch::<E, H, I, T>(&children)?,
662                            children,
663                        }),
664                        result.1,
665                        result.2,
666                    ))
667                }
668            },
669            MerkleNode::Empty => {
670                if height == 0 {
671                    if let Some(elem) = f(None) {
672                        Ok((
673                            Arc::new(MerkleNode::Leaf {
674                                value: H::digest_leaf(pos, &elem)?,
675                                pos: pos.clone(),
676                                elem,
677                            }),
678                            1i64,
679                            LookupResult::NotFound(()),
680                        ))
681                    } else {
682                        Ok((
683                            Arc::new(MerkleNode::Empty),
684                            0i64,
685                            LookupResult::NotFound(()),
686                        ))
687                    }
688                } else {
689                    let branch = traversal_path[height - 1];
690                    let mut children: Vec<_> = (0..ARITY).map(|_| Arc::new(Self::Empty)).collect();
691                    // Inserting new leave here, shortcutting
692                    let result = children[branch].update_with_internal::<H, ARITY, _>(
693                        height - 1,
694                        pos,
695                        traversal_path,
696                        f,
697                    )?;
698                    children[branch] = result.0;
699                    if matches!(*children[branch], MerkleNode::Empty) {
700                        // No update performed.
701                        Ok((Arc::new(MerkleNode::Empty), 0i64, result.2))
702                    } else {
703                        Ok((
704                            Arc::new(MerkleNode::Branch {
705                                value: digest_branch::<E, H, I, T>(&children)?,
706                                children,
707                            }),
708                            result.1,
709                            result.2,
710                        ))
711                    }
712                }
713            },
714            MerkleNode::ForgottenSubtree { .. } => Err(MerkleTreeError::ForgottenLeaf),
715        }
716    }
717}
718
719impl<E, T> MerkleNode<E, u64, T>
720where
721    E: Element,
722    T: NodeValue,
723{
724    /// Batch insertion for the given Merkle node.
725    pub(crate) fn extend_internal<H, const ARITY: usize>(
726        &self,
727        height: usize,
728        pos: &u64,
729        traversal_path: &[usize],
730        at_frontier: bool,
731        data: &mut Peekable<impl Iterator<Item = impl Borrow<E>>>,
732    ) -> Result<(Arc<Self>, u64), MerkleTreeError>
733    where
734        H: DigestAlgorithm<E, u64, T>,
735    {
736        if data.peek().is_none() {
737            return Ok((Arc::new(self.clone()), 0));
738        }
739        let mut cur_pos = *pos;
740        match self {
741            MerkleNode::Branch { value: _, children } => {
742                let mut cnt = 0u64;
743                let mut frontier = if at_frontier {
744                    traversal_path[height - 1]
745                } else {
746                    0
747                };
748                let cap = ARITY;
749                let mut children = children.clone();
750                while data.peek().is_some() && frontier < cap {
751                    let (new_child, increment) = children[frontier].extend_internal::<H, ARITY>(
752                        height - 1,
753                        &cur_pos,
754                        traversal_path,
755                        at_frontier && frontier == traversal_path[height - 1],
756                        data,
757                    )?;
758                    children[frontier] = new_child;
759                    cnt += increment;
760                    cur_pos += increment;
761                    frontier += 1;
762                }
763                let value = digest_branch::<E, H, u64, T>(&children)?;
764                Ok((Arc::new(Self::Branch { value, children }), cnt))
765            },
766            MerkleNode::Empty => {
767                if height == 0 {
768                    let elem = data.next().unwrap();
769                    let elem = elem.borrow();
770                    Ok((
771                        Arc::new(MerkleNode::Leaf {
772                            value: H::digest_leaf(pos, elem)?,
773                            pos: *pos,
774                            elem: elem.clone(),
775                        }),
776                        1,
777                    ))
778                } else {
779                    let mut cnt = 0u64;
780                    let mut frontier = if at_frontier {
781                        traversal_path[height - 1]
782                    } else {
783                        0
784                    };
785                    let cap = ARITY;
786                    let mut children: Vec<_> = (0..cap).map(|_| Arc::new(Self::Empty)).collect();
787                    while data.peek().is_some() && frontier < cap {
788                        let (new_child, increment) = children[frontier]
789                            .extend_internal::<H, ARITY>(
790                                height - 1,
791                                &cur_pos,
792                                traversal_path,
793                                at_frontier && frontier == traversal_path[height - 1],
794                                data,
795                            )?;
796                        children[frontier] = new_child;
797                        cnt += increment;
798                        cur_pos += increment;
799                        frontier += 1;
800                    }
801                    Ok((
802                        Arc::new(MerkleNode::Branch {
803                            value: digest_branch::<E, H, u64, T>(&children)?,
804                            children,
805                        }),
806                        cnt,
807                    ))
808                }
809            },
810            MerkleNode::Leaf { .. } => Err(MerkleTreeError::ExistingLeaf),
811            MerkleNode::ForgottenSubtree { .. } => Err(MerkleTreeError::ForgottenLeaf),
812        }
813    }
814
815    /// Similar to [`extend_internal`], but this function will automatically
816    /// forget every leaf except for the Merkle tree frontier.
817    pub(crate) fn extend_and_forget_internal<H, const ARITY: usize>(
818        &self,
819        height: usize,
820        pos: &u64,
821        traversal_path: &[usize],
822        at_frontier: bool,
823        data: &mut Peekable<impl Iterator<Item = impl Borrow<E>>>,
824    ) -> Result<(Arc<Self>, u64), MerkleTreeError>
825    where
826        H: DigestAlgorithm<E, u64, T>,
827    {
828        if data.peek().is_none() {
829            return Ok((Arc::new(self.clone()), 0));
830        }
831        let mut cur_pos = *pos;
832        match self {
833            MerkleNode::Branch { value: _, children } => {
834                let mut cnt = 0u64;
835                let mut frontier = if at_frontier {
836                    traversal_path[height - 1]
837                } else {
838                    0
839                };
840                let cap = ARITY;
841                let mut children = children.clone();
842                while data.peek().is_some() && frontier < cap {
843                    if frontier > 0 && !children[frontier - 1].is_forgotten() {
844                        children[frontier - 1] =
845                            Arc::new(MerkleNode::<E, u64, T>::ForgottenSubtree {
846                                value: children[frontier - 1].value(),
847                            });
848                    }
849                    let (new_child, increment) = children[frontier]
850                        .extend_and_forget_internal::<H, ARITY>(
851                            height - 1,
852                            &cur_pos,
853                            traversal_path,
854                            at_frontier && frontier == traversal_path[height - 1],
855                            data,
856                        )?;
857                    children[frontier] = new_child;
858                    cnt += increment;
859                    cur_pos += increment;
860                    frontier += 1;
861                }
862                let value = digest_branch::<E, H, u64, T>(&children)?;
863                Ok((Arc::new(Self::Branch { value, children }), cnt))
864            },
865            MerkleNode::Empty => {
866                if height == 0 {
867                    let elem = data.next().unwrap();
868                    let elem = elem.borrow();
869                    Ok((
870                        Arc::new(MerkleNode::Leaf {
871                            value: H::digest_leaf(pos, elem)?,
872                            pos: *pos,
873                            elem: elem.clone(),
874                        }),
875                        1,
876                    ))
877                } else {
878                    let mut cnt = 0u64;
879                    let mut frontier = if at_frontier {
880                        traversal_path[height - 1]
881                    } else {
882                        0
883                    };
884                    let cap = ARITY;
885                    let mut children: Vec<_> = (0..cap).map(|_| Arc::new(Self::Empty)).collect();
886                    while data.peek().is_some() && frontier < cap {
887                        if frontier > 0 && !children[frontier - 1].is_forgotten() {
888                            children[frontier - 1] =
889                                Arc::new(MerkleNode::<E, u64, T>::ForgottenSubtree {
890                                    value: children[frontier - 1].value(),
891                                });
892                        }
893                        let (new_child, increment) = children[frontier]
894                            .extend_and_forget_internal::<H, ARITY>(
895                                height - 1,
896                                &cur_pos,
897                                traversal_path,
898                                at_frontier && frontier == traversal_path[height - 1],
899                                data,
900                            )?;
901                        children[frontier] = new_child;
902                        cnt += increment;
903                        cur_pos += increment;
904                        frontier += 1;
905                    }
906                    Ok((
907                        Arc::new(MerkleNode::Branch {
908                            value: digest_branch::<E, H, u64, T>(&children)?,
909                            children,
910                        }),
911                        cnt,
912                    ))
913                }
914            },
915            MerkleNode::Leaf { .. } => Err(MerkleTreeError::ExistingLeaf),
916            MerkleNode::ForgottenSubtree { .. } => Err(MerkleTreeError::ForgottenLeaf),
917        }
918    }
919}
920
921/// Iterator type for a merkle tree
922pub struct MerkleTreeIter<'a, E: Element, I: Index, T: NodeValue> {
923    stack: Vec<&'a MerkleNode<E, I, T>>,
924}
925
926impl<'a, E: Element, I: Index, T: NodeValue> MerkleTreeIter<'a, E, I, T> {
927    /// Initialize an iterator
928    pub fn new(root: &'a MerkleNode<E, I, T>) -> Self {
929        Self { stack: vec![root] }
930    }
931}
932
933impl<'a, E, I, T> Iterator for MerkleTreeIter<'a, E, I, T>
934where
935    E: Element,
936    I: Index,
937    T: NodeValue,
938{
939    type Item = (&'a I, &'a E);
940
941    fn next(&mut self) -> Option<Self::Item> {
942        while let Some(node) = self.stack.pop() {
943            match node {
944                MerkleNode::Branch { value: _, children } => {
945                    children
946                        .iter()
947                        .rev()
948                        .filter(|child| {
949                            matches!(
950                                ***child,
951                                MerkleNode::Branch { .. } | MerkleNode::Leaf { .. }
952                            )
953                        })
954                        .for_each(|child| self.stack.push(child));
955                },
956                MerkleNode::Leaf {
957                    value: _,
958                    pos,
959                    elem,
960                } => {
961                    return Some((pos, elem));
962                },
963                _ => {},
964            }
965        }
966        None
967    }
968}
969
970/// An owned iterator type for a merkle tree
971pub struct MerkleTreeIntoIter<E: Element, I: Index, T: NodeValue> {
972    stack: Vec<Arc<MerkleNode<E, I, T>>>,
973}
974
975impl<E: Element, I: Index, T: NodeValue> MerkleTreeIntoIter<E, I, T> {
976    /// Initialize an iterator
977    pub fn new(root: Arc<MerkleNode<E, I, T>>) -> Self {
978        Self { stack: vec![root] }
979    }
980}
981
982impl<E, I, T> Iterator for MerkleTreeIntoIter<E, I, T>
983where
984    E: Element,
985    I: Index,
986    T: NodeValue,
987{
988    type Item = (I, E);
989
990    fn next(&mut self) -> Option<Self::Item> {
991        while let Some(node) = self.stack.pop() {
992            match node.as_ref() {
993                MerkleNode::Branch { value: _, children } => {
994                    children
995                        .iter()
996                        .rev()
997                        .filter(|child| {
998                            matches!(
999                                (**child).as_ref(),
1000                                MerkleNode::Branch { .. } | MerkleNode::Leaf { .. }
1001                            )
1002                        })
1003                        .for_each(|child| self.stack.push(child.clone()));
1004                },
1005                MerkleNode::Leaf {
1006                    value: _,
1007                    pos,
1008                    elem,
1009                } => {
1010                    return Some((pos.clone(), elem.clone()));
1011                },
1012                _ => {},
1013            }
1014        }
1015        None
1016    }
1017}