1use super::{PointVariable, TEPoint};
10use crate::{Circuit, CircuitError, PlonkCircuit, Variable};
11use ark_ec::{
12 twisted_edwards::{Projective, TECurveConfig as Config},
13 CurveConfig,
14};
15use ark_ff::{BigInteger, PrimeField};
16use ark_std::{format, vec, vec::Vec};
17use jf_utils::fq_to_fr;
18
19pub trait MultiScalarMultiplicationCircuit<F, P>
21where
22 F: PrimeField,
23 P: Config<BaseField = F>,
24{
25 fn msm(
30 &mut self,
31 bases: &[PointVariable],
32 scalars: &[Variable],
33 ) -> Result<PointVariable, CircuitError>;
34
35 fn msm_with_var_scalar_length(
38 &mut self,
39 bases: &[PointVariable],
40 scalars: &[Variable],
41 scalar_bit_length: usize,
42 ) -> Result<PointVariable, CircuitError>;
43}
44
45impl<F, P> MultiScalarMultiplicationCircuit<F, P> for PlonkCircuit<F>
46where
47 F: PrimeField,
48 P: Config<BaseField = F>,
49{
50 fn msm(
51 &mut self,
52 bases: &[PointVariable],
53 scalars: &[Variable],
54 ) -> Result<PointVariable, CircuitError> {
55 let scalar_bit_length = <P as CurveConfig>::ScalarField::MODULUS_BIT_SIZE as usize;
56 MultiScalarMultiplicationCircuit::<F, P>::msm_with_var_scalar_length(
57 self,
58 bases,
59 scalars,
60 scalar_bit_length,
61 )
62 }
63
64 fn msm_with_var_scalar_length(
65 &mut self,
66 bases: &[PointVariable],
67 scalars: &[Variable],
68 scalar_bit_length: usize,
69 ) -> Result<PointVariable, CircuitError> {
70 if bases.len() != scalars.len() {
71 return Err(CircuitError::ParameterError(format!(
72 "bases length ({}) does not match scalar length ({})",
73 bases.len(),
74 scalars.len()
75 )));
76 }
77
78 if self.support_lookup() {
79 msm_pippenger::<F, P>(self, bases, scalars, scalar_bit_length)
80 } else {
81 msm_naive::<F, P>(self, bases, scalars, scalar_bit_length)
82 }
83 }
84}
85
86fn msm_naive<F, P>(
126 circuit: &mut PlonkCircuit<F>,
127 bases: &[PointVariable],
128 scalars: &[Variable],
129 scalar_bit_length: usize,
130) -> Result<PointVariable, CircuitError>
131where
132 F: PrimeField,
133 P: Config<BaseField = F>,
134{
135 circuit.check_vars_bound(scalars)?;
136 for base in bases.iter() {
137 circuit.check_point_var_bound(base)?;
138 }
139
140 let scalar_0_bits_le = circuit.unpack(scalars[0], scalar_bit_length)?;
141 let mut res = circuit.variable_base_binary_scalar_mul::<P>(&scalar_0_bits_le, &bases[0])?;
142
143 for (base, scalar) in bases.iter().zip(scalars.iter()).skip(1) {
144 let scalar_bits_le = circuit.unpack(*scalar, scalar_bit_length)?;
145 let tmp = circuit.variable_base_binary_scalar_mul::<P>(&scalar_bits_le, base)?;
146 res = circuit.ecc_add::<P>(&res, &tmp)?;
147 }
148
149 Ok(res)
150}
151
152fn msm_pippenger<F, P>(
190 circuit: &mut PlonkCircuit<F>,
191 bases: &[PointVariable],
192 scalars: &[Variable],
193 scalar_bit_length: usize,
194) -> Result<PointVariable, CircuitError>
195where
196 F: PrimeField,
197 P: Config<BaseField = F>,
198{
199 for (&scalar, base) in scalars.iter().zip(bases.iter()) {
203 circuit.check_var_bound(scalar)?;
204 circuit.check_point_var_bound(base)?;
205 }
206
207 let c = if scalar_bit_length < 32 {
211 3
212 } else {
213 ln_without_floats(scalar_bit_length)
214 };
215
216 let point_zero_var = circuit.neutral_point_variable();
220 let mut window_sums = Vec::new();
224 for (base_var, &scalar_var) in bases.iter().zip(scalars.iter()) {
225 let decomposed_scalar_vars =
227 decompose_scalar_var(circuit, scalar_var, c, scalar_bit_length)?;
228
229 let mut table_point_vars = vec![point_zero_var, *base_var];
231 for _ in 0..((1 << c) - 2) {
232 let point_var = circuit.ecc_add::<P>(base_var, table_point_vars.last().unwrap())?;
233 table_point_vars.push(point_var);
234 }
235
236 let mut lookup_point_vars = Vec::new();
238 for &scalar_var in decomposed_scalar_vars.iter() {
239 let lookup_point = compute_scalar_mul_value::<F, P>(circuit, scalar_var, base_var)?;
240 let lookup_point_var = circuit.create_point_variable(lookup_point)?;
241 lookup_point_vars.push(lookup_point_var);
242 }
243
244 create_point_lookup_gates(
245 circuit,
246 &table_point_vars,
247 &decomposed_scalar_vars,
248 &lookup_point_vars,
249 )?;
250
251 if window_sums.is_empty() {
253 window_sums = lookup_point_vars;
254 } else {
255 for (window_sum_mut, lookup_point_var) in
256 window_sums.iter_mut().zip(lookup_point_vars.iter())
257 {
258 *window_sum_mut = circuit.ecc_add::<P>(window_sum_mut, lookup_point_var)?;
259 }
260 }
261 }
262
263 let lowest = *window_sums.first().unwrap();
268
269 let b = &window_sums[1..]
271 .iter()
272 .rev()
273 .fold(point_zero_var, |mut total, sum_i| {
274 total = circuit.ecc_add::<P>(&total, sum_i).unwrap();
276 for _ in 0..c {
277 total = circuit.ecc_add::<P>(&total, &total).unwrap();
279 }
280 total
281 });
282 circuit.ecc_add::<P>(&lowest, b)
283}
284
285#[inline]
286fn create_point_lookup_gates<F>(
287 circuit: &mut PlonkCircuit<F>,
288 table_point_vars: &[PointVariable],
289 lookup_scalar_vars: &[Variable],
290 lookup_point_vars: &[PointVariable],
291) -> Result<(), CircuitError>
292where
293 F: PrimeField,
294{
295 let table_vars: Vec<(Variable, Variable)> = table_point_vars
296 .iter()
297 .map(|p| (p.get_x(), p.get_y()))
298 .collect();
299 let lookup_vars: Vec<(Variable, Variable, Variable)> = lookup_scalar_vars
300 .iter()
301 .zip(lookup_point_vars.iter())
302 .map(|(&s, pt)| (s, pt.get_x(), pt.get_y()))
303 .collect();
304 circuit.create_table_and_lookup_variables(&lookup_vars, &table_vars)
305}
306
307#[inline]
308fn decompose_scalar_var<F>(
311 circuit: &mut PlonkCircuit<F>,
312 scalar_var: Variable,
313 c: usize,
314 scalar_bit_length: usize,
315) -> Result<Vec<Variable>, CircuitError>
316where
317 F: PrimeField,
318{
319 let m = (scalar_bit_length - 1) / c + 1;
321 let mut scalar_val = circuit.witness(scalar_var)?.into_bigint();
322 let decomposed_scalar_vars = (0..m)
323 .map(|_| {
324 let scalar_u64 = scalar_val.as_ref()[0] % (1 << c);
326 scalar_val.divn(c as u32);
329 circuit.create_variable(F::from(scalar_u64))
330 })
331 .collect::<Result<Vec<_>, _>>()?;
332
333 let range_size = F::from((1 << c) as u32);
335 circuit.decomposition_gate(decomposed_scalar_vars.clone(), scalar_var, range_size)?;
336
337 Ok(decomposed_scalar_vars)
338}
339
340#[inline]
341fn compute_scalar_mul_value<F, P>(
344 circuit: &PlonkCircuit<F>,
345 scalar_var: Variable,
346 base_var: &PointVariable,
347) -> Result<TEPoint<F>, CircuitError>
348where
349 F: PrimeField,
350 P: Config<BaseField = F>,
351{
352 let curve_point: Projective<P> = circuit.point_witness(base_var)?.into();
353 let scalar = fq_to_fr::<F, P>(&circuit.witness(scalar_var)?);
354 let res = curve_point * scalar;
355 Ok(res.into())
356}
357
358fn ln_without_floats(a: usize) -> usize {
363 (ark_std::log2(a) * 69 / 100) as usize
365}
366
367#[cfg(test)]
368mod tests {
369
370 use super::*;
371 use crate::PlonkType;
372 use ark_bls12_377::{g1::Config as Param377, Fq as Fq377};
373 use ark_ec::{
374 scalar_mul::variable_base::VariableBaseMSM,
375 twisted_edwards::{Affine, TECurveConfig as Config},
376 };
377 use ark_ed_on_bls12_377::{EdwardsConfig as ParamEd377, Fq as FqEd377};
378 use ark_ed_on_bls12_381::{EdwardsConfig as ParamEd381, Fq as FqEd381};
379 use ark_ed_on_bn254::{EdwardsConfig as ParamEd254, Fq as FqEd254};
380 use ark_ff::UniformRand;
381 use jf_utils::fr_to_fq;
382
383 const RANGE_BIT_LEN_FOR_TEST: usize = 8;
384
385 #[test]
386 fn test_variable_base_multi_scalar_mul() -> Result<(), CircuitError> {
387 test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::TurboPlonk)?;
388 test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::UltraPlonk)?;
389 test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::TurboPlonk)?;
390 test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::UltraPlonk)?;
391 test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::TurboPlonk)?;
392 test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::UltraPlonk)?;
393 test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::TurboPlonk)?;
394 test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::UltraPlonk)?;
395
396 Ok(())
400 }
401
402 fn test_variable_base_multi_scalar_mul_helper<F, P>(
403 plonk_type: PlonkType,
404 ) -> Result<(), CircuitError>
405 where
406 F: PrimeField,
407 P: Config<BaseField = F>,
408 {
409 let mut rng = jf_utils::test_rng();
410
411 for dim in [1, 2, 4, 8, 16, 32, 64, 128] {
412 let mut circuit: PlonkCircuit<F> = match plonk_type {
413 PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
414 PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(RANGE_BIT_LEN_FOR_TEST),
415 };
416
417 let bases: Vec<Affine<P>> = (0..dim).map(|_| Affine::<P>::rand(&mut rng)).collect();
419 let scalars: Vec<P::ScalarField> =
420 (0..dim).map(|_| P::ScalarField::rand(&mut rng)).collect();
421 let scalar_reprs: Vec<<P::ScalarField as PrimeField>::BigInt> =
422 scalars.iter().map(|x| x.into_bigint()).collect();
423 let res = Projective::<P>::msm_bigint(&bases, &scalar_reprs);
424 let res_point: TEPoint<F> = res.into();
425
426 let bases_point: Vec<TEPoint<F>> = bases.iter().map(|x| (*x).into()).collect();
428 let bases_vars: Vec<PointVariable> = bases_point
429 .iter()
430 .map(|x| circuit.create_point_variable(*x))
431 .collect::<Result<Vec<_>, _>>()?;
432 let scalar_vars: Vec<Variable> = scalars
433 .iter()
434 .map(|x| circuit.create_variable(fr_to_fq::<F, P>(x)))
435 .collect::<Result<Vec<_>, _>>()?;
436
437 let res_var = MultiScalarMultiplicationCircuit::<F, P>::msm(
439 &mut circuit,
440 &bases_vars,
441 &scalar_vars,
442 )?;
443
444 assert_eq!(circuit.point_witness(&res_var)?, res_point);
445
446 *circuit.witness_mut(2) = F::rand(&mut rng);
453 assert!(circuit.check_circuit_satisfiability(&[]).is_err());
454 assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
456 &mut circuit,
457 &bases_vars[0..dim - 1],
458 &scalar_vars
459 )
460 .is_err());
461
462 let var_number = circuit.num_vars();
464 assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
465 &mut circuit,
466 &[PointVariable(var_number, var_number)],
467 &scalar_vars
468 )
469 .is_err());
470 assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
471 &mut circuit,
472 &bases_vars,
473 &[var_number]
474 )
475 .is_err());
476 }
477 Ok(())
478 }
479}