jf_poseidon2/
external.rs

1//! Generic implementation for external layers
2
3use ark_ff::PrimeField;
4
5use crate::{add_rcs, s_box};
6
7/// The fastest 4x4 MDS matrix.
8/// [ 2 3 1 1 ]
9/// [ 1 2 3 1 ]
10/// [ 1 1 2 3 ]
11/// [ 3 1 1 2 ]
12///
13/// NOTE: we use plonky3's matrix instead of that in the original paper
14/// HorizenLab's ref: <https://github.com/Plonky3/Plonky3/blob/main/poseidon2/src/external.rs#L34>
15///
16/// This requires 7 additions and 2 doubles to compute.
17/// credit: Plonky3
18#[derive(Clone, Default)]
19struct MDSMat4;
20
21impl MDSMat4 {
22    /// x := M4 * x where M4 is the 4x4 MDS matrix
23    #[inline(always)]
24    fn matmul<F: PrimeField>(x: &mut [F; 4]) {
25        let t01 = x[0] + x[1];
26        let t23 = x[2] + x[3];
27        let t0123 = t01 + t23;
28        let t01123 = t0123 + x[1];
29        let t01233 = t0123 + x[3];
30        // The order here is important. Need to overwrite x[0] and x[2] after x[1] and
31        // x[3].
32        x[3] = t01233 + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
33        x[1] = t01123 + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
34        x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
35        x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
36    }
37}
38
39#[inline(always)]
40/// Matrix multiplication in the external layers
41// @credit: `matmul_external` in zkhash, `mds_light_permutation` in plonky3
42pub(super) fn matmul_external<F: PrimeField, const T: usize>(state: &mut [F; T]) {
43    match T {
44        2 => {
45            let sum = state[0] + state[1];
46            state[0] += sum;
47            state[1] += sum;
48        },
49
50        3 => {
51            let sum = state[0] + state[1] + state[2];
52            state[0] += sum;
53            state[1] += sum;
54            state[2] += sum;
55        },
56
57        // NOTE: matching plonky3's behavior, differs from the Horizen Labs reference implementation
58        // and the paper's description (Sec 5.1 of https://eprint.iacr.org/2023/323.pdf), in which for T=4,
59        // the circulant matrix is not applied.
60
61        // Given a 4x4 MDS matrix M, we multiply by the `4N x 4N` matrix
62        // `[[2M M  ... M], [M  2M ... M], ..., [M  M ... 2M]]`.
63        4 | 8 | 12 | 16 | 20 | 24 => {
64            // First, we apply M_4 to each consecutive four elements of the state.
65            // In Appendix B's terminology, this replaces each x_i with x_i'.
66            for chunk in state.chunks_exact_mut(4) {
67                MDSMat4::matmul(chunk.try_into().unwrap());
68            }
69            // Now, we apply the outer circulant matrix (to compute the y_i values).
70
71            // We first precompute the four sums of every four elements.
72            let sums: [F; 4] =
73                core::array::from_fn(|k| (0..T).step_by(4).map(|j| state[j + k]).sum::<F>());
74
75            // The formula for each y_i involves 2x_i' term and x_j' terms for each j that
76            // equals i mod 4. In other words, we can add a single copy of x_i'
77            // to the appropriate one of our precomputed sums
78            state
79                .iter_mut()
80                .enumerate()
81                .for_each(|(i, elem)| *elem += sums[i % 4]);
82        },
83
84        _ => {
85            panic!("Unsupported state size");
86        },
87    }
88}
89
90#[inline(always)]
91/// One external round
92// @credit `external_terminal_permute_state` in plonky3
93pub(crate) fn permute_state<F: PrimeField, const T: usize>(
94    state: &mut [F; T],
95    rc: &'static [F; T],
96    d: usize,
97) {
98    add_rcs(state, rc);
99    for s in state.iter_mut() {
100        s_box(s, d);
101    }
102    matmul_external(state);
103}