Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: http handler support set variable. #16239

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/query/service/src/servers/http/v1/query/execute_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;

Expand All @@ -22,6 +23,7 @@ use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::DataBlock;
use databend_common_expression::DataSchemaRef;
use databend_common_expression::Scalar;
use databend_common_io::prelude::FormatSettings;
use databend_common_settings::Settings;
use databend_storages_common_txn::TxnManagerRef;
Expand Down Expand Up @@ -147,6 +149,7 @@ pub struct ExecutorSessionState {
pub secondary_roles: Option<Vec<String>>,
pub settings: Arc<Settings>,
pub txn_manager: TxnManagerRef,
pub variables: HashMap<String, Scalar>,
}

impl ExecutorSessionState {
Expand All @@ -157,6 +160,7 @@ impl ExecutorSessionState {
secondary_roles: session.get_secondary_roles(),
settings: session.get_settings(),
txn_manager: session.txn_mgr(),
variables: session.get_all_variables(),
}
}
}
Expand Down
97 changes: 96 additions & 1 deletion src/query/service/src/servers/http/v1/query/http_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::BTreeMap;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
Expand All @@ -29,6 +30,7 @@ use databend_common_base::runtime::TrySpawn;
use databend_common_catalog::table_context::StageAttachment;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::Scalar;
use databend_common_io::prelude::FormatSettings;
use databend_common_metrics::http::metrics_incr_http_response_errors_count;
use databend_common_settings::ScopeLevel;
Expand All @@ -39,7 +41,9 @@ use log::warn;
use poem::web::Json;
use poem::IntoResponse;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;

use super::HttpQueryContext;
use super::RemoveReason;
Expand Down Expand Up @@ -181,6 +185,75 @@ pub struct ServerInfo {
pub start_time: String,
}

#[derive(Deserialize, Serialize, Debug, Default, Clone, Eq, PartialEq)]
pub struct HttpSessionStateInternal {
/// value is JSON of Scalar
variables: Vec<(String, String)>,
}

impl HttpSessionStateInternal {
fn new(variables: &HashMap<String, Scalar>) -> Self {
let variables = variables
.iter()
.map(|(k, v)| {
(
k.clone(),
serde_json::to_string(&v).expect("fail to serialize Scalar"),
)
})
.collect();
Self { variables }
}

pub fn get_variables(&self) -> Result<HashMap<String, Scalar>> {
let mut vars = HashMap::with_capacity(self.variables.len());
for (k, v) in self.variables.iter() {
match serde_json::from_str::<Scalar>(v) {
Ok(s) => {
vars.insert(k.to_string(), s);
}
Err(e) => {
return Err(ErrorCode::BadBytes(format!(
"fail decode scalar from string '{v}', error: {e}"
)));
}
}
}
Ok(vars)
}
}

fn serialize_as_json_string<S>(
value: &Option<HttpSessionStateInternal>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(complex_value) => {
let json_string =
serde_json::to_string(complex_value).map_err(serde::ser::Error::custom)?;
serializer.serialize_some(&json_string)
}
None => serializer.serialize_none(),
}
}

fn deserialize_from_json_string<'de, D>(
deserializer: D,
) -> Result<Option<HttpSessionStateInternal>, D::Error>
where D: Deserializer<'de> {
let json_string: Option<String> = Option::deserialize(deserializer)?;
match json_string {
Some(s) => {
let complex_value = serde_json::from_str(&s).map_err(serde::de::Error::custom)?;
Ok(Some(complex_value))
}
None => Ok(None),
}
}

#[derive(Deserialize, Serialize, Debug, Default, Clone, Eq, PartialEq)]
pub struct HttpSessionConf {
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -189,6 +262,7 @@ pub struct HttpSessionConf {
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secondary_roles: Option<Vec<String>>,
// todo: remove this later
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_server_session_secs: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -198,9 +272,19 @@ pub struct HttpSessionConf {
// used to check if the session is still on the same server
#[serde(skip_serializing_if = "Option::is_none")]
pub last_server_info: Option<ServerInfo>,
// last_query_ids[0] is the last query id, last_query_ids[1] is the second last query id, etc.
/// last_query_ids[0] is the last query id, last_query_ids[1] is the second last query id, etc.
#[serde(default)]
pub last_query_ids: Vec<String>,
/// hide state not useful to clients
/// so client only need to know there is a String field `internal`,
/// which need to carry with session/conn
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(
serialize_with = "serialize_as_json_string",
deserialize_with = "deserialize_from_json_string"
)]
pub internal: Option<HttpSessionStateInternal>,
}

impl HttpSessionConf {}
Expand Down Expand Up @@ -360,6 +444,11 @@ impl HttpQuery {
})?;
}
}
if let Some(state) = &session_conf.internal {
if !state.variables.is_empty() {
session.set_all_variables(state.get_variables()?)
}
}
try_set_txn(&ctx.query_id, &session, session_conf, &http_query_manager)?;
};

