jf_relation/gadgets/
cmp.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Comparison gadgets for circuit
8
9use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit, Variable};
10use ark_ff::{BigInteger, PrimeField};
11
12impl<F: PrimeField> PlonkCircuit<F> {
13    /// Constrain that `a` < `b`.
14    pub fn enforce_lt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
15    where
16        F: PrimeField,
17    {
18        self.check_var_bound(a)?;
19        self.check_var_bound(b)?;
20        self.enforce_lt_internal(a, b)
21    }
22
23    /// Constrain that `a` <= `b`
24    pub fn enforce_leq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
25    where
26        F: PrimeField,
27    {
28        let c = self.is_lt(b, a)?;
29        self.enforce_constant(c.0, F::zero())
30    }
31
32    /// Constrain that `a` > `b`.
33    pub fn enforce_gt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
34    where
35        F: PrimeField,
36    {
37        self.enforce_lt(b, a)
38    }
39
40    /// Constrain that `a` >= `b`.
41    pub fn enforce_geq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
42    where
43        F: PrimeField,
44    {
45        let c = self.is_lt(a, b)?;
46        self.enforce_constant(c.into(), F::zero())
47    }
48
49    /// Returns a `BoolVar` indicating whether `a` < `b`.
50    pub fn is_lt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
51    where
52        F: PrimeField,
53    {
54        self.check_var_bound(a)?;
55        self.check_var_bound(b)?;
56        self.is_lt_internal(a, b)
57    }
58
59    /// Returns a `BoolVar` indicating whether `a` > `b`.
60    pub fn is_gt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
61    where
62        F: PrimeField,
63    {
64        self.is_lt(b, a)
65    }
66
67    /// Returns a `BoolVar` indicating whether `a` <= `b`.
68    pub fn is_leq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
69    where
70        F: PrimeField,
71    {
72        self.check_var_bound(a)?;
73        self.check_var_bound(b)?;
74        let c = self.is_lt_internal(b, a)?;
75        self.logic_neg(c)
76    }
77
78    /// Returns a `BoolVar` indicating whether `a` >= `b`.
79    pub fn is_geq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
80    where
81        F: PrimeField,
82    {
83        self.check_var_bound(a)?;
84        self.check_var_bound(b)?;
85        let c = self.is_lt_internal(a, b)?;
86        self.logic_neg(c)
87    }
88
89    /// Returns a `BoolVar` indicating whether the variable `a` is less than a
90    /// given constant `val`.
91    pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
92    where
93        F: PrimeField,
94    {
95        self.check_var_bound(a)?;
96        let b = self.create_constant_variable(val)?;
97        self.is_lt(a, b)
98    }
99
100    /// Returns a `BoolVar` indicating whether the variable `a` is less than or
101    /// equal to a given constant `val`.
102    pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
103    where
104        F: PrimeField,
105    {
106        self.check_var_bound(a)?;
107        let b = self.create_constant_variable(val)?;
108        self.is_leq(a, b)
109    }
110
111    /// Returns a `BoolVar` indicating whether the variable `a` is greater than
112    /// a given constant `val`.
113    pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
114    where
115        F: PrimeField,
116    {
117        self.check_var_bound(a)?;
118        self.is_gt_constant_internal(a, &val)
119    }
120
121    /// Returns a `BoolVar` indicating whether the variable `a` is greater than
122    /// or equal a given constant `val`.
123    pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
124    where
125        F: PrimeField,
126    {
127        self.check_var_bound(a)?;
128        let b = self.create_constant_variable(val)?;
129        self.is_geq(a, b)
130    }
131
132    /// Enforce the variable `a` to be less than a
133    /// given constant `val`.
134    pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
135    where
136        F: PrimeField,
137    {
138        self.check_var_bound(a)?;
139        let b = self.create_constant_variable(val)?;
140        self.enforce_lt(a, b)
141    }
142
143    /// Enforce the variable `a` to be less than or
144    /// equal to a given constant `val`.
145    pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
146    where
147        F: PrimeField,
148    {
149        self.check_var_bound(a)?;
150        let b = self.create_constant_variable(val)?;
151        self.enforce_leq(a, b)
152    }
153
154    /// Enforce the variable `a` to be greater than
155    /// a given constant `val`.
156    pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
157    where
158        F: PrimeField,
159    {
160        self.check_var_bound(a)?;
161        let b = self.create_constant_variable(val)?;
162        self.enforce_gt(a, b)
163    }
164
165    /// Enforce the variable `a` to be greater than
166    /// or equal a given constant `val`.
167    pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
168    where
169        F: PrimeField,
170    {
171        self.check_var_bound(a)?;
172        let b = self.create_constant_variable(val)?;
173        self.enforce_geq(a, b)
174    }
175}
176
177/// Private helper functions for comparison gate
178impl<F: PrimeField> PlonkCircuit<F> {
179    /// Returns 2 `BoolVar`s.
180    /// First indicates whether `a` <= (q-1)/2 and `b` > (q-1)/2.
181    /// Second indicates whether `a` and `b` are both <= (q-1)/2
182    /// or both > (q-1)/2.
183    fn msb_check_internal(
184        &mut self,
185        a: Variable,
186        b: Variable,
187    ) -> Result<(BoolVar, BoolVar), CircuitError> {
188        let a_gt_const = self.is_gt_constant_internal(a, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
189        let b_gt_const = self.is_gt_constant_internal(b, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
190        let a_leq_const = self.logic_neg(a_gt_const)?;
191        // Check whether `a` <= (q-1)/2 and `b` > (q-1)/2
192        let msb_check = self.logic_and(a_leq_const, b_gt_const)?;
193        // Check whether `a` and `b` are both <= (q-1)/2 or
194        // are both > (q-1)/2
195        let msb_eq = self.is_equal(a_gt_const.into(), b_gt_const.into())?;
196        Ok((msb_check, msb_eq))
197    }
198
199    /// Return a variable indicating whether `a` < `b`.
200    fn is_lt_internal(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError> {
201        let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
202        // check whether (a-b) > (q-1)/2
203        let c = self.sub(a, b)?;
204        let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
205        let cmp_result = self.logic_and(msb_eq, cmp_result)?;
206
207        self.logic_or(msb_check, cmp_result)
208    }
209
210    /// Constrain that `a` < `b`
211    fn enforce_lt_internal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError> {
212        let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
213        // check whether (a-b) <= (q-1)/2
214        let c = self.sub(a, b)?;
215        let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
216        let cmp_result = self.logic_and(msb_eq, cmp_result)?;
217
218        self.logic_or_gate(msb_check, cmp_result)
219    }
220
221    /// Helper function to check whether `a` is greater than a given
222    /// constant. Let N = F::MODULUS_BIT_SIZE, it assumes that the
223    /// constant < 2^N. And it uses at most N AND/OR gates.
224    fn is_gt_constant_internal(
225        &mut self,
226        a: Variable,
227        constant: &F,
228    ) -> Result<BoolVar, CircuitError> {
229        let a_bits_le = self.unpack(a, F::MODULUS_BIT_SIZE as usize)?;
230        let const_bits_le = constant.into_bigint().to_bits_le();
231
232        // Iterating from LSB to MSB. Skip the front consecutive 1's.
233        // Put an OR gate for bit 0 and an AND gate for bit 1.
234        let mut zipped = const_bits_le
235            .into_iter()
236            .chain(ark_std::iter::repeat(false))
237            .take(a_bits_le.len())
238            .zip(a_bits_le.iter())
239            .skip_while(|(b, _)| *b);
240        if let Some((_, &var)) = zipped.next() {
241            zipped.try_fold(var, |current, (b, a)| -> Result<BoolVar, CircuitError> {
242                if b {
243                    self.logic_and(*a, current)
244                } else {
245                    self.logic_or(*a, current)
246                }
247            })
248        } else {
249            // the constant is all one
250            Ok(BoolVar(self.zero()))
251        }
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit};
258    use ark_bls12_377::Fq as Fq377;
259    use ark_ed_on_bls12_377::Fq as FqEd377;
260    use ark_ed_on_bls12_381::Fq as FqEd381;
261    use ark_ed_on_bn254::Fq as FqEd254;
262    use ark_ff::PrimeField;
263    use ark_std::cmp::Ordering;
264    use itertools::multizip;
265
266    #[test]
267    fn test_cmp_gates() -> Result<(), CircuitError> {
268        test_cmp_helper::<FqEd254>()?;
269        test_cmp_helper::<FqEd377>()?;
270        test_cmp_helper::<FqEd381>()?;
271        test_cmp_helper::<Fq377>()
272    }
273
274    fn test_cmp_helper<F: PrimeField>() -> Result<(), CircuitError> {
275        let list = [
276            (F::from(5u32), F::from(5u32)),
277            (F::from(1u32), F::from(2u32)),
278            (
279                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
280                F::from(2u32),
281            ),
282            (
283                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
284                F::from(F::MODULUS_MINUS_ONE_DIV_TWO).mul(F::from(2u32)),
285            ),
286        ];
287        multizip((
288            list,
289            [Ordering::Less, Ordering::Greater],
290            [false, true],
291            [false, true],
292        )).try_for_each(
293                |((a, b), ordering, should_also_check_equality,
294                 is_b_constant)|
295                 -> Result<(), CircuitError> {
296                    test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
297                    test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?;
298                    test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
299                    test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)
300                },
301            )
302    }
303
304    fn test_is_cmp_helper<F: PrimeField>(
305        a: &F,
306        b: &F,
307        ordering: Ordering,
308        should_also_check_equality: bool,
309        is_b_constant: bool,
310    ) -> Result<(), CircuitError> {
311        let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
312        let expected_result = if a.cmp(b) == ordering
313            || (a.cmp(b) == Ordering::Equal && should_also_check_equality)
314        {
315            F::one()
316        } else {
317            F::zero()
318        };
319        let a = circuit.create_variable(*a)?;
320        let c: BoolVar = if is_b_constant {
321            match ordering {
322                Ordering::Less => {
323                    if should_also_check_equality {
324                        circuit.is_leq_constant(a, *b)?
325                    } else {
326                        circuit.is_lt_constant(a, *b)?
327                    }
328                },
329                Ordering::Greater => {
330                    if should_also_check_equality {
331                        circuit.is_geq_constant(a, *b)?
332                    } else {
333                        circuit.is_gt_constant(a, *b)?
334                    }
335                },
336                // Equality test will be handled elsewhere, comparison gate test will not enter here
337                Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
338            }
339        } else {
340            let b = circuit.create_variable(*b)?;
341            match ordering {
342                Ordering::Less => {
343                    if should_also_check_equality {
344                        circuit.is_leq(a, b)?
345                    } else {
346                        circuit.is_lt(a, b)?
347                    }
348                },
349                Ordering::Greater => {
350                    if should_also_check_equality {
351                        circuit.is_geq(a, b)?
352                    } else {
353                        circuit.is_gt(a, b)?
354                    }
355                },
356                // Equality test will be handled elsewhere, comparison gate test will not enter here
357                Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
358            }
359        };
360        assert!(circuit.witness(c.into())?.eq(&expected_result));
361        assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
362        Ok(())
363    }
364    fn test_enforce_cmp_helper<F: PrimeField>(
365        a: &F,
366        b: &F,
367        ordering: Ordering,
368        should_also_check_equality: bool,
369        is_b_constant: bool,
370    ) -> Result<(), CircuitError> {
371        let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
372        let expected_result =
373            a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality);
374        let a = circuit.create_variable(*a)?;
375        if is_b_constant {
376            match ordering {
377                Ordering::Less => {
378                    if should_also_check_equality {
379                        circuit.enforce_leq_constant(a, *b)?
380                    } else {
381                        circuit.enforce_lt_constant(a, *b)?
382                    }
383                },
384                Ordering::Greater => {
385                    if should_also_check_equality {
386                        circuit.enforce_geq_constant(a, *b)?
387                    } else {
388                        circuit.enforce_gt_constant(a, *b)?
389                    }
390                },
391                // Equality test will be handled elsewhere, comparison gate test will not enter here
392                Ordering::Equal => (),
393            }
394        } else {
395            let b = circuit.create_variable(*b)?;
396            match ordering {
397                Ordering::Less => {
398                    if should_also_check_equality {
399                        circuit.enforce_leq(a, b)?
400                    } else {
401                        circuit.enforce_lt(a, b)?
402                    }
403                },
404                Ordering::Greater => {
405                    if should_also_check_equality {
406                        circuit.enforce_geq(a, b)?
407                    } else {
408                        circuit.enforce_gt(a, b)?
409                    }
410                },
411                // Equality test will be handled elsewhere, comparison gate test will not enter here
412                Ordering::Equal => (),
413            }
414        };
415        if expected_result {
416            assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
417        } else {
418            assert!(circuit.check_circuit_satisfiability(&[]).is_err());
419        }
420        Ok(())
421    }
422}