jf_relation/gadgets/
cmp.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Comparison gadgets for circuit
8
9use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit, Variable};
10use ark_ff::{BigInteger, PrimeField};
11
12impl<F: PrimeField> PlonkCircuit<F> {
13    /// Constrain that `a` < `b`.
14    pub fn enforce_lt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
15    where
16        F: PrimeField,
17    {
18        self.check_var_bound(a)?;
19        self.check_var_bound(b)?;
20        self.enforce_lt_internal(a, b)
21    }
22
23    /// Constrain that `a` <= `b`
24    pub fn enforce_leq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
25    where
26        F: PrimeField,
27    {
28        let c = self.is_lt(b, a)?;
29        self.enforce_constant(c.0, F::zero())
30    }
31
32    /// Constrain that `a` > `b`.
33    pub fn enforce_gt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
34    where
35        F: PrimeField,
36    {
37        self.enforce_lt(b, a)
38    }
39
40    /// Constrain that `a` >= `b`.
41    pub fn enforce_geq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
42    where
43        F: PrimeField,
44    {
45        let c = self.is_lt(a, b)?;
46        self.enforce_constant(c.into(), F::zero())
47    }
48
49    /// Returns a `BoolVar` indicating whether `a` < `b`.
50    pub fn is_lt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
51    where
52        F: PrimeField,
53    {
54        self.check_var_bound(a)?;
55        self.check_var_bound(b)?;
56        self.is_lt_internal(a, b)
57    }
58
59    /// Returns a `BoolVar` indicating whether `a` > `b`.
60    pub fn is_gt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
61    where
62        F: PrimeField,
63    {
64        self.is_lt(b, a)
65    }
66
67    /// Returns a `BoolVar` indicating whether `a` <= `b`.
68    pub fn is_leq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
69    where
70        F: PrimeField,
71    {
72        self.check_var_bound(a)?;
73        self.check_var_bound(b)?;
74        let c = self.is_lt_internal(b, a)?;
75        self.logic_neg(c)
76    }
77
78    /// Returns a `BoolVar` indicating whether `a` >= `b`.
79    pub fn is_geq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
80    where
81        F: PrimeField,
82    {
83        self.check_var_bound(a)?;
84        self.check_var_bound(b)?;
85        let c = self.is_lt_internal(a, b)?;
86        self.logic_neg(c)
87    }
88
89    /// Returns a `BoolVar` indicating whether the variable `a` is less than a
90    /// given constant `val`.
91    pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
92    where
93        F: PrimeField,
94    {
95        self.check_var_bound(a)?;
96        let b = self.create_constant_variable(val)?;
97        self.is_lt(a, b)
98    }
99
100    /// Returns a `BoolVar` indicating whether the variable `a` is less than or
101    /// equal to a given constant `val`.
102    pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
103    where
104        F: PrimeField,
105    {
106        self.check_var_bound(a)?;
107        let b = self.create_constant_variable(val)?;
108        self.is_leq(a, b)
109    }
110
111    /// Returns a `BoolVar` indicating whether the variable `a` is greater than
112    /// a given constant `val`.
113    pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
114    where
115        F: PrimeField,
116    {
117        self.check_var_bound(a)?;
118        self.is_gt_constant_internal(a, &val)
119    }
120
121    /// Returns a `BoolVar` indicating whether the variable `a` is greater than
122    /// or equal a given constant `val`.
123    pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
124    where
125        F: PrimeField,
126    {
127        self.check_var_bound(a)?;
128        let b = self.create_constant_variable(val)?;
129        self.is_geq(a, b)
130    }
131
132    /// Enforce the variable `a` to be less than a
133    /// given constant `val`.
134    pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
135    where
136        F: PrimeField,
137    {
138        self.check_var_bound(a)?;
139        let b = self.create_constant_variable(val)?;
140        self.enforce_lt(a, b)
141    }
142
143    /// Enforce the variable `a` to be less than or
144    /// equal to a given constant `val`.
145    pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
146    where
147        F: PrimeField,
148    {
149        self.check_var_bound(a)?;
150        let b = self.create_constant_variable(val)?;
151        self.enforce_leq(a, b)
152    }
153
154    /// Enforce the variable `a` to be greater than
155    /// a given constant `val`.
156    pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
157    where
158        F: PrimeField,
159    {
160        self.check_var_bound(a)?;
161        let b = self.create_constant_variable(val)?;
162        self.enforce_gt(a, b)
163    }
164
165    /// Enforce the variable `a` to be greater than
166    /// or equal a given constant `val`.
167    pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
168    where
169        F: PrimeField,
170    {
171        self.check_var_bound(a)?;
172        let b = self.create_constant_variable(val)?;
173        self.enforce_geq(a, b)
174    }
175}
176
177/// Private helper functions for comparison gate
178impl<F: PrimeField> PlonkCircuit<F> {
179    /// Returns 2 `BoolVar`s.
180    /// First indicates whether `a` <= (q-1)/2 and `b` > (q-1)/2.
181    /// Second indicates whether `a` and `b` are both <= (q-1)/2
182    /// or both > (q-1)/2.
183    fn msb_check_internal(
184        &mut self,
185        a: Variable,
186        b: Variable,
187    ) -> Result<(BoolVar, BoolVar), CircuitError> {
188        let a_gt_const = self.is_gt_constant_internal(a, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
189        let b_gt_const = self.is_gt_constant_internal(b, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
190        let a_leq_const = self.logic_neg(a_gt_const)?;
191        // Check whether `a` <= (q-1)/2 and `b` > (q-1)/2
192        let msb_check = self.logic_and(a_leq_const, b_gt_const)?;
193        // Check whether `a` and `b` are both <= (q-1)/2 or
194        // are both > (q-1)/2
195        let msb_eq = self.is_equal(a_gt_const.into(), b_gt_const.into())?;
196        Ok((msb_check, msb_eq))
197    }
198
199    /// Return a variable indicating whether `a` < `b`.
200    fn is_lt_internal(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError> {
201        let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
202        // check whether (a-b) > (q-1)/2
203        let c = self.sub(a, b)?;
204        let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
205        let cmp_result = self.logic_and(msb_eq, cmp_result)?;
206
207        self.logic_or(msb_check, cmp_result)
208    }
209
210    /// Constrain that `a` < `b`
211    fn enforce_lt_internal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError> {
212        let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
213        // check whether (a-b) <= (q-1)/2
214        let c = self.sub(a, b)?;
215        let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
216        let cmp_result = self.logic_and(msb_eq, cmp_result)?;
217
218        self.logic_or_gate(msb_check, cmp_result)
219    }
220
221    /// Helper function to check whether `a` is greater than a given
222    /// constant. Let N = F::MODULUS_BIT_SIZE, it assumes that the
223    /// constant < 2^N. And it uses at most N AND/OR gates.
224    fn is_gt_constant_internal(
225        &mut self,
226        a: Variable,
227        constant: &F,
228    ) -> Result<BoolVar, CircuitError> {
229        let a_bits_le = self.unpack(a, F::MODULUS_BIT_SIZE as usize)?;
230        let const_bits_le = constant.into_bigint().to_bits_le();
231
232        // Iterating from LSB to MSB. Skip the front consecutive 1's.
233        // Put an OR gate for bit 0 and an AND gate for bit 1.
234        let mut zipped = const_bits_le
235            .into_iter()
236            .chain(ark_std::iter::repeat(false))
237            .take(a_bits_le.len())
238            .zip(a_bits_le.iter())
239            .skip_while(|(b, _)| *b);
240        if let Some((_, &var)) = zipped.next() {
241            zipped.try_fold(var, |current, (b, a)| -> Result<BoolVar, CircuitError> {
242                if b {
243                    self.logic_and(*a, current)
244                } else {
245                    self.logic_or(*a, current)
246                }
247            })
248        } else {
249            // the constant is all one
250            Ok(BoolVar(self.zero()))
251        }
252    }
253}
254
255#[cfg(test)]
256mod test {
257
258    use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit};
259    use ark_bls12_377::Fq as Fq377;
260    use ark_ed_on_bls12_377::Fq as FqEd377;
261    use ark_ed_on_bls12_381::Fq as FqEd381;
262    use ark_ed_on_bn254::Fq as FqEd254;
263    use ark_ff::PrimeField;
264    use ark_std::cmp::Ordering;
265    use itertools::multizip;
266
267    #[test]
268    fn test_cmp_gates() -> Result<(), CircuitError> {
269        test_cmp_helper::<FqEd254>()?;
270        test_cmp_helper::<FqEd377>()?;
271        test_cmp_helper::<FqEd381>()?;
272        test_cmp_helper::<Fq377>()
273    }
274
275    fn test_cmp_helper<F: PrimeField>() -> Result<(), CircuitError> {
276        let list = [
277            (F::from(5u32), F::from(5u32)),
278            (F::from(1u32), F::from(2u32)),
279            (
280                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
281                F::from(2u32),
282            ),
283            (
284                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
285                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).mul(F::from(2u32)),
286            ),
287        ];
288        multizip((
289            list,
290            [Ordering::Less, Ordering::Greater],
291            [false, true],
292            [false, true],
293        )).try_for_each(
294                |((a, b), ordering, should_also_check_equality,
295                 is_b_constant)|
296                 -> Result<(), CircuitError> {
297                    test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
298                    test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?;
299                    test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
300                    test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)
301                },
302            )
303    }
304
305    fn test_is_cmp_helper<F: PrimeField>(
306        a: &F,
307        b: &F,
308        ordering: Ordering,
309        should_also_check_equality: bool,
310        is_b_constant: bool,
311    ) -> Result<(), CircuitError> {
312        let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
313        let expected_result = if a.cmp(b) == ordering
314            || (a.cmp(b) == Ordering::Equal && should_also_check_equality)
315        {
316            F::one()
317        } else {
318            F::zero()
319        };
320        let a = circuit.create_variable(*a)?;
321        let c: BoolVar = if is_b_constant {
322            match ordering {
323                Ordering::Less => {
324                    if should_also_check_equality {
325                        circuit.is_leq_constant(a, *b)?
326                    } else {
327                        circuit.is_lt_constant(a, *b)?
328                    }
329                },
330                Ordering::Greater => {
331                    if should_also_check_equality {
332                        circuit.is_geq_constant(a, *b)?
333                    } else {
334                        circuit.is_gt_constant(a, *b)?
335                    }
336                },
337                // Equality test will be handled elsewhere, comparison gate test will not enter here
338                Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
339            }
340        } else {
341            let b = circuit.create_variable(*b)?;
342            match ordering {
343                Ordering::Less => {
344                    if should_also_check_equality {
345                        circuit.is_leq(a, b)?
346                    } else {
347                        circuit.is_lt(a, b)?
348                    }
349                },
350                Ordering::Greater => {
351                    if should_also_check_equality {
352                        circuit.is_geq(a, b)?
353                    } else {
354                        circuit.is_gt(a, b)?
355                    }
356                },
357                // Equality test will be handled elsewhere, comparison gate test will not enter here
358                Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
359            }
360        };
361        assert!(circuit.witness(c.into())?.eq(&expected_result));
362        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
363        Ok(())
364    }
365    fn test_enforce_cmp_helper<F: PrimeField>(
366        a: &F,
367        b: &F,
368        ordering: Ordering,
369        should_also_check_equality: bool,
370        is_b_constant: bool,
371    ) -> Result<(), CircuitError> {
372        let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
373        let expected_result =
374            a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality);
375        let a = circuit.create_variable(*a)?;
376        if is_b_constant {
377            match ordering {
378                Ordering::Less => {
379                    if should_also_check_equality {
380                        circuit.enforce_leq_constant(a, *b)?
381                    } else {
382                        circuit.enforce_lt_constant(a, *b)?
383                    }
384                },
385                Ordering::Greater => {
386                    if should_also_check_equality {
387                        circuit.enforce_geq_constant(a, *b)?
388                    } else {
389                        circuit.enforce_gt_constant(a, *b)?
390                    }
391                },
392                // Equality test will be handled elsewhere, comparison gate test will not enter here
393                Ordering::Equal => (),
394            }
395        } else {
396            let b = circuit.create_variable(*b)?;
397            match ordering {
398                Ordering::Less => {
399                    if should_also_check_equality {
400                        circuit.enforce_leq(a, b)?
401                    } else {
402                        circuit.enforce_lt(a, b)?
403                    }
404                },
405                Ordering::Greater => {
406                    if should_also_check_equality {
407                        circuit.enforce_geq(a, b)?
408                    } else {
409                        circuit.enforce_gt(a, b)?
410                    }
411                },
412                // Equality test will be handled elsewhere, comparison gate test will not enter here
413                Ordering::Equal => (),
414            }
415        };
416        if expected_result {
417            assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
418        } else {
419            assert!(circuit.check_circuit_satisfiability(&[]).is_err());
420        }
421        Ok(())
422    }
423}