Skip to content

Latest commit

 

History

History
337 lines (277 loc) · 16.1 KB

10_zkStark_PartTwo_AIR_Rescue-Prime.md

File metadata and controls

337 lines (277 loc) · 16.1 KB

ETAAcademy-ZKMeme: 10. zkStark AIR and Rescue-Prime

title tags
10. zkStark AIR and Rescue-Prime
zk-meme basic quick_read zkStark_AIR_Rescue-Prime

GithubTwitterETA-ZK-Meme

Authors: Eta, looking forward to your joining

More Detail for zkStark:
Part I: STARK Overview
Part II: Basic Tools
Part II: FRI
Part IV: The STARK Polynomial IOP
Part V: A Rescue-Prime STARK
Part VI: Speeding Things Up

AIR: Arithmetic Constraint System

The arithmetic intermediate representation (AIR) is vital in STARKs, framing computations through execution traces and constraints, which are interpolated into polynomials. The trace can be viewed as a series of state transitions of registers. To reduce the time and space complexity for the prover, the polynomials are linearly combined into one polynomial. This verification involves symbolic evaluation and division by a zerofier polynomial to ensure computational integrity. Finally, from the arithmetic constraint system, two types of witnesses are obtained: the execution trace of the entire program to be proven for trace polynomials, the constraints for quotient polynomials, which could be used for polynomial commitments and FRI system as shown in zkStarks Part One.

  • Prove

    defprove( self, trace, transition_constraints, boundary, proof_stream=None ):
    # create proof stream object if necessary
    if proof_stream== None:
                proof_stream= ProofStream()
    
    # concatenate randomizers
    for kin range(self.num_randomizers):
                trace= trace+ [[self.field.sample(os.urandom(17))for sin range(self.num_registers)]]
    
    # interpolate
            trace_domain= [self.omicron^ifor iin range(len(trace))]
            trace_polynomials= []
    for sin range(self.num_registers):
                single_trace= [trace[c][s]for cin range(len(trace))]
                trace_polynomials= trace_polynomials+ [Polynomial.interpolate_domain(trace_domain, single_trace)]
    
    # subtract boundary interpolants and divide out boundary zerofiers
            boundary_quotients= []
    for sin range(self.num_registers):
                interpolant= self.boundary_interpolants(boundary)[s]
                zerofier= self.boundary_zerofiers(boundary)[s]
                quotient= (trace_polynomials[s]- interpolant)/ zerofier
                boundary_quotients+= [quotient]
    
    # commit to boundary quotients
            fri_domain= self.fri.eval_domain()
            boundary_quotient_codewords= []
            boundary_quotient_Merkle_roots= []
    for sin range(self.num_registers):
                boundary_quotient_codewords= boundary_quotient_codewords+ [boundary_quotients[s].evaluate_domain(fri_domain)]
                merkle_root= Merkle.commit(boundary_quotient_codewords[s])
                proof_stream.push(merkle_root)
    
    # symbolically evaluate transition constraints
            point= [Polynomial([self.field.zero(), self.field.one()])]+ trace_polynomials+ [tp.scale(self.omicron)for tpin trace_polynomials]
            transition_polynomials= [a.evaluate_symbolic(point)for ain transition_constraints]
    
    # divide out zerofier
            transition_quotients= [tp/ self.transition_zerofier()for tpin transition_polynomials]
    
    # commit to randomizer polynomial
            randomizer_polynomial= Polynomial([self.field.sample(os.urandom(17))for iin range(self.max_degree(transition_constraints)+1)])
            randomizer_codeword= randomizer_polynomial.evaluate_domain(fri_domain)
            randomizer_root= Merkle.commit(randomizer_codeword)
            proof_stream.push(randomizer_root)
    
    # get weights for nonlinear combination
    #  - 1 randomizer
    #  - 2 for every transition quotient
    #  - 2 for every boundary quotient
            weights= self.sample_weights(1+ 2*len(transition_quotients)+ 2*len(boundary_quotients), proof_stream.prover_fiat_shamir())
    
    assert([tq.degree()for tqin transition_quotients]== self.transition_quotient_degree_bounds(transition_constraints)), "transition quotient degrees do not match with expectation"
    
    # compute terms of nonlinear combination polynomial
            x= Polynomial([self.field.zero(), self.field.one()])
            terms= []
            terms+= [randomizer_polynomial]
    for iin range(len(transition_quotients)):
                terms+= [transition_quotients[i]]
                shift= self.max_degree(transition_constraints)- self.transition_quotient_degree_bounds(transition_constraints)[i]
                terms+= [(x^shift)* transition_quotients[i]]
    for iin range(self.num_registers):
                terms+= [boundary_quotients[i]]
                shift= self.max_degree(transition_constraints)- self.boundary_quotient_degree_bounds(len(trace), boundary)[i]
                terms+= [(x^shift)* boundary_quotients[i]]
    
    # take weighted sum
    # combination = sum(weights[i] * terms[i] for all i)
            combination= reduce(lambda a, b : a+b, [Polynomial([weights[i]])* terms[i]for iin range(len(terms))], Polynomial([]))
    
    # compute matching codeword
            combined_codeword= combination.evaluate_domain(fri_domain)
    
    # prove low degree of combination polynomial
            indices= self.fri.prove(combined_codeword, proof_stream)
            indices.sort()
            duplicated_indices= [ifor iin indices]+ [(i+ self.expansion_factor)% self.fri.domain_lengthfor iin indices]
    
    # open indicated positions in the boundary quotient codewords
    for bqcin boundary_quotient_codewords:
    for iin duplicated_indices:
                    proof_stream.push(bqc[i])
                    path= Merkle.open(i, bqc)
                    proof_stream.push(path)
    
    # ... as well as in the randomizer
    for iin indices:
                proof_stream.push(randomizer_codeword[i])
                path= Merkle.open(i, randomizer_codeword)
                proof_stream.push(path)
    
    # the final proof is just the serialized stream
    return proof_stream.serialize()
  • Verify

    def verify( self, proof, transition_constraints, boundary, proof_stream=None ):
            H = blake2b
    
            # infer trace length from boundary conditions
            original_trace_length = 1 + max(c for c, r, v in boundary)
            randomized_trace_length = original_trace_length + self.num_randomizers
    
            # deserialize with right proof stream
            if proof_stream == None:
                proof_stream = ProofStream()
            proof_stream = proof_stream.deserialize(proof)
    
            # get Merkle roots of boundary quotient codewords
            boundary_quotient_roots = []
            for s in range(self.num_registers):
                boundary_quotient_roots = boundary_quotient_roots + [proof_stream.pull()]
    
            # get Merkle root of randomizer polynomial
            randomizer_root = proof_stream.pull()
    
            # get weights for nonlinear combination
            weights = self.sample_weights(1 + 2*len(transition_constraints) + 2*len(self.boundary_interpolants(boundary)), proof_stream.verifier_fiat_shamir())
    
            # verify low degree of combination polynomial
            polynomial_values = []
            verifier_accepts = self.fri.verify(proof_stream, polynomial_values)
            polynomial_values.sort(key=lambda iv : iv[0])
            if not verifier_accepts:
                return False
    
            indices = [i for i,v in polynomial_values]
            values = [v for i,v in polynomial_values]
    
            # read and verify leafs, which are elements of boundary quotient codewords
            duplicated_indices = [i for i in indices] + [(i + self.expansion_factor) % self.fri.domain_length for i in indices]
            leafs = []
            for r in range(len(boundary_quotient_roots)):
                leafs = leafs + [dict()]
                for i in duplicated_indices:
                    leafs[r][i] = proof_stream.pull()
                    path = proof_stream.pull()
                    verifier_accepts = verifier_accepts and Merkle.verify(boundary_quotient_roots[r], i, path, leafs[r][i])
                    if not verifier_accepts:
                        return False
    
            # read and verify randomizer leafs
            randomizer = dict()
            for i in indices:
                randomizer[i] = proof_stream.pull()
                path = proof_stream.pull()
                verifier_accepts = verifier_accepts and Merkle.verify(randomizer_root, i, path, randomizer[i])
    
            # verify leafs of combination polynomial
            for i in range(len(indices)):
                current_index = indices[i] # do need i
    
                # get trace values by applying a correction to the boundary quotient values (which are the leafs)
                domain_current_index = self.generator * (self.omega^current_index)
                next_index = (current_index + self.expansion_factor) % self.fri.domain_length
                domain_next_index = self.generator * (self.omega^next_index)
                current_trace = [self.field.zero() for s in range(self.num_registers)]
                next_trace = [self.field.zero() for s in range(self.num_registers)]
                for s in range(self.num_registers):
                    zerofier = self.boundary_zerofiers(boundary)[s]
                    interpolant = self.boundary_interpolants(boundary)[s]
    
                    current_trace[s] = leafs[s][current_index] * zerofier.evaluate(domain_current_index) + interpolant.evaluate(domain_current_index)
                    next_trace[s] = leafs[s][next_index] * zerofier.evaluate(domain_next_index) + interpolant.evaluate(domain_next_index)
    
                point = [domain_current_index] + current_trace + next_trace
                transition_constraints_values = [transition_constraints[s].evaluate(point) for s in range(len(transition_constraints))]
    
                # compute nonlinear combination
                counter = 0
                terms = []
                terms += [randomizer[current_index]]
                for s in range(len(transition_constraints_values)):
                    tcv = transition_constraints_values[s]
                    quotient = tcv / self.transition_zerofier().evaluate(domain_current_index)
                    terms += [quotient]
                    shift = self.max_degree(transition_constraints) - self.transition_quotient_degree_bounds(transition_constraints)[s]
                    terms += [quotient * (domain_current_index^shift)]
                for s in range(self.num_registers):
                    bqv = leafs[s][current_index] # boundary quotient value
                    terms += [bqv]
                    shift = self.max_degree(transition_constraints) - self.boundary_quotient_degree_bounds(randomized_trace_length, boundary)[s]
                    terms += [bqv * (domain_current_index^shift)]
                combination = reduce(lambda a, b : a+b, [terms[j] * weights[j] for j in range(len(terms))], self.field.zero())
    
                # verify against combination polynomial value
                verifier_accepts = verifier_accepts and (combination == values[i])
                if not verifier_accepts:
                    return False
    
            return verifier_accepts

