Skip to content

Commit

Permalink
fix: catch error in case of action fetch_block
Browse files Browse the repository at this point in the history
  • Loading branch information
zitsen committed Sep 8, 2022
1 parent 834a110 commit 562039b
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 250 deletions.
7 changes: 1 addition & 6 deletions taos-query/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,7 @@ mod r#async {
async fn topics(&self) -> Result<Vec<Topic>, Self::Error> {
let sql = "SELECT * FROM information_schema.ins_topics";
log::debug!("query one with sql: {sql}");
Ok(self
.query(sql)
.await?
.deserialize()
.try_collect()
.await?)
Ok(self.query(sql).await?.deserialize().try_collect().await?)
}

/// Get table meta information.
Expand Down
2 changes: 1 addition & 1 deletion taos-ws-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ unsafe fn connect_with_dsn(dsn: *const c_char) -> WsTaos {
let dsn = CStr::from_ptr(dsn).to_str()?;
let builder = TaosBuilder::from_dsn(dsn)?;
let mut taos = builder.build()?;

builder.ping(&mut taos)?;
Ok(taos)
}
Expand Down
3 changes: 2 additions & 1 deletion taos-ws/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ description = "TDengine connector with websocket protocol"
anyhow = "1"
async-trait = { version = "0.1.56" }
bytes = "1.1.0"
derive_more = "0.99"
futures = { version = "0.3" }
itertools = "0.10.3"
log = "0.4"
once_cell = "1"
parse_duration = "2.1"
scc = "0.8"
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1" }
Expand All @@ -26,7 +28,6 @@ taos-query = { path = "../taos-query", version = "0.*" }
thiserror = "1"
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = { version = "0.17", features = ["native-tls"] }
derive_more = "0.99"
[dev-dependencies]
pretty_env_logger = "0.4.0"

Expand Down
11 changes: 10 additions & 1 deletion taos-ws/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![recursion_limit = "256"]
use std::fmt::{Debug, Display};
use std::time::Duration;

use once_cell::sync::OnceCell;

Expand Down Expand Up @@ -31,6 +32,7 @@ pub struct TaosBuilder {
addr: String,
auth: WsAuth,
database: Option<String>,
timeout: Duration,
}

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -123,13 +125,19 @@ impl TaosBuilder {
None => "localhost:6041".to_string(),
};

let timeout = dsn
.params
.remove("timeout")
.and_then(|s| parse_duration::parse(&s).ok())
.unwrap_or(Duration::from_secs(60 * 5)); // default to 5m

if let Some(token) = token {
// dbg!(&token);
Ok(TaosBuilder {
scheme,
addr,
auth: WsAuth::Token(token),
database: dsn.database,
timeout,
})
} else {
let username = dsn.username.unwrap_or_else(|| "root".to_string());
Expand All @@ -139,6 +147,7 @@ impl TaosBuilder {
addr,
auth: WsAuth::Plain(username, password),
database: dsn.database,
timeout,
})
}
}
Expand Down
167 changes: 85 additions & 82 deletions taos-ws/src/query/asyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ use tokio::net::TcpStream;
use tokio::sync::watch;

use tokio::time;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::{
connect_async_with_config, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
};

use super::{infra::*, TaosBuilder};

Expand Down Expand Up @@ -61,7 +63,7 @@ struct WsQuerySender {
results: Arc<QueryResMapper>,
sender: WsSender,
queries: QueryAgent,
// timeout: Duration,
timeout: Duration,
}

