1use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit, Variable};
10use ark_ff::{BigInteger, PrimeField};
11
12impl<F: PrimeField> PlonkCircuit<F> {
13 pub fn enforce_lt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
15 where
16 F: PrimeField,
17 {
18 self.check_var_bound(a)?;
19 self.check_var_bound(b)?;
20 self.enforce_lt_internal(a, b)
21 }
22
23 pub fn enforce_leq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
25 where
26 F: PrimeField,
27 {
28 let c = self.is_lt(b, a)?;
29 self.enforce_constant(c.0, F::zero())
30 }
31
32 pub fn enforce_gt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
34 where
35 F: PrimeField,
36 {
37 self.enforce_lt(b, a)
38 }
39
40 pub fn enforce_geq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
42 where
43 F: PrimeField,
44 {
45 let c = self.is_lt(a, b)?;
46 self.enforce_constant(c.into(), F::zero())
47 }
48
49 pub fn is_lt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
51 where
52 F: PrimeField,
53 {
54 self.check_var_bound(a)?;
55 self.check_var_bound(b)?;
56 self.is_lt_internal(a, b)
57 }
58
59 pub fn is_gt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
61 where
62 F: PrimeField,
63 {
64 self.is_lt(b, a)
65 }
66
67 pub fn is_leq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
69 where
70 F: PrimeField,
71 {
72 self.check_var_bound(a)?;
73 self.check_var_bound(b)?;
74 let c = self.is_lt_internal(b, a)?;
75 self.logic_neg(c)
76 }
77
78 pub fn is_geq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
80 where
81 F: PrimeField,
82 {
83 self.check_var_bound(a)?;
84 self.check_var_bound(b)?;
85 let c = self.is_lt_internal(a, b)?;
86 self.logic_neg(c)
87 }
88
89 pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
92 where
93 F: PrimeField,
94 {
95 self.check_var_bound(a)?;
96 let b = self.create_constant_variable(val)?;
97 self.is_lt(a, b)
98 }
99
100 pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
103 where
104 F: PrimeField,
105 {
106 self.check_var_bound(a)?;
107 let b = self.create_constant_variable(val)?;
108 self.is_leq(a, b)
109 }
110
111 pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
114 where
115 F: PrimeField,
116 {
117 self.check_var_bound(a)?;
118 self.is_gt_constant_internal(a, &val)
119 }
120
121 pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
124 where
125 F: PrimeField,
126 {
127 self.check_var_bound(a)?;
128 let b = self.create_constant_variable(val)?;
129 self.is_geq(a, b)
130 }
131
132 pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
135 where
136 F: PrimeField,
137 {
138 self.check_var_bound(a)?;
139 let b = self.create_constant_variable(val)?;
140 self.enforce_lt(a, b)
141 }
142
143 pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
146 where
147 F: PrimeField,
148 {
149 self.check_var_bound(a)?;
150 let b = self.create_constant_variable(val)?;
151 self.enforce_leq(a, b)
152 }
153
154 pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
157 where
158 F: PrimeField,
159 {
160 self.check_var_bound(a)?;
161 let b = self.create_constant_variable(val)?;
162 self.enforce_gt(a, b)
163 }
164
165 pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
168 where
169 F: PrimeField,
170 {
171 self.check_var_bound(a)?;
172 let b = self.create_constant_variable(val)?;
173 self.enforce_geq(a, b)
174 }
175}
176
177impl<F: PrimeField> PlonkCircuit<F> {
179 fn msb_check_internal(
184 &mut self,
185 a: Variable,
186 b: Variable,
187 ) -> Result<(BoolVar, BoolVar), CircuitError> {
188 let a_gt_const = self.is_gt_constant_internal(a, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
189 let b_gt_const = self.is_gt_constant_internal(b, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
190 let a_leq_const = self.logic_neg(a_gt_const)?;
191 let msb_check = self.logic_and(a_leq_const, b_gt_const)?;
193 let msb_eq = self.is_equal(a_gt_const.into(), b_gt_const.into())?;
196 Ok((msb_check, msb_eq))
197 }
198
199 fn is_lt_internal(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError> {
201 let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
202 let c = self.sub(a, b)?;
204 let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
205 let cmp_result = self.logic_and(msb_eq, cmp_result)?;
206
207 self.logic_or(msb_check, cmp_result)
208 }
209
210 fn enforce_lt_internal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError> {
212 let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
213 let c = self.sub(a, b)?;
215 let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
216 let cmp_result = self.logic_and(msb_eq, cmp_result)?;
217
218 self.logic_or_gate(msb_check, cmp_result)
219 }
220
221 fn is_gt_constant_internal(
225 &mut self,
226 a: Variable,
227 constant: &F,
228 ) -> Result<BoolVar, CircuitError> {
229 let a_bits_le = self.unpack(a, F::MODULUS_BIT_SIZE as usize)?;
230 let const_bits_le = constant.into_bigint().to_bits_le();
231
232 let mut zipped = const_bits_le
235 .into_iter()
236 .chain(ark_std::iter::repeat(false))
237 .take(a_bits_le.len())
238 .zip(a_bits_le.iter())
239 .skip_while(|(b, _)| *b);
240 if let Some((_, &var)) = zipped.next() {
241 zipped.try_fold(var, |current, (b, a)| -> Result<BoolVar, CircuitError> {
242 if b {
243 self.logic_and(*a, current)
244 } else {
245 self.logic_or(*a, current)
246 }
247 })
248 } else {
249 Ok(BoolVar(self.zero()))
251 }
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit};
258 use ark_bls12_377::Fq as Fq377;
259 use ark_ed_on_bls12_377::Fq as FqEd377;
260 use ark_ed_on_bls12_381::Fq as FqEd381;
261 use ark_ed_on_bn254::Fq as FqEd254;
262 use ark_ff::PrimeField;
263 use ark_std::cmp::Ordering;
264 use itertools::multizip;
265
266 #[test]
267 fn test_cmp_gates() -> Result<(), CircuitError> {
268 test_cmp_helper::<FqEd254>()?;
269 test_cmp_helper::<FqEd377>()?;
270 test_cmp_helper::<FqEd381>()?;
271 test_cmp_helper::<Fq377>()
272 }
273
274 fn test_cmp_helper<F: PrimeField>() -> Result<(), CircuitError> {
275 let list = [
276 (F::from(5u32), F::from(5u32)),
277 (F::from(1u32), F::from(2u32)),
278 (
279 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
280 F::from(2u32),
281 ),
282 (
283 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
284 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).mul(F::from(2u32)),
285 ),
286 ];
287 multizip((
288 list,
289 [Ordering::Less, Ordering::Greater],
290 [false, true],
291 [false, true],
292 )).try_for_each(
293 |((a, b), ordering, should_also_check_equality,
294 is_b_constant)|
295 -> Result<(), CircuitError> {
296 test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
297 test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?;
298 test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
299 test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)
300 },
301 )
302 }
303
304 fn test_is_cmp_helper<F: PrimeField>(
305 a: &F,
306 b: &F,
307 ordering: Ordering,
308 should_also_check_equality: bool,
309 is_b_constant: bool,
310 ) -> Result<(), CircuitError> {
311 let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
312 let expected_result = if a.cmp(b) == ordering
313 || (a.cmp(b) == Ordering::Equal && should_also_check_equality)
314 {
315 F::one()
316 } else {
317 F::zero()
318 };
319 let a = circuit.create_variable(*a)?;
320 let c: BoolVar = if is_b_constant {
321 match ordering {
322 Ordering::Less => {
323 if should_also_check_equality {
324 circuit.is_leq_constant(a, *b)?
325 } else {
326 circuit.is_lt_constant(a, *b)?
327 }
328 },
329 Ordering::Greater => {
330 if should_also_check_equality {
331 circuit.is_geq_constant(a, *b)?
332 } else {
333 circuit.is_gt_constant(a, *b)?
334 }
335 },
336 Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
338 }
339 } else {
340 let b = circuit.create_variable(*b)?;
341 match ordering {
342 Ordering::Less => {
343 if should_also_check_equality {
344 circuit.is_leq(a, b)?
345 } else {
346 circuit.is_lt(a, b)?
347 }
348 },
349 Ordering::Greater => {
350 if should_also_check_equality {
351 circuit.is_geq(a, b)?
352 } else {
353 circuit.is_gt(a, b)?
354 }
355 },
356 Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
358 }
359 };
360 assert!(circuit.witness(c.into())?.eq(&expected_result));
361 assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
362 Ok(())
363 }
364 fn test_enforce_cmp_helper<F: PrimeField>(
365 a: &F,
366 b: &F,
367 ordering: Ordering,
368 should_also_check_equality: bool,
369 is_b_constant: bool,
370 ) -> Result<(), CircuitError> {
371 let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
372 let expected_result =
373 a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality);
374 let a = circuit.create_variable(*a)?;
375 if is_b_constant {
376 match ordering {
377 Ordering::Less => {
378 if should_also_check_equality {
379 circuit.enforce_leq_constant(a, *b)?
380 } else {
381 circuit.enforce_lt_constant(a, *b)?
382 }
383 },
384 Ordering::Greater => {
385 if should_also_check_equality {
386 circuit.enforce_geq_constant(a, *b)?
387 } else {
388 circuit.enforce_gt_constant(a, *b)?
389 }
390 },
391 Ordering::Equal => (),
393 }
394 } else {
395 let b = circuit.create_variable(*b)?;
396 match ordering {
397 Ordering::Less => {
398 if should_also_check_equality {
399 circuit.enforce_leq(a, b)?
400 } else {
401 circuit.enforce_lt(a, b)?
402 }
403 },
404 Ordering::Greater => {
405 if should_also_check_equality {
406 circuit.enforce_geq(a, b)?
407 } else {
408 circuit.enforce_gt(a, b)?
409 }
410 },
411 Ordering::Equal => (),
413 }
414 };
415 if expected_result {
416 assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
417 } else {
418 assert!(circuit.check_circuit_satisfiability(&[]).is_err());
419 }
420 Ok(())
421 }
422}