jf_aead/
lib.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Use `crypto_kx` to derive shared session secrets and use symmetric AEAD
8//! (`xchacha20poly1305`) for authenticated encryption with associated data.
9//!
10//! We only provide an ultra-thin wrapper for stable APIs for jellyfish users,
11//! independent of RustCrypto's upstream changes.
12
13#![cfg_attr(not(feature = "std"), no_std)]
14#![deny(missing_docs)]
15#[cfg(test)]
16extern crate std;
17
18#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
19#[doc(hidden)]
20extern crate alloc;
21
22use ark_serialize::*;
23use ark_std::{
24    fmt,
25    ops::{Deref, DerefMut},
26    rand::{CryptoRng, RngCore},
27    vec::Vec,
28};
29use chacha20poly1305::{
30    aead::{Aead, AeadCore, Payload},
31    KeyInit, XChaCha20Poly1305, XNonce,
32};
33use displaydoc::Display;
34use serde::{Deserialize, Deserializer, Serialize};
35
36#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Hash)]
37/// Public/encryption key for AEAD
38pub struct EncKey(crypto_kx::PublicKey);
39
40impl From<[u8; 32]> for EncKey {
41    fn from(bytes: [u8; 32]) -> Self {
42        Self(crypto_kx::PublicKey::from(bytes))
43    }
44}
45impl From<EncKey> for [u8; 32] {
46    fn from(enc_key: EncKey) -> Self {
47        *enc_key.0.as_ref()
48    }
49}
50impl From<DecKey> for EncKey {
51    fn from(dec_key: DecKey) -> Self {
52        let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
53        Self(enc_key)
54    }
55}
56impl Default for EncKey {
57    fn default() -> Self {
58        Self(crypto_kx::PublicKey::from(
59            [0u8; crypto_kx::PublicKey::BYTES],
60        ))
61    }
62}
63impl fmt::Debug for EncKey {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_tuple("aead::EncKey")
66            .field(self.0.as_ref())
67            .finish()
68    }
69}
70
71/// AEAD Error.
72// This type is deliberately opaque as in `crypto_kx`.
73#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
74pub struct AEADError;
75
76impl ark_std::error::Error for AEADError {}
77
78impl EncKey {
79    /// Encrypt a message with authenticated associated data which is an
80    /// optional bytestring which is not encrypted, but is authenticated
81    /// along with the message. Failure to pass the same AAD that was used
82    /// during encryption will cause decryption to fail, which is useful if you
83    /// would like to "bind" the ciphertext to some identifier, like a
84    /// digital signature key.
85    pub fn encrypt(
86        &self,
87        mut rng: impl RngCore + CryptoRng,
88        message: &[u8],
89        aad: &[u8],
90    ) -> Result<Ciphertext, AEADError> {
91        // generate an ephemeral key pair as the virtual sender to derive the crypto box
92        let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
93        // `crypto_kx` generates a pair of shared secrets, see <https://libsodium.gitbook.io/doc/key_exchange>
94        // we use the transmission key of the ephemeral sender (equals to the receiving
95        // key of the server) as the shared secret.
96        let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
97        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
98        let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
99
100        // encrypt the message and associated data using crypto box
101        let ct = cipher
102            .encrypt(&nonce, Payload { msg: message, aad })
103            .map_err(|_| AEADError)?;
104
105        Ok(Ciphertext {
106            nonce: Nonce(nonce),
107            ct,
108            ephemeral_pk: EncKey(*ephemeral_keypair.public()),
109        })
110    }
111}
112
113/// Private/decryption key for AEAD
114// look into zeroization logic from aead lib
115#[derive(Clone, Serialize, Deserialize)]
116struct DecKey(crypto_kx::SecretKey);
117
118impl From<[u8; 32]> for DecKey {
119    fn from(bytes: [u8; 32]) -> Self {
120        Self(crypto_kx::SecretKey::from(bytes))
121    }
122}
123impl From<DecKey> for [u8; 32] {
124    fn from(dec_key: DecKey) -> Self {
125        dec_key.0.to_bytes()
126    }
127}
128
129impl Default for DecKey {
130    fn default() -> Self {
131        Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
132    }
133}
134impl fmt::Debug for DecKey {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_tuple("aead::DecKey")
137            .field(&self.0.to_bytes())
138            .finish()
139    }
140}
141
142/// Keypair for Authenticated Encryption with Associated Data
143#[derive(
144    Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
145)]
146pub struct KeyPair {
147    enc_key: EncKey,
148    dec_key: DecKey,
149}
150
151impl PartialEq for KeyPair {
152    fn eq(&self, other: &KeyPair) -> bool {
153        self.enc_key == other.enc_key
154    }
155}
156
157impl KeyPair {
158    /// Randomly sample a key pair.
159    pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
160        let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
161        Self {
162            enc_key: EncKey(enc_key),
163            dec_key: DecKey(dec_key),
164        }
165    }
166
167    /// Getter for the public/encryption key
168    pub fn enc_key(&self) -> EncKey {
169        self.enc_key.clone()
170    }
171
172    /// Getter for reference to the public/encryption key
173    pub fn enc_key_ref(&self) -> &EncKey {
174        &self.enc_key
175    }
176
177    /// Decrypt a ciphertext with authenticated associated data provided.
178    /// If the associated data is different that that used during encryption,
179    /// then decryption will fail.
180    pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
181        let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
182            .session_keys_from(&ciphertext.ephemeral_pk.0)
183            .rx;
184        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
185        let plaintext = cipher
186            .decrypt(
187                &ciphertext.nonce,
188                Payload {
189                    msg: &ciphertext.ct,
190                    aad,
191                },
192            )
193            .map_err(|_| AEADError)?;
194        Ok(plaintext)
195    }
196}
197// newtype for `chacha20poly1305::XNonce` for easier serde support for
198// `Ciphertext`.
199#[derive(Clone, Debug, PartialEq, Eq, Hash)]
200struct Nonce(XNonce);
201
202impl Serialize for Nonce {
203    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
204    where
205        S: serde::Serializer,
206    {
207        serializer.serialize_bytes(self.0.as_slice())
208    }
209}
210
211impl<'de> Deserialize<'de> for Nonce {
212    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
213    where
214        D: Deserializer<'de>,
215    {
216        struct NonceVisitor;
217
218        impl<'de> serde::de::Visitor<'de> for NonceVisitor {
219            type Value = Nonce;
220
221            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
222                formatter.write_str("byte array")
223            }
224
225            fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
226            where
227                E: serde::de::Error,
228            {
229                Ok(Nonce(*XNonce::from_slice(&v)))
230            }
231
232            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
233            where
234                A: serde::de::SeqAccess<'de>,
235            {
236                let bytes: Vec<u8> = seq
237                    .next_element()?
238                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
239                Ok(Nonce(*XNonce::from_slice(&bytes)))
240            }
241        }
242
243        deserializer.deserialize_byte_buf(NonceVisitor)
244    }
245}
246
247// Deref for newtype which acts like a smart pointer
248impl Deref for Nonce {
249    type Target = XNonce;
250    fn deref(&self) -> &Self::Target {
251        &self.0
252    }
253}
254impl DerefMut for Nonce {
255    fn deref_mut(&mut self) -> &mut Self::Target {
256        &mut self.0
257    }
258}
259
260/// The ciphertext produced by AEAD encryption
261#[derive(
262    Clone,
263    Debug,
264    PartialEq,
265    Eq,
266    Hash,
267    Serialize,
268    Deserialize,
269    CanonicalSerialize,
270    CanonicalDeserialize,
271)]
272pub struct Ciphertext {
273    nonce: Nonce,
274    ct: Vec<u8>,
275    ephemeral_pk: EncKey,
276}
277
278// TODO: (alex) Temporarily add CanonicalSerde back to these structs due to the
279// limitations of `tagged` proc macro and requests from downstream usage.
280// Tracking issue: <https://github.com/EspressoSystems/jellyfish/issues/288>
281mod canonical_serde {
282    use super::*;
283
284    impl CanonicalSerialize for EncKey {
285        fn serialize_with_mode<W: Write>(
286            &self,
287            mut writer: W,
288            _compress: Compress,
289        ) -> Result<(), SerializationError> {
290            let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
291            writer.write_all(&bytes)?;
292            Ok(())
293        }
294        fn serialized_size(&self, _compress: Compress) -> usize {
295            crypto_kx::PublicKey::BYTES
296        }
297    }
298
299    impl CanonicalDeserialize for EncKey {
300        fn deserialize_with_mode<R: Read>(
301            mut reader: R,
302            _compress: Compress,
303            _validate: Validate,
304        ) -> Result<Self, SerializationError> {
305            let mut result = [0u8; crypto_kx::PublicKey::BYTES];
306            reader.read_exact(&mut result)?;
307            Ok(EncKey(crypto_kx::PublicKey::from(result)))
308        }
309    }
310
311    impl Valid for EncKey {
312        fn check(&self) -> Result<(), SerializationError> {
313            Ok(())
314        }
315    }
316
317    impl CanonicalSerialize for DecKey {
318        fn serialize_with_mode<W: Write>(
319            &self,
320            mut writer: W,
321            _compress: Compress,
322        ) -> Result<(), SerializationError> {
323            let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
324            writer.write_all(&bytes)?;
325            Ok(())
326        }
327        fn serialized_size(&self, _compress: Compress) -> usize {
328            crypto_kx::SecretKey::BYTES
329        }
330    }
331
332    impl CanonicalDeserialize for DecKey {
333        fn deserialize_with_mode<R: Read>(
334            mut reader: R,
335            _compress: Compress,
336            _validate: Validate,
337        ) -> Result<Self, SerializationError> {
338            let mut result = [0u8; crypto_kx::SecretKey::BYTES];
339            reader.read_exact(&mut result)?;
340            Ok(DecKey(crypto_kx::SecretKey::from(result)))
341        }
342    }
343    impl Valid for DecKey {
344        fn check(&self) -> Result<(), SerializationError> {
345            Ok(())
346        }
347    }
348
349    impl CanonicalSerialize for Nonce {
350        fn serialize_with_mode<W: Write>(
351            &self,
352            mut writer: W,
353            _compress: Compress,
354        ) -> Result<(), SerializationError> {
355            writer.write_all(self.0.as_slice())?;
356            Ok(())
357        }
358        fn serialized_size(&self, _compress: Compress) -> usize {
359            // see <https://docs.rs/chacha20poly1305/0.10.1/chacha20poly1305/type.XNonce.html>
360            24
361        }
362    }
363
364    impl CanonicalDeserialize for Nonce {
365        fn deserialize_with_mode<R: Read>(
366            mut reader: R,
367            _compress: Compress,
368            _validate: Validate,
369        ) -> Result<Self, SerializationError> {
370            let mut result = [0u8; 24];
371            reader.read_exact(&mut result)?;
372            Ok(Nonce(XNonce::from(result)))
373        }
374    }
375    impl Valid for Nonce {
376        fn check(&self) -> Result<(), SerializationError> {
377            Ok(())
378        }
379    }
380}
381
382#[cfg(test)]
383mod test {
384    use super::*;
385    use ark_std::rand::SeedableRng;
386    use rand_chacha::ChaCha20Rng;
387
388    #[test]
389    fn test_aead_encryption() -> Result<(), AEADError> {
390        let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
391        let keypair1 = KeyPair::generate(&mut rng);
392        let keypair2 = KeyPair::generate(&mut rng);
393        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
394        let aad = b"my associated data".to_vec();
395
396        // check correctness
397        let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
398        assert!(keypair1.decrypt(&ct1, &aad).is_ok());
399        let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
400        assert!(msg == plaintext1);
401
402        // check soundness
403        assert!(keypair2.decrypt(&ct1, &aad).is_err());
404        assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
405        let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
406        let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
407        assert!(msg != plaintext2);
408
409        // rng or nounce shouldn't affect decryption
410        let rng = ChaCha20Rng::from_seed([1u8; 32]);
411        let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
412        assert!(keypair1.decrypt(&ct3, &aad).is_ok());
413        let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
414        assert!(msg == plaintext3);
415
416        Ok(())
417    }
418
419    #[test]
420    fn test_serde() {
421        let mut rng = jf_utils::test_rng();
422        let keypair = KeyPair::generate(&mut rng);
423        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
424        let aad = b"my associated data".to_vec();
425        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
426
427        // serde for Keypair
428        let bytes = bincode::serialize(&keypair).unwrap();
429        assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
430        // wrong byte length
431        assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
432
433        // serde for EncKey
434        let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
435        assert_eq!(
436            keypair.enc_key_ref(),
437            &bincode::deserialize(&bytes).unwrap()
438        );
439        // wrong byte length
440        assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
441
442        // serde for DecKey
443        let bytes = bincode::serialize(&keypair.dec_key).unwrap();
444        assert_eq!(
445            keypair.dec_key.0.to_bytes(),
446            bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
447        );
448        // wrong byte length
449        assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
450
451        // serde for Ciphertext
452        let bytes = bincode::serialize(&ciphertext).unwrap();
453        assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
454        // wrong byte length
455        assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
456    }
457
458    #[test]
459    fn test_canonical_serde() {
460        let mut rng = jf_utils::test_rng();
461        let keypair = KeyPair::generate(&mut rng);
462        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
463        let aad = b"my associated data".to_vec();
464        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
465
466        // when testing keypair, already tests serde on pk and sk
467        let mut bytes = Vec::new();
468        CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
469        assert_eq!(
470            keypair,
471            KeyPair::deserialize_compressed(&bytes[..]).unwrap()
472        );
473        assert!(KeyPair::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
474
475        let mut bytes = Vec::new();
476        CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
477        assert_eq!(
478            ciphertext,
479            Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
480        );
481        assert!(Ciphertext::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
482    }
483}