impl WsQuerySender {
Expand All @@ -70,6 +72,15 @@ impl WsQuerySender {
.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
async fn send_recv(&self, msg: WsSend) -> Result<WsRecvData> {
self.send_recv_timeout(msg, self.timeout).await
}
async fn send_only(&self, msg: WsSend) -> Result<()> {
let send_timeout = Duration::from_millis(1000);
self.sender.send_timeout(msg.to_msg(), send_timeout).await?;
Ok(())
}

async fn send_recv_timeout(&self, msg: WsSend, timeout: Duration) -> Result<WsRecvData> {
let send_timeout = Duration::from_millis(1000);
let req_id = msg.req_id();
let (tx, rx) = query_channel();
Expand All @@ -80,7 +91,10 @@ impl WsQuerySender {
WsSend::FetchBlock(args) => {
log::debug!("prepare req_id: {req_id} with message: {msg:?}");
if self.results.contains(&args.id) {
panic!("there's a result with id {}", args.id);
Err(RawError::from_any(format!(
"there's a result with id {}",
args.id
)))?;
}
self.results.insert(args.id, args.req_id).unwrap();

Expand All @@ -97,27 +111,14 @@ impl WsQuerySender {
self.sender.send_timeout(msg.to_msg(), send_timeout).await?;
}
}
Ok(block_in_place_or_global(rx).unwrap()?)
Ok(block_in_place_or_global(tokio::time::timeout(timeout, rx))
.map_err(|err| {
RawError::from_any(format!(
"Timeout when retrieving message: {err} ({timeout:?})"
))
})?
.unwrap()?)
}
async fn send_only(&self, msg: WsSend) -> Result<()> {
let send_timeout = Duration::from_millis(1000);
self.sender.send_timeout(msg.to_msg(), send_timeout).await?;
Ok(())
}
// async fn send_recv_timeout(&self, msg: WsSend, timeout: Duration) -> Result<WsRecvData> {
// let sleep = tokio::time::sleep(timeout);
// tokio::pin!(sleep);
// let data = tokio::select! {
// _ = &mut sleep, if !sleep.is_elapsed() => {
// log::debug!("poll timed out");
// Err(Error::QueryTimeout("poll".to_string()))?
// }
// message = self.send_recv(msg) => {
// message?
// }
// };
// Ok(data)
// }
}

#[derive(Debug)]
Expand Down Expand Up @@ -189,6 +190,8 @@ pub enum Error {
WsError(#[from] WsError),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("Websocket has been closed: {0}")]
WsClosed(String),
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -284,6 +287,15 @@ async fn read_queries(
log::warn!("req_id {req_id} not detected, message might be lost");
}
}
WsRecvData::FetchBlock => {
assert!(ok.is_err());
if let Some((_, sender)) = queries_sender.remove(&req_id)
{
sender.send(ok.map(|_| data)).unwrap();
} else {
log::warn!("req_id {req_id} not detected, message might be lost");
}
}
WsRecvData::WriteMeta => {
if let Some((_, sender)) = queries_sender.remove(&req_id)
{
Expand Down Expand Up @@ -331,7 +343,6 @@ async fn read_queries(
// v3
if let Some((_, sender)) = queries_sender.remove(&req_id) {
log::debug!("send data to fetches with id {}", res_id);
// let raw = slice.read_inlinable::<RawBlock>().unwrap();
sender.send(Ok(WsRecvData::Block { timing, raw: block[offset..].to_vec() })).unwrap();
} else {
log::warn!("req_id {res_id} not detected, message might be lost");
Expand All @@ -340,7 +351,6 @@ async fn read_queries(
// v2
if let Some((_, sender)) = queries_sender.remove(&req_id) {
log::debug!("send data to fetches with id {}", res_id);
// let raw = slice.read_inlinable::<RawBlock>().unwrap();
sender.send(Ok(WsRecvData::BlockV2 { timing, raw: block[offset..].to_vec() })).unwrap();
} else {
log::warn!("req_id {res_id} not detected, message might be lost");
Expand All @@ -350,8 +360,31 @@ async fn read_queries(
log::warn!("result id {res_id} not found");
}
}
Message::Close(_) => {
log::warn!("websocket connection is closed (unexpected?)");
Message::Close(close) => {
if let Some(close) = close {
log::warn!("websocket received close frame: {close:?}");

let mut keys = Vec::new();
queries_sender.for_each_async(|k, _| {
keys.push(*k);
}).await;
for k in keys {
if let Some((_, sender)) = queries_sender.remove(&k) {
let _ = sender.send(Err(RawError::from_any(close.reason.to_string())));
}
}
} else {
log::warn!("websocket connection is closed normally");
let mut keys = Vec::new();
queries_sender.for_each_async(|k, _| {
keys.push(*k);
}).await;
for k in keys {
if let Some((_, sender)) = queries_sender.remove(&k) {
let _ = sender.send(Err(RawError::from_any("websocket connection is closed")));
}
}
}
break;
}
Message::Ping(bytes) => {
Expand Down Expand Up @@ -379,6 +412,21 @@ async fn read_queries(
}
}
}
if queries_sender.is_empty() {
return;
}

let mut keys = Vec::new();
queries_sender
.for_each_async(|k, _| {
keys.push(*k);
})
.await;
for k in keys {
if let Some((_, sender)) = queries_sender.remove(&k) {
let _ = sender.send(Err(RawError::from_any("websocket connection is closed")));
}
}
}

impl WsTaos {
Expand All @@ -394,7 +442,10 @@ impl WsTaos {
Self::from_wsinfo(&info).await
}
pub(crate) async fn from_wsinfo(info: &TaosBuilder) -> Result<Self> {
let (ws, _) = connect_async(info.to_query_url()).await?;
let mut config = WebSocketConfig::default();
config.max_frame_size = Some(1024 * 1024 * 16);

let (ws, _) = connect_async_with_config(info.to_query_url(), Some(config)).await?;
let req_id = 0;
let (mut sender, mut reader) = ws.split();

Expand Down Expand Up @@ -494,7 +545,7 @@ impl WsTaos {
sender: ws_cloned,
queries: queries2_cloned,
results,
// timeout: Duration::from_secs(60 * 5),
timeout: info.timeout,
},
})
}
Expand All @@ -509,39 +560,15 @@ impl WsTaos {
meta.write_u64_le(message_id)?;
meta.write_u64_le(raw_meta_message as u64)?;
meta.write_all(&raw.as_bytes())?;
let len = meta.len();

log::debug!(
"write meta with req_id: {}, message_id: {}, raw data length: {:?}",
req_id,
message_id,
meta.len()
);
log::debug!("write meta with req_id: {req_id}, raw data length: {len}",);

match self.sender.send_recv(WsSend::Binary(meta)).await? {
WsRecvData::WriteMeta => Ok(()),
WsRecvData::WriteRaw => Ok(()),
_ => unreachable!(),
}

// let (tx, rx) = oneshot::channel();
// {
// self.queries.insert(req_id, tx).unwrap();
// self.ws
// .send_timeout(Message::Binary(meta), self.timeout)
// .await?;
// }
// let sleep = tokio::time::sleep(self.timeout);
// tokio::pin!(sleep);
// let _resp = tokio::select! {
// _ = &mut sleep, if !sleep.is_elapsed() => {
// log::debug!("get server version timed out");
// Err(Error::QueryTimeout("write meta".to_string()))?
// }
// message = rx => {
// message??
// }
// };
// Ok(())
}
async fn s_write_raw_block(&self, raw: &RawBlock) -> Result<()> {
let req_id = self.sender.req_id();
Expand All @@ -555,37 +582,13 @@ impl WsTaos {
meta.write_u32_le(raw.nrows() as u32)?;
meta.write_inlined_str::<2>(raw.table_name().unwrap())?;
meta.write_all(raw.as_raw_bytes())?;

log::debug!(
"write meta with req_id: {}, message_id: {}, raw data len: {:?}",
req_id,
message_id,
meta.len()
);
let len = meta.len();
log::debug!("write block with req_id: {req_id}, raw data len: {len}",);

match self.sender.send_recv(WsSend::Binary(meta)).await? {
WsRecvData::WriteRawBlock => Ok(()),
_ => unreachable!(),
}
// let (tx, rx) = oneshot::channel();
// {
// self.queries.insert(req_id, tx).unwrap();
// self.ws
// .send_timeout(Message::Binary(meta), self.timeout)
// .await?;
// }
// let sleep = tokio::time::sleep(self.timeout);
// tokio::pin!(sleep);
// let _resp = tokio::select! {
// _ = &mut sleep, if !sleep.is_elapsed() => {
// log::debug!("get server version timed out");
// Err(Error::QueryTimeout("write meta".to_string()))?
// }
// message = rx => {
// message??
// }
// };
// Ok(())
}

pub async fn s_query(&self, sql: &str) -> Result<ResultSet> {
Expand Down
5 changes: 5 additions & 0 deletions taos-ws/src/query/infra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ pub struct WsQueryResp {
#[derive(Debug, Deserialize, Clone)]
pub struct WsFetchResp {
pub id: ResId,
#[serde(default)]
pub completed: bool,
#[serde(default)]
pub lengths: Option<Vec<u32>>,
#[serde(default)]
pub rows: usize,
#[serde(default)]
#[serde_as(as = "serde_with::DurationNanoSeconds")]
Expand All @@ -145,6 +148,8 @@ pub enum WsRecvData {
},
Query(WsQueryResp),
Fetch(WsFetchResp),
/// Will only produced by error
FetchBlock,
Block {
#[serde(default)]
#[serde_as(as = "serde_with::DurationNanoSeconds")]
Expand Down
Loading

0 comments on commit 562039b

Please sign in to comment.