1use crate::{
9 errors::{PlonkError, SnarkError::ParameterError},
10 proof_system::{
11 structs::{BatchProof, OpenKey, ProvingKey, ScalarsAndBases, UniversalSrs, VerifyingKey},
12 verifier::Verifier,
13 PlonkKzgSnark, UniversalSNARK,
14 },
15 transcript::PlonkTranscript,
16};
17use ark_ec::{
18 pairing::Pairing,
19 short_weierstrass::{Affine, SWCurveConfig},
20};
21use ark_ff::One;
22use ark_std::{
23 format,
24 marker::PhantomData,
25 rand::{CryptoRng, RngCore},
26 string::ToString,
27 vec,
28 vec::Vec,
29};
30use jf_relation::{gadgets::ecc::SWToTEConParam, Circuit, MergeableCircuitType, PlonkCircuit};
31use jf_rescue::RescueParameter;
32use jf_utils::multi_pairing;
33
34pub struct BatchArgument<E: Pairing>(PhantomData<E>);
36
37#[derive(Clone)]
40pub struct Instance<E: Pairing> {
41 prove_key: ProvingKey<E>, circuit: PlonkCircuit<E::ScalarField>,
44 _circuit_type: MergeableCircuitType,
45}
46
47impl<E: Pairing> Instance<E> {
48 pub fn verify_key_ref(&self) -> &VerifyingKey<E> {
50 &self.prove_key.vk
51 }
52
53 pub fn circuit_mut_ref(&mut self) -> &mut PlonkCircuit<E::ScalarField> {
55 &mut self.circuit
56 }
57}
58
59impl<E, F, P> BatchArgument<E>
60where
61 E: Pairing<BaseField = F, G1Affine = Affine<P>>,
62 F: RescueParameter + SWToTEConParam,
63 P: SWCurveConfig<BaseField = F>,
64{
65 pub fn setup_instance(
67 srs: &UniversalSrs<E>,
68 mut circuit: PlonkCircuit<E::ScalarField>,
69 circuit_type: MergeableCircuitType,
70 ) -> Result<Instance<E>, PlonkError> {
71 circuit.finalize_for_mergeable_circuit(circuit_type)?;
72 let (prove_key, _) = PlonkKzgSnark::preprocess(srs, &circuit)?;
73 Ok(Instance {
74 prove_key,
75 circuit,
76 _circuit_type: circuit_type,
77 })
78 }
79
80 pub fn batch_prove<R, T>(
82 prng: &mut R,
83 instances_type_a: &[Instance<E>],
84 instances_type_b: &[Instance<E>],
85 ) -> Result<BatchProof<E>, PlonkError>
86 where
87 R: CryptoRng + RngCore,
88 T: PlonkTranscript<F>,
89 {
90 if instances_type_a.len() != instances_type_b.len() {
91 return Err(ParameterError(format!(
92 "the number of type A instances {} is different from the number of type B instances {}.",
93 instances_type_a.len(),
94 instances_type_b.len())
95 ).into());
96 }
97 let pks = instances_type_a
98 .iter()
99 .zip(instances_type_b.iter())
100 .map(|(pred_a, pred_b)| pred_a.prove_key.merge(&pred_b.prove_key))
101 .collect::<Result<Vec<_>, _>>()?;
102
103 let circuits = instances_type_a
104 .iter()
105 .zip(instances_type_b.iter())
106 .map(|(pred_a, pred_b)| pred_a.circuit.merge(&pred_b.circuit))
107 .collect::<Result<Vec<_>, _>>()?;
108 let pks_ref: Vec<&ProvingKey<E>> = pks.iter().collect();
109 let circuits_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
110
111 PlonkKzgSnark::batch_prove::<_, _, T>(prng, &circuits_ref, &pks_ref)
112 }
113
114 pub fn partial_verify<T>(
117 beta_g: &E::G1Affine,
118 generator_g: &E::G1Affine,
119 merged_vks: &[VerifyingKey<E>],
120 shared_public_input: &[E::ScalarField],
121 batch_proof: &BatchProof<E>,
122 blinding_factor: E::ScalarField,
123 ) -> Result<(E::G1, E::G1), PlonkError>
124 where
125 T: PlonkTranscript<F>,
126 {
127 if merged_vks.is_empty() {
128 return Err(ParameterError("empty merged verification keys".to_string()).into());
129 }
130 if merged_vks.len() != batch_proof.len() {
131 return Err(ParameterError(format!(
132 "the number of verification keys {} is different from the number of instances {}.",
133 merged_vks.len(),
134 batch_proof.len()
135 ))
136 .into());
137 }
138 let domain_size = merged_vks[0].domain_size;
139 for (i, vk) in merged_vks.iter().skip(1).enumerate() {
140 if vk.domain_size != domain_size {
141 return Err(ParameterError(format!(
142 "the {}-th verification key's domain size {} is different from {}.",
143 i, vk.domain_size, domain_size
144 ))
145 .into());
146 }
147 }
148 let verifier = Verifier::new(domain_size)?;
149 let shared_public_input = [shared_public_input, shared_public_input].concat();
151 let public_inputs = vec![&shared_public_input[..]; merged_vks.len()];
152 let merged_vks_ref: Vec<&VerifyingKey<E>> = merged_vks.iter().collect();
153 let pcs_info =
154 verifier.prepare_pcs_info::<T>(&merged_vks_ref, &public_inputs, batch_proof, &None)?;
155
156 let mut scalars_and_bases = ScalarsAndBases::<E>::new();
158 scalars_and_bases.push(E::ScalarField::one(), pcs_info.opening_proof.0);
159 scalars_and_bases.push(pcs_info.u, pcs_info.shifted_opening_proof.0);
160 scalars_and_bases.push(blinding_factor, *generator_g);
161 let inner1 = scalars_and_bases.multi_scalar_mul();
162
163 let mut scalars_and_bases = pcs_info.comm_scalars_and_bases;
167 scalars_and_bases.push(pcs_info.eval_point, pcs_info.opening_proof.0);
168 scalars_and_bases.push(
169 pcs_info.next_eval_point * pcs_info.u,
170 pcs_info.shifted_opening_proof.0,
171 );
172 scalars_and_bases.push(-pcs_info.eval, *generator_g);
173 scalars_and_bases.push(blinding_factor, *beta_g);
174 let inner2 = scalars_and_bases.multi_scalar_mul();
175
176 Ok((inner1, inner2))
177 }
178}
179
180impl<E> BatchArgument<E>
181where
182 E: Pairing,
183{
184 pub fn aggregate_verify_keys(
186 vks_type_a: &[&VerifyingKey<E>],
187 vks_type_b: &[&VerifyingKey<E>],
188 ) -> Result<Vec<VerifyingKey<E>>, PlonkError> {
189 if vks_type_a.len() != vks_type_b.len() {
190 return Err(ParameterError(format!(
191 "the number of type A verification keys {} is different from the number of type B verification keys {}.",
192 vks_type_a.len(),
193 vks_type_b.len())
194 ).into());
195 }
196 vks_type_a
197 .iter()
198 .zip(vks_type_b.iter())
199 .map(|(vk_a, vk_b)| vk_a.merge(vk_b))
200 .collect::<Result<Vec<_>, PlonkError>>()
201 }
202
203 pub fn decide(open_key: &OpenKey<E>, inner1: E::G1, inner2: E::G1) -> Result<bool, PlonkError> {
205 let g1_elems: Vec<<E as Pairing>::G1Affine> = vec![inner1.into(), (-inner2).into()];
207 let g2_elems = vec![open_key.beta_h, open_key.h];
208 Ok(multi_pairing::<E>(&g1_elems, &g2_elems).0 == E::TargetField::one())
209 }
210}
211
212pub(crate) fn new_mergeable_circuit_for_test<E: Pairing>(
213 shared_public_input: E::ScalarField,
214 i: usize,
215 circuit_type: MergeableCircuitType,
216) -> Result<PlonkCircuit<E::ScalarField>, PlonkError> {
217 let mut circuit = PlonkCircuit::new_turbo_plonk();
218 let shared_pub_var = circuit.create_public_variable(shared_public_input)?;
219 let mut var = shared_pub_var;
220 if circuit_type == MergeableCircuitType::TypeA {
221 for _ in 0..i {
223 var = circuit.add(var, shared_pub_var)?;
224 }
225 } else {
226 for _ in 0..i {
228 var = circuit.mul(var, shared_pub_var)?;
229 }
230 }
231 Ok(circuit)
232}
233
234#[allow(clippy::type_complexity)]
237pub fn build_batch_proof_and_vks_for_test<E, F, P, R, T>(
238 rng: &mut R,
239 srs: &UniversalSrs<E>,
240 num_instances: usize,
241 shared_public_input: E::ScalarField,
242) -> Result<(BatchProof<E>, Vec<VerifyingKey<E>>, Vec<VerifyingKey<E>>), PlonkError>
243where
244 E: Pairing<BaseField = F, G1Affine = Affine<P>>,
245 F: RescueParameter + SWToTEConParam,
246 P: SWCurveConfig<BaseField = F>,
247 R: CryptoRng + RngCore,
248 T: PlonkTranscript<F>,
249{
250 let mut instances_type_a = vec![];
251 let mut instances_type_b = vec![];
252 let mut vks_type_a = vec![];
253 let mut vks_type_b = vec![];
254 for i in 10..10 + num_instances {
255 let circuit = new_mergeable_circuit_for_test::<E>(
256 shared_public_input,
257 i,
258 MergeableCircuitType::TypeA,
259 )?;
260 let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeA)?;
261 vks_type_a.push(instance.verify_key_ref().clone());
262 instances_type_a.push(instance);
263
264 let circuit = new_mergeable_circuit_for_test::<E>(
265 shared_public_input,
266 i,
267 MergeableCircuitType::TypeB,
268 )?;
269 let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeB)?;
270 vks_type_b.push(instance.verify_key_ref().clone());
271 instances_type_b.push(instance);
272 }
273
274 let batch_proof =
275 BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
276 Ok((batch_proof, vks_type_a, vks_type_b))
277}
278
279#[cfg(test)]
280mod test {
281 use super::*;
282 use crate::transcript::RescueTranscript;
283 use ark_bls12_377::{Bls12_377, Fq as Fq377};
284 use ark_std::UniformRand;
285 use jf_utils::test_rng;
286
287 #[test]
288 fn test_batch_argument() -> Result<(), PlonkError> {
289 test_batch_argument_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>()
290 }
291
292 fn test_batch_argument_helper<E, F, P, T>() -> Result<(), PlonkError>
293 where
294 E: Pairing<BaseField = F, G1Affine = Affine<P>>,
295 F: RescueParameter + SWToTEConParam,
296 P: SWCurveConfig<BaseField = F>,
297 T: PlonkTranscript<F>,
298 {
299 let rng = &mut test_rng();
301 let n = 128;
302 let max_degree = n + 2;
303 let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
304
305 let shared_public_input = E::ScalarField::rand(rng);
307 let mut instances_type_a = vec![];
308 let mut instances_type_b = vec![];
309 for i in 32..50 {
310 let circuit = new_mergeable_circuit_for_test::<E>(
311 shared_public_input,
312 i,
313 MergeableCircuitType::TypeA,
314 )?;
315 let instance =
316 BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeA)?;
317 instances_type_a.push(instance);
318
319 let circuit = new_mergeable_circuit_for_test::<E>(
320 shared_public_input,
321 i,
322 MergeableCircuitType::TypeB,
323 )?;
324 let instance =
325 BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeB)?;
326 instances_type_b.push(instance);
327 }
328
329 let batch_proof =
331 BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
332 assert!(
335 BatchArgument::batch_prove::<_, T>(rng, &instances_type_a[1..], &instances_type_b)
336 .is_err()
337 );
338
339 let vks_type_a: Vec<&VerifyingKey<E>> = instances_type_a
341 .iter()
342 .map(|pred| pred.verify_key_ref())
343 .collect();
344 let vks_type_b: Vec<&VerifyingKey<E>> = instances_type_b
345 .iter()
346 .map(|pred| pred.verify_key_ref())
347 .collect();
348 let merged_vks = BatchArgument::aggregate_verify_keys(&vks_type_a, &vks_type_b)?;
349 assert!(BatchArgument::aggregate_verify_keys(&vks_type_a[1..], &vks_type_b).is_err());
351
352 let open_key_ref = &vks_type_a[0].open_key;
354 let beta_g_ref = &srs.powers_of_g[1];
355 let blinding_factor = E::ScalarField::rand(rng);
356 let (inner1, inner2) = BatchArgument::partial_verify::<T>(
357 beta_g_ref,
358 &open_key_ref.g,
359 &merged_vks,
360 &[shared_public_input],
361 &batch_proof,
362 blinding_factor,
363 )?;
364 assert!(BatchArgument::decide(open_key_ref, inner1, inner2)?);
365 assert!(BatchArgument::partial_verify::<T>(
368 beta_g_ref,
369 &open_key_ref.g,
370 &[],
371 &[shared_public_input],
372 &batch_proof,
373 blinding_factor
374 )
375 .is_err());
376 assert!(BatchArgument::partial_verify::<T>(
378 beta_g_ref,
379 &open_key_ref.g,
380 &merged_vks[1..],
381 &[shared_public_input],
382 &batch_proof,
383 blinding_factor
384 )
385 .is_err());
386 let mut bad_merged_vks = merged_vks;
388 bad_merged_vks[0].domain_size /= 2;
389 assert!(BatchArgument::partial_verify::<T>(
390 beta_g_ref,
391 &open_key_ref.g,
392 &bad_merged_vks,
393 &[shared_public_input],
394 &batch_proof,
395 blinding_factor
396 )
397 .is_err());
398
399 Ok(())
400 }
401}