jf_aead/
lib.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Use `crypto_kx` to derive shared session secrets and use symmetric AEAD
8//! (`xchacha20poly1305`) for authenticated encryption with associated data.
9//!
10//! We only provide an ultra-thin wrapper for stable APIs for jellyfish users,
11//! independent of RustCrypto's upstream changes.
12
13#![cfg_attr(not(feature = "std"), no_std)]
14// Temporarily allow warning for nightly compilation with [`displaydoc`].
15#![allow(warnings)]
16#![deny(missing_docs)]
17#[cfg(test)]
18extern crate std;
19
20#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
21#[doc(hidden)]
22extern crate alloc;
23
24use ark_serialize::*;
25use ark_std::{
26    fmt, format,
27    ops::{Deref, DerefMut},
28    rand::{CryptoRng, RngCore},
29    vec::Vec,
30};
31use chacha20poly1305::{
32    aead::{Aead, AeadCore, Payload},
33    KeyInit, XChaCha20Poly1305, XNonce,
34};
35use derivative::Derivative;
36use displaydoc::Display;
37use serde::{Deserialize, Deserializer, Serialize};
38
39#[derive(Clone, Eq, Derivative, Serialize, Deserialize)]
40#[derivative(PartialEq, Hash)]
41/// Public/encryption key for AEAD
42pub struct EncKey(crypto_kx::PublicKey);
43
44impl From<[u8; 32]> for EncKey {
45    fn from(bytes: [u8; 32]) -> Self {
46        Self(crypto_kx::PublicKey::from(bytes))
47    }
48}
49impl From<EncKey> for [u8; 32] {
50    fn from(enc_key: EncKey) -> Self {
51        *enc_key.0.as_ref()
52    }
53}
54impl From<DecKey> for EncKey {
55    fn from(dec_key: DecKey) -> Self {
56        let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
57        Self(enc_key)
58    }
59}
60impl Default for EncKey {
61    fn default() -> Self {
62        Self(crypto_kx::PublicKey::from(
63            [0u8; crypto_kx::PublicKey::BYTES],
64        ))
65    }
66}
67impl fmt::Debug for EncKey {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        f.debug_tuple("aead::EncKey")
70            .field(self.0.as_ref())
71            .finish()
72    }
73}
74
75/// AEAD Error.
76// This type is deliberately opaque as in `crypto_kx`.
77#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
78pub struct AEADError;
79
80impl ark_std::error::Error for AEADError {}
81
82impl EncKey {
83    /// Encrypt a message with authenticated associated data which is an
84    /// optional bytestring which is not encrypted, but is authenticated
85    /// along with the message. Failure to pass the same AAD that was used
86    /// during encryption will cause decryption to fail, which is useful if you
87    /// would like to "bind" the ciphertext to some identifier, like a
88    /// digital signature key.
89    pub fn encrypt(
90        &self,
91        mut rng: impl RngCore + CryptoRng,
92        message: &[u8],
93        aad: &[u8],
94    ) -> Result<Ciphertext, AEADError> {
95        // generate an ephemeral key pair as the virtual sender to derive the crypto box
96        let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
97        // `crypto_kx` generates a pair of shared secrets, see <https://libsodium.gitbook.io/doc/key_exchange>
98        // we use the transmission key of the ephemeral sender (equals to the receiving
99        // key of the server) as the shared secret.
100        let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
101        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
102        let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
103
104        // encrypt the message and associated data using crypto box
105        let ct = cipher
106            .encrypt(&nonce, Payload { msg: message, aad })
107            .map_err(|_| AEADError)?;
108
109        Ok(Ciphertext {
110            nonce: Nonce(nonce),
111            ct,
112            ephemeral_pk: EncKey(*ephemeral_keypair.public()),
113        })
114    }
115}
116
117/// Private/decryption key for AEAD
118// look into zeroization logic from aead lib
119#[derive(Clone, Serialize, Deserialize)]
120struct DecKey(crypto_kx::SecretKey);
121
122impl From<[u8; 32]> for DecKey {
123    fn from(bytes: [u8; 32]) -> Self {
124        Self(crypto_kx::SecretKey::from(bytes))
125    }
126}
127impl From<DecKey> for [u8; 32] {
128    fn from(dec_key: DecKey) -> Self {
129        dec_key.0.to_bytes()
130    }
131}
132
133impl Default for DecKey {
134    fn default() -> Self {
135        Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
136    }
137}
138impl fmt::Debug for DecKey {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        f.debug_tuple("aead::DecKey")
141            .field(&self.0.to_bytes())
142            .finish()
143    }
144}
145
146/// Keypair for Authenticated Encryption with Associated Data
147#[derive(
148    Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
149)]
150pub struct KeyPair {
151    enc_key: EncKey,
152    dec_key: DecKey,
153}
154
155impl PartialEq for KeyPair {
156    fn eq(&self, other: &KeyPair) -> bool {
157        self.enc_key == other.enc_key
158    }
159}
160
161impl KeyPair {
162    /// Randomly sample a key pair.
163    pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
164        let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
165        Self {
166            enc_key: EncKey(enc_key),
167            dec_key: DecKey(dec_key),
168        }
169    }
170
171    /// Getter for the public/encryption key
172    pub fn enc_key(&self) -> EncKey {
173        self.enc_key.clone()
174    }
175
176    /// Getter for reference to the public/encryption key
177    pub fn enc_key_ref(&self) -> &EncKey {
178        &self.enc_key
179    }
180
181    /// Decrypt a ciphertext with authenticated associated data provided.
182    /// If the associated data is different that that used during encryption,
183    /// then decryption will fail.
184    pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
185        let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
186            .session_keys_from(&ciphertext.ephemeral_pk.0)
187            .rx;
188        let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
189        let plaintext = cipher
190            .decrypt(
191                &ciphertext.nonce,
192                Payload {
193                    msg: &ciphertext.ct,
194                    aad,
195                },
196            )
197            .map_err(|_| AEADError)?;
198        Ok(plaintext)
199    }
200}
201// newtype for `chacha20poly1305::XNonce` for easier serde support for
202// `Ciphertext`.
203#[derive(Clone, Debug, PartialEq, Eq, Hash)]
204struct Nonce(XNonce);
205
206impl Serialize for Nonce {
207    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208    where
209        S: serde::Serializer,
210    {
211        serializer.serialize_bytes(self.0.as_slice())
212    }
213}
214
215impl<'de> Deserialize<'de> for Nonce {
216    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217    where
218        D: Deserializer<'de>,
219    {
220        struct NonceVisitor;
221
222        impl<'de> serde::de::Visitor<'de> for NonceVisitor {
223            type Value = Nonce;
224
225            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
226                formatter.write_str("byte array")
227            }
228
229            fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
230            where
231                E: serde::de::Error,
232            {
233                Ok(Nonce(*XNonce::from_slice(&v)))
234            }
235
236            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
237            where
238                A: serde::de::SeqAccess<'de>,
239            {
240                let bytes: Vec<u8> = seq
241                    .next_element()?
242                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
243                Ok(Nonce(*XNonce::from_slice(&bytes)))
244            }
245        }
246
247        deserializer.deserialize_byte_buf(NonceVisitor)
248    }
249}
250
251// Deref for newtype which acts like a smart pointer
252impl Deref for Nonce {
253    type Target = XNonce;
254    fn deref(&self) -> &Self::Target {
255        &self.0
256    }
257}
258impl DerefMut for Nonce {
259    fn deref_mut(&mut self) -> &mut Self::Target {
260        &mut self.0
261    }
262}
263
264/// The ciphertext produced by AEAD encryption
265#[derive(
266    Clone,
267    Debug,
268    PartialEq,
269    Eq,
270    Hash,
271    Serialize,
272    Deserialize,
273    CanonicalSerialize,
274    CanonicalDeserialize,
275)]
276pub struct Ciphertext {
277    nonce: Nonce,
278    ct: Vec<u8>,
279    ephemeral_pk: EncKey,
280}
281
282// TODO: (alex) Temporarily add CanonicalSerde back to these structs due to the
283// limitations of `tagged` proc macro and requests from downstream usage.
284// Tracking issue: <https://github.com/EspressoSystems/jellyfish/issues/288>
285mod canonical_serde {
286    use super::*;
287
288    impl CanonicalSerialize for EncKey {
289        fn serialize_with_mode<W: Write>(
290            &self,
291            mut writer: W,
292            _compress: Compress,
293        ) -> Result<(), SerializationError> {
294            let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
295            writer.write_all(&bytes)?;
296            Ok(())
297        }
298        fn serialized_size(&self, _compress: Compress) -> usize {
299            crypto_kx::PublicKey::BYTES
300        }
301    }
302
303    impl CanonicalDeserialize for EncKey {
304        fn deserialize_with_mode<R: Read>(
305            mut reader: R,
306            _compress: Compress,
307            _validate: Validate,
308        ) -> Result<Self, SerializationError> {
309            let mut result = [0u8; crypto_kx::PublicKey::BYTES];
310            reader.read_exact(&mut result)?;
311            Ok(EncKey(crypto_kx::PublicKey::from(result)))
312        }
313    }
314
315    impl Valid for EncKey {
316        fn check(&self) -> Result<(), SerializationError> {
317            Ok(())
318        }
319    }
320
321    impl CanonicalSerialize for DecKey {
322        fn serialize_with_mode<W: Write>(
323            &self,
324            mut writer: W,
325            _compress: Compress,
326        ) -> Result<(), SerializationError> {
327            let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
328            writer.write_all(&bytes)?;
329            Ok(())
330        }
331        fn serialized_size(&self, _compress: Compress) -> usize {
332            crypto_kx::SecretKey::BYTES
333        }
334    }
335
336    impl CanonicalDeserialize for DecKey {
337        fn deserialize_with_mode<R: Read>(
338            mut reader: R,
339            _compress: Compress,
340            _validate: Validate,
341        ) -> Result<Self, SerializationError> {
342            let mut result = [0u8; crypto_kx::SecretKey::BYTES];
343            reader.read_exact(&mut result)?;
344            Ok(DecKey(crypto_kx::SecretKey::from(result)))
345        }
346    }
347    impl Valid for DecKey {
348        fn check(&self) -> Result<(), SerializationError> {
349            Ok(())
350        }
351    }
352
353    impl CanonicalSerialize for Nonce {
354        fn serialize_with_mode<W: Write>(
355            &self,
356            mut writer: W,
357            _compress: Compress,
358        ) -> Result<(), SerializationError> {
359            writer.write_all(self.0.as_slice())?;
360            Ok(())
361        }
362        fn serialized_size(&self, _compress: Compress) -> usize {
363            // see <https://docs.rs/chacha20poly1305/0.10.1/chacha20poly1305/type.XNonce.html>
364            24
365        }
366    }
367
368    impl CanonicalDeserialize for Nonce {
369        fn deserialize_with_mode<R: Read>(
370            mut reader: R,
371            _compress: Compress,
372            _validate: Validate,
373        ) -> Result<Self, SerializationError> {
374            let mut result = [0u8; 24];
375            reader.read_exact(&mut result)?;
376            Ok(Nonce(XNonce::from(result)))
377        }
378    }
379    impl Valid for Nonce {
380        fn check(&self) -> Result<(), SerializationError> {
381            Ok(())
382        }
383    }
384}
385
386#[cfg(test)]
387mod test {
388    use super::*;
389    use ark_std::rand::SeedableRng;
390    use rand_chacha::ChaCha20Rng;
391
392    #[test]
393    fn test_aead_encryption() -> Result<(), AEADError> {
394        let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
395        let keypair1 = KeyPair::generate(&mut rng);
396        let keypair2 = KeyPair::generate(&mut rng);
397        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
398        let aad = b"my associated data".to_vec();
399
400        // check correctness
401        let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
402        assert!(keypair1.decrypt(&ct1, &aad).is_ok());
403        let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
404        assert!(msg == plaintext1);
405
406        // check soundness
407        assert!(keypair2.decrypt(&ct1, &aad).is_err());
408        assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
409        let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
410        let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
411        assert!(msg != plaintext2);
412
413        // rng or nounce shouldn't affect decryption
414        let rng = ChaCha20Rng::from_seed([1u8; 32]);
415        let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
416        assert!(keypair1.decrypt(&ct3, &aad).is_ok());
417        let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
418        assert!(msg == plaintext3);
419
420        Ok(())
421    }
422
423    #[test]
424    fn test_serde() {
425        let mut rng = jf_utils::test_rng();
426        let keypair = KeyPair::generate(&mut rng);
427        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
428        let aad = b"my associated data".to_vec();
429        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
430
431        // serde for Keypair
432        let bytes = bincode::serialize(&keypair).unwrap();
433        assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
434        // wrong byte length
435        assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
436
437        // serde for EncKey
438        let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
439        assert_eq!(
440            keypair.enc_key_ref(),
441            &bincode::deserialize(&bytes).unwrap()
442        );
443        // wrong byte length
444        assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
445
446        // serde for DecKey
447        let bytes = bincode::serialize(&keypair.dec_key).unwrap();
448        assert_eq!(
449            keypair.dec_key.0.to_bytes(),
450            bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
451        );
452        // wrong byte length
453        assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
454
455        // serde for Ciphertext
456        let bytes = bincode::serialize(&ciphertext).unwrap();
457        assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
458        // wrong byte length
459        assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
460    }
461
462    #[test]
463    fn test_canonical_serde() {
464        let mut rng = jf_utils::test_rng();
465        let keypair = KeyPair::generate(&mut rng);
466        let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
467        let aad = b"my associated data".to_vec();
468        let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
469
470        // when testing keypair, already tests serde on pk and sk
471        let mut bytes = Vec::new();
472        CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
473        assert_eq!(
474            keypair,
475            KeyPair::deserialize_compressed(&bytes[..]).unwrap()
476        );
477        assert!(KeyPair::deserialize_compressed(&bytes[1..]).is_err());
478
479        let mut bytes = Vec::new();
480        CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
481        assert_eq!(
482            ciphertext,
483            Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
484        );
485        assert!(Ciphertext::deserialize_compressed(&bytes[1..]).is_err());
486    }
487}