jf_relation/gadgets/
arithmetic.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Circuit implementation for arithmetic extensions
8
9use super::utils::next_multiple;
10use crate::{
11    constants::{GATE_WIDTH, N_MUL_SELECTORS},
12    gates::{
13        ConstantAdditionGate, ConstantMultiplicationGate, FifthRootGate, LinCombGate, MulAddGate,
14        QuadPolyGate,
15    },
16    Circuit, CircuitError, PlonkCircuit, Variable,
17};
18use ark_ff::PrimeField;
19use ark_std::{borrow::ToOwned, boxed::Box, string::ToString, vec::Vec};
20use num_bigint::BigUint;
21
22impl<F: PrimeField> PlonkCircuit<F> {
23    /// Arithmetic gates
24    ///
25    /// Quadratic polynomial gate: q1 * a + q2 * b + q3 * c + q4 * d + q12 * a *
26    /// b + q34 * c * d + q_c = q_o * e, where q1, q2, q3, q4, q12, q34,
27    /// q_c, q_o are selectors; a, b, c, d are input wires; e is the output
28    /// wire. Return error if variables are invalid.
29    pub fn quad_poly_gate(
30        &mut self,
31        wires: &[Variable; GATE_WIDTH + 1],
32        q_lc: &[F; GATE_WIDTH],
33        q_mul: &[F; N_MUL_SELECTORS],
34        q_o: F,
35        q_c: F,
36    ) -> Result<(), CircuitError> {
37        self.check_vars_bound(wires)?;
38
39        self.insert_gate(
40            wires,
41            Box::new(QuadPolyGate {
42                q_lc: *q_lc,
43                q_mul: *q_mul,
44                q_o,
45                q_c,
46            }),
47        )?;
48        Ok(())
49    }
50
51    /// Arithmetic gates
52    ///
53    /// Quadratic polynomial gate:
54    /// e = q1 * a + q2 * b + q3 * c + q4 * d + q12 * a *
55    /// b + q34 * c * d + q_c, where q1, q2, q3, q4, q12, q34,
56    /// q_c are selectors; a, b, c, d are input wires
57    ///
58    /// Return the variable for
59    /// Return error if variables are invalid.
60    pub fn gen_quad_poly(
61        &mut self,
62        wires: &[Variable; GATE_WIDTH],
63        q_lc: &[F; GATE_WIDTH],
64        q_mul: &[F; N_MUL_SELECTORS],
65        q_c: F,
66    ) -> Result<Variable, CircuitError> {
67        self.check_vars_bound(wires)?;
68        let output_val = q_lc[0] * self.witness(wires[0])?
69            + q_lc[1] * self.witness(wires[1])?
70            + q_lc[2] * self.witness(wires[2])?
71            + q_lc[3] * self.witness(wires[3])?
72            + q_mul[0] * self.witness(wires[0])? * self.witness(wires[1])?
73            + q_mul[1] * self.witness(wires[2])? * self.witness(wires[3])?
74            + q_c;
75        let output_var = self.create_variable(output_val)?;
76        let wires = [wires[0], wires[1], wires[2], wires[3], output_var];
77
78        self.insert_gate(
79            &wires,
80            Box::new(QuadPolyGate {
81                q_lc: *q_lc,
82                q_mul: *q_mul,
83                q_o: F::one(),
84                q_c,
85            }),
86        )?;
87
88        Ok(output_var)
89    }
90
91    /// Constrain a linear combination gate:
92    /// q1 * a + q2 * b + q3 * c + q4 * d  = y
93    pub fn lc_gate(
94        &mut self,
95        wires: &[Variable; GATE_WIDTH + 1],
96        coeffs: &[F; GATE_WIDTH],
97    ) -> Result<(), CircuitError> {
98        self.check_vars_bound(wires)?;
99
100        let wire_vars = [wires[0], wires[1], wires[2], wires[3], wires[4]];
101        self.insert_gate(&wire_vars, Box::new(LinCombGate { coeffs: *coeffs }))?;
102        Ok(())
103    }
104
105    /// Obtain a variable representing a linear combination.
106    /// Return error if variables are invalid.
107    pub fn lc(
108        &mut self,
109        wires_in: &[Variable; GATE_WIDTH],
110        coeffs: &[F; GATE_WIDTH],
111    ) -> Result<Variable, CircuitError> {
112        self.check_vars_bound(wires_in)?;
113
114        let vals_in: Vec<F> = wires_in
115            .iter()
116            .map(|&var| self.witness(var))
117            .collect::<Result<Vec<_>, CircuitError>>()?;
118
119        // calculate y as the linear combination of coeffs and vals_in
120        let y_val = vals_in
121            .iter()
122            .zip(coeffs.iter())
123            .map(|(&val, &coeff)| val * coeff)
124            .sum();
125        let y = self.create_variable(y_val)?;
126
127        let wires = [wires_in[0], wires_in[1], wires_in[2], wires_in[3], y];
128        self.lc_gate(&wires, coeffs)?;
129        Ok(y)
130    }
131
132    /// Constrain a mul-addition gate:
133    /// q_muls\[0\] * wires\[0\] *  wires\[1\] +  q_muls\[1\] * wires\[2\] *
134    /// wires\[3\] = wires\[4\]
135    pub fn mul_add_gate(
136        &mut self,
137        wires: &[Variable; GATE_WIDTH + 1],
138        q_muls: &[F; N_MUL_SELECTORS],
139    ) -> Result<(), CircuitError> {
140        self.check_vars_bound(wires)?;
141
142        let wire_vars = [wires[0], wires[1], wires[2], wires[3], wires[4]];
143        self.insert_gate(&wire_vars, Box::new(MulAddGate { coeffs: *q_muls }))?;
144        Ok(())
145    }
146
147    /// Obtain a variable representing `q12 * a * b + q34 * c * d`,
148    /// where `a, b, c, d` are input wires, and `q12`, `q34` are selectors.
149    /// Return error if variables are invalid.
150    pub fn mul_add(
151        &mut self,
152        wires_in: &[Variable; GATE_WIDTH],
153        q_muls: &[F; N_MUL_SELECTORS],
154    ) -> Result<Variable, CircuitError> {
155        self.check_vars_bound(wires_in)?;
156
157        let vals_in: Vec<F> = wires_in
158            .iter()
159            .map(|&var| self.witness(var))
160            .collect::<Result<Vec<_>, CircuitError>>()?;
161
162        // calculate y as the mul-addition of coeffs and vals_in
163        let y_val = q_muls[0] * vals_in[0] * vals_in[1] + q_muls[1] * vals_in[2] * vals_in[3];
164        let y = self.create_variable(y_val)?;
165
166        let wires = [wires_in[0], wires_in[1], wires_in[2], wires_in[3], y];
167        self.mul_add_gate(&wires, q_muls)?;
168        Ok(y)
169    }
170
171    /// Obtain a variable representing the sum of a list of variables.
172    /// Return error if variables are invalid.
173    pub fn sum(&mut self, elems: &[Variable]) -> Result<Variable, CircuitError> {
174        if elems.is_empty() {
175            return Err(CircuitError::ParameterError(
176                "Sum over an empty slice of variables is undefined".to_string(),
177            ));
178        }
179        self.check_vars_bound(elems)?;
180
181        let sum = {
182            let sum_val: F = elems
183                .iter()
184                .map(|&elem| self.witness(elem))
185                .collect::<Result<Vec<_>, CircuitError>>()?
186                .iter()
187                .sum();
188            self.create_variable(sum_val)?
189        };
190
191        // pad to ("next multiple of 3" + 1) in length
192        let mut padded: Vec<Variable> = elems.to_owned();
193        let rate = GATE_WIDTH - 1; // rate at which each lc add
194        let padded_len = next_multiple(elems.len() - 1, rate)? + 1;
195        padded.resize(padded_len, self.zero());
196
197        // z_0 = = x_0
198        // z_i = z_i-1 + x_3i-2 + x_3i-1 + x_3i
199        let coeffs = [F::one(); GATE_WIDTH];
200        let mut accum = padded[0];
201        for i in 1..padded_len / rate {
202            accum = self.lc(
203                &[
204                    accum,
205                    padded[rate * i - 2],
206                    padded[rate * i - 1],
207                    padded[rate * i],
208                ],
209                &coeffs,
210            )?;
211        }
212        // final round
213        let wires = [
214            accum,
215            padded[padded_len - 3],
216            padded[padded_len - 2],
217            padded[padded_len - 1],
218            sum,
219        ];
220        self.lc_gate(&wires, &coeffs)?;
221
222        Ok(sum)
223    }
224
225    /// Constrain variable `y` to the addition of `a` and `c`, where `c` is a
226    /// constant value Return error if the input variables are invalid.
227    pub fn add_constant_gate(
228        &mut self,
229        x: Variable,
230        c: F,
231        y: Variable,
232    ) -> Result<(), CircuitError> {
233        self.check_var_bound(x)?;
234        self.check_var_bound(y)?;
235
236        let wire_vars = &[x, self.one(), 0, 0, y];
237        self.insert_gate(wire_vars, Box::new(ConstantAdditionGate(c)))?;
238        Ok(())
239    }
240
241    /// Obtains a variable representing an addition with a constant value
242    /// Return error if the input variable is invalid
243    pub fn add_constant(
244        &mut self,
245        input_var: Variable,
246        elem: &F,
247    ) -> Result<Variable, CircuitError> {
248        self.check_var_bound(input_var)?;
249
250        let input_val = self.witness(input_var).unwrap();
251        let output_val = *elem + input_val;
252        let output_var = self.create_variable(output_val).unwrap();
253
254        self.add_constant_gate(input_var, *elem, output_var)?;
255
256        Ok(output_var)
257    }
258
259    /// Constrain variable `y` to the product of `a` and `c`, where `c` is a
260    /// constant value Return error if the input variables are invalid.
261    pub fn mul_constant_gate(
262        &mut self,
263        x: Variable,
264        c: F,
265        y: Variable,
266    ) -> Result<(), CircuitError> {
267        self.check_var_bound(x)?;
268        self.check_var_bound(y)?;
269
270        let wire_vars = &[x, 0, 0, 0, y];
271        self.insert_gate(wire_vars, Box::new(ConstantMultiplicationGate(c)))?;
272        Ok(())
273    }
274
275    /// Obtains a variable representing a multiplication with a constant value
276    /// Return error if the input variable is invalid
277    pub fn mul_constant(
278        &mut self,
279        input_var: Variable,
280        elem: &F,
281    ) -> Result<Variable, CircuitError> {
282        self.check_var_bound(input_var)?;
283
284        let input_val = self.witness(input_var).unwrap();
285        let output_val = *elem * input_val;
286        let output_var = self.create_variable(output_val).unwrap();
287
288        self.mul_constant_gate(input_var, *elem, output_var)?;
289
290        Ok(output_var)
291    }
292
293    /// Return a variable to be the 11th power of the input variable.
294    /// Cost: 3 constraints.
295    pub fn power_11_gen(&mut self, x: Variable) -> Result<Variable, CircuitError> {
296        self.check_var_bound(x)?;
297
298        // now we prove that x^11 = x_to_11
299        let x_val = self.witness(x)?;
300        let x_to_5_val = x_val.pow([5]);
301        let x_to_5 = self.create_variable(x_to_5_val)?;
302        let wire_vars = &[x, 0, 0, 0, x_to_5];
303        self.insert_gate(wire_vars, Box::new(FifthRootGate))?;
304
305        let x_to_10 = self.mul(x_to_5, x_to_5)?;
306        self.mul(x_to_10, x)
307    }
308
309    /// Constraint a variable to be the 11th power of another variable.
310    /// Cost: 3 constraints.
311    pub fn power_11_gate(&mut self, x: Variable, x_to_11: Variable) -> Result<(), CircuitError> {
312        self.check_var_bound(x)?;
313        self.check_var_bound(x_to_11)?;
314
315        // now we prove that x^11 = x_to_11
316        let x_val = self.witness(x)?;
317        let x_to_5_val = x_val.pow([5]);
318        let x_to_5 = self.create_variable(x_to_5_val)?;
319        let wire_vars = &[x, 0, 0, 0, x_to_5];
320        self.insert_gate(wire_vars, Box::new(FifthRootGate))?;
321
322        let x_to_10 = self.mul(x_to_5, x_to_5)?;
323        self.mul_gate(x_to_10, x, x_to_11)
324    }
325
326    /// Obtain the truncation of the input.
327    /// Constrain that the input and output values congruent modulo
328    /// 2^bit_length. Return error if the input is invalid.
329    pub fn truncate(&mut self, a: Variable, bit_length: usize) -> Result<Variable, CircuitError> {
330        self.check_var_bound(a)?;
331        let a_val = self.witness(a)?;
332        let a_uint: BigUint = a_val.into();
333        let modulus = F::from(2u8).pow([bit_length as u64]);
334        let modulus_uint: BigUint = modulus.into();
335        let res = F::from(a_uint % modulus_uint);
336        let b = self.create_variable(res)?;
337        self.truncate_gate(a, b, bit_length)?;
338        Ok(b)
339    }
340
341    /// Truncation gate.
342    /// Constrain that b == a modulo 2^bit_length.
343    /// Return error if the inputs are invalid; or b >= 2^bit_length.
344    pub fn truncate_gate(
345        &mut self,
346        a: Variable,
347        b: Variable,
348        bit_length: usize,
349    ) -> Result<(), CircuitError> {
350        if !self.support_lookup() {
351            return Err(CircuitError::ParameterError(
352                "does not support range table".to_string(),
353            ));
354        }
355
356        self.check_var_bound(a)?;
357        self.check_var_bound(b)?;
358
359        let a_val = self.witness(a)?;
360        let b_val = self.witness(b)?;
361        let modulus = F::from(2u8).pow([bit_length as u64]);
362        let modulus_uint: BigUint = modulus.into();
363
364        if b_val >= modulus {
365            return Err(CircuitError::ParameterError(
366                "Truncation error: b is greater than 2^bit_length".to_string(),
367            ));
368        }
369
370        let native_field_bit_length = F::MODULUS_BIT_SIZE as usize;
371        if native_field_bit_length <= bit_length {
372            return Err(CircuitError::ParameterError(
373                "Truncation error: native field is not greater than truncation target".to_string(),
374            ));
375        }
376
377        let bit_length_non_lookup_range = bit_length % self.range_bit_len()?;
378        let bit_length_lookup_component = bit_length - bit_length_non_lookup_range;
379
380        // we need to show that a and b satisfy the following
381        // relationship:
382        // (1) b = a mod modulus
383        // where
384        // * a is native_field_bit_length bits
385        // * b is bit_length bits
386        //
387        // which is
388        // (2) a = b + z * modulus
389        // for some z, where
390        // * z < 2^(native_field_bit_length - bit_length)
391        //
392        // So we set delta_length = native_field_bit_length - bit_length
393
394        let delta_length = native_field_bit_length - bit_length;
395        let delta_length_non_lookup_range = delta_length % self.range_bit_len()?;
396        let delta_length_lookup_component = delta_length - delta_length_non_lookup_range;
397
398        // Now (2) becomes
399        // (3) a = b1 + b2 * 2^bit_length_lookup_component
400        //       + modulus * (z1 + 2^delta_length_lookup_component * z2)
401        // with
402        //   b1 < 2^bit_length_lookup_component
403        //   b2 < 2^bit_length_non_lookup_range
404        //   z1 < 2^delta_length_lookup_component
405        //   z2 < 2^delta_length_non_lookup_range
406
407        // The concrete statements we need to prove becomes
408        // (4) b = b1 + b2 * 2^bit_length_lookup_component
409        // (5) a = b + modulus * z1
410        //       + modulus * 2^delta_length_lookup_component * z2
411        // (6) b1 < 2^bit_length_lookup_component
412        // (7) b2 < 2^bit_length_non_lookup_range
413        // (8) z1 < 2^delta_length_lookup_component
414        // (9) z2 < 2^delta_length_non_lookup_range
415
416        // step 1. setup the constants
417        let two_to_bit_length_lookup_component =
418            F::from(2u8).pow([bit_length_lookup_component as u64]);
419        let two_to_bit_length_lookup_component_uint: BigUint =
420            two_to_bit_length_lookup_component.into();
421
422        let two_to_delta_length_lookup_component =
423            F::from(2u8).pow([delta_length_lookup_component as u64]);
424        let two_to_delta_length_lookup_component_uint: BigUint =
425            two_to_delta_length_lookup_component.into();
426
427        let modulus_mul_two_to_delta_length_lookup_component_uint =
428            &two_to_delta_length_lookup_component_uint * &modulus_uint;
429        let modulus_mul_two_to_delta_length_lookup_component =
430            F::from(modulus_mul_two_to_delta_length_lookup_component_uint);
431
432        // step 2. get the intermediate data in the clear
433        let a_uint: BigUint = a_val.into();
434        let b_uint: BigUint = b_val.into();
435        let b1_uint = &b_uint % &two_to_bit_length_lookup_component_uint;
436        let b2_uint = &b_uint / &two_to_bit_length_lookup_component_uint;
437
438        let z_uint = (&a_uint - &b_uint) / &modulus_uint;
439        let z1_uint = &z_uint % &two_to_delta_length_lookup_component_uint;
440        let z2_uint = &z_uint / &two_to_delta_length_lookup_component_uint;
441
442        // step 3. create intermediate variables
443        let b1_var = self.create_variable(F::from(b1_uint))?;
444        let b2_var = self.create_variable(F::from(b2_uint))?;
445        let z1_var = self.create_variable(F::from(z1_uint))?;
446        let z2_var = self.create_variable(F::from(z2_uint))?;
447
448        // step 4. prove equations (4) - (9)
449        // (4) b = b1 + b2 * 2^bit_length_lookup_component
450        let wires = [b1_var, b2_var, self.zero(), self.zero(), b];
451        let coeffs = [
452            F::one(),
453            two_to_bit_length_lookup_component,
454            F::zero(),
455            F::zero(),
456        ];
457        self.lc_gate(&wires, &coeffs)?;
458
459        // (5) a = b + modulus * z1
460        //       + modulus * 2^delta_length_lookup_component * z2
461        let wires = [b, z1_var, z2_var, self.zero(), a];
462        let coeffs = [
463            F::one(),
464            modulus,
465            modulus_mul_two_to_delta_length_lookup_component,
466            F::zero(),
467        ];
468        self.lc_gate(&wires, &coeffs)?;
469
470        // (6) b1 < 2^bit_length_lookup_component
471        // note that bit_length_lookup_component is public information
472        // so we don't need to add a selection gate here
473        if bit_length_lookup_component != 0 {
474            self.range_gate_with_lookup(b1_var, bit_length_lookup_component)?;
475        }
476
477        // (7) b2 < 2^bit_length_non_lookup_range
478        // note that bit_length_non_lookup_range is public information
479        // so we don't need to add a selection gate here
480        if bit_length_non_lookup_range != 0 {
481            self.enforce_in_range(b2_var, bit_length_non_lookup_range)?;
482        }
483
484        // (8) z1 < 2^delta_length_lookup_component
485        // note that delta_length_lookup_component is public information
486        // so we don't need to add a selection gate here
487        if delta_length_lookup_component != 0 {
488            self.range_gate_with_lookup(z1_var, delta_length_lookup_component)?;
489        }
490
491        // (9) z2 < 2^delta_length_non_lookup_range
492        // note that delta_length_non_lookup_range is public information
493        // so we don't need to add a selection gate here
494        if delta_length_non_lookup_range != 0 {
495            self.enforce_in_range(z2_var, delta_length_non_lookup_range)?;
496        }
497
498        Ok(())
499    }
500}
501
502#[cfg(test)]
503mod test {
504    use crate::{
505        constants::GATE_WIDTH, gadgets::test_utils::test_variable_independence_for_circuit,
506        Circuit, CircuitError, PlonkCircuit,
507    };
508    use ark_bls12_377::Fq as Fq377;
509    use ark_ed_on_bls12_377::Fq as FqEd377;
510    use ark_ed_on_bls12_381::Fq as FqEd381;
511    use ark_ed_on_bn254::Fq as FqEd254;
512    use ark_ff::PrimeField;
513    use ark_std::{vec, vec::Vec};
514    use jf_utils::test_rng;
515    use num_bigint::BigUint;
516
517    #[test]
518    fn test_quad_poly_gate() -> Result<(), CircuitError> {
519        test_quad_poly_gate_helper::<FqEd254>()?;
520        test_quad_poly_gate_helper::<FqEd377>()?;
521        test_quad_poly_gate_helper::<FqEd381>()?;
522        test_quad_poly_gate_helper::<Fq377>()
523    }
524    fn test_quad_poly_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
525        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
526        let q_lc = [F::from(2u32), F::from(3u32), F::from(5u32), F::from(2u32)];
527        let q_mul = [F::one(), F::from(2u8)];
528        let q_o = F::one();
529        let q_c = F::from(9u8);
530        let wires_1: Vec<_> = [
531            F::from(23u32),
532            F::from(8u32),
533            F::from(1u32),
534            -F::from(20u32),
535            F::from(188u32),
536        ]
537        .iter()
538        .map(|val| circuit.create_variable(*val).unwrap())
539        .collect();
540        let wires_2: Vec<_> = [
541            F::zero(),
542            -F::from(8u32),
543            F::from(1u32),
544            F::zero(),
545            -F::from(10u32),
546        ]
547        .iter()
548        .map(|val| circuit.create_variable(*val).unwrap())
549        .collect();
550
551        // 23 * 2 + 8 * 3 + 1 * 5 + (-20) * 2 + 23 * 8 + 2 * 1 * (-20) + 9 = 188
552        let var = wires_1[0];
553        circuit.quad_poly_gate(&wires_1.try_into().unwrap(), &q_lc, &q_mul, q_o, q_c)?;
554        // 0 * 2 + (-8) * 3 + 1 * 5 + 0 * 2 + 0 * -8 + 1 * 0 + 9 = -10
555        circuit.quad_poly_gate(&wires_2.try_into().unwrap(), &q_lc, &q_mul, q_o, q_c)?;
556        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
557        *circuit.witness_mut(var) = F::from(34u32);
558        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
559        // Check variable out of bound error.
560        assert!(circuit
561            .quad_poly_gate(&[0, 1, 1, circuit.num_vars(), 0], &q_lc, &q_mul, q_o, q_c)
562            .is_err());
563
564        let circuit_1 = build_quad_poly_gate_circuit([
565            -F::from(98973u32),
566            F::from(4u32),
567            F::zero(),
568            F::from(79u32),
569            F::one(),
570        ])?;
571        let circuit_2 = build_quad_poly_gate_circuit([
572            F::one(),
573            F::zero(),
574            F::from(6u32),
575            -F::from(9u32),
576            F::one(),
577        ])?;
578        test_variable_independence_for_circuit(circuit_1, circuit_2)?;
579
580        Ok(())
581    }
582    fn build_quad_poly_gate_circuit<F: PrimeField>(
583        wires: [F; GATE_WIDTH + 1],
584    ) -> Result<PlonkCircuit<F>, CircuitError> {
585        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
586        let wires: Vec<_> = wires
587            .iter()
588            .map(|val| circuit.create_variable(*val).unwrap())
589            .collect();
590        let q_lc = [F::from(2u32), F::from(3u32), F::from(5u32), F::from(2u32)];
591        let q_mul = [F::one(), F::from(2u8)];
592        let q_o = F::one();
593        let q_c = F::from(9u8);
594        circuit.quad_poly_gate(&wires.try_into().unwrap(), &q_lc, &q_mul, q_o, q_c)?;
595        circuit.finalize_for_arithmetization()?;
596        Ok(circuit)
597    }
598
599    #[test]
600    fn test_lc() -> Result<(), CircuitError> {
601        test_lc_helper::<FqEd254>()?;
602        test_lc_helper::<FqEd377>()?;
603        test_lc_helper::<FqEd381>()?;
604        test_lc_helper::<Fq377>()
605    }
606    fn test_lc_helper<F: PrimeField>() -> Result<(), CircuitError> {
607        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
608        let wire_in_1: Vec<_> = [
609            F::from(23u32),
610            F::from(8u32),
611            F::from(1u32),
612            -F::from(20u32),
613        ]
614        .iter()
615        .map(|val| circuit.create_variable(*val).unwrap())
616        .collect();
617        let wire_in_2: Vec<_> = [F::zero(), -F::from(8u32), F::from(1u32), F::zero()]
618            .iter()
619            .map(|val| circuit.create_variable(*val).unwrap())
620            .collect();
621        let coeffs = [F::from(2u32), F::from(3u32), F::from(5u32), F::from(2u32)];
622        let y_1 = circuit.lc(&wire_in_1.try_into().unwrap(), &coeffs)?;
623        let y_2 = circuit.lc(&wire_in_2.try_into().unwrap(), &coeffs)?;
624
625        // 23 * 2 + 8 * 3 + 1 * 5 + (-20) * 2 = 35
626        assert_eq!(circuit.witness(y_1)?, F::from(35u32));
627        // 0 * 2 + (-8) * 3 + 1 * 5 + 0 * 2 = -19
628        assert_eq!(circuit.witness(y_2)?, -F::from(19u32));
629        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
630        *circuit.witness_mut(y_1) = F::from(34u32);
631        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
632        // Check variable out of bound error.
633        assert!(circuit.lc(&[0, 1, 1, circuit.num_vars()], &coeffs).is_err());
634
635        let circuit_1 =
636            build_lc_circuit([-F::from(98973u32), F::from(4u32), F::zero(), F::from(79u32)])?;
637        let circuit_2 = build_lc_circuit([F::one(), F::zero(), F::from(6u32), -F::from(9u32)])?;
638        test_variable_independence_for_circuit(circuit_1, circuit_2)?;
639
640        Ok(())
641    }
642
643    fn build_lc_circuit<F: PrimeField>(wires_in: [F; 4]) -> Result<PlonkCircuit<F>, CircuitError> {
644        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
645        let wires_in: Vec<_> = wires_in
646            .iter()
647            .map(|val| circuit.create_variable(*val).unwrap())
648            .collect();
649        let coeffs = [F::from(2u32), F::from(3u32), F::from(5u32), F::from(2u32)];
650        circuit.lc(&wires_in.try_into().unwrap(), &coeffs)?;
651        circuit.finalize_for_arithmetization()?;
652        Ok(circuit)
653    }
654
655    #[test]
656    fn test_mul_add() -> Result<(), CircuitError> {
657        test_mul_add_helper::<FqEd254>()?;
658        test_mul_add_helper::<FqEd377>()?;
659        test_mul_add_helper::<FqEd381>()?;
660        test_mul_add_helper::<Fq377>()
661    }
662
663    fn test_mul_add_helper<F: PrimeField>() -> Result<(), CircuitError> {
664        let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
665        let wire_in_1: Vec<_> = [
666            F::from(23u32),
667            F::from(8u32),
668            F::from(1u32),
669            -F::from(20u32),
670        ]
671        .iter()
672        .map(|val| circuit.create_variable(*val).unwrap())
673        .collect();
674        let wire_in_2: Vec<_> = [F::one(), -F::from(8u32), F::one(), F::one()]
675            .iter()
676            .map(|val| circuit.create_variable(*val).unwrap())
677            .collect();
678        let q_muls = [F::from(3u32), F::from(5u32)];
679        let y_1 = circuit.mul_add(&wire_in_1.try_into().unwrap(), &q_muls)?;
680        let y_2 = circuit.mul_add(&wire_in_2.try_into().unwrap(), &q_muls)?;
681
682        // 3 * (23 * 8) + 5 * (1 * -20) = 452
683        assert_eq!(circuit.witness(y_1)?, F::from(452u32));
684        // 3 * (1 * -8) + 5 * (1 * 1)= -19
685        assert_eq!(circuit.witness(y_2)?, -F::from(19u32));
686        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
687        *circuit.witness_mut(y_1) = F::from(34u32);
688        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
689        // Check variable out of bound error.
690        assert!(circuit
691            .mul_add(&[0, 1, 1, circuit.num_vars()], &q_muls)
692            .is_err());
693
694        let circuit_1 =
695            build_mul_add_circuit([-F::from(98973u32), F::from(4u32), F::zero(), F::from(79u32)])?;
696        let circuit_2 =
697            build_mul_add_circuit([F::one(), F::zero(), F::from(6u32), -F::from(9u32)])?;
698        test_variable_independence_for_circuit(circuit_1, circuit_2)?;
699
700        Ok(())
701    }
702
703    fn build_mul_add_circuit<F: PrimeField>(
704        wires_in: [F; 4],
705    ) -> Result<PlonkCircuit<F>, CircuitError> {
706        let mut circuit = PlonkCircuit::new_turbo_plonk();
707        let wires_in: Vec<_> = wires_in
708            .iter()
709            .map(|val| circuit.create_variable(*val).unwrap())
710            .collect();
711        let q_muls = [F::from(3u32), F::from(5u32)];
712        circuit.mul_add(&wires_in.try_into().unwrap(), &q_muls)?;
713        circuit.finalize_for_arithmetization()?;
714        Ok(circuit)
715    }
716
717    #[test]
718    fn test_sum() -> Result<(), CircuitError> {
719        test_sum_helper::<FqEd254>()?;
720        test_sum_helper::<FqEd377>()?;
721        test_sum_helper::<FqEd381>()?;
722        test_sum_helper::<Fq377>()
723    }
724
725    fn test_sum_helper<F: PrimeField>() -> Result<(), CircuitError> {
726        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
727        let mut vars = vec![];
728        for i in 0..11 {
729            vars.push(circuit.create_variable(F::from(i as u32))?);
730        }
731
732        // sum over an empty array should be undefined behavior, thus fail
733        assert!(circuit.sum(&[]).is_err());
734
735        for until in 1..11 {
736            let expected_sum = F::from((0..until).sum::<u32>());
737            let sum = circuit.sum(&vars[..until as usize])?;
738            assert_eq!(circuit.witness(sum)?, expected_sum);
739        }
740        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
741        // if mess up the wire value, should fail
742        *circuit.witness_mut(vars[5]) = F::one();
743        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
744        // Check variable out of bound error.
745        assert!(circuit.sum(&[circuit.num_vars()]).is_err());
746
747        let circuit_1 = build_sum_circuit(vec![
748            -F::from(73u32),
749            F::from(4u32),
750            F::zero(),
751            F::from(79u32),
752            F::from(23u32),
753        ])?;
754        let circuit_2 = build_sum_circuit(vec![
755            F::one(),
756            F::zero(),
757            F::from(6u32),
758            -F::from(9u32),
759            F::one(),
760        ])?;
761        test_variable_independence_for_circuit(circuit_1, circuit_2)?;
762
763        Ok(())
764    }
765
766    fn build_sum_circuit<F: PrimeField>(vals: Vec<F>) -> Result<PlonkCircuit<F>, CircuitError> {
767        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
768        let mut vars = vec![];
769        for val in vals {
770            vars.push(circuit.create_variable(val)?);
771        }
772        circuit.sum(&vars[..])?;
773        circuit.finalize_for_arithmetization()?;
774        Ok(circuit)
775    }
776
777    #[test]
778    fn test_power_11_gen_gate() -> Result<(), CircuitError> {
779        test_power_11_gen_gate_helper::<FqEd254>()?;
780        test_power_11_gen_gate_helper::<FqEd377>()?;
781        test_power_11_gen_gate_helper::<FqEd381>()?;
782        test_power_11_gen_gate_helper::<Fq377>()
783    }
784    fn test_power_11_gen_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
785        let mut rng = test_rng();
786        let x = F::rand(&mut rng);
787        let y = F::rand(&mut rng);
788        let x11 = x.pow([11]);
789
790        // Create a satisfied circuit
791        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
792
793        let x_var = circuit.create_variable(x)?;
794        let x_to_11_var = circuit.create_variable(x11)?;
795
796        let x_to_11_var_rec = circuit.power_11_gen(x_var)?;
797        circuit.enforce_equal(x_to_11_var, x_to_11_var_rec)?;
798        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
799
800        // Create an unsatisfied circuit
801        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
802
803        let y_var = circuit.create_variable(y)?;
804        let x_to_11_var = circuit.create_variable(x11)?;
805
806        let x_to_11_var_rec = circuit.power_11_gen(y_var)?;
807        circuit.enforce_equal(x_to_11_var, x_to_11_var_rec)?;
808        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
809
810        // Create an unsatisfied circuit
811        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
812        let x_var = circuit.create_variable(x)?;
813        let y_var = circuit.create_variable(y)?;
814
815        let x_to_11_var_rec = circuit.power_11_gen(x_var)?;
816        circuit.enforce_equal(y_var, x_to_11_var_rec)?;
817        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
818
819        Ok(())
820    }
821
822    #[test]
823    fn test_power_11_gate() -> Result<(), CircuitError> {
824        test_power_11_gate_helper::<FqEd254>()?;
825        test_power_11_gate_helper::<FqEd377>()?;
826        test_power_11_gate_helper::<FqEd381>()?;
827        test_power_11_gate_helper::<Fq377>()
828    }
829    fn test_power_11_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
830        let mut rng = test_rng();
831        let x = F::rand(&mut rng);
832        let y = F::rand(&mut rng);
833        let x11 = x.pow([11]);
834
835        // Create a satisfied circuit
836        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
837        let x_var = circuit.create_variable(x)?;
838        let x_to_11_var = circuit.create_variable(x11)?;
839
840        circuit.power_11_gate(x_var, x_to_11_var)?;
841        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
842
843        // Create an unsatisfied circuit
844        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
845        let y_var = circuit.create_variable(y)?;
846        let x_to_11_var = circuit.create_variable(x11)?;
847
848        circuit.power_11_gate(y_var, x_to_11_var)?;
849        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
850
851        // Create an unsatisfied circuit
852        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
853        let x_var = circuit.create_variable(x)?;
854        let y = circuit.create_variable(y)?;
855
856        circuit.power_11_gate(x_var, y)?;
857        assert!(circuit.check_circuit_satisfiability(&[]).is_err());
858
859        Ok(())
860    }
861
862    #[test]
863    fn test_truncation_gate() -> Result<(), CircuitError> {
864        test_truncation_gate_helper::<FqEd254>()?;
865        test_truncation_gate_helper::<FqEd377>()?;
866        test_truncation_gate_helper::<FqEd381>()?;
867        test_truncation_gate_helper::<Fq377>()
868    }
869    fn test_truncation_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
870        let mut rng = test_rng();
871        let x = F::rand(&mut rng);
872        let x_uint: BigUint = x.into();
873
874        // Create a satisfied circuit
875        for len in [80, 100, 201, 248] {
876            let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
877            let x_var = circuit.create_variable(x)?;
878            let modulus = F::from(2u8).pow([len as u64]);
879            let modulus_uint: BigUint = modulus.into();
880            let y_var = circuit.truncate(x_var, len)?;
881            assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
882            let y = circuit.witness(y_var)?;
883            assert!(y < modulus);
884            assert_eq!(y, F::from(&x_uint % &modulus_uint))
885        }
886
887        // more tests
888        for minus_len in 1..=16 {
889            let len = F::MODULUS_BIT_SIZE as usize - minus_len;
890            let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
891            let x_var = circuit.create_variable(x)?;
892            let modulus = F::from(2u8).pow([len as u64]);
893            let modulus_uint: BigUint = modulus.into();
894            let y_var = circuit.truncate(x_var, len)?;
895            assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
896            let y = circuit.witness(y_var)?;
897            assert!(y < modulus);
898            assert_eq!(y, F::from(&x_uint % &modulus_uint))
899        }
900
901        // Bad path: b > 2^bit_len
902        {
903            let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
904            let x = F::rand(&mut rng);
905            let x_var = circuit.create_variable(x)?;
906            let y = F::rand(&mut rng);
907            let y_var = circuit.create_variable(y)?;
908
909            assert!(circuit.truncate_gate(x_var, y_var, 16).is_err());
910        }
911
912        // Bad path: b!= a % 2^bit_len
913        {
914            let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
915            let x = F::rand(&mut rng);
916            let x_var = circuit.create_variable(x)?;
917            let y = F::one();
918            let y_var = circuit.create_variable(y)?;
919            circuit.truncate_gate(x_var, y_var, 192)?;
920            assert!(circuit.check_circuit_satisfiability(&[]).is_err());
921        }
922
923        // Bad path: bit_len = F::MODULUS_BIT_SIZE
924        {
925            let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
926            let x = F::rand(&mut rng);
927            let x_var = circuit.create_variable(x)?;
928            let y = F::one();
929            let y_var = circuit.create_variable(y)?;
930            assert!(circuit
931                .truncate_gate(x_var, y_var, F::MODULUS_BIT_SIZE as usize)
932                .is_err());
933        }
934
935        Ok(())
936    }
937
938    #[test]
939    fn test_arithmetization() -> Result<(), CircuitError> {
940        test_arithmetization_helper::<FqEd254>()?;
941        test_arithmetization_helper::<FqEd377>()?;
942        test_arithmetization_helper::<FqEd381>()?;
943        test_arithmetization_helper::<Fq377>()
944    }
945
946    fn test_arithmetization_helper<F: PrimeField>() -> Result<(), CircuitError> {
947        // Create the circuit
948        let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
949        // is_equal gate
950        let val = F::from(31415u32);
951        let a = circuit.create_variable(val)?;
952        let b = circuit.create_variable(val)?;
953        circuit.is_equal(a, b)?;
954
955        // lc gate
956        let wire_in: Vec<_> = [
957            F::from(23u32),
958            F::from(8u32),
959            F::from(1u32),
960            -F::from(20u32),
961        ]
962        .iter()
963        .map(|val| circuit.create_variable(*val).unwrap())
964        .collect();
965        let coeffs = [F::from(2u32), F::from(3u32), F::from(5u32), F::from(2u32)];
966        circuit.lc(&wire_in.try_into().unwrap(), &coeffs)?;
967
968        // conditional select gate
969        let bit_true = circuit.create_boolean_variable(true)?;
970        let x_0 = circuit.create_variable(F::from(23u32))?;
971        let x_1 = circuit.create_variable(F::from(24u32))?;
972        circuit.conditional_select(bit_true, x_0, x_1)?;
973
974        // range gate
975        let b = circuit.create_variable(F::from(1023u32))?;
976        circuit.enforce_in_range(b, 10)?;
977
978        // sum gate
979        let mut vars = vec![];
980        for i in 0..11 {
981            vars.push(circuit.create_variable(F::from(i as u32))?);
982        }
983        circuit.sum(&vars[..vars.len()])?;
984
985        // Finalize the circuit
986        circuit.finalize_for_arithmetization()?;
987        let pub_inputs = vec![];
988        crate::constraint_system::test::test_arithmetization_for_circuit(circuit, pub_inputs)?;
989        Ok(())
990    }
991}