1#![cfg_attr(not(feature = "std"), no_std)]
10#![allow(warnings)]
12#![deny(missing_docs)]
13#[cfg(test)]
14extern crate std;
15
16#[macro_use]
17extern crate derivative;
18
19#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
20#[doc(hidden)]
21extern crate alloc;
22
23#[cfg(feature = "gadgets")]
24pub mod gadgets;
25
26use ark_ec::{
27 twisted_edwards::{Affine, Projective, TECurveConfig as Config},
28 AffineRepr, CurveGroup, Group,
29};
30use ark_ff::{Field, UniformRand};
31use ark_serialize::*;
32use ark_std::{
33 hash::{Hash, Hasher},
34 rand::{CryptoRng, Rng, RngCore},
35 string::{String, ToString},
36 vec,
37 vec::Vec,
38};
39use displaydoc::Display;
40use jf_rescue::{Permutation, RescueParameter, RescueVector, PRP, STATE_SIZE};
41#[cfg(feature = "parallel")]
42use rayon::prelude::*;
43use zeroize::Zeroize;
44
45#[derive(Display, Debug)]
47pub struct ParameterError(String);
48
49#[derive(CanonicalSerialize, CanonicalDeserialize, Zeroize, Derivative)]
54#[derivative(
55 Debug(bound = "P: Config"),
56 Clone(bound = "P: Config"),
57 Eq(bound = "P: Config"),
58 Default(bound = "P: Config")
59)]
60pub struct EncKey<P>
61where
62 P: Config,
63{
64 pub(crate) key: Projective<P>,
65}
66
67impl<P: Config> Hash for EncKey<P> {
68 fn hash<H: Hasher>(&self, state: &mut H) {
69 Hash::hash(&self.key.into_affine(), state)
70 }
71}
72
73impl<P: Config> PartialEq for EncKey<P> {
74 fn eq(&self, other: &Self) -> bool {
75 self.key.into_affine() == other.key.into_affine()
76 }
77}
78
79impl<P> From<&EncKey<P>> for (P::BaseField, P::BaseField)
80where
81 P: Config,
82{
83 fn from(pk: &EncKey<P>) -> Self {
84 let point = pk.key.into_affine();
85 (point.x, point.y)
86 }
87}
88
89#[derive(Zeroize, CanonicalSerialize, CanonicalDeserialize, Derivative)]
94#[derivative(
95 Debug(bound = "P: Config"),
96 Clone(bound = "P: Config"),
97 PartialEq(bound = "P: Config")
98)]
99pub(crate) struct DecKey<P>
100where
101 P: Config,
102{
103 key: P::ScalarField,
104}
105
106impl<P: Config> Drop for DecKey<P> {
107 fn drop(&mut self) {
108 self.key.zeroize();
109 }
110}
111
112#[derive(CanonicalSerialize, CanonicalDeserialize, Derivative)]
117#[derivative(
118 Debug(bound = "P: Config"),
119 Clone(bound = "P: Config"),
120 PartialEq(bound = "P: Config")
121)]
122pub struct KeyPair<P>
124where
125 P: Config,
126{
127 pub(crate) enc: EncKey<P>,
128 dec: DecKey<P>,
129}
130
131#[derive(CanonicalSerialize, CanonicalDeserialize, Derivative)]
136#[derivative(
137 Debug(bound = "P: Config"),
138 Clone(bound = "P: Config"),
139 PartialEq(bound = "P: Config"),
140 Eq(bound = "P: Config"),
141 Hash(bound = "P: Config")
142)]
143pub struct Ciphertext<P>
144where
145 P: Config,
146{
147 pub(crate) ephemeral: EncKey<P>,
148 pub(crate) data: Vec<P::BaseField>,
149}
150
151impl<P> Ciphertext<P>
152where
153 P: Config,
154{
155 pub fn to_scalars(&self) -> Vec<P::BaseField> {
157 let mut result = vec![];
158 let (x, y) = (&self.ephemeral).into();
159 result.push(x);
160 result.push(y);
161 result.extend_from_slice(&self.data);
162 result
163 }
164
165 pub fn from_scalars(scalars: &[P::BaseField]) -> Result<Self, ParameterError> {
167 if scalars.len() < 2 {
168 return Err(ParameterError(
169 "At least 2 scalars in length for ciphertext".to_string(),
170 ));
171 }
172 let key = Affine::new(scalars[0], scalars[1]);
173
174 let ephemeral = EncKey {
175 key: key.into_group(),
176 };
177 let mut data = vec![];
178 data.extend_from_slice(&scalars[2..]);
179 Ok(Self { ephemeral, data })
180 }
181}
182
183impl<P> KeyPair<P>
188where
189 P: Config,
190{
191 pub fn generate<R: CryptoRng + RngCore>(rng: &mut R) -> KeyPair<P> {
193 let dec = DecKey {
194 key: P::ScalarField::rand(rng),
195 };
196 let enc = EncKey::from(&dec);
197 KeyPair { enc, dec }
198 }
199
200 pub(crate) fn dec_key_ref(&self) -> &DecKey<P> {
202 &self.dec
203 }
204
205 pub fn enc_key(&self) -> EncKey<P> {
207 self.enc.clone()
208 }
209
210 pub fn enc_key_ref(&self) -> &EncKey<P> {
212 &self.enc
213 }
214}
215
216impl<P> From<DecKey<P>> for KeyPair<P>
217where
218 P: Config,
219{
220 fn from(dec: DecKey<P>) -> Self {
221 let enc = EncKey::from(&dec);
222 KeyPair { enc, dec }
223 }
224}
225
226impl<P: Config> UniformRand for EncKey<P> {
228 fn rand<R>(rng: &mut R) -> Self
229 where
230 R: Rng + RngCore + ?Sized,
231 {
232 EncKey {
233 key: Projective::<P>::rand(rng),
234 }
235 }
236}
237
238impl<F, P> EncKey<P>
239where
240 F: RescueParameter,
241 P: Config<BaseField = F>,
242{
243 fn compute_cipher_text_from_ephemeral_key_pair(
244 &self,
245 ephemeral_key_pair: KeyPair<P>,
246 msg: &[F],
247 ) -> Ciphertext<P> {
248 let shared_key = (self.key * ephemeral_key_pair.dec_key_ref().key).into_affine();
249 let perm = Permutation::default();
250 let key = perm.eval(&RescueVector::from(&[
253 shared_key.x,
254 shared_key.y,
255 F::zero(),
256 F::zero(),
257 ]));
258 Ciphertext {
260 ephemeral: ephemeral_key_pair.enc_key(),
261 data: apply_counter_mode_stream::<F>(&key, msg, &F::zero(), Direction::Encrypt),
262 }
263 }
264
265 pub fn deterministic_encrypt(&self, r: P::ScalarField, msg: &[F]) -> Ciphertext<P> {
270 let ephemeral_key_pair = KeyPair::from(DecKey { key: r });
271 self.compute_cipher_text_from_ephemeral_key_pair(ephemeral_key_pair, msg)
272 }
273
274 pub fn encrypt<R: CryptoRng + RngCore>(
276 &self,
277 prng: &mut R,
278 msg: &[P::BaseField],
279 ) -> Ciphertext<P> {
280 let ephemeral_key_pair = KeyPair::generate(prng);
281 self.compute_cipher_text_from_ephemeral_key_pair(ephemeral_key_pair, msg)
282 }
283}
284
285impl<F, P> DecKey<P>
286where
287 F: RescueParameter,
288 P: Config<BaseField = F>,
289{
290 fn decrypt(&self, ctext: &Ciphertext<P>) -> Vec<P::BaseField> {
292 let perm = Permutation::default();
293 let shared_key = (ctext.ephemeral.key * self.key).into_affine();
294 let key = perm.eval(&RescueVector::from(&[
295 shared_key.x,
296 shared_key.y,
297 F::zero(),
298 F::zero(),
299 ]));
300 apply_counter_mode_stream::<F>(&key, ctext.data.as_slice(), &F::zero(), Direction::Decrypt)
302 }
303}
304
305impl<P> From<&DecKey<P>> for EncKey<P>
306where
307 P: Config,
308{
309 fn from(dec_key: &DecKey<P>) -> Self {
310 let mut point = Projective::<P>::generator();
311 point *= dec_key.key;
312 Self { key: point }
313 }
314}
315
316impl<F, P> KeyPair<P>
317where
318 F: RescueParameter,
319 P: Config<BaseField = F>,
320{
321 pub fn decrypt(&self, ctext: &Ciphertext<P>) -> Vec<F> {
323 self.dec.decrypt(ctext)
324 }
325}
326
327pub(crate) enum Direction {
328 Encrypt,
329 Decrypt,
330}
331
332pub(crate) fn apply_counter_mode_stream<F>(
333 key: &RescueVector<F>,
334 data: &[F],
335 nonce: &F,
336 direction: Direction,
337) -> Vec<F>
338where
339 F: RescueParameter,
340{
341 let prp = PRP::default();
342 let round_keys = prp.key_schedule(key);
343 let mut output = data.to_vec();
345 pad_with_zeros(&mut output, STATE_SIZE);
347
348 let round_fn = |(idx, output_chunk): (usize, &mut [F])| {
349 let stream_chunk = prp.prp_with_round_keys(
350 &round_keys,
351 &RescueVector::from(&[
352 nonce.add(F::from(idx as u64)),
353 F::zero(),
354 F::zero(),
355 F::zero(),
356 ]),
357 );
358 for (output_elem, stream_elem) in output_chunk.iter_mut().zip(stream_chunk.elems().iter()) {
359 match direction {
360 Direction::Encrypt => output_elem.add_assign(stream_elem),
361 Direction::Decrypt => output_elem.sub_assign(stream_elem),
362 }
363 }
364 };
365 #[cfg(feature = "parallel")]
366 {
367 output
368 .par_chunks_exact_mut(STATE_SIZE)
369 .enumerate()
370 .for_each(round_fn);
371 }
372 #[cfg(not(feature = "parallel"))]
373 {
374 output
375 .chunks_exact_mut(STATE_SIZE)
376 .enumerate()
377 .for_each(round_fn);
378 }
379 output.truncate(data.len());
381 output
382}
383
384#[inline]
385fn pad_with_zeros<F: Field>(vec: &mut Vec<F>, multiple: usize) {
386 let len = vec.len();
387 let new_len = compute_len_to_next_multiple(len, multiple);
388 vec.resize(new_len, F::zero())
389}
390
391#[inline]
392fn compute_len_to_next_multiple(len: usize, multiple: usize) -> usize {
393 if len % multiple == 0 {
394 len
395 } else {
396 len + multiple - len % multiple
397 }
398}
399
400#[cfg(test)]
401mod test {
402 use super::{Ciphertext, DecKey, EncKey, KeyPair, UniformRand};
403 use ark_ed_on_bls12_377::{EdwardsConfig as ParamEd377, Fq as FqEd377, Fr as FrEd377};
404 use ark_ed_on_bls12_381::{EdwardsConfig as ParamEd381, Fq as FqEd381, Fr as FrEd381};
405 use ark_ed_on_bls12_381_bandersnatch::{
406 EdwardsConfig as ParamEd381b, Fq as FqEd381b, Fr as FrEd381b,
407 };
408 use ark_ed_on_bn254::{EdwardsConfig as ParamEd254, Fq as FqEd254, Fr as FrEd254};
409 use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
410 use ark_std::{vec, vec::Vec};
411
412 macro_rules! test_enc_and_dec {
413 ($param: tt, $base_field:tt, $scalar_field: tt) => {
414 let mut rng = jf_utils::test_rng();
415 let keypair: KeyPair<$param> = KeyPair::generate(&mut rng);
416 let mut data = vec![];
417 let mut i = 0;
418
419 let pub_key = keypair.enc_key_ref();
420
421 loop {
422 if i == 17 {
423 break;
424 }
425
426 let ctext1 = pub_key.encrypt(&mut rng, &data);
427 let decrypted1 = keypair.decrypt(&ctext1);
428 assert_eq!(&data, decrypted1.as_slice());
429 let decrypted1 = keypair.dec_key_ref().decrypt(&ctext1);
430 assert_eq!(&data, decrypted1.as_slice());
431
432 let ctext2 = pub_key.deterministic_encrypt($scalar_field::rand(&mut rng), &data);
433 let decrypted2 = keypair.decrypt(&ctext2);
434 assert_eq!(&data, decrypted2.as_slice());
435
436 data.push($base_field::rand(&mut rng));
437 i += 1;
438 }
439 };
440 }
441
442 #[test]
443 fn test_enc_and_dec() {
444 test_enc_and_dec!(ParamEd254, FqEd254, FrEd254);
445 test_enc_and_dec!(ParamEd377, FqEd377, FrEd377);
446 test_enc_and_dec!(ParamEd381, FqEd381, FrEd381);
447 test_enc_and_dec!(ParamEd381b, FqEd381b, FrEd381b);
448 }
449
450 macro_rules! test_serdes {
451 ($param: tt, $base_field:tt, $scalar_field: tt) => {
452 let mut rng = jf_utils::test_rng();
453 let keypair = KeyPair::<$param>::generate(&mut rng);
454 let msg = vec![$base_field::rand(&mut rng)];
455 let ct = keypair.enc_key().encrypt(&mut rng, &msg[..]);
456
457 let mut ser_bytes: Vec<u8> = Vec::new();
458 keypair.serialize_compressed(&mut ser_bytes).unwrap();
459 let de: KeyPair<$param> = KeyPair::deserialize_compressed(&ser_bytes[..]).unwrap();
460 assert_eq!(de, keypair);
461
462 let mut ser_bytes: Vec<u8> = Vec::new();
463 keypair.enc.serialize_compressed(&mut ser_bytes).unwrap();
464 let de: EncKey<$param> = EncKey::deserialize_compressed(&ser_bytes[..]).unwrap();
465 assert_eq!(keypair.enc, de);
466
467 let mut ser_bytes: Vec<u8> = Vec::new();
468 keypair.dec.serialize_compressed(&mut ser_bytes).unwrap();
469 let de: DecKey<$param> = DecKey::deserialize_compressed(&ser_bytes[..]).unwrap();
470 assert_eq!(keypair.dec, de);
471
472 let mut ser_bytes: Vec<u8> = Vec::new();
473 ct.serialize_compressed(&mut ser_bytes).unwrap();
474 let de: Ciphertext<$param> =
475 Ciphertext::deserialize_compressed(&ser_bytes[..]).unwrap();
476 assert_eq!(ct, de);
477 };
478 }
479
480 #[test]
481 fn test_serde() {
482 test_serdes!(ParamEd254, FqEd254, FrEd254);
483 test_serdes!(ParamEd377, FqEd377, FrEd377);
484 test_serdes!(ParamEd381, FqEd381, FrEd381);
485 test_serdes!(ParamEd381b, FqEd381b, FrEd381b);
486 }
487}