jf_plonk/proof_system/
batch_arg.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! An argument system that proves/verifies multiple instances in a batch.
8use crate::{
9    errors::{PlonkError, SnarkError::ParameterError},
10    proof_system::{
11        structs::{BatchProof, OpenKey, ProvingKey, ScalarsAndBases, UniversalSrs, VerifyingKey},
12        verifier::Verifier,
13        PlonkKzgSnark, UniversalSNARK,
14    },
15    transcript::PlonkTranscript,
16};
17use ark_ec::{
18    pairing::Pairing,
19    short_weierstrass::{Affine, SWCurveConfig},
20};
21use ark_ff::One;
22use ark_std::{
23    format,
24    marker::PhantomData,
25    rand::{CryptoRng, RngCore},
26    string::ToString,
27    vec,
28    vec::Vec,
29};
30use jf_relation::{gadgets::ecc::SWToTEConParam, Circuit, MergeableCircuitType, PlonkCircuit};
31use jf_rescue::RescueParameter;
32use jf_utils::multi_pairing;
33
34/// A batching argument.
35pub struct BatchArgument<E: Pairing>(PhantomData<E>);
36
37/// A circuit instance that consists of the corresponding proving
38/// key/verification key/circuit.
39#[derive(Clone)]
40pub struct Instance<E: Pairing> {
41    // TODO: considering giving instance an ID
42    prove_key: ProvingKey<E>, // the verification key can be obtained inside the proving key.
43    circuit: PlonkCircuit<E::ScalarField>,
44    _circuit_type: MergeableCircuitType,
45}
46
47impl<E: Pairing> Instance<E> {
48    /// Get verification key by reference.
49    pub fn verify_key_ref(&self) -> &VerifyingKey<E> {
50        &self.prove_key.vk
51    }
52
53    /// Get mutable circuit by reference.
54    pub fn circuit_mut_ref(&mut self) -> &mut PlonkCircuit<E::ScalarField> {
55        &mut self.circuit
56    }
57}
58
59impl<E, F, P> BatchArgument<E>
60where
61    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
62    F: RescueParameter + SWToTEConParam,
63    P: SWCurveConfig<BaseField = F>,
64{
65    /// Setup the circuit and the proving key for a (mergeable) instance.
66    pub fn setup_instance(
67        srs: &UniversalSrs<E>,
68        mut circuit: PlonkCircuit<E::ScalarField>,
69        circuit_type: MergeableCircuitType,
70    ) -> Result<Instance<E>, PlonkError> {
71        circuit.finalize_for_mergeable_circuit(circuit_type)?;
72        let (prove_key, _) = PlonkKzgSnark::preprocess(srs, &circuit)?;
73        Ok(Instance {
74            prove_key,
75            circuit,
76            _circuit_type: circuit_type,
77        })
78    }
79
80    /// Prove satisfiability of multiple instances in a batch.
81    pub fn batch_prove<R, T>(
82        prng: &mut R,
83        instances_type_a: &[Instance<E>],
84        instances_type_b: &[Instance<E>],
85    ) -> Result<BatchProof<E>, PlonkError>
86    where
87        R: CryptoRng + RngCore,
88        T: PlonkTranscript<F>,
89    {
90        if instances_type_a.len() != instances_type_b.len() {
91            return Err(ParameterError(format!(
92                "the number of type A instances {} is different from the number of type B instances {}.", 
93                instances_type_a.len(),
94                instances_type_b.len())
95            ).into());
96        }
97        let pks = instances_type_a
98            .iter()
99            .zip(instances_type_b.iter())
100            .map(|(pred_a, pred_b)| pred_a.prove_key.merge(&pred_b.prove_key))
101            .collect::<Result<Vec<_>, _>>()?;
102
103        let circuits = instances_type_a
104            .iter()
105            .zip(instances_type_b.iter())
106            .map(|(pred_a, pred_b)| pred_a.circuit.merge(&pred_b.circuit))
107            .collect::<Result<Vec<_>, _>>()?;
108        let pks_ref: Vec<&ProvingKey<E>> = pks.iter().collect();
109        let circuits_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
110
111        PlonkKzgSnark::batch_prove::<_, _, T>(prng, &circuits_ref, &pks_ref)
112    }
113
114    /// Partially verify a batched proof without performing the pairing. Return
115    /// the two group elements used in the final pairing.
116    pub fn partial_verify<T>(
117        beta_g: &E::G1Affine,
118        generator_g: &E::G1Affine,
119        merged_vks: &[VerifyingKey<E>],
120        shared_public_input: &[E::ScalarField],
121        batch_proof: &BatchProof<E>,
122        blinding_factor: E::ScalarField,
123    ) -> Result<(E::G1, E::G1), PlonkError>
124    where
125        T: PlonkTranscript<F>,
126    {
127        if merged_vks.is_empty() {
128            return Err(ParameterError("empty merged verification keys".to_string()).into());
129        }
130        if merged_vks.len() != batch_proof.len() {
131            return Err(ParameterError(format!(
132                "the number of verification keys {} is different from the number of instances {}.",
133                merged_vks.len(),
134                batch_proof.len()
135            ))
136            .into());
137        }
138        let domain_size = merged_vks[0].domain_size;
139        for (i, vk) in merged_vks.iter().skip(1).enumerate() {
140            if vk.domain_size != domain_size {
141                return Err(ParameterError(format!(
142                    "the {}-th verification key's domain size {} is different from {}.",
143                    i, vk.domain_size, domain_size
144                ))
145                .into());
146            }
147        }
148        let verifier = Verifier::new(domain_size)?;
149        // we need to copy the public input once after merging the circuit
150        let shared_public_input = [shared_public_input, shared_public_input].concat();
151        let public_inputs = vec![&shared_public_input[..]; merged_vks.len()];
152        let merged_vks_ref: Vec<&VerifyingKey<E>> = merged_vks.iter().collect();
153        let pcs_info =
154            verifier.prepare_pcs_info::<T>(&merged_vks_ref, &public_inputs, batch_proof, &None)?;
155
156        // inner1 = [open_proof] + u * [shifted_open_proof] + blinding_factor * [1]1
157        let mut scalars_and_bases = ScalarsAndBases::<E>::new();
158        scalars_and_bases.push(E::ScalarField::one(), pcs_info.opening_proof.0);
159        scalars_and_bases.push(pcs_info.u, pcs_info.shifted_opening_proof.0);
160        scalars_and_bases.push(blinding_factor, *generator_g);
161        let inner1 = scalars_and_bases.multi_scalar_mul();
162
163        // inner2 = eval_point * [open_proof] + next_eval_point * u *
164        // [shifted_open_proof] + [aggregated_comm] - aggregated_eval * [1]1 +
165        // blinding_factor * [beta]1
166        let mut scalars_and_bases = pcs_info.comm_scalars_and_bases;
167        scalars_and_bases.push(pcs_info.eval_point, pcs_info.opening_proof.0);
168        scalars_and_bases.push(
169            pcs_info.next_eval_point * pcs_info.u,
170            pcs_info.shifted_opening_proof.0,
171        );
172        scalars_and_bases.push(-pcs_info.eval, *generator_g);
173        scalars_and_bases.push(blinding_factor, *beta_g);
174        let inner2 = scalars_and_bases.multi_scalar_mul();
175
176        Ok((inner1, inner2))
177    }
178}
179
180impl<E> BatchArgument<E>
181where
182    E: Pairing,
183{
184    /// Aggregate verification keys
185    pub fn aggregate_verify_keys(
186        vks_type_a: &[&VerifyingKey<E>],
187        vks_type_b: &[&VerifyingKey<E>],
188    ) -> Result<Vec<VerifyingKey<E>>, PlonkError> {
189        if vks_type_a.len() != vks_type_b.len() {
190            return Err(ParameterError(format!(
191                "the number of type A verification keys {} is different from the number of type B verification keys {}.", 
192                vks_type_a.len(),
193                vks_type_b.len())
194            ).into());
195        }
196        vks_type_a
197            .iter()
198            .zip(vks_type_b.iter())
199            .map(|(vk_a, vk_b)| vk_a.merge(vk_b))
200            .collect::<Result<Vec<_>, PlonkError>>()
201    }
202
203    /// Perform the final pairing to verify the proof.
204    pub fn decide(open_key: &OpenKey<E>, inner1: E::G1, inner2: E::G1) -> Result<bool, PlonkError> {
205        // check e(elem1, [beta]2) ?= e(elem2, [1]2)
206        let g1_elems: Vec<<E as Pairing>::G1Affine> = vec![inner1.into(), (-inner2).into()];
207        let g2_elems = vec![open_key.beta_h, open_key.h];
208        Ok(multi_pairing::<E>(&g1_elems, &g2_elems).0 == E::TargetField::one())
209    }
210}
211
212pub(crate) fn new_mergeable_circuit_for_test<E: Pairing>(
213    shared_public_input: E::ScalarField,
214    i: usize,
215    circuit_type: MergeableCircuitType,
216) -> Result<PlonkCircuit<E::ScalarField>, PlonkError> {
217    let mut circuit = PlonkCircuit::new_turbo_plonk();
218    let shared_pub_var = circuit.create_public_variable(shared_public_input)?;
219    let mut var = shared_pub_var;
220    if circuit_type == MergeableCircuitType::TypeA {
221        // compute type A instances: add `shared_public_input` by i times
222        for _ in 0..i {
223            var = circuit.add(var, shared_pub_var)?;
224        }
225    } else {
226        // compute type B instances: mul `shared_public_input` by i times
227        for _ in 0..i {
228            var = circuit.mul(var, shared_pub_var)?;
229        }
230    }
231    Ok(circuit)
232}
233
234/// Create `num_instances` type A/B instance verifying keys and
235/// compute the corresponding batch proof. Only used for testing.
236#[allow(clippy::type_complexity)]
237pub fn build_batch_proof_and_vks_for_test<E, F, P, R, T>(
238    rng: &mut R,
239    srs: &UniversalSrs<E>,
240    num_instances: usize,
241    shared_public_input: E::ScalarField,
242) -> Result<(BatchProof<E>, Vec<VerifyingKey<E>>, Vec<VerifyingKey<E>>), PlonkError>
243where
244    E: Pairing<BaseField = F, G1Affine = Affine<P>>,
245    F: RescueParameter + SWToTEConParam,
246    P: SWCurveConfig<BaseField = F>,
247    R: CryptoRng + RngCore,
248    T: PlonkTranscript<F>,
249{
250    let mut instances_type_a = vec![];
251    let mut instances_type_b = vec![];
252    let mut vks_type_a = vec![];
253    let mut vks_type_b = vec![];
254    for i in 10..10 + num_instances {
255        let circuit = new_mergeable_circuit_for_test::<E>(
256            shared_public_input,
257            i,
258            MergeableCircuitType::TypeA,
259        )?;
260        let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeA)?;
261        vks_type_a.push(instance.verify_key_ref().clone());
262        instances_type_a.push(instance);
263
264        let circuit = new_mergeable_circuit_for_test::<E>(
265            shared_public_input,
266            i,
267            MergeableCircuitType::TypeB,
268        )?;
269        let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeB)?;
270        vks_type_b.push(instance.verify_key_ref().clone());
271        instances_type_b.push(instance);
272    }
273
274    let batch_proof =
275        BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
276    Ok((batch_proof, vks_type_a, vks_type_b))
277}
278
279#[cfg(test)]
280mod test {
281    use super::*;
282    use crate::transcript::RescueTranscript;
283    use ark_bls12_377::{Bls12_377, Fq as Fq377};
284    use ark_std::UniformRand;
285    use jf_utils::test_rng;
286
287    #[test]
288    fn test_batch_argument() -> Result<(), PlonkError> {
289        test_batch_argument_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>()
290    }
291
292    fn test_batch_argument_helper<E, F, P, T>() -> Result<(), PlonkError>
293    where
294        E: Pairing<BaseField = F, G1Affine = Affine<P>>,
295        F: RescueParameter + SWToTEConParam,
296        P: SWCurveConfig<BaseField = F>,
297        T: PlonkTranscript<F>,
298    {
299        // 1. Simulate universal setup
300        let rng = &mut test_rng();
301        let n = 128;
302        let max_degree = n + 2;
303        let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
304
305        // 2. Setup instances
306        let shared_public_input = E::ScalarField::rand(rng);
307        let mut instances_type_a = vec![];
308        let mut instances_type_b = vec![];
309        for i in 32..50 {
310            let circuit = new_mergeable_circuit_for_test::<E>(
311                shared_public_input,
312                i,
313                MergeableCircuitType::TypeA,
314            )?;
315            let instance =
316                BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeA)?;
317            instances_type_a.push(instance);
318
319            let circuit = new_mergeable_circuit_for_test::<E>(
320                shared_public_input,
321                i,
322                MergeableCircuitType::TypeB,
323            )?;
324            let instance =
325                BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeB)?;
326            instances_type_b.push(instance);
327        }
328
329        // 3. Batch Proving
330        let batch_proof =
331            BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
332        // error path: inconsistent length between instances_type_a and
333        // instances_type_b
334        assert!(
335            BatchArgument::batch_prove::<_, T>(rng, &instances_type_a[1..], &instances_type_b)
336                .is_err()
337        );
338
339        // 4. Aggregate verification keys
340        let vks_type_a: Vec<&VerifyingKey<E>> = instances_type_a
341            .iter()
342            .map(|pred| pred.verify_key_ref())
343            .collect();
344        let vks_type_b: Vec<&VerifyingKey<E>> = instances_type_b
345            .iter()
346            .map(|pred| pred.verify_key_ref())
347            .collect();
348        let merged_vks = BatchArgument::aggregate_verify_keys(&vks_type_a, &vks_type_b)?;
349        // error path: inconsistent length between vks_type_a and vks_type_b
350        assert!(BatchArgument::aggregate_verify_keys(&vks_type_a[1..], &vks_type_b).is_err());
351
352        // 5. Verification
353        let open_key_ref = &vks_type_a[0].open_key;
354        let beta_g_ref = &srs.powers_of_g[1];
355        let blinding_factor = E::ScalarField::rand(rng);
356        let (inner1, inner2) = BatchArgument::partial_verify::<T>(
357            beta_g_ref,
358            &open_key_ref.g,
359            &merged_vks,
360            &[shared_public_input],
361            &batch_proof,
362            blinding_factor,
363        )?;
364        assert!(BatchArgument::decide(open_key_ref, inner1, inner2)?);
365        // error paths
366        // empty merged_vks
367        assert!(BatchArgument::partial_verify::<T>(
368            beta_g_ref,
369            &open_key_ref.g,
370            &[],
371            &[shared_public_input],
372            &batch_proof,
373            blinding_factor
374        )
375        .is_err());
376        // the number of vks is different the number of instances
377        assert!(BatchArgument::partial_verify::<T>(
378            beta_g_ref,
379            &open_key_ref.g,
380            &merged_vks[1..],
381            &[shared_public_input],
382            &batch_proof,
383            blinding_factor
384        )
385        .is_err());
386        // inconsistent domain size between verification keys
387        let mut bad_merged_vks = merged_vks;
388        bad_merged_vks[0].domain_size /= 2;
389        assert!(BatchArgument::partial_verify::<T>(
390            beta_g_ref,
391            &open_key_ref.g,
392            &bad_merged_vks,
393            &[shared_public_input],
394            &batch_proof,
395            blinding_factor
396        )
397        .is_err());
398
399        Ok(())
400    }
401}