1#![cfg_attr(not(feature = "std"), no_std)]
14#![allow(warnings)]
16#![deny(missing_docs)]
17#[cfg(test)]
18extern crate std;
19
20#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
21#[doc(hidden)]
22extern crate alloc;
23
24use ark_serialize::*;
25use ark_std::{
26 fmt, format,
27 ops::{Deref, DerefMut},
28 rand::{CryptoRng, RngCore},
29 vec::Vec,
30};
31use chacha20poly1305::{
32 aead::{Aead, AeadCore, Payload},
33 KeyInit, XChaCha20Poly1305, XNonce,
34};
35use displaydoc::Display;
36use serde::{Deserialize, Deserializer, Serialize};
37
38#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Hash)]
39pub struct EncKey(crypto_kx::PublicKey);
41
42impl From<[u8; 32]> for EncKey {
43 fn from(bytes: [u8; 32]) -> Self {
44 Self(crypto_kx::PublicKey::from(bytes))
45 }
46}
47impl From<EncKey> for [u8; 32] {
48 fn from(enc_key: EncKey) -> Self {
49 *enc_key.0.as_ref()
50 }
51}
52impl From<DecKey> for EncKey {
53 fn from(dec_key: DecKey) -> Self {
54 let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
55 Self(enc_key)
56 }
57}
58impl Default for EncKey {
59 fn default() -> Self {
60 Self(crypto_kx::PublicKey::from(
61 [0u8; crypto_kx::PublicKey::BYTES],
62 ))
63 }
64}
65impl fmt::Debug for EncKey {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 f.debug_tuple("aead::EncKey")
68 .field(self.0.as_ref())
69 .finish()
70 }
71}
72
73#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
76pub struct AEADError;
77
78impl ark_std::error::Error for AEADError {}
79
80impl EncKey {
81 pub fn encrypt(
88 &self,
89 mut rng: impl RngCore + CryptoRng,
90 message: &[u8],
91 aad: &[u8],
92 ) -> Result<Ciphertext, AEADError> {
93 let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
95 let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
99 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
100 let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
101
102 let ct = cipher
104 .encrypt(&nonce, Payload { msg: message, aad })
105 .map_err(|_| AEADError)?;
106
107 Ok(Ciphertext {
108 nonce: Nonce(nonce),
109 ct,
110 ephemeral_pk: EncKey(*ephemeral_keypair.public()),
111 })
112 }
113}
114
115#[derive(Clone, Serialize, Deserialize)]
118struct DecKey(crypto_kx::SecretKey);
119
120impl From<[u8; 32]> for DecKey {
121 fn from(bytes: [u8; 32]) -> Self {
122 Self(crypto_kx::SecretKey::from(bytes))
123 }
124}
125impl From<DecKey> for [u8; 32] {
126 fn from(dec_key: DecKey) -> Self {
127 dec_key.0.to_bytes()
128 }
129}
130
131impl Default for DecKey {
132 fn default() -> Self {
133 Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
134 }
135}
136impl fmt::Debug for DecKey {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_tuple("aead::DecKey")
139 .field(&self.0.to_bytes())
140 .finish()
141 }
142}
143
144#[derive(
146 Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
147)]
148pub struct KeyPair {
149 enc_key: EncKey,
150 dec_key: DecKey,
151}
152
153impl PartialEq for KeyPair {
154 fn eq(&self, other: &KeyPair) -> bool {
155 self.enc_key == other.enc_key
156 }
157}
158
159impl KeyPair {
160 pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
162 let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
163 Self {
164 enc_key: EncKey(enc_key),
165 dec_key: DecKey(dec_key),
166 }
167 }
168
169 pub fn enc_key(&self) -> EncKey {
171 self.enc_key.clone()
172 }
173
174 pub fn enc_key_ref(&self) -> &EncKey {
176 &self.enc_key
177 }
178
179 pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
183 let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
184 .session_keys_from(&ciphertext.ephemeral_pk.0)
185 .rx;
186 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
187 let plaintext = cipher
188 .decrypt(
189 &ciphertext.nonce,
190 Payload {
191 msg: &ciphertext.ct,
192 aad,
193 },
194 )
195 .map_err(|_| AEADError)?;
196 Ok(plaintext)
197 }
198}
199#[derive(Clone, Debug, PartialEq, Eq, Hash)]
202struct Nonce(XNonce);
203
204impl Serialize for Nonce {
205 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: serde::Serializer,
208 {
209 serializer.serialize_bytes(self.0.as_slice())
210 }
211}
212
213impl<'de> Deserialize<'de> for Nonce {
214 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215 where
216 D: Deserializer<'de>,
217 {
218 struct NonceVisitor;
219
220 impl<'de> serde::de::Visitor<'de> for NonceVisitor {
221 type Value = Nonce;
222
223 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
224 formatter.write_str("byte array")
225 }
226
227 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
228 where
229 E: serde::de::Error,
230 {
231 Ok(Nonce(*XNonce::from_slice(&v)))
232 }
233
234 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
235 where
236 A: serde::de::SeqAccess<'de>,
237 {
238 let bytes: Vec<u8> = seq
239 .next_element()?
240 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
241 Ok(Nonce(*XNonce::from_slice(&bytes)))
242 }
243 }
244
245 deserializer.deserialize_byte_buf(NonceVisitor)
246 }
247}
248
249impl Deref for Nonce {
251 type Target = XNonce;
252 fn deref(&self) -> &Self::Target {
253 &self.0
254 }
255}
256impl DerefMut for Nonce {
257 fn deref_mut(&mut self) -> &mut Self::Target {
258 &mut self.0
259 }
260}
261
262#[derive(
264 Clone,
265 Debug,
266 PartialEq,
267 Eq,
268 Hash,
269 Serialize,
270 Deserialize,
271 CanonicalSerialize,
272 CanonicalDeserialize,
273)]
274pub struct Ciphertext {
275 nonce: Nonce,
276 ct: Vec<u8>,
277 ephemeral_pk: EncKey,
278}
279
280mod canonical_serde {
284 use super::*;
285
286 impl CanonicalSerialize for EncKey {
287 fn serialize_with_mode<W: Write>(
288 &self,
289 mut writer: W,
290 _compress: Compress,
291 ) -> Result<(), SerializationError> {
292 let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
293 writer.write_all(&bytes)?;
294 Ok(())
295 }
296 fn serialized_size(&self, _compress: Compress) -> usize {
297 crypto_kx::PublicKey::BYTES
298 }
299 }
300
301 impl CanonicalDeserialize for EncKey {
302 fn deserialize_with_mode<R: Read>(
303 mut reader: R,
304 _compress: Compress,
305 _validate: Validate,
306 ) -> Result<Self, SerializationError> {
307 let mut result = [0u8; crypto_kx::PublicKey::BYTES];
308 reader.read_exact(&mut result)?;
309 Ok(EncKey(crypto_kx::PublicKey::from(result)))
310 }
311 }
312
313 impl Valid for EncKey {
314 fn check(&self) -> Result<(), SerializationError> {
315 Ok(())
316 }
317 }
318
319 impl CanonicalSerialize for DecKey {
320 fn serialize_with_mode<W: Write>(
321 &self,
322 mut writer: W,
323 _compress: Compress,
324 ) -> Result<(), SerializationError> {
325 let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
326 writer.write_all(&bytes)?;
327 Ok(())
328 }
329 fn serialized_size(&self, _compress: Compress) -> usize {
330 crypto_kx::SecretKey::BYTES
331 }
332 }
333
334 impl CanonicalDeserialize for DecKey {
335 fn deserialize_with_mode<R: Read>(
336 mut reader: R,
337 _compress: Compress,
338 _validate: Validate,
339 ) -> Result<Self, SerializationError> {
340 let mut result = [0u8; crypto_kx::SecretKey::BYTES];
341 reader.read_exact(&mut result)?;
342 Ok(DecKey(crypto_kx::SecretKey::from(result)))
343 }
344 }
345 impl Valid for DecKey {
346 fn check(&self) -> Result<(), SerializationError> {
347 Ok(())
348 }
349 }
350
351 impl CanonicalSerialize for Nonce {
352 fn serialize_with_mode<W: Write>(
353 &self,
354 mut writer: W,
355 _compress: Compress,
356 ) -> Result<(), SerializationError> {
357 writer.write_all(self.0.as_slice())?;
358 Ok(())
359 }
360 fn serialized_size(&self, _compress: Compress) -> usize {
361 24
363 }
364 }
365
366 impl CanonicalDeserialize for Nonce {
367 fn deserialize_with_mode<R: Read>(
368 mut reader: R,
369 _compress: Compress,
370 _validate: Validate,
371 ) -> Result<Self, SerializationError> {
372 let mut result = [0u8; 24];
373 reader.read_exact(&mut result)?;
374 Ok(Nonce(XNonce::from(result)))
375 }
376 }
377 impl Valid for Nonce {
378 fn check(&self) -> Result<(), SerializationError> {
379 Ok(())
380 }
381 }
382}
383
384#[cfg(test)]
385mod test {
386 use super::*;
387 use ark_std::rand::SeedableRng;
388 use rand_chacha::ChaCha20Rng;
389
390 #[test]
391 fn test_aead_encryption() -> Result<(), AEADError> {
392 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
393 let keypair1 = KeyPair::generate(&mut rng);
394 let keypair2 = KeyPair::generate(&mut rng);
395 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
396 let aad = b"my associated data".to_vec();
397
398 let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
400 assert!(keypair1.decrypt(&ct1, &aad).is_ok());
401 let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
402 assert!(msg == plaintext1);
403
404 assert!(keypair2.decrypt(&ct1, &aad).is_err());
406 assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
407 let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
408 let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
409 assert!(msg != plaintext2);
410
411 let rng = ChaCha20Rng::from_seed([1u8; 32]);
413 let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
414 assert!(keypair1.decrypt(&ct3, &aad).is_ok());
415 let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
416 assert!(msg == plaintext3);
417
418 Ok(())
419 }
420
421 #[test]
422 fn test_serde() {
423 let mut rng = jf_utils::test_rng();
424 let keypair = KeyPair::generate(&mut rng);
425 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
426 let aad = b"my associated data".to_vec();
427 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
428
429 let bytes = bincode::serialize(&keypair).unwrap();
431 assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
432 assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
434
435 let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
437 assert_eq!(
438 keypair.enc_key_ref(),
439 &bincode::deserialize(&bytes).unwrap()
440 );
441 assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
443
444 let bytes = bincode::serialize(&keypair.dec_key).unwrap();
446 assert_eq!(
447 keypair.dec_key.0.to_bytes(),
448 bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
449 );
450 assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
452
453 let bytes = bincode::serialize(&ciphertext).unwrap();
455 assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
456 assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
458 }
459
460 #[test]
461 fn test_canonical_serde() {
462 let mut rng = jf_utils::test_rng();
463 let keypair = KeyPair::generate(&mut rng);
464 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
465 let aad = b"my associated data".to_vec();
466 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
467
468 let mut bytes = Vec::new();
470 CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
471 assert_eq!(
472 keypair,
473 KeyPair::deserialize_compressed(&bytes[..]).unwrap()
474 );
475 assert!(KeyPair::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
476
477 let mut bytes = Vec::new();
478 CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
479 assert_eq!(
480 ciphertext,
481 Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
482 );
483 assert!(Ciphertext::deserialize_compressed(&bytes[..bytes.len() - 1]).is_err());
484 }
485}