jf_plonk/
lagrange.rs

1//! Utilities for Lagrange interpolations, evaluations, coefficients for a
2//! polynomial
3
4use ark_ff::FftField;
5use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
6use ark_std::{ops::Range, vec, vec::Vec};
7
8// TODO: (alex) include these APIs upstream in arkworks directly
9
10/// A helper trait for computing Lagrange coefficients of an evaluation domain
11///
12/// from arkworks:
13/// Evaluate all Lagrange polynomials at tau to get the lagrange coefficients.
14/// Define the following as
15/// - H: The coset we are in, with generator g and offset h
16/// - n: The size of the coset H
17/// - Z_H: The vanishing polynomial for H. Z_H(x) = prod_{i in n} (x - hg^i) =
18///   x^n - h^n
19/// - v_i: A sequence of values, where v_0 = 1/(n * h^(n-1)), and v_{i + 1} = g
20///   * v_i
21///
22/// We then compute L_{i,H}(tau) as `L_{i,H}(tau) = Z_H(tau) * v_i / (tau - h *
23/// g^i)`
24#[allow(dead_code)]
25pub(crate) trait LagrangeCoeffs<F: FftField> {
26    /// Returns the first coefficient: `L_{0, Domain}(tau)`
27    fn first_lagrange_coeff(&self, tau: F) -> F;
28    /// Returns the last coefficient: `L_{n-1, Domain}(tau)`
29    fn last_lagrange_coeff(&self, tau: F) -> F;
30    /// Returns (first, last) lagrange coeffs
31    fn first_and_last_lagrange_coeffs(&self, tau: F) -> (F, F) {
32        (
33            self.first_lagrange_coeff(tau),
34            self.last_lagrange_coeff(tau),
35        )
36    }
37    /// Return a list of coefficients for `L_{range, Domain}(tau)`
38    fn lagrange_coeffs_for_range(&self, range: Range<usize>, tau: F) -> Vec<F>;
39}
40
41impl<F: FftField> LagrangeCoeffs<F> for Radix2EvaluationDomain<F> {
42    // L_0(tau) = Z_H(tau) * g^0 / (n * h^(n-1) * (tau - h * g^0))
43    // with g^0 = 1
44    // special care when tau in H, as both numerator and denominator is zero
45    fn first_lagrange_coeff(&self, tau: F) -> F {
46        let offset = self.coset_offset();
47        if tau == offset {
48            // when tau = g^0 * offset
49            return F::one();
50        }
51
52        let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
53        if z_h_at_tau.is_zero() {
54            // the case where tau is the first element in the coset
55            // already early-return
56            F::zero()
57        } else {
58            let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
59            let denominator =
60                self.size_as_field_element() * offset_pow_size_minus_one * (tau - offset);
61            z_h_at_tau * denominator.inverse().unwrap()
62        }
63    }
64
65    // L_n-1(tau) = Z_H(tau) * g^-1 / (n * h^(n-1) * (tau - h * g^-1))
66    // with g^n-1 = g^-1
67    fn last_lagrange_coeff(&self, tau: F) -> F {
68        let offset = self.coset_offset();
69        if tau == self.group_gen_inv() * offset {
70            return F::one();
71        }
72
73        let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
74        if z_h_at_tau.is_zero() {
75            // the case where tau is the last element in the coset
76            // already early-return
77            F::zero()
78        } else {
79            let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
80            let denominator = self.size_as_field_element()
81                * offset_pow_size_minus_one
82                * (tau - offset * self.group_gen_inv());
83            z_h_at_tau * self.group_gen_inv() * denominator.inverse().unwrap()
84        }
85    }
86
87    // a slightly cheaper implementation of the generic default
88    // saving repeated work when computing two coeffs separately
89    fn first_and_last_lagrange_coeffs(&self, tau: F) -> (F, F) {
90        let offset = self.coset_offset();
91        let group_gen_inv = self.group_gen_inv();
92        if tau == offset {
93            return (F::one(), F::zero());
94        }
95        if tau == group_gen_inv * offset {
96            return (F::zero(), F::one());
97        }
98
99        let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
100        if z_h_at_tau.is_zero() {
101            (F::zero(), F::zero())
102        } else {
103            let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
104            let first_denominator =
105                self.size_as_field_element() * offset_pow_size_minus_one * (tau - offset);
106            let last_denominator = self.size_as_field_element()
107                * offset_pow_size_minus_one
108                * (tau - offset * group_gen_inv);
109
110            (
111                z_h_at_tau / first_denominator,
112                z_h_at_tau * group_gen_inv / last_denominator,
113            )
114        }
115    }
116
117    // similar to `EvaluationDomain::evaluate_all_lagrange_coefficients()`
118    //
119    // # Panic
120    // if `range` exceeds the `self.size()`
121    fn lagrange_coeffs_for_range(&self, range: Range<usize>, tau: F) -> Vec<F> {
122        if range.end > self.size() {
123            panic!("Out of range: domain size smaller than range.end");
124        }
125        let size = range.end - range.start;
126        let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
127        let offset = self.coset_offset();
128        let group_gen = self.group_gen();
129        let group_start = group_gen.pow([range.start as u64]);
130
131        if z_h_at_tau.is_zero() {
132            // In this case, we know that tau = hg^i, for some value i.
133            // Then i-th lagrange coefficient in this case is then simply 1,
134            // and all other lagrange coefficients are 0.
135            // Thus we find i by brute force.
136            let mut u = vec![F::zero(); size];
137            let mut omega_i = offset * group_start;
138            for u_i in u.iter_mut().take(size) {
139                if omega_i == tau {
140                    *u_i = F::one();
141                    break;
142                }
143                omega_i *= &group_gen;
144            }
145            u
146        } else {
147            // In this case we have to compute `Z_H(tau) * v_i / (tau - h g^i)`
148            // for i in start..end
149            // We actually compute this by computing (Z_H(tau) * v_i)^{-1} * (tau - h g^i)
150            // and then batch inverting to get the correct lagrange coefficients.
151            // We let `l_i = (Z_H(tau) * v_i)^-1` and `r_i = tau - h g^i`
152            // Notice that since Z_H(tau) is i-independent,
153            // and v_i = g * v_{i-1}, it follows that
154            // l_i = g^-1 * l_{i-1}
155
156            let group_gen_inv = self.group_gen_inv();
157            let start = range.start as u64;
158
159            // v_0_inv = n * h^(n-1)
160            let v_0_inv = self.size_as_field_element() * self.coset_offset_pow_size() / offset;
161            let mut l_i = z_h_at_tau.inverse().unwrap() * v_0_inv * group_gen_inv.pow([start]);
162
163            let mut negative_cur_elem = -offset * group_start;
164            let mut lagrange_coefficients_inverse = vec![F::zero(); size];
165            for coeff in lagrange_coefficients_inverse.iter_mut() {
166                let r_i = tau + negative_cur_elem;
167                *coeff = l_i * r_i;
168                // Increment l_i and negative_cur_elem
169                l_i *= &group_gen_inv;
170                negative_cur_elem *= &group_gen;
171            }
172            ark_ff::fields::batch_inversion(lagrange_coefficients_inverse.as_mut_slice());
173            lagrange_coefficients_inverse
174        }
175    }
176}
177
178#[cfg(test)]
179mod test {
180    use super::*;
181    use ark_bls12_381::Fr;
182    use ark_std::{rand::Rng, One, UniformRand, Zero};
183
184    /// Test that for points in the domain, coefficients are computed correctly
185    #[test]
186    fn test_in_domain_lagrange_coeff() {
187        let mut rng = jf_utils::test_rng();
188        for domain_log_size in 4..9 {
189            let domain_size = 1 << domain_log_size;
190            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
191            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
192            for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
193                if i == 0 {
194                    assert_eq!(domain.first_lagrange_coeff(x), Fr::one());
195                    assert_eq!(domain.last_lagrange_coeff(x), Fr::zero());
196                    assert_eq!(coset_domain.first_lagrange_coeff(coset_x), Fr::one());
197                    assert_eq!(coset_domain.last_lagrange_coeff(coset_x), Fr::zero());
198                }
199                if i == domain.size() - 1 {
200                    assert_eq!(domain.last_lagrange_coeff(x), Fr::one());
201                    assert_eq!(domain.first_lagrange_coeff(x), Fr::zero());
202                    assert_eq!(coset_domain.last_lagrange_coeff(coset_x), Fr::one());
203                    assert_eq!(coset_domain.first_lagrange_coeff(coset_x), Fr::zero());
204                }
205
206                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
207                let coset_lagrange_coeffs =
208                    coset_domain.evaluate_all_lagrange_coefficients(coset_x);
209                for _ in 0..10 {
210                    let start = rng.gen_range(0..i + 1);
211                    let end = rng.gen_range(start..domain_size + 1);
212                    assert_eq!(
213                        domain.lagrange_coeffs_for_range(start..end, x),
214                        lagrange_coeffs[start..end]
215                    );
216                    assert_eq!(
217                        coset_domain.lagrange_coeffs_for_range(start..end, coset_x),
218                        coset_lagrange_coeffs[start..end]
219                    );
220                }
221            }
222        }
223    }
224
225    #[test]
226    fn test_random_lagrange_coeff() {
227        let mut rng = jf_utils::test_rng();
228        for domain_log_size in 4..9 {
229            let domain_size = 1 << domain_log_size;
230            let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
231            let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
232
233            for _ in 0..10 {
234                let x = Fr::rand(&mut rng);
235                let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
236                let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(x);
237
238                assert_eq!(domain.first_lagrange_coeff(x), lagrange_coeffs[0]);
239                assert_eq!(
240                    domain.last_lagrange_coeff(x),
241                    lagrange_coeffs[domain_size - 1]
242                );
243                assert_eq!(
244                    coset_domain.first_lagrange_coeff(x),
245                    coset_lagrange_coeffs[0]
246                );
247                assert_eq!(
248                    coset_domain.last_lagrange_coeff(x),
249                    coset_lagrange_coeffs[domain_size - 1]
250                );
251
252                for _ in 0..10 {
253                    let start = rng.gen_range(0..domain_size);
254                    let end = rng.gen_range(start..domain_size + 1);
255                    assert_eq!(
256                        domain.lagrange_coeffs_for_range(start..end, x),
257                        lagrange_coeffs[start..end]
258                    );
259                    assert_eq!(
260                        coset_domain.lagrange_coeffs_for_range(start..end, x),
261                        coset_lagrange_coeffs[start..end]
262                    );
263                }
264            }
265        }
266    }
267}