use super::{PointVariable, TEPoint};
use crate::{Circuit, CircuitError, PlonkCircuit, Variable};
use ark_ec::{
twisted_edwards::{Projective, TECurveConfig as Config},
CurveConfig,
};
use ark_ff::{BigInteger, PrimeField};
use ark_std::{format, vec, vec::Vec};
use jf_utils::fq_to_fr;
pub trait MultiScalarMultiplicationCircuit<F, P>
where
F: PrimeField,
P: Config<BaseField = F>,
{
fn msm(
&mut self,
bases: &[PointVariable],
scalars: &[Variable],
) -> Result<PointVariable, CircuitError>;
fn msm_with_var_scalar_length(
&mut self,
bases: &[PointVariable],
scalars: &[Variable],
scalar_bit_length: usize,
) -> Result<PointVariable, CircuitError>;
}
impl<F, P> MultiScalarMultiplicationCircuit<F, P> for PlonkCircuit<F>
where
F: PrimeField,
P: Config<BaseField = F>,
{
fn msm(
&mut self,
bases: &[PointVariable],
scalars: &[Variable],
) -> Result<PointVariable, CircuitError> {
let scalar_bit_length = <P as CurveConfig>::ScalarField::MODULUS_BIT_SIZE as usize;
MultiScalarMultiplicationCircuit::<F, P>::msm_with_var_scalar_length(
self,
bases,
scalars,
scalar_bit_length,
)
}
fn msm_with_var_scalar_length(
&mut self,
bases: &[PointVariable],
scalars: &[Variable],
scalar_bit_length: usize,
) -> Result<PointVariable, CircuitError> {
if bases.len() != scalars.len() {
return Err(CircuitError::ParameterError(format!(
"bases length ({}) does not match scalar length ({})",
bases.len(),
scalars.len()
)));
}
if self.support_lookup() {
msm_pippenger::<F, P>(self, bases, scalars, scalar_bit_length)
} else {
msm_naive::<F, P>(self, bases, scalars, scalar_bit_length)
}
}
}
fn msm_naive<F, P>(
circuit: &mut PlonkCircuit<F>,
bases: &[PointVariable],
scalars: &[Variable],
scalar_bit_length: usize,
) -> Result<PointVariable, CircuitError>
where
F: PrimeField,
P: Config<BaseField = F>,
{
circuit.check_vars_bound(scalars)?;
for base in bases.iter() {
circuit.check_point_var_bound(base)?;
}
let scalar_0_bits_le = circuit.unpack(scalars[0], scalar_bit_length)?;
let mut res = circuit.variable_base_binary_scalar_mul::<P>(&scalar_0_bits_le, &bases[0])?;
for (base, scalar) in bases.iter().zip(scalars.iter()).skip(1) {
let scalar_bits_le = circuit.unpack(*scalar, scalar_bit_length)?;
let tmp = circuit.variable_base_binary_scalar_mul::<P>(&scalar_bits_le, base)?;
res = circuit.ecc_add::<P>(&res, &tmp)?;
}
Ok(res)
}
fn msm_pippenger<F, P>(
circuit: &mut PlonkCircuit<F>,
bases: &[PointVariable],
scalars: &[Variable],
scalar_bit_length: usize,
) -> Result<PointVariable, CircuitError>
where
F: PrimeField,
P: Config<BaseField = F>,
{
for (&scalar, base) in scalars.iter().zip(bases.iter()) {
circuit.check_var_bound(scalar)?;
circuit.check_point_var_bound(base)?;
}
let c = if scalar_bit_length < 32 {
3
} else {
ln_without_floats(scalar_bit_length)
};
let point_zero_var = circuit.neutral_point_variable();
let mut window_sums = Vec::new();
for (base_var, &scalar_var) in bases.iter().zip(scalars.iter()) {
let decomposed_scalar_vars =
decompose_scalar_var(circuit, scalar_var, c, scalar_bit_length)?;
let mut table_point_vars = vec![point_zero_var, *base_var];
for _ in 0..((1 << c) - 2) {
let point_var = circuit.ecc_add::<P>(base_var, table_point_vars.last().unwrap())?;
table_point_vars.push(point_var);
}
let mut lookup_point_vars = Vec::new();
for &scalar_var in decomposed_scalar_vars.iter() {
let lookup_point = compute_scalar_mul_value::<F, P>(circuit, scalar_var, base_var)?;
let lookup_point_var = circuit.create_point_variable(lookup_point)?;
lookup_point_vars.push(lookup_point_var);
}
create_point_lookup_gates(
circuit,
&table_point_vars,
&decomposed_scalar_vars,
&lookup_point_vars,
)?;
if window_sums.is_empty() {
window_sums = lookup_point_vars;
} else {
for (window_sum_mut, lookup_point_var) in
window_sums.iter_mut().zip(lookup_point_vars.iter())
{
*window_sum_mut = circuit.ecc_add::<P>(window_sum_mut, lookup_point_var)?;
}
}
}
let lowest = *window_sums.first().unwrap();
let b = &window_sums[1..]
.iter()
.rev()
.fold(point_zero_var, |mut total, sum_i| {
total = circuit.ecc_add::<P>(&total, sum_i).unwrap();
for _ in 0..c {
total = circuit.ecc_add::<P>(&total, &total).unwrap();
}
total
});
circuit.ecc_add::<P>(&lowest, b)
}
#[inline]
fn create_point_lookup_gates<F>(
circuit: &mut PlonkCircuit<F>,
table_point_vars: &[PointVariable],
lookup_scalar_vars: &[Variable],
lookup_point_vars: &[PointVariable],
) -> Result<(), CircuitError>
where
F: PrimeField,
{
let table_vars: Vec<(Variable, Variable)> = table_point_vars
.iter()
.map(|p| (p.get_x(), p.get_y()))
.collect();
let lookup_vars: Vec<(Variable, Variable, Variable)> = lookup_scalar_vars
.iter()
.zip(lookup_point_vars.iter())
.map(|(&s, pt)| (s, pt.get_x(), pt.get_y()))
.collect();
circuit.create_table_and_lookup_variables(&lookup_vars, &table_vars)
}
#[inline]
fn decompose_scalar_var<F>(
circuit: &mut PlonkCircuit<F>,
scalar_var: Variable,
c: usize,
scalar_bit_length: usize,
) -> Result<Vec<Variable>, CircuitError>
where
F: PrimeField,
{
let m = (scalar_bit_length - 1) / c + 1;
let mut scalar_val = circuit.witness(scalar_var)?.into_bigint();
let decomposed_scalar_vars = (0..m)
.map(|_| {
let scalar_u64 = scalar_val.as_ref()[0] % (1 << c);
scalar_val.divn(c as u32);
circuit.create_variable(F::from(scalar_u64))
})
.collect::<Result<Vec<_>, _>>()?;
let range_size = F::from((1 << c) as u32);
circuit.decomposition_gate(decomposed_scalar_vars.clone(), scalar_var, range_size)?;
Ok(decomposed_scalar_vars)
}
#[inline]
fn compute_scalar_mul_value<F, P>(
circuit: &PlonkCircuit<F>,
scalar_var: Variable,
base_var: &PointVariable,
) -> Result<TEPoint<F>, CircuitError>
where
F: PrimeField,
P: Config<BaseField = F>,
{
let curve_point: Projective<P> = circuit.point_witness(base_var)?.into();
let scalar = fq_to_fr::<F, P>(&circuit.witness(scalar_var)?);
let res = curve_point * scalar;
Ok(res.into())
}
fn ln_without_floats(a: usize) -> usize {
(ark_std::log2(a) * 69 / 100) as usize
}
#[cfg(test)]
mod tests {
use super::*;
use crate::PlonkType;
use ark_bls12_377::{g1::Config as Param377, Fq as Fq377};
use ark_ec::{
scalar_mul::variable_base::VariableBaseMSM,
twisted_edwards::{Affine, TECurveConfig as Config},
};
use ark_ed_on_bls12_377::{EdwardsConfig as ParamEd377, Fq as FqEd377};
use ark_ed_on_bls12_381::{EdwardsConfig as ParamEd381, Fq as FqEd381};
use ark_ed_on_bn254::{EdwardsConfig as ParamEd254, Fq as FqEd254};
use ark_ff::UniformRand;
use jf_utils::fr_to_fq;
const RANGE_BIT_LEN_FOR_TEST: usize = 8;
#[test]
fn test_variable_base_multi_scalar_mul() -> Result<(), CircuitError> {
test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::TurboPlonk)?;
test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::UltraPlonk)?;
test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::TurboPlonk)?;
test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::UltraPlonk)?;
test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::TurboPlonk)?;
test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::UltraPlonk)?;
test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::TurboPlonk)?;
test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::UltraPlonk)?;
Ok(())
}
fn test_variable_base_multi_scalar_mul_helper<F, P>(
plonk_type: PlonkType,
) -> Result<(), CircuitError>
where
F: PrimeField,
P: Config<BaseField = F>,
{
let mut rng = jf_utils::test_rng();
for dim in [1, 2, 4, 8, 16, 32, 64, 128] {
let mut circuit: PlonkCircuit<F> = match plonk_type {
PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(RANGE_BIT_LEN_FOR_TEST),
};
let bases: Vec<Affine<P>> = (0..dim).map(|_| Affine::<P>::rand(&mut rng)).collect();
let scalars: Vec<P::ScalarField> =
(0..dim).map(|_| P::ScalarField::rand(&mut rng)).collect();
let scalar_reprs: Vec<<P::ScalarField as PrimeField>::BigInt> =
scalars.iter().map(|x| x.into_bigint()).collect();
let res = Projective::<P>::msm_bigint(&bases, &scalar_reprs);
let res_point: TEPoint<F> = res.into();
let bases_point: Vec<TEPoint<F>> = bases.iter().map(|x| (*x).into()).collect();
let bases_vars: Vec<PointVariable> = bases_point
.iter()
.map(|x| circuit.create_point_variable(*x))
.collect::<Result<Vec<_>, _>>()?;
let scalar_vars: Vec<Variable> = scalars
.iter()
.map(|x| circuit.create_variable(fr_to_fq::<F, P>(x)))
.collect::<Result<Vec<_>, _>>()?;
let res_var = MultiScalarMultiplicationCircuit::<F, P>::msm(
&mut circuit,
&bases_vars,
&scalar_vars,
)?;
assert_eq!(circuit.point_witness(&res_var)?, res_point);
*circuit.witness_mut(2) = F::rand(&mut rng);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
&mut circuit,
&bases_vars[0..dim - 1],
&scalar_vars
)
.is_err());
let var_number = circuit.num_vars();
assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
&mut circuit,
&[PointVariable(var_number, var_number)],
&scalar_vars
)
.is_err());
assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
&mut circuit,
&bases_vars,
&[var_number]
)
.is_err());
}
Ok(())
}
}