Expand Down Expand Up @@ -548,6 +637,11 @@ impl HttpQuery {
let role = session_state.current_role.clone();
let secondary_roles = session_state.secondary_roles.clone();
let txn_state = session_state.txn_manager.lock().state();
let internal = if !session_state.variables.is_empty() {
Some(HttpSessionStateInternal::new(&session_state.variables))
} else {
None
};
if txn_state != TxnState::AutoCommit
&& !self.is_txn_mgr_saved.load(Ordering::Relaxed)
&& matches!(executor.state, ExecuteState::Stopped(_))
Expand All @@ -573,6 +667,7 @@ impl HttpQuery {
txn_state: Some(txn_state),
last_server_info: Some(HttpQueryManager::instance().server_info.clone()),
last_query_ids: vec![self.id.clone()],
internal,
}
}

Expand Down
10 changes: 10 additions & 0 deletions src/query/service/src/sessions/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;

Expand All @@ -20,6 +21,7 @@ use databend_common_catalog::cluster_info::Cluster;
use databend_common_config::GlobalConfig;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::Scalar;
use databend_common_io::prelude::FormatSettings;
use databend_common_meta_app::principal::GrantObject;
use databend_common_meta_app::principal::OwnershipObject;
Expand Down Expand Up @@ -352,6 +354,14 @@ impl Session {
Some(x) => x.get_query_profiles(),
}
}

pub fn get_all_variables(&self) -> HashMap<String, Scalar> {
self.session_ctx.get_all_variables()
}

pub fn set_all_variables(&self, variables: HashMap<String, Scalar>) {
self.session_ctx.set_all_variables(variables)
}
}

impl Drop for Session {
Expand Down
6 changes: 6 additions & 0 deletions src/query/service/src/sessions/session_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,10 @@ impl SessionContext {
pub fn get_variable(&self, key: &str) -> Option<Scalar> {
self.variables.read().get(key).cloned()
}
pub fn get_all_variables(&self) -> HashMap<String, Scalar> {
self.variables.read().clone()
}
pub fn set_all_variables(&self, variables: HashMap<String, Scalar>) {
*self.variables.write() = variables
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,7 @@ async fn test_affect() -> Result<()> {
txn_state: Some(TxnState::AutoCommit),
last_server_info: None,
last_query_ids: vec![],
internal: None,
}),
),
(
Expand All @@ -1415,6 +1416,7 @@ async fn test_affect() -> Result<()> {
txn_state: Some(TxnState::AutoCommit),
last_server_info: None,
last_query_ids: vec![],
internal: None,
}),
),
(
Expand All @@ -1432,6 +1434,7 @@ async fn test_affect() -> Result<()> {
txn_state: Some(TxnState::AutoCommit),
last_server_info: None,
last_query_ids: vec![],
internal: None,
}),
),
(
Expand All @@ -1451,6 +1454,7 @@ async fn test_affect() -> Result<()> {
txn_state: Some(TxnState::AutoCommit),
last_server_info: None,
last_query_ids: vec![],
internal: None,
}),
),
(
Expand All @@ -1472,6 +1476,7 @@ async fn test_affect() -> Result<()> {
txn_state: Some(TxnState::AutoCommit),
last_server_info: None,
last_query_ids: vec![],
internal: None,
}),
),
];
Expand Down
1 change: 1 addition & 0 deletions tests/sqllogictests/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct HttpSessionConf {
pub last_server_info: Option<ServerInfo>,
#[serde(default)]
pub last_query_ids: Vec<String>,
pub internal: Option<String>,
}

pub fn parser_rows(rows: &Value) -> Result<Vec<Vec<String>>> {
Expand Down
8 changes: 1 addition & 7 deletions tests/sqllogictests/suites/query/set.test
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ select value, default = value from system.settings where name in ('max_threads'
4 0
56 0

onlyif mysql
statement ok
set variable (a, b) = (select 3, 55)

onlyif mysql

statement ok
SET GLOBAL (max_threads, storage_io_min_bytes_for_seek) = select $a + 1, $b + 1;

Expand All @@ -30,25 +29,20 @@ select default = value from system.settings where name in ('max_threads', 'stor
1
1

onlyif mysql
statement ok
set variable a = 1;

onlyif mysql
statement ok
set variable (b, c) = ('yy', 'zz');

onlyif mysql
query ITT
select $a + getvariable('a') + $a, getvariable('b'), getvariable('c'), getvariable('d')
----
3 yy zz NULL

onlyif mysql
statement ok
unset variable (a, b)

onlyif mysql
query ITT
select getvariable('a'), getvariable('b'), 'xx' || 'yy' || getvariable('c') , getvariable('d')
----
Expand Down
Loading