jf_relation/gadgets/ecc/
msm.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! This module implements multi-scalar-multiplication circuits.
8
9use super::{PointVariable, TEPoint};
10use crate::{Circuit, CircuitError, PlonkCircuit, Variable};
11use ark_ec::{
12    twisted_edwards::{Projective, TECurveConfig as Config},
13    CurveConfig,
14};
15use ark_ff::{BigInteger, PrimeField};
16use ark_std::{format, vec, vec::Vec};
17use jf_utils::fq_to_fr;
18
19/// Compute the multi-scalar-multiplications in circuit.
20pub trait MultiScalarMultiplicationCircuit<F, P>
21where
22    F: PrimeField,
23    P: Config<BaseField = F>,
24{
25    /// Compute the multi-scalar-multiplications.
26    /// Use pippenger when the circuit supports lookup;
27    /// Use naive method otherwise.
28    /// Return error if the number bases does not match the number of scalars.
29    fn msm(
30        &mut self,
31        bases: &[PointVariable],
32        scalars: &[Variable],
33    ) -> Result<PointVariable, CircuitError>;
34
35    /// Compute the multi-scalar-multiplications where each scalar has at most
36    /// `scalar_bit_length` bits.
37    fn msm_with_var_scalar_length(
38        &mut self,
39        bases: &[PointVariable],
40        scalars: &[Variable],
41        scalar_bit_length: usize,
42    ) -> Result<PointVariable, CircuitError>;
43}
44
45impl<F, P> MultiScalarMultiplicationCircuit<F, P> for PlonkCircuit<F>
46where
47    F: PrimeField,
48    P: Config<BaseField = F>,
49{
50    fn msm(
51        &mut self,
52        bases: &[PointVariable],
53        scalars: &[Variable],
54    ) -> Result<PointVariable, CircuitError> {
55        let scalar_bit_length = <P as CurveConfig>::ScalarField::MODULUS_BIT_SIZE as usize;
56        MultiScalarMultiplicationCircuit::<F, P>::msm_with_var_scalar_length(
57            self,
58            bases,
59            scalars,
60            scalar_bit_length,
61        )
62    }
63
64    fn msm_with_var_scalar_length(
65        &mut self,
66        bases: &[PointVariable],
67        scalars: &[Variable],
68        scalar_bit_length: usize,
69    ) -> Result<PointVariable, CircuitError> {
70        if bases.len() != scalars.len() {
71            return Err(CircuitError::ParameterError(format!(
72                "bases length ({}) does not match scalar length ({})",
73                bases.len(),
74                scalars.len()
75            )));
76        }
77
78        if self.support_lookup() {
79            msm_pippenger::<F, P>(self, bases, scalars, scalar_bit_length)
80        } else {
81            msm_naive::<F, P>(self, bases, scalars, scalar_bit_length)
82        }
83    }
84}
85
86// A naive way to implement msm by computing them individually.
87// Used for double checking the correctness; also as a fall-back solution
88// to Pippenger.
89//
90// Some typical result on BW6-761 curve is shown below (i.e. the circuit
91// simulates BLS12-377 curve operations). More results are available in the test
92// function.
93//
94// number of basis: 1
95// #variables: 1867
96// #constraints: 1865
97//
98// number of basis: 2
99// #variables: 3734
100// #constraints: 3730
101//
102// number of basis: 4
103// #variables: 7468
104// #constraints: 7460
105//
106// number of basis: 8
107// #variables: 14936
108// #constraints: 14920
109//
110// number of basis: 16
111// #variables: 29872
112// #constraints: 29840
113//
114// number of basis: 32
115// #variables: 59744
116// #constraints: 59680
117//
118// number of basis: 64
119// #variables: 119488
120// #constraints: 119360
121//
122// number of basis: 128
123// #variables: 238976
124// #constraints: 238720
125fn msm_naive<F, P>(
126    circuit: &mut PlonkCircuit<F>,
127    bases: &[PointVariable],
128    scalars: &[Variable],
129    scalar_bit_length: usize,
130) -> Result<PointVariable, CircuitError>
131where
132    F: PrimeField,
133    P: Config<BaseField = F>,
134{
135    circuit.check_vars_bound(scalars)?;
136    for base in bases.iter() {
137        circuit.check_point_var_bound(base)?;
138    }
139
140    let scalar_0_bits_le = circuit.unpack(scalars[0], scalar_bit_length)?;
141    let mut res = circuit.variable_base_binary_scalar_mul::<P>(&scalar_0_bits_le, &bases[0])?;
142
143    for (base, scalar) in bases.iter().zip(scalars.iter()).skip(1) {
144        let scalar_bits_le = circuit.unpack(*scalar, scalar_bit_length)?;
145        let tmp = circuit.variable_base_binary_scalar_mul::<P>(&scalar_bits_le, base)?;
146        res = circuit.ecc_add::<P>(&res, &tmp)?;
147    }
148
149    Ok(res)
150}
151
152// A variant of Pippenger MSM.
153//
154// Some typical result on BW6-761 curve is shown below (i.e. the circuit
155// simulates BLS12-377 curve operations). More results are available in the test
156// function.
157//
158// number of basis: 1
159// #variables: 887
160// #constraints: 783
161//
162// number of basis: 2
163// #variables: 1272
164// #constraints: 1064
165//
166// number of basis: 4
167// #variables: 2042
168// #constraints: 1626
169//
170// number of basis: 8
171// #variables: 3582
172// #constraints: 2750
173//
174// number of basis: 16
175// #variables: 6662
176// #constraints: 4998
177//
178// number of basis: 32
179// #variables: 12822
180// #constraints: 9494
181//
182// number of basis: 64
183// #variables: 25142
184// #constraints: 18486
185//
186// number of basis: 128
187// #variables: 49782
188// #constraints: 36470
189fn msm_pippenger<F, P>(
190    circuit: &mut PlonkCircuit<F>,
191    bases: &[PointVariable],
192    scalars: &[Variable],
193    scalar_bit_length: usize,
194) -> Result<PointVariable, CircuitError>
195where
196    F: PrimeField,
197    P: Config<BaseField = F>,
198{
199    // ================================================
200    // check inputs
201    // ================================================
202    for (&scalar, base) in scalars.iter().zip(bases.iter()) {
203        circuit.check_var_bound(scalar)?;
204        circuit.check_point_var_bound(base)?;
205    }
206
207    // ================================================
208    // set up parameters
209    // ================================================
210    let c = if scalar_bit_length < 32 {
211        3
212    } else {
213        ln_without_floats(scalar_bit_length)
214    };
215
216    // ================================================
217    // compute lookup tables and window sums
218    // ================================================
219    let point_zero_var = circuit.neutral_point_variable();
220    // Each window is of size `c`.
221    // We divide up the bits 0..scalar_bit_length into windows of size `c`, and
222    // in parallel process each such window.
223    let mut window_sums = Vec::new();
224    for (base_var, &scalar_var) in bases.iter().zip(scalars.iter()) {
225        // decompose scalar into c-bit scalars
226        let decomposed_scalar_vars =
227            decompose_scalar_var(circuit, scalar_var, c, scalar_bit_length)?;
228
229        // create point table [0 * base, 1 * base, ..., (2^c-1) * base]
230        let mut table_point_vars = vec![point_zero_var, *base_var];
231        for _ in 0..((1 << c) - 2) {
232            let point_var = circuit.ecc_add::<P>(base_var, table_point_vars.last().unwrap())?;
233            table_point_vars.push(point_var);
234        }
235
236        // create lookup point variables
237        let mut lookup_point_vars = Vec::new();
238        for &scalar_var in decomposed_scalar_vars.iter() {
239            let lookup_point = compute_scalar_mul_value::<F, P>(circuit, scalar_var, base_var)?;
240            let lookup_point_var = circuit.create_point_variable(lookup_point)?;
241            lookup_point_vars.push(lookup_point_var);
242        }
243
244        create_point_lookup_gates(
245            circuit,
246            &table_point_vars,
247            &decomposed_scalar_vars,
248            &lookup_point_vars,
249        )?;
250
251        // update window sums
252        if window_sums.is_empty() {
253            window_sums = lookup_point_vars;
254        } else {
255            for (window_sum_mut, lookup_point_var) in
256                window_sums.iter_mut().zip(lookup_point_vars.iter())
257            {
258                *window_sum_mut = circuit.ecc_add::<P>(window_sum_mut, lookup_point_var)?;
259            }
260        }
261    }
262
263    // ================================================
264    // performing additions
265    // ================================================
266    // We store the sum for the lowest window.
267    let lowest = *window_sums.first().unwrap();
268
269    // We're traversing windows from high to low.
270    let b = &window_sums[1..]
271        .iter()
272        .rev()
273        .fold(point_zero_var, |mut total, sum_i| {
274            // total += sum_i
275            total = circuit.ecc_add::<P>(&total, sum_i).unwrap();
276            for _ in 0..c {
277                // double
278                total = circuit.ecc_add::<P>(&total, &total).unwrap();
279            }
280            total
281        });
282    circuit.ecc_add::<P>(&lowest, b)
283}
284
285#[inline]
286fn create_point_lookup_gates<F>(
287    circuit: &mut PlonkCircuit<F>,
288    table_point_vars: &[PointVariable],
289    lookup_scalar_vars: &[Variable],
290    lookup_point_vars: &[PointVariable],
291) -> Result<(), CircuitError>
292where
293    F: PrimeField,
294{
295    let table_vars: Vec<(Variable, Variable)> = table_point_vars
296        .iter()
297        .map(|p| (p.get_x(), p.get_y()))
298        .collect();
299    let lookup_vars: Vec<(Variable, Variable, Variable)> = lookup_scalar_vars
300        .iter()
301        .zip(lookup_point_vars.iter())
302        .map(|(&s, pt)| (s, pt.get_x(), pt.get_y()))
303        .collect();
304    circuit.create_table_and_lookup_variables(&lookup_vars, &table_vars)
305}
306
307#[inline]
308/// Decompose a `scalar_bit_length`-bit scalar `s` into many c-bit scalar
309/// variables `{s0, ..., s_m}` such that `s = \sum_{j=0..m} 2^{cj} * s_j`
310fn decompose_scalar_var<F>(
311    circuit: &mut PlonkCircuit<F>,
312    scalar_var: Variable,
313    c: usize,
314    scalar_bit_length: usize,
315) -> Result<Vec<Variable>, CircuitError>
316where
317    F: PrimeField,
318{
319    // create witness
320    let m = (scalar_bit_length - 1) / c + 1;
321    let mut scalar_val = circuit.witness(scalar_var)?.into_bigint();
322    let decomposed_scalar_vars = (0..m)
323        .map(|_| {
324            // We mod the remaining bits by 2^{window size}, thus taking `c` bits.
325            let scalar_u64 = scalar_val.as_ref()[0] % (1 << c);
326            // We right-shift by c bits, thus getting rid of the
327            // lower bits.
328            scalar_val.divn(c as u32);
329            circuit.create_variable(F::from(scalar_u64))
330        })
331        .collect::<Result<Vec<_>, _>>()?;
332
333    // create circuit
334    let range_size = F::from((1 << c) as u32);
335    circuit.decomposition_gate(decomposed_scalar_vars.clone(), scalar_var, range_size)?;
336
337    Ok(decomposed_scalar_vars)
338}
339
340#[inline]
341/// Compute the value of scalar multiplication `witness(scalar_var) *
342/// witness(base_var)`. This function does not add any constraints.
343fn compute_scalar_mul_value<F, P>(
344    circuit: &PlonkCircuit<F>,
345    scalar_var: Variable,
346    base_var: &PointVariable,
347) -> Result<TEPoint<F>, CircuitError>
348where
349    F: PrimeField,
350    P: Config<BaseField = F>,
351{
352    let curve_point: Projective<P> = circuit.point_witness(base_var)?.into();
353    let scalar = fq_to_fr::<F, P>(&circuit.witness(scalar_var)?);
354    let res = curve_point * scalar;
355    Ok(res.into())
356}
357
358/// The result of this function is only approximately `ln(a)`
359/// [`Explanation of usage`]
360///
361/// [`Explanation of usage`]: https://github.com/scipr-lab/zexe/issues/79#issue-556220473
362fn ln_without_floats(a: usize) -> usize {
363    // log2(a) * ln(2)
364    (ark_std::log2(a) * 69 / 100) as usize
365}
366
367#[cfg(test)]
368mod tests {
369
370    use super::*;
371    use crate::PlonkType;
372    use ark_bls12_377::{g1::Config as Param377, Fq as Fq377};
373    use ark_ec::{
374        scalar_mul::variable_base::VariableBaseMSM,
375        twisted_edwards::{Affine, TECurveConfig as Config},
376    };
377    use ark_ed_on_bls12_377::{EdwardsConfig as ParamEd377, Fq as FqEd377};
378    use ark_ed_on_bls12_381::{EdwardsConfig as ParamEd381, Fq as FqEd381};
379    use ark_ed_on_bn254::{EdwardsConfig as ParamEd254, Fq as FqEd254};
380    use ark_ff::UniformRand;
381    use jf_utils::fr_to_fq;
382
383    const RANGE_BIT_LEN_FOR_TEST: usize = 8;
384
385    #[test]
386    fn test_variable_base_multi_scalar_mul() -> Result<(), CircuitError> {
387        test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::TurboPlonk)?;
388        test_variable_base_multi_scalar_mul_helper::<FqEd254, ParamEd254>(PlonkType::UltraPlonk)?;
389        test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::TurboPlonk)?;
390        test_variable_base_multi_scalar_mul_helper::<FqEd377, ParamEd377>(PlonkType::UltraPlonk)?;
391        test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::TurboPlonk)?;
392        test_variable_base_multi_scalar_mul_helper::<FqEd381, ParamEd381>(PlonkType::UltraPlonk)?;
393        test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::TurboPlonk)?;
394        test_variable_base_multi_scalar_mul_helper::<Fq377, Param377>(PlonkType::UltraPlonk)?;
395
396        // // uncomment the following code to dump the circuit comparison to screen
397        // assert!(false);
398
399        Ok(())
400    }
401
402    fn test_variable_base_multi_scalar_mul_helper<F, P>(
403        plonk_type: PlonkType,
404    ) -> Result<(), CircuitError>
405    where
406        F: PrimeField,
407        P: Config<BaseField = F>,
408    {
409        let mut rng = jf_utils::test_rng();
410
411        for dim in [1, 2, 4, 8, 16, 32, 64, 128] {
412            let mut circuit: PlonkCircuit<F> = match plonk_type {
413                PlonkType::TurboPlonk => PlonkCircuit::new_turbo_plonk(),
414                PlonkType::UltraPlonk => PlonkCircuit::new_ultra_plonk(RANGE_BIT_LEN_FOR_TEST),
415            };
416
417            // bases and scalars
418            let bases: Vec<Affine<P>> = (0..dim).map(|_| Affine::<P>::rand(&mut rng)).collect();
419            let scalars: Vec<P::ScalarField> =
420                (0..dim).map(|_| P::ScalarField::rand(&mut rng)).collect();
421            let scalar_reprs: Vec<<P::ScalarField as PrimeField>::BigInt> =
422                scalars.iter().map(|x| x.into_bigint()).collect();
423            let res = Projective::<P>::msm_bigint(&bases, &scalar_reprs);
424            let res_point: TEPoint<F> = res.into();
425
426            // corresponding wires
427            let bases_point: Vec<TEPoint<F>> = bases.iter().map(|x| (*x).into()).collect();
428            let bases_vars: Vec<PointVariable> = bases_point
429                .iter()
430                .map(|x| circuit.create_point_variable(*x))
431                .collect::<Result<Vec<_>, _>>()?;
432            let scalar_vars: Vec<Variable> = scalars
433                .iter()
434                .map(|x| circuit.create_variable(fr_to_fq::<F, P>(x)))
435                .collect::<Result<Vec<_>, _>>()?;
436
437            // compute circuit
438            let res_var = MultiScalarMultiplicationCircuit::<F, P>::msm(
439                &mut circuit,
440                &bases_vars,
441                &scalar_vars,
442            )?;
443
444            assert_eq!(circuit.point_witness(&res_var)?, res_point);
445
446            // // uncomment the following code to dump the circuit comparison to screen
447            // ark_std::println!("number of basis: {}", dim);
448            // ark_std::println!("#variables: {}", circuit.num_vars(),);
449            // ark_std::println!("#constraints: {}\n", circuit.num_gates(),);
450
451            // wrong witness should fail
452            *circuit.witness_mut(2) = F::rand(&mut rng);
453            assert!(circuit.check_circuit_satisfiability(&[]).is_err());
454            // un-matching basis & scalars
455            assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
456                &mut circuit,
457                &bases_vars[0..dim - 1],
458                &scalar_vars
459            )
460            .is_err());
461
462            // Check variable out of bound error.
463            let var_number = circuit.num_vars();
464            assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
465                &mut circuit,
466                &[PointVariable(var_number, var_number)],
467                &scalar_vars
468            )
469            .is_err());
470            assert!(MultiScalarMultiplicationCircuit::<F, P>::msm(
471                &mut circuit,
472                &bases_vars,
473                &[var_number]
474            )
475            .is_err());
476        }
477        Ok(())
478    }
479}