jf_poseidon2/
internal.rs

1//! Generic implementation for internal layers
2
3use ark_ff::PrimeField;
4
5use crate::s_box;
6
7/// Matrix multiplication in the internal layers
8/// Given a vector v compute the matrix vector product (1 + diag(v))*state
9/// with 1 denoting the constant matrix of ones.
10// @credit: `matmul_internal()` in zkhash and in plonky3
11#[inline(always)]
12fn matmul_internal<F: PrimeField, const T: usize>(
13    state: &mut [F; T],
14    mat_diag_minus_1: &'static [F; T],
15) {
16    match T {
17        // for 2 and 3, since we know the constants, we hardcode it
18        2 => {
19            // [2, 1]
20            // [1, 3]
21            let mut sum = state[0];
22            sum += state[1];
23            state[0] += sum;
24            state[1].double_in_place();
25            state[1] += sum;
26        },
27        3 => {
28            // [2, 1, 1]
29            // [1, 2, 1]
30            // [1, 1, 3]
31            let mut sum = state[0];
32            sum += state[1];
33            sum += state[2];
34            state[0] += sum;
35            state[1] += sum;
36            state[2].double_in_place();
37            state[2] += sum;
38        },
39        _ => {
40            let sum: F = state.iter().sum();
41            for i in 0..T {
42                state[i] *= mat_diag_minus_1[i];
43                state[i] += sum;
44            }
45        },
46    }
47}
48
49/// One internal round
50// @credit `internal_permute_state()` in plonky3
51#[inline(always)]
52pub(crate) fn permute_state<F: PrimeField, const T: usize>(
53    state: &mut [F; T],
54    rc: F,
55    d: usize,
56    mat_diag_minus_1: &'static [F; T],
57) {
58    state[0] += rc;
59    s_box(&mut state[0], d);
60    matmul_internal(state, mat_diag_minus_1);
61}