1use ark_ff::FftField;
5use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
6use ark_std::{ops::Range, vec, vec::Vec};
7
8#[allow(dead_code)]
25pub(crate) trait LagrangeCoeffs<F: FftField> {
26 fn first_lagrange_coeff(&self, tau: F) -> F;
28 fn last_lagrange_coeff(&self, tau: F) -> F;
30 fn first_and_last_lagrange_coeffs(&self, tau: F) -> (F, F) {
32 (
33 self.first_lagrange_coeff(tau),
34 self.last_lagrange_coeff(tau),
35 )
36 }
37 fn lagrange_coeffs_for_range(&self, range: Range<usize>, tau: F) -> Vec<F>;
39}
40
41impl<F: FftField> LagrangeCoeffs<F> for Radix2EvaluationDomain<F> {
42 fn first_lagrange_coeff(&self, tau: F) -> F {
46 let offset = self.coset_offset();
47 if tau == offset {
48 return F::one();
50 }
51
52 let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
53 if z_h_at_tau.is_zero() {
54 F::zero()
57 } else {
58 let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
59 let denominator =
60 self.size_as_field_element() * offset_pow_size_minus_one * (tau - offset);
61 z_h_at_tau * denominator.inverse().unwrap()
62 }
63 }
64
65 fn last_lagrange_coeff(&self, tau: F) -> F {
68 let offset = self.coset_offset();
69 if tau == self.group_gen_inv() * offset {
70 return F::one();
71 }
72
73 let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
74 if z_h_at_tau.is_zero() {
75 F::zero()
78 } else {
79 let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
80 let denominator = self.size_as_field_element()
81 * offset_pow_size_minus_one
82 * (tau - offset * self.group_gen_inv());
83 z_h_at_tau * self.group_gen_inv() * denominator.inverse().unwrap()
84 }
85 }
86
87 fn first_and_last_lagrange_coeffs(&self, tau: F) -> (F, F) {
90 let offset = self.coset_offset();
91 let group_gen_inv = self.group_gen_inv();
92 if tau == offset {
93 return (F::one(), F::zero());
94 }
95 if tau == group_gen_inv * offset {
96 return (F::zero(), F::one());
97 }
98
99 let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
100 if z_h_at_tau.is_zero() {
101 (F::zero(), F::zero())
102 } else {
103 let offset_pow_size_minus_one = self.coset_offset_pow_size() / offset;
104 let first_denominator =
105 self.size_as_field_element() * offset_pow_size_minus_one * (tau - offset);
106 let last_denominator = self.size_as_field_element()
107 * offset_pow_size_minus_one
108 * (tau - offset * group_gen_inv);
109
110 (
111 z_h_at_tau / first_denominator,
112 z_h_at_tau * group_gen_inv / last_denominator,
113 )
114 }
115 }
116
117 fn lagrange_coeffs_for_range(&self, range: Range<usize>, tau: F) -> Vec<F> {
122 if range.end > self.size() {
123 panic!("Out of range: domain size smaller than range.end");
124 }
125 let size = range.end - range.start;
126 let z_h_at_tau = self.evaluate_vanishing_polynomial(tau);
127 let offset = self.coset_offset();
128 let group_gen = self.group_gen();
129 let group_start = group_gen.pow([range.start as u64]);
130
131 if z_h_at_tau.is_zero() {
132 let mut u = vec![F::zero(); size];
137 let mut omega_i = offset * group_start;
138 for u_i in u.iter_mut().take(size) {
139 if omega_i == tau {
140 *u_i = F::one();
141 break;
142 }
143 omega_i *= &group_gen;
144 }
145 u
146 } else {
147 let group_gen_inv = self.group_gen_inv();
157 let start = range.start as u64;
158
159 let v_0_inv = self.size_as_field_element() * self.coset_offset_pow_size() / offset;
161 let mut l_i = z_h_at_tau.inverse().unwrap() * v_0_inv * group_gen_inv.pow([start]);
162
163 let mut negative_cur_elem = -offset * group_start;
164 let mut lagrange_coefficients_inverse = vec![F::zero(); size];
165 for coeff in lagrange_coefficients_inverse.iter_mut() {
166 let r_i = tau + negative_cur_elem;
167 *coeff = l_i * r_i;
168 l_i *= &group_gen_inv;
170 negative_cur_elem *= &group_gen;
171 }
172 ark_ff::fields::batch_inversion(lagrange_coefficients_inverse.as_mut_slice());
173 lagrange_coefficients_inverse
174 }
175 }
176}
177
178#[cfg(test)]
179mod test {
180 use super::*;
181 use ark_bls12_381::Fr;
182 use ark_std::{rand::Rng, One, UniformRand, Zero};
183
184 #[test]
186 fn test_in_domain_lagrange_coeff() {
187 let mut rng = jf_utils::test_rng();
188 for domain_log_size in 4..9 {
189 let domain_size = 1 << domain_log_size;
190 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
191 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
192 for (i, (x, coset_x)) in domain.elements().zip(coset_domain.elements()).enumerate() {
193 if i == 0 {
194 assert_eq!(domain.first_lagrange_coeff(x), Fr::one());
195 assert_eq!(domain.last_lagrange_coeff(x), Fr::zero());
196 assert_eq!(coset_domain.first_lagrange_coeff(coset_x), Fr::one());
197 assert_eq!(coset_domain.last_lagrange_coeff(coset_x), Fr::zero());
198 }
199 if i == domain.size() - 1 {
200 assert_eq!(domain.last_lagrange_coeff(x), Fr::one());
201 assert_eq!(domain.first_lagrange_coeff(x), Fr::zero());
202 assert_eq!(coset_domain.last_lagrange_coeff(coset_x), Fr::one());
203 assert_eq!(coset_domain.first_lagrange_coeff(coset_x), Fr::zero());
204 }
205
206 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
207 let coset_lagrange_coeffs =
208 coset_domain.evaluate_all_lagrange_coefficients(coset_x);
209 for _ in 0..10 {
210 let start = rng.gen_range(0..i + 1);
211 let end = rng.gen_range(start..domain_size + 1);
212 assert_eq!(
213 domain.lagrange_coeffs_for_range(start..end, x),
214 lagrange_coeffs[start..end]
215 );
216 assert_eq!(
217 coset_domain.lagrange_coeffs_for_range(start..end, coset_x),
218 coset_lagrange_coeffs[start..end]
219 );
220 }
221 }
222 }
223 }
224
225 #[test]
226 fn test_random_lagrange_coeff() {
227 let mut rng = jf_utils::test_rng();
228 for domain_log_size in 4..9 {
229 let domain_size = 1 << domain_log_size;
230 let domain = Radix2EvaluationDomain::<Fr>::new(domain_size).unwrap();
231 let coset_domain = domain.get_coset(Fr::GENERATOR).unwrap();
232
233 for _ in 0..10 {
234 let x = Fr::rand(&mut rng);
235 let lagrange_coeffs = domain.evaluate_all_lagrange_coefficients(x);
236 let coset_lagrange_coeffs = coset_domain.evaluate_all_lagrange_coefficients(x);
237
238 assert_eq!(domain.first_lagrange_coeff(x), lagrange_coeffs[0]);
239 assert_eq!(
240 domain.last_lagrange_coeff(x),
241 lagrange_coeffs[domain_size - 1]
242 );
243 assert_eq!(
244 coset_domain.first_lagrange_coeff(x),
245 coset_lagrange_coeffs[0]
246 );
247 assert_eq!(
248 coset_domain.last_lagrange_coeff(x),
249 coset_lagrange_coeffs[domain_size - 1]
250 );
251
252 for _ in 0..10 {
253 let start = rng.gen_range(0..domain_size);
254 let end = rng.gen_range(start..domain_size + 1);
255 assert_eq!(
256 domain.lagrange_coeffs_for_range(start..end, x),
257 lagrange_coeffs[start..end]
258 );
259 assert_eq!(
260 coset_domain.lagrange_coeffs_for_range(start..end, x),
261 coset_lagrange_coeffs[start..end]
262 );
263 }
264 }
265 }
266 }
267}