1#![cfg_attr(not(feature = "std"), no_std)]
14#![allow(warnings)]
16#![deny(missing_docs)]
17#[cfg(test)]
18extern crate std;
19
20#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
21#[doc(hidden)]
22extern crate alloc;
23
24use ark_serialize::*;
25use ark_std::{
26 fmt, format,
27 ops::{Deref, DerefMut},
28 rand::{CryptoRng, RngCore},
29 vec::Vec,
30};
31use chacha20poly1305::{
32 aead::{Aead, AeadCore, Payload},
33 KeyInit, XChaCha20Poly1305, XNonce,
34};
35use derivative::Derivative;
36use displaydoc::Display;
37use serde::{Deserialize, Deserializer, Serialize};
38
39#[derive(Clone, Eq, Derivative, Serialize, Deserialize)]
40#[derivative(PartialEq, Hash)]
41pub struct EncKey(crypto_kx::PublicKey);
43
44impl From<[u8; 32]> for EncKey {
45 fn from(bytes: [u8; 32]) -> Self {
46 Self(crypto_kx::PublicKey::from(bytes))
47 }
48}
49impl From<EncKey> for [u8; 32] {
50 fn from(enc_key: EncKey) -> Self {
51 *enc_key.0.as_ref()
52 }
53}
54impl From<DecKey> for EncKey {
55 fn from(dec_key: DecKey) -> Self {
56 let enc_key = *crypto_kx::Keypair::from(dec_key.0).public();
57 Self(enc_key)
58 }
59}
60impl Default for EncKey {
61 fn default() -> Self {
62 Self(crypto_kx::PublicKey::from(
63 [0u8; crypto_kx::PublicKey::BYTES],
64 ))
65 }
66}
67impl fmt::Debug for EncKey {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 f.debug_tuple("aead::EncKey")
70 .field(self.0.as_ref())
71 .finish()
72 }
73}
74
75#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Display)]
78pub struct AEADError;
79
80impl ark_std::error::Error for AEADError {}
81
82impl EncKey {
83 pub fn encrypt(
90 &self,
91 mut rng: impl RngCore + CryptoRng,
92 message: &[u8],
93 aad: &[u8],
94 ) -> Result<Ciphertext, AEADError> {
95 let ephemeral_keypair = crypto_kx::Keypair::generate(&mut rng);
97 let shared_secret = ephemeral_keypair.session_keys_to(&self.0).tx;
101 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
102 let nonce = XChaCha20Poly1305::generate_nonce(&mut rng);
103
104 let ct = cipher
106 .encrypt(&nonce, Payload { msg: message, aad })
107 .map_err(|_| AEADError)?;
108
109 Ok(Ciphertext {
110 nonce: Nonce(nonce),
111 ct,
112 ephemeral_pk: EncKey(*ephemeral_keypair.public()),
113 })
114 }
115}
116
117#[derive(Clone, Serialize, Deserialize)]
120struct DecKey(crypto_kx::SecretKey);
121
122impl From<[u8; 32]> for DecKey {
123 fn from(bytes: [u8; 32]) -> Self {
124 Self(crypto_kx::SecretKey::from(bytes))
125 }
126}
127impl From<DecKey> for [u8; 32] {
128 fn from(dec_key: DecKey) -> Self {
129 dec_key.0.to_bytes()
130 }
131}
132
133impl Default for DecKey {
134 fn default() -> Self {
135 Self(crypto_kx::SecretKey::from([0; crypto_kx::SecretKey::BYTES]))
136 }
137}
138impl fmt::Debug for DecKey {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 f.debug_tuple("aead::DecKey")
141 .field(&self.0.to_bytes())
142 .finish()
143 }
144}
145
146#[derive(
148 Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize,
149)]
150pub struct KeyPair {
151 enc_key: EncKey,
152 dec_key: DecKey,
153}
154
155impl PartialEq for KeyPair {
156 fn eq(&self, other: &KeyPair) -> bool {
157 self.enc_key == other.enc_key
158 }
159}
160
161impl KeyPair {
162 pub fn generate<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
164 let (enc_key, dec_key) = crypto_kx::Keypair::generate(rng).split();
165 Self {
166 enc_key: EncKey(enc_key),
167 dec_key: DecKey(dec_key),
168 }
169 }
170
171 pub fn enc_key(&self) -> EncKey {
173 self.enc_key.clone()
174 }
175
176 pub fn enc_key_ref(&self) -> &EncKey {
178 &self.enc_key
179 }
180
181 pub fn decrypt(&self, ciphertext: &Ciphertext, aad: &[u8]) -> Result<Vec<u8>, AEADError> {
185 let shared_secret = crypto_kx::Keypair::from(self.dec_key.0.clone())
186 .session_keys_from(&ciphertext.ephemeral_pk.0)
187 .rx;
188 let cipher = XChaCha20Poly1305::new(shared_secret.as_ref().into());
189 let plaintext = cipher
190 .decrypt(
191 &ciphertext.nonce,
192 Payload {
193 msg: &ciphertext.ct,
194 aad,
195 },
196 )
197 .map_err(|_| AEADError)?;
198 Ok(plaintext)
199 }
200}
201#[derive(Clone, Debug, PartialEq, Eq, Hash)]
204struct Nonce(XNonce);
205
206impl Serialize for Nonce {
207 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208 where
209 S: serde::Serializer,
210 {
211 serializer.serialize_bytes(self.0.as_slice())
212 }
213}
214
215impl<'de> Deserialize<'de> for Nonce {
216 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217 where
218 D: Deserializer<'de>,
219 {
220 struct NonceVisitor;
221
222 impl<'de> serde::de::Visitor<'de> for NonceVisitor {
223 type Value = Nonce;
224
225 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
226 formatter.write_str("byte array")
227 }
228
229 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
230 where
231 E: serde::de::Error,
232 {
233 Ok(Nonce(*XNonce::from_slice(&v)))
234 }
235
236 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
237 where
238 A: serde::de::SeqAccess<'de>,
239 {
240 let bytes: Vec<u8> = seq
241 .next_element()?
242 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
243 Ok(Nonce(*XNonce::from_slice(&bytes)))
244 }
245 }
246
247 deserializer.deserialize_byte_buf(NonceVisitor)
248 }
249}
250
251impl Deref for Nonce {
253 type Target = XNonce;
254 fn deref(&self) -> &Self::Target {
255 &self.0
256 }
257}
258impl DerefMut for Nonce {
259 fn deref_mut(&mut self) -> &mut Self::Target {
260 &mut self.0
261 }
262}
263
264#[derive(
266 Clone,
267 Debug,
268 PartialEq,
269 Eq,
270 Hash,
271 Serialize,
272 Deserialize,
273 CanonicalSerialize,
274 CanonicalDeserialize,
275)]
276pub struct Ciphertext {
277 nonce: Nonce,
278 ct: Vec<u8>,
279 ephemeral_pk: EncKey,
280}
281
282mod canonical_serde {
286 use super::*;
287
288 impl CanonicalSerialize for EncKey {
289 fn serialize_with_mode<W: Write>(
290 &self,
291 mut writer: W,
292 _compress: Compress,
293 ) -> Result<(), SerializationError> {
294 let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into();
295 writer.write_all(&bytes)?;
296 Ok(())
297 }
298 fn serialized_size(&self, _compress: Compress) -> usize {
299 crypto_kx::PublicKey::BYTES
300 }
301 }
302
303 impl CanonicalDeserialize for EncKey {
304 fn deserialize_with_mode<R: Read>(
305 mut reader: R,
306 _compress: Compress,
307 _validate: Validate,
308 ) -> Result<Self, SerializationError> {
309 let mut result = [0u8; crypto_kx::PublicKey::BYTES];
310 reader.read_exact(&mut result)?;
311 Ok(EncKey(crypto_kx::PublicKey::from(result)))
312 }
313 }
314
315 impl Valid for EncKey {
316 fn check(&self) -> Result<(), SerializationError> {
317 Ok(())
318 }
319 }
320
321 impl CanonicalSerialize for DecKey {
322 fn serialize_with_mode<W: Write>(
323 &self,
324 mut writer: W,
325 _compress: Compress,
326 ) -> Result<(), SerializationError> {
327 let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into();
328 writer.write_all(&bytes)?;
329 Ok(())
330 }
331 fn serialized_size(&self, _compress: Compress) -> usize {
332 crypto_kx::SecretKey::BYTES
333 }
334 }
335
336 impl CanonicalDeserialize for DecKey {
337 fn deserialize_with_mode<R: Read>(
338 mut reader: R,
339 _compress: Compress,
340 _validate: Validate,
341 ) -> Result<Self, SerializationError> {
342 let mut result = [0u8; crypto_kx::SecretKey::BYTES];
343 reader.read_exact(&mut result)?;
344 Ok(DecKey(crypto_kx::SecretKey::from(result)))
345 }
346 }
347 impl Valid for DecKey {
348 fn check(&self) -> Result<(), SerializationError> {
349 Ok(())
350 }
351 }
352
353 impl CanonicalSerialize for Nonce {
354 fn serialize_with_mode<W: Write>(
355 &self,
356 mut writer: W,
357 _compress: Compress,
358 ) -> Result<(), SerializationError> {
359 writer.write_all(self.0.as_slice())?;
360 Ok(())
361 }
362 fn serialized_size(&self, _compress: Compress) -> usize {
363 24
365 }
366 }
367
368 impl CanonicalDeserialize for Nonce {
369 fn deserialize_with_mode<R: Read>(
370 mut reader: R,
371 _compress: Compress,
372 _validate: Validate,
373 ) -> Result<Self, SerializationError> {
374 let mut result = [0u8; 24];
375 reader.read_exact(&mut result)?;
376 Ok(Nonce(XNonce::from(result)))
377 }
378 }
379 impl Valid for Nonce {
380 fn check(&self) -> Result<(), SerializationError> {
381 Ok(())
382 }
383 }
384}
385
386#[cfg(test)]
387mod test {
388 use super::*;
389 use ark_std::rand::SeedableRng;
390 use rand_chacha::ChaCha20Rng;
391
392 #[test]
393 fn test_aead_encryption() -> Result<(), AEADError> {
394 let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
395 let keypair1 = KeyPair::generate(&mut rng);
396 let keypair2 = KeyPair::generate(&mut rng);
397 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
398 let aad = b"my associated data".to_vec();
399
400 let ct1 = keypair1.enc_key.encrypt(&mut rng, &msg, &aad)?;
402 assert!(keypair1.decrypt(&ct1, &aad).is_ok());
403 let plaintext1 = keypair1.decrypt(&ct1, &aad)?;
404 assert!(msg == plaintext1);
405
406 assert!(keypair2.decrypt(&ct1, &aad).is_err());
408 assert!(keypair1.decrypt(&ct1, b"wrong associated data").is_err());
409 let ct2 = keypair1.enc_key.encrypt(&mut rng, b"wrong message", &aad)?;
410 let plaintext2 = keypair1.decrypt(&ct2, &aad)?;
411 assert!(msg != plaintext2);
412
413 let rng = ChaCha20Rng::from_seed([1u8; 32]);
415 let ct3 = keypair1.enc_key.encrypt(rng, &msg, &aad)?;
416 assert!(keypair1.decrypt(&ct3, &aad).is_ok());
417 let plaintext3 = keypair1.decrypt(&ct3, &aad)?;
418 assert!(msg == plaintext3);
419
420 Ok(())
421 }
422
423 #[test]
424 fn test_serde() {
425 let mut rng = jf_utils::test_rng();
426 let keypair = KeyPair::generate(&mut rng);
427 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
428 let aad = b"my associated data".to_vec();
429 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
430
431 let bytes = bincode::serialize(&keypair).unwrap();
433 assert_eq!(keypair, bincode::deserialize(&bytes).unwrap());
434 assert!(bincode::deserialize::<KeyPair>(&bytes[1..]).is_err());
436
437 let bytes = bincode::serialize(keypair.enc_key_ref()).unwrap();
439 assert_eq!(
440 keypair.enc_key_ref(),
441 &bincode::deserialize(&bytes).unwrap()
442 );
443 assert!(bincode::deserialize::<EncKey>(&bytes[1..]).is_err());
445
446 let bytes = bincode::serialize(&keypair.dec_key).unwrap();
448 assert_eq!(
449 keypair.dec_key.0.to_bytes(),
450 bincode::deserialize::<DecKey>(&bytes).unwrap().0.to_bytes()
451 );
452 assert!(bincode::deserialize::<DecKey>(&bytes[1..]).is_err());
454
455 let bytes = bincode::serialize(&ciphertext).unwrap();
457 assert_eq!(&ciphertext, &bincode::deserialize(&bytes).unwrap());
458 assert!(bincode::deserialize::<Ciphertext>(&bytes[1..]).is_err());
460 }
461
462 #[test]
463 fn test_canonical_serde() {
464 let mut rng = jf_utils::test_rng();
465 let keypair = KeyPair::generate(&mut rng);
466 let msg = b"The quick brown fox jumps over the lazy dog".to_vec();
467 let aad = b"my associated data".to_vec();
468 let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap();
469
470 let mut bytes = Vec::new();
472 CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap();
473 assert_eq!(
474 keypair,
475 KeyPair::deserialize_compressed(&bytes[..]).unwrap()
476 );
477 assert!(KeyPair::deserialize_compressed(&bytes[1..]).is_err());
478
479 let mut bytes = Vec::new();
480 CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap();
481 assert_eq!(
482 ciphertext,
483 Ciphertext::deserialize_compressed(&bytes[..]).unwrap()
484 );
485 assert!(Ciphertext::deserialize_compressed(&bytes[1..]).is_err());
486 }
487}