use crate::{
constants::{compute_coset_representatives, GATE_WIDTH, N_MUL_SELECTORS},
gates::*,
CircuitError,
CircuitError::*,
};
use ark_ff::{FftField, Field, PrimeField};
use ark_poly::{
domain::Radix2EvaluationDomain, univariate::DensePolynomial, DenseUVPolynomial,
EvaluationDomain,
};
use ark_std::{boxed::Box, cmp::max, format, string::ToString, vec, vec::Vec};
use hashbrown::{HashMap, HashSet};
use jf_utils::par_utils::parallelizable_slice_iter;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub type GateId = usize;
pub type WireId = usize;
pub type Variable = usize;
#[derive(Debug, Clone, Copy)]
pub struct BoolVar(pub usize);
impl From<BoolVar> for Variable {
fn from(bv: BoolVar) -> Self {
bv.0
}
}
impl BoolVar {
pub(crate) fn new_unchecked(inner: usize) -> Self {
Self(inner)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PlonkType {
TurboPlonk,
UltraPlonk,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum MergeableCircuitType {
TypeA,
TypeB,
}
pub trait Circuit<F: Field> {
fn num_gates(&self) -> usize;
fn num_vars(&self) -> usize;
fn num_inputs(&self) -> usize;
fn num_wire_types(&self) -> usize;
fn public_input(&self) -> Result<Vec<F>, CircuitError>;
fn check_circuit_satisfiability(&self, pub_input: &[F]) -> Result<(), CircuitError>;
fn create_constant_variable(&mut self, val: F) -> Result<Variable, CircuitError>;
fn create_variable(&mut self, val: F) -> Result<Variable, CircuitError>;
fn create_boolean_variable(&mut self, val: bool) -> Result<BoolVar, CircuitError> {
let val_scalar = if val { F::one() } else { F::zero() };
let var = self.create_variable(val_scalar)?;
self.enforce_bool(var)?;
Ok(BoolVar(var))
}
fn create_public_variable(&mut self, val: F) -> Result<Variable, CircuitError>;
fn create_public_boolean_variable(&mut self, val: bool) -> Result<BoolVar, CircuitError> {
let val_scalar = if val { F::one() } else { F::zero() };
let var = self.create_public_variable(val_scalar)?;
Ok(BoolVar(var))
}
fn set_variable_public(&mut self, var: Variable) -> Result<(), CircuitError>;
fn zero(&self) -> Variable;
fn one(&self) -> Variable;
fn false_var(&self) -> BoolVar {
BoolVar::new_unchecked(self.zero())
}
fn true_var(&self) -> BoolVar {
BoolVar::new_unchecked(self.one())
}
fn witness(&self, idx: Variable) -> Result<F, CircuitError>;
fn enforce_constant(&mut self, var: Variable, constant: F) -> Result<(), CircuitError>;
fn add_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError>;
fn add(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError>;
fn sub_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError>;
fn sub(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError>;
fn mul_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError>;
fn mul(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError>;
fn enforce_bool(&mut self, a: Variable) -> Result<(), CircuitError>;
fn enforce_equal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError>;
fn pad_gates(&mut self, n: usize);
fn support_lookup(&self) -> bool;
}
pub(crate) type SortedLookupVecAndPolys<F> = (Vec<F>, DensePolynomial<F>, DensePolynomial<F>);
pub trait Arithmetization<F: FftField>: Circuit<F> {
fn srs_size(&self) -> Result<usize, CircuitError>;
fn eval_domain_size(&self) -> Result<usize, CircuitError>;
fn compute_selector_polynomials(&self) -> Result<Vec<DensePolynomial<F>>, CircuitError>;
fn compute_extended_permutation_polynomials(
&self,
) -> Result<Vec<DensePolynomial<F>>, CircuitError>;
fn compute_prod_permutation_polynomial(
&self,
beta: &F,
gamma: &F,
) -> Result<DensePolynomial<F>, CircuitError>;
fn compute_wire_polynomials(&self) -> Result<Vec<DensePolynomial<F>>, CircuitError>;
fn compute_pub_input_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError>;
fn compute_range_table_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_key_table_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_table_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_q_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_merged_lookup_table(&self, _tau: F) -> Result<Vec<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_lookup_sorted_vec_polynomials(
&self,
_tau: F,
_lookup_table: &[F],
) -> Result<SortedLookupVecAndPolys<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
fn compute_lookup_prod_polynomial(
&self,
_tau: &F,
_beta: &F,
_gamma: &F,
_lookup_table: &[F],
_sorted_vec: &[F],
) -> Result<DensePolynomial<F>, CircuitError> {
Err(CircuitError::LookupUnsupported)
}
}
const RANGE_WIRE_ID: usize = 5;
const LOOKUP_KEY_WIRE_ID: usize = 0;
const LOOKUP_VAL_1_WIRE_ID: usize = 1;
const LOOKUP_VAL_2_WIRE_ID: usize = 2;
const TABLE_VAL_1_WIRE_ID: usize = 3;
const TABLE_VAL_2_WIRE_ID: usize = 4;
#[derive(Debug, Clone, Copy)]
struct PlonkParams {
plonk_type: PlonkType,
range_bit_len: Option<usize>,
}
impl PlonkParams {
fn init(plonk_type: PlonkType, range_bit_len: Option<usize>) -> Result<Self, CircuitError> {
if plonk_type == PlonkType::TurboPlonk {
return Ok(Self {
plonk_type,
range_bit_len: None,
});
}
if range_bit_len.is_none() {
return Err(ParameterError(
"range bit len cannot be none for UltraPlonk".to_string(),
));
}
Ok(Self {
plonk_type,
range_bit_len,
})
}
}
#[derive(Debug, Clone)]
pub struct PlonkCircuit<F>
where
F: FftField,
{
num_vars: usize,
gates: Vec<Box<dyn Gate<F>>>,
wire_variables: [Vec<Variable>; GATE_WIDTH + 2],
pub_input_gate_ids: Vec<GateId>,
witness: Vec<F>,
wire_permutation: Vec<(WireId, GateId)>,
extended_id_permutation: Vec<F>,
num_wire_types: usize,
eval_domain: Radix2EvaluationDomain<F>,
plonk_params: PlonkParams,
num_table_elems: usize,
table_gate_ids: Vec<(GateId, usize)>,
}
impl<F: FftField> Default for PlonkCircuit<F> {
fn default() -> Self {
let params = PlonkParams::init(PlonkType::TurboPlonk, None).unwrap();
Self::new(params)
}
}
impl<F: FftField> PlonkCircuit<F> {
fn new(plonk_params: PlonkParams) -> Self {
let zero = F::zero();
let one = F::one();
let mut circuit = Self {
num_vars: 2,
witness: vec![zero, one],
gates: vec![],
wire_variables: [vec![], vec![], vec![], vec![], vec![], vec![]],
pub_input_gate_ids: vec![],
wire_permutation: vec![],
extended_id_permutation: vec![],
num_wire_types: GATE_WIDTH
+ 1
+ match plonk_params.plonk_type {
PlonkType::TurboPlonk => 0,
PlonkType::UltraPlonk => 1,
},
eval_domain: Radix2EvaluationDomain::new(1).unwrap(),
plonk_params,
num_table_elems: 0,
table_gate_ids: vec![],
};
circuit.enforce_constant(0, zero).unwrap(); circuit.enforce_constant(1, one).unwrap(); circuit
}
pub fn new_turbo_plonk() -> Self {
let plonk_params = PlonkParams::init(PlonkType::TurboPlonk, None).unwrap(); Self::new(plonk_params)
}
pub fn new_ultra_plonk(range_bit_len: usize) -> Self {
let plonk_params = PlonkParams::init(PlonkType::UltraPlonk, Some(range_bit_len)).unwrap(); Self::new(plonk_params)
}
pub fn insert_gate(
&mut self,
wire_vars: &[Variable; GATE_WIDTH + 1],
gate: Box<dyn Gate<F>>,
) -> Result<(), CircuitError> {
self.check_finalize_flag(false)?;
for (wire_var, wire_variable) in wire_vars
.iter()
.zip(self.wire_variables.iter_mut().take(GATE_WIDTH + 1))
{
wire_variable.push(*wire_var)
}
self.gates.push(gate);
Ok(())
}
pub fn add_range_check_variable(&mut self, var: Variable) -> Result<(), CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(false)?;
self.check_var_bound(var)?;
self.wire_variables[RANGE_WIRE_ID].push(var);
Ok(())
}
#[inline]
pub fn check_var_bound(&self, var: Variable) -> Result<(), CircuitError> {
if var >= self.num_vars {
return Err(VarIndexOutOfBound(var, self.num_vars));
}
Ok(())
}
pub fn check_vars_bound(&self, vars: &[Variable]) -> Result<(), CircuitError> {
for &var in vars {
self.check_var_bound(var)?
}
Ok(())
}
pub fn witness_mut(&mut self, idx: Variable) -> &mut F {
&mut self.witness[idx]
}
pub(crate) fn table_gate_ids_mut(&mut self) -> &mut Vec<(GateId, usize)> {
&mut self.table_gate_ids
}
pub(crate) fn num_table_elems_mut(&mut self) -> &mut usize {
&mut self.num_table_elems
}
pub(crate) fn num_table_elems(&self) -> usize {
self.num_table_elems
}
pub fn range_bit_len(&self) -> Result<usize, CircuitError> {
if self.plonk_params.plonk_type != PlonkType::UltraPlonk {
return Err(ParameterError(
"call range_bit_len() with non-ultraplonk circuit".to_string(),
));
}
Ok(self.plonk_params.range_bit_len.unwrap()) }
pub fn range_size(&self) -> Result<usize, CircuitError> {
Ok(1 << self.range_bit_len()?)
}
pub(crate) fn create_boolean_variable_unchecked(
&mut self,
a: F,
) -> Result<BoolVar, CircuitError> {
let var = self.create_variable(a)?;
Ok(BoolVar::new_unchecked(var))
}
}
impl<F: FftField> Circuit<F> for PlonkCircuit<F> {
fn num_gates(&self) -> usize {
self.gates.len()
}
fn num_vars(&self) -> usize {
self.num_vars
}
fn num_inputs(&self) -> usize {
self.pub_input_gate_ids.len()
}
fn num_wire_types(&self) -> usize {
self.num_wire_types
}
fn public_input(&self) -> Result<Vec<F>, CircuitError> {
self.pub_input_gate_ids
.iter()
.map(|&gate_id| -> Result<F, CircuitError> {
let var = self.wire_variables[GATE_WIDTH][gate_id];
self.witness(var)
})
.collect::<Result<Vec<F>, CircuitError>>()
}
fn check_circuit_satisfiability(&self, pub_input: &[F]) -> Result<(), CircuitError> {
if pub_input.len() != self.num_inputs() {
return Err(PubInputLenMismatch(
pub_input.len(),
self.pub_input_gate_ids.len(),
));
}
for (i, gate_id) in self.pub_input_gate_ids.iter().enumerate() {
let pi = pub_input[i];
self.check_gate(*gate_id, &pi)?;
}
for gate_id in 0..self.num_gates() {
if !self.is_io_gate(gate_id) {
let pi = F::zero();
self.check_gate(gate_id, &pi)?;
}
}
if self.plonk_params.plonk_type == PlonkType::UltraPlonk {
for idx in 0..self.wire_variables[RANGE_WIRE_ID].len() {
self.check_range_gate(idx)?
}
let mut key_val_table = HashSet::new();
key_val_table.insert((F::zero(), F::zero(), F::zero(), F::zero()));
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
for (gate_id, ((&q_lookup, &table_dom_sep), &table_key)) in q_lookup_vec
.iter()
.zip(table_dom_sep_vec.iter())
.zip(table_key_vec.iter())
.enumerate()
{
if q_lookup != F::zero() {
let val0 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, gate_id))?;
key_val_table.insert((table_dom_sep, table_key, val0, val1));
}
}
for (gate_id, (&q_lookup, &q_dom_sep)) in
q_lookup_vec.iter().zip(q_dom_sep_vec.iter()).enumerate()
{
if q_lookup != F::zero() {
let key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, gate_id))?;
let val0 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, gate_id))?;
let val1 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, gate_id))?;
if !key_val_table.contains(&(q_dom_sep, key, val0, val1)) {
return Err(GateCheckFailure(
gate_id,
format!(
"Lookup gate failed: ({q_dom_sep}, {key}, {val0}, {val1}) not in the table",
),
));
}
}
}
}
Ok(())
}
fn create_constant_variable(&mut self, val: F) -> Result<Variable, CircuitError> {
let var = self.create_variable(val)?;
self.enforce_constant(var, val)?;
Ok(var)
}
fn create_variable(&mut self, val: F) -> Result<Variable, CircuitError> {
self.check_finalize_flag(false)?;
self.witness.push(val);
self.num_vars += 1;
Ok(self.num_vars - 1)
}
fn create_public_variable(&mut self, val: F) -> Result<Variable, CircuitError> {
let var = self.create_variable(val)?;
self.set_variable_public(var)?;
Ok(var)
}
fn set_variable_public(&mut self, var: Variable) -> Result<(), CircuitError> {
self.check_finalize_flag(false)?;
self.pub_input_gate_ids.push(self.num_gates());
let wire_vars = &[0, 0, 0, 0, var];
self.insert_gate(wire_vars, Box::new(IoGate))?;
Ok(())
}
fn zero(&self) -> Variable {
0
}
fn one(&self) -> Variable {
1
}
fn witness(&self, idx: Variable) -> Result<F, CircuitError> {
self.check_var_bound(idx)?;
Ok(self.witness[idx])
}
fn enforce_constant(&mut self, var: Variable, constant: F) -> Result<(), CircuitError> {
self.check_var_bound(var)?;
let wire_vars = &[0, 0, 0, 0, var];
self.insert_gate(wire_vars, Box::new(ConstantGate(constant)))?;
Ok(())
}
fn add_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
self.check_var_bound(c)?;
let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(AdditionGate))?;
Ok(())
}
fn add(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
let val = self.witness(a)? + self.witness(b)?;
let c = self.create_variable(val)?;
self.add_gate(a, b, c)?;
Ok(c)
}
fn sub_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
self.check_var_bound(c)?;
let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(SubtractionGate))?;
Ok(())
}
fn sub(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
let val = self.witness(a)? - self.witness(b)?;
let c = self.create_variable(val)?;
self.sub_gate(a, b, c)?;
Ok(c)
}
fn mul_gate(&mut self, a: Variable, b: Variable, c: Variable) -> Result<(), CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
self.check_var_bound(c)?;
let wire_vars = &[a, b, 0, 0, c];
self.insert_gate(wire_vars, Box::new(MultiplicationGate))?;
Ok(())
}
fn mul(&mut self, a: Variable, b: Variable) -> Result<Variable, CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
let val = self.witness(a)? * self.witness(b)?;
let c = self.create_variable(val)?;
self.mul_gate(a, b, c)?;
Ok(c)
}
fn enforce_bool(&mut self, a: Variable) -> Result<(), CircuitError> {
self.check_var_bound(a)?;
let wire_vars = &[a, a, 0, 0, a];
self.insert_gate(wire_vars, Box::new(BoolGate))?;
Ok(())
}
fn enforce_equal(&mut self, a: Variable, b: Variable) -> Result<(), CircuitError> {
self.check_var_bound(a)?;
self.check_var_bound(b)?;
let wire_vars = &[a, b, 0, 0, 0];
self.insert_gate(wire_vars, Box::new(EqualityGate))?;
Ok(())
}
fn pad_gates(&mut self, n: usize) {
let wire_vars = &[self.zero(), self.zero(), 0, 0, 0];
for _ in 0..n {
self.insert_gate(wire_vars, Box::new(EqualityGate)).unwrap();
}
}
fn support_lookup(&self) -> bool {
self.plonk_params.plonk_type == PlonkType::UltraPlonk
}
}
impl<F: FftField> PlonkCircuit<F> {
fn check_range_gate(&self, idx: usize) -> Result<(), CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
if idx >= self.wire_variables[RANGE_WIRE_ID].len() {
return Err(IndexError);
}
let range_size = self.range_size()?;
if self.witness[self.wire_variables[RANGE_WIRE_ID][idx]] >= F::from(range_size as u32) {
return Err(GateCheckFailure(
idx,
format!(
"Range gate failed: {} >= {}",
self.witness[self.wire_variables[RANGE_WIRE_ID][idx]], range_size
),
));
}
Ok(())
}
fn is_finalized(&self) -> bool {
self.eval_domain.size() != 1
}
fn rearrange_gates(&mut self) -> Result<(), CircuitError> {
self.check_finalize_flag(true)?;
for (gate_id, io_gate_id) in self.pub_input_gate_ids.iter_mut().enumerate() {
if *io_gate_id > gate_id {
self.gates.swap(gate_id, *io_gate_id);
for i in 0..GATE_WIDTH + 1 {
self.wire_variables[i].swap(gate_id, *io_gate_id);
}
*io_gate_id = gate_id;
}
}
if self.support_lookup() {
let n = self.eval_domain.size();
let mut cur_gate_id = n - 2;
for &(table_gate_id, table_size) in self.table_gate_ids.iter().rev() {
for gate_id in (table_gate_id..table_gate_id + table_size).rev() {
if gate_id < cur_gate_id {
self.gates.swap(gate_id, cur_gate_id);
for j in 0..GATE_WIDTH + 1 {
self.wire_variables[j].swap(gate_id, cur_gate_id);
}
cur_gate_id -= 1;
}
}
}
}
Ok(())
}
fn is_io_gate(&self, gate_id: GateId) -> bool {
self.gates[gate_id].as_any().is::<IoGate>()
}
fn pad(&mut self) -> Result<(), CircuitError> {
self.check_finalize_flag(true)?;
let n = self.eval_domain.size();
for _ in self.num_gates()..n {
self.gates.push(Box::new(PaddingGate));
}
for wire_id in 0..self.num_wire_types() {
self.wire_variables[wire_id].resize(n, self.zero());
}
Ok(())
}
fn check_gate(&self, gate_id: Variable, pub_input: &F) -> Result<(), CircuitError> {
let w_vals: Vec<F> = (0..GATE_WIDTH + 1)
.map(|i| self.witness[self.wire_variables[i][gate_id]])
.collect();
let q_lc: [F; GATE_WIDTH] = self.gates[gate_id].q_lc();
let q_mul: [F; N_MUL_SELECTORS] = self.gates[gate_id].q_mul();
let q_hash: [F; GATE_WIDTH] = self.gates[gate_id].q_hash();
let q_c = self.gates[gate_id].q_c();
let q_o = self.gates[gate_id].q_o();
let q_ecc = self.gates[gate_id].q_ecc();
let expected_gate_output = *pub_input
+ q_lc[0] * w_vals[0]
+ q_lc[1] * w_vals[1]
+ q_lc[2] * w_vals[2]
+ q_lc[3] * w_vals[3]
+ q_mul[0] * w_vals[0] * w_vals[1]
+ q_mul[1] * w_vals[2] * w_vals[3]
+ q_ecc * w_vals[0] * w_vals[1] * w_vals[2] * w_vals[3] * w_vals[4]
+ q_hash[0] * w_vals[0].pow([5])
+ q_hash[1] * w_vals[1].pow([5])
+ q_hash[2] * w_vals[2].pow([5])
+ q_hash[3] * w_vals[3].pow([5])
+ q_c;
let gate_output = q_o * w_vals[4];
if expected_gate_output != gate_output {
return Err(
GateCheckFailure(
gate_id,
format!(
"gate: {:?}, wire values: {:?}, pub_input: {}, expected_gate_output: {}, gate_output: {}",
self.gates[gate_id],
w_vals,
pub_input,
expected_gate_output,
gate_output
)
));
}
Ok(())
}
#[inline]
fn compute_wire_permutation(&mut self) {
assert!(self.is_finalized());
let n = self.eval_domain.size();
let m = self.num_vars();
let mut variable_wires_map = vec![vec![]; m];
for (gate_wire_id, variables) in self
.wire_variables
.iter()
.take(self.num_wire_types())
.enumerate()
{
for (gate_id, &var) in variables.iter().enumerate() {
variable_wires_map[var].push((gate_wire_id, gate_id));
}
}
self.wire_permutation = vec![(0usize, 0usize); self.num_wire_types * n];
for wires_vec in variable_wires_map.iter_mut() {
if !wires_vec.is_empty() {
wires_vec.push(wires_vec[0]);
for window in wires_vec.windows(2) {
self.wire_permutation[window[0].0 * n + window[0].1] = window[1];
}
wires_vec.pop();
}
}
}
#[inline]
fn check_finalize_flag(&self, expect_finalized: bool) -> Result<(), CircuitError> {
if !self.is_finalized() && expect_finalized {
return Err(UnfinalizedCircuit);
}
if self.is_finalized() && !expect_finalized {
return Err(ModifyFinalizedCircuit);
}
Ok(())
}
#[inline]
fn check_plonk_type(&self, expect_type: PlonkType) -> Result<(), CircuitError> {
if self.plonk_params.plonk_type != expect_type {
return Err(WrongPlonkType);
}
Ok(())
}
#[inline]
fn wire_variable(&self, i: WireId, j: GateId) -> Variable {
match j < self.wire_variables[i].len() {
true => self.wire_variables[i][j],
false => self.zero(),
}
}
#[inline]
fn q_lc(&self) -> [Vec<F>; GATE_WIDTH] {
let mut result = [vec![], vec![], vec![], vec![]];
for gate in &self.gates {
let q_lc_vec = gate.q_lc();
result[0].push(q_lc_vec[0]);
result[1].push(q_lc_vec[1]);
result[2].push(q_lc_vec[2]);
result[3].push(q_lc_vec[3]);
}
result
}
#[inline]
fn q_mul(&self) -> [Vec<F>; N_MUL_SELECTORS] {
let mut result = [vec![], vec![]];
for gate in &self.gates {
let q_mul_vec = gate.q_mul();
result[0].push(q_mul_vec[0]);
result[1].push(q_mul_vec[1]);
}
result
}
#[inline]
fn q_hash(&self) -> [Vec<F>; GATE_WIDTH] {
let mut result = [vec![], vec![], vec![], vec![]];
for gate in &self.gates {
let q_hash_vec = gate.q_hash();
result[0].push(q_hash_vec[0]);
result[1].push(q_hash_vec[1]);
result[2].push(q_hash_vec[2]);
result[3].push(q_hash_vec[3]);
}
result
}
#[inline]
fn q_o(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_o()).collect()
}
#[inline]
fn q_c(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_c()).collect()
}
#[inline]
fn q_ecc(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_ecc()).collect()
}
#[inline]
fn q_lookup(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_lookup()).collect()
}
#[inline]
fn q_dom_sep(&self) -> Vec<F> {
self.gates.iter().map(|g| g.q_dom_sep()).collect()
}
#[inline]
fn table_key_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_key()).collect()
}
#[inline]
fn table_dom_sep_vec(&self) -> Vec<F> {
self.gates.iter().map(|g| g.table_dom_sep()).collect()
}
#[inline]
fn all_selectors(&self) -> Vec<Vec<F>> {
let mut selectors = vec![];
self.q_lc()
.as_ref()
.iter()
.chain(self.q_mul().as_ref().iter())
.chain(self.q_hash().as_ref().iter())
.for_each(|s| selectors.push(s.clone()));
selectors.push(self.q_o());
selectors.push(self.q_c());
selectors.push(self.q_ecc());
if self.support_lookup() {
selectors.push(self.q_lookup());
}
selectors
}
}
impl<F: PrimeField> PlonkCircuit<F> {
#[inline]
fn compute_extended_id_permutation(&mut self) {
assert!(self.is_finalized());
let n = self.eval_domain.size();
let k: Vec<F> = compute_coset_representatives(self.num_wire_types, Some(n));
let group_elems: Vec<F> = self.eval_domain.elements().collect();
self.extended_id_permutation = vec![F::zero(); self.num_wire_types * n];
for (i, &coset_repr) in k.iter().enumerate() {
for (j, &group_elem) in group_elems.iter().enumerate() {
self.extended_id_permutation[i * n + j] = coset_repr * group_elem;
}
}
}
#[inline]
fn compute_extended_permutation(&self) -> Result<Vec<F>, CircuitError> {
assert!(self.is_finalized());
let n = self.eval_domain.size();
let extended_perm: Vec<F> = self
.wire_permutation
.iter()
.map(|&(wire_id, gate_id)| {
if wire_id >= self.num_wire_types {
F::zero()
} else {
self.extended_id_permutation[wire_id * n + gate_id]
}
})
.collect();
if extended_perm.len() != self.num_wire_types * n {
return Err(ParameterError(
"Length of the extended permutation vector should be number of gate \
(including padded dummy gates) * number of wire types"
.to_string(),
));
}
Ok(extended_perm)
}
}
impl<F: PrimeField> PlonkCircuit<F> {
pub fn finalize_for_arithmetization(&mut self) -> Result<(), CircuitError> {
if self.is_finalized() {
return Ok(());
}
let num_slots_needed = match self.support_lookup() {
false => self.num_gates(),
true => max(
self.num_gates(),
max(self.range_size()?, self.wire_variables[RANGE_WIRE_ID].len())
+ self.num_table_elems()
+ 1,
), };
self.eval_domain = Radix2EvaluationDomain::new(num_slots_needed)
.ok_or(CircuitError::DomainCreationError)?;
self.pad()?;
self.rearrange_gates()?;
self.compute_wire_permutation();
self.compute_extended_id_permutation();
Ok(())
}
pub fn finalize_for_mergeable_circuit(
&mut self,
circuit_type: MergeableCircuitType,
) -> Result<(), CircuitError> {
if self.plonk_params.plonk_type != PlonkType::TurboPlonk {
return Err(WrongPlonkType);
}
self.finalize_for_arithmetization()?;
let n = self.eval_domain_size()?;
self.eval_domain =
Radix2EvaluationDomain::new(2 * n).ok_or(CircuitError::DomainCreationError)?;
for _ in 0..n {
self.gates.push(Box::new(PaddingGate));
}
for wire_id in 0..self.num_wire_types() {
self.wire_variables[wire_id].resize(2 * n, self.zero());
}
if circuit_type == MergeableCircuitType::TypeA {
let mut wire_perm = vec![(self.num_wire_types, 0usize); self.num_wire_types * 2 * n];
for i in 0..self.num_wire_types {
for j in 0..n {
wire_perm[i * 2 * n + j] = self.wire_permutation[i * n + j];
}
}
self.wire_permutation = wire_perm;
} else {
self.gates.reverse();
for wire_id in 0..self.num_wire_types() {
self.wire_variables[wire_id].reverse();
}
for io_gate in self.pub_input_gate_ids.iter_mut() {
*io_gate = 2 * n - 1 - *io_gate;
}
let mut wire_perm = vec![(self.num_wire_types, 0usize); self.num_wire_types * 2 * n];
for i in 0..self.num_wire_types {
for j in 0..n {
let (wire_id, gate_id) = self.wire_permutation[i * n + j];
let gate_id = 2 * n - 1 - gate_id;
wire_perm[i * 2 * n + 2 * n - 1 - j] = (wire_id, gate_id);
}
}
self.wire_permutation = wire_perm;
}
self.compute_extended_id_permutation();
Ok(())
}
#[allow(dead_code)]
pub fn merge(&self, other: &Self) -> Result<Self, CircuitError> {
self.check_finalize_flag(true)?;
other.check_finalize_flag(true)?;
if self.eval_domain_size()? != other.eval_domain_size()? {
return Err(ParameterError(format!(
"cannot merge circuits with different domain sizes: {}, {}",
self.eval_domain_size()?,
other.eval_domain_size()?
)));
}
if self.plonk_params.plonk_type != PlonkType::TurboPlonk
|| other.plonk_params.plonk_type != PlonkType::TurboPlonk
{
return Err(ParameterError(
"do not support merging non-TurboPlonk circuits.".to_string(),
));
}
if self.num_inputs() != other.num_inputs() {
return Err(ParameterError(format!(
"self.num_inputs = {} different from other.num_inputs = {}",
self.num_inputs(),
other.num_inputs()
)));
}
if self.pub_input_gate_ids[0] != 0 {
return Err(ParameterError(
"the first circuit is not type A".to_string(),
));
}
if other.pub_input_gate_ids[0] != other.eval_domain_size()? - 1 {
return Err(ParameterError(
"the second circuit is not type B".to_string(),
));
}
let num_vars = self.num_vars + other.num_vars;
let witness: Vec<F> = [self.witness.as_slice(), other.witness.as_slice()].concat();
let pub_input_gate_ids: Vec<usize> = [
self.pub_input_gate_ids.as_slice(),
other.pub_input_gate_ids.as_slice(),
]
.concat();
let n = self.eval_domain_size()? / 2;
let mut gates = vec![];
let mut wire_variables = [vec![], vec![], vec![], vec![], vec![], vec![]];
for (j, gate) in self.gates.iter().take(n).enumerate() {
gates.push((*gate).clone());
for (i, wire_vars) in wire_variables
.iter_mut()
.enumerate()
.take(self.num_wire_types)
{
wire_vars.push(self.wire_variable(i, j));
}
}
for (j, gate) in other.gates.iter().skip(n).enumerate() {
gates.push((*gate).clone());
for (i, wire_vars) in wire_variables
.iter_mut()
.enumerate()
.take(self.num_wire_types)
{
wire_vars.push(other.wire_variable(i, n + j) + self.num_vars);
}
}
let mut wire_permutation = vec![(0usize, 0usize); self.num_wire_types * 2 * n];
for i in 0..self.num_wire_types {
for j in 0..n {
wire_permutation[i * 2 * n + j] = self.wire_permutation[i * 2 * n + j];
wire_permutation[i * 2 * n + n + j] = other.wire_permutation[i * 2 * n + n + j];
}
}
Ok(Self {
num_vars,
witness,
gates,
wire_variables,
pub_input_gate_ids,
wire_permutation,
extended_id_permutation: self.extended_id_permutation.clone(),
num_wire_types: self.num_wire_types,
eval_domain: self.eval_domain,
plonk_params: self.plonk_params,
num_table_elems: 0,
table_gate_ids: vec![],
})
}
}
impl<F> Arithmetization<F> for PlonkCircuit<F>
where
F: PrimeField,
{
fn srs_size(&self) -> Result<usize, CircuitError> {
Ok(self.eval_domain_size()? + 2)
}
fn eval_domain_size(&self) -> Result<usize, CircuitError> {
self.check_finalize_flag(true)?;
Ok(self.eval_domain.size())
}
fn compute_selector_polynomials(&self) -> Result<Vec<DensePolynomial<F>>, CircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
if domain.size() < self.num_gates() {
return Err(ParameterError(
"Domain size should be bigger than number of constraint".to_string(),
));
}
let selector_polys = parallelizable_slice_iter(&self.all_selectors())
.map(|selector| DensePolynomial::from_coefficients_vec(domain.ifft(selector)))
.collect();
Ok(selector_polys)
}
fn compute_extended_permutation_polynomials(
&self,
) -> Result<Vec<DensePolynomial<F>>, CircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let n = domain.size();
let extended_perm = self.compute_extended_permutation()?;
let extended_perm_polys: Vec<DensePolynomial<F>> =
parallelizable_slice_iter(&(0..self.num_wire_types).collect::<Vec<_>>()) .map(|i| {
DensePolynomial::from_coefficients_vec(
domain.ifft(&extended_perm[i * n..(i + 1) * n]),
)
})
.collect();
Ok(extended_perm_polys)
}
fn compute_prod_permutation_polynomial(
&self,
beta: &F,
gamma: &F,
) -> Result<DensePolynomial<F>, CircuitError> {
self.check_finalize_flag(true)?;
let mut product_vec = vec![F::one()];
let domain = &self.eval_domain;
let n = domain.size();
for j in 0..(n - 1) {
let mut a = F::one();
let mut b = F::one();
for i in 0..self.num_wire_types {
let wire_value = self.witness[self.wire_variable(i, j)];
let tmp = wire_value + gamma;
a *= tmp + *beta * self.extended_id_permutation[i * n + j];
let (perm_i, perm_j) = self.wire_permutation[i * n + j];
b *= tmp + *beta * self.extended_id_permutation[perm_i * n + perm_j];
}
let prev_prod = *product_vec.last().ok_or(CircuitError::IndexError)?;
product_vec.push(prev_prod * a / b);
}
domain.ifft_in_place(&mut product_vec);
Ok(DensePolynomial::from_coefficients_vec(product_vec))
}
fn compute_wire_polynomials(&self) -> Result<Vec<DensePolynomial<F>>, CircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
if domain.size() < self.num_gates() {
return Err(ParameterError(format!(
"Domain size {} should be bigger than number of constraint {}",
domain.size(),
self.num_gates()
)));
}
let witness = &self.witness;
let wire_polys: Vec<DensePolynomial<F>> = parallelizable_slice_iter(&self.wire_variables)
.take(self.num_wire_types())
.map(|wire_vars| {
let mut wire_vec: Vec<F> = wire_vars.iter().map(|&var| witness[var]).collect();
domain.ifft_in_place(&mut wire_vec);
DensePolynomial::from_coefficients_vec(wire_vec)
})
.collect();
assert_eq!(wire_polys.len(), self.num_wire_types());
Ok(wire_polys)
}
fn compute_pub_input_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let mut pub_input_vec = vec![F::zero(); domain.size()];
self.pub_input_gate_ids.iter().for_each(|&io_gate_id| {
let var = self.wire_variables[GATE_WIDTH][io_gate_id];
pub_input_vec[io_gate_id] = self.witness[var];
});
domain.ifft_in_place(&mut pub_input_vec);
Ok(DensePolynomial::from_coefficients_vec(pub_input_vec))
}
fn compute_range_table_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
let range_table = self.compute_range_table()?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&range_table),
))
}
fn compute_key_table_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_key_vec()),
))
}
fn compute_table_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.table_dom_sep_vec()),
))
}
fn compute_q_dom_sep_polynomial(&self) -> Result<DensePolynomial<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
Ok(DensePolynomial::from_coefficients_vec(
domain.ifft(&self.q_dom_sep()),
))
}
fn compute_merged_lookup_table(&self, tau: F) -> Result<Vec<F>, CircuitError> {
let range_table = self.compute_range_table()?;
let table_key_vec = self.table_key_vec();
let table_dom_sep_vec = self.table_dom_sep_vec();
let q_lookup_vec = self.q_lookup();
let mut merged_lookup_table = vec![];
for i in 0..self.eval_domain_size()? {
merged_lookup_table.push(self.merged_table_value(
tau,
&range_table,
&table_key_vec,
&table_dom_sep_vec,
&q_lookup_vec,
i,
)?);
}
Ok(merged_lookup_table)
}
fn compute_lookup_prod_polynomial(
&self,
tau: &F,
beta: &F,
gamma: &F,
merged_lookup_table: &[F],
sorted_vec: &[F],
) -> Result<DensePolynomial<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let n = domain.size();
if n != self.wire_variables[RANGE_WIRE_ID].len() {
return Err(ParameterError(
"Domain size should match the size of the padded lookup variables vector"
.to_string(),
));
}
if n != merged_lookup_table.len() {
return Err(ParameterError(
"Domain size should match the size of the padded lookup table".to_string(),
));
}
if 2 * n - 1 != sorted_vec.len() {
return Err(ParameterError(
"The sorted vector has wrong length".to_string(),
));
}
let mut product_vec = vec![F::one()];
let beta_plus_one = F::one() + *beta;
let gamma_mul_beta_plus_one = *gamma * beta_plus_one;
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for j in 0..(n - 2) {
let lookup_wire_val =
self.merged_lookup_wire_value(*tau, j, &q_lookup_vec, &q_dom_sep_vec)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
let h1_next_val = sorted_vec[j + 1];
let h2_val = sorted_vec[n - 1 + j];
let h2_next_val = sorted_vec[n + j];
let a = beta_plus_one
* (*gamma + lookup_wire_val)
* (gamma_mul_beta_plus_one + table_val + *beta * table_next_val);
let b = (gamma_mul_beta_plus_one + h1_val + *beta * h1_next_val)
* (gamma_mul_beta_plus_one + h2_val + *beta * h2_next_val);
let prev_prod = *product_vec.last().ok_or(CircuitError::IndexError)?;
product_vec.push(prev_prod * a / b);
}
product_vec.push(F::one());
domain.ifft_in_place(&mut product_vec);
Ok(DensePolynomial::from_coefficients_vec(product_vec))
}
fn compute_lookup_sorted_vec_polynomials(
&self,
tau: F,
merged_lookup_table: &[F],
) -> Result<SortedLookupVecAndPolys<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let n = domain.size();
if n != self.wire_variables[RANGE_WIRE_ID].len() {
return Err(ParameterError(
"Domain size should match the size of the padded lookup variables vector"
.to_string(),
));
}
if n != merged_lookup_table.len() {
return Err(ParameterError(
"Domain size should match the size of the padded lookup table".to_string(),
));
}
let mut lookup_map = HashMap::<F, usize>::new();
let q_lookup_vec = self.q_lookup();
let q_dom_sep_vec = self.q_dom_sep();
for i in 0..(n - 1) {
let elem = self.merged_lookup_wire_value(tau, i, &q_lookup_vec, &q_dom_sep_vec)?;
let n_lookups = lookup_map.entry(elem).or_insert(0);
*n_lookups += 1;
}
let mut sorted_vec = vec![];
for elem in merged_lookup_table.iter() {
if let Some(n_lookup) = lookup_map.get(elem) {
sorted_vec.extend(vec![*elem; 1 + n_lookup]);
lookup_map.remove(elem);
} else {
sorted_vec.push(*elem);
}
}
if sorted_vec.len() != 2 * n - 1 {
return Err(ParameterError("The sorted vector has wrong length, some lookup variables might be outside the table".to_string()));
}
let h1_poly = DensePolynomial::from_coefficients_vec(domain.ifft(&sorted_vec[..n]));
let h2_poly = DensePolynomial::from_coefficients_vec(domain.ifft(&sorted_vec[n - 1..]));
Ok((sorted_vec, h1_poly, h2_poly))
}
}
impl<F: PrimeField> PlonkCircuit<F> {
#[inline]
fn compute_range_table(&self) -> Result<Vec<F>, CircuitError> {
self.check_plonk_type(PlonkType::UltraPlonk)?;
self.check_finalize_flag(true)?;
let domain = &self.eval_domain;
let range_size = self.range_size()?;
if domain.size() < range_size {
return Err(ParameterError(format!(
"Domain size {} < range size {}",
domain.size(),
range_size
)));
}
let mut range_table: Vec<F> = (0..range_size).map(|i| F::from(i as u32)).collect();
range_table.resize(domain.size(), F::zero());
Ok(range_table)
}
#[inline]
fn merged_table_value(
&self,
tau: F,
range_table: &[F],
table_key_vec: &[F],
table_dom_sep_vec: &[F],
q_lookup_vec: &[F],
i: usize,
) -> Result<F, CircuitError> {
let range_val = range_table[i];
let key_val = table_key_vec[i];
let dom_sep_val = table_dom_sep_vec[i];
let q_lookup_val = q_lookup_vec[i];
let table_val_1 = self.witness(self.wire_variable(TABLE_VAL_1_WIRE_ID, i))?;
let table_val_2 = self.witness(self.wire_variable(TABLE_VAL_2_WIRE_ID, i))?;
Ok(range_val
+ q_lookup_val
* tau
* (dom_sep_val + tau * (key_val + tau * (table_val_1 + tau * table_val_2))))
}
#[inline]
fn merged_lookup_wire_value(
&self,
tau: F,
i: usize,
q_lookup_vec: &[F],
q_dom_sep_vec: &[F],
) -> Result<F, CircuitError> {
let w_range_val = self.witness(self.wire_variable(RANGE_WIRE_ID, i))?;
let lookup_key = self.witness(self.wire_variable(LOOKUP_KEY_WIRE_ID, i))?;
let lookup_val_1 = self.witness(self.wire_variable(LOOKUP_VAL_1_WIRE_ID, i))?;
let lookup_val_2 = self.witness(self.wire_variable(LOOKUP_VAL_2_WIRE_ID, i))?;
let q_lookup_val = q_lookup_vec[i];
let q_dom_sep_val = q_dom_sep_vec[i];
Ok(w_range_val
+ q_lookup_val
* tau
* (q_dom_sep_val + tau * (lookup_key + tau * (lookup_val_1 + tau * lookup_val_2))))
}
}
#[cfg(test)]
pub(crate) mod test {
use super::{Arithmetization, Circuit, PlonkCircuit};
use crate::{constants::compute_coset_representatives, CircuitError};
use ark_bls12_377::Fq as Fq377;
use ark_ed_on_bls12_377::Fq as FqEd377;
use ark_ed_on_bls12_381::Fq as FqEd381;
use ark_ed_on_bn254::Fq as FqEd254;
use ark_ff::PrimeField;
use ark_poly::{domain::Radix2EvaluationDomain, univariate::DensePolynomial, EvaluationDomain};
use ark_std::{vec, vec::Vec};
use jf_utils::test_rng;
#[test]
fn test_circuit_trait() -> Result<(), CircuitError> {
test_circuit_trait_helper::<FqEd254>()?;
test_circuit_trait_helper::<FqEd377>()?;
test_circuit_trait_helper::<FqEd381>()?;
test_circuit_trait_helper::<Fq377>()
}
fn test_circuit_trait_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_variable(F::from(1u32))?;
circuit.enforce_constant(a, F::from(3u32))?;
circuit.enforce_bool(b)?;
let c = circuit.add(a, b)?;
let d = circuit.sub(a, b)?;
let e = circuit.mul(c, d)?;
let f = circuit.create_public_variable(F::from(8u32))?;
circuit.enforce_equal(e, f)?;
assert_eq!(circuit.num_gates(), 9);
assert_eq!(circuit.num_vars(), 8);
assert_eq!(circuit.num_inputs(), 1);
let pub_input = &[F::from(8u32)];
let verify = circuit.check_circuit_satisfiability(pub_input);
assert!(verify.is_ok(), "{:?}", verify.unwrap_err());
let bad_pub_input = &[F::from(0u32)];
assert!(circuit.check_circuit_satisfiability(bad_pub_input).is_err());
let bad_pub_input = &[F::from(8u32), F::from(8u32)];
assert!(circuit.check_circuit_satisfiability(bad_pub_input).is_err());
Ok(())
}
#[test]
fn test_add() -> Result<(), CircuitError> {
test_add_helper::<FqEd254>()?;
test_add_helper::<FqEd377>()?;
test_add_helper::<FqEd381>()?;
test_add_helper::<Fq377>()
}
fn test_add_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_variable(F::from(1u32))?;
let c = circuit.add(a, b)?;
assert_eq!(circuit.witness(c)?, F::from(4u32));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(c) = F::from(1u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit.add(circuit.num_vars(), a).is_err());
Ok(())
}
#[test]
fn test_sub() -> Result<(), CircuitError> {
test_sub_helper::<FqEd254>()?;
test_sub_helper::<FqEd377>()?;
test_sub_helper::<FqEd381>()?;
test_sub_helper::<Fq377>()
}
fn test_sub_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_variable(F::from(1u32))?;
let c = circuit.sub(a, b)?;
assert_eq!(circuit.witness(c)?, F::from(2u32));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(c) = F::from(1u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit.sub(circuit.num_vars(), a).is_err());
Ok(())
}
#[test]
fn test_mul() -> Result<(), CircuitError> {
test_mul_helper::<FqEd254>()?;
test_mul_helper::<FqEd377>()?;
test_mul_helper::<FqEd381>()?;
test_mul_helper::<Fq377>()
}
fn test_mul_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_variable(F::from(2u32))?;
let c = circuit.mul(a, b)?;
assert_eq!(circuit.witness(c)?, F::from(6u32));
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(c) = F::from(1u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit.mul(circuit.num_vars(), a).is_err());
Ok(())
}
#[test]
fn test_equal_gate() -> Result<(), CircuitError> {
test_equal_gate_helper::<FqEd254>()?;
test_equal_gate_helper::<FqEd377>()?;
test_equal_gate_helper::<FqEd381>()?;
test_equal_gate_helper::<Fq377>()
}
fn test_equal_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_variable(F::from(3u32))?;
circuit.enforce_equal(a, b)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(b) = F::from(1u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit.enforce_equal(circuit.num_vars(), a).is_err());
Ok(())
}
#[test]
fn test_bool() -> Result<(), CircuitError> {
test_bool_helper::<FqEd254>()?;
test_bool_helper::<FqEd377>()?;
test_bool_helper::<FqEd381>()?;
test_bool_helper::<Fq377>()
}
fn test_bool_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(0u32))?;
circuit.enforce_bool(a)?;
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(a) = F::from(2u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit.enforce_bool(circuit.num_vars()).is_err());
Ok(())
}
#[test]
fn test_constant() -> Result<(), CircuitError> {
test_constant_helper::<FqEd254>()?;
test_constant_helper::<FqEd377>()?;
test_constant_helper::<FqEd381>()?;
test_constant_helper::<Fq377>()
}
fn test_constant_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(10u32))?;
circuit.enforce_constant(a, F::from(10u32))?;
assert!(circuit.check_circuit_satisfiability(&[]).is_ok());
*circuit.witness_mut(a) = F::from(2u32);
assert!(circuit.check_circuit_satisfiability(&[]).is_err());
assert!(circuit
.enforce_constant(circuit.num_vars(), F::from(0u32))
.is_err());
Ok(())
}
#[test]
fn test_io_gate() -> Result<(), CircuitError> {
test_io_gate_helper::<FqEd254>()?;
test_io_gate_helper::<FqEd377>()?;
test_io_gate_helper::<FqEd381>()?;
test_io_gate_helper::<Fq377>()
}
fn test_io_gate_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let b = circuit.create_variable(F::from(0u32))?;
let a = circuit.create_public_variable(F::from(1u32))?;
circuit.enforce_bool(a)?;
circuit.enforce_bool(b)?;
circuit.set_variable_public(b)?;
assert!(circuit
.check_circuit_satisfiability(&[F::from(1u32), F::from(0u32)])
.is_ok());
*circuit.witness_mut(a) = F::from(0u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(0u32), F::from(0u32)])
.is_ok());
*circuit.witness_mut(b) = F::from(1u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(0u32), F::from(1u32)])
.is_ok());
assert!(circuit
.check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)])
.is_err());
*circuit.witness_mut(a) = F::from(2u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(2u32), F::from(1u32)])
.is_err());
*circuit.witness_mut(a) = F::from(0u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)])
.is_err());
*circuit.witness_mut(b) = F::from(2u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(0u32), F::from(2u32)])
.is_err());
Ok(())
}
#[test]
fn test_io_gate_multi_inputs() -> Result<(), CircuitError> {
test_io_gate_multi_inputs_helper::<FqEd254>()?;
test_io_gate_multi_inputs_helper::<FqEd377>()?;
test_io_gate_multi_inputs_helper::<FqEd381>()?;
test_io_gate_multi_inputs_helper::<Fq377>()
}
fn test_io_gate_multi_inputs_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit = PlonkCircuit::<F>::new_turbo_plonk();
let a = circuit.create_public_variable(F::from(1u32))?;
let b = circuit.create_public_variable(F::from(2u32))?;
let c = circuit.create_public_variable(F::from(3u32))?;
circuit.add_gate(a, b, c)?;
assert!(circuit
.check_circuit_satisfiability(&[F::from(1u32), F::from(2u32), F::from(3u32)])
.is_ok());
assert!(circuit
.check_circuit_satisfiability(&[F::from(2u32), F::from(1u32), F::from(3u32)])
.is_err());
*circuit.witness_mut(a) = F::from(4u32);
*circuit.witness_mut(b) = F::from(8u32);
*circuit.witness_mut(c) = F::from(12u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(4u32), F::from(8u32), F::from(12u32)])
.is_ok());
*circuit.witness_mut(a) = F::from(2u32);
assert!(circuit
.check_circuit_satisfiability(&[F::from(2u32), F::from(8u32), F::from(12u32)])
.is_err());
Ok(())
}
fn create_turbo_plonk_instance<F: PrimeField>(
) -> Result<(PlonkCircuit<F>, Vec<F>), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_public_variable(F::from(1u32))?;
circuit.enforce_constant(a, F::from(3u32))?;
circuit.enforce_bool(b)?;
let c = circuit.add(a, b)?;
let d = circuit.sub(a, b)?;
let e = circuit.mul(c, d)?;
let f = circuit.create_public_variable(F::from(8u32))?;
circuit.enforce_equal(e, f)?;
Ok((circuit, vec![F::from(1u32), F::from(8u32)]))
}
fn create_ultra_plonk_instance<F: PrimeField>(
) -> Result<(PlonkCircuit<F>, Vec<F>), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(4);
let a = circuit.create_variable(F::from(3u32))?;
let b = circuit.create_public_variable(F::from(1u32))?;
circuit.enforce_constant(a, F::from(3u32))?;
circuit.enforce_bool(b)?;
let c = circuit.add(a, b)?;
let d = circuit.sub(a, b)?;
let e = circuit.mul(c, d)?;
let f = circuit.create_public_variable(F::from(8u32))?;
circuit.enforce_equal(e, f)?;
circuit.add_range_check_variable(b)?;
circuit.add_range_check_variable(c)?;
circuit.add_range_check_variable(e)?;
circuit.add_range_check_variable(f)?;
circuit.add_range_check_variable(circuit.zero())?;
let table_vars = [(a, b), (c, d), (e, f)];
let x = circuit.create_variable(F::from(3u8))?;
let y = circuit.create_variable(F::from(8u8))?;
let key1 = circuit.create_variable(F::from(2u8))?;
let lookup_vars = [(circuit.zero(), x, circuit.one()), (key1, y, y)];
circuit.create_table_and_lookup_variables(&lookup_vars, &table_vars)?;
Ok((circuit, vec![F::from(1u32), F::from(8u32)]))
}
#[test]
fn test_compute_extended_permutation() -> Result<(), CircuitError> {
test_compute_extended_permutation_helper::<FqEd254>()?;
test_compute_extended_permutation_helper::<FqEd377>()?;
test_compute_extended_permutation_helper::<FqEd381>()?;
test_compute_extended_permutation_helper::<Fq377>()
}
fn test_compute_extended_permutation_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
let a = circuit.create_variable(F::from(2u32))?;
let b = circuit.create_public_variable(F::from(3u32))?;
let c = circuit.add(a, b)?;
let d = circuit.add(circuit.one(), a)?;
let _ = circuit.mul(c, d)?;
let (mut circuit, _) = create_ultra_plonk_instance::<F>()?;
check_wire_permutation_and_extended_id_permutation(&mut circuit)?;
Ok(())
}
fn check_wire_permutation_and_extended_id_permutation<F: PrimeField>(
circuit: &mut PlonkCircuit<F>,
) -> Result<(), CircuitError> {
let domain = Radix2EvaluationDomain::<F>::new(circuit.num_gates())
.ok_or(CircuitError::DomainCreationError)?;
let n = domain.size();
circuit.eval_domain = domain;
circuit.pad()?;
circuit.compute_wire_permutation();
let mut visit_wire = vec![false; circuit.num_wire_types * n];
let mut visit_variable = vec![false; circuit.num_vars()];
for i in 0..circuit.num_wire_types {
for j in 0..n {
if visit_wire[i * n + j] {
continue;
}
let cycle_var = circuit.wire_variable(i, j);
assert!(!visit_variable[cycle_var]);
visit_variable[cycle_var] = true;
let mut wire_id = i;
let mut gate_id = j;
visit_wire[i * n + j] = true;
loop {
let (next_wire_id, next_gate_id) =
circuit.wire_permutation[wire_id * n + gate_id];
if next_wire_id == i && next_gate_id == j {
break;
}
let next_var = circuit.wire_variable(next_wire_id, next_gate_id);
assert_eq!(cycle_var, next_var);
assert!(!visit_wire[next_wire_id * n + next_gate_id]);
visit_wire[next_wire_id * n + next_gate_id] = true;
wire_id = next_wire_id;
gate_id = next_gate_id;
}
}
}
circuit.compute_extended_id_permutation();
let k: Vec<F> = compute_coset_representatives(circuit.num_wire_types, Some(n));
let group_elems: Vec<F> = domain.elements().collect();
(0..circuit.num_wire_types).for_each(|i| {
(0..n).for_each(|j| {
assert_eq!(
k[i] * group_elems[j],
circuit.extended_id_permutation[i * n + j]
)
});
});
Ok(())
}
#[test]
fn test_ultra_plonk_flag() -> Result<(), CircuitError> {
test_ultra_plonk_flag_helper::<FqEd254>()?;
test_ultra_plonk_flag_helper::<FqEd377>()?;
test_ultra_plonk_flag_helper::<FqEd381>()?;
test_ultra_plonk_flag_helper::<Fq377>()
}
fn test_ultra_plonk_flag_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
assert!(circuit.add_range_check_variable(0).is_err());
circuit.finalize_for_arithmetization()?;
assert!(circuit.compute_range_table_polynomial().is_err());
assert!(circuit.compute_key_table_polynomial().is_err());
assert!(circuit.compute_merged_lookup_table(F::one()).is_err());
assert!(circuit
.compute_lookup_sorted_vec_polynomials(F::one(), &[])
.is_err());
assert!(circuit
.compute_lookup_prod_polynomial(&F::one(), &F::one(), &F::one(), &[], &[])
.is_err());
Ok(())
}
#[test]
fn test_finalized_flag() -> Result<(), CircuitError> {
test_finalized_flag_helper::<FqEd254>()?;
test_finalized_flag_helper::<FqEd377>()?;
test_finalized_flag_helper::<FqEd381>()?;
test_finalized_flag_helper::<Fq377>()
}
fn test_finalized_flag_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_turbo_plonk();
assert!(circuit.compute_selector_polynomials().is_err());
assert!(circuit.compute_extended_permutation_polynomials().is_err());
assert!(circuit.compute_pub_input_polynomial().is_err());
assert!(circuit.compute_wire_polynomials().is_err());
assert!(circuit
.compute_prod_permutation_polynomial(&F::one(), &F::one())
.is_err());
circuit.finalize_for_arithmetization()?;
assert!(circuit.create_variable(F::one()).is_err());
assert!(circuit.create_public_variable(F::one()).is_err());
assert!(circuit.add_gate(0, 0, 0).is_err());
assert!(circuit.sub_gate(0, 0, 0).is_err());
assert!(circuit.mul_gate(0, 0, 0).is_err());
assert!(circuit.enforce_constant(0, F::one()).is_err());
assert!(circuit.enforce_bool(0).is_err());
assert!(circuit.enforce_equal(0, 0).is_err());
Ok(())
}
#[test]
fn test_ultra_plonk_finalized_flag() -> Result<(), CircuitError> {
test_ultra_plonk_finalized_flag_helper::<FqEd254>()?;
test_ultra_plonk_finalized_flag_helper::<FqEd377>()?;
test_ultra_plonk_finalized_flag_helper::<FqEd381>()?;
test_ultra_plonk_finalized_flag_helper::<Fq377>()
}
fn test_ultra_plonk_finalized_flag_helper<F: PrimeField>() -> Result<(), CircuitError> {
let mut circuit: PlonkCircuit<F> = PlonkCircuit::new_ultra_plonk(1);
assert!(circuit.compute_selector_polynomials().is_err());
assert!(circuit.compute_extended_permutation_polynomials().is_err());
assert!(circuit.compute_pub_input_polynomial().is_err());
assert!(circuit.compute_wire_polynomials().is_err());
assert!(circuit
.compute_prod_permutation_polynomial(&F::one(), &F::one())
.is_err());
assert!(circuit.compute_range_table_polynomial().is_err());
assert!(circuit.compute_key_table_polynomial().is_err());
assert!(circuit.compute_merged_lookup_table(F::one()).is_err());
assert!(circuit
.compute_lookup_sorted_vec_polynomials(F::one(), &[])
.is_err());
assert!(circuit
.compute_lookup_prod_polynomial(&F::one(), &F::one(), &F::one(), &[], &[])
.is_err());
circuit.finalize_for_arithmetization()?;
assert!(circuit.create_variable(F::one()).is_err());
assert!(circuit.create_public_variable(F::one()).is_err());
assert!(circuit.add_gate(0, 0, 0).is_err());
assert!(circuit.sub_gate(0, 0, 0).is_err());
assert!(circuit.mul_gate(0, 0, 0).is_err());
assert!(circuit.enforce_constant(0, F::one()).is_err());
assert!(circuit.enforce_bool(0).is_err());
assert!(circuit.enforce_equal(0, 0).is_err());
assert!(circuit.add_range_check_variable(0).is_err());
Ok(())
}
#[test]
fn test_arithmetization() -> Result<(), CircuitError> {
test_arithmetization_helper::<FqEd254>()?;
test_arithmetization_helper::<FqEd377>()?;
test_arithmetization_helper::<FqEd381>()?;
test_arithmetization_helper::<Fq377>()
}
fn test_arithmetization_helper<F: PrimeField>() -> Result<(), CircuitError> {
let (mut circuit, pub_inputs) = create_turbo_plonk_instance::<F>()?;
circuit.finalize_for_arithmetization()?;
test_arithmetization_for_circuit(circuit, pub_inputs)?;
let (mut circuit, pub_inputs) = create_ultra_plonk_instance::<F>()?;
circuit.finalize_for_arithmetization()?;
test_arithmetization_for_lookup_circuit(&circuit)?;
test_arithmetization_for_circuit(circuit, pub_inputs)?;
Ok(())
}
fn check_polynomial<F: PrimeField>(poly: &DensePolynomial<F>, evals: &[F]) {
let domain = Radix2EvaluationDomain::new(evals.len()).unwrap();
let poly_eval = poly.evaluate_over_domain_by_ref(domain);
for (&a, &b) in poly_eval.evals.iter().zip(evals.iter()) {
assert_eq!(a, b);
}
}
pub(crate) fn test_arithmetization_for_lookup_circuit<F: PrimeField>(
circuit: &PlonkCircuit<F>,
) -> Result<(), CircuitError> {
let n = circuit.eval_domain.size();
let range_table_poly = circuit.compute_range_table_polynomial()?;
let range_table = circuit.compute_range_table()?;
check_polynomial(&range_table_poly, &range_table);
let key_table_poly = circuit.compute_key_table_polynomial()?;
let key_table = circuit.table_key_vec();
check_polynomial(&key_table_poly, &key_table);
let rng = &mut test_rng();
let tau = F::rand(rng);
let merged_lookup_table = circuit.compute_merged_lookup_table(tau)?;
let (sorted_vec, h1_poly, h2_poly) =
circuit.compute_lookup_sorted_vec_polynomials(tau, &merged_lookup_table)?;
assert_eq!(sorted_vec.len(), 2 * n - 1);
assert_eq!(sorted_vec[0], merged_lookup_table[0]);
let mut ptr = 1;
for slice in sorted_vec.windows(2) {
if slice[0] == slice[1] {
continue;
}
while ptr < n && merged_lookup_table[ptr] == merged_lookup_table[ptr - 1] {
ptr += 1;
}
assert!(ptr < n);
assert_eq!(merged_lookup_table[ptr], slice[1]);
ptr += 1;
}
assert_eq!(ptr, n);
check_polynomial(&h1_poly, &sorted_vec[..n]);
check_polynomial(&h2_poly, &sorted_vec[n - 1..]);
let beta = F::rand(rng);
let gamma = F::rand(rng);
let prod_poly = circuit.compute_lookup_prod_polynomial(
&tau,
&beta,
&gamma,
&merged_lookup_table,
&sorted_vec,
)?;
let mut prod_evals = vec![F::one()];
let one_plus_beta = F::one() + beta;
let gamma_mul_one_plus_beta = gamma * one_plus_beta;
let q_lookup_vec = circuit.q_lookup();
let q_dom_sep = circuit.q_dom_sep();
for j in 0..(n - 2) {
let lookup_wire_val =
circuit.merged_lookup_wire_value(tau, j, &q_lookup_vec, &q_dom_sep)?;
let table_val = merged_lookup_table[j];
let table_next_val = merged_lookup_table[j + 1];
let h1_val = sorted_vec[j];
let h1_next_val = sorted_vec[j + 1];
let h2_val = sorted_vec[n - 1 + j];
let h2_next_val = sorted_vec[n + j];
let a = one_plus_beta
* (gamma + lookup_wire_val)
* (gamma_mul_one_plus_beta + table_val + beta * table_next_val);
let b = (gamma_mul_one_plus_beta + h1_val + beta * h1_next_val)
* (gamma_mul_one_plus_beta + h2_val + beta * h2_next_val);
let prod = prod_evals[j] * a / b;
prod_evals.push(prod);
}
prod_evals.push(F::one());
check_polynomial(&prod_poly, &prod_evals);
Ok(())
}
pub(crate) fn test_arithmetization_for_circuit<F: PrimeField>(
circuit: PlonkCircuit<F>,
pub_inputs: Vec<F>,
) -> Result<(), CircuitError> {
let n = circuit.eval_domain.size();
let selector_polys = circuit.compute_selector_polynomials()?;
selector_polys
.iter()
.zip(circuit.all_selectors().iter())
.for_each(|(poly, evals)| check_polynomial(poly, evals));
let wire_polys = circuit.compute_wire_polynomials()?;
for (poly, wire_vars) in wire_polys
.iter()
.zip(circuit.wire_variables.iter().take(circuit.num_wire_types()))
{
let wire_evals: Vec<F> = wire_vars.iter().map(|&var| circuit.witness[var]).collect();
check_polynomial(poly, &wire_evals);
}
let pi_poly = circuit.compute_pub_input_polynomial()?;
let mut pi_evals = pub_inputs;
pi_evals.extend(vec![F::zero(); n - 2]);
check_polynomial(&pi_poly, &pi_evals);
let sigma_polys = circuit.compute_extended_permutation_polynomials()?;
let extended_perm: Vec<F> = circuit
.wire_permutation
.iter()
.map(|&(i, j)| circuit.extended_id_permutation[i * n + j])
.collect();
for (i, poly) in sigma_polys.iter().enumerate() {
check_polynomial(poly, &extended_perm[i * n..(i + 1) * n]);
}
let rng = &mut test_rng();
let beta = F::rand(rng);
let gamma = F::rand(rng);
let prod_poly = circuit.compute_prod_permutation_polynomial(&beta, &gamma)?;
let mut prod_evals = vec![F::one()];
for j in 0..(n - 1) {
let mut a = F::one();
let mut b = F::one();
for i in 0..circuit.num_wire_types {
let wire_value = circuit.witness[circuit.wire_variable(i, j)];
a *= wire_value + beta * circuit.extended_id_permutation[i * n + j] + gamma;
b *= wire_value + beta * extended_perm[i * n + j] + gamma;
}
let prod = prod_evals[j] * a / b;
prod_evals.push(prod);
}
check_polynomial(&prod_poly, &prod_evals);
Ok(())
}
}