1#![cfg_attr(not(feature = "std"), no_std)]
14#![deny(missing_docs)]
15#[cfg(test)]
16extern crate std;
17
18#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
19#[doc(hidden)]
20extern crate alloc;
21
22use ark_serialize::*;
23use ark_std::{
24 fmt,
25 ops::{Deref, DerefMut},
26 rand::{CryptoRng, RngCore},
27 vec::Vec,
28};
29use chacha20poly1305::{
30 aead::{Aead, AeadCore, Payload},
31 KeyInit, XChaCha20Poly1305, XNonce,
32};
33use displaydoc::Display;
34use serde::{Deserialize, Deserializer, Serialize};
35
36#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Hash)]
37pub struct EncKey(crypto_kx::PublicKey);
39
40impl From<[u8; 32]> for EncKey {
41 fn from(bytes: [u8; 32]) -> Self {
42 Self(crypto_kx::PublicKey::from(bytes))
43 }
44}
45impl From<EncKey> for [u8; 32] {
46 fn from(enc_key: EncKey) -> Self {
47 *enc_key.0.as_ref()
48 }
49}
50impl From<DecKey> for EncKey {
51 fn from(dec_key: DecKey) -> Self {
52 let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
53 Self(enc_key)
54 }
55}
56impl Default for EncKey {
57 fn default() -> Self {
58 Self(crypto_kx::PublicKey::from(
59 [0u8; crypto_kx::PublicKey::BYTES],
60 ))
61 }
62}
63impl fmt::Debug for EncKey {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_tuple("aead::EncKey")
66 .field(self.0.as_ref())
67 .finish()
68 }
69}
70
71#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
74pub struct AEADError;
75
76impl ark_std::error::Error for AEADError {}
77
78impl EncKey {
79 pub fn encrypt(
86 &self,
87 mut rng: impl RngCore + CryptoRng,
88 message: &[u8],
89 aad: &[u8],
90 ) -> Result<Ciphertext, AEADError> {
91 let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
93 let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
97 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
98 let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
99
100 let ct = cipher
102 .encrypt(&nonce, Payload { msg: message, aad })
103 .map_err(|_| AEADError)?;
104
105 Ok(Ciphertext {
106 nonce: Nonce(nonce),
107 ct,
108 ephemeral_pk: EncKey(*ephemeral_keypair.public()),
109 })
110 }
111}
112
113#[derive(Clone, Serialize, Deserialize)]
116struct DecKey(crypto_kx::SecretKey);
117
118impl From<[u8; 32]> for DecKey {
119 fn from(bytes: [u8; 32]) -> Self {
120 Self(crypto_kx::SecretKey::from(bytes))
121 }
122}
123impl From<DecKey> for [u8; 32] {
124 fn from(dec_key: DecKey) -> Self {
125 dec_key.0.to_bytes()
126 }
127}
128
129impl Default for DecKey {
130 fn default() -> Self {
131 Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
132 }
133}
134impl fmt::Debug for DecKey {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_tuple("aead::DecKey")
137 .field(&self.0.to_bytes())
138 .finish()
139 }
140}
141
142#[derive(
144 Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
145)]
146pub struct KeyPair {
147 enc_key: EncKey,
148 dec_key: DecKey,
149}
150
151impl PartialEq for KeyPair {
152 fn eq(&self, other: &KeyPair) -> bool {
153 self.enc_key == other.enc_key
154 }
155}
156
157impl KeyPair {
158 pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
160 let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
161 Self {
162 enc_key: EncKey(enc_key),
163 dec_key: DecKey(dec_key),
164 }
165 }
166
167 pub fn enc_key(&self) -> EncKey {
169 self.enc_key.clone()
170 }
171
172 pub fn enc_key_ref(&self) -> &EncKey {
174 &self.enc_key
175 }
176
177 pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
181 let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
182 .session_keys_from(&ciphertext.ephemeral_pk.0)
183 .rx;
184 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
185 let plaintext = cipher
186 .decrypt(
187 &ciphertext.nonce,
188 Payload {
189 msg: &ciphertext.ct,
190 aad,
191 },
192 )
193 .map_err(|_| AEADError)?;
194 Ok(plaintext)
195 }
196}
197#[derive(Clone, Debug, PartialEq, Eq, Hash)]
200struct Nonce(XNonce);
201
202impl Serialize for Nonce {
203 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
204 where
205 S: serde::Serializer,
206 {
207 serializer.serialize_bytes(self.0.as_slice())
208 }
209}
210
211impl<'de> Deserialize<'de> for Nonce {
212 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
213 where
214 D: Deserializer<'de>,
215 {
216 struct NonceVisitor;
217
218 impl<'de> serde::de::Visitor<'de> for NonceVisitor {
219 type Value = Nonce;
220
221 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
222 formatter.write_str("byte array")
223 }
224
225 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
226 where
227 E: serde::de::Error,
228 {
229 Ok(Nonce(*XNonce::from_slice(&v)))
230 }
231
232 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
233 where
234 A: serde::de::SeqAccess<'de>,
235 {
236 let bytes: Vec<u8> = seq
237 .next_element()?
238 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
239 Ok(Nonce(*XNonce::from_slice(&bytes)))
240 }
241 }
242
243 deserializer.deserialize_byte_buf(NonceVisitor)
244 }
245}
246
247impl Deref for Nonce {
249 type Target = XNonce;
250 fn deref(&self) -> &Self::Target {
251 &self.0
252 }
253}
254impl DerefMut for Nonce {
255 fn deref_mut(&mut self) -> &mut Self::Target {
256 &mut self.0
257 }
258}
259
260#[derive(
262 Clone,
263 Debug,
264 PartialEq,
265 Eq,
266 Hash,
267 Serialize,
268 Deserialize,
269 CanonicalSerialize,
270 CanonicalDeserialize,
271)]
272pub struct Ciphertext {
273 nonce: Nonce,
274 ct: Vec<u8>,
275 ephemeral_pk: EncKey,
276}
277
278mod canonical_serde {
282 use super::*;
283
284 impl CanonicalSerialize for EncKey {
285 fn serialize_with_mode<W: Write>(
286 &self,
287 mut writer: W,
288 _compress: Compress,
289 ) -> Result<(), SerializationError> {
290 let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
291 writer.write_all(&bytes)?;
292 Ok(())
293 }
294 fn serialized_size(&self, _compress: Compress) -> usize {
295 crypto_kx::PublicKey::BYTES
296 }
297 }
298
299 impl CanonicalDeserialize for EncKey {
300 fn deserialize_with_mode<R: Read>(
301 mut reader: R,
302 _compress: Compress,
303 _validate: Validate,
304 ) -> Result<Self, SerializationError> {
305 let mut result = [0u8; crypto_kx::PublicKey::BYTES];
306 reader.read_exact(&mut result)?;
307 Ok(EncKey(crypto_kx::PublicKey::from(result)))
308 }
309 }
310
311 impl Valid for EncKey {
312 fn check(&self) -> Result<(), SerializationError> {
313 Ok(())
314 }
315 }
316
317 impl CanonicalSerialize for DecKey {
318 fn serialize_with_mode<W: Write>(
319 &self,
320 mut writer: W,
321 _compress: Compress,
322 ) -> Result<(), SerializationError> {
323 let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
324 writer.write_all(&bytes)?;
325 Ok(())
326 }
327 fn serialized_size(&self, _compress: Compress) -> usize {
328 crypto_kx::SecretKey::BYTES
329 }
330 }
331
332 impl CanonicalDeserialize for DecKey {
333 fn deserialize_with_mode<R: Read>(
334 mut reader: R,
335 _compress: Compress,
336 _validate: Validate,
337 ) -> Result<Self, SerializationError> {
338 let mut result = [0u8; crypto_kx::SecretKey::BYTES];
339 reader.read_exact(&mut result)?;
340 Ok(DecKey(crypto_kx::SecretKey::from(result)))
341 }
342 }
343 impl Valid for DecKey {
344 fn check(&self) -> Result<(), SerializationError> {
345 Ok(())
346 }
347 }
348
349 impl CanonicalSerialize for Nonce {
350 fn serialize_with_mode<W: Write>(
351 &self,
352 mut writer: W,
353 _compress: Compress,
354 ) -> Result<(), SerializationError> {
355 writer.write_all(self.0.as_slice())?;
356 Ok(())
357 }
358 fn serialized_size(&self, _compress: Compress) -> usize {
359 24
361 }
362 }
363
364 impl CanonicalDeserialize for Nonce {
365 fn deserialize_with_mode<R: Read>(
366 mut reader: R,
367 _compress: Compress,
368 _validate: Validate,
369 ) -> Result<Self, SerializationError> {
370 let mut result = [0u8; 24];
371 reader.read_exact(&mut result)?;
372 Ok(Nonce(XNonce::from(result)))
373 }
374 }
375 impl Valid for Nonce {
376 fn check(&self) -> Result<(), SerializationError> {
377 Ok(())
378 }
379 }
380}
381
382#[cfg(test)]
383mod test {
384 use super::*;
385 use ark_std::rand::SeedableRng;
386 use rand_chacha::ChaCha20Rng;
387
388 #[test]
389 fn test_aead_encryption() -> Result<(), AEADError> {
390 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
391 let keypair1 = KeyPair::generate(&mut rng);
392 let keypair2 = KeyPair::generate(&mut rng);
393 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
394 let aad = b"my associated data".to_vec();
395
396 let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
398 assert!(keypair1.decrypt(&ct1, &aad).is_ok());
399 let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
400 assert!(msg == plaintext1);
401
402 assert!(keypair2.decrypt(&ct1, &aad).is_err());
404 assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
405 let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
406 let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
407 assert!(msg != plaintext2);
408
409 let rng = ChaCha20Rng::from_seed([1u8; 32]);
411 let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
412 assert!(keypair1.decrypt(&ct3, &aad).is_ok());
413 let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
414 assert!(msg == plaintext3);
415
416 Ok(())
417 }
418
419 #[test]
420 fn test_serde() {
421 let mut rng = jf_utils::test_rng();
422 let keypair = KeyPair::generate(&mut rng);
423 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
424 let aad = b"my associated data".to_vec();
425 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
426
427 let bytes = bincode::serialize(&keypair).unwrap();
429 assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
430 assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
432
433 let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
435 assert_eq!(
436 keypair.enc_key_ref(),
437 &bincode::deserialize(&bytes).unwrap()
438 );
439 assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
441
442 let bytes = bincode::serialize(&keypair.dec_key).unwrap();
444 assert_eq!(
445 keypair.dec_key.0.to_bytes(),
446 bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
447 );
448 assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
450
451 let bytes = bincode::serialize(&ciphertext).unwrap();
453 assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
454 assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
456 }
457
458 #[test]
459 fn test_canonical_serde() {
460 let mut rng = jf_utils::test_rng();
461 let keypair = KeyPair::generate(&mut rng);
462 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
463 let aad = b"my associated data".to_vec();
464 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
465
466 let mut bytes = Vec::new();
468 CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
469 assert_eq!(
470 keypair,
471 KeyPair::deserialize_compressed(&bytes[..]).unwrap()
472 );
473 assert!(KeyPair::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
474
475 let mut bytes = Vec::new();
476 CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
477 assert_eq!(
478 ciphertext,
479 Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
480 );
481 assert!(Ciphertext::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
482 }
483}