1use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit, Variable};
10use ark_ff::{BigInteger, PrimeField};
11
12impl<F: PrimeField> PlonkCircuit<F> {
13 pub fn enforce_lt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
15 where
16 F: PrimeField,
17 {
18 self.check_var_bound(a)?;
19 self.check_var_bound(b)?;
20 self.enforce_lt_internal(a, b)
21 }
22
23 pub fn enforce_leq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
25 where
26 F: PrimeField,
27 {
28 let c = self.is_lt(b, a)?;
29 self.enforce_constant(c.0, F::zero())
30 }
31
32 pub fn enforce_gt(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
34 where
35 F: PrimeField,
36 {
37 self.enforce_lt(b, a)
38 }
39
40 pub fn enforce_geq(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>
42 where
43 F: PrimeField,
44 {
45 let c = self.is_lt(a, b)?;
46 self.enforce_constant(c.into(), F::zero())
47 }
48
49 pub fn is_lt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
51 where
52 F: PrimeField,
53 {
54 self.check_var_bound(a)?;
55 self.check_var_bound(b)?;
56 self.is_lt_internal(a, b)
57 }
58
59 pub fn is_gt(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
61 where
62 F: PrimeField,
63 {
64 self.is_lt(b, a)
65 }
66
67 pub fn is_leq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
69 where
70 F: PrimeField,
71 {
72 self.check_var_bound(a)?;
73 self.check_var_bound(b)?;
74 let c = self.is_lt_internal(b, a)?;
75 self.logic_neg(c)
76 }
77
78 pub fn is_geq(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError>
80 where
81 F: PrimeField,
82 {
83 self.check_var_bound(a)?;
84 self.check_var_bound(b)?;
85 let c = self.is_lt_internal(a, b)?;
86 self.logic_neg(c)
87 }
88
89 pub fn is_lt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
92 where
93 F: PrimeField,
94 {
95 self.check_var_bound(a)?;
96 let b = self.create_constant_variable(val)?;
97 self.is_lt(a, b)
98 }
99
100 pub fn is_leq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
103 where
104 F: PrimeField,
105 {
106 self.check_var_bound(a)?;
107 let b = self.create_constant_variable(val)?;
108 self.is_leq(a, b)
109 }
110
111 pub fn is_gt_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
114 where
115 F: PrimeField,
116 {
117 self.check_var_bound(a)?;
118 self.is_gt_constant_internal(a, &val)
119 }
120
121 pub fn is_geq_constant(&mut self, a: Variable, val: F) -> Result<BoolVar, CircuitError>
124 where
125 F: PrimeField,
126 {
127 self.check_var_bound(a)?;
128 let b = self.create_constant_variable(val)?;
129 self.is_geq(a, b)
130 }
131
132 pub fn enforce_lt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
135 where
136 F: PrimeField,
137 {
138 self.check_var_bound(a)?;
139 let b = self.create_constant_variable(val)?;
140 self.enforce_lt(a, b)
141 }
142
143 pub fn enforce_leq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
146 where
147 F: PrimeField,
148 {
149 self.check_var_bound(a)?;
150 let b = self.create_constant_variable(val)?;
151 self.enforce_leq(a, b)
152 }
153
154 pub fn enforce_gt_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
157 where
158 F: PrimeField,
159 {
160 self.check_var_bound(a)?;
161 let b = self.create_constant_variable(val)?;
162 self.enforce_gt(a, b)
163 }
164
165 pub fn enforce_geq_constant(&mut self, a: Variable, val: F) -> Result<(), CircuitError>
168 where
169 F: PrimeField,
170 {
171 self.check_var_bound(a)?;
172 let b = self.create_constant_variable(val)?;
173 self.enforce_geq(a, b)
174 }
175}
176
177impl<F: PrimeField> PlonkCircuit<F> {
179 fn msb_check_internal(
184 &mut self,
185 a: Variable,
186 b: Variable,
187 ) -> Result<(BoolVar, BoolVar), CircuitError> {
188 let a_gt_const = self.is_gt_constant_internal(a, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
189 let b_gt_const = self.is_gt_constant_internal(b, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
190 let a_leq_const = self.logic_neg(a_gt_const)?;
191 let msb_check = self.logic_and(a_leq_const, b_gt_const)?;
193 let msb_eq = self.is_equal(a_gt_const.into(), b_gt_const.into())?;
196 Ok((msb_check, msb_eq))
197 }
198
199 fn is_lt_internal(&mut self, a: Variable, b: Variable) -> Result<BoolVar, CircuitError> {
201 let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
202 let c = self.sub(a, b)?;
204 let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
205 let cmp_result = self.logic_and(msb_eq, cmp_result)?;
206
207 self.logic_or(msb_check, cmp_result)
208 }
209
210 fn enforce_lt_internal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError> {
212 let (msb_check, msb_eq) = self.msb_check_internal(a, b)?;
213 let c = self.sub(a, b)?;
215 let cmp_result = self.is_gt_constant_internal(c, &F::from(F::MODULUS_MINUS_ONE_DIV_TWO))?;
216 let cmp_result = self.logic_and(msb_eq, cmp_result)?;
217
218 self.logic_or_gate(msb_check, cmp_result)
219 }
220
221 fn is_gt_constant_internal(
225 &mut self,
226 a: Variable,
227 constant: &F,
228 ) -> Result<BoolVar, CircuitError> {
229 let a_bits_le = self.unpack(a, F::MODULUS_BIT_SIZE as usize)?;
230 let const_bits_le = constant.into_bigint().to_bits_le();
231
232 let mut zipped = const_bits_le
235 .into_iter()
236 .chain(ark_std::iter::repeat(false))
237 .take(a_bits_le.len())
238 .zip(a_bits_le.iter())
239 .skip_while(|(b, _)| *b);
240 if let Some((_, &var)) = zipped.next() {
241 zipped.try_fold(var, |current, (b, a)| -> Result<BoolVar, CircuitError> {
242 if b {
243 self.logic_and(*a, current)
244 } else {
245 self.logic_or(*a, current)
246 }
247 })
248 } else {
249 Ok(BoolVar(self.zero()))
251 }
252 }
253}
254
255#[cfg(test)]
256mod test {
257
258 use crate::{BoolVar, Circuit, CircuitError, PlonkCircuit};
259 use ark_bls12_377::Fq as Fq377;
260 use ark_ed_on_bls12_377::Fq as FqEd377;
261 use ark_ed_on_bls12_381::Fq as FqEd381;
262 use ark_ed_on_bn254::Fq as FqEd254;
263 use ark_ff::PrimeField;
264 use ark_std::cmp::Ordering;
265 use itertools::multizip;
266
267 #[test]
268 fn test_cmp_gates() -> Result<(), CircuitError> {
269 test_cmp_helper::<FqEd254>()?;
270 test_cmp_helper::<FqEd377>()?;
271 test_cmp_helper::<FqEd381>()?;
272 test_cmp_helper::<Fq377>()
273 }
274
275 fn test_cmp_helper<F: PrimeField>() -> Result<(), CircuitError> {
276 let list = [
277 (F::from(5u32), F::from(5u32)),
278 (F::from(1u32), F::from(2u32)),
279 (
280 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
281 F::from(2u32),
282 ),
283 (
284 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).add(F::one()),
285 F::from(F::MODULUS_MINUS_ONE_DIV_TWO).mul(F::from(2u32)),
286 ),
287 ];
288 multizip((
289 list,
290 [Ordering::Less, Ordering::Greater],
291 [false, true],
292 [false, true],
293 )).try_for_each(
294 |((a, b), ordering, should_also_check_equality,
295 is_b_constant)|
296 -> Result<(), CircuitError> {
297 test_enforce_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
298 test_enforce_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)?;
299 test_is_cmp_helper(&a, &b, ordering, should_also_check_equality, is_b_constant)?;
300 test_is_cmp_helper(&b, &a, ordering, should_also_check_equality, is_b_constant)
301 },
302 )
303 }
304
305 fn test_is_cmp_helper<F: PrimeField>(
306 a: &F,
307 b: &F,
308 ordering: Ordering,
309 should_also_check_equality: bool,
310 is_b_constant: bool,
311 ) -> Result<(), CircuitError> {
312 let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
313 let expected_result = if a.cmp(b) == ordering
314 || (a.cmp(b) == Ordering::Equal && should_also_check_equality)
315 {
316 F::one()
317 } else {
318 F::zero()
319 };
320 let a = circuit.create_variable(*a)?;
321 let c: BoolVar = if is_b_constant {
322 match ordering {
323 Ordering::Less => {
324 if should_also_check_equality {
325 circuit.is_leq_constant(a, *b)?
326 } else {
327 circuit.is_lt_constant(a, *b)?
328 }
329 },
330 Ordering::Greater => {
331 if should_also_check_equality {
332 circuit.is_geq_constant(a, *b)?
333 } else {
334 circuit.is_gt_constant(a, *b)?
335 }
336 },
337 Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
339 }
340 } else {
341 let b = circuit.create_variable(*b)?;
342 match ordering {
343 Ordering::Less => {
344 if should_also_check_equality {
345 circuit.is_leq(a, b)?
346 } else {
347 circuit.is_lt(a, b)?
348 }
349 },
350 Ordering::Greater => {
351 if should_also_check_equality {
352 circuit.is_geq(a, b)?
353 } else {
354 circuit.is_gt(a, b)?
355 }
356 },
357 Ordering::Equal => circuit.create_boolean_variable_unchecked(expected_result)?,
359 }
360 };
361 assert!(circuit.witness(c.into())?.eq(&expected_result));
362 assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
363 Ok(())
364 }
365 fn test_enforce_cmp_helper<F: PrimeField>(
366 a: &F,
367 b: &F,
368 ordering: Ordering,
369 should_also_check_equality: bool,
370 is_b_constant: bool,
371 ) -> Result<(), CircuitError> {
372 let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
373 let expected_result =
374 a.cmp(b) == ordering || (a.cmp(b) == Ordering::Equal && should_also_check_equality);
375 let a = circuit.create_variable(*a)?;
376 if is_b_constant {
377 match ordering {
378 Ordering::Less => {
379 if should_also_check_equality {
380 circuit.enforce_leq_constant(a, *b)?
381 } else {
382 circuit.enforce_lt_constant(a, *b)?
383 }
384 },
385 Ordering::Greater => {
386 if should_also_check_equality {
387 circuit.enforce_geq_constant(a, *b)?
388 } else {
389 circuit.enforce_gt_constant(a, *b)?
390 }
391 },
392 Ordering::Equal => (),
394 }
395 } else {
396 let b = circuit.create_variable(*b)?;
397 match ordering {
398 Ordering::Less => {
399 if should_also_check_equality {
400 circuit.enforce_leq(a, b)?
401 } else {
402 circuit.enforce_lt(a, b)?
403 }
404 },
405 Ordering::Greater => {
406 if should_also_check_equality {
407 circuit.enforce_geq(a, b)?
408 } else {
409 circuit.enforce_gt(a, b)?
410 }
411 },
412 Ordering::Equal => (),
414 }
415 };
416 if expected_result {
417 assert!(circuit.check_circuit_satisfiability(&[]).is_ok())
418 } else {
419 assert!(circuit.check_circuit_satisfiability(&[]).is_err());
420 }
421 Ok(())
422 }
423}