Rescue-Prime: a STARK Proof of AIR and Signature

Rescue-Prime STARK, a concretely useful STARK proof system that serves as both a post-quantum signature scheme and a proof of correct evaluation for the Rescue-Prime hash function. Rescue-Prime is described as an arithmetization-oriented hash function, employing a sponge construction with multiple almost-identical rounds. The steps involved in a single round include forward and backward S-box operations, matrix multiplications, and adding round constants. Transition constraints and boundary constraints are detailed for arithmetizing the Rescue-Prime function, and the process of obtaining witness data (the execution trace).

  • Rescue-Prime AIR

    defround_constants_polynomials( self, omicron ):
            first_step_constants= []
    for iin range(self.m):
                domain= [omicron^rfor rin range(0, self.N)]
                values= [self.round_constants[2*r*self.m+i]for rin range(0, self.N)]
                univariate= Polynomial.interpolate_domain(domain, values)
                multivariate= MPolynomial.lift(univariate, 0)
                first_step_constants+= [multivariate]
            second_step_constants= []
    for iin range(self.m):
                domain= [omicron^rfor rin range(0, self.N)]
                values= [self.field.zero()]* self.N
    #for r in range(self.N):
    #    print("len(round_constants):", len(self.round_constants), " but grabbing index:", 2*r*self.m+self.m+i, "for r=", r, "for m=", self.m, "for i=", i)
    #    values[r] = self.round_constants[2*r*self.m + self.m + i]
                values= [self.round_constants[2*r*self.m+self.m+i]for rin range(self.N)]
                univariate= Polynomial.interpolate_domain(domain, values)
                multivariate= MPolynomial.lift(univariate, 0)
                second_step_constants+= [multivariate]
    
    return first_step_constants, second_step_constants
    
    deftransition_constraints( self, omicron ):
    # get polynomials that interpolate through the round constants
            first_step_constants, second_step_constants= self.round_constants_polynomials(omicron)
    
    # arithmetize one round of Rescue-Prime
            variables= MPolynomial.variables(1+ 2*self.m, self.field)
            cycle_index= variables[0]
            previous_state= variables[1:(1+self.m)]
            next_state= variables[(1+self.m):(1+2*self.m)]
            air= []
    for iin range(self.m):
    # compute left hand side symbolically
    # lhs = sum(MPolynomial.constant(self.MDS[i][k]) * (previous_state[k]^self.alpha) for k in range(self.m)) + first_step_constants[i]
                lhs= MPolynomial.constant(self.field.zero())
    for kin range(self.m):
                    lhs= lhs+ MPolynomial.constant(self.MDS[i][k])* (previous_state[k]^self.alpha)
                lhs= lhs+ first_step_constants[i]
    
    # compute right hand side symbolically
    # rhs = sum(MPolynomial.constant(self.MDSinv[i][k]) * (next_state[k] - second_step_constants[k]) for k in range(self.m))^self.alpha
                rhs= MPolynomial.constant(self.field.zero())
    for kin range(self.m):
                    rhs= rhs+ MPolynomial.constant(self.MDSinv[i][k])* (next_state[k]- second_step_constants[k])
                rhs= rhs^self.alpha
    
    # equate left and right hand sides
                air+= [lhs-rhs]
    
    return air
    
        def trace( self, input_element ):
            trace = []
    
            # absorb
            state = [input_element] + [self.field.zero()] * (self.m - 1)
    
            # explicit copy to record state into trace
            trace += [[s for s in state]]
    
            # permutation
            for r in range(self.N):
    
                # forward half-round
                # S-box
                for i in range(self.m):
                    state[i] = state[i]^self.alpha
                # matrix
                temp = [self.field.zero() for i in range(self.m)]
                for i in range(self.m):
                    for j in range(self.m):
                        temp[i] = temp[i] + self.MDS[i][j] * state[j]
                # constants
                state = [temp[i] + self.round_constants[2*r*self.m+i] for i in range(self.m)]
    
                # backward half-round
                # S-box
                for i in range(self.m):
                    state[i] = state[i]^self.alphainv
                # matrix
                temp = [self.field.zero() for i in range(self.m)]
                for i in range(self.m):
                    for j in range(self.m):
                        temp[i] = temp[i] + self.MDS[i][j] * state[j]
                # constants
                state = [temp[i] + self.round_constants[2*r*self.m+self.m+i] for i in range(self.m)]
    
                # record state at this point, with explicit copy
                trace += [[s for s in state]]
    
            return trace