Skip to content

Commit d46ada0

Browse files
committed
replication: define DecodingPlugin trait for logical decoding
When using a logical decoding stream the payload of each message is encoded using a plugin. Each implementor of the trait is responsible for providing its name and options, as well as decoding the incoming bytes into a more appropriate data structure. Signed-off-by: Petros Angelatos <petrosagg@gmail.com>
1 parent 89cc819 commit d46ada0

File tree

4 files changed

+65
-37
lines changed

4 files changed

+65
-37
lines changed

postgres-protocol/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub mod authentication;
2020
pub mod escape;
2121
pub mod message;
2222
pub mod password;
23+
pub mod replication;
2324
pub mod types;
2425

2526
/// A Postgres OID.

postgres-protocol/src/message/backend.rs

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#![allow(missing_docs)]
22

33
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
4-
use bytes::{Bytes, BytesMut};
4+
use bytes::buf::Reader as BufReader;
5+
use bytes::{Buf, Bytes, BytesMut};
56
use fallible_iterator::FallibleIterator;
67
use memchr::memchr;
78
use std::cmp;
@@ -39,6 +40,20 @@ pub const READY_FOR_QUERY_TAG: u8 = b'Z';
3940
pub const XLOG_DATA_TAG: u8 = b'w';
4041
pub const PRIMARY_KEEPALIVE_TAG: u8 = b'k';
4142

