use crate::{
errors::{PlonkError, SnarkError::ParameterError},
proof_system::{
structs::{BatchProof, OpenKey, ProvingKey, ScalarsAndBases, UniversalSrs, VerifyingKey},
verifier::Verifier,
PlonkKzgSnark, UniversalSNARK,
},
transcript::PlonkTranscript,
};
use ark_ec::{
pairing::Pairing,
short_weierstrass::{Affine, SWCurveConfig},
};
use ark_ff::One;
use ark_std::{
format,
marker::PhantomData,
rand::{CryptoRng, RngCore},
string::ToString,
vec,
vec::Vec,
};
use jf_relation::{gadgets::ecc::SWToTEConParam, Circuit, MergeableCircuitType, PlonkCircuit};
use jf_rescue::RescueParameter;
use jf_utils::multi_pairing;
pub struct BatchArgument<E: Pairing>(PhantomData<E>);
#[derive(Clone)]
pub struct Instance<E: Pairing> {
prove_key: ProvingKey<E>, circuit: PlonkCircuit<E::ScalarField>,
_circuit_type: MergeableCircuitType,
}
impl<E: Pairing> Instance<E> {
pub fn verify_key_ref(&self) -> &VerifyingKey<E> {
&self.prove_key.vk
}
pub fn circuit_mut_ref(&mut self) -> &mut PlonkCircuit<E::ScalarField> {
&mut self.circuit
}
}
impl<E, F, P> BatchArgument<E>
where
E: Pairing<BaseField = F, G1Affine = Affine<P>>,
F: RescueParameter + SWToTEConParam,
P: SWCurveConfig<BaseField = F>,
{
pub fn setup_instance(
srs: &UniversalSrs<E>,
mut circuit: PlonkCircuit<E::ScalarField>,
circuit_type: MergeableCircuitType,
) -> Result<Instance<E>, PlonkError> {
circuit.finalize_for_mergeable_circuit(circuit_type)?;
let (prove_key, _) = PlonkKzgSnark::preprocess(srs, &circuit)?;
Ok(Instance {
prove_key,
circuit,
_circuit_type: circuit_type,
})
}
pub fn batch_prove<R, T>(
prng: &mut R,
instances_type_a: &[Instance<E>],
instances_type_b: &[Instance<E>],
) -> Result<BatchProof<E>, PlonkError>
where
R: CryptoRng + RngCore,
T: PlonkTranscript<F>,
{
if instances_type_a.len() != instances_type_b.len() {
return Err(ParameterError(format!(
"the number of type A instances {} is different from the number of type B instances {}.",
instances_type_a.len(),
instances_type_b.len())
).into());
}
let pks = instances_type_a
.iter()
.zip(instances_type_b.iter())
.map(|(pred_a, pred_b)| pred_a.prove_key.merge(&pred_b.prove_key))
.collect::<Result<Vec<_>, _>>()?;
let circuits = instances_type_a
.iter()
.zip(instances_type_b.iter())
.map(|(pred_a, pred_b)| pred_a.circuit.merge(&pred_b.circuit))
.collect::<Result<Vec<_>, _>>()?;
let pks_ref: Vec<&ProvingKey<E>> = pks.iter().collect();
let circuits_ref: Vec<&PlonkCircuit<E::ScalarField>> = circuits.iter().collect();
PlonkKzgSnark::batch_prove::<_, _, T>(prng, &circuits_ref, &pks_ref)
}
pub fn partial_verify<T>(
beta_g: &E::G1Affine,
generator_g: &E::G1Affine,
merged_vks: &[VerifyingKey<E>],
shared_public_input: &[E::ScalarField],
batch_proof: &BatchProof<E>,
blinding_factor: E::ScalarField,
) -> Result<(E::G1, E::G1), PlonkError>
where
T: PlonkTranscript<F>,
{
if merged_vks.is_empty() {
return Err(ParameterError("empty merged verification keys".to_string()).into());
}
if merged_vks.len() != batch_proof.len() {
return Err(ParameterError(format!(
"the number of verification keys {} is different from the number of instances {}.",
merged_vks.len(),
batch_proof.len()
))
.into());
}
let domain_size = merged_vks[0].domain_size;
for (i, vk) in merged_vks.iter().skip(1).enumerate() {
if vk.domain_size != domain_size {
return Err(ParameterError(format!(
"the {}-th verification key's domain size {} is different from {}.",
i, vk.domain_size, domain_size
))
.into());
}
}
let verifier = Verifier::new(domain_size)?;
let shared_public_input = [shared_public_input, shared_public_input].concat();
let public_inputs = vec![&shared_public_input[..]; merged_vks.len()];
let merged_vks_ref: Vec<&VerifyingKey<E>> = merged_vks.iter().collect();
let pcs_info =
verifier.prepare_pcs_info::<T>(&merged_vks_ref, &public_inputs, batch_proof, &None)?;
let mut scalars_and_bases = ScalarsAndBases::<E>::new();
scalars_and_bases.push(E::ScalarField::one(), pcs_info.opening_proof.0);
scalars_and_bases.push(pcs_info.u, pcs_info.shifted_opening_proof.0);
scalars_and_bases.push(blinding_factor, *generator_g);
let inner1 = scalars_and_bases.multi_scalar_mul();
let mut scalars_and_bases = pcs_info.comm_scalars_and_bases;
scalars_and_bases.push(pcs_info.eval_point, pcs_info.opening_proof.0);
scalars_and_bases.push(
pcs_info.next_eval_point * pcs_info.u,
pcs_info.shifted_opening_proof.0,
);
scalars_and_bases.push(-pcs_info.eval, *generator_g);
scalars_and_bases.push(blinding_factor, *beta_g);
let inner2 = scalars_and_bases.multi_scalar_mul();
Ok((inner1, inner2))
}
}
impl<E> BatchArgument<E>
where
E: Pairing,
{
pub fn aggregate_verify_keys(
vks_type_a: &[&VerifyingKey<E>],
vks_type_b: &[&VerifyingKey<E>],
) -> Result<Vec<VerifyingKey<E>>, PlonkError> {
if vks_type_a.len() != vks_type_b.len() {
return Err(ParameterError(format!(
"the number of type A verification keys {} is different from the number of type B verification keys {}.",
vks_type_a.len(),
vks_type_b.len())
).into());
}
vks_type_a
.iter()
.zip(vks_type_b.iter())
.map(|(vk_a, vk_b)| vk_a.merge(vk_b))
.collect::<Result<Vec<_>, PlonkError>>()
}
pub fn decide(open_key: &OpenKey<E>, inner1: E::G1, inner2: E::G1) -> Result<bool, PlonkError> {
let g1_elems: Vec<<E as Pairing>::G1Affine> = vec![inner1.into(), (-inner2).into()];
let g2_elems = vec![open_key.beta_h, open_key.h];
Ok(multi_pairing::<E>(&g1_elems, &g2_elems).0 == E::TargetField::one())
}
}
pub(crate) fn new_mergeable_circuit_for_test<E: Pairing>(
shared_public_input: E::ScalarField,
i: usize,
circuit_type: MergeableCircuitType,
) -> Result<PlonkCircuit<E::ScalarField>, PlonkError> {
let mut circuit = PlonkCircuit::new_turbo_plonk();
let shared_pub_var = circuit.create_public_variable(shared_public_input)?;
let mut var = shared_pub_var;
if circuit_type == MergeableCircuitType::TypeA {
for _ in 0..i {
var = circuit.add(var, shared_pub_var)?;
}
} else {
for _ in 0..i {
var = circuit.mul(var, shared_pub_var)?;
}
}
Ok(circuit)
}
#[allow(clippy::type_complexity)]
pub fn build_batch_proof_and_vks_for_test<E, F, P, R, T>(
rng: &mut R,
srs: &UniversalSrs<E>,
num_instances: usize,
shared_public_input: E::ScalarField,
) -> Result<(BatchProof<E>, Vec<VerifyingKey<E>>, Vec<VerifyingKey<E>>), PlonkError>
where
E: Pairing<BaseField = F, G1Affine = Affine<P>>,
F: RescueParameter + SWToTEConParam,
P: SWCurveConfig<BaseField = F>,
R: CryptoRng + RngCore,
T: PlonkTranscript<F>,
{
let mut instances_type_a = vec![];
let mut instances_type_b = vec![];
let mut vks_type_a = vec![];
let mut vks_type_b = vec![];
for i in 10..10 + num_instances {
let circuit = new_mergeable_circuit_for_test::<E>(
shared_public_input,
i,
MergeableCircuitType::TypeA,
)?;
let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeA)?;
vks_type_a.push(instance.verify_key_ref().clone());
instances_type_a.push(instance);
let circuit = new_mergeable_circuit_for_test::<E>(
shared_public_input,
i,
MergeableCircuitType::TypeB,
)?;
let instance = BatchArgument::setup_instance(srs, circuit, MergeableCircuitType::TypeB)?;
vks_type_b.push(instance.verify_key_ref().clone());
instances_type_b.push(instance);
}
let batch_proof =
BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
Ok((batch_proof, vks_type_a, vks_type_b))
}
#[cfg(test)]
mod test {
use super::*;
use crate::transcript::RescueTranscript;
use ark_bls12_377::{Bls12_377, Fq as Fq377};
use ark_std::UniformRand;
use jf_utils::test_rng;
#[test]
fn test_batch_argument() -> Result<(), PlonkError> {
test_batch_argument_helper::<Bls12_377, Fq377, _, RescueTranscript<_>>()
}
fn test_batch_argument_helper<E, F, P, T>() -> Result<(), PlonkError>
where
E: Pairing<BaseField = F, G1Affine = Affine<P>>,
F: RescueParameter + SWToTEConParam,
P: SWCurveConfig<BaseField = F>,
T: PlonkTranscript<F>,
{
let rng = &mut test_rng();
let n = 128;
let max_degree = n + 2;
let srs = PlonkKzgSnark::<E>::universal_setup_for_testing(max_degree, rng)?;
let shared_public_input = E::ScalarField::rand(rng);
let mut instances_type_a = vec![];
let mut instances_type_b = vec![];
for i in 32..50 {
let circuit = new_mergeable_circuit_for_test::<E>(
shared_public_input,
i,
MergeableCircuitType::TypeA,
)?;
let instance =
BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeA)?;
instances_type_a.push(instance);
let circuit = new_mergeable_circuit_for_test::<E>(
shared_public_input,
i,
MergeableCircuitType::TypeB,
)?;
let instance =
BatchArgument::setup_instance(&srs, circuit, MergeableCircuitType::TypeB)?;
instances_type_b.push(instance);
}
let batch_proof =
BatchArgument::batch_prove::<_, T>(rng, &instances_type_a, &instances_type_b)?;
assert!(
BatchArgument::batch_prove::<_, T>(rng, &instances_type_a[1..], &instances_type_b)
.is_err()
);
let vks_type_a: Vec<&VerifyingKey<E>> = instances_type_a
.iter()
.map(|pred| pred.verify_key_ref())
.collect();
let vks_type_b: Vec<&VerifyingKey<E>> = instances_type_b
.iter()
.map(|pred| pred.verify_key_ref())
.collect();
let merged_vks = BatchArgument::aggregate_verify_keys(&vks_type_a, &vks_type_b)?;
assert!(BatchArgument::aggregate_verify_keys(&vks_type_a[1..], &vks_type_b).is_err());
let open_key_ref = &vks_type_a[0].open_key;
let beta_g_ref = &srs.powers_of_g[1];
let blinding_factor = E::ScalarField::rand(rng);
let (inner1, inner2) = BatchArgument::partial_verify::<T>(
beta_g_ref,
&open_key_ref.g,
&merged_vks,
&[shared_public_input],
&batch_proof,
blinding_factor,
)?;
assert!(BatchArgument::decide(open_key_ref, inner1, inner2)?);
assert!(BatchArgument::partial_verify::<T>(
beta_g_ref,
&open_key_ref.g,
&[],
&[shared_public_input],
&batch_proof,
blinding_factor
)
.is_err());
assert!(BatchArgument::partial_verify::<T>(
beta_g_ref,
&open_key_ref.g,
&merged_vks[1..],
&[shared_public_input],
&batch_proof,
blinding_factor
)
.is_err());
let mut bad_merged_vks = merged_vks;
bad_merged_vks[0].domain_size /= 2;
assert!(BatchArgument::partial_verify::<T>(
beta_g_ref,
&open_key_ref.g,
&bad_merged_vks,
&[shared_public_input],
&batch_proof,
blinding_factor
)
.is_err());
Ok(())
}
}