jf_plonk/proof_system/
snark.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Instantiations of Plonk-based proof systems
8use super::{
9    prover::Prover,
10    structs::{
11        BatchProof, Challenges, Oracles, PlookupProof, PlookupProvingKey, PlookupVerifyingKey,
12        Proof, ProvingKey, VerifyingKey,
13    },
14    verifier::Verifier,
15    UniversalSNARK,
16};
17use crate::{
18    constants::EXTRA_TRANSCRIPT_MSG_LABEL,
19    errors::{PlonkError, SnarkError::ParameterError},
20    proof_system::structs::UniversalSrs,
21    transcript::*,
22};
23use ark_ec::{
24    pairing::Pairing,
25    short_weierstrass::{Affine, SWCurveConfig},
26};
27use ark_ff::{Field, One};
28use ark_std::{
29    format,
30    marker::PhantomData,
31    rand::{CryptoRng, RngCore},
32    string::ToString,
33    vec,
34    vec::Vec,
35};
36use jf_pcs::{prelude::UnivariateKzgPCS, PolynomialCommitmentScheme, StructuredReferenceString};
37use jf_relation::{
38    constants::compute_coset_representatives, gadgets::ecc::SWToTEConParam, Arithmetization,
39};
40use jf_rescue::RescueParameter;
41use jf_utils::par_utils::parallelizable_slice_iter;
42#[cfg(feature = "parallel")]
43use rayon::prelude::*;
44
45/// A Plonk instantiated with KZG PCS
46pub struct PlonkKzgSnark<E: Pairing>(PhantomData<E>);
47
48impl<E, F, P> PlonkKzgSnark<E>
49where
50    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
51    F: RescueParameter + SWToTEConParam,
52    P: SWCurveConfig<BaseField = F>,
53{
54    #[allow(clippy::new_without_default)]
55    /// A new Plonk KZG SNARK
56    pub fn new() -> Self {
57        Self(PhantomData)
58    }
59
60    /// Generate an aggregated Plonk proof for multiple instances.
61    pub fn batch_prove<C, R, T>(
62        prng: &mut R,
63        circuits: &[&C],
64        prove_keys: &[&ProvingKey<E>],
65    ) -> Result<BatchProof<E>, PlonkError>
66    where
67        C: Arithmetization<E::ScalarField>,
68        R: CryptoRng + RngCore,
69        T: PlonkTranscript<F>,
70    {
71        let (batch_proof, ..) =
72            Self::batch_prove_internal::<_, _, T>(prng, circuits, prove_keys, None)?;
73        Ok(batch_proof)
74    }
75
76    /// Verify a single aggregated Plonk proof.
77    pub fn verify_batch_proof<T>(
78        verify_keys: &[&VerifyingKey<E>],
79        public_inputs: &[&[E::ScalarField]],
80        batch_proof: &BatchProof<E>,
81    ) -> Result<(), PlonkError>
82    where
83        T: PlonkTranscript<F>,
84    {
85        if verify_keys.is_empty() {
86            return Err(ParameterError("empty verification keys".to_string()).into());
87        }
88        let verifier = Verifier::new(verify_keys[0].domain_size)?;
89        let pcs_info =
90            verifier.prepare_pcs_info::<T>(verify_keys, public_inputs, batch_proof, &None)?;
91        if !Verifier::batch_verify_opening_proofs::<T>(
92            &verify_keys[0].open_key, // all open_key are the same
93            &[pcs_info],
94        )? {
95            return Err(PlonkError::WrongProof);
96        }
97        Ok(())
98    }
99
100    /// Batch verify multiple SNARK proofs (w.r.t. different verifying keys).
101    pub fn batch_verify<T>(
102        verify_keys: &[&VerifyingKey<E>],
103        public_inputs: &[&[E::ScalarField]],
104        proofs: &[&Proof<E>],
105        extra_transcript_init_msgs: &[Option<Vec<u8>>],
106    ) -> Result<(), PlonkError>
107    where
108        T: PlonkTranscript<F>,
109    {
110        if public_inputs.len() != proofs.len()
111            || verify_keys.len() != proofs.len()
112            || extra_transcript_init_msgs.len() != proofs.len()
113        {
114            return Err(ParameterError(format!(
115                "verify_keys.len: {}, public_inputs.len: {}, proofs.len: {}, \
116                 extra_transcript_msg.len: {}",
117                verify_keys.len(),
118                public_inputs.len(),
119                proofs.len(),
120                extra_transcript_init_msgs.len()
121            ))
122            .into());
123        }
124        if verify_keys.is_empty() {
125            return Err(
126                ParameterError("the number of instances cannot be zero".to_string()).into(),
127            );
128        }
129
130        let pcs_infos = parallelizable_slice_iter(verify_keys)
131            .zip(parallelizable_slice_iter(proofs))
132            .zip(parallelizable_slice_iter(public_inputs))
133            .zip(parallelizable_slice_iter(extra_transcript_init_msgs))
134            .map(|(((&vk, &proof), &pub_input), extra_msg)| {
135                let verifier = Verifier::new(vk.domain_size)?;
136                verifier.prepare_pcs_info::<T>(
137                    &[vk],
138                    &[pub_input],
139                    &(*proof).clone().into(),
140                    extra_msg,
141                )
142            })
143            .collect::<Result<Vec<_>, PlonkError>>()?;
144
145        if !Verifier::batch_verify_opening_proofs::<T>(
146            &verify_keys[0].open_key, // all open_key are the same
147            &pcs_infos,
148        )? {
149            return Err(PlonkError::WrongProof);
150        }
151        Ok(())
152    }
153
154    /// An internal private API for ease of testing
155    ///
156    /// Batchly compute a Plonk proof for multiple instances. Return the batch
157    /// proof and the corresponding online polynomial oracles and
158    /// challenges. Refer to Sec 8.4 of https://eprint.iacr.org/2019/953.pdf
159    ///
160    /// `circuit` and `prove_key` has to be consistent (with the same evaluation
161    /// domain etc.), otherwise return error.
162    #[allow(clippy::type_complexity)]
163    fn batch_prove_internal<C, R, T>(
164        prng: &mut R,
165        circuits: &[&C],
166        prove_keys: &[&ProvingKey<E>],
167        extra_transcript_init_msg: Option<Vec<u8>>,
168    ) -> Result<
169        (
170            BatchProof<E>,
171            Vec<Oracles<E::ScalarField>>,
172            Challenges<E::ScalarField>,
173        ),
174        PlonkError,
175    >
176    where
177        C: Arithmetization<E::ScalarField>,
178        R: CryptoRng + RngCore,
179        T: PlonkTranscript<F>,
180    {
181        if circuits.is_empty() {
182            return Err(ParameterError("zero number of circuits/proving keys".to_string()).into());
183        }
184        if circuits.len() != prove_keys.len() {
185            return Err(ParameterError(format!(
186                "the number of circuits {} != the number of proving keys {}",
187                circuits.len(),
188                prove_keys.len()
189            ))
190            .into());
191        }
192        let n = circuits[0].eval_domain_size()?;
193        let num_wire_types = circuits[0].num_wire_types();
194        for (circuit, pk) in circuits.iter().zip(prove_keys.iter()) {
195            if circuit.eval_domain_size()? != n {
196                return Err(ParameterError(format!(
197                    "circuit domain size {} != expected domain size {}",
198                    circuit.eval_domain_size()?,
199                    n
200                ))
201                .into());
202            }
203            if pk.domain_size() != n {
204                return Err(ParameterError(format!(
205                    "proving key domain size {} != expected domain size {}",
206                    pk.domain_size(),
207                    n
208                ))
209                .into());
210            }
211            if circuit.num_inputs() != pk.vk.num_inputs {
212                return Err(ParameterError(format!(
213                    "circuit.num_inputs {} != prove_key.num_inputs {}",
214                    circuit.num_inputs(),
215                    pk.vk.num_inputs
216                ))
217                .into());
218            }
219            if circuit.support_lookup() != pk.plookup_pk.is_some() {
220                return Err(ParameterError(
221                    "Mismatched Plonk types between the proving key and the circuit".to_string(),
222                )
223                .into());
224            }
225            if circuit.num_wire_types() != num_wire_types {
226                return Err(ParameterError("inconsistent plonk circuit types".to_string()).into());
227            }
228        }
229
230        // Initialize transcript
231        let mut transcript = T::new(b"PlonkProof");
232        if let Some(msg) = extra_transcript_init_msg {
233            transcript.append_message(EXTRA_TRANSCRIPT_MSG_LABEL, &msg)?;
234        }
235        for (pk, circuit) in prove_keys.iter().zip(circuits.iter()) {
236            transcript.append_vk_and_pub_input(&pk.vk, &circuit.public_input()?)?;
237        }
238        // Initialize verifier challenges and online polynomial oracles.
239        let mut challenges = Challenges::default();
240        let mut online_oracles = vec![Oracles::default(); circuits.len()];
241        let prover = Prover::new(n, num_wire_types)?;
242
243        // Round 1
244        let mut wires_poly_comms_vec = vec![];
245        for i in 0..circuits.len() {
246            let ((wires_poly_comms, wire_polys), pi_poly) =
247                prover.run_1st_round(prng, &prove_keys[i].commit_key, circuits[i])?;
248            online_oracles[i].wire_polys = wire_polys;
249            online_oracles[i].pub_inp_poly = pi_poly;
250            transcript.append_commitments(b"witness_poly_comms", &wires_poly_comms)?;
251            wires_poly_comms_vec.push(wires_poly_comms);
252        }
253
254        // Round 1.5
255        // Plookup: compute and interpolate the sorted concatenation of the (merged)
256        // lookup table and the (merged) witness values
257        if circuits.iter().any(|c| C::support_lookup(c)) {
258            challenges.tau = Some(transcript.get_challenge::<E>(b"tau")?);
259        } else {
260            challenges.tau = None;
261        }
262
263        let mut h_poly_comms_vec = vec![];
264        let mut sorted_vec_list = vec![];
265        let mut merged_table_list = vec![];
266        for i in 0..circuits.len() {
267            let (sorted_vec, h_poly_comms, merged_table) = if circuits[i].support_lookup() {
268                let ((h_poly_comms, h_polys), sorted_vec, merged_table) = prover
269                    .run_plookup_1st_round(
270                        prng,
271                        &prove_keys[i].commit_key,
272                        circuits[i],
273                        challenges.tau.unwrap(),
274                    )?;
275                online_oracles[i].plookup_oracles.h_polys = h_polys;
276                transcript.append_commitments(b"h_poly_comms", &h_poly_comms)?;
277                (Some(sorted_vec), Some(h_poly_comms), Some(merged_table))
278            } else {
279                (None, None, None)
280            };
281            h_poly_comms_vec.push(h_poly_comms);
282            sorted_vec_list.push(sorted_vec);
283            merged_table_list.push(merged_table);
284        }
285
286        // Round 2
287        challenges.beta = transcript.get_challenge::<E>(b"beta")?;
288        challenges.gamma = transcript.get_challenge::<E>(b"gamma")?;
289        let mut prod_perm_poly_comms_vec = vec![];
290        for i in 0..circuits.len() {
291            let (prod_perm_poly_comm, prod_perm_poly) =
292                prover.run_2nd_round(prng, &prove_keys[i].commit_key, circuits[i], &challenges)?;
293            online_oracles[i].prod_perm_poly = prod_perm_poly;
294            transcript.append_commitment(b"perm_poly_comms", &prod_perm_poly_comm)?;
295            prod_perm_poly_comms_vec.push(prod_perm_poly_comm);
296        }
297
298        // Round 2.5
299        // Plookup: compute Plookup product accumulation polynomial
300        let mut prod_lookup_poly_comms_vec = vec![];
301        for i in 0..circuits.len() {
302            let prod_lookup_poly_comm = if circuits[i].support_lookup() {
303                let (prod_lookup_poly_comm, prod_lookup_poly) = prover.run_plookup_2nd_round(
304                    prng,
305                    &prove_keys[i].commit_key,
306                    circuits[i],
307                    &challenges,
308                    merged_table_list[i].as_ref(),
309                    sorted_vec_list[i].as_ref(),
310                )?;
311                online_oracles[i].plookup_oracles.prod_lookup_poly = prod_lookup_poly;
312                transcript.append_commitment(b"plookup_poly_comms", &prod_lookup_poly_comm)?;
313                Some(prod_lookup_poly_comm)
314            } else {
315                None
316            };
317            prod_lookup_poly_comms_vec.push(prod_lookup_poly_comm);
318        }
319
320        // Round 3
321        challenges.alpha = transcript.get_challenge::<E>(b"alpha")?;
322        let (split_quot_poly_comms, split_quot_polys) = prover.run_3rd_round(
323            prng,
324            &prove_keys[0].commit_key,
325            prove_keys,
326            &challenges,
327            &online_oracles,
328            num_wire_types,
329        )?;
330        transcript.append_commitments(b"quot_poly_comms", &split_quot_poly_comms)?;
331
332        // Round 4
333        challenges.zeta = transcript.get_challenge::<E>(b"zeta")?;
334        let mut poly_evals_vec = vec![];
335        for i in 0..circuits.len() {
336            let poly_evals = prover.compute_evaluations(
337                prove_keys[i],
338                &challenges,
339                &online_oracles[i],
340                num_wire_types,
341            );
342            transcript.append_proof_evaluations::<E>(&poly_evals)?;
343            poly_evals_vec.push(poly_evals);
344        }
345
346        // Round 4.5
347        // Plookup: compute evaluations on Plookup-related polynomials
348        let mut plookup_evals_vec = vec![];
349        for i in 0..circuits.len() {
350            let plookup_evals = if circuits[i].support_lookup() {
351                let evals = prover.compute_plookup_evaluations(
352                    prove_keys[i],
353                    &challenges,
354                    &online_oracles[i],
355                )?;
356                transcript.append_plookup_evaluations::<E>(&evals)?;
357                Some(evals)
358            } else {
359                None
360            };
361            plookup_evals_vec.push(plookup_evals);
362        }
363
364        let mut lin_poly = Prover::<E>::compute_quotient_component_for_lin_poly(
365            n,
366            challenges.zeta,
367            &split_quot_polys,
368        )?;
369        let mut alpha_base = E::ScalarField::one();
370        let alpha_3 = challenges.alpha.square() * challenges.alpha;
371        let alpha_7 = alpha_3.square() * challenges.alpha;
372        for i in 0..circuits.len() {
373            lin_poly = lin_poly
374                + prover.compute_non_quotient_component_for_lin_poly(
375                    alpha_base,
376                    prove_keys[i],
377                    &challenges,
378                    &online_oracles[i],
379                    &poly_evals_vec[i],
380                    plookup_evals_vec[i].as_ref(),
381                )?;
382            // update the alpha power term (i.e. the random combiner that aggregates
383            // multiple instances)
384            if plookup_evals_vec[i].is_some() {
385                alpha_base *= alpha_7;
386            } else {
387                alpha_base *= alpha_3;
388            }
389        }
390
391        // Round 5
392        challenges.v = transcript.get_challenge::<E>(b"v")?;
393        let (opening_proof, shifted_opening_proof) = prover.compute_opening_proofs(
394            &prove_keys[0].commit_key,
395            prove_keys,
396            &challenges.zeta,
397            &challenges.v,
398            &online_oracles,
399            &lin_poly,
400        )?;
401
402        // Plookup: build Plookup argument
403        let mut plookup_proofs_vec = vec![];
404        for i in 0..circuits.len() {
405            let plookup_proof = if circuits[i].support_lookup() {
406                Some(PlookupProof {
407                    h_poly_comms: h_poly_comms_vec[i].clone().unwrap(),
408                    prod_lookup_poly_comm: prod_lookup_poly_comms_vec[i].unwrap(),
409                    poly_evals: plookup_evals_vec[i].clone().unwrap(),
410                })
411            } else {
412                None
413            };
414            plookup_proofs_vec.push(plookup_proof);
415        }
416
417        Ok((
418            BatchProof {
419                wires_poly_comms_vec,
420                prod_perm_poly_comms_vec,
421                poly_evals_vec,
422                plookup_proofs_vec,
423                split_quot_poly_comms,
424                opening_proof,
425                shifted_opening_proof,
426            },
427            online_oracles,
428            challenges,
429        ))
430    }
431}
432
433impl<E, F, P> UniversalSNARK<E> for PlonkKzgSnark<E>
434where
435    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
436    F: RescueParameter + SWToTEConParam,
437    P: SWCurveConfig<BaseField = F>,
438{
439    type Proof = Proof<E>;
440    type ProvingKey = ProvingKey<E>;
441    type VerifyingKey = VerifyingKey<E>;
442    type UniversalSRS = UniversalSrs<E>;
443    type Error = PlonkError;
444
445    // FIXME: (alex) see <https://github.com/EspressoSystems/jellyfish/issues/249>
446    #[cfg(any(test, feature = "test-srs"))]
447    fn universal_setup_for_testing<R: RngCore + CryptoRng>(
448        max_degree: usize,
449        rng: &mut R,
450    ) -> Result<Self::UniversalSRS, Self::Error> {
451        use ark_ec::{CurveGroup, ScalarMul};
452        use ark_std::{end_timer, start_timer, UniformRand};
453
454        let setup_time = start_timer!(|| format!("KZG10::Setup with degree {}", max_degree));
455        let beta = E::ScalarField::rand(rng);
456        let g = E::G1::rand(rng);
457        let h = E::G2::rand(rng);
458
459        let mut powers_of_beta = vec![E::ScalarField::one()];
460
461        let mut cur = beta;
462        for _ in 0..max_degree {
463            powers_of_beta.push(cur);
464            cur *= &beta;
465        }
466
467        let g_time = start_timer!(|| "Generating powers of G");
468        let powers_of_g = g
469            .batch_mul(&powers_of_beta)
470            .into_iter()
471            .map(E::G1::from)
472            .collect::<Vec<_>>();
473        end_timer!(g_time);
474
475        let powers_of_g = E::G1::normalize_batch(&powers_of_g);
476
477        let h = h.into_affine();
478        let beta_h = (h * beta).into_affine();
479
480        let pp = UniversalSrs {
481            powers_of_g,
482            h,
483            beta_h,
484            powers_of_h: vec![h, beta_h],
485        };
486        end_timer!(setup_time);
487        Ok(pp)
488    }
489
490    /// Input a circuit and the SRS, precompute the proving key and verification
491    /// key.
492    fn preprocess<C: Arithmetization<E::ScalarField>>(
493        srs: &Self::UniversalSRS,
494        circuit: &C,
495    ) -> Result<(Self::ProvingKey, Self::VerifyingKey), Self::Error> {
496        // Make sure the SRS can support the circuit (with hiding degree of 2 for zk)
497        let domain_size = circuit.eval_domain_size()?;
498        let srs_size = circuit.srs_size()?;
499        let num_inputs = circuit.num_inputs();
500        if srs.max_degree() < circuit.srs_size()? {
501            return Err(PlonkError::IndexTooLarge);
502        }
503        // 1. Compute selector and permutation polynomials.
504        let selectors_polys = circuit.compute_selector_polynomials()?;
505        let sigma_polys = circuit.compute_extended_permutation_polynomials()?;
506
507        // Compute Plookup proving key if support lookup.
508        let plookup_pk = if circuit.support_lookup() {
509            let range_table_poly = circuit.compute_range_table_polynomial()?;
510            let key_table_poly = circuit.compute_key_table_polynomial()?;
511            let table_dom_sep_poly = circuit.compute_table_dom_sep_polynomial()?;
512            let q_dom_sep_poly = circuit.compute_q_dom_sep_polynomial()?;
513            Some(PlookupProvingKey {
514                range_table_poly,
515                key_table_poly,
516                table_dom_sep_poly,
517                q_dom_sep_poly,
518            })
519        } else {
520            None
521        };
522
523        // 2. Compute VerifyingKey
524        let (commit_key, open_key) = srs.trim(srs_size)?;
525        let selector_comms = parallelizable_slice_iter(&selectors_polys)
526            .map(|poly| UnivariateKzgPCS::commit(&commit_key, poly).map_err(PlonkError::PCSError))
527            .collect::<Result<Vec<_>, PlonkError>>()?
528            .into_iter()
529            .collect();
530        let sigma_comms = parallelizable_slice_iter(&sigma_polys)
531            .map(|poly| UnivariateKzgPCS::commit(&commit_key, poly).map_err(PlonkError::PCSError))
532            .collect::<Result<Vec<_>, PlonkError>>()?
533            .into_iter()
534            .collect();
535
536        // Compute Plookup verifying key if support lookup.
537        let plookup_vk = match circuit.support_lookup() {
538            false => None,
539            true => Some(PlookupVerifyingKey {
540                range_table_comm: UnivariateKzgPCS::commit(
541                    &commit_key,
542                    &plookup_pk.as_ref().unwrap().range_table_poly,
543                )?,
544                key_table_comm: UnivariateKzgPCS::commit(
545                    &commit_key,
546                    &plookup_pk.as_ref().unwrap().key_table_poly,
547                )?,
548                table_dom_sep_comm: UnivariateKzgPCS::commit(
549                    &commit_key,
550                    &plookup_pk.as_ref().unwrap().table_dom_sep_poly,
551                )?,
552                q_dom_sep_comm: UnivariateKzgPCS::commit(
553                    &commit_key,
554                    &plookup_pk.as_ref().unwrap().q_dom_sep_poly,
555                )?,
556            }),
557        };
558
559        let vk = VerifyingKey {
560            domain_size,
561            num_inputs,
562            selector_comms,
563            sigma_comms,
564            k: compute_coset_representatives(circuit.num_wire_types(), Some(domain_size)),
565            open_key,
566            plookup_vk,
567            is_merged: false,
568        };
569
570        // Compute ProvingKey (which includes the VerifyingKey)
571        let pk = ProvingKey {
572            sigmas: sigma_polys,
573            selectors: selectors_polys,
574            commit_key,
575            vk: vk.clone(),
576            plookup_pk,
577        };
578
579        Ok((pk, vk))
580    }
581
582    /// Compute a Plonk proof.
583    /// Refer to Sec 8.4 of <https://eprint.iacr.org/2019/953.pdf>
584    ///
585    /// `circuit` and `prove_key` has to be consistent (with the same evaluation
586    /// domain etc.), otherwise return error.
587    fn prove<C, R, T>(
588        rng: &mut R,
589        circuit: &C,
590        prove_key: &Self::ProvingKey,
591        extra_transcript_init_msg: Option<Vec<u8>>,
592    ) -> Result<Self::Proof, Self::Error>
593    where
594        C: Arithmetization<E::ScalarField>,
595        R: CryptoRng + RngCore,
596        T: PlonkTranscript<F>,
597    {
598        let (batch_proof, ..) = Self::batch_prove_internal::<_, _, T>(
599            rng,
600            &[circuit],
601            &[prove_key],
602            extra_transcript_init_msg,
603        )?;
604        Ok(Proof {
605            wires_poly_comms: batch_proof.wires_poly_comms_vec[0].clone(),
606            prod_perm_poly_comm: batch_proof.prod_perm_poly_comms_vec[0],
607            split_quot_poly_comms: batch_proof.split_quot_poly_comms,
608            opening_proof: batch_proof.opening_proof,
609            shifted_opening_proof: batch_proof.shifted_opening_proof,
610            poly_evals: batch_proof.poly_evals_vec[0].clone(),
611            plookup_proof: batch_proof.plookup_proofs_vec[0].clone(),
612        })
613    }
614
615    fn verify<T>(
616        verify_key: &Self::VerifyingKey,
617        public_input: &[E::ScalarField],
618        proof: &Self::Proof,
619        extra_transcript_init_msg: Option<Vec<u8>>,
620    ) -> Result<(), Self::Error>
621    where
622        T: PlonkTranscript<F>,
623    {
624        Self::batch_verify::<T>(
625            &[verify_key],
626            &[public_input],
627            &[proof],
628            &[extra_transcript_init_msg],
629        )
630    }
631}
632
633#[cfg(test)]
634pub mod test {
635    use crate::{
636        errors::PlonkError,
637        proof_system::{
638            structs::{
639                eval_merged_lookup_witness, eval_merged_table, Challenges, Oracles, Proof,
640                ProvingKey, UniversalSrs, VerifyingKey,
641            },
642            PlonkKzgSnark, UniversalSNARK,
643        },
644        transcript::{
645            rescue::RescueTranscript, solidity::SolidityTranscript, standard::StandardTranscript,
646            PlonkTranscript,
647        },
648        PlonkType,
649    };
650    use ark_bls12_377::{Bls12_377, Fq as Fq377};
651    use ark_bls12_381::{Bls12_381, Fq as Fq381};
652    use ark_bn254::{Bn254, Fq as Fq254};
653    use ark_bw6_761::{Fq as Fq761, BW6_761};
654    use ark_ec::{
655        pairing::Pairing,
656        short_weierstrass::{Affine, SWCurveConfig},
657    };
658    use ark_ff::{One, PrimeField, Zero};
659    use ark_poly::{
660        univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Polynomial,
661        Radix2EvaluationDomain,
662    };
663    use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
664    use ark_std::{
665        format,
666        rand::{CryptoRng, RngCore},
667        string::ToString,
668        vec,
669        vec::Vec,
670    };
671    use core::ops::{Mul, Neg};
672    use jf_pcs::{
673        prelude::{Commitment, UnivariateKzgPCS},
674        PolynomialCommitmentScheme,
675    };
676    use jf_relation::{
677        constants::GATE_WIDTH, gadgets::ecc::SWToTEConParam, Arithmetization, Circuit,
678        MergeableCircuitType, PlonkCircuit,
679    };
680    use jf_rescue::RescueParameter;
681    use jf_utils::test_rng;
682
683    // Different `m`s lead to different circuits.
684    // Different `a0`s lead to different witness values.
685    // For UltraPlonk circuits, `a0` should be less than or equal to `m+1`
686    pub(crate) fn gen_circuit_for_test<F: PrimeField>(
687        m: usize,
688        a0: usize,
689        plonk_type: PlonkType,
690    ) -> Result<PlonkCircuit<F>, PlonkError> {
691        let range_bit_len = 5;
692        let mut cs: PlonkCircuit<F> = match plonk_type {
693            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
694            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(range_bit_len),
695        };
696        // Create variables
697        let mut a = vec![];
698        for i in a0..(a0 + 4 * m) {
699            a.push(cs.create_variable(F::from(i as u64))?);
700        }
701        let b = [
702            cs.create_public_variable(F::from(m as u64 * 2))?,
703            cs.create_public_variable(F::from(a0 as u64 * 2 + m as u64 * 4 - 1))?,
704        ];
705        let c = cs.create_public_variable(
706            (cs.witness(b[1])? + cs.witness(a[0])?) * (cs.witness(b[1])? - cs.witness(a[0])?),
707        )?;
708
709        // Create gates:
710        // 1. a0 + ... + a_{4*m-1} = b0 * b1
711        // 2. (b1 + a0) * (b1 - a0) = c
712        // 3. b0 = 2 * m
713        let mut acc = cs.zero();
714        a.iter().for_each(|&elem| acc = cs.add(acc, elem).unwrap());
715        let b_mul = cs.mul(b[0], b[1])?;
716        cs.enforce_equal(acc, b_mul)?;
717        let b1_plus_a0 = cs.add(b[1], a[0])?;
718        let b1_minus_a0 = cs.sub(b[1], a[0])?;
719        cs.mul_gate(b1_plus_a0, b1_minus_a0, c)?;
720        cs.enforce_constant(b[0], F::from(m as u64 * 2))?;
721
722        if plonk_type == PlonkType::UltraPlonk {
723            // Create range gates
724            // 1. range_table = {0, 1, ..., 31}
725            // 2. a_i \in range_table for i = 0..m-1
726            // 3. b0 \in range_table
727            for &var in a.iter().take(m) {
728                cs.add_range_check_variable(var)?;
729            }
730            cs.add_range_check_variable(b[0])?;
731
732            // Create variable table lookup gates
733            // 1. table = [(a0, a2), (a1, a3), (b0, a0)]
734            let table_vars = [(a[0], a[2]), (a[1], a[3]), (b[0], a[0])];
735            // 2. lookup_witness = [(1, a0+1, a0+3), (2, 2m, a0)]
736            let key0 = cs.one();
737            let key1 = cs.create_variable(F::from(2u8))?;
738            let two_m = cs.create_public_variable(F::from(m as u64 * 2))?;
739            let a1 = cs.add_constant(a[0], &F::one())?;
740            let a3 = cs.add_constant(a[0], &F::from(3u8))?;
741            let lookup_vars = [(key0, a1, a3), (key1, two_m, a[0])];
742            cs.create_table_and_lookup_variables(&lookup_vars, &table_vars)?;
743        }
744
745        // Finalize the circuit.
746        cs.finalize_for_arithmetization()?;
747
748        Ok(cs)
749    }
750
751    #[test]
752    fn test_preprocessing() -> Result<(), PlonkError> {
753        test_preprocessing_helper::<Bn254, Fq254, _>(PlonkType::TurboPlonk)?;
754        test_preprocessing_helper::<Bn254, Fq254, _>(PlonkType::UltraPlonk)?;
755        test_preprocessing_helper::<Bls12_377, Fq377, _>(PlonkType::TurboPlonk)?;
756        test_preprocessing_helper::<Bls12_377, Fq377, _>(PlonkType::UltraPlonk)?;
757        test_preprocessing_helper::<Bls12_381, Fq381, _>(PlonkType::TurboPlonk)?;
758        test_preprocessing_helper::<Bls12_381, Fq381, _>(PlonkType::UltraPlonk)?;
759        test_preprocessing_helper::<BW6_761, Fq761, _>(PlonkType::TurboPlonk)?;
760        test_preprocessing_helper::<BW6_761, Fq761, _>(PlonkType::UltraPlonk)
761    }
762    fn test_preprocessing_helper<E, F, P>(plonk_type: PlonkType) -> Result<(), PlonkError>
763    where
764        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
765        F: RescueParameter + SWToTEConParam,
766        P: SWCurveConfig<BaseField = F>,
767    {
768        let rng = &mut jf_utils::test_rng();
769        let circuit = gen_circuit_for_test(5, 6, plonk_type)?;
770        let domain_size = circuit.eval_domain_size()?;
771        let num_inputs = circuit.num_inputs();
772        let selectors = circuit.compute_selector_polynomials()?;
773        let sigmas = circuit.compute_extended_permutation_polynomials()?;
774
775        let max_degree = 64 + 2;
776        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
777        let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
778
779        // check proving key
780        assert_eq!(pk.selectors, selectors);
781        assert_eq!(pk.sigmas, sigmas);
782        assert_eq!(pk.domain_size(), domain_size);
783        assert_eq!(pk.num_inputs(), num_inputs);
784        let num_wire_types = GATE_WIDTH
785            + 1
786            + match plonk_type {
787                PlonkType::TurboPlonk => 0,
788                PlonkType::UltraPlonk => 1,
789            };
790        assert_eq!(pk.sigmas.len(), num_wire_types);
791        // check plookup proving key
792        if plonk_type == PlonkType::UltraPlonk {
793            let range_table_poly = circuit.compute_range_table_polynomial()?;
794            assert_eq!(
795                pk.plookup_pk.as_ref().unwrap().range_table_poly,
796                range_table_poly
797            );
798
799            let key_table_poly = circuit.compute_key_table_polynomial()?;
800            assert_eq!(
801                pk.plookup_pk.as_ref().unwrap().key_table_poly,
802                key_table_poly
803            );
804        }
805
806        // check verifying key
807        assert_eq!(vk.domain_size, domain_size);
808        assert_eq!(vk.num_inputs, num_inputs);
809        assert_eq!(vk.selector_comms.len(), selectors.len());
810        assert_eq!(vk.sigma_comms.len(), sigmas.len());
811        assert_eq!(vk.sigma_comms.len(), num_wire_types);
812        selectors
813            .iter()
814            .zip(vk.selector_comms.iter())
815            .for_each(|(p, &p_comm)| {
816                let expected_comm = UnivariateKzgPCS::commit(&pk.commit_key, p).unwrap();
817                assert_eq!(expected_comm, p_comm);
818            });
819        sigmas
820            .iter()
821            .zip(vk.sigma_comms.iter())
822            .for_each(|(p, &p_comm)| {
823                let expected_comm = UnivariateKzgPCS::commit(&pk.commit_key, p).unwrap();
824                assert_eq!(expected_comm, p_comm);
825            });
826        // check plookup verification key
827        if plonk_type == PlonkType::UltraPlonk {
828            let expected_comm = UnivariateKzgPCS::commit(
829                &pk.commit_key,
830                &pk.plookup_pk.as_ref().unwrap().range_table_poly,
831            )
832            .unwrap();
833            assert_eq!(
834                expected_comm,
835                vk.plookup_vk.as_ref().unwrap().range_table_comm
836            );
837
838            let expected_comm = UnivariateKzgPCS::commit(
839                &pk.commit_key,
840                &pk.plookup_pk.as_ref().unwrap().key_table_poly,
841            )
842            .unwrap();
843            assert_eq!(
844                expected_comm,
845                vk.plookup_vk.as_ref().unwrap().key_table_comm
846            );
847        }
848
849        Ok(())
850    }
851
852    #[test]
853    fn test_plonk_proof_system() -> Result<(), PlonkError> {
854        // merlin transcripts
855        test_plonk_proof_system_helper::<Bn254, Fq254, _, StandardTranscript>(
856            PlonkType::TurboPlonk,
857        )?;
858        test_plonk_proof_system_helper::<Bn254, Fq254, _, StandardTranscript>(
859            PlonkType::UltraPlonk,
860        )?;
861        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, StandardTranscript>(
862            PlonkType::TurboPlonk,
863        )?;
864        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, StandardTranscript>(
865            PlonkType::UltraPlonk,
866        )?;
867        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, StandardTranscript>(
868            PlonkType::TurboPlonk,
869        )?;
870        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, StandardTranscript>(
871            PlonkType::UltraPlonk,
872        )?;
873        test_plonk_proof_system_helper::<BW6_761, Fq761, _, StandardTranscript>(
874            PlonkType::TurboPlonk,
875        )?;
876        test_plonk_proof_system_helper::<BW6_761, Fq761, _, StandardTranscript>(
877            PlonkType::UltraPlonk,
878        )?;
879
880        // rescue transcripts
881        // currently only available for bls12-377
882        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
883            PlonkType::TurboPlonk,
884        )?;
885        test_plonk_proof_system_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
886            PlonkType::UltraPlonk,
887        )?;
888
889        // solidity-friendly keccak256 transcripts
890        // currently only needed for CAPE using bls12-381
891        test_plonk_proof_system_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
892            PlonkType::TurboPlonk,
893        )?;
894        Ok(())
895    }
896
897    fn test_plonk_proof_system_helper<E, F, P, T>(plonk_type: PlonkType) -> Result<(), PlonkError>
898    where
899        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
900        F: RescueParameter + SWToTEConParam,
901        P: SWCurveConfig<BaseField = F>,
902        T: PlonkTranscript<F>,
903    {
904        // 1. Simulate universal setup
905        let rng = &mut test_rng();
906        let n = 64;
907        let max_degree = n + 2;
908        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
909
910        // 2. Create circuits
911        let circuits = (0..6)
912            .map(|i| {
913                let m = 2 + i / 3;
914                let a0 = 1 + i % 3;
915                gen_circuit_for_test(m, a0, plonk_type)
916            })
917            .collect::<Result<Vec<_>, PlonkError>>()?;
918        // 3. Preprocessing
919        let (pk1, vk1) = PlonkKzgSnark::<E>::preprocess(&srs, &circuits[0])?;
920        let (pk2, vk2) = PlonkKzgSnark::<E>::preprocess(&srs, &circuits[3])?;
921        // 4. Proving
922        let mut proofs = vec![];
923        let mut extra_msgs = vec![];
924        for (i, cs) in circuits.iter().enumerate() {
925            let pk_ref = if i < 3 { &pk1 } else { &pk2 };
926            let extra_msg = if i % 2 == 0 {
927                None
928            } else {
929                Some(format!("extra message: {}", i).into_bytes())
930            };
931            proofs.push(
932                PlonkKzgSnark::<E>::prove::<_, _, T>(rng, cs, pk_ref, extra_msg.clone()).unwrap(),
933            );
934            extra_msgs.push(extra_msg);
935        }
936
937        // 5. Verification
938        let public_inputs: Vec<Vec<E::ScalarField>> = circuits
939            .iter()
940            .map(|cs| cs.public_input())
941            .collect::<Result<Vec<Vec<E::ScalarField>>, _>>(
942        )?;
943        for (i, proof) in proofs.iter().enumerate() {
944            let vk_ref = if i < 3 { &vk1 } else { &vk2 };
945            assert!(PlonkKzgSnark::<E>::verify::<T>(
946                vk_ref,
947                &public_inputs[i],
948                proof,
949                extra_msgs[i].clone(),
950            )
951            .is_ok());
952            // Inconsistent proof should fail the verification.
953            let mut bad_pub_input = public_inputs[i].clone();
954            bad_pub_input[0] = E::ScalarField::from(0u8);
955            assert!(PlonkKzgSnark::<E>::verify::<T>(
956                vk_ref,
957                &bad_pub_input,
958                proof,
959                extra_msgs[i].clone(),
960            )
961            .is_err());
962            // Incorrect extra transcript message should fail
963            assert!(PlonkKzgSnark::<E>::verify::<T>(
964                vk_ref,
965                &bad_pub_input,
966                proof,
967                Some("wrong message".to_string().into_bytes()),
968            )
969            .is_err());
970
971            // Incorrect proof [W_z] = 0, [W_z*g] = 0
972            // attack against some vulnerable implementation described in:
973            // https://cryptosubtlety.medium.com/00-8d4adcf4d255
974            let mut bad_proof = proof.clone();
975            bad_proof.opening_proof = Commitment::default();
976            bad_proof.shifted_opening_proof = Commitment::default();
977            assert!(PlonkKzgSnark::<E>::verify::<T>(
978                vk_ref,
979                &public_inputs[i],
980                &bad_proof,
981                extra_msgs[i].clone(),
982            )
983            .is_err());
984        }
985
986        // 6. Batch verification
987        let vks = vec![&vk1, &vk1, &vk1, &vk2, &vk2, &vk2];
988        let mut public_inputs_ref: Vec<&[E::ScalarField]> = public_inputs
989            .iter()
990            .map(|pub_input| &pub_input[..])
991            .collect();
992        let mut proofs_ref: Vec<&Proof<E>> = proofs.iter().collect();
993        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
994            &vks,
995            &public_inputs_ref,
996            &proofs_ref,
997            &extra_msgs,
998        )
999        .is_ok());
1000
1001        // Inconsistent params
1002        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1003            &vks[..5],
1004            &public_inputs_ref,
1005            &proofs_ref,
1006            &extra_msgs,
1007        )
1008        .is_err());
1009
1010        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1011            &vks,
1012            &public_inputs_ref[..5],
1013            &proofs_ref,
1014            &extra_msgs,
1015        )
1016        .is_err());
1017
1018        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1019            &vks,
1020            &public_inputs_ref,
1021            &proofs_ref[..5],
1022            &extra_msgs,
1023        )
1024        .is_err());
1025
1026        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1027            &vks,
1028            &public_inputs_ref,
1029            &proofs_ref,
1030            &vec![None; vks.len()],
1031        )
1032        .is_err());
1033
1034        assert!(
1035            PlonkKzgSnark::<E>::batch_verify::<T>(&vks, &public_inputs_ref, &proofs_ref, &[],)
1036                .is_err()
1037        );
1038
1039        // Empty params
1040        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(&[], &[], &[], &[],).is_err());
1041
1042        // Error paths
1043        let tmp_pi_ref = public_inputs_ref[0];
1044        public_inputs_ref[0] = public_inputs_ref[1];
1045        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1046            &vks,
1047            &public_inputs_ref,
1048            &proofs_ref,
1049            &extra_msgs,
1050        )
1051        .is_err());
1052        public_inputs_ref[0] = tmp_pi_ref;
1053
1054        proofs_ref[0] = proofs_ref[1];
1055        assert!(PlonkKzgSnark::<E>::batch_verify::<T>(
1056            &vks,
1057            &public_inputs_ref,
1058            &proofs_ref,
1059            &extra_msgs,
1060        )
1061        .is_err());
1062
1063        Ok(())
1064    }
1065
1066    #[test]
1067    fn test_inconsistent_pub_input_len() -> Result<(), PlonkError> {
1068        // merlin transcripts
1069        test_inconsistent_pub_input_len_helper::<Bn254, Fq254, _, StandardTranscript>(
1070            PlonkType::TurboPlonk,
1071        )?;
1072        test_inconsistent_pub_input_len_helper::<Bn254, Fq254, _, StandardTranscript>(
1073            PlonkType::UltraPlonk,
1074        )?;
1075        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1076            PlonkType::TurboPlonk,
1077        )?;
1078        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1079            PlonkType::UltraPlonk,
1080        )?;
1081        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1082            PlonkType::TurboPlonk,
1083        )?;
1084        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1085            PlonkType::UltraPlonk,
1086        )?;
1087        test_inconsistent_pub_input_len_helper::<BW6_761, Fq761, _, StandardTranscript>(
1088            PlonkType::TurboPlonk,
1089        )?;
1090        test_inconsistent_pub_input_len_helper::<BW6_761, Fq761, _, StandardTranscript>(
1091            PlonkType::UltraPlonk,
1092        )?;
1093
1094        // rescue transcripts
1095        // currently only available for bls12-377
1096        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1097            PlonkType::TurboPlonk,
1098        )?;
1099        test_inconsistent_pub_input_len_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1100            PlonkType::UltraPlonk,
1101        )?;
1102
1103        // Solidity-friendly keccak256 transcript
1104        test_inconsistent_pub_input_len_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
1105            PlonkType::TurboPlonk,
1106        )?;
1107
1108        Ok(())
1109    }
1110
1111    fn test_inconsistent_pub_input_len_helper<E, F, P, T>(
1112        plonk_type: PlonkType,
1113    ) -> Result<(), PlonkError>
1114    where
1115        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1116        F: RescueParameter + SWToTEConParam,
1117        P: SWCurveConfig<BaseField = F>,
1118        T: PlonkTranscript<F>,
1119    {
1120        // 1. Simulate universal setup
1121        let rng = &mut test_rng();
1122        let n = 8;
1123        let max_degree = n + 2;
1124        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1125
1126        // 2. Create circuits
1127        let mut cs1: PlonkCircuit<E::ScalarField> = match plonk_type {
1128            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
1129            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(2),
1130        };
1131        let var = cs1.create_variable(E::ScalarField::from(1u8))?;
1132        cs1.enforce_constant(var, E::ScalarField::from(1u8))?;
1133        cs1.finalize_for_arithmetization()?;
1134        let mut cs2: PlonkCircuit<E::ScalarField> = match plonk_type {
1135            PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
1136            PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(2),
1137        };
1138        cs2.create_public_variable(E::ScalarField::from(1u8))?;
1139        cs2.finalize_for_arithmetization()?;
1140
1141        // 3. Preprocessing
1142        let (pk1, vk1) = PlonkKzgSnark::<E>::preprocess(&srs, &cs1)?;
1143        let (pk2, vk2) = PlonkKzgSnark::<E>::preprocess(&srs, &cs2)?;
1144
1145        // 4. Proving
1146        assert!(PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &cs2, &pk1, None).is_err());
1147        let proof2 = PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &cs2, &pk2, None)?;
1148
1149        // 5. Verification
1150        assert!(
1151            PlonkKzgSnark::<E>::verify::<T>(&vk2, &[E::ScalarField::from(1u8)], &proof2, None,)
1152                .is_ok()
1153        );
1154        // wrong verification key
1155        assert!(
1156            PlonkKzgSnark::<E>::verify::<T>(&vk1, &[E::ScalarField::from(1u8)], &proof2, None,)
1157                .is_err()
1158        );
1159        // wrong public input
1160        assert!(PlonkKzgSnark::<E>::verify::<T>(&vk2, &[], &proof2, None).is_err());
1161
1162        Ok(())
1163    }
1164
1165    #[test]
1166    fn test_plonk_prover_polynomials() -> Result<(), PlonkError> {
1167        // merlin transcripts
1168        test_plonk_prover_polynomials_helper::<Bn254, Fq254, _, StandardTranscript>(
1169            PlonkType::TurboPlonk,
1170        )?;
1171        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1172            PlonkType::TurboPlonk,
1173        )?;
1174        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1175            PlonkType::TurboPlonk,
1176        )?;
1177        test_plonk_prover_polynomials_helper::<BW6_761, Fq761, _, StandardTranscript>(
1178            PlonkType::TurboPlonk,
1179        )?;
1180        test_plonk_prover_polynomials_helper::<Bn254, Fq254, _, StandardTranscript>(
1181            PlonkType::UltraPlonk,
1182        )?;
1183        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1184            PlonkType::UltraPlonk,
1185        )?;
1186        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1187            PlonkType::UltraPlonk,
1188        )?;
1189        test_plonk_prover_polynomials_helper::<BW6_761, Fq761, _, StandardTranscript>(
1190            PlonkType::UltraPlonk,
1191        )?;
1192
1193        // rescue transcripts
1194        // currently only available for bls12-377
1195        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1196            PlonkType::TurboPlonk,
1197        )?;
1198        test_plonk_prover_polynomials_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1199            PlonkType::UltraPlonk,
1200        )?;
1201
1202        // Solidity-friendly keccak256 transcript
1203        test_plonk_prover_polynomials_helper::<Bls12_381, Fq381, _, SolidityTranscript>(
1204            PlonkType::TurboPlonk,
1205        )?;
1206
1207        Ok(())
1208    }
1209
1210    fn test_plonk_prover_polynomials_helper<E, F, P, T>(
1211        plonk_type: PlonkType,
1212    ) -> Result<(), PlonkError>
1213    where
1214        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1215        F: RescueParameter + SWToTEConParam,
1216        P: SWCurveConfig<BaseField = F>,
1217        T: PlonkTranscript<F>,
1218    {
1219        // 1. Simulate universal setup
1220        let rng = &mut test_rng();
1221        let n = 64;
1222        let max_degree = n + 2;
1223        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1224
1225        // 2. Create the circuit
1226        let circuit = gen_circuit_for_test(10, 3, plonk_type)?;
1227        assert!(circuit.num_gates() <= n);
1228
1229        // 3. Preprocessing
1230        let (pk, _) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1231
1232        // 4. Proving
1233        let (_, oracles, challenges) =
1234            PlonkKzgSnark::<E>::batch_prove_internal::<_, _, T>(rng, &[&circuit], &[&pk], None)?;
1235
1236        // 5. Check that the targeted polynomials evaluate to zero on the vanishing set.
1237        check_plonk_prover_polynomials(plonk_type, &oracles[0], &pk, &challenges)?;
1238
1239        Ok(())
1240    }
1241
1242    fn check_plonk_prover_polynomials<E: Pairing>(
1243        plonk_type: PlonkType,
1244        oracles: &Oracles<E::ScalarField>,
1245        pk: &ProvingKey<E>,
1246        challenges: &Challenges<E::ScalarField>,
1247    ) -> Result<(), PlonkError> {
1248        check_circuit_polynomial_on_vanishing_set(oracles, pk)?;
1249        check_perm_polynomials_on_vanishing_set(oracles, pk, challenges)?;
1250        if plonk_type == PlonkType::UltraPlonk {
1251            check_lookup_polynomials_on_vanishing_set(oracles, pk, challenges)?;
1252        }
1253
1254        Ok(())
1255    }
1256
1257    fn check_circuit_polynomial_on_vanishing_set<E: Pairing>(
1258        oracles: &Oracles<E::ScalarField>,
1259        pk: &ProvingKey<E>,
1260    ) -> Result<(), PlonkError> {
1261        let q_lc: Vec<&DensePolynomial<E::ScalarField>> =
1262            (0..GATE_WIDTH).map(|j| &pk.selectors[j]).collect();
1263        let q_mul: Vec<&DensePolynomial<E::ScalarField>> = (GATE_WIDTH..GATE_WIDTH + 2)
1264            .map(|j| &pk.selectors[j])
1265            .collect();
1266        let q_hash: Vec<&DensePolynomial<E::ScalarField>> = (GATE_WIDTH + 2..2 * GATE_WIDTH + 2)
1267            .map(|j| &pk.selectors[j])
1268            .collect();
1269        let q_o = &pk.selectors[2 * GATE_WIDTH + 2];
1270        let q_c = &pk.selectors[2 * GATE_WIDTH + 3];
1271        let q_ecc = &pk.selectors[2 * GATE_WIDTH + 4];
1272        let circuit_poly = q_c
1273            + &oracles.pub_inp_poly
1274            + (&oracles.wire_polys[0]).mul(q_lc[0])
1275            + (&oracles.wire_polys[1]).mul(q_lc[1])
1276            + (&oracles.wire_polys[2]).mul(q_lc[2])
1277            + (&oracles.wire_polys[3]).mul(q_lc[3])
1278            + (&oracles.wire_polys[0])
1279                .mul(&oracles.wire_polys[1])
1280                .mul(q_mul[0])
1281            + (&oracles.wire_polys[2])
1282                .mul(&oracles.wire_polys[3])
1283                .mul(q_mul[1])
1284            + (&oracles.wire_polys[0])
1285                .mul(&oracles.wire_polys[1])
1286                .mul(&oracles.wire_polys[2])
1287                .mul(&oracles.wire_polys[3])
1288                .mul(&oracles.wire_polys[4])
1289                .mul(q_ecc)
1290            + (&oracles.wire_polys[0])
1291                .mul(&oracles.wire_polys[0])
1292                .mul(&oracles.wire_polys[0])
1293                .mul(&oracles.wire_polys[0])
1294                .mul(&oracles.wire_polys[0])
1295                .mul(q_hash[0])
1296            + (&oracles.wire_polys[1])
1297                .mul(&oracles.wire_polys[1])
1298                .mul(&oracles.wire_polys[1])
1299                .mul(&oracles.wire_polys[1])
1300                .mul(&oracles.wire_polys[1])
1301                .mul(q_hash[1])
1302            + (&oracles.wire_polys[2])
1303                .mul(&oracles.wire_polys[2])
1304                .mul(&oracles.wire_polys[2])
1305                .mul(&oracles.wire_polys[2])
1306                .mul(&oracles.wire_polys[2])
1307                .mul(q_hash[2])
1308            + (&oracles.wire_polys[3])
1309                .mul(&oracles.wire_polys[3])
1310                .mul(&oracles.wire_polys[3])
1311                .mul(&oracles.wire_polys[3])
1312                .mul(&oracles.wire_polys[3])
1313                .mul(q_hash[3])
1314            + (&oracles.wire_polys[4]).mul(q_o).neg();
1315
1316        // check that the polynomial evaluates to zero on the vanishing set
1317        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(pk.domain_size())
1318            .ok_or(PlonkError::DomainCreationError)?;
1319        for i in 0..domain.size() {
1320            assert_eq!(
1321                circuit_poly.evaluate(&domain.element(i)),
1322                E::ScalarField::zero()
1323            );
1324        }
1325
1326        Ok(())
1327    }
1328
1329    fn check_perm_polynomials_on_vanishing_set<E: Pairing>(
1330        oracles: &Oracles<E::ScalarField>,
1331        pk: &ProvingKey<E>,
1332        challenges: &Challenges<E::ScalarField>,
1333    ) -> Result<(), PlonkError> {
1334        let beta = challenges.beta;
1335        let gamma = challenges.gamma;
1336
1337        // check that \prod_i [w_i(X) + beta * k_i * X + gamma] * z(X) = \prod_i [w_i(X)
1338        // + beta * sigma_i(X) + gamma] * z(wX) on the vanishing set
1339        let one_poly = DensePolynomial::from_coefficients_vec(vec![E::ScalarField::one()]);
1340        let poly_1 = oracles
1341            .wire_polys
1342            .iter()
1343            .enumerate()
1344            .fold(one_poly.clone(), |acc, (j, w)| {
1345                let poly =
1346                    &DensePolynomial::from_coefficients_vec(vec![gamma, beta * pk.k()[j]]) + w;
1347                acc.mul(&poly)
1348            });
1349        let poly_2 =
1350            oracles
1351                .wire_polys
1352                .iter()
1353                .zip(pk.sigmas.iter())
1354                .fold(one_poly, |acc, (w, sigma)| {
1355                    let poly = w.clone()
1356                        + sigma.mul(beta)
1357                        + DensePolynomial::from_coefficients_vec(vec![gamma]);
1358                    acc.mul(&poly)
1359                });
1360
1361        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(pk.domain_size())
1362            .ok_or(PlonkError::DomainCreationError)?;
1363        for i in 0..domain.size() {
1364            let point = domain.element(i);
1365            let eval_1 = poly_1.evaluate(&point) * oracles.prod_perm_poly.evaluate(&point);
1366            let eval_2 = poly_2.evaluate(&point)
1367                * oracles.prod_perm_poly.evaluate(&(point * domain.group_gen));
1368            assert_eq!(eval_1, eval_2);
1369        }
1370
1371        // check z(X) = 1 at point 1
1372        assert_eq!(
1373            oracles.prod_perm_poly.evaluate(&domain.element(0)),
1374            E::ScalarField::one()
1375        );
1376
1377        Ok(())
1378    }
1379
1380    fn check_lookup_polynomials_on_vanishing_set<E: Pairing>(
1381        oracles: &Oracles<E::ScalarField>,
1382        pk: &ProvingKey<E>,
1383        challenges: &Challenges<E::ScalarField>,
1384    ) -> Result<(), PlonkError> {
1385        let beta = challenges.beta;
1386        let gamma = challenges.gamma;
1387        let n = pk.domain_size();
1388        let domain = Radix2EvaluationDomain::<E::ScalarField>::new(n)
1389            .ok_or(PlonkError::DomainCreationError)?;
1390        let prod_poly = &oracles.plookup_oracles.prod_lookup_poly;
1391        let h_polys = &oracles.plookup_oracles.h_polys;
1392
1393        // check z(X) = 1 at point 1
1394        assert_eq!(
1395            prod_poly.evaluate(&domain.element(0)),
1396            E::ScalarField::one()
1397        );
1398
1399        // check z(X) = 1 at point w^{n-1}
1400        assert_eq!(
1401            prod_poly.evaluate(&domain.element(n - 1)),
1402            E::ScalarField::one()
1403        );
1404
1405        // check h1(X) = h2(w * X) at point w^{n-1}
1406        assert_eq!(
1407            h_polys[0].evaluate(&domain.element(n - 1)),
1408            h_polys[1].evaluate(&domain.element(0))
1409        );
1410
1411        // check z(X) *
1412        //      (1+beta) * (gamma + merged_lookup_wire(X)) *
1413        //      (gamma(1+beta) + merged_table(X) + beta * merged_table(Xw))
1414        //     = z(Xw) *
1415        //      (gamma(1+beta) + h1(X) + beta * h1(Xw)) *
1416        //      (gamma(1+beta) + h2(x) + beta * h2(Xw))
1417        // on the vanishing set excluding point w^{n-1}
1418        let beta_plus_one = E::ScalarField::one() + beta;
1419        let gamma_mul_beta_plus_one = gamma * beta_plus_one;
1420
1421        let range_table_poly_ref = &pk.plookup_pk.as_ref().unwrap().range_table_poly;
1422        let key_table_poly_ref = &pk.plookup_pk.as_ref().unwrap().key_table_poly;
1423        let table_dom_sep_poly_ref = &pk.plookup_pk.as_ref().unwrap().table_dom_sep_poly;
1424        let q_dom_sep_poly_ref = &pk.plookup_pk.as_ref().unwrap().q_dom_sep_poly;
1425
1426        for i in 0..domain.size() - 1 {
1427            let point = domain.element(i);
1428            let next_point = point * domain.group_gen;
1429            let merged_lookup_wire_eval = eval_merged_lookup_witness::<E>(
1430                challenges.tau.unwrap(),
1431                oracles.wire_polys[5].evaluate(&point),
1432                oracles.wire_polys[0].evaluate(&point),
1433                oracles.wire_polys[1].evaluate(&point),
1434                oracles.wire_polys[2].evaluate(&point),
1435                pk.q_lookup_poly()?.evaluate(&point),
1436                q_dom_sep_poly_ref.evaluate(&point),
1437            );
1438            let merged_table_eval = eval_merged_table::<E>(
1439                challenges.tau.unwrap(),
1440                range_table_poly_ref.evaluate(&point),
1441                key_table_poly_ref.evaluate(&point),
1442                pk.q_lookup_poly()?.evaluate(&point),
1443                oracles.wire_polys[3].evaluate(&point),
1444                oracles.wire_polys[4].evaluate(&point),
1445                table_dom_sep_poly_ref.evaluate(&point),
1446            );
1447            let merged_table_next_eval = eval_merged_table::<E>(
1448                challenges.tau.unwrap(),
1449                range_table_poly_ref.evaluate(&next_point),
1450                key_table_poly_ref.evaluate(&next_point),
1451                pk.q_lookup_poly()?.evaluate(&next_point),
1452                oracles.wire_polys[3].evaluate(&next_point),
1453                oracles.wire_polys[4].evaluate(&next_point),
1454                table_dom_sep_poly_ref.evaluate(&next_point),
1455            );
1456
1457            let eval_1 = prod_poly.evaluate(&point)
1458                * beta_plus_one
1459                * (gamma + merged_lookup_wire_eval)
1460                * (gamma_mul_beta_plus_one + merged_table_eval + beta * merged_table_next_eval);
1461            let eval_2 = prod_poly.evaluate(&next_point)
1462                * (gamma_mul_beta_plus_one
1463                    + h_polys[0].evaluate(&point)
1464                    + beta * h_polys[0].evaluate(&next_point))
1465                * (gamma_mul_beta_plus_one
1466                    + h_polys[1].evaluate(&point)
1467                    + beta * h_polys[1].evaluate(&next_point));
1468            assert_eq!(eval_1, eval_2, "i={}, domain_size={}", i, domain.size());
1469        }
1470
1471        Ok(())
1472    }
1473
1474    #[test]
1475    fn test_proof_from_to_fields() -> Result<(), PlonkError> {
1476        test_proof_from_to_fields_helper::<Bn254, _>()?;
1477        test_proof_from_to_fields_helper::<Bls12_381, _>()?;
1478        test_proof_from_to_fields_helper::<Bls12_377, _>()?;
1479        test_proof_from_to_fields_helper::<BW6_761, _>()?;
1480        Ok(())
1481    }
1482
1483    fn test_proof_from_to_fields_helper<E, P>() -> Result<(), PlonkError>
1484    where
1485        E: Pairing<G1Affine = Affine<P>>,
1486        E::BaseField: RescueParameter + SWToTEConParam,
1487        P: SWCurveConfig<BaseField = E::BaseField, ScalarField = E::ScalarField>,
1488    {
1489        let rng = &mut jf_utils::test_rng();
1490        let circuit = gen_circuit_for_test(3, 4, PlonkType::TurboPlonk)?;
1491        let max_degree = 80;
1492        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1493
1494        let (pk, _) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1495        let proof =
1496            PlonkKzgSnark::<E>::prove::<_, _, StandardTranscript>(rng, &circuit, &pk, None)?;
1497
1498        let base_fields: Vec<E::BaseField> = proof.clone().into();
1499        let res: Proof<E> = base_fields.try_into()?;
1500        assert_eq!(res, proof);
1501
1502        Ok(())
1503    }
1504
1505    #[test]
1506    fn test_serde() -> Result<(), PlonkError> {
1507        // merlin transcripts
1508        test_serde_helper::<Bn254, Fq254, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1509        test_serde_helper::<Bn254, Fq254, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1510        test_serde_helper::<Bls12_377, Fq377, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1511        test_serde_helper::<Bls12_377, Fq377, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1512        test_serde_helper::<Bls12_381, Fq381, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1513        test_serde_helper::<Bls12_381, Fq381, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1514        test_serde_helper::<BW6_761, Fq761, _, StandardTranscript>(PlonkType::TurboPlonk)?;
1515        test_serde_helper::<BW6_761, Fq761, _, StandardTranscript>(PlonkType::UltraPlonk)?;
1516
1517        // rescue transcripts
1518        // currently only available for bls12-377
1519        test_serde_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(PlonkType::TurboPlonk)?;
1520        test_serde_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(PlonkType::UltraPlonk)?;
1521
1522        // Solidity-friendly keccak256 transcript
1523        test_serde_helper::<Bls12_381, Fq381, _, SolidityTranscript>(PlonkType::TurboPlonk)?;
1524
1525        Ok(())
1526    }
1527
1528    fn test_serde_helper<E, F, P, T>(plonk_type: PlonkType) -> Result<(), PlonkError>
1529    where
1530        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1531        F: RescueParameter + SWToTEConParam,
1532        P: SWCurveConfig<BaseField = F>,
1533        T: PlonkTranscript<F>,
1534    {
1535        let rng = &mut jf_utils::test_rng();
1536        let circuit = gen_circuit_for_test(3, 4, plonk_type)?;
1537        let max_degree = 80;
1538        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1539
1540        let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, &circuit)?;
1541        let proof = PlonkKzgSnark::<E>::prove::<_, _, T>(rng, &circuit, &pk, None)?;
1542
1543        let mut ser_bytes = Vec::new();
1544        srs.serialize_compressed(&mut ser_bytes)?;
1545        let de = UniversalSrs::<E>::deserialize_compressed(&ser_bytes[..])?;
1546        assert_eq!(de, srs);
1547
1548        let mut ser_bytes = Vec::new();
1549        pk.serialize_compressed(&mut ser_bytes)?;
1550        let de = ProvingKey::<E>::deserialize_compressed(&ser_bytes[..])?;
1551        assert_eq!(de, pk);
1552
1553        let mut ser_bytes = Vec::new();
1554        vk.serialize_compressed(&mut ser_bytes)?;
1555        let de = VerifyingKey::<E>::deserialize_compressed(&ser_bytes[..])?;
1556        assert_eq!(de, vk);
1557
1558        let mut ser_bytes = Vec::new();
1559        proof.serialize_compressed(&mut ser_bytes)?;
1560        let de = Proof::<E>::deserialize_compressed(&ser_bytes[..])?;
1561        assert_eq!(de, proof);
1562
1563        Ok(())
1564    }
1565
1566    #[test]
1567    fn test_key_aggregation_and_batch_prove() -> Result<(), PlonkError> {
1568        // merlin transcripts
1569        test_key_aggregation_and_batch_prove_helper::<Bn254, Fq254, _, StandardTranscript>(
1570            PlonkType::TurboPlonk,
1571        )?;
1572        test_key_aggregation_and_batch_prove_helper::<Bls12_377, Fq377, _, StandardTranscript>(
1573            PlonkType::TurboPlonk,
1574        )?;
1575        test_key_aggregation_and_batch_prove_helper::<Bls12_381, Fq381, _, StandardTranscript>(
1576            PlonkType::TurboPlonk,
1577        )?;
1578        test_key_aggregation_and_batch_prove_helper::<BW6_761, Fq761, _, StandardTranscript>(
1579            PlonkType::TurboPlonk,
1580        )?;
1581
1582        // rescue transcripts
1583        // currently only available for bls12-377
1584        test_key_aggregation_and_batch_prove_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>(
1585            PlonkType::TurboPlonk,
1586        )
1587    }
1588
1589    fn test_key_aggregation_and_batch_prove_helper<E, F, P, T>(
1590        plonk_type: PlonkType,
1591    ) -> Result<(), PlonkError>
1592    where
1593        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1594        F: RescueParameter + SWToTEConParam,
1595        P: SWCurveConfig<BaseField = F>,
1596        T: PlonkTranscript<F>,
1597    {
1598        // 1. Simulate universal setup
1599        let rng = &mut test_rng();
1600        let n = 128;
1601        let max_degree = n + 2;
1602        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
1603
1604        // 2. Create many circuits with same domain size
1605        let circuits = (6..13)
1606            .map(|i| gen_circuit_for_test::<E::ScalarField>(i, i, plonk_type))
1607            .collect::<Result<Vec<_>, PlonkError>>()?; // the number of gates = 4m + 11
1608        let cs_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
1609
1610        // 3. Preprocessing
1611        let mut prove_keys = vec![];
1612        let mut verify_keys = vec![];
1613        for circuit in circuits.iter() {
1614            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, circuit)?;
1615            prove_keys.push(pk);
1616            verify_keys.push(vk);
1617        }
1618        let pks_ref: Vec<&ProvingKey<E>> = prove_keys.iter().collect();
1619        let vks_ref: Vec<&VerifyingKey<E>> = verify_keys.iter().collect();
1620
1621        // 4. Batch Proving and verification
1622        check_batch_prove_and_verify::<_, _, _, _, T>(rng, &cs_ref, &pks_ref, &vks_ref)?;
1623
1624        // Batch proving with circuit/key aggregation
1625        //
1626        // 2. Create circuits
1627        let type_a_circuits: Vec<PlonkCircuit<E::ScalarField>> = (6..13)
1628            .map(|i| gen_mergeable_circuit(i, i, MergeableCircuitType::TypeA))
1629            .collect::<Result<Vec<_>, PlonkError>>()?; // the number of gates = 4m + 11
1630        let type_b_circuits = (6..13)
1631            .map(|i| gen_mergeable_circuit(i, i, MergeableCircuitType::TypeB))
1632            .collect::<Result<Vec<_>, PlonkError>>()?;
1633        // merge circuits
1634        let circuits = type_a_circuits
1635            .iter()
1636            .zip(type_b_circuits.iter())
1637            .map(|(cs_a, cs_b)| cs_a.merge(cs_b))
1638            .collect::<Result<Vec<_>, _>>()?;
1639        let cs_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
1640
1641        // 3. Preprocessing
1642        let mut pks_type_a = vec![];
1643        let mut vks_type_a = vec![];
1644        for cs_a in type_a_circuits.iter() {
1645            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs_a)?;
1646            pks_type_a.push(pk);
1647            vks_type_a.push(vk);
1648        }
1649        let mut pks_type_b = vec![];
1650        let mut vks_type_b = vec![];
1651        for cs_b in type_b_circuits.iter() {
1652            let (pk, vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs_b)?;
1653            pks_type_b.push(pk);
1654            vks_type_b.push(vk);
1655        }
1656        // merge proving keys
1657        let pks = pks_type_a
1658            .iter()
1659            .zip(pks_type_b.iter())
1660            .map(|(pk_a, pk_b)| pk_a.merge(pk_b))
1661            .collect::<Result<Vec<_>, PlonkError>>()?;
1662        // merge verification keys
1663        let vks = vks_type_a
1664            .iter()
1665            .zip(vks_type_b.iter())
1666            .map(|(vk_a, vk_b)| vk_a.merge(vk_b))
1667            .collect::<Result<Vec<_>, PlonkError>>()?;
1668        // check that the merged keys are correct
1669        for (cs, vk) in circuits.iter().zip(vks.iter()) {
1670            let (_, mut expected_vk) = PlonkKzgSnark::<E>::preprocess(&srs, cs)?;
1671            expected_vk.is_merged = true;
1672            assert_eq!(*vk, expected_vk);
1673        }
1674
1675        let pks_ref: Vec<&ProvingKey<E>> = pks.iter().collect();
1676        let vks_ref: Vec<&VerifyingKey<E>> = vks.iter().collect();
1677
1678        // 4. Batch Proving and verification
1679        check_batch_prove_and_verify::<_, _, _, _, T>(rng, &cs_ref, &pks_ref, &vks_ref)?;
1680
1681        Ok(())
1682    }
1683
1684    fn check_batch_prove_and_verify<E, F, P, R, T>(
1685        rng: &mut R,
1686        cs_ref: &[&PlonkCircuit<E::ScalarField>],
1687        pks_ref: &[&ProvingKey<E>],
1688        vks_ref: &[&VerifyingKey<E>],
1689    ) -> Result<(), PlonkError>
1690    where
1691        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
1692        F: RescueParameter + SWToTEConParam,
1693        P: SWCurveConfig<BaseField = F>,
1694        R: CryptoRng + RngCore,
1695        T: PlonkTranscript<F>,
1696    {
1697        // Batch Proving
1698        let batch_proof = PlonkKzgSnark::<E>::batch_prove::<_, _, T>(rng, cs_ref, pks_ref)?;
1699
1700        // Verification
1701        let public_inputs: Vec<Vec<E::ScalarField>> = cs_ref
1702            .iter()
1703            .map(|&cs| cs.public_input())
1704            .collect::<Result<Vec<Vec<E::ScalarField>>, _>>(
1705        )?;
1706        let pi_ref: Vec<&[E::ScalarField]> = public_inputs
1707            .iter()
1708            .map(|pub_input| &pub_input[..])
1709            .collect();
1710        assert!(
1711            PlonkKzgSnark::<E>::verify_batch_proof::<T>(vks_ref, &pi_ref, &batch_proof,).is_ok()
1712        );
1713        let mut bad_pi_ref = pi_ref.clone();
1714        bad_pi_ref[0] = bad_pi_ref[1];
1715        assert!(
1716            PlonkKzgSnark::<E>::verify_batch_proof::<T>(vks_ref, &bad_pi_ref, &batch_proof,)
1717                .is_err()
1718        );
1719
1720        Ok(())
1721    }
1722
1723    fn gen_mergeable_circuit<F: PrimeField>(
1724        m: usize,
1725        a0: usize,
1726        circuit_type: MergeableCircuitType,
1727    ) -> Result<PlonkCircuit<F>, PlonkError> {
1728        let mut cs: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
1729        // Create variables
1730        let mut a = vec![];
1731        for i in a0..(a0 + 4 * m) {
1732            a.push(cs.create_variable(F::from(i as u64))?);
1733        }
1734        let b = [
1735            cs.create_public_variable(F::from(m as u64 * 2))?,
1736            cs.create_public_variable(F::from(a0 as u64 * 2 + m as u64 * 4 - 1))?,
1737        ];
1738        let c = cs.create_public_variable(
1739            (cs.witness(b[1])? + cs.witness(a[0])?) * (cs.witness(b[1])? - cs.witness(a[0])?),
1740        )?;
1741
1742        // Create gates:
1743        // 1. a0 + ... + a_{4*m-1} = b0 * b1
1744        // 2. (b1 + a0) * (b1 - a0) = c
1745        // 3. b0 = 2 * m
1746        let mut acc = cs.zero();
1747        a.iter().for_each(|&elem| acc = cs.add(acc, elem).unwrap());
1748        let b_mul = cs.mul(b[0], b[1])?;
1749        cs.enforce_equal(acc, b_mul)?;
1750        let b1_plus_a0 = cs.add(b[1], a[0])?;
1751        let b1_minus_a0 = cs.sub(b[1], a[0])?;
1752        cs.mul_gate(b1_plus_a0, b1_minus_a0, c)?;
1753        cs.enforce_constant(b[0], F::from(m as u64 * 2))?;
1754
1755        cs.finalize_for_mergeable_circuit(circuit_type)?;
1756
1757        Ok(cs)
1758    }
1759}