1use ark_ff::{FftField, Field};
10use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
11use ark_std::{
12 format,
13 string::{String, ToString},
14 vec,
15 vec::Vec,
16};
17use core::borrow::Borrow;
18use displaydoc::Display;
19
20#[derive(Display, Debug)]
22pub struct RSCodeError(String);
23impl ark_std::error::Error for RSCodeError {}
24
25pub fn reed_solomon_erasure_encode<F, D>(data: D, parity_size: usize) -> impl Iterator<Item = F>
48where
49 F: Field,
50 D: IntoIterator,
51 D::Item: Borrow<F>,
52 D::IntoIter: ExactSizeIterator + Clone,
53{
54 let data_iter = data.into_iter();
55 let num_shares = data_iter.len() + parity_size;
56
57 (1..=num_shares).map(move |index| {
60 let mut value = F::zero();
61 let mut x = F::one();
62 data_iter.clone().for_each(|coef| {
63 value += x * coef.borrow();
64 x *= F::from(index as u64);
65 });
66 value
67 })
68}
69
70pub fn reed_solomon_erasure_decode<F, D, T1, T2>(
76 shares: D,
77 data_size: usize,
78) -> Result<Vec<F>, RSCodeError>
79where
80 F: Field,
81 T1: Borrow<F>,
82 T2: Borrow<F>,
83 D: IntoIterator,
84 D::Item: Borrow<(T1, T2)>,
85 D::IntoIter: ExactSizeIterator + Clone,
86{
87 let shares_iter = shares.into_iter().take(data_size);
88 if shares_iter.len() < data_size {
89 return Err(RSCodeError(format!(
90 "Insufficient evaluation points: got {} expected at least {}",
91 shares_iter.len(),
92 data_size
93 )));
94 }
95
96 let x = shares_iter
104 .clone()
105 .map(|share| *share.borrow().0.borrow())
106 .collect::<Vec<_>>();
107 let mut l = vec![F::zero(); data_size + 1];
109 l[0] = F::one();
110 for i in 1..data_size + 1 {
111 l[i] = F::one();
112 for j in (1..i).rev() {
113 l[j] = l[j - 1] - x[i - 1] * l[j];
114 }
115 l[0] = -x[i - 1] * l[0];
116 }
117 let w = (0..data_size)
119 .map(|i| {
120 let mut ret = F::one();
121 for j in 0..data_size {
122 if i != j {
123 let denom = x[i] - x[j];
124 if denom.is_zero() {
125 return Err(RSCodeError(format!(
126 "duplicate input point {} at indices {}, {}",
127 x[i], i, j
128 )));
129 }
130 ret /= denom;
131 }
132 }
133 Ok(ret)
134 })
135 .collect::<Result<Vec<_>, _>>()?;
136 let mut f = vec![F::zero(); data_size];
138 for (i, share) in shares_iter.enumerate() {
140 let mut li = vec![F::zero(); data_size];
141 li[data_size - 1] = F::one();
142 for j in (0..data_size - 1).rev() {
143 li[j] = l[j + 1] + x[i] * li[j + 1];
144 }
145 let weight = w[i] * share.borrow().1.borrow();
146 for j in 0..data_size {
147 f[j] += weight * li[j];
148 }
149 }
150 Ok(f)
151}
152
153pub fn reed_solomon_erasure_decode_rou<F, D>(
159 shares: D,
160 data_size: usize,
161 domain: &Radix2EvaluationDomain<F>,
162) -> Result<Vec<F>, RSCodeError>
163where
164 F: FftField,
165 D: IntoIterator,
166 D::Item: Borrow<(usize, F)>,
167 D::IntoIter: ExactSizeIterator + Clone,
168{
169 let shares_iter = shares.into_iter();
170
171 let max_index = shares_iter
173 .clone()
174 .max_by_key(|s| s.borrow().0)
175 .ok_or_else(|| RSCodeError("empty shares".to_string()))?
176 .borrow()
177 .0;
178 if max_index >= domain.size() {
179 return Err(RSCodeError(format!(
180 "share index {} out of bounds for domain size {}",
181 max_index,
182 domain.size()
183 )));
184 }
185
186 let domain_elements: Vec<_> = domain.elements().collect();
190
191 let domain_shares = shares_iter.map(|share| {
192 let &(index, eval) = share.borrow();
193 (domain_elements[index], eval)
195 });
196 reed_solomon_erasure_decode(domain_shares, data_size)
197}
198
199#[cfg(test)]
200mod test {
201 use ark_bls12_377::Fr as Fr377;
202 use ark_bls12_381::Fr as Fr381;
203 use ark_bn254::Fr as Fr254;
204 use ark_ff::{FftField, Field};
205 use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
206 use ark_std::{vec, vec::Vec};
207
208 use crate::reed_solomon_code::{
209 reed_solomon_erasure_decode, reed_solomon_erasure_decode_rou, reed_solomon_erasure_encode,
210 };
211
212 fn test_rs_code_helper<F: Field>() {
213 let data = vec![F::from(1u64), F::from(2u64)];
215 let expected = vec![F::from(3u64), F::from(5u64), F::from(7u64)];
217 let code: Vec<F> = reed_solomon_erasure_encode(data.iter(), 1).collect();
218 assert_eq!(code, expected);
219
220 for to_be_removed in 0..code.len() {
221 let mut indices = vec![F::from(1u64), F::from(2u64), F::from(3u64)];
222 let mut new_code = code.clone();
223 indices.remove(to_be_removed);
224 new_code.remove(to_be_removed);
225 let output = reed_solomon_erasure_decode(indices.iter().zip(new_code), 2).unwrap();
226 assert_eq!(data, output);
227 }
228 }
229
230 #[test]
231 fn test_rs_code() {
232 test_rs_code_helper::<Fr254>();
233 test_rs_code_helper::<Fr377>();
234 test_rs_code_helper::<Fr381>();
235 }
236
237 fn test_rs_code_fft_helper<F: FftField>() {
238 let domain = Radix2EvaluationDomain::<F>::new(3).unwrap();
239 let input = vec![F::from(1u64), F::from(2u64)];
240
241 {
243 let mut code = domain.fft(&input);
244 let mut eval_points = domain.elements().collect::<Vec<_>>();
245 eval_points.remove(1);
246 code.remove(1);
247 let output = reed_solomon_erasure_decode(eval_points.iter().zip(code), 2).unwrap();
248 assert_eq!(input, output);
249 }
250
251 {
253 let mut code = domain.fft(&input);
254 code.remove(1);
255 let output =
256 reed_solomon_erasure_decode_rou([0, 2].into_iter().zip(code), 2, &domain).unwrap();
257 assert_eq!(input, output);
258 }
259 }
260
261 #[test]
262 fn test_rs_code_fft() {
263 test_rs_code_fft_helper::<Fr254>();
264 test_rs_code_fft_helper::<Fr377>();
265 test_rs_code_fft_helper::<Fr381>();
266 }
267
268 fn duplicate_inputs_helper<F: Field>() {
269 let payload = [3u64, 1, 4].map(|x| F::from(x));
271 let expected = [8u64, 21, 42, 71, 108].map(|x| F::from(x));
273 let code: Vec<F> = reed_solomon_erasure_encode(payload.iter(), 2).collect();
274 assert_eq!(code, expected);
275
276 let mut points = [1u64, 2, 3].map(|x| F::from(x));
277 let recovered_payload: Vec<F> =
278 reed_solomon_erasure_decode(points.iter().zip(&code[..3]), payload.len()).unwrap();
279 assert_eq!(recovered_payload, payload);
280
281 points[1] = points[0]; assert!(reed_solomon_erasure_decode::<F, _, _, _>(
283 points.iter().zip(&code[..3]),
284 payload.len()
285 )
286 .is_err());
287 }
288
289 #[test]
290 fn duplicate_inputs() {
291 duplicate_inputs_helper::<Fr381>();
292 }
293}