1use crate::{Poseidon2, Poseidon2Params};
4use ark_ff::PrimeField;
5use ark_std::marker::PhantomData;
6use nimue::{hash::sponge::Sponge, Unit};
7use zeroize::Zeroize;
8
9pub trait Poseidon2Sponge {}
11impl<F, const N: usize, const R: usize, P> Poseidon2Sponge for Poseidon2SpongeState<F, N, R, P>
12where
13 F: PrimeField,
14 P: Poseidon2Params<F, N>,
15{
16}
17
18#[derive(Clone, Debug)]
32pub struct Poseidon2SpongeState<
33 F: PrimeField,
34 const N: usize,
35 const R: usize,
36 P: Poseidon2Params<F, N>,
37> {
38 pub(crate) state: [F; N],
40 _rate: PhantomData<[(); R]>,
41 _p: PhantomData<P>,
42}
43
44impl<F, const N: usize, const R: usize, P> Sponge for Poseidon2SpongeState<F, N, R, P>
45where
46 F: PrimeField + Unit,
47 P: Poseidon2Params<F, N>,
48{
49 type U = F;
50 const N: usize = N;
51 const R: usize = R;
52
53 fn new(iv: [u8; 32]) -> Self {
54 assert!(N >= 2 && R > 0 && N > R);
55 assert!((N - R) as u32 * <F as PrimeField>::MODULUS_BIT_SIZE >= 200);
58
59 let mut state = [F::default(); N];
61 state[R] = F::from_be_bytes_mod_order(&iv);
62 Self {
63 state,
64 _rate: PhantomData,
65 _p: PhantomData,
66 }
67 }
68
69 fn permute(&mut self) {
70 Poseidon2::permute_mut::<P, N>(&mut self.state);
71 }
72}
73impl<F, const N: usize, const R: usize, P> Default for Poseidon2SpongeState<F, N, R, P>
74where
75 F: PrimeField,
76 P: Poseidon2Params<F, N>,
77{
78 fn default() -> Self {
79 Self {
80 state: [F::default(); N],
81 _rate: PhantomData,
82 _p: PhantomData,
83 }
84 }
85}
86
87impl<F, const N: usize, const R: usize, P> AsRef<[F]> for Poseidon2SpongeState<F, N, R, P>
88where
89 F: PrimeField,
90 P: Poseidon2Params<F, N>,
91{
92 fn as_ref(&self) -> &[F] {
93 &self.state
94 }
95}
96impl<F, const N: usize, const R: usize, P> AsMut<[F]> for Poseidon2SpongeState<F, N, R, P>
97where
98 F: PrimeField,
99 P: Poseidon2Params<F, N>,
100{
101 fn as_mut(&mut self) -> &mut [F] {
102 &mut self.state
103 }
104}
105
106impl<F, const N: usize, const R: usize, P> Zeroize for Poseidon2SpongeState<F, N, R, P>
107where
108 F: PrimeField,
109 P: Poseidon2Params<F, N>,
110{
111 fn zeroize(&mut self) {
112 self.state.zeroize();
113 }
114}
115
116#[cfg(feature = "bls12-381")]
117mod bls12_381 {
118 #![allow(dead_code)]
119 use super::*;
120 use crate::constants::bls12_381::*;
121 use ark_bls12_381::Fr;
122 use nimue::hash::sponge::DuplexSponge;
123 pub type Poseidon2SpongeStateBlsN2R1 = Poseidon2SpongeState<Fr, 2, 1, Poseidon2ParamsBls2>;
125 pub type Poseidon2SpongeBlsN2R1 = DuplexSponge<Poseidon2SpongeStateBlsN2R1>;
127
128 pub type Poseidon2SpongeStateBlsN3R1 = Poseidon2SpongeState<Fr, 3, 1, Poseidon2ParamsBls3>;
130 pub type Poseidon2SpongeBlsN3R1 = DuplexSponge<Poseidon2SpongeStateBlsN3R1>;
132
133 pub type Poseidon2SpongeStateBlsN3R2 = Poseidon2SpongeState<Fr, 3, 2, Poseidon2ParamsBls3>;
135 pub type Poseidon2SpongeBlsN3R2 = DuplexSponge<Poseidon2SpongeStateBlsN3R2>;
137
138 #[test]
139 fn test_bls_sponge() {
140 use super::tests::test_sponge;
141 test_sponge::<Fr, Poseidon2SpongeBlsN2R1>();
142 test_sponge::<Fr, Poseidon2SpongeBlsN3R1>();
143 test_sponge::<Fr, Poseidon2SpongeBlsN3R2>();
144 }
145}
146
147#[cfg(feature = "bn254")]
148mod bn254 {
149 #![allow(dead_code)]
150 use super::*;
151 use crate::constants::bn254::*;
152 use ark_bn254::Fr;
153 use nimue::hash::sponge::DuplexSponge;
154 pub type Poseidon2SpongeStateBnN3R1 = Poseidon2SpongeState<Fr, 3, 1, Poseidon2ParamsBn3>;
156 pub type Poseidon2SpongeBnN3R1 = DuplexSponge<Poseidon2SpongeStateBnN3R1>;
158
159 pub type Poseidon2SpongeStateBnN3R2 = Poseidon2SpongeState<Fr, 3, 2, Poseidon2ParamsBn3>;
161 pub type Poseidon2SpongeBnN3R2 = DuplexSponge<Poseidon2SpongeStateBnN3R2>;
163
164 #[test]
165 fn test_bn_sponge() {
166 use super::tests::test_sponge;
167 test_sponge::<Fr, Poseidon2SpongeBnN3R1>();
168 test_sponge::<Fr, Poseidon2SpongeBnN3R2>();
169 }
170}
171
172#[cfg(test)]
173pub(crate) mod tests {
174 use super::*;
175 use ark_ff::BigInteger;
176 use ark_std::vec::Vec;
177 use nimue::{DuplexHash, IOPattern, UnitTranscript};
178
179 pub(crate) fn test_sponge<F: PrimeField + Unit, H: DuplexHash<F>>() {
182 let io = IOPattern::<H, F>::new("test")
183 .absorb(1, "in")
184 .squeeze(2048, "out");
185
186 let mut merlin = io.to_merlin();
188 merlin.add_units(&[F::from(42u32)]).unwrap();
190
191 let mut merlin_challenges = [F::default(); 2048];
192 merlin.fill_challenge_units(&mut merlin_challenges).unwrap();
193
194 let mut arthur = io.to_arthur(merlin.transcript());
196 arthur.fill_next_units(&mut [F::default()]).unwrap();
199 let mut arthur_challenges = [F::default(); 2048];
200 arthur.fill_challenge_units(&mut arthur_challenges).unwrap();
201
202 assert_eq!(merlin_challenges, arthur_challenges);
204
205 let chal_bytes: Vec<u8> = merlin_challenges
207 .iter()
208 .flat_map(|c| c.into_bigint().to_bytes_le())
209 .collect();
210 let frequencies = (0u8..=255)
211 .map(|i| chal_bytes.iter().filter(|&&x| x == i).count())
212 .collect::<Vec<_>>();
213 let expected_mean = (F::MODULUS_BIT_SIZE / 8 * 2048 / 256) as usize;
215 assert!(
216 frequencies
217 .iter()
218 .all(|&x| x < expected_mean * 2 && x > expected_mean / 2),
219 "Counts for each value shouldn't be too far away from mean: {:?}",
220 frequencies
221 );
222 }
223}