Skip to content

chore: reuse SVMProvider type #1025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions src/arch/svm/SpokeUtils.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import assert from "assert";
import { Logger } from "winston";
import { Rpc, SolanaRpcApi, Address, fetchEncodedAccounts, fetchEncodedAccount } from "@solana/kit";
import { Address, fetchEncodedAccounts, fetchEncodedAccount } from "@solana/kit";
import { fetchState, decodeFillStatusAccount } from "@across-protocol/contracts/dist/src/svm/clients/SvmSpoke";

import { SvmCpiEventsClient } from "./eventsClient";
import { Deposit, FillStatus, FillWithBlock, RelayData } from "../../interfaces";
import { BigNumber, chainIsSvm, chunk, isUnsafeDepositId } from "../../utils";
import { getFillStatusPda, unwrapEventData } from "./utils";
import { SVMEventNames } from "./types";

type Provider = Rpc<SolanaRpcApi>;
import { SVMEventNames, SVMProvider } from "./types";

/**
* @param spokePool SpokePool Contract instance.
Expand Down Expand Up @@ -39,7 +37,7 @@ export function getTimeAt(_spokePool: unknown, _blockNumber: number): Promise<nu
* @note This should be the same as getTimeAt() but can differ in test. These two functions should be consolidated.
* @returns The chain time at the specified slot.
*/
export async function getTimestampForSlot(provider: Provider, slotNumber: number): Promise<number> {
export async function getTimestampForSlot(provider: SVMProvider, slotNumber: number): Promise<number> {
const block = await provider.getBlock(BigInt(slotNumber)).send();
let timestamp: number;
if (!block?.blockTime) {
Expand All @@ -58,7 +56,7 @@ export async function getTimestampForSlot(provider: Provider, slotNumber: number
* @param statePda Spoke Pool's State PDA
* @returns fill deadline buffer
*/
export async function getFillDeadline(provider: Provider, statePda: Address): Promise<number> {
export async function getFillDeadline(provider: SVMProvider, statePda: Address): Promise<number> {
const state = await fetchState(provider, statePda);
return state.data.fillDeadlineBuffer;
}
Expand Down Expand Up @@ -102,12 +100,11 @@ export async function relayFillStatus(
programId: Address,
relayData: RelayData,
destinationChainId: number,
provider: Provider,
svmEventsClient: SvmCpiEventsClient,
atHeight?: number
): Promise<FillStatus> {
assert(chainIsSvm(destinationChainId), "Destination chain must be an SVM chain");

const provider = svmEventsClient.getRpc();
// Get fill status PDA using relayData
const fillStatusPda = await getFillStatusPda(programId, relayData, destinationChainId);
const currentSlot = await provider.getSlot({ commitment: "confirmed" }).send();
Expand Down Expand Up @@ -152,13 +149,12 @@ export async function fillStatusArray(
programId: Address,
relayData: RelayData[],
destinationChainId: number,
provider: Provider,
svmEventsClient: SvmCpiEventsClient,
atHeight?: number,
logger?: Logger
): Promise<(FillStatus | undefined)[]> {
assert(chainIsSvm(destinationChainId), "Destination chain must be an SVM chain");

const provider = svmEventsClient.getRpc();
const chunkSize = 100;
const chunkedRelayData = chunk(relayData, chunkSize);

Expand Down Expand Up @@ -325,7 +321,7 @@ async function resolveFillStatusFromPdaEvents(
* @param relayData An array of relay data from which the fill status PDAs were derived.
*/
async function fetchBatchFillStatusFromPdaAccounts(
provider: Provider,
provider: SVMProvider,
fillStatusPdas: Address[],
relayDataArray: RelayData[]
): Promise<(FillStatus | undefined)[]> {
Expand Down
32 changes: 7 additions & 25 deletions src/arch/svm/eventsClient.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
import { Idl } from "@coral-xyz/anchor";
import { getDeployedAddress, SvmSpokeIdl } from "@across-protocol/contracts";
import { getSolanaChainId } from "@across-protocol/contracts/dist/src/svm/web3-v1";
import web3, {
Address,
Commitment,
GetSignaturesForAddressApi,
GetTransactionApi,
Rpc,
RpcTransport,
Signature,
SolanaRpcApiFromTransport,
} from "@solana/kit";
import web3, { Address, Commitment, GetSignaturesForAddressApi, GetTransactionApi, Signature } from "@solana/kit";
import { bs58 } from "../../utils";
import { EventName, EventWithData } from "./types";
import { EventName, EventWithData, SVMProvider } from "./types";
import { decodeEvent, isDevnet } from "./utils";

// Utility type to extract the return type for the JSON encoding overload. We only care about the overload where the
Expand All @@ -29,20 +20,15 @@ type GetSignaturesForAddressTransaction = ReturnType<GetSignaturesForAddressApi[
type GetSignaturesForAddressApiResponse = readonly GetSignaturesForAddressTransaction[];

export class SvmCpiEventsClient {
private rpc: web3.Rpc<web3.SolanaRpcApiFromTransport<RpcTransport>>;
private rpc: SVMProvider;
private programAddress: Address;
private programEventAuthority: Address;
private idl: Idl;

/**
* Protected constructor. Use the async create() method to instantiate.
*/
protected constructor(
rpc: web3.Rpc<web3.SolanaRpcApiFromTransport<RpcTransport>>,
address: Address,
eventAuthority: Address,
idl: Idl
) {
protected constructor(rpc: SVMProvider, address: Address, eventAuthority: Address, idl: Idl) {
this.rpc = rpc;
this.programAddress = address;
this.programEventAuthority = eventAuthority;
Expand All @@ -52,18 +38,14 @@ export class SvmCpiEventsClient {
/**
* Factory method to asynchronously create an instance of SvmSpokeEventsClient.
*/
public static async create(rpc: web3.Rpc<web3.SolanaRpcApiFromTransport<RpcTransport>>): Promise<SvmCpiEventsClient> {
public static async create(rpc: SVMProvider): Promise<SvmCpiEventsClient> {
const isTestnet = await isDevnet(rpc);
const programId = getDeployedAddress("SvmSpoke", getSolanaChainId(isTestnet ? "devnet" : "mainnet").toString());
if (!programId) throw new Error("Program not found");
return this.createFor(rpc, programId, SvmSpokeIdl);
}

public static async createFor(
rpc: web3.Rpc<web3.SolanaRpcApiFromTransport<RpcTransport>>,
programId: string,
idl: Idl
): Promise<SvmCpiEventsClient> {
public static async createFor(rpc: SVMProvider, programId: string, idl: Idl): Promise<SvmCpiEventsClient> {
const address = web3.address(programId);
const [eventAuthority] = await web3.getProgramDerivedAddress({
programAddress: address,
Expand Down Expand Up @@ -233,7 +215,7 @@ export class SvmCpiEventsClient {
return this.programAddress;
}

public getRpc(): Rpc<SolanaRpcApiFromTransport<RpcTransport>> {
public getRpc(): SVMProvider {
return this.rpc;
}
}
4 changes: 2 additions & 2 deletions src/arch/svm/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Signature, Address, UnixTimestamp, SolanaRpcApi, Rpc } from "@solana/kit";
import { Signature, Address, UnixTimestamp, Rpc, SolanaRpcApiFromTransport, RpcTransport } from "@solana/kit";
import { SvmSpokeClient } from "@across-protocol/contracts";

export type EventData =
Expand Down Expand Up @@ -46,4 +46,4 @@ export type EventWithData = {
program: Address;
};

export type SVMProvider = Rpc<SolanaRpcApi>;
export type SVMProvider = Rpc<SolanaRpcApiFromTransport<RpcTransport>>;
6 changes: 3 additions & 3 deletions src/arch/svm/utils.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { BN, BorshEventCoder, Idl } from "@coral-xyz/anchor";
import { BigNumber, getRelayDataHash, isUint8Array, SvmAddress } from "../../utils";
import web3, { address, isAddress, RpcTransport, getProgramDerivedAddress, getU64Encoder, Address } from "@solana/kit";
import { address, isAddress, getProgramDerivedAddress, getU64Encoder, Address } from "@solana/kit";

import { EventName, SVMEventNames } from "./types";
import { EventName, SVMEventNames, SVMProvider } from "./types";
import { FillType, RelayData } from "../../interfaces";

/**
* Helper to determine if the current RPC network is devnet.
*/
export async function isDevnet(rpc: web3.Rpc<web3.SolanaRpcApiFromTransport<RpcTransport>>): Promise<boolean> {
export async function isDevnet(rpc: SVMProvider): Promise<boolean> {
const genesisHash = await rpc.getGenesisHash().send();
return genesisHash === "EtWTRABZaYq6iMfeYKouRu166VU2xqa1wcaWoxPkrZBG";
}
Expand Down
19 changes: 2 additions & 17 deletions src/clients/SpokePoolClient/SVMSpokePoolClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,7 @@ export class SvmSpokePoolClient extends SpokePoolClient {
destinationChainId?: number
): Promise<FillStatus> {
destinationChainId ??= this.chainId;
return relayFillStatus(
this.programId,
relayData,
destinationChainId,
this.svmEventsClient.getRpc(),
this.svmEventsClient,
atHeight
);
return relayFillStatus(this.programId, relayData, destinationChainId, this.svmEventsClient, atHeight);
}

/**
Expand All @@ -244,14 +237,6 @@ export class SvmSpokePoolClient extends SpokePoolClient {
): Promise<(FillStatus | undefined)[]> {
// @note: deploymentBlock actually refers to the deployment slot. Also, blockTag should be a slot number.
destinationChainId ??= this.chainId;
return fillStatusArray(
this.programId,
relayData,
destinationChainId,
this.svmEventsClient.getRpc(),
this.svmEventsClient,
atHeight,
this.logger
);
return fillStatusArray(this.programId, relayData, destinationChainId, this.svmEventsClient, atHeight, this.logger);
}
}
14 changes: 4 additions & 10 deletions test/mocks/MockSolanaEventClient.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import { Address, RpcTransport, SolanaRpcApiFromTransport, Rpc } from "@solana/kit";
import { Address } from "@solana/kit";
import { SvmCpiEventsClient } from "../../src/arch/svm/eventsClient";
import { Idl } from "@coral-xyz/anchor";
import { EventName, EventWithData } from "../../src/arch/svm";
import { EventName, EventWithData, SVMProvider } from "../../src/arch/svm";
import { MockSolanaRpcFactory } from "./MockSolanaRpcFactory";

export class MockSolanaEventClient extends SvmCpiEventsClient {
private events: Record<EventName, EventWithData[]> = {} as Record<EventName, EventWithData[]>;
private slotHeight: bigint;

constructor(programId = "JAZWcGrpSWNPTBj8QtJ9UyQqhJCDhG9GJkDeMf5NQBiq") {
super(
null as unknown as Rpc<SolanaRpcApiFromTransport<RpcTransport>>,
programId as Address,
null as unknown as Address,
null as unknown as Idl,
null as unknown as Address
);
super(null as unknown as SVMProvider, programId as Address, null as unknown as Address, null as unknown as Idl);
}

public setSlotHeight(slotHeight: bigint) {
Expand Down Expand Up @@ -47,7 +41,7 @@ export class MockSolanaEventClient extends SvmCpiEventsClient {
);
}

public override getRpc(): Rpc<SolanaRpcApiFromTransport<RpcTransport>> {
public override getRpc(): SVMProvider {
const client = new MockSolanaRpcFactory("https://test.com", 1234567890);
client.setResult("getSlot", [], this.slotHeight);
return client.createRpcClient();
Expand Down