jf_relation/gadgets/ecc/
glv.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7use crate::{
8    gadgets::ecc::{MultiScalarMultiplicationCircuit, PointVariable},
9    BoolVar, Circuit, CircuitError, PlonkCircuit, Variable,
10};
11use ark_ec::{
12    twisted_edwards::{Projective, TECurveConfig},
13    CurveGroup,
14};
15use ark_ff::{PrimeField, Zero};
16use jf_utils::field_switching;
17use num_bigint::{BigInt, BigUint};
18
19use super::TEPoint;
20
21// phi(P) = lambda*P for all P
22// constants that are used to calculate phi(P)
23// see <https://eprint.iacr.org/2021/1152>
24const COEFF_B: [u8; 32] = [
25    180, 16, 37, 23, 77, 1, 15, 238, 214, 244, 154, 13, 119, 18, 167, 46, 136, 26, 81, 99, 58, 13,
26    240, 97, 165, 38, 132, 130, 139, 242, 201, 82,
27];
28
29const COEFF_C: [u8; 32] = [
30    61, 11, 101, 223, 108, 128, 92, 81, 233, 244, 54, 255, 207, 171, 86, 132, 7, 209, 23, 108, 253,
31    110, 124, 169, 195, 87, 84, 134, 207, 36, 198, 108,
32];
33/// The lambda parameter for decomposition.
34const LAMBDA: [u8; 32] = [
35    5, 223, 131, 135, 64, 33, 61, 209, 110, 5, 165, 112, 185, 157, 196, 207, 43, 199, 56, 43, 86,
36    73, 248, 237, 147, 164, 57, 74, 220, 243, 180, 19,
37];
38/// Lower bits of Lambda, s.t. LAMBDA = LAMBDA_1 + 2^128 LAMBDA_2
39const LAMBDA_1: [u8; 32] = [
40    5, 223, 131, 135, 64, 33, 61, 209, 110, 5, 165, 112, 185, 157, 196, 207, 0, 0, 0, 0, 0, 0, 0,
41    0, 0, 0, 0, 0, 0, 0, 0, 0,
42];
43/// Higher bits of Lambda, s.t.
44// LAMBDA = LAMBDA_1 + 2^128 LAMBDA_2
45const LAMBDA_2: [u8; 32] = [
46    43, 199, 56, 43, 86, 73, 248, 237, 147, 164, 57, 74, 220, 243, 180, 19, 0, 0, 0, 0, 0, 0, 0, 0,
47    0, 0, 0, 0, 0, 0, 0, 0,
48];
49/// Lower bits of r, s.t. r = r1 +
50// 2^128 r2
51const R1: [u8; 32] = [
52    225, 231, 118, 40, 181, 6, 253, 116, 113, 4, 25, 116, 0, 135, 143, 255, 0, 0, 0, 0, 0, 0, 0, 0,
53    0, 0, 0, 0, 0, 0, 0, 0,
54];
55/// Higher bits of r, s.t. r = r1
56// + 2^128 r2
57const R2: [u8; 32] = [
58    0, 118, 104, 2, 2, 118, 206, 12, 82, 95, 103, 202, 212, 105, 251, 28, 0, 0, 0, 0, 0, 0, 0, 0,
59    0, 0, 0, 0, 0, 0, 0, 0,
60];
61
62const COEFF_N11: [u8; 32] = [
63    31, 24, 137, 151, 74, 249, 2, 75, 142, 146, 230, 75, 0, 226, 95, 85, 0, 0, 0, 0, 0, 0, 0, 0, 0,
64    0, 0, 0, 0, 0, 0, 0,
65];
66
67const COEFF_N12: [u8; 32] = [
68    68, 31, 214, 35, 26, 89, 226, 248, 93, 143, 94, 229, 238, 179, 20, 8, 0, 0, 0, 0, 0, 0, 0, 0,
69    0, 0, 0, 0, 0, 0, 0, 0,
70];
71
72const COEFF_N21: [u8; 32] = [
73    136, 62, 172, 71, 52, 178, 196, 241, 187, 30, 189, 202, 221, 103, 41, 16, 0, 0, 0, 0, 0, 0, 0,
74    0, 0, 0, 0, 0, 0, 0, 0, 0,
75];
76const COEFF_N22: [u8; 32] = [
77    194, 207, 237, 144, 106, 13, 250, 41, 227, 113, 50, 40, 0, 165, 47, 170, 0, 118, 104, 2, 2,
78    118, 206, 12, 82, 95, 103, 202, 212, 105, 251, 28,
79];
80
81// GLV related gates
82impl<F> PlonkCircuit<F>
83where
84    F: PrimeField,
85{
86    /// Perform GLV multiplication in circuit (which costs a few less
87    /// constraints).
88    pub fn glv_mul<P: TECurveConfig<BaseField = F>>(
89        &mut self,
90        scalar: Variable,
91        base: &PointVariable,
92    ) -> Result<PointVariable, CircuitError> {
93        self.check_var_bound(scalar)?;
94        self.check_point_var_bound(base)?;
95
96        let (s1_var, s2_var, s2_sign_var) =
97            scalar_decomposition_gate::<P::BaseField, P::ScalarField>(self, &scalar)?;
98
99        let endo_base_var = endomorphism_circuit::<_, P>(self, base)?;
100        multi_scalar_mul_circuit::<_, P>(self, base, s1_var, &endo_base_var, s2_var, s2_sign_var)
101    }
102}
103
104/// The circuit for 2 base scalar multiplication with scalar bit length 128.
105fn multi_scalar_mul_circuit<F, P>(
106    circuit: &mut PlonkCircuit<F>,
107    base: &PointVariable,
108    scalar_1: Variable,
109    endo_base: &PointVariable,
110    scalar_2: Variable,
111    scalar_2_sign_var: BoolVar,
112) -> Result<PointVariable, CircuitError>
113where
114    F: PrimeField,
115    P: TECurveConfig<BaseField = F>,
116{
117    let endo_base_neg = circuit.inverse_point(endo_base)?;
118    let endo_base =
119        circuit.binary_point_vars_select(scalar_2_sign_var, endo_base, &endo_base_neg)?;
120
121    MultiScalarMultiplicationCircuit::<F, P>::msm_with_var_scalar_length(
122        circuit,
123        &[*base, endo_base],
124        &[scalar_1, scalar_2],
125        128,
126    )
127}
128
129/// Mapping a point G to phi(G):= lambda G where phi is the endomorphism
130fn endomorphism<F, P>(base: &TEPoint<F>) -> TEPoint<F>
131where
132    F: PrimeField,
133    P: TECurveConfig<BaseField = F>,
134{
135    let x = base.get_x();
136    let y = base.get_y();
137    let b = F::from_le_bytes_mod_order(COEFF_B.as_ref());
138    let c = F::from_le_bytes_mod_order(COEFF_C.as_ref());
139
140    let xy = x * y;
141    let y_square = y * y;
142    let f_y = c * (F::one() - y_square);
143    let g_y = b * (y_square + b);
144    let h_y = y_square - b;
145
146    Projective::<P>::new(f_y * h_y, g_y * xy, F::one(), h_y * xy)
147        .into_affine()
148        .into()
149}
150
151/// The circuit for computing the point endomorphism.
152fn endomorphism_circuit<F, P>(
153    circuit: &mut PlonkCircuit<F>,
154    point_var: &PointVariable,
155) -> Result<PointVariable, CircuitError>
156where
157    F: PrimeField,
158    P: TECurveConfig<BaseField = F>,
159{
160    let base = circuit.point_witness(point_var)?;
161    let endo_point = endomorphism::<_, P>(&base);
162    let endo_point_var = circuit.create_point_variable(endo_point)?;
163
164    let b = F::from_le_bytes_mod_order(COEFF_B.as_ref());
165    let c = F::from_le_bytes_mod_order(COEFF_C.as_ref());
166    let b_square = b * b;
167
168    let x_var = point_var.get_x();
169    let y_var = point_var.get_y();
170
171    // xy = x * y
172    let xy_var = circuit.mul(x_var, y_var)?;
173
174    // f(y) = c(1 - y^2)
175    let wire = [y_var, y_var, circuit.zero(), circuit.zero()];
176    let coeff = [F::zero(), F::zero(), F::zero(), F::zero()];
177    let q_mul = [-c, F::zero()];
178    let q_c = c;
179    let f_y_var = circuit.gen_quad_poly(&wire, &coeff, &q_mul, q_c)?;
180
181    // g(y) = b(y^2 + b)
182    let wire = [y_var, y_var, circuit.zero(), circuit.zero()];
183    let coeff = [F::zero(), F::zero(), F::zero(), F::zero()];
184    let q_mul = [b, F::zero()];
185    let q_c = b_square;
186    let g_y_var = circuit.gen_quad_poly(&wire, &coeff, &q_mul, q_c)?;
187
188    // h(y) = y^2 - b
189    let wire = [y_var, y_var, circuit.zero(), circuit.zero()];
190    let coeff = [F::zero(), F::zero(), F::zero(), F::zero()];
191    let q_mul = [F::one(), F::zero()];
192    let q_c = -b;
193    let h_y_var = circuit.gen_quad_poly(&wire, &coeff, &q_mul, q_c)?;
194
195    // res_x = f(y) / (xy)
196    circuit.mul_gate(endo_point_var.get_x(), xy_var, f_y_var)?;
197    // res_y = g(y) / h(y)
198    circuit.mul_gate(endo_point_var.get_y(), h_y_var, g_y_var)?;
199
200    Ok(endo_point_var)
201}
202
203/// Decompose a scalar s into k1, k2, s.t.
204///     scalar = k1 - k2_sign * k2 * lambda
205/// via a Babai's nearest plane algorithm
206/// Guarantees that k1 and k2 are less than 128 bits.
207fn scalar_decomposition<F: PrimeField>(scalar: &F) -> (F, F, bool) {
208    let scalar_z: BigUint = (*scalar).into();
209
210    let tmp = F::from_le_bytes_mod_order(COEFF_N11.as_ref());
211    let n11: BigUint = tmp.into();
212
213    let tmp = F::from_le_bytes_mod_order(COEFF_N12.as_ref());
214    let n12: BigUint = tmp.into();
215
216    let tmp = F::from_le_bytes_mod_order(COEFF_N21.as_ref());
217    let n21: BigUint = tmp.into();
218
219    let tmp = F::from_le_bytes_mod_order(COEFF_N22.as_ref());
220    let n22: BigUint = tmp.into();
221
222    let r: BigUint = F::MODULUS.into();
223    let r_over_2 = &r / BigUint::from(2u8);
224
225    // beta = vector([n,0]) * self.curve.N_inv
226    let beta_1 = &scalar_z * &n11;
227    let beta_2 = &scalar_z * &n12;
228
229    let beta_1 = &beta_1 / &r;
230    let beta_2 = &beta_2 / &r;
231
232    // b = vector([int(beta[0]), int(beta[1])]) * self.curve.N
233    let b1: BigUint = &beta_1 * &n11 + &beta_2 * &n21;
234    let b2: BigUint = (&beta_1 * &n12 + &beta_2 * &n22) % r;
235
236    let k1 = F::from(scalar_z - b1);
237    let is_k2_pos = b2 < r_over_2;
238
239    let k2 = if is_k2_pos { F::from(b2) } else { -F::from(b2) };
240
241    (k1, k2, is_k2_pos)
242}
243
244macro_rules! fq_to_big_int {
245    ($fq: expr) => {
246        <BigInt as From<BigUint>>::from($fq.into_bigint().into())
247    };
248}
249
250macro_rules! int_to_fq {
251    ($in: expr) => {
252        F::from_le_bytes_mod_order(&$in.to_bytes_le().1)
253    };
254}
255
256// Input a scalar s as in Fq wires,
257// compute k1, k2 and a k2_sign s.t.
258//  s = k1 - k2_sign * k2 * lambda mod |Fr|
259// where
260// * s ~ 253 bits, private input
261// * lambda ~ 253 bits, public input
262// * k1, k2 each ~ 128 bits, private inputs
263// * k2_sign - Boolean, private inputs
264// Return the variables for k1 and k2
265// and sign bit for k2.
266#[allow(clippy::type_complexity)]
267fn scalar_decomposition_gate<F, S>(
268    circuit: &mut PlonkCircuit<F>,
269    s_var: &Variable,
270) -> Result<(Variable, Variable, BoolVar), CircuitError>
271where
272    F: PrimeField,
273    S: PrimeField,
274{
275    // the order of scalar field
276    // r = 13108968793781547619861935127046491459309155893440570251786403306729687672801 < 2^253
277    // q = 52435875175126190479447740508185965837690552500527637822603658699938581184513 < 2^255
278
279    // for an input scalar s,
280    // we need to prove the following statement over ZZ
281    //
282    // (0) lambda * k2_sign * k2 + s = t * Fr::modulus + k1
283    //
284    // for some t, where
285    // * t < (k2 + 1) < 2^128
286    // * k1, k2 < sqrt{2r} < 2^128
287    // * lambda, s, modulus are ~253 bits
288    //
289    // which becomes
290    // (1) lambda_1 * k2_sign * k2 + 2^128 lambda_2 * k2_sign * k2 + s
291    //        - t * r1 - t *2^128 r2 - k1 = 0
292    // where
293    // (2) lambda = lambda_1 + 2^128 lambda_2   <- public info
294    // (3) Fr::modulus = r1 + 2^128 r2          <- public info
295    // with
296    //  lambda_1 and r1 < 2^128
297    //  lambda_2 and r2 < 2^125
298    //
299    // reorganizing (1) gives us
300    // (4)          lambda_1 * k2_sign * k2 + s - t * r1 - k1
301    //     + 2^128 (lambda_2 * k2_sign * k2 - t * r2)
302    //     = 0
303    //
304    // Now set
305    // (5) tmp = lambda_1 * k2_sign * k2 + s - t * r1 - k1
306    // with
307    // (6) tmp = tmp1 + 2^128 tmp2
308    // for tmp1 < 2^128 and tmp2 < 2^128
309    //
310    // that is
311    // tmp1 will be the lower 128 bits of
312    //     lambda * k2_sign * k2 + s - t * Fr::modulus + k1
313    // which will be 0 due to (0).
314    // (7) tmp1 =  (lambda_1 * k2_sign * k2 + s - t * r1 - k1) % 2^128 = 0
315    // note that t * r1 < 2^254
316    //
317    // i.e. tmp2 will be the carrier overflowing 2^128,
318    // and on the 2^128 term, we have
319    // (8) tmp2 + lambda_2 * k2_sign * k2 - t * r2 = 0
320    // also due to (0).
321    //
322    // the concrete statements that we need to prove (0) are
323    //  (a) k1 < 2^128
324    //  (b) k2 < 2^128
325    //  (c) tmp1 = 0
326    //  (d) tmp2 < 2^128
327    //  (e) tmp = tmp1 + 2^128 tmp2
328    //  (f) tmp =  lambda_1 * k2_sign * k2 + s - t * r1 - k1
329    //  (g) tmp2 + lambda_2 * k2_sign * k2   = t * r2
330    // which can all be evaluated over Fq without overflow
331
332    // ============================================
333    // step 1: build integers
334    // ============================================
335    // 2^128
336    let two_to_128 = BigInt::from(2u64).pow(128);
337
338    // s
339    let s = circuit.witness(*s_var)?;
340    let s_int = fq_to_big_int!(s);
341    let s_fr = field_switching::<_, S>(&s);
342
343    // lambda = lambda_1 + 2^128 lambda_2
344    let lambda = F::from_le_bytes_mod_order(LAMBDA.as_ref());
345    let lambda_1 = F::from_le_bytes_mod_order(LAMBDA_1.as_ref());
346
347    let lambda_int = fq_to_big_int!(lambda);
348    let lambda_1_int = fq_to_big_int!(lambda_1);
349    let lambda_2 = F::from_le_bytes_mod_order(LAMBDA_2.as_ref());
350
351    // s = k1 - lambda * k2 * k2_sign
352    let (k1, k2, is_k2_positive) = scalar_decomposition(&s_fr);
353    let k1_int = fq_to_big_int!(k1);
354    let k2_int = fq_to_big_int!(k2);
355    let k2_sign = if is_k2_positive {
356        BigInt::from(1)
357    } else {
358        BigInt::from(-1)
359    };
360    let k2_with_sign = &k2_int * &k2_sign;
361
362    // fr_order = r1 + 2^128 r2
363    let fr_order_uint: BigUint = S::MODULUS.into();
364    let fr_order_int: BigInt = fr_order_uint.into();
365    let r1 = F::from_le_bytes_mod_order(R1.as_ref());
366    let r1_int = fq_to_big_int!(r1);
367    let r2 = F::from_le_bytes_mod_order(R2.as_ref());
368
369    // t * t_sign = (lambda * k2 * k2_sign + s - k1) / fr_order
370    let mut t_int = (&lambda_int * &k2_with_sign + &s_int - &k1_int) / &fr_order_int;
371    let t_int_sign = if t_int < BigInt::zero() {
372        t_int = -t_int;
373        BigInt::from(-1)
374    } else {
375        BigInt::from(1)
376    };
377    let t_int_with_sign = &t_int * &t_int_sign;
378
379    // tmp = tmp1 + 2^128 tmp2 =  lambda_1 * k2 * k2_sign + s - t * t_sign * r1 - k1
380    let tmp_int = &lambda_1_int * &k2_with_sign + &s_int - &t_int_with_sign * &r1_int - &k1_int;
381    let tmp2_int = &tmp_int / &two_to_128;
382
383    #[cfg(test)]
384    {
385        use ark_ff::BigInteger;
386
387        let fq_uint: BigUint = F::MODULUS.into();
388        let fq_int: BigInt = fq_uint.into();
389
390        let tmp1_int = &tmp_int % &two_to_128;
391
392        let lambda_2_int = fq_to_big_int!(lambda_2);
393        let r2_int = fq_to_big_int!(r2);
394        // sanity checks
395        // equation (0): lambda * k2_sign * k2 + s = t * t_sign * Fr::modulus + k1
396        assert_eq!(
397            &s_int + &lambda_int * &k2_with_sign,
398            &k1_int + &t_int_with_sign * &fr_order_int
399        );
400
401        // equation (4)
402        //              lambda_1 * k2_sign * k2 + s - t * t_sign * r1 - k1
403        //     + 2^128 (lambda_2 * k2_sign * k2 - t * r2)
404        //     = 0
405        assert_eq!(
406            &lambda_1_int * &k2_with_sign + &s_int - &t_int_with_sign * &r1_int - &k1_int
407                + &two_to_128 * (&lambda_2_int * &k2_with_sign - &t_int_with_sign * &r2_int),
408            BigInt::zero()
409        );
410
411        //  (a) k1 < 2^128
412        //  (b) k2 < 2^128
413        let k1_bits = get_bits(&k1.into_bigint().to_bits_le());
414        let k2_bits = get_bits(&k1.into_bigint().to_bits_le());
415
416        assert!(k1_bits < 128, "k1 bits {}", k1_bits);
417        assert!(k2_bits < 128, "k2 bits {}", k1_bits);
418
419        //  (c) tmp1 = 0
420        //  (d) tmp2 < 2^128
421        //  (e) tmp = tmp1 + 2^128 tmp2
422        assert!(tmp1_int == BigInt::from(0));
423        let tmp2_fq = F::from_le_bytes_mod_order(&tmp2_int.to_bytes_le().1);
424        let tmp2_bits = get_bits(&tmp2_fq.into_bigint().to_bits_le());
425        assert!(tmp1_int == BigInt::from(0));
426        assert!(tmp2_bits < 128, "tmp2 bits {}", tmp2_bits);
427
428        // equation (f): tmp1 + 2^128 tmp2 =  lambda_1 * k2_sign * k2 + s - t * t_sign *
429        // r1 - k1
430        assert_eq!(
431            &tmp1_int + &two_to_128 * &tmp2_int,
432            &lambda_1_int * &k2_with_sign + &s_int - &t_int_with_sign * &r1_int - &k1_int
433        );
434        assert!(&tmp_int + &t_int_with_sign * &r1_int + &k1_int < fq_int);
435
436        assert!(&lambda_1_int * &k2_int + &s_int < fq_int);
437
438        // equation (g) tmp2 + lambda_2 * k2_sign * k2 + s2  = t * t_sign * r2
439        assert_eq!(
440            &tmp2_int + &lambda_2_int * &k2_with_sign,
441            &t_int_with_sign * &r2_int
442        );
443
444        // all intermediate data are positive
445        assert!(k1_int >= BigInt::zero());
446        assert!(k2_int >= BigInt::zero());
447        assert!(t_int >= BigInt::zero());
448        assert!(tmp_int >= BigInt::zero());
449        assert!(tmp2_int >= BigInt::zero());
450
451        // t and k2 has a same sign
452        assert_eq!(t_int_sign, k2_sign);
453    }
454
455    // ============================================
456    // step 2. build the variables
457    // ============================================
458    let two_to_128 = F::from(BigUint::from(2u64).pow(128));
459
460    let k1_var = circuit.create_variable(int_to_fq!(k1_int))?;
461    let k2_var = circuit.create_variable(int_to_fq!(k2_int))?;
462    let k2_sign_var = circuit.create_boolean_variable(is_k2_positive)?;
463
464    let t_var = circuit.create_variable(int_to_fq!(t_int))?;
465
466    let tmp_var = circuit.create_variable(int_to_fq!(tmp_int))?;
467
468    let tmp2_var = circuit.create_variable(int_to_fq!(tmp2_int))?;
469
470    // ============================================
471    // step 3. range proofs
472    // ============================================
473    //  (a) k1 < 2^128
474    //  (b) k2 < 2^128
475    circuit.enforce_in_range(k1_var, 128)?;
476    circuit.enforce_in_range(k2_var, 128)?;
477
478    //  (c) tmp1 = 0        <- implied by tmp = 2^128 * tmp2
479    //  (d) tmp2 < 2^128
480    //  (e) tmp = tmp1 + 2^128 tmp2
481    circuit.mul_constant_gate(tmp2_var, two_to_128, tmp_var)?;
482    circuit.enforce_in_range(tmp2_var, 128)?;
483
484    // ============================================
485    // step 4. equality proofs
486    // ============================================
487    //  (f) tmp + t * k2_sign * r1 + k1 =  lambda_1 * k2_sign * k2 + s
488    //  (note that we cannot do subtraction because subtraction is over Fq)
489    let k2_is_pos_sat = {
490        //  (f.1) if k2_sign = 1, then, we prove over Z
491        //      tmp + t * r1 + k1 =  lambda_1 * k2 + s
492        let left_wire = [tmp_var, t_var, k1_var, circuit.zero()];
493        let left_coeff = [F::one(), r1, F::one(), F::zero()];
494        let left_var = circuit.lc(&left_wire, &left_coeff)?;
495
496        let right_wire = [k2_var, *s_var, circuit.zero(), circuit.zero()];
497        let right_coeff = [lambda_1, F::one(), F::zero(), F::zero()];
498        let right_var = circuit.lc(&right_wire, &right_coeff)?;
499
500        circuit.is_equal(left_var, right_var)?
501    };
502
503    let k2_is_neg_sat = {
504        //  (f.2) if k2_sign = -1, then, we prove over Z
505        //    lambda_1 * k2 +  tmp + k1 =   s  + t * r1
506        let left_wire = [k2_var, tmp_var, k1_var, circuit.zero()];
507        let left_coeff = [lambda_1, F::one(), F::one(), F::zero()];
508        let left_var = circuit.lc(&left_wire, &left_coeff)?;
509
510        let right_wire = [*s_var, t_var, circuit.zero(), circuit.zero()];
511        let right_coeff = [F::one(), r1, F::zero(), F::zero()];
512        let right_var = circuit.lc(&right_wire, &right_coeff)?;
513        circuit.is_equal(left_var, right_var)?
514    };
515
516    //  (f.3) either f.1 or f.2 is satisfied
517    let sat =
518        circuit.conditional_select(k2_sign_var, k2_is_neg_sat.into(), k2_is_pos_sat.into())?;
519    circuit.enforce_true(sat)?;
520
521    //  (g) tmp2 + lambda_2 * k2_sign * k2 + s2  = t * t_sign * r2
522
523    let k2_is_pos_sat = {
524        //  (g.1) if k2_sign = 1 then
525        //      tmp2 + lambda_2 * k_2_var = t * r2
526        let left_wire = [tmp2_var, k2_var, circuit.zero(), circuit.zero()];
527        let left_coeff = [F::one(), lambda_2, F::zero(), F::zero()];
528        let left_var = circuit.lc(&left_wire, &left_coeff)?;
529
530        let right_var = circuit.mul_constant(t_var, &r2)?;
531
532        circuit.is_equal(left_var, right_var)?
533    };
534
535    let k2_is_neg_sat = {
536        //  (g.2) if k2_sign = -1 then
537        //      tmp2  + t * r2 = lambda_2 * k_2_var
538        let left_wire = [tmp2_var, t_var, circuit.zero(), circuit.zero()];
539        let left_coeff = [F::one(), r2, F::zero(), F::zero()];
540        let left_var = circuit.lc(&left_wire, &left_coeff)?;
541
542        let right_var = circuit.mul_constant(k2_var, &lambda_2)?;
543
544        circuit.is_equal(left_var, right_var)?
545    };
546
547    //  (g.3) either g.1 or g.2 is satisfied
548    let sat =
549        circuit.conditional_select(k2_sign_var, k2_is_neg_sat.into(), k2_is_pos_sat.into())?;
550    circuit.enforce_true(sat)?;
551
552    // extract the output
553    Ok((k1_var, k2_var, k2_sign_var))
554}
555
556#[cfg(test)]
557/// return the highest non-zero bits of a bit string.
558fn get_bits(a: &[bool]) -> u16 {
559    let mut res = 256;
560    for e in a.iter().rev() {
561        if !e {
562            res -= 1;
563        } else {
564            return res;
565        }
566    }
567    res
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use ark_ec::{
574        twisted_edwards::{Affine, TECurveConfig as Config},
575        CurveConfig,
576    };
577    use ark_ed_on_bls12_381_bandersnatch::{EdwardsAffine, EdwardsConfig, Fq, Fr};
578    use ark_ff::{BigInteger, MontFp, One, UniformRand};
579    use jf_utils::{fr_to_fq, test_rng};
580
581    #[test]
582    fn test_glv() -> Result<(), CircuitError> {
583        test_glv_helper::<Fq, EdwardsConfig>()
584    }
585
586    fn test_glv_helper<F, P>() -> Result<(), CircuitError>
587    where
588        F: PrimeField,
589        P: Config<BaseField = F>,
590    {
591        let mut rng = jf_utils::test_rng();
592
593        for _ in 0..100 {
594            {
595                let mut base = Affine::<P>::rand(&mut rng);
596                let s = P::ScalarField::rand(&mut rng);
597                let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
598
599                let s_var = circuit.create_variable(fr_to_fq::<F, P>(&s))?;
600                let base_var = circuit.create_point_variable(TEPoint::from(base))?;
601                base = (base * s).into();
602                let result = circuit.variable_base_scalar_mul::<P>(s_var, &base_var)?;
603                assert_eq!(TEPoint::from(base), circuit.point_witness(&result)?);
604
605                // ark_std::println!("Turbo Plonk: {} constraints", circuit.num_gates());
606                assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
607            }
608            {
609                let mut base = Affine::<P>::rand(&mut rng);
610                let s = P::ScalarField::rand(&mut rng);
611                let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
612
613                let s_var = circuit.create_variable(fr_to_fq::<F, P>(&s))?;
614                let base_var = circuit.create_point_variable(TEPoint::from(base))?;
615                base = (base * s).into();
616                let result = circuit.variable_base_scalar_mul::<P>(s_var, &base_var)?;
617                assert_eq!(TEPoint::from(base), circuit.point_witness(&result)?);
618
619                // ark_std::println!("Ultra Plonk: {} constraints", circuit.num_gates());
620                assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
621            }
622
623            {
624                let mut base = Affine::<P>::rand(&mut rng);
625                let s = P::ScalarField::rand(&mut rng);
626                let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
627
628                let s_var = circuit.create_variable(fr_to_fq::<F, P>(&s))?;
629                let base_var = circuit.create_point_variable(TEPoint::from(base))?;
630                base = (base * s).into();
631                let result = circuit.glv_mul::<P>(s_var, &base_var)?;
632                assert_eq!(TEPoint::from(base), circuit.point_witness(&result)?);
633
634                // ark_std::println!("Turbo Plonk GLV: {} constraints", circuit.num_gates());
635                assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
636            }
637
638            {
639                let mut base = Affine::<P>::rand(&mut rng);
640                let s = P::ScalarField::rand(&mut rng);
641                let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(16);
642
643                let s_var = circuit.create_variable(fr_to_fq::<F, P>(&s))?;
644                let base_var = circuit.create_point_variable(TEPoint::from(base))?;
645                base = (base * s).into();
646                let result = circuit.glv_mul::<P>(s_var, &base_var)?;
647                assert_eq!(TEPoint::from(base), circuit.point_witness(&result)?);
648
649                // ark_std::println!("Ultra Plonk GLV: {} constraints", circuit.num_gates());
650                assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
651            }
652        }
653        Ok(())
654    }
655
656    #[test]
657    fn test_endomorphism() {
658        let base_point = EdwardsAffine::new_unchecked(
659            MontFp!(
660                "29627151942733444043031429156003786749302466371339015363120350521834195802525"
661            ),
662            MontFp!(
663                "27488387519748396681411951718153463804682561779047093991696427532072116857978"
664            ),
665        );
666        let endo_point = EdwardsAffine::new_unchecked(
667            MontFp!("3995099504672814451457646880854530097687530507181962222512229786736061793535"),
668            MontFp!(
669                "33370049900732270411777328808452912493896532385897059012214433666611661340894"
670            ),
671        );
672        let base_point: TEPoint<Fq> = base_point.into();
673        let endo_point: TEPoint<Fq> = endo_point.into();
674
675        let t = endomorphism::<_, EdwardsConfig>(&base_point);
676        assert_eq!(t, endo_point);
677
678        let mut circuit: PlonkCircuit<Fq> = PlonkCircuit::new_turbo_plonk();
679        let point_var = circuit.create_point_variable(base_point).unwrap();
680        let endo_var = endomorphism_circuit::<_, EdwardsConfig>(&mut circuit, &point_var).unwrap();
681        let endo_point_rec = circuit.point_witness(&endo_var).unwrap();
682        assert_eq!(endo_point_rec, endo_point);
683    }
684
685    #[test]
686    fn test_decomposition() {
687        let mut rng = test_rng();
688        let lambda: Fr = Fr::from_le_bytes_mod_order(LAMBDA.as_ref());
689
690        for _ in 0..100 {
691            let scalar = Fr::rand(&mut rng);
692            let (k1, k2, is_k2_pos) = scalar_decomposition(&scalar);
693            assert!(get_bits(&k1.into_bigint().to_bits_le()) <= 128);
694            assert!(get_bits(&k2.into_bigint().to_bits_le()) <= 128);
695            let k2 = if is_k2_pos { k2 } else { -k2 };
696
697            assert_eq!(k1 - k2 * lambda, scalar,);
698
699            let mut circuit: PlonkCircuit<Fq> = PlonkCircuit::new_ultra_plonk(16);
700            let scalar_var = circuit.create_variable(field_switching(&scalar)).unwrap();
701            let (k1_var, k2_var, k2_sign_var) = scalar_decomposition_gate::<
702                <EdwardsConfig as CurveConfig>::BaseField,
703                <EdwardsConfig as CurveConfig>::ScalarField,
704            >(&mut circuit, &scalar_var)
705            .unwrap();
706
707            let k1_rec = circuit.witness(k1_var).unwrap();
708            assert_eq!(field_switching::<_, Fq>(&k1), k1_rec);
709
710            let k2_rec = circuit.witness(k2_var).unwrap();
711            let k2_sign = circuit.witness(k2_sign_var.into()).unwrap();
712            let k2_with_sign_rec = if k2_sign == Fq::one() {
713                field_switching::<_, Fr>(&k2_rec)
714            } else {
715                -field_switching::<_, Fr>(&k2_rec)
716            };
717
718            assert_eq!(k2, k2_with_sign_rec);
719        }
720    }
721}