diff --git a/src/db3js/src/lib/db3.ts b/src/db3js/src/lib/db3.ts index 4a42a5e5..fb321966 100644 --- a/src/db3js/src/lib/db3.ts +++ b/src/db3js/src/lib/db3.ts @@ -196,12 +196,19 @@ export class DB3 { if (this.querySessionInfo) { return {} } + + const sessionRequest = new db3_node_pb.OpenSessionRequest() const header = window.crypto.getRandomValues(new Uint8Array(32)) - const [signature, public_key] = await sign(header) - sessionRequest.setHeader(header) + const payload = new db3_session_pb.OpenSessionPayload() + payload.setHeader(header.toString()) + payload.setStartTime(Math.floor(Date.now() / 1000)) + const payloadU8 = payload.serializeBinary() + const [signature, public_key] = await sign(payloadU8) + sessionRequest.setPayload(payloadU8) sessionRequest.setSignature(signature) sessionRequest.setPublicKey(public_key) + try { const res = await this.client.openQuerySession(sessionRequest, {}) this.sessionToken = res.getSessionToken() diff --git a/src/db3js/src/pkg/db3_node_pb.d.ts b/src/db3js/src/pkg/db3_node_pb.d.ts index b3445c22..ad50c9ac 100644 --- a/src/db3js/src/pkg/db3_node_pb.d.ts +++ b/src/db3js/src/pkg/db3_node_pb.d.ts @@ -360,10 +360,10 @@ export namespace GetSessionInfoRequest { } export class OpenSessionRequest extends jspb.Message { - getHeader(): Uint8Array | string; - getHeader_asU8(): Uint8Array; - getHeader_asB64(): string; - setHeader(value: Uint8Array | string): OpenSessionRequest; + getPayload(): Uint8Array | string; + getPayload_asU8(): Uint8Array; + getPayload_asB64(): string; + setPayload(value: Uint8Array | string): OpenSessionRequest; getSignature(): Uint8Array | string; getSignature_asU8(): Uint8Array; @@ -385,7 +385,7 @@ export class OpenSessionRequest extends jspb.Message { export namespace OpenSessionRequest { export type AsObject = { - header: Uint8Array | string, + payload: Uint8Array | string, signature: Uint8Array | string, publicKey: Uint8Array | string, } diff --git a/src/db3js/src/pkg/db3_node_pb.js b/src/db3js/src/pkg/db3_node_pb.js index a1ea7e6c..76362575 100644 --- a/src/db3js/src/pkg/db3_node_pb.js +++ b/src/db3js/src/pkg/db3_node_pb.js @@ -3743,88 +3743,94 @@ proto.db3_node_proto.OpenSessionRequest.prototype.serializeBinary = function() { * @param {!jspb.BinaryWriter} writer * @suppress {unusedLocalVariables} f is only used for nested messages */ -proto.db3_node_proto.OpenSessionRequest.serializeBinaryToWriter = function( - message, - writer, -) { - var f = undefined; - f = message.getHeader_asU8(); - if (f.length > 0) { - writer.writeBytes(1, f); - } - f = message.getSignature_asU8(); - if (f.length > 0) { - writer.writeBytes(2, f); - } - f = message.getPublicKey_asU8(); - if (f.length > 0) { - writer.writeBytes(3, f); - } +proto.db3_node_proto.OpenSessionRequest.serializeBinaryToWriter = function (message, writer) { + var f = undefined; + f = message.getPayload_asU8(); + if (f.length > 0) { + writer.writeBytes( + 1, + f + ); + } + f = message.getSignature_asU8(); + if (f.length > 0) { + writer.writeBytes( + 2, + f + ); + } + f = message.getPublicKey_asU8(); + if (f.length > 0) { + writer.writeBytes( + 3, + f + ); + } }; + /** - * optional bytes header = 1; + * optional bytes payload = 1; * @return {string} */ -proto.db3_node_proto.OpenSessionRequest.prototype.getHeader = function() { - return /** @type {string} */ (jspb.Message.getFieldWithDefault( - this, - 1, - "", - )); +proto.db3_node_proto.OpenSessionRequest.prototype.getPayload = function () { + return /** @type {string} */ (jspb.Message.getFieldWithDefault(this, 1, "")); }; + /** - * optional bytes header = 1; - * This is a type-conversion wrapper around `getHeader()` + * optional bytes payload = 1; + * This is a type-conversion wrapper around `getPayload()` * @return {string} */ -proto.db3_node_proto.OpenSessionRequest.prototype.getHeader_asB64 = function() { - return /** @type {string} */ (jspb.Message.bytesAsB64(this.getHeader())); +proto.db3_node_proto.OpenSessionRequest.prototype.getPayload_asB64 = function () { + return /** @type {string} */ (jspb.Message.bytesAsB64( + this.getPayload())); }; + /** - * optional bytes header = 1; + * optional bytes payload = 1; * Note that Uint8Array is not supported on all browsers. * @see http://caniuse.com/Uint8Array - * This is a type-conversion wrapper around `getHeader()` + * This is a type-conversion wrapper around `getPayload()` * @return {!Uint8Array} */ -proto.db3_node_proto.OpenSessionRequest.prototype.getHeader_asU8 = function() { - return /** @type {!Uint8Array} */ (jspb.Message.bytesAsU8( - this.getHeader(), - )); +proto.db3_node_proto.OpenSessionRequest.prototype.getPayload_asU8 = function () { + return /** @type {!Uint8Array} */ (jspb.Message.bytesAsU8( + this.getPayload())); }; + /** * @param {!(string|Uint8Array)} value * @return {!proto.db3_node_proto.OpenSessionRequest} returns this */ -proto.db3_node_proto.OpenSessionRequest.prototype.setHeader = function(value) { - return jspb.Message.setProto3BytesField(this, 1, value); +proto.db3_node_proto.OpenSessionRequest.prototype.setPayload = function (value) { + return jspb.Message.setProto3BytesField(this, 1, value); }; + /** * optional bytes signature = 2; * @return {string} */ -proto.db3_node_proto.OpenSessionRequest.prototype.getSignature = function() { - return /** @type {string} */ (jspb.Message.getFieldWithDefault( - this, - 2, - "", - )); +proto.db3_node_proto.OpenSessionRequest.prototype.getSignature = function () { + return /** @type {string} */ (jspb.Message.getFieldWithDefault(this, 2, "")); }; + /** * optional bytes signature = 2; * This is a type-conversion wrapper around `getSignature()` * @return {string} */ -proto.db3_node_proto.OpenSessionRequest.prototype.getSignature_asB64 = function() { - return /** @type {string} */ (jspb.Message.bytesAsB64(this.getSignature())); +proto.db3_node_proto.OpenSessionRequest.prototype.getSignature_asB64 = function () { + return /** @type {string} */ (jspb.Message.bytesAsB64( + this.getSignature())); }; + /** * optional bytes signature = 2; * Note that Uint8Array is not supported on all browsers. diff --git a/src/db3js/src/pkg/db3_session_pb.d.ts b/src/db3js/src/pkg/db3_session_pb.d.ts index e430a215..ced904cd 100644 --- a/src/db3js/src/pkg/db3_session_pb.d.ts +++ b/src/db3js/src/pkg/db3_session_pb.d.ts @@ -57,6 +57,28 @@ export namespace CloseSessionPayload { } } +export class OpenSessionPayload extends jspb.Message { + getHeader(): string; + setHeader(value: string): OpenSessionPayload; + + getStartTime(): number; + setStartTime(value: number): OpenSessionPayload; + + serializeBinary(): Uint8Array; + toObject(includeInstance?: boolean): OpenSessionPayload.AsObject; + static toObject(includeInstance: boolean, msg: OpenSessionPayload): OpenSessionPayload.AsObject; + static serializeBinaryToWriter(message: OpenSessionPayload, writer: jspb.BinaryWriter): void; + static deserializeBinary(bytes: Uint8Array): OpenSessionPayload; + static deserializeBinaryFromReader(message: OpenSessionPayload, reader: jspb.BinaryReader): OpenSessionPayload; +} + +export namespace OpenSessionPayload { + export type AsObject = { + header: string, + startTime: number, + } +} + export class QuerySession extends jspb.Message { getNonce(): number; setNonce(value: number): QuerySession; diff --git a/src/db3js/src/pkg/db3_session_pb.js b/src/db3js/src/pkg/db3_session_pb.js index 2e0f8230..6fb354b0 100644 --- a/src/db3js/src/pkg/db3_session_pb.js +++ b/src/db3js/src/pkg/db3_session_pb.js @@ -24,6 +24,11 @@ globalThis.__hack = goog.exportSymbol( null, global, ); +globalThis.__hack = goog.exportSymbol( + "proto.db3_session_proto.OpenSessionPayload", + null, + global, +); globalThis.__hack = goog.exportSymbol( "proto.db3_session_proto.QuerySession", null, @@ -72,16 +77,36 @@ if (goog.DEBUG && !COMPILED) { * @constructor */ proto.db3_session_proto.CloseSessionPayload = function(opt_data) { - jspb.Message.initialize(this, opt_data, 0, -1, null, null); + jspb.Message.initialize(this, opt_data, 0, -1, null, null); }; goog.inherits(proto.db3_session_proto.CloseSessionPayload, jspb.Message); if (goog.DEBUG && !COMPILED) { - /** - * @public - * @override - */ - proto.db3_session_proto.CloseSessionPayload.displayName = - "proto.db3_session_proto.CloseSessionPayload"; + /** + * @public + * @override + */ + proto.db3_session_proto.CloseSessionPayload.displayName = 'proto.db3_session_proto.CloseSessionPayload'; +} +/** + * Generated by JsPbCodeGenerator. + * @param {Array=} opt_data Optional initial data array, typically from a + * server response, or constructed directly in Javascript. The array is used + * in place and becomes part of the constructed object. It is not cloned. + * If no data is provided, the constructed object will be empty, but still + * valid. + * @extends {jspb.Message} + * @constructor + */ +proto.db3_session_proto.OpenSessionPayload = function(opt_data) { + jspb.Message.initialize(this, opt_data, 0, -1, null, null); +}; +goog.inherits(proto.db3_session_proto.OpenSessionPayload, jspb.Message); +if (goog.DEBUG && !COMPILED) { + /** + * @public + * @override + */ + proto.db3_session_proto.OpenSessionPayload.displayName = 'proto.db3_session_proto.OpenSessionPayload'; } /** * Generated by JsPbCodeGenerator. @@ -522,16 +547,179 @@ proto.db3_session_proto.CloseSessionPayload.prototype.getSessionToken = function )); }; + /** * @param {string} value * @return {!proto.db3_session_proto.CloseSessionPayload} returns this */ -proto.db3_session_proto.CloseSessionPayload.prototype.setSessionToken = function( - value, -) { - return jspb.Message.setProto3StringField(this, 2, value); +proto.db3_session_proto.CloseSessionPayload.prototype.setSessionToken = function(value) { + return jspb.Message.setProto3StringField(this, 2, value); +}; + + + + + +if (jspb.Message.GENERATE_TO_OBJECT) { +/** + * Creates an object representation of this proto. + * Field names that are reserved in JavaScript and will be renamed to pb_name. + * Optional fields that are not set will be set to undefined. + * To access a reserved field use, foo.pb_, eg, foo.pb_default. + * For the list of reserved names please see: + * net/proto2/compiler/js/internal/generator.cc#kKeyword. + * @param {boolean=} opt_includeInstance Deprecated. whether to include the + * JSPB instance for transitional soy proto support: + * http://goto/soy-param-migration + * @return {!Object} + */ +proto.db3_session_proto.OpenSessionPayload.prototype.toObject = function(opt_includeInstance) { + return proto.db3_session_proto.OpenSessionPayload.toObject(opt_includeInstance, this); +}; + + +/** + * Static version of the {@see toObject} method. + * @param {boolean|undefined} includeInstance Deprecated. Whether to include + * the JSPB instance for transitional soy proto support: + * http://goto/soy-param-migration + * @param {!proto.db3_session_proto.OpenSessionPayload} msg The msg instance to transform. + * @return {!Object} + * @suppress {unusedLocalVariables} f is only used for nested messages + */ +proto.db3_session_proto.OpenSessionPayload.toObject = function(includeInstance, msg) { + var f, obj = { + header: jspb.Message.getFieldWithDefault(msg, 1, ""), + startTime: jspb.Message.getFieldWithDefault(msg, 2, 0) + }; + + if (includeInstance) { + obj.$jspbMessageInstance = msg; + } + return obj; +}; +} + + +/** + * Deserializes binary data (in protobuf wire format). + * @param {jspb.ByteSource} bytes The bytes to deserialize. + * @return {!proto.db3_session_proto.OpenSessionPayload} + */ +proto.db3_session_proto.OpenSessionPayload.deserializeBinary = function(bytes) { + var reader = new jspb.BinaryReader(bytes); + var msg = new proto.db3_session_proto.OpenSessionPayload; + return proto.db3_session_proto.OpenSessionPayload.deserializeBinaryFromReader(msg, reader); +}; + + +/** + * Deserializes binary data (in protobuf wire format) from the + * given reader into the given message object. + * @param {!proto.db3_session_proto.OpenSessionPayload} msg The message object to deserialize into. + * @param {!jspb.BinaryReader} reader The BinaryReader to use. + * @return {!proto.db3_session_proto.OpenSessionPayload} + */ +proto.db3_session_proto.OpenSessionPayload.deserializeBinaryFromReader = function(msg, reader) { + while (reader.nextField()) { + if (reader.isEndGroup()) { + break; + } + var field = reader.getFieldNumber(); + switch (field) { + case 1: + var value = /** @type {string} */ (reader.readString()); + msg.setHeader(value); + break; + case 2: + var value = /** @type {number} */ (reader.readInt64()); + msg.setStartTime(value); + break; + default: + reader.skipField(); + break; + } + } + return msg; +}; + + +/** + * Serializes the message to binary data (in protobuf wire format). + * @return {!Uint8Array} + */ +proto.db3_session_proto.OpenSessionPayload.prototype.serializeBinary = function() { + var writer = new jspb.BinaryWriter(); + proto.db3_session_proto.OpenSessionPayload.serializeBinaryToWriter(this, writer); + return writer.getResultBuffer(); +}; + + +/** + * Serializes the given message to binary data (in protobuf wire + * format), writing to the given BinaryWriter. + * @param {!proto.db3_session_proto.OpenSessionPayload} message + * @param {!jspb.BinaryWriter} writer + * @suppress {unusedLocalVariables} f is only used for nested messages + */ +proto.db3_session_proto.OpenSessionPayload.serializeBinaryToWriter = function(message, writer) { + var f = undefined; + f = message.getHeader(); + if (f.length > 0) { + writer.writeString( + 1, + f + ); + } + f = message.getStartTime(); + if (f !== 0) { + writer.writeInt64( + 2, + f + ); + } +}; + + +/** + * optional string header = 1; + * @return {string} + */ +proto.db3_session_proto.OpenSessionPayload.prototype.getHeader = function() { + return /** @type {string} */ (jspb.Message.getFieldWithDefault(this, 1, "")); }; + +/** + * @param {string} value + * @return {!proto.db3_session_proto.OpenSessionPayload} returns this + */ +proto.db3_session_proto.OpenSessionPayload.prototype.setHeader = function(value) { + return jspb.Message.setProto3StringField(this, 1, value); +}; + + +/** + * optional int64 start_time = 2; + * @return {number} + */ +proto.db3_session_proto.OpenSessionPayload.prototype.getStartTime = function() { + return /** @type {number} */ (jspb.Message.getFieldWithDefault(this, 2, 0)); +}; + + +/** + * @param {number} value + * @return {!proto.db3_session_proto.OpenSessionPayload} returns this + */ +proto.db3_session_proto.OpenSessionPayload.prototype.setStartTime = function(value) { + return jspb.Message.setProto3IntField(this, 2, value); +}; + + + + + if (jspb.Message.GENERATE_TO_OBJECT) { /** * Creates an object representation of this proto. diff --git a/src/node/src/storage_node_impl.rs b/src/node/src/storage_node_impl.rs index 1efa4e97..423ce5d6 100644 --- a/src/node/src/storage_node_impl.rs +++ b/src/node/src/storage_node_impl.rs @@ -27,7 +27,9 @@ use db3_proto::db3_node_proto::{ GetSessionInfoResponse, OpenSessionRequest, OpenSessionResponse, QueryBillRequest, QueryBillResponse, }; -use db3_proto::db3_session_proto::{CloseSessionPayload, QuerySession, QuerySessionInfo}; +use db3_proto::db3_session_proto::{ + CloseSessionPayload, OpenSessionPayload, QuerySession, QuerySessionInfo, +}; use db3_session::query_session_verifier; use db3_session::session_manager::DEFAULT_SESSION_PERIOD; use db3_session::session_manager::DEFAULT_SESSION_QUERY_LIMIT; @@ -163,15 +165,18 @@ impl StorageNode for StorageNodeImpl { ) -> std::result::Result, Status> { let r = request.into_inner(); let account_id = Verifier::verify( - r.header.as_ref(), + r.payload.as_ref(), r.signature.as_ref(), r.public_key.as_ref(), ) .map_err(|e| Status::internal(format!("{:?}", e)))?; + let payload = OpenSessionPayload::decode(r.payload.as_ref()) + .map_err(|_| Status::internal("fail to decode open session request ".to_string()))?; + let header = payload.header; match self.context.node_store.lock() { Ok(mut node_store) => { let sess_store = node_store.get_session_store(); - match sess_store.add_new_session(account_id.addr) { + match sess_store.add_new_session(&header, payload.start_time, account_id.addr) { Ok((session_token, query_session_info)) => { // Takes a reference and returns Option<&V> Ok(Response::new(OpenSessionResponse { diff --git a/src/proto/proto/db3_node.proto b/src/proto/proto/db3_node.proto index 34945c9e..d44c994c 100644 --- a/src/proto/proto/db3_node.proto +++ b/src/proto/proto/db3_node.proto @@ -96,7 +96,7 @@ message GetSessionInfoRequest { SessionIdentifier session_identifier = 1; } message OpenSessionRequest { - bytes header = 1; + bytes payload = 1; bytes signature = 2; bytes public_key = 3; } diff --git a/src/proto/proto/db3_session.proto b/src/proto/proto/db3_session.proto index 23679d9c..21c9b22d 100644 --- a/src/proto/proto/db3_session.proto +++ b/src/proto/proto/db3_session.proto @@ -39,7 +39,10 @@ message CloseSessionPayload { QuerySessionInfo session_info = 1; string session_token = 2; } - +message OpenSessionPayload { + string header = 1; + int64 start_time = 2; +} message QuerySession { // the counter of account uint64 nonce = 1; @@ -56,3 +59,4 @@ message QuerySession { // client public key bytes client_public_key = 7; } + diff --git a/src/sdk/Cargo.toml b/src/sdk/Cargo.toml index 41a69644..05327593 100644 --- a/src/sdk/Cargo.toml +++ b/src/sdk/Cargo.toml @@ -22,8 +22,15 @@ prost = "0.11" prost-types = "0.11" ethereum-types = { version = "0.14.0", default-features = false } subtle-encoding = { version = "0.5", default-features = false, features = ["bech32-preview"] } - +chrono = "0.4.22" [dev-dependencies] rand = "0.8.5" db3-base={path="../base", version="0.1.0"} db3-cmd={path="../cmd", version="0.1.0"} +[dependencies.uuid] +version = "1.2.2" +features = [ + "v4", # Lets you generate random UUIDs + "fast-rng", # Use a faster (but still sufficiently random) RNG + "macro-diagnostics", # Enable better diagnostics for compile-time UUIDs +] \ No newline at end of file diff --git a/src/sdk/src/store_sdk.rs b/src/sdk/src/store_sdk.rs index b703ac76..fe36bfb8 100644 --- a/src/sdk/src/store_sdk.rs +++ b/src/sdk/src/store_sdk.rs @@ -16,6 +16,7 @@ // use bytes::BytesMut; +use chrono::Utc; use db3_crypto::signer::Db3Signer; use db3_proto::db3_account_proto::Account; use db3_proto::db3_bill_proto::Bill; @@ -25,14 +26,14 @@ use db3_proto::db3_node_proto::{ OpenSessionRequest, OpenSessionResponse, QueryBillKey, QueryBillRequest, Range as DB3Range, RangeKey, RangeValue, SessionIdentifier, }; -use db3_proto::db3_session_proto::{CloseSessionPayload, QuerySessionInfo}; +use db3_proto::db3_session_proto::{CloseSessionPayload, OpenSessionPayload, QuerySessionInfo}; use db3_session::session_manager::SessionPool; use ethereum_types::Address as AccountAddress; use prost::Message; use std::sync::Arc; use subtle_encoding::base64; use tonic::Status; - +use uuid::Uuid; pub struct StoreSDK { client: Arc>, signer: Db3Signer, @@ -52,13 +53,21 @@ impl StoreSDK { } pub async fn open_session(&mut self) -> std::result::Result { - let buf = "Header".as_bytes(); + let payload = OpenSessionPayload { + header: Uuid::new_v4().to_string(), + start_time: Utc::now().timestamp(), + }; + let mut buf = BytesMut::with_capacity(1024 * 8); + payload + .encode(&mut buf) + .map_err(|e| Status::internal(format!("{}", e)))?; + let buf = buf.freeze(); let (signature, public_key) = self .signer .sign(buf.as_ref()) .map_err(|e| Status::internal(format!("{:?}", e)))?; let r = OpenSessionRequest { - header: buf.to_vec(), + payload: buf.as_ref().to_vec(), signature: signature.as_ref().to_vec(), public_key: public_key.as_ref().to_vec(), }; @@ -71,7 +80,7 @@ impl StoreSDK { .insert_session_with_token(&result.query_session_info.unwrap(), &result.session_token) { Ok(_) => Ok(response.clone()), - Err(e) => Err(Status::internal(format!("Fail to create session {}", e))), + Err(e) => Err(Status::internal(format!("Fail to open session {}", e))), } } /// close session @@ -262,15 +271,21 @@ impl StoreSDK { mod tests { use super::Db3Signer; use super::StoreSDK; + use super::*; use crate::mutation_sdk::MutationSDK; + use bytes::BytesMut; + use chrono::Utc; use db3_base::{get_a_random_nonce, get_a_static_keypair, get_address_from_pk}; use db3_proto::db3_base_proto::{ChainId, ChainRole}; use db3_proto::db3_mutation_proto::KvPair; use db3_proto::db3_mutation_proto::{Mutation, MutationAction}; use db3_proto::db3_node_proto::storage_node_client::StorageNodeClient; + use db3_proto::db3_node_proto::OpenSessionRequest; + use db3_proto::db3_session_proto::OpenSessionPayload; use std::sync::Arc; use std::time; use tonic::transport::Endpoint; + use uuid::Uuid; #[tokio::test] async fn it_get_bills() { @@ -519,4 +534,73 @@ mod tests { "query session verify fail. expect query count 1 but 101" ); } + + #[tokio::test] + async fn open_session_replay_attach() { + let mut rng = rand::thread_rng(); + let nonce = get_a_random_nonce(); + + let ep = "http://127.0.0.1:26659"; + let rpc_endpoint = Endpoint::new(ep.to_string()).unwrap(); + let channel = rpc_endpoint.connect_lazy(); + let mut client = StorageNodeClient::new(channel); + let kp = get_a_static_keypair(); + let signer = Db3Signer::new(kp); + let payload = OpenSessionPayload { + header: Uuid::new_v4().to_string(), + start_time: Utc::now().timestamp(), + }; + let mut buf = BytesMut::with_capacity(1024 * 8); + payload.encode(&mut buf); + let buf = buf.freeze(); + let (signature, public_key) = signer + .sign(buf.as_ref()) + .map_err(|e| Status::internal(format!("{:?}", e))) + .unwrap(); + let r = OpenSessionRequest { + payload: buf.as_ref().to_vec(), + signature: signature.as_ref().to_vec(), + public_key: public_key.as_ref().to_vec(), + }; + let request = tonic::Request::new(r.clone()); + let response = client.open_query_session(request).await; + assert!(response.is_ok()); + + // duplicate header + std::thread::sleep(time::Duration::from_millis(1000)); + let request = tonic::Request::new(r.clone()); + let response = client.open_query_session(request).await; + assert!(response.is_err()); + } + #[tokio::test] + async fn open_session_ttl_expiered() { + let mut rng = rand::thread_rng(); + let nonce = get_a_random_nonce(); + + let ep = "http://127.0.0.1:26659"; + let rpc_endpoint = Endpoint::new(ep.to_string()).unwrap(); + let channel = rpc_endpoint.connect_lazy(); + let mut client = StorageNodeClient::new(channel); + let kp = get_a_static_keypair(); + let signer = Db3Signer::new(kp); + let payload = OpenSessionPayload { + header: Uuid::new_v4().to_string(), + start_time: Utc::now().timestamp() - 6, + }; + let mut buf = BytesMut::with_capacity(1024 * 8); + payload.encode(&mut buf); + let buf = buf.freeze(); + let (signature, public_key) = signer + .sign(buf.as_ref()) + .map_err(|e| Status::internal(format!("{:?}", e))) + .unwrap(); + let r = OpenSessionRequest { + payload: buf.as_ref().to_vec(), + signature: signature.as_ref().to_vec(), + public_key: public_key.as_ref().to_vec(), + }; + let request = tonic::Request::new(r.clone()); + let response = client.open_query_session(request).await; + assert!(response.is_err()); + } } diff --git a/src/session/src/session_manager.rs b/src/session/src/session_manager.rs index 97374c25..aa16d642 100644 --- a/src/session/src/session_manager.rs +++ b/src/session/src/session_manager.rs @@ -18,7 +18,7 @@ use chrono::Utc; use db3_proto::db3_session_proto::{QuerySessionInfo, SessionStatus}; use ethereum_types::Address; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use uuid::Uuid; // retry generate token @@ -33,6 +33,9 @@ pub const DEFAULT_SESSION_POOL_SIZE_LIMIT: usize = 1000; // default session clean period 1 min pub const DEFAULT_CLEANUP_SESSION_PERIOD: i64 = 60; +// default session ttl 5s +pub const DEFAULT_SESSION_TTL: i64 = 5; + pub struct SessionPool { session_pool: HashMap, last_cleanup_time: i64, @@ -63,6 +66,7 @@ impl SessionPool { &mut self, sid: i32, token: &str, + start_time: i64, ) -> Result<(String, QuerySessionInfo), String> { if self.need_cleanup() { self.cleanup_session(); @@ -74,7 +78,7 @@ impl SessionPool { DEFAULT_SESSION_POOL_SIZE_LIMIT )); } - let sess = SessionManager::create_session(sid); + let sess = SessionManager::create_session(sid, start_time); self.session_pool.insert(token.to_string(), sess.clone()); return Ok((token.to_string(), sess.session_info)); } @@ -118,6 +122,7 @@ impl SessionPool { pub struct SessionStore { session_pools: HashMap, token_account_map: HashMap, + open_session_headers: HashSet, sid: i32, } @@ -126,6 +131,7 @@ impl SessionStore { SessionStore { session_pools: HashMap::new(), token_account_map: HashMap::new(), + open_session_headers: HashSet::new(), sid: 0, } } @@ -143,21 +149,46 @@ impl SessionStore { Err(format!("Fail to generate unique token after retry")) } + fn is_session_header_exit(&self, header: &String) -> bool { + self.open_session_headers.contains(header) + } + fn add_session_header(&mut self, header: &String) { + self.open_session_headers.insert(header.clone()); + } + fn is_ttl_expired(&self, ts: i64) -> bool { + Utc::now().timestamp() - ts >= DEFAULT_SESSION_TTL + } /// Add session into pool - pub fn add_new_session(&mut self, addr: Address) -> Result<(String, QuerySessionInfo), String> { + pub fn add_new_session( + &mut self, + header: &String, + start_time: i64, + addr: Address, + ) -> Result<(String, QuerySessionInfo), String> { + if self.is_ttl_expired(start_time) { + return Err(format!("Session HEADER {} ttl is expired", header)); + } + if self.is_session_header_exit(header) { + return Err(format!("Session HEADER {} already exist", header)); + } self.sid += 1; let token = self.generate_unique_token().map_err(|e| e)?; match self.session_pools.get_mut(&addr) { Some(sess_pool) => { self.token_account_map.insert(token.clone(), addr); - sess_pool.create_new_session(self.sid, &token) + let res = sess_pool.create_new_session(self.sid, &token, start_time); + if res.is_ok() { + self.add_session_header(header); + } + res } None => { let mut sess_pool = SessionPool::new(); - let res = sess_pool.create_new_session(self.sid, &token); + let res = sess_pool.create_new_session(self.sid, &token, start_time); if res.is_ok() { self.token_account_map.insert(token.clone(), addr); self.session_pools.insert(addr, sess_pool); + self.add_session_header(header); } res } @@ -209,10 +240,9 @@ pub struct SessionManager { impl SessionManager { pub fn new() -> Self { - Self::create_session(0) + Self::create_session(0, Utc::now().timestamp()) } - pub fn create_session(id: i32) -> Self { - let start_time = Utc::now().timestamp(); + pub fn create_session(id: i32, start_time: i64) -> Self { SessionManager { session_info: QuerySessionInfo { id, @@ -264,7 +294,7 @@ mod tests { use db3_base::get_a_static_keypair; use db3_base::get_address_from_pk; use db3_proto::db3_session_proto::SessionStatus; - + use uuid::Uuid; #[test] fn test_new_session() { let mut session = SessionManager::new(); @@ -300,11 +330,14 @@ mod tests { let mut sess_store = SessionStore::new(); let kp = get_a_static_keypair(); let addr = get_address_from_pk(&kp.public); + let ts = Utc::now().timestamp(); for _ in 0..DEFAULT_SESSION_POOL_SIZE_LIMIT { - assert!(sess_store.add_new_session(addr).is_ok()) + assert!(sess_store + .add_new_session(&Uuid::new_v4().to_string(), ts, addr) + .is_ok()) } - let res = sess_store.add_new_session(addr); + let res = sess_store.add_new_session(&Uuid::new_v4().to_string(), ts, addr); assert!(res.is_err()); assert_eq!( "Fail to create new session since session pool size exceed limit 1000", @@ -317,12 +350,13 @@ mod tests { let mut sess_store = SessionStore::new(); let kp = get_a_static_keypair(); let addr = get_address_from_pk(&kp.public); + let ts = Utc::now().timestamp(); // add session and create new session pool - let res = sess_store.add_new_session(addr); + let res = sess_store.add_new_session(&Uuid::new_v4().to_string(), ts, addr); assert!(res.is_ok()); let token1 = res.unwrap().0; assert_eq!(token1.len(), 36); - let res = sess_store.add_new_session(addr); + let res = sess_store.add_new_session(&Uuid::new_v4().to_string(), ts, addr); assert!(res.is_ok()); let token2 = res.unwrap().0; assert_ne!(token1, token2); @@ -332,18 +366,32 @@ mod tests { let res = sess_store.get_session_mut(&"token_unknow".to_string()); assert!(res.is_none()); } - + #[test] + fn add_session_wrong_path_duplicate_header() { + let mut sess_store = SessionStore::new(); + let kp = get_a_static_keypair(); + let addr = get_address_from_pk(&kp.public); + let header = Uuid::new_v4().to_string(); + let ts = Utc::now().timestamp(); + // add session and create new session pool + let res = sess_store.add_new_session(&header, ts, addr); + assert!(res.is_ok()); + let token1 = res.unwrap().0; + assert_eq!(token1.len(), 36); + let res = sess_store.add_new_session(&header, ts, addr); + assert!(res.is_err()); + } #[test] fn remove_session_test() { let mut sess_store = SessionStore::new(); let kp = get_a_static_keypair(); let addr = get_address_from_pk(&kp.public); - - let res = sess_store.add_new_session(addr); + let ts = Utc::now().timestamp(); + let res = sess_store.add_new_session(&Uuid::new_v4().to_string(), ts, addr); assert!(res.is_ok()); let token1 = res.unwrap().0; assert_eq!(token1.len(), 36); - let res = sess_store.add_new_session(addr); + let res = sess_store.add_new_session(&Uuid::new_v4().to_string(), ts, addr); assert!(res.is_ok()); let token2 = res.unwrap().0; assert_ne!(token1, token2); @@ -359,8 +407,11 @@ mod tests { let mut sess_store = SessionStore::new(); let kp = get_a_static_keypair(); let addr = get_address_from_pk(&kp.public); + let ts = Utc::now().timestamp(); for i in 0..100 { - let (token, _) = sess_store.add_new_session(addr).unwrap(); + let (token, _) = sess_store + .add_new_session(&Uuid::new_v4().to_string(), ts, addr) + .unwrap(); // convert session with even id into blocked status if i % 2 == 0 { @@ -388,4 +439,11 @@ mod tests { 50 ); } + #[test] + fn is_ttl_expired_test() { + let mut sess_store = SessionStore::new(); + assert!(!sess_store.is_ttl_expired(Utc::now().timestamp() - 1)); + assert!(sess_store.is_ttl_expired(Utc::now().timestamp() - 5)); + assert!(sess_store.is_ttl_expired(Utc::now().timestamp() - 10)); + } }