jf_plonk/proof_system/
snark.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Instantiations of Plonk-based proof systems
8use super::{
9    prover::Prover,
10    structs::{
11        BatchProof, Challenges, Oracles, PlookupProof, PlookupProvingKey, PlookupVerifyingKey,
12        Proof, ProvingKey, VerifyingKey,
13    },
14    verifier::Verifier,
15    UniversalSNARK,
16};
17use crate::{
18    constants::EXTRA_TRANSCRIPT_MSG_LABEL,
19    errors::{PlonkError, SnarkError::ParameterError},
20    proof_system::structs::UniversalSrs,
21    transcript::*,
22};
23use ark_ec::{
24    pairing::Pairing,
25    short_weierstrass::{Affine, SWCurveConfig},
26};
27use ark_ff::{Field, One};
28use ark_std::{
29    format,
30    marker::PhantomData,
31    rand::{CryptoRng, RngCore},
32    string::ToString,
33    vec,
34    vec::Vec,
35};
36use jf_pcs::{prelude::UnivariateKzgPCS, PolynomialCommitmentScheme, StructuredReferenceString};
37use jf_relation::{
38    constants::compute_coset_representatives, gadgets::ecc::SWToTEConParam, Arithmetization,
39};
40use jf_rescue::RescueParameter;
41use jf_utils::par_utils::parallelizable_slice_iter;
42#[cfg(feature = "parallel")]
43use rayon::prelude::*;
44
45/// A Plonk instantiated with KZG PCS
46pub struct PlonkKzgSnark<E: Pairing>(PhantomData<E>);
47
48impl<E, F, P> PlonkKzgSnark<E>
49where
50    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
51    F: RescueParameter + SWToTEConParam,
52    P: SWCurveConfig<BaseField = F>,
53{
54    #[allow(clippy::new_without_default)]
55    /// A new Plonk KZG SNARK
56    pub fn new() -> Self {
57        Self(PhantomData)
58    }
59
60    /// Generate an aggregated Plonk proof for multiple instances.
61    pub fn batch_prove<C, R, T>(
62        prng: &mut R,
63        circuits: &[&C],
64        prove_keys: &[&ProvingKey<E>],
65    ) -> Result<BatchProof<E>, PlonkError>
66    where
67        C: Arithmetization<E::ScalarField>,
68        R: CryptoRng + RngCore,
69        T: PlonkTranscript<F>,
70    {
71        let (batch_proof, ..) =
72            Self::batch_prove_internal::<_, _, T>(prng, circuits, prove_keys, None)?;
73        Ok(batch_proof)
74    }
75
76    /// Verify a single aggregated Plonk proof.
77    pub fn verify_batch_proof<T>(
78        verify_keys: &[&VerifyingKey<E>],
79        public_inputs: &[&[E::ScalarField]],
80        batch_proof: &BatchProof<E>,
81    ) -> Result<(), PlonkError>
82    where
83        T: PlonkTranscript<F>,
84    {
85        if verify_keys.is_empty() {
86            return Err(ParameterError("empty verification keys".to_string()).into());
87        }
88        let verifier = Verifier::new(verify_keys[0].domain_size)?;
89        let pcs_info =
90            verifier.prepare_pcs_info::<T>(verify_keys, public_inputs, batch_proof, &None)?;
91        if !Verifier::batch_verify_opening_proofs::<T>(
92            &verify_keys[0].open_key, // all open_key are the same
93            &[pcs_info],
94        )? {
95            return Err(PlonkError::WrongProof);
96        }
97        Ok(())
98    }
99
100    /// Batch verify multiple SNARK proofs (w.r.t. different verifying keys).
101    pub fn batch_verify<T>(
102        verify_keys: &[&VerifyingKey<E>],
103        public_inputs: &[&[E::ScalarField]],
104        proofs: &[&Proof<E>],
105        extra_transcript_init_msgs: &[Option<Vec<u8>>],
106    ) -> Result<(), PlonkError>
107    where
108        T: PlonkTranscript<F>,
109    {
110        if public_inputs.len() != proofs.len()
111            || verify_keys.len() != proofs.len()
112            || extra_transcript_init_msgs.len() != proofs.len()
113        {
114            return Err(ParameterError(format!(
115                "verify_keys.len: {}, public_inputs.len: {}, proofs.len: {}, \
116                 extra_transcript_msg.len: {}",
117                verify_keys.len(),
118                public_inputs.len(),
119                proofs.len(),
120                extra_transcript_init_msgs.len()
121            ))
122            .into());
123        }
124        if verify_keys.is_empty() {
125            return Err(
126                ParameterError("the number of instances cannot be zero".to_string()).into(),
127            );
128        }
129
130        let pcs_infos = parallelizable_slice_iter(verify_keys)
131            .zip(parallelizable_slice_iter(proofs))
132            .zip(parallelizable_slice_iter(public_inputs))
133            .zip(parallelizable_slice_iter(extra_transcript_init_msgs))
134            .map(|(((&vk, &proof), &pub_input), extra_msg)| {
135                let verifier = Verifier::new(vk.domain_size)?;
136                verifier.prepare_pcs_info::<T>(
137                    &[vk],
138                    &[pub_input],
139                    &(*proof).clone().into(),
140                    extra_msg,
141                )
142            })
143            .collect::<Result<Vec<_>, PlonkError>>()?;
144
145        if !Verifier::batch_verify_opening_proofs::<T>(
146            &verify_keys[0].open_key, // all open_key are the same
147            &pcs_infos,
148        )? {
149            return Err(PlonkError::WrongProof);
150        }
151        Ok(())
152    }
153
154    /// An internal private API for ease of testing
155    ///
156    /// Batchly compute a Plonk proof for multiple instances. Return the batch
157    /// proof and the corresponding online polynomial oracles and
158    /// challenges. Refer to Sec 8.4 of https://eprint.iacr.org/2019/953.pdf
159    ///
160    /// `circuit` and `prove_key` has to be consistent (with the same evaluation
161    /// domain etc.), otherwise return error.
162    #[allow(clippy::type_complexity)]
163    fn batch_prove_internal<C, R, T>(
164        prng: &mut R,
165        circuits: &[&C],
166        prove_keys: &[&ProvingKey<E>],
167        extra_transcript_init_msg: Option<Vec<u8>>,
168    ) -> Result<
169        (
170            BatchProof<E>,
171            Vec<Oracles<E::ScalarField>>,
172            Challenges<E::ScalarField>,
173        ),
174        PlonkError,
175    >
176    where
177        C: Arithmetization<E::ScalarField>,
178        R: CryptoRng + RngCore,
179        T: PlonkTranscript<F>,
180    {
181        if circuits.is_empty() {
182            return Err(ParameterError("zero number of circuits/proving keys".to_string()).into());
183        }
184        if circuits.len() != prove_keys.len() {
185            return Err(ParameterError(format!(
186                "the number of circuits {} != the number of proving keys {}",
187                circuits.len(),
188                prove_keys.len()
189            ))
190            .into());
191        }
192        let n = circuits[0].eval_domain_size()?;
193        let num_wire_types = circuits[0].num_wire_types();
194        for (circuit, pk) in circuits.iter().zip(prove_keys.iter()) {
195            if circuit.eval_domain_size()? != n {
196                return Err(ParameterError(format!(
197                    "circuit domain size {} != expected domain size {}",
198                    circuit.eval_domain_size()?,
199                    n
200                ))
201                .into());
202            }
203            if pk.domain_size() != n {
204                return Err(ParameterError(format!(
205                    "proving key domain size {} != expected domain size {}",
206                    pk.domain_size(),
207                    n
208                ))
209                .into());
210            }
211            if circuit.num_inputs() != pk.vk.num_inputs {
212                return Err(ParameterError(format!(
213                    "circuit.num_inputs {} != prove_key.num_inputs {}",
214                    circuit.num_inputs(),
215                    pk.vk.num_inputs
216                ))
217                .into());
218            }
219            if circuit.support_lookup() != pk.plookup_pk.is_some() {
220                return Err(ParameterError(
221                    "Mismatched Plonk types between the proving key and the circuit".to_string(),
222                )
223                .into());
224            }
225            if circuit.num_wire_types() != num_wire_types {
226                return Err(ParameterError("inconsistent plonk circuit types".to_string()).into());
227            }
228        }
229
230        // Initialize transcript
231        let mut transcript = T::new(b"PlonkProof");
232        if let Some(msg) = extra_transcript_init_msg {
233            transcript.append_message(EXTRA_TRANSCRIPT_MSG_LABEL, &msg)?;
234        }
235        for (pk, circuit) in prove_keys.iter().zip(circuits.iter()) {
236            transcript.append_vk_and_pub_input(&pk.vk, &circuit.public_input()?)?;
237        }
238        // Initialize verifier challenges and online polynomial oracles.
239        let mut challenges = Challenges::default();
240        let mut online_oracles = vec![Oracles::default(); circuits.len()];
241        let prover = Prover::new(n, num_wire_types)?;
242
243        // Round 1
244        let mut wires_poly_comms_vec = vec![];
245        for i in 0..circuits.len() {
246            let ((wires_poly_comms, wire_polys), pi_poly) =
247                prover.run_1st_round(prng, &prove_keys[i].commit_key, circuits[i])?;
248            online_oracles[i].wire_polys = wire_polys;
249            online_oracles[i].pub_inp_poly = pi_poly;
250            transcript.append_commitments(b"witness_poly_comms", &wires_poly_comms)?;
251            wires_poly_comms_vec.push(wires_poly_comms);
252        }
253
254        // Round 1.5
255        // Plookup: compute and interpolate the sorted concatenation of the (merged)
256        // lookup table and the (merged) witness values
257        if circuits.iter().any(|c| C::support_lookup(c)) {
258            challenges.tau = Some(transcript.get_challenge::<E>(b"tau")?);
259        } else {
260            challenges.tau = None;
261        }
262
263        let mut h_poly_comms_vec = vec![];
264        let mut sorted_vec_list = vec![];
265        let mut merged_table_list = vec![];
266        for i in 0..circuits.len() {
267            let (sorted_vec, h_poly_comms, merged_table) = if circuits[i].support_lookup() {
268                let ((h_poly_comms, h_polys), sorted_vec, merged_table) = prover
269                    .run_plookup_1st_round(
270                        prng,
271                        &prove_keys[i].commit_key,
272                        circuits[i],
273                        challenges.tau.unwrap(),
274                    )?;
275                online_oracles[i].plookup_oracles.h_polys = h_polys;
276                transcript.append_commitments(b"h_poly_comms", &h_poly_comms)?;
277                (Some(sorted_vec), Some(h_poly_comms), Some(merged_table))
278            } else {
279                (None, None, None)
280            };
281            h_poly_comms_vec.push(h_poly_comms);
282            sorted_vec_list.push(sorted_vec);
283            merged_table_list.push(merged_table);
284        }
285
286        // Round 2
287        challenges.beta = transcript.get_challenge::<E>(b"beta")?;
288        challenges.gamma = transcript.get_challenge::<E>(b"gamma")?;
289        let mut prod_perm_poly_comms_vec = vec![];
290        for i in 0..circuits.len() {
291            let (prod_perm_poly_comm, prod_perm_poly) =
292                prover.run_2nd_round(prng, &prove_keys[i].commit_key, circuits[i], &challenges)?;
293            online_oracles[i].prod_perm_poly = prod_perm_poly;
294            transcript.append_commitment(b"perm_poly_comms", &prod_perm_poly_comm)?;
295            prod_perm_poly_comms_vec.push(prod_perm_poly_comm);
296        }
297
298        // Round 2.5
299        // Plookup: compute Plookup product accumulation polynomial
300        let mut prod_lookup_poly_comms_vec = vec![];
301        for i in 0..circuits.len() {
302            let prod_lookup_poly_comm = if circuits[i].support_lookup() {
303                let (prod_lookup_poly_comm, prod_lookup_poly) = prover.run_plookup_2nd_round(
304                    prng,
305                    &prove_keys[i].commit_key,
306                    circuits[i],
307                    &challenges,
308                    merged_table_list[i].as_ref(),
309                    sorted_vec_list[i].as_ref(),
310                )?;
311                online_oracles[i].plookup_oracles.prod_lookup_poly = prod_lookup_poly;
312                transcript.append_commitment(b"plookup_poly_comms", &prod_lookup_poly_comm)?;
313                Some(prod_lookup_poly_comm)
314            } else {
315                None
316            };
317            prod_lookup_poly_comms_vec.push(prod_lookup_poly_comm);
318        }
319
320        // Round 3
321        challenges.alpha = transcript.get_challenge::<E>(b"alpha")?;
322        let (split_quot_poly_comms, split_quot_polys) = prover.run_3rd_round(
323            prng,
324            &prove_keys[0].commit_key,
325            prove_keys,
326            &challenges,
327            &online_oracles,
328            num_wire_types,
329        )?;
330        transcript.append_commitments(b"quot_poly_comms", &split_quot_poly_comms)?;
331
332        // Round 4
333        challenges.zeta = transcript.get_challenge::<E>(b"zeta")?;
334        let mut poly_evals_vec = vec![];
335        for i in 0..circuits.len() {
336            let poly_evals = prover.compute_evaluations(
337                prove_keys[i],
338                &challenges,
339                &online_oracles[i],
340                num_wire_types,
341            );
342            transcript.append_proof_evaluations::<E>(&poly_evals)?;
343            poly_evals_vec.push(poly_evals);
344        }
345
346        // Round 4.5
347        // Plookup: compute evaluations on Plookup-related polynomials
348        let mut plookup_evals_vec = vec![];
349        for i in 0..circuits.len() {
350            let plookup_evals = if circuits[i].support_lookup() {
351                let evals = prover.compute_plookup_evaluations(
352                    prove_keys[i],
353                    &challenges,
354                    &online_oracles[i],
355                )?;
356                transcript.append_plookup_evaluations::<E>(&evals)?;
357                Some(evals)
358            } else {
359                None
360            };
361            plookup_evals_vec.push(plookup_evals);
362        }
363
364        let mut lin_poly = Prover::<E>::compute_quotient_component_for_lin_poly(
365            n,
366            challenges.zeta,
367            &split_quot_polys,
368        )?;
369        let mut alpha_base = E::ScalarField::one();
370        let alpha_3 = challenges.alpha.square() * challenges.alpha;
371        let alpha_7 = alpha_3.square() * challenges.alpha;
372        for i in 0..circuits.len() {
373            lin_poly = lin_poly
374                + prover.compute_non_quotient_component_for_lin_poly(
375                    alpha_base,
376                    prove_keys[i],
377                    &challenges,
378                    &online_oracles[i],
379                    &poly_evals_vec[i],
380                    plookup_evals_vec[i].as_ref(),
381                )?;
382            // update the alpha power term (i.e. the random combiner that aggregates
383            // multiple instances)
384            if plookup_evals_vec[i].is_some() {
385                alpha_base *= alpha_7;
386            } else {
387                alpha_base *= alpha_3;
388            }
389        }
390
391        // Round 5
392        challenges.v = transcript.get_challenge::<E>(b"v")?;
393        let (opening_proof, shifted_opening_proof) = prover.compute_opening_proofs(
394            &prove_keys[0].commit_key,
395            prove_keys,
396            &challenges.zeta,
397            &challenges.v,
398            &online_oracles,
399            &lin_poly,
400        )?;
401
402        // Plookup: build Plookup argument
403        let mut plookup_proofs_vec = vec![];
404        for i in 0..circuits.len() {
405            let plookup_proof = if circuits[i].support_lookup() {
406                Some(PlookupProof {
407                    h_poly_comms: h_poly_comms_vec[i].clone().unwrap(),
408                    prod_lookup_poly_comm: prod_lookup_poly_comms_vec[i].unwrap(),
409                    poly_evals: plookup_evals_vec[i].clone().unwrap(),
410                })
411            } else {
412                None
413            };
414            plookup_proofs_vec.push(plookup_proof);
415        }
416
417        Ok((
418            BatchProof {
419                wires_poly_comms_vec,
420                prod_perm_poly_comms_vec,
421                poly_evals_vec,
422                plookup_proofs_vec,
423                split_quot_poly_comms,
424                opening_proof,
425                shifted_opening_proof,
426            },
427            online_oracles,
428            challenges,
429        ))
430    }
431}
432
433impl<E, F, P> UniversalSNARK<E> for PlonkKzgSnark<E>
434where
435    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
436    F: RescueParameter + SWToTEConParam,
437    P: SWCurveConfig<BaseField = F>,
438{
439    type Proof = Proof<E>;
440    type ProvingKey = ProvingKey<E>;
441    type VerifyingKey = VerifyingKey<E>;
442    type UniversalSRS = UniversalSrs<E>;
443    type Error = PlonkError;
444
445    // FIXME: (alex) see <https://github.com/EspressoSystems/jellyfish/issues/249>
446    #[cfg(any(test, feature = "test-srs"))]
447    fn universal_setup_for_testing<R: RngCore + CryptoRng>(
448        max_degree: usize,
449        rng: &mut R,
450    ) -> Result<Self::UniversalSRS, Self::Error> {
451        use ark_ec::{scalar_mul::fixed_base::FixedBase, CurveGroup};
452        use ark_ff::PrimeField;
453        use ark_std::{end_timer, start_timer, UniformRand};
454
455        let setup_time = start_timer!(|| format!("KZG10::Setup with degree {}", max_degree));
456        let beta = E::ScalarField::rand(rng);
457        let g = E::G1::rand(rng);
458        let h = E::G2::rand(rng);
459
460        let mut powers_of_beta = vec![E::ScalarField::one()];
461
462        let mut cur = beta;
463        for _ in 0..max_degree {
464            powers_of_beta.push(cur);
465            cur *= &beta;
466        }
467
468        let window_size = FixedBase::get_mul_window_size(max_degree + 1);
469
470        let scalar_bits = E::ScalarField::MODULUS_BIT_SIZE as usize;
471        let g_time = start_timer!(|| "Generating powers of G");
472        // TODO: parallelization
473        let g_table = FixedBase::get_window_table(scalar_bits, window_size, g);
474        let powers_of_g =
475            FixedBase::msm::<E::G1>(scalar_bits, window_size, &g_table, &powers_of_beta);
476        end_timer!(g_time);
477
478        let powers_of_g = E::G1::normalize_batch(&powers_of_g);
479
480        let h = h.into_affine();
481        let beta_h = (h * beta).into_affine();
482
483        let pp = UniversalSrs {
484            powers_of_g,
485            h,
486            beta_h,
487            powers_of_h: vec![h, beta_h],
488        };
489        end_timer!(setup_time);
490        Ok(pp)
491    }
492
493    /// Input a circuit and the SRS, precompute the proving key and verification
494    /// key.
495    fn preprocess<C: Arithmetization<E::ScalarField>>(
496        srs: &Self::UniversalSRS,
497        circuit: &C,
498    ) -> Result<(Self::ProvingKey, Self::VerifyingKey), Self::Error> {
499        // Make sure the SRS can support the circuit (with hiding degree of 2 for zk)
500        let domain_size = circuit.eval_domain_size()?;
501        let srs_size = circuit.srs_size()?;
502        let num_inputs = circuit.num_inputs();
503        if srs.max_degree() < circuit.srs_size()? {
504            return Err(PlonkError::IndexTooLarge);
505        }
506        // 1. Compute selector and permutation polynomials.
507        let selectors_polys = circuit.compute_selector_polynomials()?;
508        let sigma_polys = circuit.compute_extended_permutation_polynomials()?;
509
510        // Compute Plookup proving key if support lookup.
511        let plookup_pk = if circuit.support_lookup() {
512            let range_table_poly = circuit.compute_range_table_polynomial()?;
513            let key_table_poly = circuit.compute_key_table_polynomial()?;
514            let table_dom_sep_poly = circuit.compute_table_dom_sep_polynomial()?;
515            let q_dom_sep_poly = circuit.compute_q_dom_sep_polynomial()?;
516            Some(PlookupProvingKey {
517                range_table_poly,
518                key_table_poly,
519                table_dom_sep_poly,
520                q_dom_sep_poly,
521            })
522        } else {
523            None
524        };
525
526        // 2. Compute VerifyingKey
527        let (commit_key, open_key) = srs.trim(srs_size)?;
528        let selector_comms = parallelizable_slice_iter(&selectors_polys)
529            .map(|poly| UnivariateKzgPCS::commit(&commit_key, poly).map_err(PlonkError::PCSError))
530            .collect::<Result<Vec<_>, PlonkError>>()?
531            .into_iter()
532            .collect();
533        let sigma_comms = parallelizable_slice_iter(&sigma_polys)
534            .map(|poly| UnivariateKzgPCS::commit(&commit_key, poly).map_err(PlonkError::PCSError))
535            .collect::<Result<Vec<_>, PlonkError>>()?
536            .into_iter()
537            .collect();
538
539        // Compute Plookup verifying key if support lookup.
540        let plookup_vk = match circuit.support_lookup() {
541            false => None,
542            true => Some(PlookupVerifyingKey {
543                range_table_comm: UnivariateKzgPCS::commit(
544                    &commit_key,
545                    &plookup_pk.as_ref().unwrap().range_table_poly,
546                )?,
547                key_table_comm: UnivariateKzgPCS::commit(
548                    &commit_key,
549                    &plookup_pk.as_ref().unwrap().key_table_poly,
550                )?,
551                table_dom_sep_comm: UnivariateKzgPCS::commit(
552                    &commit_key,
553                    &plookup_pk.as_ref().unwrap().table_dom_sep_poly,
554                )?,
555                q_dom_sep_comm: UnivariateKzgPCS::commit(
556                    &commit_key,
557                    &plookup_pk.as_ref().unwrap().q_dom_sep_poly,
558                )?,
559            }),
560        };
561
562        let vk = VerifyingKey {
563            domain_size,
564            num_inputs,
565            selector_comms,
566            sigma_comms,
567            k: compute_coset_representatives(circuit.num_wire_types(), Some(domain_size)),
568            open_key,
569            plookup_vk,
570            is_merged: false,
571        };
572
573        // Compute ProvingKey (which includes the VerifyingKey)
574        let pk = ProvingKey {
575            sigmas: sigma_polys,
576            selectors: selectors_polys,
577            commit_key,
578            vk: vk.clone(),
579            plookup_pk,
580        };
581
582        Ok((pk, vk))
583    }
584
585    /// Compute a Plonk proof.
586    /// Refer to Sec 8.4 of <https://eprint.iacr.org/2019/953.pdf>
587    ///
588    /// `circuit` and `prove_key` has to be consistent (with the same evaluation
589    /// domain etc.), otherwise return error.
590    fn prove<C, R, T>(
591        rng: &mut R,
592        circuit: &C,
593        prove_key: &Self::ProvingKey,
594        extra_transcript_init_msg: Option<Vec<u8>>,
595    ) -> Result<Self::Proof, Self::Error>
596    where
597        C: Arithmetization<E::ScalarField>,
598        R: CryptoRng + RngCore,
599        T: PlonkTranscript<F>,
600    {
601        let (batch_proof, ..) = Self::batch_prove_internal::<_, _, T>(
602            rng,
603            &[circuit],
604            &[prove_key],
605            extra_transcript_init_msg,
606        )?;
607        Ok(Proof {
608            wires_poly_comms: batch_proof.wires_poly_comms_vec[0].clone(),
609            prod_perm_poly_comm: batch_proof.prod_perm_poly_comms_vec[0],
610            split_quot_poly_comms: batch_proof.split_quot_poly_comms,
611            opening_proof: batch_proof.opening_proof,
612            shifted_opening_proof: batch_proof.shifted_opening_proof,
613            poly_evals: batch_proof.poly_evals_vec[0].clone(),
614            plookup_proof: batch_proof.plookup_proofs_vec[0].clone(),
615        })
616    }
617
618    fn verify<T>(
619        verify_key: &Self::VerifyingKey,
620        public_input: &[E::ScalarField],
621        proof: &Self::Proof,
622        extra_transcript_init_msg: Option<Vec<u8>>,
623    ) -> Result<(), Self::Error>
624    where
625        T: PlonkTranscript<F>,
626    {
627        Self::batch_verify::<T>(
628            &[verify_key],
629            &[public_input],
630            &[proof],
631            &[extra_transcript_init_msg],
632        )
633    }
634}
635
636#[cfg(test)]
637pub mod test {
638    use crate::{
639        errors::PlonkError,
640        proof_system::{
641            structs::{
642                eval_merged_lookup_witness, eval_merged_table, Challenges, Oracles, Proof,
643                ProvingKey, UniversalSrs, VerifyingKey,
644            },
645            PlonkKzgSnark, UniversalSNARK,
646        },
647        transcript::{
648            rescue::RescueTranscript, solidity::SolidityTranscript, standard::StandardTranscript,
649            PlonkTranscript,
650        },
651        PlonkType,
652    };
653    use ark_bls12_377::{Bls12_377, Fq as Fq377};
654    use ark_bls12_381::{Bls12_381, Fq as Fq381};
655    use ark_bn254::{Bn254, Fq as Fq254};
656    use ark_bw6_761::{Fq as Fq761, BW6_761};
657    use ark_ec::{
658        pairing::Pairing,
659        short_weierstrass::{Affine, SWCurveConfig},
660    };
661    use ark_ff::{One, PrimeField, Zero};
662    use ark_poly::{
663        univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Polynomial,
664        Radix2EvaluationDomain,
665    };
666    use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
667    use ark_std::{
668        format,
669        rand::{CryptoRng, RngCore},
670        string::ToString,
671        vec,
672        vec::Vec,
673    };
674    use core::ops::{Mul, Neg};
675    use jf_pcs::{
676        prelude::{Commitment, UnivariateKzgPCS},
677        PolynomialCommitmentScheme,
678    };
679    use jf_relation::{
680        constants::GATE_WIDTH, gadgets::ecc::SWToTEConParam, Arithmetization, Circuit,
681        MergeableCircuitType, PlonkCircuit,
682    };
683    use jf_rescue::RescueParameter;
684    use jf_utils::test_rng;
685
686    // Different `m`s lead to different circuits.
687    // Different `a0`s lead to different witness values.
688    // For UltraPlonk circuits, `a0` should be less than or equal to `m+1`
689    pub(crate) fn gen_circuit_for_test<F: PrimeField>(
690        m: usize,
691        a0: usize,
692        plonk_type: PlonkType,
693    ) -> Result<PlonkCircuit<F>, PlonkError> {
694        let range_bit_len = 5;
695        let mut cs: PlonkCircuit<F> = match plonk_type {
696            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
697            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(range_bit_len),
698        };
699        // Create variables
700        let mut a = vec![];
701        for i in a0..(a0 + 4 * m) {
702            a.push(cs.create_variable(F::from(i as u64))?);
703        }
704        let b = [
705            cs.create_public_variable(F::from(m as u64 * 2))?,
706            cs.create_public_variable(F::from(a0 as u64 * 2 + m as u64 * 4 - 1))?,
707        ];
708        let c = cs.create_public_variable(
709            (cs.witness(b[1])? + cs.witness(a[0])?) * (cs.witness(b[1])? - cs.witness(a[0])?),
710        )?;
711
712        // Create gates:
713        // 1. a0 + ... + a_{4*m-1} = b0 * b1
714        // 2. (b1 + a0) * (b1 - a0) = c
715        // 3. b0 = 2 * m
716        let mut acc = cs.zero();
717        a.iter().for_each(|&elem| acc = cs.add(acc, elem).unwrap());
718        let b_mul = cs.mul(b[0], b[1])?;
719        cs.enforce_equal(acc, b_mul)?;
720        let b1_plus_a0 = cs.add(b[1], a[0])?;
721        let b1_minus_a0 = cs.sub(b[1], a[0])?;
722        cs.mul_gate(b1_plus_a0, b1_minus_a0, c)?;
723        cs.enforce_constant(b[0], F::from(m as u64 * 2))?;
724
725        if plonk_type == PlonkType::UltraPlonk {
726            // Create range gates
727            // 1. range_table = {0, 1, ..., 31}
728            // 2. a_i \in range_table for i = 0..m-1
729            // 3. b0 \in range_table
730            for &var in a.iter().take(m) {
731                cs.add_range_check_variable(var)?;
732            }
733            cs.add_range_check_variable(b[0])?;
734
735            // Create variable table lookup gates
736            // 1. table = [(a0, a2), (a1, a3), (b0, a0)]
737            let table_vars = [(a[0], a[2]), (a[1], a[3]), (b[0], a[0])];
738            // 2. lookup_witness = [(1, a0+1, a0+3), (2, 2m, a0)]
739            let key0 = cs.one();
740            let key1 = cs.create_variable(F::from(2u8))?;
741            let two_m = cs.create_public_variable(F::from(m as u64 * 2))?;
742            let a1 = cs.add_constant(a[0], &F::one())?;
743            let a3 = cs.add_constant(a[0], &F::from(3u8))?;
744            let lookup_vars = [(key0, a1, a3), (key1, two_m, a[0])];
745            cs.create_table_and_lookup_variables(&lookup_vars, &table_vars)?;
746        }
747
748        // Finalize the circuit.
749        cs.finalize_for_arithmetization()?;
750
751        Ok(cs)
752    }
753
754    #[test]
755    fn test_preprocessing() -> Result<(), PlonkError> {
756        test_preprocessing_helper::<Bn254, Fq254, _>(PlonkType::TurboPlonk)?;
757        test_preprocessing_helper::<Bn254, Fq254, _>(PlonkType::UltraPlonk)?;
758        test_preprocessing_helper::<Bls12_377, Fq377, _>(PlonkType::TurboPlonk)?;
759        test_preprocessing_helper::<Bls12_377, Fq377, _>(PlonkType::UltraPlonk)?;
760        test_preprocessing_helper::<Bls12_381, Fq381, _>(PlonkType::TurboPlonk)?;
761        test_preprocessing_helper::<Bls12_381, Fq381, _>(PlonkType::UltraPlonk)?;
762        test_preprocessing_helper::<BW6_761, Fq761, _>(PlonkType::TurboPlonk)?;
763        test_preprocessing_helper::<BW6_761, Fq761, _>(PlonkType::UltraPlonk)
764    }
765    fn test_preprocessing_helper<E, F, P>(plonk_type: PlonkType) -> Result<(), PlonkError>
766    where
767        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
768        F: RescueParameter + SWToTEConParam,
769        P: SWCurveConfig<BaseField = F>,
770    {
771        let rng = &mut jf_utils::test_rng();
772        let circuit = gen_circuit_for_test(5, 6, plonk_type)?;
773        let domain_size = circuit.eval_domain_size()?;
774        let num_inputs = circuit.num_inputs();
775        let selectors = circuit.compute_selector_polynomials()?;
776        let sigmas = circuit.compute_extended_permutation_polynomials()?;
777
778        let max_degree = 64 + 2;
779        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
780        let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
781
782        // check proving key
783        assert_eq!(pk.selectors, selectors);
784        assert_eq!(pk.sigmas, sigmas);
785        assert_eq!(pk.domain_size(), domain_size);
786        assert_eq!(pk.num_inputs(), num_inputs);
787        let num_wire_types = GATE_WIDTH
788            + 1
789            + match plonk_type {
790                PlonkType::TurboPlonk => 0,
791                PlonkType::UltraPlonk => 1,
792            };
793        assert_eq!(pk.sigmas.len(), num_wire_types);
794        // check plookup proving key
795        if plonk_type == PlonkType::UltraPlonk {
796            let range_table_poly = circuit.compute_range_table_polynomial()?;
797            assert_eq!(
798                pk.plookup_pk.as_ref().unwrap().range_table_poly,
799                range_table_poly
800            );
801
802            let key_table_poly = circuit.compute_key_table_polynomial()?;
803            assert_eq!(
804                pk.plookup_pk.as_ref().unwrap().key_table_poly,
805                key_table_poly
806            );
807        }
808
809        // check verifying key
810        assert_eq!(vk.domain_size, domain_size);
811        assert_eq!(vk.num_inputs, num_inputs);
812        assert_eq!(vk.selector_comms.len(), selectors.len());
813        assert_eq!(vk.sigma_comms.len(), sigmas.len());
814        assert_eq!(vk.sigma_comms.len(), num_wire_types);
815        selectors
816            .iter()
817            .zip(vk.selector_comms.iter())
818            .for_each(|(p, &p_comm)| {
819                let expected_comm = UnivariateKzgPCS::commit(&pk.commit_key, p).unwrap();
820                assert_eq!(expected_comm, p_comm);
821            });
822        sigmas
823            .iter()
824            .zip(vk.sigma_comms.iter())
825            .for_each(|(p, &p_comm)| {
826                let expected_comm = UnivariateKzgPCS::commit(&pk.commit_key, p).unwrap();
827                assert_eq!(expected_comm, p_comm);
828            });
829        // check plookup verification key
830        if plonk_type == PlonkType::UltraPlonk {
831            let expected_comm = UnivariateKzgPCS::commit(
832                &pk.commit_key,
833                &pk.plookup_pk.as_ref().unwrap().range_table_poly,
834            )
835            .unwrap();
836            assert_eq!(
837                expected_comm,
838                vk.plookup_vk.as_ref().unwrap().range_table_comm
839            );
840
841            let expected_comm = UnivariateKzgPCS::commit(
842                &pk.commit_key,
843                &pk.plookup_pk.as_ref().unwrap().key_table_poly,
844            )
845            .unwrap();
846            assert_eq!(
847                expected_comm,
848                vk.plookup_vk.as_ref().unwrap().key_table_comm
849            );
850        }
851
852        Ok(())
853    }
854
855    #[test]
856    fn test_plonk_proof_system() -> Result<(), PlonkError> {
857        // merlin transcripts
858        test_plonk_proof_system_helper::<Bn254, Fq254, _, StandardTranscript>(
859            PlonkType::TurboPlonk,
860        )?;
861        test_plonk_proof_system_helper::<Bn254, Fq254, _, StandardTranscript>(
862            PlonkType::UltraPlonk,
863        )?;
864        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, StandardTranscript>(
865            PlonkType::TurboPlonk,
866        )?;
867        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, StandardTranscript>(
868            PlonkType::UltraPlonk,
869        )?;
870        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, StandardTranscript>(
871            PlonkType::TurboPlonk,
872        )?;
873        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, StandardTranscript>(
874            PlonkType::UltraPlonk,
875        )?;
876        test_plonk_proof_system_helper::<BW6_761, Fq761, _, StandardTranscript>(
877            PlonkType::TurboPlonk,
878        )?;
879        test_plonk_proof_system_helper::<BW6_761, Fq761, _, StandardTranscript>(
880            PlonkType::UltraPlonk,
881        )?;
882
883        // rescue transcripts
884        // currently only available for bls12-377
885        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
886            PlonkType::TurboPlonk,
887        )?;
888        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
889            PlonkType::UltraPlonk,
890        )?;
891
892        // solidity-friendly keccak256 transcripts
893        // currently only needed for CAPE using bls12-381
894        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
895            PlonkType::TurboPlonk,
896        )?;
897        Ok(())
898    }
899
900    fn test_plonk_proof_system_helper<E, F, P, T>(plonk_type: PlonkType) -> Result<(), PlonkError>
901    where
902        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
903        F: RescueParameter + SWToTEConParam,
904        P: SWCurveConfig<BaseField = F>,
905        T: PlonkTranscript<F>,
906    {
907        // 1. Simulate universal setup
908        let rng = &mut test_rng();
909        let n = 64;
910        let max_degree = n + 2;
911        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
912
913        // 2. Create circuits
914        let circuits = (0..6)
915            .map(|i| {
916                let m = 2 + i / 3;
917                let a0 = 1 + i % 3;
918                gen_circuit_for_test(m, a0, plonk_type)
919            })
920            .collect::<Result<Vec<_>, PlonkError>>()?;
921        // 3. Preprocessing
922        let (pk1, vk1) = PlonkKzgSnark::<E>::preprocess(&srs, &circuits[0])?;
923        let (pk2, vk2) = PlonkKzgSnark::<E>::preprocess(&srs, &circuits[3])?;
924        // 4. Proving
925        let mut proofs = vec![];
926        let mut extra_msgs = vec![];
927        for (i, cs) in circuits.iter().enumerate() {
928            let pk_ref = if i < 3 { &pk1 } else { &pk2 };
929            let extra_msg = if i % 2 == 0 {
930                None
931            } else {
932                Some(format!("extra message: {}", i).into_bytes())
933            };
934            proofs.push(
935                PlonkKzgSnark::<E>::prove::<_, _, T>(rng, cs, pk_ref, extra_msg.clone()).unwrap(),
936            );
937            extra_msgs.push(extra_msg);
938        }
939
940        // 5. Verification
941        let public_inputs: Vec<Vec<E::ScalarField>> = circuits
942            .iter()
943            .map(|cs| cs.public_input())
944            .collect::<Result<Vec<Vec<E::ScalarField>>, _>>(
945        )?;
946        for (i, proof) in proofs.iter().enumerate() {
947            let vk_ref = if i < 3 { &vk1 } else { &vk2 };
948            assert!(PlonkKzgSnark::<E>::verify::<T>(
949                vk_ref,
950                &public_inputs[i],
951                proof,
952                extra_msgs[i].clone(),
953            )
954            .is_ok());
955            // Inconsistent proof should fail the verification.
956            let mut bad_pub_input = public_inputs[i].clone();
957            bad_pub_input[0] = E::ScalarField::from(0u8);
958            assert!(PlonkKzgSnark::<E>::verify::<T>(
959                vk_ref,
960                &bad_pub_input,
961                proof,
962                extra_msgs[i].clone(),
963            )
964            .is_err());
965            // Incorrect extra transcript message should fail
966            assert!(PlonkKzgSnark::<E>::verify::<T>(
967                vk_ref,
968                &bad_pub_input,
969                proof,
970                Some("wrong message".to_string().into_bytes()),
971            )
972            .is_err());
973
974            // Incorrect proof [W_z] = 0, [W_z*g] = 0
975            // attack against some vulnerable implementation described in:
976            // https://cryptosubtlety.medium.com/00-8d4adcf4d255
977            let mut bad_proof = proof.clone();
978            bad_proof.opening_proof = Commitment::default();
979            bad_proof.shifted_opening_proof = Commitment::default();
980            assert!(PlonkKzgSnark::<E>::verify::<T>(
981                vk_ref,
982                &public_inputs[i],
983                &bad_proof,
984                extra_msgs[i].clone(),
985            )
986            .is_err());
987        }
988
989        // 6. Batch verification
990        let vks = vec![&vk1, &vk1, &vk1, &vk2, &vk2, &vk2];
991        let mut public_inputs_ref: Vec<&[E::ScalarField]> = public_inputs
992            .iter()
993            .map(|pub_input| &pub_input[..])
994            .collect();
995        let mut proofs_ref: Vec<&Proof<E>> = proofs.iter().collect();
996        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
997            &vks,
998            &public_inputs_ref,
999            &proofs_ref,
1000            &extra_msgs,
1001        )
1002        .is_ok());
1003
1004        // Inconsistent params
1005        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1006            &vks[..5],
1007            &public_inputs_ref,
1008            &proofs_ref,
1009            &extra_msgs,
1010        )
1011        .is_err());
1012
1013        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1014            &vks,
1015            &public_inputs_ref[..5],
1016            &proofs_ref,
1017            &extra_msgs,
1018        )
1019        .is_err());
1020
1021        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1022            &vks,
1023            &public_inputs_ref,
1024            &proofs_ref[..5],
1025            &extra_msgs,
1026        )
1027        .is_err());
1028
1029        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1030            &vks,
1031            &public_inputs_ref,
1032            &proofs_ref,
1033            &vec![None; vks.len()],
1034        )
1035        .is_err());
1036
1037        assert!(
1038            PlonkKzgSnark::<E>::batch_verify::<T>(&vks, &public_inputs_ref, &proofs_ref, &[],)
1039                .is_err()
1040        );
1041
1042        // Empty params
1043        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(&[], &[], &[], &[],).is_err());
1044
1045        // Error paths
1046        let tmp_pi_ref = public_inputs_ref[0];
1047        public_inputs_ref[0] = public_inputs_ref[1];
1048        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1049            &vks,
1050            &public_inputs_ref,
1051            &proofs_ref,
1052            &extra_msgs,
1053        )
1054        .is_err());
1055        public_inputs_ref[0] = tmp_pi_ref;
1056
1057        proofs_ref[0] = proofs_ref[1];
1058        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1059            &vks,
1060            &public_inputs_ref,
1061            &proofs_ref,
1062            &extra_msgs,
1063        )
1064        .is_err());
1065
1066        Ok(())
1067    }
1068
1069    #[test]
1070    fn test_inconsistent_pub_input_len() -> Result<(), PlonkError> {
1071        // merlin transcripts
1072        test_inconsistent_pub_input_len_helper::<Bn254, Fq254, _, StandardTranscript>(
1073            PlonkType::TurboPlonk,
1074        )?;
1075        test_inconsistent_pub_input_len_helper::<Bn254, Fq254, _, StandardTranscript>(
1076            PlonkType::UltraPlonk,
1077        )?;
1078        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1079            PlonkType::TurboPlonk,
1080        )?;
1081        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1082            PlonkType::UltraPlonk,
1083        )?;
1084        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1085            PlonkType::TurboPlonk,
1086        )?;
1087        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1088            PlonkType::UltraPlonk,
1089        )?;
1090        test_inconsistent_pub_input_len_helper::<BW6_761, Fq761, _, StandardTranscript>(
1091            PlonkType::TurboPlonk,
1092        )?;
1093        test_inconsistent_pub_input_len_helper::<BW6_761, Fq761, _, StandardTranscript>(
1094            PlonkType::UltraPlonk,
1095        )?;
1096
1097        // rescue transcripts
1098        // currently only available for bls12-377
1099        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1100            PlonkType::TurboPlonk,
1101        )?;
1102        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1103            PlonkType::UltraPlonk,
1104        )?;
1105
1106        // Solidity-friendly keccak256 transcript
1107        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
1108            PlonkType::TurboPlonk,
1109        )?;
1110
1111        Ok(())
1112    }
1113
1114    fn test_inconsistent_pub_input_len_helper<E, F, P, T>(
1115        plonk_type: PlonkType,
1116    ) -> Result<(), PlonkError>
1117    where
1118        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1119        F: RescueParameter + SWToTEConParam,
1120        P: SWCurveConfig<BaseField = F>,
1121        T: PlonkTranscript<F>,
1122    {
1123        // 1. Simulate universal setup
1124        let rng = &mut test_rng();
1125        let n = 8;
1126        let max_degree = n + 2;
1127        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1128
1129        // 2. Create circuits
1130        let mut cs1: PlonkCircuit<E::ScalarField> = match plonk_type {
1131            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
1132            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(2),
1133        };
1134        let var = cs1.create_variable(E::ScalarField::from(1u8))?;
1135        cs1.enforce_constant(var, E::ScalarField::from(1u8))?;
1136        cs1.finalize_for_arithmetization()?;
1137        let mut cs2: PlonkCircuit<E::ScalarField> = match plonk_type {
1138            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
1139            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(2),
1140        };
1141        cs2.create_public_variable(E::ScalarField::from(1u8))?;
1142        cs2.finalize_for_arithmetization()?;
1143
1144        // 3. Preprocessing
1145        let (pk1, vk1) = PlonkKzgSnark::<E>::preprocess(&srs, &cs1)?;
1146        let (pk2, vk2) = PlonkKzgSnark::<E>::preprocess(&srs, &cs2)?;
1147
1148        // 4. Proving
1149        assert!(PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &cs2, &pk1, None).is_err());
1150        let proof2 = PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &cs2, &pk2, None)?;
1151
1152        // 5. Verification
1153        assert!(
1154            PlonkKzgSnark::<E>::verify::<T>(&vk2, &[E::ScalarField::from(1u8)], &proof2, None,)
1155                .is_ok()
1156        );
1157        // wrong verification key
1158        assert!(
1159            PlonkKzgSnark::<E>::verify::<T>(&vk1, &[E::ScalarField::from(1u8)], &proof2, None,)
1160                .is_err()
1161        );
1162        // wrong public input
1163        assert!(PlonkKzgSnark::<E>::verify::<T>(&vk2, &[], &proof2, None).is_err());
1164
1165        Ok(())
1166    }
1167
1168    #[test]
1169    fn test_plonk_prover_polynomials() -> Result<(), PlonkError> {
1170        // merlin transcripts
1171        test_plonk_prover_polynomials_helper::<Bn254, Fq254, _, StandardTranscript>(
1172            PlonkType::TurboPlonk,
1173        )?;
1174        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1175            PlonkType::TurboPlonk,
1176        )?;
1177        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1178            PlonkType::TurboPlonk,
1179        )?;
1180        test_plonk_prover_polynomials_helper::<BW6_761, Fq761, _, StandardTranscript>(
1181            PlonkType::TurboPlonk,
1182        )?;
1183        test_plonk_prover_polynomials_helper::<Bn254, Fq254, _, StandardTranscript>(
1184            PlonkType::UltraPlonk,
1185        )?;
1186        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1187            PlonkType::UltraPlonk,
1188        )?;
1189        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1190            PlonkType::UltraPlonk,
1191        )?;
1192        test_plonk_prover_polynomials_helper::<BW6_761, Fq761, _, StandardTranscript>(
1193            PlonkType::UltraPlonk,
1194        )?;
1195
1196        // rescue transcripts
1197        // currently only available for bls12-377
1198        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1199            PlonkType::TurboPlonk,
1200        )?;
1201        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1202            PlonkType::UltraPlonk,
1203        )?;
1204
1205        // Solidity-friendly keccak256 transcript
1206        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
1207            PlonkType::TurboPlonk,
1208        )?;
1209
1210        Ok(())
1211    }
1212
1213    fn test_plonk_prover_polynomials_helper<E, F, P, T>(
1214        plonk_type: PlonkType,
1215    ) -> Result<(), PlonkError>
1216    where
1217        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1218        F: RescueParameter + SWToTEConParam,
1219        P: SWCurveConfig<BaseField = F>,
1220        T: PlonkTranscript<F>,
1221    {
1222        // 1. Simulate universal setup
1223        let rng = &mut test_rng();
1224        let n = 64;
1225        let max_degree = n + 2;
1226        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1227
1228        // 2. Create the circuit
1229        let circuit = gen_circuit_for_test(10, 3, plonk_type)?;
1230        assert!(circuit.num_gates() <= n);
1231
1232        // 3. Preprocessing
1233        let (pk, _) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1234
1235        // 4. Proving
1236        let (_, oracles, challenges) =
1237            PlonkKzgSnark::<E>::batch_prove_internal::<_, _, T>(rng, &[&circuit], &[&pk], None)?;
1238
1239        // 5. Check that the targeted polynomials evaluate to zero on the vanishing set.
1240        check_plonk_prover_polynomials(plonk_type, &oracles[0], &pk, &challenges)?;
1241
1242        Ok(())
1243    }
1244
1245    fn check_plonk_prover_polynomials<E: Pairing>(
1246        plonk_type: PlonkType,
1247        oracles: &Oracles<E::ScalarField>,
1248        pk: &ProvingKey<E>,
1249        challenges: &Challenges<E::ScalarField>,
1250    ) -> Result<(), PlonkError> {
1251        check_circuit_polynomial_on_vanishing_set(oracles, pk)?;
1252        check_perm_polynomials_on_vanishing_set(oracles, pk, challenges)?;
1253        if plonk_type == PlonkType::UltraPlonk {
1254            check_lookup_polynomials_on_vanishing_set(oracles, pk, challenges)?;
1255        }
1256
1257        Ok(())
1258    }
1259
1260    fn check_circuit_polynomial_on_vanishing_set<E: Pairing>(
1261        oracles: &Oracles<E::ScalarField>,
1262        pk: &ProvingKey<E>,
1263    ) -> Result<(), PlonkError> {
1264        let q_lc: Vec<&DensePolynomial<E::ScalarField>> =
1265            (0..GATE_WIDTH).map(|j| &pk.selectors[j]).collect();
1266        let q_mul: Vec<&DensePolynomial<E::ScalarField>> = (GATE_WIDTH..GATE_WIDTH + 2)
1267            .map(|j| &pk.selectors[j])
1268            .collect();
1269        let q_hash: Vec<&DensePolynomial<E::ScalarField>> = (GATE_WIDTH + 2..2 * GATE_WIDTH + 2)
1270            .map(|j| &pk.selectors[j])
1271            .collect();
1272        let q_o = &pk.selectors[2 * GATE_WIDTH + 2];
1273        let q_c = &pk.selectors[2 * GATE_WIDTH + 3];
1274        let q_ecc = &pk.selectors[2 * GATE_WIDTH + 4];
1275        let circuit_poly = q_c
1276            + &oracles.pub_inp_poly
1277            + oracles.wire_polys[0].mul(q_lc[0])
1278            + oracles.wire_polys[1].mul(q_lc[1])
1279            + oracles.wire_polys[2].mul(q_lc[2])
1280            + oracles.wire_polys[3].mul(q_lc[3])
1281            + oracles.wire_polys[0]
1282                .mul(&oracles.wire_polys[1])
1283                .mul(q_mul[0])
1284            + oracles.wire_polys[2]
1285                .mul(&oracles.wire_polys[3])
1286                .mul(q_mul[1])
1287            + oracles.wire_polys[0]
1288                .mul(&oracles.wire_polys[1])
1289                .mul(&oracles.wire_polys[2])
1290                .mul(&oracles.wire_polys[3])
1291                .mul(&oracles.wire_polys[4])
1292                .mul(q_ecc)
1293            + oracles.wire_polys[0]
1294                .mul(&oracles.wire_polys[0])
1295                .mul(&oracles.wire_polys[0])
1296                .mul(&oracles.wire_polys[0])
1297                .mul(&oracles.wire_polys[0])
1298                .mul(q_hash[0])
1299            + oracles.wire_polys[1]
1300                .mul(&oracles.wire_polys[1])
1301                .mul(&oracles.wire_polys[1])
1302                .mul(&oracles.wire_polys[1])
1303                .mul(&oracles.wire_polys[1])
1304                .mul(q_hash[1])
1305            + oracles.wire_polys[2]
1306                .mul(&oracles.wire_polys[2])
1307                .mul(&oracles.wire_polys[2])
1308                .mul(&oracles.wire_polys[2])
1309                .mul(&oracles.wire_polys[2])
1310                .mul(q_hash[2])
1311            + oracles.wire_polys[3]
1312                .mul(&oracles.wire_polys[3])
1313                .mul(&oracles.wire_polys[3])
1314                .mul(&oracles.wire_polys[3])
1315                .mul(&oracles.wire_polys[3])
1316                .mul(q_hash[3])
1317            + oracles.wire_polys[4].mul(q_o).neg();
1318
1319        // check that the polynomial evaluates to zero on the vanishing set
1320        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(pk.domain_size())
1321            .ok_or(PlonkError::DomainCreationError)?;
1322        for i in 0..domain.size() {
1323            assert_eq!(
1324                circuit_poly.evaluate(&domain.element(i)),
1325                E::ScalarField::zero()
1326            );
1327        }
1328
1329        Ok(())
1330    }
1331
1332    fn check_perm_polynomials_on_vanishing_set<E: Pairing>(
1333        oracles: &Oracles<E::ScalarField>,
1334        pk: &ProvingKey<E>,
1335        challenges: &Challenges<E::ScalarField>,
1336    ) -> Result<(), PlonkError> {
1337        let beta = challenges.beta;
1338        let gamma = challenges.gamma;
1339
1340        // check that \prod_i [w_i(X) + beta * k_i * X + gamma] * z(X) = \prod_i [w_i(X)
1341        // + beta * sigma_i(X) + gamma] * z(wX) on the vanishing set
1342        let one_poly = DensePolynomial::from_coefficients_vec(vec![E::ScalarField::one()]);
1343        let poly_1 = oracles
1344            .wire_polys
1345            .iter()
1346            .enumerate()
1347            .fold(one_poly.clone(), |acc, (j, w)| {
1348                let poly =
1349                    &DensePolynomial::from_coefficients_vec(vec![gamma, beta * pk.k()[j]]) + w;
1350                acc.mul(&poly)
1351            });
1352        let poly_2 =
1353            oracles
1354                .wire_polys
1355                .iter()
1356                .zip(pk.sigmas.iter())
1357                .fold(one_poly, |acc, (w, sigma)| {
1358                    let poly = w.clone()
1359                        + sigma.mul(beta)
1360                        + DensePolynomial::from_coefficients_vec(vec![gamma]);
1361                    acc.mul(&poly)
1362                });
1363
1364        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(pk.domain_size())
1365            .ok_or(PlonkError::DomainCreationError)?;
1366        for i in 0..domain.size() {
1367            let point = domain.element(i);
1368            let eval_1 = poly_1.evaluate(&point) * oracles.prod_perm_poly.evaluate(&point);
1369            let eval_2 = poly_2.evaluate(&point)
1370                * oracles.prod_perm_poly.evaluate(&(point * domain.group_gen));
1371            assert_eq!(eval_1, eval_2);
1372        }
1373
1374        // check z(X) = 1 at point 1
1375        assert_eq!(
1376            oracles.prod_perm_poly.evaluate(&domain.element(0)),
1377            E::ScalarField::one()
1378        );
1379
1380        Ok(())
1381    }
1382
1383    fn check_lookup_polynomials_on_vanishing_set<E: Pairing>(
1384        oracles: &Oracles<E::ScalarField>,
1385        pk: &ProvingKey<E>,
1386        challenges: &Challenges<E::ScalarField>,
1387    ) -> Result<(), PlonkError> {
1388        let beta = challenges.beta;
1389        let gamma = challenges.gamma;
1390        let n = pk.domain_size();
1391        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(n)
1392            .ok_or(PlonkError::DomainCreationError)?;
1393        let prod_poly = &oracles.plookup_oracles.prod_lookup_poly;
1394        let h_polys = &oracles.plookup_oracles.h_polys;
1395
1396        // check z(X) = 1 at point 1
1397        assert_eq!(
1398            prod_poly.evaluate(&domain.element(0)),
1399            E::ScalarField::one()
1400        );
1401
1402        // check z(X) = 1 at point w^{n-1}
1403        assert_eq!(
1404            prod_poly.evaluate(&domain.element(n - 1)),
1405            E::ScalarField::one()
1406        );
1407
1408        // check h1(X) = h2(w * X) at point w^{n-1}
1409        assert_eq!(
1410            h_polys[0].evaluate(&domain.element(n - 1)),
1411            h_polys[1].evaluate(&domain.element(0))
1412        );
1413
1414        // check z(X) *
1415        //      (1+beta) * (gamma + merged_lookup_wire(X)) *
1416        //      (gamma(1+beta) + merged_table(X) + beta * merged_table(Xw))
1417        //     = z(Xw) *
1418        //      (gamma(1+beta) + h1(X) + beta * h1(Xw)) *
1419        //      (gamma(1+beta) + h2(x) + beta * h2(Xw))
1420        // on the vanishing set excluding point w^{n-1}
1421        let beta_plus_one = E::ScalarField::one() + beta;
1422        let gamma_mul_beta_plus_one = gamma * beta_plus_one;
1423
1424        let range_table_poly_ref = &pk.plookup_pk.as_ref().unwrap().range_table_poly;
1425        let key_table_poly_ref = &pk.plookup_pk.as_ref().unwrap().key_table_poly;
1426        let table_dom_sep_poly_ref = &pk.plookup_pk.as_ref().unwrap().table_dom_sep_poly;
1427        let q_dom_sep_poly_ref = &pk.plookup_pk.as_ref().unwrap().q_dom_sep_poly;
1428
1429        for i in 0..domain.size() - 1 {
1430            let point = domain.element(i);
1431            let next_point = point * domain.group_gen;
1432            let merged_lookup_wire_eval = eval_merged_lookup_witness::<E>(
1433                challenges.tau.unwrap(),
1434                oracles.wire_polys[5].evaluate(&point),
1435                oracles.wire_polys[0].evaluate(&point),
1436                oracles.wire_polys[1].evaluate(&point),
1437                oracles.wire_polys[2].evaluate(&point),
1438                pk.q_lookup_poly()?.evaluate(&point),
1439                q_dom_sep_poly_ref.evaluate(&point),
1440            );
1441            let merged_table_eval = eval_merged_table::<E>(
1442                challenges.tau.unwrap(),
1443                range_table_poly_ref.evaluate(&point),
1444                key_table_poly_ref.evaluate(&point),
1445                pk.q_lookup_poly()?.evaluate(&point),
1446                oracles.wire_polys[3].evaluate(&point),
1447                oracles.wire_polys[4].evaluate(&point),
1448                table_dom_sep_poly_ref.evaluate(&point),
1449            );
1450            let merged_table_next_eval = eval_merged_table::<E>(
1451                challenges.tau.unwrap(),
1452                range_table_poly_ref.evaluate(&next_point),
1453                key_table_poly_ref.evaluate(&next_point),
1454                pk.q_lookup_poly()?.evaluate(&next_point),
1455                oracles.wire_polys[3].evaluate(&next_point),
1456                oracles.wire_polys[4].evaluate(&next_point),
1457                table_dom_sep_poly_ref.evaluate(&next_point),
1458            );
1459
1460            let eval_1 = prod_poly.evaluate(&point)
1461                * beta_plus_one
1462                * (gamma + merged_lookup_wire_eval)
1463                * (gamma_mul_beta_plus_one + merged_table_eval + beta * merged_table_next_eval);
1464            let eval_2 = prod_poly.evaluate(&next_point)
1465                * (gamma_mul_beta_plus_one
1466                    + h_polys[0].evaluate(&point)
1467                    + beta * h_polys[0].evaluate(&next_point))
1468                * (gamma_mul_beta_plus_one
1469                    + h_polys[1].evaluate(&point)
1470                    + beta * h_polys[1].evaluate(&next_point));
1471            assert_eq!(eval_1, eval_2, "i={}, domain_size={}", i, domain.size());
1472        }
1473
1474        Ok(())
1475    }
1476
1477    #[test]
1478    fn test_proof_from_to_fields() -> Result<(), PlonkError> {
1479        test_proof_from_to_fields_helper::<Bn254, _>()?;
1480        test_proof_from_to_fields_helper::<Bls12_381, _>()?;
1481        test_proof_from_to_fields_helper::<Bls12_377, _>()?;
1482        test_proof_from_to_fields_helper::<BW6_761, _>()?;
1483        Ok(())
1484    }
1485
1486    fn test_proof_from_to_fields_helper<E, P>() -> Result<(), PlonkError>
1487    where
1488        E: Pairing<G1Affine = Affine<P>>,
1489        E::BaseField: RescueParameter + SWToTEConParam,
1490        P: SWCurveConfig<BaseField = E::BaseField, ScalarField = E::ScalarField>,
1491    {
1492        let rng = &mut jf_utils::test_rng();
1493        let circuit = gen_circuit_for_test(3, 4, PlonkType::TurboPlonk)?;
1494        let max_degree = 80;
1495        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1496
1497        let (pk, _) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1498        let proof =
1499            PlonkKzgSnark::<E>::prove::<_, _, StandardTranscript>(rng, &circuit, &pk, None)?;
1500
1501        let base_fields: Vec<E::BaseField> = proof.clone().into();
1502        let res: Proof<E> = base_fields.try_into()?;
1503        assert_eq!(res, proof);
1504
1505        Ok(())
1506    }
1507
1508    #[test]
1509    fn test_serde() -> Result<(), PlonkError> {
1510        // merlin transcripts
1511        test_serde_helper::<Bn254, Fq254, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1512        test_serde_helper::<Bn254, Fq254, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1513        test_serde_helper::<Bls12_377, Fq377, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1514        test_serde_helper::<Bls12_377, Fq377, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1515        test_serde_helper::<Bls12_381, Fq381, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1516        test_serde_helper::<Bls12_381, Fq381, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1517        test_serde_helper::<BW6_761, Fq761, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1518        test_serde_helper::<BW6_761, Fq761, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1519
1520        // rescue transcripts
1521        // currently only available for bls12-377
1522        test_serde_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(PlonkType::TurboPlonk)?;
1523        test_serde_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(PlonkType::UltraPlonk)?;
1524
1525        // Solidity-friendly keccak256 transcript
1526        test_serde_helper::<Bls12_381, Fq381, _, SolidityTranscript>(PlonkType::TurboPlonk)?;
1527
1528        Ok(())
1529    }
1530
1531    fn test_serde_helper<E, F, P, T>(plonk_type: PlonkType) -> Result<(), PlonkError>
1532    where
1533        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1534        F: RescueParameter + SWToTEConParam,
1535        P: SWCurveConfig<BaseField = F>,
1536        T: PlonkTranscript<F>,
1537    {
1538        let rng = &mut jf_utils::test_rng();
1539        let circuit = gen_circuit_for_test(3, 4, plonk_type)?;
1540        let max_degree = 80;
1541        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1542
1543        let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1544        let proof = PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &circuit, &pk, None)?;
1545
1546        let mut ser_bytes = Vec::new();
1547        srs.serialize_compressed(&mut ser_bytes)?;
1548        let de = UniversalSrs::<E>::deserialize_compressed(&ser_bytes[..])?;
1549        assert_eq!(de, srs);
1550
1551        let mut ser_bytes = Vec::new();
1552        pk.serialize_compressed(&mut ser_bytes)?;
1553        let de = ProvingKey::<E>::deserialize_compressed(&ser_bytes[..])?;
1554        assert_eq!(de, pk);
1555
1556        let mut ser_bytes = Vec::new();
1557        vk.serialize_compressed(&mut ser_bytes)?;
1558        let de = VerifyingKey::<E>::deserialize_compressed(&ser_bytes[..])?;
1559        assert_eq!(de, vk);
1560
1561        let mut ser_bytes = Vec::new();
1562        proof.serialize_compressed(&mut ser_bytes)?;
1563        let de = Proof::<E>::deserialize_compressed(&ser_bytes[..])?;
1564        assert_eq!(de, proof);
1565
1566        Ok(())
1567    }
1568
1569    #[test]
1570    fn test_key_aggregation_and_batch_prove() -> Result<(), PlonkError> {
1571        // merlin transcripts
1572        test_key_aggregation_and_batch_prove_helper::<Bn254, Fq254, _, StandardTranscript>(
1573            PlonkType::TurboPlonk,
1574        )?;
1575        test_key_aggregation_and_batch_prove_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1576            PlonkType::TurboPlonk,
1577        )?;
1578        test_key_aggregation_and_batch_prove_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1579            PlonkType::TurboPlonk,
1580        )?;
1581        test_key_aggregation_and_batch_prove_helper::<BW6_761, Fq761, _, StandardTranscript>(
1582            PlonkType::TurboPlonk,
1583        )?;
1584
1585        // rescue transcripts
1586        // currently only available for bls12-377
1587        test_key_aggregation_and_batch_prove_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1588            PlonkType::TurboPlonk,
1589        )
1590    }
1591
1592    fn test_key_aggregation_and_batch_prove_helper<E, F, P, T>(
1593        plonk_type: PlonkType,
1594    ) -> Result<(), PlonkError>
1595    where
1596        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1597        F: RescueParameter + SWToTEConParam,
1598        P: SWCurveConfig<BaseField = F>,
1599        T: PlonkTranscript<F>,
1600    {
1601        // 1. Simulate universal setup
1602        let rng = &mut test_rng();
1603        let n = 128;
1604        let max_degree = n + 2;
1605        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1606
1607        // 2. Create many circuits with same domain size
1608        let circuits = (6..13)
1609            .map(|i| gen_circuit_for_test::<E::ScalarField>(i, i, plonk_type))
1610            .collect::<Result<Vec<_>, PlonkError>>()?; // the number of gates = 4m + 11
1611        let cs_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
1612
1613        // 3. Preprocessing
1614        let mut prove_keys = vec![];
1615        let mut verify_keys = vec![];
1616        for circuit in circuits.iter() {
1617            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, circuit)?;
1618            prove_keys.push(pk);
1619            verify_keys.push(vk);
1620        }
1621        let pks_ref: Vec<&ProvingKey<E>> = prove_keys.iter().collect();
1622        let vks_ref: Vec<&VerifyingKey<E>> = verify_keys.iter().collect();
1623
1624        // 4. Batch Proving and verification
1625        check_batch_prove_and_verify::<_, _, _, _, T>(rng, &cs_ref, &pks_ref, &vks_ref)?;
1626
1627        // Batch proving with circuit/key aggregation
1628        //
1629        // 2. Create circuits
1630        let type_a_circuits: Vec<PlonkCircuit<E::ScalarField>> = (6..13)
1631            .map(|i| gen_mergeable_circuit(i, i, MergeableCircuitType::TypeA))
1632            .collect::<Result<Vec<_>, PlonkError>>()?; // the number of gates = 4m + 11
1633        let type_b_circuits = (6..13)
1634            .map(|i| gen_mergeable_circuit(i, i, MergeableCircuitType::TypeB))
1635            .collect::<Result<Vec<_>, PlonkError>>()?;
1636        // merge circuits
1637        let circuits = type_a_circuits
1638            .iter()
1639            .zip(type_b_circuits.iter())
1640            .map(|(cs_a, cs_b)| cs_a.merge(cs_b))
1641            .collect::<Result<Vec<_>, _>>()?;
1642        let cs_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
1643
1644        // 3. Preprocessing
1645        let mut pks_type_a = vec![];
1646        let mut vks_type_a = vec![];
1647        for cs_a in type_a_circuits.iter() {
1648            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs_a)?;
1649            pks_type_a.push(pk);
1650            vks_type_a.push(vk);
1651        }
1652        let mut pks_type_b = vec![];
1653        let mut vks_type_b = vec![];
1654        for cs_b in type_b_circuits.iter() {
1655            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs_b)?;
1656            pks_type_b.push(pk);
1657            vks_type_b.push(vk);
1658        }
1659        // merge proving keys
1660        let pks = pks_type_a
1661            .iter()
1662            .zip(pks_type_b.iter())
1663            .map(|(pk_a, pk_b)| pk_a.merge(pk_b))
1664            .collect::<Result<Vec<_>, PlonkError>>()?;
1665        // merge verification keys
1666        let vks = vks_type_a
1667            .iter()
1668            .zip(vks_type_b.iter())
1669            .map(|(vk_a, vk_b)| vk_a.merge(vk_b))
1670            .collect::<Result<Vec<_>, PlonkError>>()?;
1671        // check that the merged keys are correct
1672        for (cs, vk) in circuits.iter().zip(vks.iter()) {
1673            let (_, mut expected_vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs)?;
1674            expected_vk.is_merged = true;
1675            assert_eq!(*vk, expected_vk);
1676        }
1677
1678        let pks_ref: Vec<&ProvingKey<E>> = pks.iter().collect();
1679        let vks_ref: Vec<&VerifyingKey<E>> = vks.iter().collect();
1680
1681        // 4. Batch Proving and verification
1682        check_batch_prove_and_verify::<_, _, _, _, T>(rng, &cs_ref, &pks_ref, &vks_ref)?;
1683
1684        Ok(())
1685    }
1686
1687    fn check_batch_prove_and_verify<E, F, P, R, T>(
1688        rng: &mut R,
1689        cs_ref: &[&PlonkCircuit<E::ScalarField>],
1690        pks_ref: &[&ProvingKey<E>],
1691        vks_ref: &[&VerifyingKey<E>],
1692    ) -> Result<(), PlonkError>
1693    where
1694        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1695        F: RescueParameter + SWToTEConParam,
1696        P: SWCurveConfig<BaseField = F>,
1697        R: CryptoRng + RngCore,
1698        T: PlonkTranscript<F>,
1699    {
1700        // Batch Proving
1701        let batch_proof = PlonkKzgSnark::<E>::batch_prove::<_, _, T>(rng, cs_ref, pks_ref)?;
1702
1703        // Verification
1704        let public_inputs: Vec<Vec<E::ScalarField>> = cs_ref
1705            .iter()
1706            .map(|&cs| cs.public_input())
1707            .collect::<Result<Vec<Vec<E::ScalarField>>, _>>(
1708        )?;
1709        let pi_ref: Vec<&[E::ScalarField]> = public_inputs
1710            .iter()
1711            .map(|pub_input| &pub_input[..])
1712            .collect();
1713        assert!(
1714            PlonkKzgSnark::<E>::verify_batch_proof::<T>(vks_ref, &pi_ref, &batch_proof,).is_ok()
1715        );
1716        let mut bad_pi_ref = pi_ref.clone();
1717        bad_pi_ref[0] = bad_pi_ref[1];
1718        assert!(
1719            PlonkKzgSnark::<E>::verify_batch_proof::<T>(vks_ref, &bad_pi_ref, &batch_proof,)
1720                .is_err()
1721        );
1722
1723        Ok(())
1724    }
1725
1726    fn gen_mergeable_circuit<F: PrimeField>(
1727        m: usize,
1728        a0: usize,
1729        circuit_type: MergeableCircuitType,
1730    ) -> Result<PlonkCircuit<F>, PlonkError> {
1731        let mut cs: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
1732        // Create variables
1733        let mut a = vec![];
1734        for i in a0..(a0 + 4 * m) {
1735            a.push(cs.create_variable(F::from(i as u64))?);
1736        }
1737        let b = [
1738            cs.create_public_variable(F::from(m as u64 * 2))?,
1739            cs.create_public_variable(F::from(a0 as u64 * 2 + m as u64 * 4 - 1))?,
1740        ];
1741        let c = cs.create_public_variable(
1742            (cs.witness(b[1])? + cs.witness(a[0])?) * (cs.witness(b[1])? - cs.witness(a[0])?),
1743        )?;
1744
1745        // Create gates:
1746        // 1. a0 + ... + a_{4*m-1} = b0 * b1
1747        // 2. (b1 + a0) * (b1 - a0) = c
1748        // 3. b0 = 2 * m
1749        let mut acc = cs.zero();
1750        a.iter().for_each(|&elem| acc = cs.add(acc, elem).unwrap());
1751        let b_mul = cs.mul(b[0], b[1])?;
1752        cs.enforce_equal(acc, b_mul)?;
1753        let b1_plus_a0 = cs.add(b[1], a[0])?;
1754        let b1_minus_a0 = cs.sub(b[1], a[0])?;
1755        cs.mul_gate(b1_plus_a0, b1_minus_a0, c)?;
1756        cs.enforce_constant(b[0], F::from(m as u64 * 2))?;
1757
1758        cs.finalize_for_mergeable_circuit(circuit_type)?;
1759
1760        Ok(cs)
1761    }
1762}