43+
pub trait Parse: Sized {
44+
fn parse(buf: &Bytes) -> io::Result<Self> {
45+
Self::parse_reader(&mut buf.clone().reader())
46+
}
47+
48+
fn parse_reader(buf: &mut BufReader<Bytes>) -> io::Result<Self>;
49+
}
50+
51+
impl Parse for Bytes {
52+
fn parse_reader(buf: &mut BufReader<Bytes>) -> io::Result<Self> {
53+
Ok(buf.get_ref().clone())
54+
}
55+
}
56+
4257
#[derive(Debug, Copy, Clone)]
4358
pub struct Header {
4459
tag: u8,
@@ -297,31 +312,26 @@ impl Message {
297312

298313
/// An enum representing Postgres backend replication messages.
299314
#[non_exhaustive]
300-
pub enum ReplicationMessage {
301-
XLogData(XLogDataBody),
315+
pub enum ReplicationMessage<D> {
316+
XLogData(XLogDataBody<D>),
302317
PrimaryKeepAlive(PrimaryKeepAliveBody),
303318
}
304319

305-
impl ReplicationMessage {
306-
pub fn parse(bytes: &Bytes) -> io::Result<ReplicationMessage> {
307-
let mut buf = Buffer {
308-
bytes: bytes.clone(),
309-
idx: 0,
310-
};
311-
320+
impl<D: Parse> Parse for ReplicationMessage<D> {
321+
fn parse_reader(buf: &mut BufReader<Bytes>) -> io::Result<Self> {
312322
let tag = buf.read_u8()?;
313323

314324
let replication_message = match tag {
315325
XLOG_DATA_TAG => {
316326
let wal_start = buf.read_u64::<BigEndian>()?;
317327
let wal_end = buf.read_u64::<BigEndian>()?;
318328
let timestamp = buf.read_i64::<BigEndian>()?;
319-
let storage = buf.read_all();
329+
let data = D::parse(&buf.get_mut().clone())?;
320330
ReplicationMessage::XLogData(XLogDataBody {
321331
wal_start,
322332
wal_end,
323333
timestamp,
324-
storage,
334+
data,
325335
})
326336
}
327337
PRIMARY_KEEPALIVE_TAG => {
@@ -865,14 +875,14 @@ impl RowDescriptionBody {
865875
}
866876
}
867877

868-
pub struct XLogDataBody {
878+
pub struct XLogDataBody<D> {
869879
wal_start: u64,
870880
wal_end: u64,
871881
timestamp: i64,
872-
storage: Bytes,
882+
data: D,
873883
}
874884

875-
impl XLogDataBody {
885+
impl<D> XLogDataBody<D> {
876886
#[inline]
877887
pub fn wal_start(&self) -> u64 {
878888
self.wal_start
@@ -889,13 +899,13 @@ impl XLogDataBody {
889899
}
890900

891901
#[inline]
892-
pub fn data(&self) -> &[u8] {
893-
&self.storage
902+
pub fn data(&self) -> &D {
903+
&self.data
894904
}
895905

896906
#[inline]
897-
pub fn into_bytes(self) -> Bytes {
898-
self.storage
907+
pub fn into_data(self) -> D {
908+
self.data
899909
}
900910
}
901911

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
use std::collections::HashMap;
2+
3+
use crate::message::backend::Parse;
4+
5+
pub trait DecodingPlugin {
6+
type Message: Parse;
7+
8+
fn name(&self) -> &str;
9+
fn options(&self) -> HashMap<String, String>;
10+
}

tokio-postgres/src/replication_client.rs

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,20 @@ use crate::copy_out;
128128
use crate::connection::RequestMessages;
129129
use crate::types::Type;
130130
use crate::{simple_query, Client, Error, SimpleQueryStream, SimpleQueryMessage};
131-
use bytes::BytesMut;
131+
use bytes::{Bytes, BytesMut};
132132
use fallible_iterator::FallibleIterator;
133133
use futures::{ready, Stream, TryStreamExt};
134134
use pin_project::{pin_project, pinned_drop};
135135
use postgres_types::PgLsn;
136136
use postgres_protocol::escape::{escape_identifier, escape_literal};
137-
use postgres_protocol::message::backend::{Message, ReplicationMessage, RowDescriptionBody};
137+
use postgres_protocol::message::backend::{
138+
Message, Parse, ReplicationMessage, RowDescriptionBody,
139+
};
138140
use postgres_protocol::message::frontend;
141+
use postgres_protocol::replication::DecodingPlugin;
142+
139143
use std::io;
140-
use std::marker::PhantomPinned;
144+
use std::marker::{PhantomData, PhantomPinned};
141145
use std::path::{Path, PathBuf};
142146
use std::pin::Pin;
143147
use std::str::from_utf8;
@@ -495,11 +499,11 @@ impl ReplicationClient {
495499
/// Plugins](https://www.postgresql.org/docs/current/logicaldecoding-output-plugin.html)).
496500
/// * `snapshot_mode`: Decides what to do with the snapshot
497501
/// created during logical slot initialization.
498-
pub async fn create_logical_replication_slot(
502+
pub async fn create_logical_replication_slot<P: DecodingPlugin>(
499503
&mut self,
500504
slot_name: &str,
501505
temporary: bool,
502-
plugin_name: &str,
506+
plugin: &P,
503507
snapshot_mode: Option<SnapshotMode>,
504508
) -> Result<CreateReplicationSlotResponse, Error> {
505509
let temporary_str = if temporary { " TEMPORARY" } else { "" };
@@ -512,7 +516,7 @@ impl ReplicationClient {
512516
"CREATE_REPLICATION_SLOT {}{} LOGICAL {}{}",
513517
escape_identifier(slot_name),
514518
temporary_str,
515-
escape_identifier(plugin_name),
519+
escape_identifier(plugin.name()),
516520
snapshot_str
517521
);
518522
let mut responses = self.send(&command).await?;
@@ -584,7 +588,7 @@ impl ReplicationClient {
584588
slot_name: Option<&str>,
585589
lsn: PgLsn,
586590
timeline_id: Option<u32>,
587-
) -> Result<Pin<Box<ReplicationStream<'a>>>, Error> {
591+
) -> Result<Pin<Box<ReplicationStream<'a, Bytes>>>, Error> {
588592
let slot = match slot_name {
589593
Some(name) => format!(" SLOT {}", escape_identifier(name)),
590594
None => String::from(""),
@@ -614,13 +618,14 @@ impl ReplicationClient {
614618
/// * `lsn`: The starting WAL location.
615619
/// * `options`: (name, value) pairs of options passed to the
616620
/// slot's logical decoding plugin.
617-
pub async fn start_logical_replication<'a>(
621+
pub async fn start_logical_replication<'a, P: DecodingPlugin>(
618622
&'a mut self,
619623
slot_name: &str,
620624
lsn: PgLsn,
621-
options: &[(&str, &str)],
622-
) -> Result<Pin<Box<ReplicationStream<'a>>>, Error> {
625+
plugin: &P,
626+
) -> Result<Pin<Box<ReplicationStream<'a, P::Message>>>, Error> {
623627
let slot = format!(" SLOT {}", escape_identifier(slot_name));
628+
let options = plugin.options();
624629
let options_string = if !options.is_empty() {
625630
format!(
626631
" ({})",
@@ -678,10 +683,10 @@ impl ReplicationClient {
678683
Ok(responses)
679684
}
680685

681-
async fn start_replication<'a>(
686+
async fn start_replication<'a, D>(
682687
&'a mut self,
683688
command: String,
684-
) -> Result<Pin<Box<ReplicationStream<'a>>>, Error> {
689+
) -> Result<Pin<Box<ReplicationStream<'a, D>>>, Error> {
685690
let mut copyboth_received = false;
686691
let mut replication_response: Option<ReplicationResponse> = None;
687692
let mut responses = self.send(&command).await?;
@@ -717,6 +722,7 @@ impl ReplicationClient {
717722
copydone_sent: false,
718723
copydone_received: false,
719724
replication_response: replication_response,
725+
_phantom_data: PhantomData,
720726
_phantom_pinned: PhantomPinned,
721727
}))
722728
}
@@ -752,18 +758,19 @@ impl ReplicationClient {
752758
/// [stop_replication()](ReplicationStream::stop_replication()) will
753759
/// return a response tuple.
754760
#[pin_project(PinnedDrop)]
755-
pub struct ReplicationStream<'a> {
761+
pub struct ReplicationStream<'a, D> {
756762
rclient: &'a mut ReplicationClient,
757763
responses: Responses,
758764
copyboth_received: bool,
759765
copydone_sent: bool,
760766
copydone_received: bool,
761767
replication_response: Option<ReplicationResponse>,
768+
_phantom_data: PhantomData<D>,
762769
#[pin]
763770
_phantom_pinned: PhantomPinned,
764771
}
765772

766-
impl<'a> ReplicationStream<'a> {
773+
impl<'a, D> ReplicationStream<'a, D> {
767774
/// Send standby update to server.
768775
pub async fn standby_status_update(
769776
self: Pin<&mut Self>,
@@ -855,8 +862,8 @@ impl<'a> ReplicationStream<'a> {
855862
}
856863
}
857864

858-
impl Stream for ReplicationStream<'_> {
859-
type Item = Result<ReplicationMessage, Error>;
865+
impl<D: Parse> Stream for ReplicationStream<'_, D> {
866+
type Item = Result<ReplicationMessage<D>, Error>;
860867

861868
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
862869
let this = self.project();
@@ -872,7 +879,7 @@ impl Stream for ReplicationStream<'_> {
872879
assert!(!*this.copydone_received);
873880
match ready!(this.responses.poll_next(cx)?) {
874881
Message::CopyData(body) => {
875-
let r = ReplicationMessage::parse(&body.into_bytes());
882+
let r = ReplicationMessage::parse(&mut body.into_bytes());
876883
Poll::Ready(Some(r.map_err(Error::parse)))
877884
}
878885
Message::CopyDone => {
@@ -887,7 +894,7 @@ impl Stream for ReplicationStream<'_> {
887894
}
888895

889896
#[pinned_drop]
890-
impl PinnedDrop for ReplicationStream<'_> {
897+
impl<D> PinnedDrop for ReplicationStream<'_, D> {
891898
fn drop(mut self: Pin<&mut Self>) {
892899
let this = self.project();
893900
if *this.copyboth_received && !*this.copydone_sent {

0 commit comments

Comments
 (0)