jf_aead/
lib.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Use `crypto_kx` to derive shared session secrets and use symmetric AEAD
8//! (`xchacha20poly1305`) for authenticated encryption with associated data.
9//!
10//! We only provide an ultra-thin wrapper for stable APIs for jellyfish users,
11//! independent of RustCrypto's upstream changes.
12
13#![cfg_attr(not(feature = "std"), no_std)]
14// Temporarily allow warning for nightly compilation with [`displaydoc`].
15#![allow(warnings)]
16#![deny(missing_docs)]
17#[cfg(test)]
18extern crate std;
19
20#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
21#[doc(hidden)]
22extern crate alloc;
23
24use ark_serialize::*;
25use ark_std::{
26    fmt, format,
27    ops::{Deref, DerefMut},
28    rand::{CryptoRng, RngCore},
29    vec::Vec,
30};
31use chacha20poly1305::{
32    aead::{Aead, AeadCore, Payload},
33    KeyInit, XChaCha20Poly1305, XNonce,
34};
35use displaydoc::Display;
36use serde::{Deserialize, Deserializer, Serialize};
37
38#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Hash)]
39/// Public/encryption key for AEAD
40pub struct EncKey(crypto_kx::PublicKey);
41
42impl From<[u8; 32]> for EncKey {
43    fn from(bytes: [u8; 32]) -> Self {
44        Self(crypto_kx::PublicKey::from(bytes))
45    }
46}
47impl From<EncKey> for [u8; 32] {
48    fn from(enc_key: EncKey) -> Self {
49        *enc_key.0.as_ref()
50    }
51}
52impl From<DecKey> for EncKey {
53    fn from(dec_key: DecKey) -> Self {
54        let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
55        Self(enc_key)
56    }
57}
58impl Default for EncKey {
59    fn default() -> Self {
60        Self(crypto_kx::PublicKey::from(
61            [0u8; crypto_kx::PublicKey::BYTES],
62        ))
63    }
64}
65impl fmt::Debug for EncKey {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.debug_tuple("aead::EncKey")
68            .field(self.0.as_ref())
69            .finish()
70    }
71}
72
73/// AEAD Error.
74// This type is deliberately opaque as in `crypto_kx`.
75#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
76pub struct AEADError;
77
78impl ark_std::error::Error for AEADError {}
79
80impl EncKey {
81    /// Encrypt a message with authenticated associated data which is an
82    /// optional bytestring which is not encrypted, but is authenticated
83    /// along with the message. Failure to pass the same AAD that was used
84    /// during encryption will cause decryption to fail, which is useful if you
85    /// would like to "bind" the ciphertext to some identifier, like a
86    /// digital signature key.
87    pub fn encrypt(
88        &self,
89        mut rng: impl RngCore + CryptoRng,
90        message: &[u8],
91        aad: &[u8],
92    ) -> Result<Ciphertext, AEADError> {
93        // generate an ephemeral key pair as the virtual sender to derive the crypto box
94        let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
95        // `crypto_kx` generates a pair of shared secrets, see <https://libsodium.gitbook.io/doc/key_exchange>
96        // we use the transmission key of the ephemeral sender (equals to the receiving
97        // key of the server) as the shared secret.
98        let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
99        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
100        let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
101
102        // encrypt the message and associated data using crypto box
103        let ct = cipher
104            .encrypt(&nonce, Payload { msg: message, aad })
105            .map_err(|_| AEADError)?;
106
107        Ok(Ciphertext {
108            nonce: Nonce(nonce),
109            ct,
110            ephemeral_pk: EncKey(*ephemeral_keypair.public()),
111        })
112    }
113}
114
115/// Private/decryption key for AEAD
116// look into zeroization logic from aead lib
117#[derive(Clone, Serialize, Deserialize)]
118struct DecKey(crypto_kx::SecretKey);
119
120impl From<[u8; 32]> for DecKey {
121    fn from(bytes: [u8; 32]) -> Self {
122        Self(crypto_kx::SecretKey::from(bytes))
123    }
124}
125impl From<DecKey> for [u8; 32] {
126    fn from(dec_key: DecKey) -> Self {
127        dec_key.0.to_bytes()
128    }
129}
130
131impl Default for DecKey {
132    fn default() -> Self {
133        Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
134    }
135}
136impl fmt::Debug for DecKey {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.debug_tuple("aead::DecKey")
139            .field(&self.0.to_bytes())
140            .finish()
141    }
142}
143
144/// Keypair for Authenticated Encryption with Associated Data
145#[derive(
146    Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
147)]
148pub struct KeyPair {
149    enc_key: EncKey,
150    dec_key: DecKey,
151}
152
153impl PartialEq for KeyPair {
154    fn eq(&self, other: &KeyPair) -> bool {
155        self.enc_key == other.enc_key
156    }
157}
158
159impl KeyPair {
160    /// Randomly sample a key pair.
161    pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
162        let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
163        Self {
164            enc_key: EncKey(enc_key),
165            dec_key: DecKey(dec_key),
166        }
167    }
168
169    /// Getter for the public/encryption key
170    pub fn enc_key(&self) -> EncKey {
171        self.enc_key.clone()
172    }
173
174    /// Getter for reference to the public/encryption key
175    pub fn enc_key_ref(&self) -> &EncKey {
176        &self.enc_key
177    }
178
179    /// Decrypt a ciphertext with authenticated associated data provided.
180    /// If the associated data is different that that used during encryption,
181    /// then decryption will fail.
182    pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
183        let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
184            .session_keys_from(&ciphertext.ephemeral_pk.0)
185            .rx;
186        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
187        let plaintext = cipher
188            .decrypt(
189                &ciphertext.nonce,
190                Payload {
191                    msg: &ciphertext.ct,
192                    aad,
193                },
194            )
195            .map_err(|_| AEADError)?;
196        Ok(plaintext)
197    }
198}
199// newtype for `chacha20poly1305::XNonce` for easier serde support for
200// `Ciphertext`.
201#[derive(Clone, Debug, PartialEq, Eq, Hash)]
202struct Nonce(XNonce);
203
204impl Serialize for Nonce {
205    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: serde::Serializer,
208    {
209        serializer.serialize_bytes(self.0.as_slice())
210    }
211}
212
213impl<'de> Deserialize<'de> for Nonce {
214    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215    where
216        D: Deserializer<'de>,
217    {
218        struct NonceVisitor;
219
220        impl<'de> serde::de::Visitor<'de> for NonceVisitor {
221            type Value = Nonce;
222
223            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224                formatter.write_str("byte array")
225            }
226
227            fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
228            where
229                E: serde::de::Error,
230            {
231                Ok(Nonce(*XNonce::from_slice(&v)))
232            }
233
234            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
235            where
236                A: serde::de::SeqAccess<'de>,
237            {
238                let bytes: Vec<u8> = seq
239                    .next_element()?
240                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
241                Ok(Nonce(*XNonce::from_slice(&bytes)))
242            }
243        }
244
245        deserializer.deserialize_byte_buf(NonceVisitor)
246    }
247}
248
249// Deref for newtype which acts like a smart pointer
250impl Deref for Nonce {
251    type Target = XNonce;
252    fn deref(&self) -> &Self::Target {
253        &self.0
254    }
255}
256impl DerefMut for Nonce {
257    fn deref_mut(&mut self) -> &mut Self::Target {
258        &mut self.0
259    }
260}
261
262/// The ciphertext produced by AEAD encryption
263#[derive(
264    Clone,
265    Debug,
266    PartialEq,
267    Eq,
268    Hash,
269    Serialize,
270    Deserialize,
271    CanonicalSerialize,
272    CanonicalDeserialize,
273)]
274pub struct Ciphertext {
275    nonce: Nonce,
276    ct: Vec<u8>,
277    ephemeral_pk: EncKey,
278}
279
280// TODO: (alex) Temporarily add CanonicalSerde back to these structs due to the
281// limitations of `tagged` proc macro and requests from downstream usage.
282// Tracking issue: <https://github.com/EspressoSystems/jellyfish/issues/288>
283mod canonical_serde {
284    use super::*;
285
286    impl CanonicalSerialize for EncKey {
287        fn serialize_with_mode<W: Write>(
288            &self,
289            mut writer: W,
290            _compress: Compress,
291        ) -> Result<(), SerializationError> {
292            let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
293            writer.write_all(&bytes)?;
294            Ok(())
295        }
296        fn serialized_size(&self, _compress: Compress) -> usize {
297            crypto_kx::PublicKey::BYTES
298        }
299    }
300
301    impl CanonicalDeserialize for EncKey {
302        fn deserialize_with_mode<R: Read>(
303            mut reader: R,
304            _compress: Compress,
305            _validate: Validate,
306        ) -> Result<Self, SerializationError> {
307            let mut result = [0u8; crypto_kx::PublicKey::BYTES];
308            reader.read_exact(&mut result)?;
309            Ok(EncKey(crypto_kx::PublicKey::from(result)))
310        }
311    }
312
313    impl Valid for EncKey {
314        fn check(&self) -> Result<(), SerializationError> {
315            Ok(())
316        }
317    }
318
319    impl CanonicalSerialize for DecKey {
320        fn serialize_with_mode<W: Write>(
321            &self,
322            mut writer: W,
323            _compress: Compress,
324        ) -> Result<(), SerializationError> {
325            let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
326            writer.write_all(&bytes)?;
327            Ok(())
328        }
329        fn serialized_size(&self, _compress: Compress) -> usize {
330            crypto_kx::SecretKey::BYTES
331        }
332    }
333
334    impl CanonicalDeserialize for DecKey {
335        fn deserialize_with_mode<R: Read>(
336            mut reader: R,
337            _compress: Compress,
338            _validate: Validate,
339        ) -> Result<Self, SerializationError> {
340            let mut result = [0u8; crypto_kx::SecretKey::BYTES];
341            reader.read_exact(&mut result)?;
342            Ok(DecKey(crypto_kx::SecretKey::from(result)))
343        }
344    }
345    impl Valid for DecKey {
346        fn check(&self) -> Result<(), SerializationError> {
347            Ok(())
348        }
349    }
350
351    impl CanonicalSerialize for Nonce {
352        fn serialize_with_mode<W: Write>(
353            &self,
354            mut writer: W,
355            _compress: Compress,
356        ) -> Result<(), SerializationError> {
357            writer.write_all(self.0.as_slice())?;
358            Ok(())
359        }
360        fn serialized_size(&self, _compress: Compress) -> usize {
361            // see <https://docs.rs/chacha20poly1305/0.10.1/chacha20poly1305/type.XNonce.html>
362            24
363        }
364    }
365
366    impl CanonicalDeserialize for Nonce {
367        fn deserialize_with_mode<R: Read>(
368            mut reader: R,
369            _compress: Compress,
370            _validate: Validate,
371        ) -> Result<Self, SerializationError> {
372            let mut result = [0u8; 24];
373            reader.read_exact(&mut result)?;
374            Ok(Nonce(XNonce::from(result)))
375        }
376    }
377    impl Valid for Nonce {
378        fn check(&self) -> Result<(), SerializationError> {
379            Ok(())
380        }
381    }
382}
383
384#[cfg(test)]
385mod test {
386    use super::*;
387    use ark_std::rand::SeedableRng;
388    use rand_chacha::ChaCha20Rng;
389
390    #[test]
391    fn test_aead_encryption() -> Result<(), AEADError> {
392        let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
393        let keypair1 = KeyPair::generate(&mut rng);
394        let keypair2 = KeyPair::generate(&mut rng);
395        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
396        let aad = b"my associated data".to_vec();
397
398        // check correctness
399        let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
400        assert!(keypair1.decrypt(&ct1, &aad).is_ok());
401        let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
402        assert!(msg == plaintext1);
403
404        // check soundness
405        assert!(keypair2.decrypt(&ct1, &aad).is_err());
406        assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
407        let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
408        let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
409        assert!(msg != plaintext2);
410
411        // rng or nounce shouldn't affect decryption
412        let rng = ChaCha20Rng::from_seed([1u8; 32]);
413        let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
414        assert!(keypair1.decrypt(&ct3, &aad).is_ok());
415        let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
416        assert!(msg == plaintext3);
417
418        Ok(())
419    }
420
421    #[test]
422    fn test_serde() {
423        let mut rng = jf_utils::test_rng();
424        let keypair = KeyPair::generate(&mut rng);
425        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
426        let aad = b"my associated data".to_vec();
427        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
428
429        // serde for Keypair
430        let bytes = bincode::serialize(&keypair).unwrap();
431        assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
432        // wrong byte length
433        assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
434
435        // serde for EncKey
436        let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
437        assert_eq!(
438            keypair.enc_key_ref(),
439            &bincode::deserialize(&bytes).unwrap()
440        );
441        // wrong byte length
442        assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
443
444        // serde for DecKey
445        let bytes = bincode::serialize(&keypair.dec_key).unwrap();
446        assert_eq!(
447            keypair.dec_key.0.to_bytes(),
448            bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
449        );
450        // wrong byte length
451        assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
452
453        // serde for Ciphertext
454        let bytes = bincode::serialize(&ciphertext).unwrap();
455        assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
456        // wrong byte length
457        assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
458    }
459
460    #[test]
461    fn test_canonical_serde() {
462        let mut rng = jf_utils::test_rng();
463        let keypair = KeyPair::generate(&mut rng);
464        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
465        let aad = b"my associated data".to_vec();
466        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
467
468        // when testing keypair, already tests serde on pk and sk
469        let mut bytes = Vec::new();
470        CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
471        assert_eq!(
472            keypair,
473            KeyPair::deserialize_compressed(&bytes[..]).unwrap()
474        );
475        assert!(KeyPair::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
476
477        let mut bytes = Vec::new();
478        CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
479        assert_eq!(
480            ciphertext,
481            Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
482        );
483        assert!(Ciphertext::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
484    }
485}