jf_utils/
reed_solomon_code.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Module for erasure code
8
9use ark_ff::{FftField, Field};
10use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
11use ark_std::{
12    format,
13    string::{String, ToString},
14    vec,
15    vec::Vec,
16};
17use core::borrow::Borrow;
18use displaydoc::Display;
19
20/// Erasure code error
21#[derive(Display, Debug)]
22pub struct RSCodeError(String);
23impl ark_std::error::Error for RSCodeError {}
24
25/// Erasure-encode `data` into `data.len() + parity_size` shares.
26///
27/// Treating the input data as the coefficients of a polynomial,
28/// Returns the evaluations of this polynomial over [1, data.len() +
29/// parity_size].
30///
31/// If `F` is a [`FftField`], the encoding can be done using FFT on
32/// a `GeneralEvaluationDomain` (E.g. when num_shares = 3):
33/// ```
34/// use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
35/// use ark_bn254::Fr as F;
36/// use ark_std::{vec, One, Zero};
37/// use jf_utils::reed_solomon_code::reed_solomon_erasure_decode;
38///
39/// let domain = GeneralEvaluationDomain::<F>::new(3).unwrap();
40/// let input = vec![F::one(), F::one()];
41/// let mut result = domain.fft(&input); // FFT encoding
42/// let mut eval_points = domain.elements().collect::<Vec<_>>(); // Evaluation points
43/// // test decoding
44/// let output = reed_solomon_erasure_decode(eval_points.iter().zip(result).take(2), 2).unwrap();
45/// assert_eq!(input, output);
46/// ```
47pub fn reed_solomon_erasure_encode<F, D>(data: D, parity_size: usize) -> impl Iterator<Item = F>
48where
49    F: Field,
50    D: IntoIterator,
51    D::Item: Borrow<F>,
52    D::IntoIter: ExactSizeIterator + Clone,
53{
54    let data_iter = data.into_iter();
55    let num_shares = data_iter.len() + parity_size;
56
57    // view `data` as coefficients of a polynomial
58    // make shares by evaluating this polynomial at 1..=num_shares
59    (1..=num_shares).map(move |index| {
60        let mut value = F::zero();
61        let mut x = F::one();
62        data_iter.clone().for_each(|coef| {
63            value += x * coef.borrow();
64            x *= F::from(index as u64);
65        });
66        value
67    })
68}
69
70/// Decode into `data_size` data elements via polynomial interpolation.
71/// The degree of the interpolated polynomial is `data_size - 1`.
72/// First part of the share is the evaluation point, second part is its
73/// evaluation. Returns a data vector of length `data_size`.
74/// Time complexity of O(n^2).
75pub fn reed_solomon_erasure_decode<F, D, T1, T2>(
76    shares: D,
77    data_size: usize,
78) -> Result<Vec<F>, RSCodeError>
79where
80    F: Field,
81    T1: Borrow<F>,
82    T2: Borrow<F>,
83    D: IntoIterator,
84    D::Item: Borrow<(T1, T2)>,
85    D::IntoIter: ExactSizeIterator + Clone,
86{
87    let shares_iter = shares.into_iter().take(data_size);
88    if shares_iter.len() < data_size {
89        return Err(RSCodeError(format!(
90            "Insufficient evaluation points: got {} expected at least {}",
91            shares_iter.len(),
92            data_size
93        )));
94    }
95
96    // Lagrange interpolation:
97    // Given a list of points (x_1, y_1) ... (x_n, y_n)
98    //  1. Define l(x) = \prod (x - x_i)
99    //  2. Calculate the barycentric weight w_i = \prod_{j \neq i} 1 / (x_i -
100    // x_j)
101    //  3. Calculate l_i(x) = w_i * l(x) / (x - x_i)
102    //  4. Return f(x) = \sum_i y_i * l_i(x)
103    let x = shares_iter
104        .clone()
105        .map(|share| *share.borrow().0.borrow())
106        .collect::<Vec<_>>();
107    // Calculating l(x) = \prod (x - x_i)
108    let mut l = vec![F::zero(); data_size + 1];
109    l[0] = F::one();
110    for i in 1..data_size + 1 {
111        l[i] = F::one();
112        for j in (1..i).rev() {
113            l[j] = l[j - 1] - x[i - 1] * l[j];
114        }
115        l[0] = -x[i - 1] * l[0];
116    }
117    // Calculate the barycentric weight w_i
118    let w = (0..data_size)
119        .map(|i| {
120            let mut ret = F::one();
121            for j in 0..data_size {
122                if i != j {
123                    let denom = x[i] - x[j];
124                    if denom.is_zero() {
125                        return Err(RSCodeError(format!(
126                            "duplicate input point {} at indices {}, {}",
127                            x[i], i, j
128                        )));
129                    }
130                    ret /= denom;
131                }
132            }
133            Ok(ret)
134        })
135        .collect::<Result<Vec<_>, _>>()?;
136    // Calculate f(x) = \sum_i l_i(x)
137    let mut f = vec![F::zero(); data_size];
138    // for i in 0..shares.len() {
139    for (i, share) in shares_iter.enumerate() {
140        let mut li = vec![F::zero(); data_size];
141        li[data_size - 1] = F::one();
142        for j in (0..data_size - 1).rev() {
143            li[j] = l[j + 1] + x[i] * li[j + 1];
144        }
145        let weight = w[i] * share.borrow().1.borrow();
146        for j in 0..data_size {
147            f[j] += weight * li[j];
148        }
149    }
150    Ok(f)
151}
152
153/// Like [`reed_solomon_erasure_decode`] except input points are drawn from the
154/// given FFT domain.
155///
156/// Differences from [`reed_solomon_erasure_decode`]:
157/// - First part of the share is an index into `domain`
158pub fn reed_solomon_erasure_decode_rou<F, D>(
159    shares: D,
160    data_size: usize,
161    domain: &Radix2EvaluationDomain<F>,
162) -> Result<Vec<F>, RSCodeError>
163where
164    F: FftField,
165    D: IntoIterator,
166    D::Item: Borrow<(usize, F)>,
167    D::IntoIter: ExactSizeIterator + Clone,
168{
169    let shares_iter = shares.into_iter();
170
171    // check arguments
172    let max_index = shares_iter
173        .clone()
174        .max_by_key(|s| s.borrow().0)
175        .ok_or_else(|| RSCodeError("empty shares".to_string()))?
176        .borrow()
177        .0;
178    if max_index >= domain.size() {
179        return Err(RSCodeError(format!(
180            "share index {} out of bounds for domain size {}",
181            max_index,
182            domain.size()
183        )));
184    }
185
186    // We need random access to domain elements
187    // but we are given only an iterator.
188    // The least bad solution is to collect all elements.
189    let domain_elements: Vec<_> = domain.elements().collect();
190
191    let domain_shares = shares_iter.map(|share| {
192        let &(index, eval) = share.borrow();
193        // index cannot panic, we already checked
194        (domain_elements[index], eval)
195    });
196    reed_solomon_erasure_decode(domain_shares, data_size)
197}
198
199#[cfg(test)]
200mod test {
201    use ark_bls12_377::Fr as Fr377;
202    use ark_bls12_381::Fr as Fr381;
203    use ark_bn254::Fr as Fr254;
204    use ark_ff::{FftField, Field};
205    use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
206    use ark_std::{vec, vec::Vec};
207
208    use crate::reed_solomon_code::{
209        reed_solomon_erasure_decode, reed_solomon_erasure_decode_rou, reed_solomon_erasure_encode,
210    };
211
212    fn test_rs_code_helper<F: Field>() {
213        // Encoded as a polynomial 2x + 1
214        let data = vec![F::from(1u64), F::from(2u64)];
215        // Evaluation of the above polynomial on (1, 2, 3) is (3, 5, 7)
216        let expected = vec![F::from(3u64), F::from(5u64), F::from(7u64)];
217        let code: Vec<F> = reed_solomon_erasure_encode(data.iter(), 1).collect();
218        assert_eq!(code, expected);
219
220        for to_be_removed in 0..code.len() {
221            let mut indices = vec![F::from(1u64), F::from(2u64), F::from(3u64)];
222            let mut new_code = code.clone();
223            indices.remove(to_be_removed);
224            new_code.remove(to_be_removed);
225            let output = reed_solomon_erasure_decode(indices.iter().zip(new_code), 2).unwrap();
226            assert_eq!(data, output);
227        }
228    }
229
230    #[test]
231    fn test_rs_code() {
232        test_rs_code_helper::<Fr254>();
233        test_rs_code_helper::<Fr377>();
234        test_rs_code_helper::<Fr381>();
235    }
236
237    fn test_rs_code_fft_helper<F: FftField>() {
238        let domain = Radix2EvaluationDomain::<F>::new(3).unwrap();
239        let input = vec![F::from(1u64), F::from(2u64)];
240
241        // manually encode via FFT, then decode by explicitly supplying roots of unity
242        {
243            let mut code = domain.fft(&input);
244            let mut eval_points = domain.elements().collect::<Vec<_>>();
245            eval_points.remove(1);
246            code.remove(1);
247            let output = reed_solomon_erasure_decode(eval_points.iter().zip(code), 2).unwrap();
248            assert_eq!(input, output);
249        }
250
251        // manually encode via FFT, then decode via reed_solomon_erasure_decode_rou
252        {
253            let mut code = domain.fft(&input);
254            code.remove(1);
255            let output =
256                reed_solomon_erasure_decode_rou([0, 2].into_iter().zip(code), 2, &domain).unwrap();
257            assert_eq!(input, output);
258        }
259    }
260
261    #[test]
262    fn test_rs_code_fft() {
263        test_rs_code_fft_helper::<Fr254>();
264        test_rs_code_fft_helper::<Fr377>();
265        test_rs_code_fft_helper::<Fr381>();
266    }
267
268    fn duplicate_inputs_helper<F: Field>() {
269        // Encoded as a polynomial 4x^2 + x + 3
270        let payload = [3u64, 1, 4].map(|x| F::from(x));
271        // Evaluation of the above polynomial on (1, 2, 3, 4, 5) is (8, 21, 42, 71, 108)
272        let expected = [8u64, 21, 42, 71, 108].map(|x| F::from(x));
273        let code: Vec<F> = reed_solomon_erasure_encode(payload.iter(), 2).collect();
274        assert_eq!(code, expected);
275
276        let mut points = [1u64, 2, 3].map(|x| F::from(x));
277        let recovered_payload: Vec<F> =
278            reed_solomon_erasure_decode(points.iter().zip(&code[..3]), payload.len()).unwrap();
279        assert_eq!(recovered_payload, payload);
280
281        points[1] = points[0]; // duplicate input point
282        assert!(reed_solomon_erasure_decode::<F, _, _, _>(
283            points.iter().zip(&code[..3]),
284            payload.len()
285        )
286        .is_err());
287    }
288
289    #[test]
290    fn duplicate_inputs() {
291        duplicate_inputs_helper::<Fr381>();
292    }
293}