jf_elgamal/
lib.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the Jellyfish library.
3
4// You should have received a copy of the MIT License
5// along with the Jellyfish library. If not, see <https://mit-license.org/>.
6
7//! Implements the ElGamal encryption scheme.
8
9#![cfg_attr(not(feature = "std"), no_std)]
10// Temporarily allow warning for nightly compilation with [`displaydoc`].
11#![allow(warnings)]
12#![deny(missing_docs)]
13#[cfg(test)]
14extern crate std;
15
16#[macro_use]
17extern crate derivative;
18
19#[cfg(any(not(feature = "std"), target_has_atomic = "ptr"))]
20#[doc(hidden)]
21extern crate alloc;
22
23#[cfg(feature = "gadgets")]
24pub mod gadgets;
25
26use ark_ec::{
27    twisted_edwards::{Affine, Projective, TECurveConfig as Config},
28    AffineRepr, CurveGroup, Group,
29};
30use ark_ff::{Field, UniformRand};
31use ark_serialize::*;
32use ark_std::{
33    hash::{Hash, Hasher},
34    rand::{CryptoRng, Rng, RngCore},
35    string::{String, ToString},
36    vec,
37    vec::Vec,
38};
39use displaydoc::Display;
40use jf_rescue::{Permutation, RescueParameter, RescueVector, PRP, STATE_SIZE};
41#[cfg(feature = "parallel")]
42use rayon::prelude::*;
43use zeroize::Zeroize;
44
45/// Parameter error: {0}
46#[derive(Display, Debug)]
47pub struct ParameterError(String);
48
49// =====================================================
50// encrypt key
51// =====================================================
52/// Encryption key for encryption scheme
53#[derive(CanonicalSerialize, CanonicalDeserialize, Zeroize, Derivative)]
54#[derivative(
55    Debug(bound = "P: Config"),
56    Clone(bound = "P: Config"),
57    Eq(bound = "P: Config"),
58    Default(bound = "P: Config")
59)]
60pub struct EncKey<P>
61where
62    P: Config,
63{
64    pub(crate) key: Projective<P>,
65}
66
67impl<P: Config> Hash for EncKey<P> {
68    fn hash<H: Hasher>(&self, state: &mut H) {
69        Hash::hash(&self.key.into_affine(), state)
70    }
71}
72
73impl<P: Config> PartialEq for EncKey<P> {
74    fn eq(&self, other: &Self) -> bool {
75        self.key.into_affine() == other.key.into_affine()
76    }
77}
78
79impl<P> From<&EncKey<P>> for (P::BaseField, P::BaseField)
80where
81    P: Config,
82{
83    fn from(pk: &EncKey<P>) -> Self {
84        let point = pk.key.into_affine();
85        (point.x, point.y)
86    }
87}
88
89// =====================================================
90// decrypt key
91// =====================================================
92/// Decryption key for encryption scheme
93#[derive(Zeroize, CanonicalSerialize, CanonicalDeserialize, Derivative)]
94#[derivative(
95    Debug(bound = "P: Config"),
96    Clone(bound = "P: Config"),
97    PartialEq(bound = "P: Config")
98)]
99pub(crate) struct DecKey<P>
100where
101    P: Config,
102{
103    key: P::ScalarField,
104}
105
106impl<P: Config> Drop for DecKey<P> {
107    fn drop(&mut self) {
108        self.key.zeroize();
109    }
110}
111
112// =====================================================
113// key pair
114// =====================================================
115
116#[derive(CanonicalSerialize, CanonicalDeserialize, Derivative)]
117#[derivative(
118    Debug(bound = "P: Config"),
119    Clone(bound = "P: Config"),
120    PartialEq(bound = "P: Config")
121)]
122/// KeyPair structure for encryption scheme
123pub struct KeyPair<P>
124where
125    P: Config,
126{
127    pub(crate) enc: EncKey<P>,
128    dec: DecKey<P>,
129}
130
131// =====================================================
132// ciphertext
133// =====================================================
134/// Public encryption cipher text
135#[derive(CanonicalSerialize, CanonicalDeserialize, Derivative)]
136#[derivative(
137    Debug(bound = "P: Config"),
138    Clone(bound = "P: Config"),
139    PartialEq(bound = "P: Config"),
140    Eq(bound = "P: Config"),
141    Hash(bound = "P: Config")
142)]
143pub struct Ciphertext<P>
144where
145    P: Config,
146{
147    pub(crate) ephemeral: EncKey<P>,
148    pub(crate) data: Vec<P::BaseField>,
149}
150
151impl<P> Ciphertext<P>
152where
153    P: Config,
154{
155    /// Flatten out the ciphertext into a vector of scalars
156    pub fn to_scalars(&self) -> Vec<P::BaseField> {
157        let mut result = vec![];
158        let (x, y) = (&self.ephemeral).into();
159        result.push(x);
160        result.push(y);
161        result.extend_from_slice(&self.data);
162        result
163    }
164
165    /// Reconstruct the ciphertext from a list of scalars.
166    pub fn from_scalars(scalars: &[P::BaseField]) -> Result<Self, ParameterError> {
167        if scalars.len() < 2 {
168            return Err(ParameterError(
169                "At least 2 scalars in length for ciphertext".to_string(),
170            ));
171        }
172        let key = Affine::new(scalars[0], scalars[1]);
173
174        let ephemeral = EncKey {
175            key: key.into_group(),
176        };
177        let mut data = vec![];
178        data.extend_from_slice(&scalars[2..]);
179        Ok(Self { ephemeral, data })
180    }
181}
182
183// =====================================================
184// end of definitions
185// =====================================================
186
187impl<P> KeyPair<P>
188where
189    P: Config,
190{
191    /// Key generation algorithm for public key encryption scheme
192    pub fn generate<R: CryptoRng + RngCore>(rng: &mut R) -> KeyPair<P> {
193        let dec = DecKey {
194            key: P::ScalarField::rand(rng),
195        };
196        let enc = EncKey::from(&dec);
197        KeyPair { enc, dec }
198    }
199
200    /// Get decryption key reference
201    pub(crate) fn dec_key_ref(&self) -> &DecKey<P> {
202        &self.dec
203    }
204
205    /// Get encryption key
206    pub fn enc_key(&self) -> EncKey<P> {
207        self.enc.clone()
208    }
209
210    /// Get encryption key reference
211    pub fn enc_key_ref(&self) -> &EncKey<P> {
212        &self.enc
213    }
214}
215
216impl<P> From<DecKey<P>> for KeyPair<P>
217where
218    P: Config,
219{
220    fn from(dec: DecKey<P>) -> Self {
221        let enc = EncKey::from(&dec);
222        KeyPair { enc, dec }
223    }
224}
225
226/// Sample a random public key with unknown associated secret key
227impl<P: Config> UniformRand for EncKey<P> {
228    fn rand<R>(rng: &mut R) -> Self
229    where
230        R: Rng + RngCore + ?Sized,
231    {
232        EncKey {
233            key: Projective::<P>::rand(rng),
234        }
235    }
236}
237
238impl<F, P> EncKey<P>
239where
240    F: RescueParameter,
241    P: Config<BaseField = F>,
242{
243    fn compute_cipher_text_from_ephemeral_key_pair(
244        &self,
245        ephemeral_key_pair: KeyPair<P>,
246        msg: &[F],
247    ) -> Ciphertext<P> {
248        let shared_key = (self.key * ephemeral_key_pair.dec_key_ref().key).into_affine();
249        let perm = Permutation::default();
250        // TODO check if ok to use (x,y,0,0) as a key, since
251        // key = perm(x,y,0,0) doesn't buy us anything.
252        let key = perm.eval(&RescueVector::from(&[
253            shared_key.x,
254            shared_key.y,
255            F::zero(),
256            F::zero(),
257        ]));
258        // since key was just sampled and to be used only once, we can allow NONCE = 0
259        Ciphertext {
260            ephemeral: ephemeral_key_pair.enc_key(),
261            data: apply_counter_mode_stream::<F>(&key, msg, &F::zero(), Direction::Encrypt),
262        }
263    }
264
265    /// Public key encryption function with pre-sampled randomness
266    /// * `r` - randomness
267    /// * `msg` - plaintext
268    /// * `returns` - Ciphertext
269    pub fn deterministic_encrypt(&self, r: P::ScalarField, msg: &[F]) -> Ciphertext<P> {
270        let ephemeral_key_pair = KeyPair::from(DecKey { key: r });
271        self.compute_cipher_text_from_ephemeral_key_pair(ephemeral_key_pair, msg)
272    }
273
274    /// Public key encryption function
275    pub fn encrypt<R: CryptoRng + RngCore>(
276        &self,
277        prng: &mut R,
278        msg: &[P::BaseField],
279    ) -> Ciphertext<P> {
280        let ephemeral_key_pair = KeyPair::generate(prng);
281        self.compute_cipher_text_from_ephemeral_key_pair(ephemeral_key_pair, msg)
282    }
283}
284
285impl<F, P> DecKey<P>
286where
287    F: RescueParameter,
288    P: Config<BaseField = F>,
289{
290    /// Decryption function
291    fn decrypt(&self, ctext: &Ciphertext<P>) -> Vec<P::BaseField> {
292        let perm = Permutation::default();
293        let shared_key = (ctext.ephemeral.key * self.key).into_affine();
294        let key = perm.eval(&RescueVector::from(&[
295            shared_key.x,
296            shared_key.y,
297            F::zero(),
298            F::zero(),
299        ]));
300        // since key was just samples and to be used only once, we can have NONCE = 0
301        apply_counter_mode_stream::<F>(&key, ctext.data.as_slice(), &F::zero(), Direction::Decrypt)
302    }
303}
304
305impl<P> From<&DecKey<P>> for EncKey<P>
306where
307    P: Config,
308{
309    fn from(dec_key: &DecKey<P>) -> Self {
310        let mut point = Projective::<P>::generator();
311        point *= dec_key.key;
312        Self { key: point }
313    }
314}
315
316impl<F, P> KeyPair<P>
317where
318    F: RescueParameter,
319    P: Config<BaseField = F>,
320{
321    /// Decryption function
322    pub fn decrypt(&self, ctext: &Ciphertext<P>) -> Vec<F> {
323        self.dec.decrypt(ctext)
324    }
325}
326
327pub(crate) enum Direction {
328    Encrypt,
329    Decrypt,
330}
331
332pub(crate) fn apply_counter_mode_stream<F>(
333    key: &RescueVector<F>,
334    data: &[F],
335    nonce: &F,
336    direction: Direction,
337) -> Vec<F>
338where
339    F: RescueParameter,
340{
341    let prp = PRP::default();
342    let round_keys = prp.key_schedule(key);
343    // compute stream
344    let mut output = data.to_vec();
345    // temporarily append dummy padding element
346    pad_with_zeros(&mut output, STATE_SIZE);
347
348    let round_fn = |(idx, output_chunk): (usize, &mut [F])| {
349        let stream_chunk = prp.prp_with_round_keys(
350            &round_keys,
351            &RescueVector::from(&[
352                nonce.add(F::from(idx as u64)),
353                F::zero(),
354                F::zero(),
355                F::zero(),
356            ]),
357        );
358        for (output_elem, stream_elem) in output_chunk.iter_mut().zip(stream_chunk.elems().iter()) {
359            match direction {
360                Direction::Encrypt => output_elem.add_assign(stream_elem),
361                Direction::Decrypt => output_elem.sub_assign(stream_elem),
362            }
363        }
364    };
365    #[cfg(feature = "parallel")]
366    {
367        output
368            .par_chunks_exact_mut(STATE_SIZE)
369            .enumerate()
370            .for_each(round_fn);
371    }
372    #[cfg(not(feature = "parallel"))]
373    {
374        output
375            .chunks_exact_mut(STATE_SIZE)
376            .enumerate()
377            .for_each(round_fn);
378    }
379    // remove dummy padding elements
380    output.truncate(data.len());
381    output
382}
383
384#[inline]
385fn pad_with_zeros<F: Field>(vec: &mut Vec<F>, multiple: usize) {
386    let len = vec.len();
387    let new_len = compute_len_to_next_multiple(len, multiple);
388    vec.resize(new_len, F::zero())
389}
390
391#[inline]
392fn compute_len_to_next_multiple(len: usize, multiple: usize) -> usize {
393    if len % multiple == 0 {
394        len
395    } else {
396        len + multiple - len % multiple
397    }
398}
399
400#[cfg(test)]
401mod test {
402    use super::{Ciphertext, DecKey, EncKey, KeyPair, UniformRand};
403    use ark_ed_on_bls12_377::{EdwardsConfig as ParamEd377, Fq as FqEd377, Fr as FrEd377};
404    use ark_ed_on_bls12_381::{EdwardsConfig as ParamEd381, Fq as FqEd381, Fr as FrEd381};
405    use ark_ed_on_bls12_381_bandersnatch::{
406        EdwardsConfig as ParamEd381b, Fq as FqEd381b, Fr as FrEd381b,
407    };
408    use ark_ed_on_bn254::{EdwardsConfig as ParamEd254, Fq as FqEd254, Fr as FrEd254};
409    use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
410    use ark_std::{vec, vec::Vec};
411
412    macro_rules! test_enc_and_dec {
413        ($param: tt, $base_field:tt, $scalar_field: tt) => {
414            let mut rng = jf_utils::test_rng();
415            let keypair: KeyPair<$param> = KeyPair::generate(&mut rng);
416            let mut data = vec![];
417            let mut i = 0;
418
419            let pub_key = keypair.enc_key_ref();
420
421            loop {
422                if i == 17 {
423                    break;
424                }
425
426                let ctext1 = pub_key.encrypt(&mut rng, &data);
427                let decrypted1 = keypair.decrypt(&ctext1);
428                assert_eq!(&data, decrypted1.as_slice());
429                let decrypted1 = keypair.dec_key_ref().decrypt(&ctext1);
430                assert_eq!(&data, decrypted1.as_slice());
431
432                let ctext2 = pub_key.deterministic_encrypt($scalar_field::rand(&mut rng), &data);
433                let decrypted2 = keypair.decrypt(&ctext2);
434                assert_eq!(&data, decrypted2.as_slice());
435
436                data.push($base_field::rand(&mut rng));
437                i += 1;
438            }
439        };
440    }
441
442    #[test]
443    fn test_enc_and_dec() {
444        test_enc_and_dec!(ParamEd254, FqEd254, FrEd254);
445        test_enc_and_dec!(ParamEd377, FqEd377, FrEd377);
446        test_enc_and_dec!(ParamEd381, FqEd381, FrEd381);
447        test_enc_and_dec!(ParamEd381b, FqEd381b, FrEd381b);
448    }
449
450    macro_rules! test_serdes {
451        ($param: tt, $base_field:tt, $scalar_field: tt) => {
452            let mut rng = jf_utils::test_rng();
453            let keypair = KeyPair::<$param>::generate(&mut rng);
454            let msg = vec![$base_field::rand(&mut rng)];
455            let ct = keypair.enc_key().encrypt(&mut rng, &msg[..]);
456
457            let mut ser_bytes: Vec<u8> = Vec::new();
458            keypair.serialize_compressed(&mut ser_bytes).unwrap();
459            let de: KeyPair<$param> = KeyPair::deserialize_compressed(&ser_bytes[..]).unwrap();
460            assert_eq!(de, keypair);
461
462            let mut ser_bytes: Vec<u8> = Vec::new();
463            keypair.enc.serialize_compressed(&mut ser_bytes).unwrap();
464            let de: EncKey<$param> = EncKey::deserialize_compressed(&ser_bytes[..]).unwrap();
465            assert_eq!(keypair.enc, de);
466
467            let mut ser_bytes: Vec<u8> = Vec::new();
468            keypair.dec.serialize_compressed(&mut ser_bytes).unwrap();
469            let de: DecKey<$param> = DecKey::deserialize_compressed(&ser_bytes[..]).unwrap();
470            assert_eq!(keypair.dec, de);
471
472            let mut ser_bytes: Vec<u8> = Vec::new();
473            ct.serialize_compressed(&mut ser_bytes).unwrap();
474            let de: Ciphertext<$param> =
475                Ciphertext::deserialize_compressed(&ser_bytes[..]).unwrap();
476            assert_eq!(ct, de);
477        };
478    }
479
480    #[test]
481    fn test_serde() {
482        test_serdes!(ParamEd254, FqEd254, FrEd254);
483        test_serdes!(ParamEd377, FqEd377, FrEd377);
484        test_serdes!(ParamEd381, FqEd381, FrEd381);
485        test_serdes!(ParamEd381b, FqEd381b, FrEd381b);
486    }
487}