jf_vid/advz/
payload_prover.rs

1// Copyright (c) 2023 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Implementations of [`PayloadProver`] for `Advz`.
8//!
9//! Two implementations:
10//! 1. `PROOF = `[`SmallRangeProof`]: Useful for small sub-slices of `payload`
11//!    such as an individual transaction within a block. Not snark-friendly
12//!    because it requires a pairing. Consists of metadata required to verify a
13//!    KZG batch proof.
14//! 2. `PROOF = `[`LargeRangeProof`]: Useful for large sub-slices of `payload`
15//!    such as a complete namespace. Snark-friendly because it does not require
16//!    a pairing. Consists of metadata required to rebuild a KZG commitment.
17
18use super::{
19    bytes_to_field::{bytes_to_field, elem_byte_capacity},
20    AdvzInternal, KzgEval, KzgProof, MaybeGPU, PolynomialCommitmentScheme, Vec, VidResult,
21};
22use crate::{
23    payload_prover::{PayloadProver, Statement},
24    vid, VidError, VidScheme,
25};
26use anyhow::anyhow;
27use ark_ec::pairing::Pairing;
28use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
29use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
30use ark_std::{format, ops::Range};
31use itertools::Itertools;
32use jf_merkle_tree::hasher::HasherDigest;
33use jf_pcs::prelude::UnivariateKzgPCS;
34use jf_utils::canonical;
35use serde::{Deserialize, Serialize};
36
37/// A proof intended for use on small payload subslices.
38///
39/// KZG batch proofs and accompanying metadata.
40///
41/// TODO use batch proof instead of `Vec<P>` <https://github.com/EspressoSystems/jellyfish/issues/387>
42#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
43#[serde(bound = "P: CanonicalSerialize + CanonicalDeserialize")]
44pub struct SmallRangeProof<P> {
45    #[serde(with = "canonical")]
46    proofs: Vec<P>,
47    prefix_bytes: Vec<u8>,
48    suffix_bytes: Vec<u8>,
49}
50
51/// A proof intended for use on large payload subslices.
52///
53/// Metadata needed to recover a KZG commitment.
54#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
55#[serde(bound = "F: CanonicalSerialize + CanonicalDeserialize")]
56pub struct LargeRangeProof<F> {
57    #[serde(with = "canonical")]
58    prefix_elems: Vec<F>,
59    #[serde(with = "canonical")]
60    suffix_elems: Vec<F>,
61    prefix_bytes: Vec<u8>,
62    suffix_bytes: Vec<u8>,
63}
64
65impl<E, H, T> PayloadProver<SmallRangeProof<KzgProof<E>>> for AdvzInternal<E, H, T>
66where
67    E: Pairing,
68    H: HasherDigest,
69    T: Sync,
70    AdvzInternal<E, H, T>: MaybeGPU<E>,
71{
72    fn payload_proof<B>(
73        &self,
74        payload: B,
75        range: Range<usize>,
76    ) -> VidResult<SmallRangeProof<KzgProof<E>>>
77    where
78        B: AsRef<[u8]>,
79    {
80        let payload = payload.as_ref();
81        check_range_nonempty_and_in_bounds(payload.len(), &range)?;
82
83        // index conversion
84        let multiplicity = self.min_multiplicity(payload.len())?;
85        let range_elem = self.range_byte_to_elem(&range);
86        let range_poly = self.range_elem_to_poly(&range_elem, multiplicity);
87        let range_elem_byte = self.range_elem_to_byte_clamped(&range_elem, payload.len());
88        let range_poly_byte =
89            self.range_poly_to_byte_clamped(&range_poly, payload.len(), multiplicity);
90        let offset_elem =
91            self.offset_poly_to_elem(range_poly.start, range_elem.start, multiplicity);
92        let final_points_range_end =
93            self.final_poly_points_range_end(range_elem.len(), offset_elem, multiplicity);
94
95        // prepare list of input points
96        //
97        // perf: if payload is small enough to fit into a single polynomial then
98        // we don't need all the points in this domain.
99        let points: Vec<_> = Self::eval_domain(
100            usize::try_from(self.recovery_threshold * multiplicity).map_err(vid)?,
101        )?
102        .elements()
103        .collect();
104
105        let elems_iter = bytes_to_field::<_, KzgEval<E>>(&payload[range_poly_byte]);
106        let mut proofs = Vec::with_capacity(range_poly.len() * points.len());
107        for (i, evals_iter) in elems_iter
108            .chunks((self.recovery_threshold * multiplicity) as usize)
109            .into_iter()
110            .enumerate()
111        {
112            let poly = Self::interpolate_polynomial(
113                evals_iter,
114                (self.recovery_threshold * multiplicity) as usize,
115            )?;
116            let points_range = Range {
117                // first polynomial? skip to the start of the proof range
118                start: if i == 0 { offset_elem } else { 0 },
119                // final polynomial? stop at the end of the proof range
120                end: if i == range_poly.len() - 1 {
121                    final_points_range_end
122                } else {
123                    points.len()
124                },
125            };
126            proofs.extend(
127                UnivariateKzgPCS::multi_open(&self.ck, &poly, &points[points_range])
128                    .map_err(vid)?
129                    .0,
130            );
131        }
132
133        Ok(SmallRangeProof {
134            proofs,
135            prefix_bytes: payload[range_elem_byte.start..range.start].to_vec(),
136            suffix_bytes: payload[range.end..range_elem_byte.end].to_vec(),
137        })
138    }
139
140    fn payload_verify(
141        &self,
142        stmt: Statement<Self>,
143        proof: &SmallRangeProof<KzgProof<E>>,
144    ) -> VidResult<Result<(), ()>> {
145        Self::check_stmt_consistency(&stmt)?;
146
147        // prepare list of data elems
148        let data_elems: Vec<_> = bytes_to_field::<_, KzgEval<E>>(
149            proof
150                .prefix_bytes
151                .iter()
152                .chain(stmt.payload_subslice)
153                .chain(proof.suffix_bytes.iter()),
154        )
155        .collect();
156
157        if data_elems.len() != proof.proofs.len() {
158            return Err(VidError::Argument(format!(
159                "data len {} differs from proof len {}",
160                data_elems.len(),
161                proof.proofs.len()
162            )));
163        }
164
165        // index conversion
166        let range_elem = self.range_byte_to_elem(&stmt.range);
167        let range_poly = self.range_elem_to_poly(&range_elem, stmt.common.multiplicity);
168        let offset_elem =
169            self.offset_poly_to_elem(range_poly.start, range_elem.start, stmt.common.multiplicity);
170        let final_points_range_end = self.final_poly_points_range_end(
171            range_elem.len(),
172            offset_elem,
173            stmt.common.multiplicity,
174        );
175
176        // prepare list of input points
177        //
178        // perf: if payload is small enough to fit into a single polynomial then
179        // we don't need all the points in this domain.
180        let points: Vec<_> = Self::eval_domain(
181            usize::try_from(self.recovery_threshold * stmt.common.multiplicity).map_err(vid)?,
182        )?
183        .elements()
184        .collect();
185
186        // verify proof
187        let mut cur_proof_index = 0;
188        for (i, poly_commit) in stmt.common.poly_commits[range_poly.clone()]
189            .iter()
190            .enumerate()
191        {
192            let points_range = Range {
193                // first polynomial? skip to the start of the proof range
194                start: if i == 0 { offset_elem } else { 0 },
195                // final polynomial? stop at the end of the proof range
196                end: if i == range_poly.len() - 1 {
197                    final_points_range_end
198                } else {
199                    points.len()
200                },
201            };
202            // TODO naive verify for multi_open https://github.com/EspressoSystems/jellyfish/issues/387
203            for point in points[points_range].iter() {
204                let data_elem = data_elems
205                    .get(cur_proof_index)
206                    .ok_or_else(|| VidError::Internal(anyhow!("ran out of data elems")))?;
207                let cur_proof = proof
208                    .proofs
209                    .get(cur_proof_index)
210                    .ok_or_else(|| VidError::Internal(anyhow!("ran out of proofs")))?;
211                if !UnivariateKzgPCS::verify(&self.vk, poly_commit, point, data_elem, cur_proof)
212                    .map_err(vid)?
213                {
214                    return Ok(Err(()));
215                }
216                cur_proof_index += 1;
217            }
218        }
219        assert_eq!(cur_proof_index, proof.proofs.len()); // sanity
220        Ok(Ok(()))
221    }
222}
223
224impl<E, H, T> PayloadProver<LargeRangeProof<KzgEval<E>>> for AdvzInternal<E, H, T>
225where
226    E: Pairing,
227    H: HasherDigest,
228    T: Sync,
229    AdvzInternal<E, H, T>: MaybeGPU<E>,
230{
231    fn payload_proof<B>(
232        &self,
233        payload: B,
234        range: Range<usize>,
235    ) -> VidResult<LargeRangeProof<KzgEval<E>>>
236    where
237        B: AsRef<[u8]>,
238    {
239        let payload = payload.as_ref();
240        check_range_nonempty_and_in_bounds(payload.len(), &range)?;
241
242        // index conversion
243        let multiplicity = self.min_multiplicity(payload.len())?;
244        let range_elem = self.range_byte_to_elem(&range);
245        let range_poly = self.range_elem_to_poly(&range_elem, multiplicity);
246        let range_elem_byte = self.range_elem_to_byte_clamped(&range_elem, payload.len());
247        let range_poly_byte =
248            self.range_poly_to_byte_clamped(&range_poly, payload.len(), multiplicity);
249        let offset_elem =
250            self.offset_poly_to_elem(range_poly.start, range_elem.start, multiplicity);
251
252        // compute the prefix and suffix elems
253        let mut elems_iter = bytes_to_field::<_, KzgEval<E>>(payload[range_poly_byte].iter());
254        let prefix_elems: Vec<_> = elems_iter.by_ref().take(offset_elem).collect();
255        let suffix_elems: Vec<_> = elems_iter.skip(range_elem.len()).collect();
256
257        Ok(LargeRangeProof {
258            prefix_elems,
259            suffix_elems,
260            prefix_bytes: payload[range_elem_byte.start..range.start].to_vec(),
261            suffix_bytes: payload[range.end..range_elem_byte.end].to_vec(),
262        })
263    }
264
265    fn payload_verify(
266        &self,
267        stmt: Statement<Self>,
268        proof: &LargeRangeProof<KzgEval<E>>,
269    ) -> VidResult<Result<(), ()>> {
270        Self::check_stmt_consistency(&stmt)?;
271
272        // index conversion
273        let range_poly = self.range_byte_to_poly(&stmt.range, stmt.common.multiplicity);
274
275        // rebuild the needed payload elements from statement and proof
276        let elems_iter = proof
277            .prefix_elems
278            .iter()
279            .cloned()
280            .chain(bytes_to_field::<_, KzgEval<E>>(
281                proof
282                    .prefix_bytes
283                    .iter()
284                    .chain(stmt.payload_subslice)
285                    .chain(proof.suffix_bytes.iter()),
286            ))
287            .chain(proof.suffix_elems.iter().cloned());
288        // rebuild the poly commits, check against `common`
289        for (commit_index, evals_iter) in range_poly.into_iter().zip(
290            elems_iter
291                .chunks((self.recovery_threshold * stmt.common.multiplicity) as usize)
292                .into_iter(),
293        ) {
294            let poly = Self::interpolate_polynomial(
295                evals_iter,
296                (stmt.common.multiplicity * self.recovery_threshold) as usize,
297            )?;
298            let poly_commit = UnivariateKzgPCS::commit(&self.ck, &poly).map_err(vid)?;
299            if poly_commit != stmt.common.poly_commits[commit_index] {
300                return Ok(Err(()));
301            }
302        }
303        Ok(Ok(()))
304    }
305}
306
307impl<E, H, T> AdvzInternal<E, H, T>
308where
309    E: Pairing,
310    H: HasherDigest,
311    T: Sync,
312    AdvzInternal<E, H, T>: MaybeGPU<E>,
313{
314    // lots of index manipulation
315    fn range_byte_to_elem(&self, range: &Range<usize>) -> Range<usize> {
316        range_coarsen(range, elem_byte_capacity::<KzgEval<E>>())
317    }
318    fn range_elem_to_byte_clamped(&self, range: &Range<usize>, len: usize) -> Range<usize> {
319        let result = range_refine(range, elem_byte_capacity::<KzgEval<E>>());
320        Range {
321            end: ark_std::cmp::min(result.end, len),
322            ..result
323        }
324    }
325    fn range_elem_to_poly(&self, range: &Range<usize>, multiplicity: u32) -> Range<usize> {
326        range_coarsen(range, (self.recovery_threshold * multiplicity) as usize)
327    }
328    fn range_byte_to_poly(&self, range: &Range<usize>, multiplicity: u32) -> Range<usize> {
329        range_coarsen(
330            range,
331            (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::<KzgEval<E>>(),
332        )
333    }
334    fn range_poly_to_byte_clamped(
335        &self,
336        range: &Range<usize>,
337        len: usize,
338        multiplicity: u32,
339    ) -> Range<usize> {
340        let result = range_refine(
341            range,
342            (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::<KzgEval<E>>(),
343        );
344        Range {
345            end: ark_std::cmp::min(result.end, len),
346            ..result
347        }
348    }
349    fn offset_poly_to_elem(
350        &self,
351        range_poly_start: usize,
352        range_elem_start: usize,
353        multiplicity: u32,
354    ) -> usize {
355        let start_poly_byte = index_refine(
356            range_poly_start,
357            (self.recovery_threshold * multiplicity) as usize * elem_byte_capacity::<KzgEval<E>>(),
358        );
359        range_elem_start - index_coarsen(start_poly_byte, elem_byte_capacity::<KzgEval<E>>())
360    }
361    fn final_poly_points_range_end(
362        &self,
363        range_elem_len: usize,
364        offset_elem: usize,
365        multiplicity: u32,
366    ) -> usize {
367        (range_elem_len + offset_elem - 1) % (self.recovery_threshold * multiplicity) as usize + 1
368    }
369
370    fn check_stmt_consistency(stmt: &Statement<Self>) -> VidResult<()> {
371        check_range_nonempty_and_in_bounds(
372            stmt.common.payload_byte_len.try_into().map_err(vid)?,
373            &stmt.range,
374        )?;
375        if stmt.payload_subslice.len() != stmt.range.len() {
376            return Err(VidError::Argument(format!(
377                "payload_subslice length {} inconsistent with range length {}",
378                stmt.payload_subslice.len(),
379                stmt.range.len()
380            )));
381        }
382        Self::is_consistent(stmt.commit, stmt.common)
383    }
384}
385
386fn range_coarsen(range: &Range<usize>, denominator: usize) -> Range<usize> {
387    assert!(!range.is_empty(), "{:?}", range);
388    Range {
389        start: index_coarsen(range.start, denominator),
390        end: index_coarsen(range.end - 1, denominator) + 1,
391    }
392}
393
394fn range_refine(range: &Range<usize>, multiplier: usize) -> Range<usize> {
395    assert!(!range.is_empty(), "{:?}", range);
396    Range {
397        start: index_refine(range.start, multiplier),
398        end: index_refine(range.end, multiplier),
399    }
400}
401
402fn index_coarsen(index: usize, denominator: usize) -> usize {
403    index / denominator
404}
405
406fn index_refine(index: usize, multiplier: usize) -> usize {
407    index * multiplier
408}
409
410fn check_range_nonempty_and_in_bounds(len: usize, range: &Range<usize>) -> VidResult<()> {
411    if range.is_empty() {
412        return Err(VidError::Argument(format!(
413            "empty range ({}..{})",
414            range.start, range.end
415        )));
416    }
417    // no need to check range.start because we already checked range.is_empty()
418    if range.end > len {
419        return Err(VidError::Argument(format!(
420            "range ({}..{}) out of bounds for length {}",
421            range.start, range.end, len
422        )));
423    }
424    Ok(())
425}
426
427#[cfg(test)]
428mod tests {
429    use crate::{
430        advz::{
431            bytes_to_field::elem_byte_capacity,
432            payload_prover::{LargeRangeProof, SmallRangeProof, Statement},
433            test::*,
434            *,
435        },
436        payload_prover::PayloadProver,
437    };
438    use ark_bn254::Bn254;
439    use ark_std::{ops::Range, print, println, rand::Rng};
440    use sha2::Sha256;
441
442    fn correctness_generic<E, H>()
443    where
444        E: Pairing,
445        H: HasherDigest,
446    {
447        // play with these items
448        let (recovery_threshold, num_storage_nodes, max_multiplicity) = (4, 6, 2);
449        let num_polys = 3;
450        let num_random_cases = 20;
451
452        // more items as a function of the above
453        let poly_elems_len = recovery_threshold as usize * max_multiplicity as usize;
454        let payload_elems_len = num_polys * poly_elems_len;
455        let poly_bytes_len = poly_elems_len * elem_byte_capacity::<E::ScalarField>();
456        let payload_bytes_base_len = payload_elems_len * elem_byte_capacity::<E::ScalarField>();
457        let mut rng = jf_utils::test_rng();
458        let srs = init_srs(payload_elems_len, &mut rng);
459        let mut advz = Advz::<E, H>::with_multiplicity(
460            num_storage_nodes,
461            recovery_threshold,
462            max_multiplicity,
463            srs,
464        )
465        .unwrap();
466
467        // TEST: different payload byte lengths
468        let payload_byte_len_noise_cases = vec![0, poly_bytes_len / 2, poly_bytes_len - 1];
469        let payload_len_cases = payload_byte_len_noise_cases
470            .into_iter()
471            .map(|l| payload_bytes_base_len - l);
472
473        // TEST: prove data ranges for this payload
474        // it takes too long to test all combos of (polynomial, start, len)
475        // so do some edge cases and random cases
476        let edge_cases = {
477            let mut edge_cases = make_edge_cases(0, poly_bytes_len); // inside the first polynomial
478            edge_cases.extend(make_edge_cases(
479                payload_bytes_base_len - poly_bytes_len,
480                payload_bytes_base_len,
481            )); // inside the final polynomial
482            edge_cases.extend(make_edge_cases(0, payload_bytes_base_len)); // spanning the entire payload
483            edge_cases
484        };
485        let random_cases = {
486            let mut random_cases = Vec::with_capacity(num_random_cases);
487            for _ in 0..num_random_cases {
488                let start = rng.gen_range(0..payload_bytes_base_len - 1);
489                let end = rng.gen_range(start + 1..payload_bytes_base_len);
490                random_cases.push(Range { start, end });
491            }
492            random_cases
493        };
494        let all_cases = [(edge_cases, "edge"), (random_cases, "rand")];
495
496        // at least one test case should have nontrivial multiplicity
497        let mut nontrivial_multiplicity = false;
498
499        for payload_len_case in payload_len_cases {
500            let payload = init_random_payload(payload_len_case, &mut rng);
501            let d = advz.disperse(&payload).unwrap();
502            if d.common.multiplicity > 1 {
503                nontrivial_multiplicity = true;
504            }
505            println!("payload byte len case: {}", payload.len());
506
507            for cases in all_cases.iter() {
508                for range in cases.0.iter() {
509                    print!("{} case: {:?}", cases.1, range);
510
511                    // ensure range fits inside payload
512                    let range = if range.start >= payload.len() {
513                        println!(" outside payload len {}, skipping", payload.len());
514                        continue;
515                    } else if range.end > payload.len() {
516                        println!(" clamped to payload len {}", payload.len());
517                        Range {
518                            end: payload.len(),
519                            ..*range
520                        }
521                    } else {
522                        println!();
523                        range.clone()
524                    };
525
526                    let stmt = Statement {
527                        payload_subslice: &payload[range.clone()],
528                        range: range.clone(),
529                        commit: &d.commit,
530                        common: &d.common,
531                    };
532
533                    let small_range_proof: SmallRangeProof<_> =
534                        advz.payload_proof(&payload, range.clone()).unwrap();
535                    advz.payload_verify(stmt.clone(), &small_range_proof)
536                        .unwrap()
537                        .unwrap();
538
539                    let large_range_proof: LargeRangeProof<_> =
540                        advz.payload_proof(&payload, range.clone()).unwrap();
541                    advz.payload_verify(stmt.clone(), &large_range_proof)
542                        .unwrap()
543                        .unwrap();
544
545                    // test wrong proofs
546                    let stmt_corrupted = Statement {
547                        // corrupt the payload subslice by adding 1 to each byte
548                        payload_subslice: &stmt
549                            .payload_subslice
550                            .iter()
551                            .cloned()
552                            .map(|b| b.wrapping_add(1))
553                            .collect::<Vec<_>>(),
554                        ..stmt
555                    };
556                    advz.payload_verify(stmt_corrupted.clone(), &small_range_proof)
557                        .unwrap()
558                        .unwrap_err();
559                    advz.payload_verify(stmt_corrupted, &large_range_proof)
560                        .unwrap()
561                        .unwrap_err();
562
563                    // TODO more tests for bad proofs, eg:
564                    // - valid proof, different range
565                    // - corrupt proof
566                    // - etc
567                }
568            }
569        }
570
571        assert!(
572            nontrivial_multiplicity,
573            "at least one payload size should use multiplicity > 1"
574        );
575
576        fn make_edge_cases(min: usize, max: usize) -> Vec<Range<usize>> {
577            vec![
578                Range {
579                    start: min,
580                    end: min + 1,
581                },
582                Range {
583                    start: min,
584                    end: min + 2,
585                },
586                Range {
587                    start: min,
588                    end: max - 1,
589                },
590                Range {
591                    start: min,
592                    end: max,
593                },
594                Range {
595                    start: min + 1,
596                    end: min + 2,
597                },
598                Range {
599                    start: min + 1,
600                    end: min + 3,
601                },
602                Range {
603                    start: min + 1,
604                    end: max - 1,
605                },
606                Range {
607                    start: min + 1,
608                    end: max,
609                },
610                Range {
611                    start: max - 2,
612                    end: max - 1,
613                },
614                Range {
615                    start: max - 2,
616                    end: max,
617                },
618                Range {
619                    start: max - 1,
620                    end: max,
621                },
622            ]
623        }
624    }
625
626    #[test]
627    fn correctness() {
628        correctness_generic::<Bn254, Sha256>();
629    }
630}