Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove config and add generic extension type #176

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 51 additions & 35 deletions extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use constants::*;
use frame_support::{
dispatch::{GetDispatchInfo, PostDispatchInfo},
pallet_prelude::*,
traits::{Contains, OriginTrait},
traits::OriginTrait,
};
use frame_system::RawOrigin;
use pallet_contracts::chain_extension::{
Expand All @@ -22,30 +22,42 @@ use sp_std::vec::Vec;

type ContractSchedule<T> = <T as pallet_contracts::Config>::Schedule;

/// Trait for the Pop API chain extension configuration.
pub trait Config:
frame_system::Config<RuntimeCall: GetDispatchInfo + Dispatchable<PostInfo = PostDispatchInfo>>
{
/// A query of runtime state.
type RuntimeRead: Decode;
/// Something to read runtime states.
type StateReader: ReadState<Self>;
/// Allowlisted runtime calls and read state calls.
type AllowedApiCalls: Contains<Self::RuntimeCall> + Contains<Self::RuntimeRead>;
/// Handles the query from the chain extension environment for state reads.
pub trait ReadState {
type StateQuery: Decode;

/// Allowed state queries from the API.
fn contains(c: &Self::StateQuery) -> bool;

/// Reads state using the provided query, returning the result as a byte vector.
fn read(read: Self::StateQuery) -> Vec<u8>;

/// Decodes parameters into state query.
fn decode(params: &mut &[u8]) -> Result<Self::StateQuery, DispatchError> {
decode_checked(params)
}
}

/// Trait for handling parameters from the chain extension environment during state read operations.
pub trait ReadState<T: Config> {
fn read(read: T::RuntimeRead) -> Vec<u8>;
/// Handles the query from the chain extension environment for dispatch calls.
pub trait CallFilter {
type Call: Decode;

/// Allowed runtime calls from the API.
fn contains(t: &Self::Call) -> bool;
}

#[derive(Default)]
pub struct ApiExtension;
pub struct ApiExtension<I>(PhantomData<I>);

impl<T> ChainExtension<T> for ApiExtension
impl<T, I> ChainExtension<T> for ApiExtension<I>
where
T: Config + pallet_contracts::Config,
T: pallet_contracts::Config
+ frame_system::Config<
RuntimeCall: GetDispatchInfo + Dispatchable<PostInfo = PostDispatchInfo>,
>,
T::AccountId: UncheckedFrom<T::Hash> + AsRef<[u8]>,
// Bound the type by the two traits which need to be implemented by the runtime.
I: ReadState + CallFilter<Call = <T as frame_system::Config>::RuntimeCall> + 'static,
{
fn call<E: Ext<T = T>>(
&mut self,
Expand All @@ -72,10 +84,10 @@ where
log::debug!(target: LOG_TARGET, "Read input successfully");
match function_id {
FuncId::Dispatch => {
dispatch::<T, E>(&mut env, version, pallet_index, call_index, params)
dispatch::<T, E, I>(&mut env, version, pallet_index, call_index, params)
},
FuncId::ReadState => {
read_state::<T, E>(&mut env, version, pallet_index, call_index, params)
read_state::<T, E, I>(&mut env, version, pallet_index, call_index, params)
},
}
},
Expand Down Expand Up @@ -105,17 +117,13 @@ fn extract_env<T, E: Ext<T = T>>(env: &Environment<E, BufInBufOutState>) -> (u8,
(version, function_id, pallet_index, call_index)
}

fn read_state<T, E>(
fn read_state<T: frame_system::Config, E: Ext<T = T>, StateReader: ReadState>(
env: &mut Environment<E, BufInBufOutState>,
version: u8,
pallet_index: u8,
call_index: u8,
mut params: Vec<u8>,
) -> Result<(), DispatchError>
where
T: Config,
E: Ext<T = T>,
{
) -> Result<(), DispatchError> {
const LOG_PREFIX: &str = " read_state |";

// Prefix params with version, pallet, index to simplify decoding.
Expand All @@ -129,9 +137,9 @@ where
env.charge_weight(T::DbWeight::get().reads(1_u64))?;
let result = match version {
VersionedStateRead::V0 => {
let read = decode_checked::<T::RuntimeRead>(&mut encoded_read)?;
ensure!(T::AllowedApiCalls::contains(&read), UNKNOWN_CALL_ERROR);
T::StateReader::read(read)
let read = StateReader::decode(&mut encoded_read)?;
ensure!(StateReader::contains(&read), UNKNOWN_CALL_ERROR);
StateReader::read(read)
},
};
log::trace!(
Expand All @@ -141,16 +149,19 @@ where
env.write(&result, false, None)
}

fn dispatch<T, E>(
fn dispatch<T, E, Filter>(
env: &mut Environment<E, BufInBufOutState>,
version: u8,
pallet_index: u8,
call_index: u8,
mut params: Vec<u8>,
) -> Result<(), DispatchError>
where
T: Config,
T: frame_system::Config<
RuntimeCall: GetDispatchInfo + Dispatchable<PostInfo = PostDispatchInfo>,
>,
E: Ext<T = T>,
Filter: CallFilter<Call = <T as frame_system::Config>::RuntimeCall> + 'static,
{
const LOG_PREFIX: &str = " dispatch |";

Expand All @@ -162,7 +173,7 @@ where
// Contract is the origin by default.
let origin: T::RuntimeOrigin = RawOrigin::Signed(env.ext().address().clone()).into();
match call {
VersionedDispatch::V0(call) => dispatch_call::<T, E>(env, call, origin, LOG_PREFIX),
VersionedDispatch::V0(call) => dispatch_call::<T, E, Filter>(env, call, origin, LOG_PREFIX),
}
}

Expand All @@ -171,19 +182,22 @@ fn decode_checked<T: Decode>(params: &mut &[u8]) -> Result<T, DispatchError> {
T::decode(params).map_err(|_| DECODING_FAILED_ERROR)
}

fn dispatch_call<T, E>(
fn dispatch_call<T, E, Filter>(
env: &mut Environment<E, BufInBufOutState>,
call: T::RuntimeCall,
mut origin: T::RuntimeOrigin,
log_prefix: &str,
) -> Result<(), DispatchError>
where
T: Config,
T: frame_system::Config<
RuntimeCall: GetDispatchInfo + Dispatchable<PostInfo = PostDispatchInfo>,
>,
E: Ext<T = T>,
Filter: CallFilter<Call = <T as frame_system::Config>::RuntimeCall> + 'static,
{
let charged_dispatch_weight = env.charge_weight(call.get_dispatch_info().weight)?;
log::debug!(target:LOG_TARGET, "{} Inputted RuntimeCall: {:?}", log_prefix, call);
origin.add_filter(T::AllowedApiCalls::contains);
origin.add_filter(Filter::contains);
match call.dispatch(origin) {
Ok(info) => {
log::debug!(target:LOG_TARGET, "{} success, actual weight: {:?}", log_prefix, info.actual_weight);
Expand All @@ -210,7 +224,9 @@ enum VersionedStateRead {

/// Wrapper to enable versioning of runtime calls.
#[derive(Decode, Debug)]
enum VersionedDispatch<T: Config> {
enum VersionedDispatch<
T: frame_system::Config<RuntimeCall: GetDispatchInfo + Dispatchable<PostInfo = PostDispatchInfo>>,
chungquantin marked this conversation as resolved.
Show resolved Hide resolved
> {
/// Version zero of dispatch calls.
#[codec(index = 0)]
V0(T::RuntimeCall),
Expand Down
49 changes: 23 additions & 26 deletions runtime/devnet/src/config/api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{config::assets::TrustBackedAssetsInstance, fungibles, Runtime, RuntimeCall};
use codec::{Decode, Encode, MaxEncodedLen};
use frame_support::traits::Contains;
use pop_chain_extension::ReadState;
use pop_chain_extension::{CallFilter, ReadState};
use sp_std::vec::Vec;

/// A query of runtime state.
Expand All @@ -13,22 +12,36 @@ pub enum RuntimeRead {
Fungibles(fungibles::Read<Runtime>),
}

/// A struct that provides a state reading implementation for the Runtime.
pub struct StateReader;
impl ReadState<Runtime> for StateReader {
/// A struct that implement requirements for the Pop API chain extension.
#[derive(Default)]
pub struct Extension;
impl ReadState for Extension {
type StateQuery = RuntimeRead;

fn contains(c: &Self::StateQuery) -> bool {
use fungibles::Read::*;
matches!(
c,
RuntimeRead::Fungibles(
TotalSupply(..)
| BalanceOf { .. } | Allowance { .. }
| TokenName(..) | TokenSymbol(..)
| TokenDecimals(..) | AssetExists(..)
)
)
}

fn read(read: RuntimeRead) -> Vec<u8> {
match read {
RuntimeRead::Fungibles(key) => fungibles::Pallet::read_state(key),
}
}
}

/// A type to identify allowed calls to the Runtime from the API.
pub struct AllowedApiCalls;
impl CallFilter for Extension {
type Call = RuntimeCall;

impl Contains<RuntimeCall> for AllowedApiCalls {
/// Allowed runtime calls from the API.
fn contains(c: &RuntimeCall) -> bool {
fn contains(c: &Self::Call) -> bool {
use fungibles::Call::*;
matches!(
c,
Expand All @@ -46,22 +59,6 @@ impl Contains<RuntimeCall> for AllowedApiCalls {
}
}

impl Contains<RuntimeRead> for AllowedApiCalls {
/// Allowed state queries from the API.
fn contains(c: &RuntimeRead) -> bool {
use fungibles::Read::*;
matches!(
c,
RuntimeRead::Fungibles(
TotalSupply(..)
| BalanceOf { .. } | Allowance { .. }
| TokenName(..) | TokenSymbol(..)
| TokenDecimals(..) | AssetExists(..)
)
)
}
}

impl fungibles::Config for Runtime {
type AssetsInstance = TrustBackedAssetsInstance;
type WeightInfo = fungibles::weights::SubstrateWeight<Runtime>;
Expand Down
10 changes: 2 additions & 8 deletions runtime/devnet/src/config/contracts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::api::{AllowedApiCalls, RuntimeRead, StateReader};
use super::api::Extension;
use crate::{
deposit, Balance, Balances, BalancesCall, Perbill, Runtime, RuntimeCall, RuntimeEvent,
RuntimeHoldReason, Timestamp,
Expand Down Expand Up @@ -45,12 +45,6 @@ parameter_types! {
pub const CodeHashLockupDepositPercent: Perbill = Perbill::from_percent(0);
}

impl pop_chain_extension::Config for Runtime {
type RuntimeRead = RuntimeRead;
type StateReader = StateReader;
type AllowedApiCalls = AllowedApiCalls;
}

impl pallet_contracts::Config for Runtime {
type Time = Timestamp;
type Randomness = DummyRandomness<Self>;
Expand All @@ -70,7 +64,7 @@ impl pallet_contracts::Config for Runtime {
type CallStack = [pallet_contracts::Frame<Self>; 23];
type WeightPrice = pallet_transaction_payment::Pallet<Self>;
type WeightInfo = pallet_contracts::weights::SubstrateWeight<Self>;
type ChainExtension = pop_chain_extension::ApiExtension;
type ChainExtension = pop_chain_extension::ApiExtension<Extension>;
type Schedule = Schedule;
type AddressGenerator = pallet_contracts::DefaultAddressGenerator;
// This node is geared towards development and testing of contracts.
Expand Down
Loading