jf_poseidon2/external.rs
1//! Generic implementation for external layers
2
3use ark_ff::PrimeField;
4
5use crate::{add_rcs, s_box};
6
7/// The fastest 4x4 MDS matrix.
8/// [ 2 3 1 1 ]
9/// [ 1 2 3 1 ]
10/// [ 1 1 2 3 ]
11/// [ 3 1 1 2 ]
12///
13/// NOTE: we use plonky3's matrix instead of that in the original paper
14/// HorizenLab's ref: <https://github.com/Plonky3/Plonky3/blob/main/poseidon2/src/external.rs#L34>
15///
16/// This requires 7 additions and 2 doubles to compute.
17/// credit: Plonky3
18#[derive(Clone, Default)]
19struct MDSMat4;
20
21impl MDSMat4 {
22 /// x := M4 * x where M4 is the 4x4 MDS matrix
23 #[inline(always)]
24 fn matmul<F: PrimeField>(x: &mut [F; 4]) {
25 let t01 = x[0] + x[1];
26 let t23 = x[2] + x[3];
27 let t0123 = t01 + t23;
28 let t01123 = t0123 + x[1];
29 let t01233 = t0123 + x[3];
30 // The order here is important. Need to overwrite x[0] and x[2] after x[1] and
31 // x[3].
32 x[3] = t01233 + x[0].double(); // 3*x[0] + x[1] + x[2] + 2*x[3]
33 x[1] = t01123 + x[2].double(); // x[0] + 2*x[1] + 3*x[2] + x[3]
34 x[0] = t01123 + t01; // 2*x[0] + 3*x[1] + x[2] + x[3]
35 x[2] = t01233 + t23; // x[0] + x[1] + 2*x[2] + 3*x[3]
36 }
37}
38
39#[inline(always)]
40/// Matrix multiplication in the external layers
41// @credit: `matmul_external` in zkhash, `mds_light_permutation` in plonky3
42pub(super) fn matmul_external<F: PrimeField, const T: usize>(state: &mut [F; T]) {
43 match T {
44 2 => {
45 let sum = state[0] + state[1];
46 state[0] += sum;
47 state[1] += sum;
48 },
49
50 3 => {
51 let sum = state[0] + state[1] + state[2];
52 state[0] += sum;
53 state[1] += sum;
54 state[2] += sum;
55 },
56
57 // NOTE: matching plonky3's behavior, differs from the Horizen Labs reference implementation
58 // and the paper's description (Sec 5.1 of https://eprint.iacr.org/2023/323.pdf), in which for T=4,
59 // the circulant matrix is not applied.
60
61 // Given a 4x4 MDS matrix M, we multiply by the `4N x 4N` matrix
62 // `[[2M M ... M], [M 2M ... M], ..., [M M ... 2M]]`.
63 4 | 8 | 12 | 16 | 20 | 24 => {
64 // First, we apply M_4 to each consecutive four elements of the state.
65 // In Appendix B's terminology, this replaces each x_i with x_i'.
66 for chunk in state.chunks_exact_mut(4) {
67 MDSMat4::matmul(chunk.try_into().unwrap());
68 }
69 // Now, we apply the outer circulant matrix (to compute the y_i values).
70
71 // We first precompute the four sums of every four elements.
72 let sums: [F; 4] =
73 core::array::from_fn(|k| (0..T).step_by(4).map(|j| state[j + k]).sum::<F>());
74
75 // The formula for each y_i involves 2x_i' term and x_j' terms for each j that
76 // equals i mod 4. In other words, we can add a single copy of x_i'
77 // to the appropriate one of our precomputed sums
78 state
79 .iter_mut()
80 .enumerate()
81 .for_each(|(i, elem)| *elem += sums[i % 4]);
82 },
83
84 _ => {
85 panic!("Unsupported state size");
86 },
87 }
88}
89
90#[inline(always)]
91/// One external round
92// @credit `external_terminal_permute_state` in plonky3
93pub(crate) fn permute_state<F: PrimeField, const T: usize>(
94 state: &mut [F; T],
95 rc: &'static [F; T],
96 d: usize,
97) {
98 add_rcs(state, rc);
99 for s in state.iter_mut() {
100 s_box(s, d);
101 }
102 matmul_external(state);
103}