diff --git a/db/schema.sql b/db/schema.sql index 9faeb1a45..a797ab5cf 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -15,10 +15,12 @@ CREATE TABLE tasks( task_id BYTEA UNIQUE NOT NULL, -- 32-byte TaskID as defined by the DAP specification aggregator_role AGGREGATOR_ROLE NOT NULL, -- the role of this aggregator for this task aggregator_endpoints TEXT[] NOT NULL, -- aggregator HTTPS endpoints, leader first + query_type JSON NOT NULL, -- the query type in use for this task, along with its parameters vdaf JSON NOT NULL, -- the VDAF instance in use for this task, along with its parameters - max_batch_lifetime BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + max_batch_query_count BIGINT NOT NULL, -- the maximum number of times a given batch may be collected + task_expiration TIMESTAMP NOT NULL, -- the time after which client reports are no longer accepted min_batch_size BIGINT NOT NULL, -- the minimum number of reports in a batch to allow it to be collected - min_batch_duration BIGINT NOT NULL, -- the minimum duration in seconds of a single batch interval + time_precision BIGINT NOT NULL, -- the duration to which clients are expected to round their report timestamps, in seconds tolerable_clock_skew BIGINT NOT NULL, -- the maximum acceptable clock skew to allow between client and aggregator, in seconds collector_hpke_config BYTEA NOT NULL -- the HPKE config of the collector (encoded HpkeConfig message) ); diff --git a/integration_tests/src/daphne.rs b/integration_tests/src/daphne.rs index c51f99283..3470f5379 100644 --- a/integration_tests/src/daphne.rs +++ b/integration_tests/src/daphne.rs @@ -60,7 +60,7 @@ impl<'a> Daphne<'a> { // Aes128Gcm); this is checked in `DaphneHpkeConfig::from`. let dap_hpke_receiver_config_list = serde_json::to_string( &task - .hpke_keys + .hpke_keys() .values() .map(|(hpke_config, private_key)| DaphneHpkeReceiverConfig { config: DaphneHpkeConfig::from(hpke_config.clone()), @@ -80,16 +80,16 @@ impl<'a> Daphne<'a> { let dap_collect_id_key: [u8; 16] = random(); let dap_task_list = serde_json::to_string(&HashMap::from([( - hex::encode(task.id.as_ref()), + hex::encode(task.id().as_ref()), DaphneDapTaskConfig { version: "v01".to_string(), - leader_url: task.aggregator_url(Role::Leader).unwrap().clone(), - helper_url: task.aggregator_url(Role::Helper).unwrap().clone(), - min_batch_duration: task.min_batch_duration.as_seconds(), - min_batch_size: task.min_batch_size, - vdaf: daphne_vdaf_config_from_janus_vdaf(&task.vdaf), - vdaf_verify_key: hex::encode(task.vdaf_verify_keys().first().unwrap().as_bytes()), - collector_hpke_config: DaphneHpkeConfig::from(task.collector_hpke_config.clone()), + leader_url: task.aggregator_url(&Role::Leader).unwrap().clone(), + helper_url: task.aggregator_url(&Role::Helper).unwrap().clone(), + min_batch_duration: task.time_precision().as_seconds(), // TODO(#493): this field will likely need to be renamed + min_batch_size: task.min_batch_size(), + vdaf: daphne_vdaf_config_from_janus_vdaf(task.vdaf()), + vdaf_verify_key: hex::encode(task.vdaf_verify_keys().first().unwrap().as_ref()), + collector_hpke_config: DaphneHpkeConfig::from(task.collector_hpke_config().clone()), }, )])) .unwrap(); @@ -97,11 +97,11 @@ impl<'a> Daphne<'a> { // Daphne currently only supports one auth token per task. Janus supports multiple tokens // per task to allow rotation; we supply Daphne with the "primary" token. let aggregator_bearer_token_list = json!({ - hex::encode(task.id.as_ref()): String::from_utf8(task.primary_aggregator_auth_token().as_bytes().to_vec()).unwrap() + hex::encode(task.id().as_ref()): String::from_utf8(task.primary_aggregator_auth_token().as_bytes().to_vec()).unwrap() }).to_string(); - let collector_bearer_token_list = if task.role == Role::Leader { + let collector_bearer_token_list = if task.role() == &Role::Leader { json!({ - hex::encode(task.id.as_ref()): String::from_utf8(task.primary_collector_auth_token().as_bytes().to_vec()).unwrap() + hex::encode(task.id().as_ref()): String::from_utf8(task.primary_collector_auth_token().as_bytes().to_vec()).unwrap() }).to_string() } else { String::new() @@ -160,7 +160,7 @@ impl<'a> Daphne<'a> { // Start the Daphne test container running. let port = pick_unused_port().expect("Couldn't pick unused port"); - let endpoint = task.aggregator_url(task.role).unwrap(); + let endpoint = task.aggregator_url(task.role()).unwrap(); let args = [ ( @@ -179,7 +179,7 @@ impl<'a> Daphne<'a> { ), ( "DAP_AGGREGATOR_ROLE".to_string(), - task.role.as_str().to_string(), + task.role().as_str().to_string(), ), ( "DAP_GLOBAL_CONFIG".to_string(), @@ -233,7 +233,7 @@ impl<'a> Daphne<'a> { task::spawn({ let http_client = reqwest::Client::default(); let mut request_url = task - .aggregator_url(task.role) + .aggregator_url(task.role()) .unwrap() .join("/internal/process") .unwrap(); @@ -266,7 +266,7 @@ impl<'a> Daphne<'a> { Self { daphne_container, - role: task.role, + role: *task.role(), start_shutdown_sender: Some(start_shutdown_sender), shutdown_complete_receiver: Some(shutdown_complete_receiver), } diff --git a/integration_tests/src/janus.rs b/integration_tests/src/janus.rs index a6419a7aa..4fc8ff36d 100644 --- a/integration_tests/src/janus.rs +++ b/integration_tests/src/janus.rs @@ -53,7 +53,7 @@ impl<'a> Janus<'a> { task: &Task, ) -> Janus<'a> { // Start the Janus interop aggregator container running. - let endpoint = task.aggregator_url(task.role).unwrap(); + let endpoint = task.aggregator_url(task.role()).unwrap(); let container = container_client.run( RunnableImage::from(Aggregator::default()) .with_network(network) @@ -77,7 +77,7 @@ impl<'a> Janus<'a> { assert_eq!(resp.get("status"), Some(&Some("success".to_string()))); Self::Container { - role: task.role, + role: *task.role(), container, } } diff --git a/integration_tests/tests/common/mod.rs b/integration_tests/tests/common/mod.rs index de6d76255..281a95e08 100644 --- a/integration_tests/tests/common/mod.rs +++ b/integration_tests/tests/common/mod.rs @@ -7,69 +7,36 @@ use janus_collector::{ use janus_core::{ hpke::{test_util::generate_test_hpke_config_and_private_key, HpkePrivateKey}, retries::test_http_request_exponential_backoff, - task::{AuthenticationToken, VdafInstance}, + task::VdafInstance, time::{Clock, RealClock, TimeExt}, }; -use janus_messages::{Duration, HpkeConfig, Interval, Role}; -use janus_server::{ - messages::DurationExt, - task::{test_util::generate_auth_token, Task, PRIO3_AES128_VERIFY_KEY_LENGTH}, - SecretBytes, -}; +use janus_messages::{Duration, Interval, Role}; +use janus_server::task::{test_util::TaskBuilder, QueryType, Task}; use prio::vdaf::prio3::Prio3; use rand::random; use reqwest::Url; use std::iter; use tokio::time; -// Returns (leader_task, helper_task). -pub fn create_test_tasks(collector_hpke_config: &HpkeConfig) -> (Task, Task) { - // Generate parameters. - let task_id = random(); - let buf: [u8; 4] = random(); - let endpoints = Vec::from([ - Url::parse(&format!("http://leader-{}:8080/", hex::encode(buf))).unwrap(), - Url::parse(&format!("http://helper-{}:8080/", hex::encode(buf))).unwrap(), - ]); - let vdaf_verify_key: [u8; PRIO3_AES128_VERIFY_KEY_LENGTH] = random(); - let vdaf_verify_keys = Vec::from([SecretBytes::new(vdaf_verify_key.to_vec())]); - let aggregator_auth_tokens = Vec::from([generate_auth_token()]); - - // Create tasks & return. - let leader_task = Task::new( - task_id, - endpoints.clone(), +// Returns (collector_private_key, leader_task, helper_task). +pub fn test_task_builders() -> (HpkePrivateKey, TaskBuilder, TaskBuilder) { + let endpoint_random_value = hex::encode(random::<[u8; 4]>()); + let (collector_hpke_config, collector_private_key) = + generate_test_hpke_config_and_private_key(); + let leader_task = TaskBuilder::new( + QueryType::TimeInterval, VdafInstance::Prio3Aes128Count.into(), Role::Leader, - vdaf_verify_keys.clone(), - 1, - 46, - Duration::from_hours(8).unwrap(), - Duration::from_minutes(10).unwrap(), - collector_hpke_config.clone(), - aggregator_auth_tokens.clone(), - Vec::from([generate_auth_token()]), - Vec::from([generate_test_hpke_config_and_private_key()]), ) - .unwrap(); - let helper_task = Task::new( - task_id, - endpoints, - VdafInstance::Prio3Aes128Count.into(), - Role::Helper, - vdaf_verify_keys, - 1, - 46, - Duration::from_hours(8).unwrap(), - Duration::from_minutes(10).unwrap(), - collector_hpke_config.clone(), - aggregator_auth_tokens, - Vec::new(), - Vec::from([generate_test_hpke_config_and_private_key()]), - ) - .unwrap(); + .with_aggregator_endpoints(Vec::from([ + Url::parse(&format!("http://leader-{endpoint_random_value}:8080/")).unwrap(), + Url::parse(&format!("http://helper-{endpoint_random_value}:8080/")).unwrap(), + ])) + .with_min_batch_size(46) + .with_collector_hpke_config(collector_hpke_config); + let helper_task = leader_task.clone().with_role(Role::Helper); - (leader_task, helper_task) + (collector_private_key, leader_task, helper_task) } pub fn translate_url_for_external_access(url: &Url, external_port: u16) -> Url { @@ -86,33 +53,32 @@ pub async fn submit_measurements_and_verify_aggregate( ) { // Translate aggregator endpoints for our perspective outside the container network. let aggregator_endpoints: Vec<_> = leader_task - .aggregator_endpoints + .aggregator_endpoints() .iter() .zip([leader_port, helper_port]) .map(|(url, port)| translate_url_for_external_access(url, port)) .collect(); // Create client. - let task_id = leader_task.id; let vdaf = Prio3::new_aes128_count(2).unwrap(); let client_parameters = ClientParameters::new( - task_id, + *leader_task.id(), aggregator_endpoints.clone(), - leader_task.min_batch_duration, + *leader_task.time_precision(), ); let http_client = janus_client::default_http_client().unwrap(); let leader_report_config = janus_client::aggregator_hpke_config( &client_parameters, - Role::Leader, - task_id, + &Role::Leader, + leader_task.id(), &http_client, ) .await .unwrap(); let helper_report_config = janus_client::aggregator_hpke_config( &client_parameters, - Role::Helper, - task_id, + &Role::Helper, + leader_task.id(), &http_client, ) .await @@ -132,7 +98,7 @@ pub async fn submit_measurements_and_verify_aggregate( // We generate exactly one batch's worth of measurement uploads to work around an issue in // Daphne at time of writing. let clock = RealClock::default(); - let total_measurements: usize = leader_task.min_batch_size.try_into().unwrap(); + let total_measurements: usize = leader_task.min_batch_size().try_into().unwrap(); let num_nonzero_measurements = total_measurements / 2; let num_zero_measurements = total_measurements - num_nonzero_measurements; assert!(num_nonzero_measurements > 0 && num_zero_measurements > 0); @@ -147,23 +113,18 @@ pub async fn submit_measurements_and_verify_aggregate( // Send a collect request. let batch_interval = Interval::new( before_timestamp - .to_batch_unit_interval_start(leader_task.min_batch_duration) + .to_batch_unit_interval_start(leader_task.time_precision()) .unwrap(), - // Use two minimum batch durations as the interval duration in order to avoid a race - // condition if this test happens to run very close to the end of a batch window. - Duration::from_seconds(2 * leader_task.min_batch_duration.as_seconds()), + // Use two time precisions as the interval duration in order to avoid a race condition if + // this test happens to run very close to the end of a batch window. + Duration::from_seconds(2 * leader_task.time_precision().as_seconds()), ) .unwrap(); let collector_params = CollectorParameters::new( - task_id, + *leader_task.id(), aggregator_endpoints[Role::Leader.index().unwrap()].clone(), - AuthenticationToken::from( - leader_task - .primary_collector_auth_token() - .as_bytes() - .to_vec(), - ), - leader_task.collector_hpke_config.clone(), + leader_task.primary_collector_auth_token().clone(), + leader_task.collector_hpke_config().clone(), collector_private_key.clone(), ) .with_http_request_backoff(test_http_request_exponential_backoff()) diff --git a/integration_tests/tests/daphne.rs b/integration_tests/tests/daphne.rs index 84daa9870..aaadd79af 100644 --- a/integration_tests/tests/daphne.rs +++ b/integration_tests/tests/daphne.rs @@ -1,13 +1,11 @@ #![cfg(feature = "daphne")] -use common::{create_test_tasks, submit_measurements_and_verify_aggregate}; +use common::{submit_measurements_and_verify_aggregate, test_task_builders}; use integration_tests::{daphne::Daphne, janus::Janus}; use interop_binaries::test_util::generate_network_name; -use janus_core::{ - hpke::test_util::generate_test_hpke_config_and_private_key, - test_util::{install_test_trace_subscriber, testcontainers::container_client}, -}; +use janus_core::test_util::{install_test_trace_subscriber, testcontainers::container_client}; use janus_messages::Role; +use janus_server::task::Task; mod common; @@ -18,17 +16,19 @@ async fn daphne_janus() { // Start servers. let network = generate_network_name(); - let (collector_hpke_config, collector_private_key) = - generate_test_hpke_config_and_private_key(); - let (mut leader_task, mut helper_task) = create_test_tasks(&collector_hpke_config); + let (collector_private_key, leader_task, helper_task) = test_task_builders(); // Daphne is hardcoded to serve from a path starting with /v01/. - for task in [&mut leader_task, &mut helper_task] { - task.aggregator_endpoints - .get_mut(Role::Leader.index().unwrap()) - .unwrap() - .set_path("/v01/"); - } + let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] + .into_iter() + .map(|task| { + let mut endpoints = task.aggregator_endpoints().to_vec(); + endpoints[Role::Leader.index().unwrap()].set_path("/v01/"); + task.with_aggregator_endpoints(endpoints).build() + }) + .collect::>() + .try_into() + .unwrap(); let container_client = container_client(); let leader = Daphne::new(&container_client, &network, &leader_task).await; @@ -50,17 +50,19 @@ async fn janus_daphne() { // Start servers. let network = generate_network_name(); - let (collector_hpke_config, collector_private_key) = - generate_test_hpke_config_and_private_key(); - let (mut leader_task, mut helper_task) = create_test_tasks(&collector_hpke_config); + let (collector_private_key, leader_task, helper_task) = test_task_builders(); // Daphne is hardcoded to serve from a path starting with /v01/. - for task in [&mut leader_task, &mut helper_task] { - task.aggregator_endpoints - .get_mut(Role::Helper.index().unwrap()) - .unwrap() - .set_path("/v01/"); - } + let [leader_task, helper_task]: [Task; 2] = [leader_task, helper_task] + .into_iter() + .map(|task| { + let mut endpoints = task.aggregator_endpoints().to_vec(); + endpoints[Role::Helper.index().unwrap()].set_path("/v01/"); + task.with_aggregator_endpoints(endpoints).build() + }) + .collect::>() + .try_into() + .unwrap(); let container_client = container_client(); let leader = Janus::new_in_container(&container_client, &network, &leader_task).await; diff --git a/integration_tests/tests/janus.rs b/integration_tests/tests/janus.rs index 8722556b7..d8a9871ec 100644 --- a/integration_tests/tests/janus.rs +++ b/integration_tests/tests/janus.rs @@ -1,8 +1,8 @@ -use common::{create_test_tasks, submit_measurements_and_verify_aggregate}; +use common::{submit_measurements_and_verify_aggregate, test_task_builders}; use integration_tests::janus::Janus; use interop_binaries::test_util::generate_network_name; use janus_core::{ - hpke::{test_util::generate_test_hpke_config_and_private_key, HpkePrivateKey}, + hpke::HpkePrivateKey, test_util::{install_test_trace_subscriber, testcontainers::container_client}, }; use janus_server::task::Task; @@ -48,12 +48,10 @@ impl<'a> JanusPair<'a> { /// - `JANUS_E2E_LEADER_NAMESPACE`: The Kubernetes namespace where the DAP leader is deployed. /// - `JANUS_E2E_HELPER_NAMESPACE`: The Kubernetes namespace where the DAP helper is deployed. pub async fn new(container_client: &'a Cli) -> JanusPair<'a> { - let (collector_hpke_config, collector_private_key) = - generate_test_hpke_config_and_private_key(); - let (mut leader_task, mut helper_task) = create_test_tasks(&collector_hpke_config); + let (collector_private_key, leader_task, helper_task) = test_task_builders(); // The environment variables should either all be present, or all be absent - let (leader, helper) = match ( + let (leader_task, leader, helper) = match ( env::var("JANUS_E2E_KUBE_CONFIG_PATH"), env::var("JANUS_E2E_KUBECTL_CONTEXT_NAME"), env::var("JANUS_E2E_LEADER_NAMESPACE"), @@ -71,8 +69,14 @@ impl<'a> JanusPair<'a> { // and so they need the in-cluster DNS name of the other aggregator. However, since // aggregators use the endpoint URLs in the task to construct collect job URIs, we // must only fix the _peer_ aggregator's endpoint. - leader_task.aggregator_endpoints[1] = - Self::in_cluster_aggregator_url(&helper_namespace); + let leader_endpoints = { + let mut endpoints = leader_task.aggregator_endpoints().to_vec(); + endpoints[1] = Self::in_cluster_aggregator_url(&helper_namespace); + endpoints + }; + let leader_task = leader_task + .with_aggregator_endpoints(leader_endpoints) + .build(); let leader = Janus::new_with_kubernetes_cluster( &kubeconfig_path, &kubectl_context_name, @@ -81,8 +85,14 @@ impl<'a> JanusPair<'a> { ) .await; - helper_task.aggregator_endpoints[0] = - Self::in_cluster_aggregator_url(&leader_namespace); + let helper_endpoints = { + let mut endpoints = helper_task.aggregator_endpoints().to_vec(); + endpoints[0] = Self::in_cluster_aggregator_url(&leader_namespace); + endpoints + }; + let helper_task = helper_task + .with_aggregator_endpoints(helper_endpoints) + .build(); let helper = Janus::new_with_kubernetes_cluster( &kubeconfig_path, &kubectl_context_name, @@ -91,7 +101,7 @@ impl<'a> JanusPair<'a> { ) .await; - (leader, helper) + (leader_task, leader, helper) } ( Err(VarError::NotPresent), @@ -99,12 +109,13 @@ impl<'a> JanusPair<'a> { Err(VarError::NotPresent), Err(VarError::NotPresent), ) => { + let leader_task = leader_task.build(); let network = generate_network_name(); let leader = Janus::new_in_container(container_client, &network, &leader_task).await; let helper = - Janus::new_in_container(container_client, &network, &helper_task).await; - (leader, helper) + Janus::new_in_container(container_client, &network, &helper_task.build()).await; + (leader_task, leader, helper) } _ => panic!("unexpected environment variables"), }; diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs index c4c5e4458..9e93508a8 100644 --- a/interop_binaries/src/bin/janus_interop_aggregator.rs +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -7,7 +7,7 @@ use interop_binaries::{ AddTaskResponse, AggregatorAddTaskRequest, HpkeConfigRegistry, }; use janus_core::{task::AuthenticationToken, time::RealClock, TokioRuntime}; -use janus_messages::{Duration, HpkeConfig, Role, TaskId}; +use janus_messages::{Duration, HpkeConfig, Role, TaskId, Time}; use janus_server::{ aggregator::{ aggregate_share::CollectJobDriver, aggregation_job_creator::AggregationJobCreator, @@ -54,7 +54,7 @@ async fn handle_add_task( base64::decode_config(request.verify_key, URL_SAFE_NO_PAD) .context("invalid base64url content in \"verifyKey\"")?, ); - let min_batch_duration = Duration::from_seconds(request.min_batch_duration); + let time_precision = Duration::from_seconds(request.time_precision); let collector_hpke_config_bytes = base64::decode_config(request.collector_hpke_config, URL_SAFE_NO_PAD) .context("invalid base64url content in \"collectorHpkeConfig\"")?; @@ -70,9 +70,9 @@ async fn handle_add_task( } (0, Some(collector_authentication_token)) => ( Role::Leader, - vec![AuthenticationToken::from( + Vec::from([AuthenticationToken::from( collector_authentication_token.into_bytes(), - )], + )]), ), (1, _) => (Role::Helper, Vec::new()), _ => return Err(anyhow::anyhow!("invalid \"aggregator_id\" value")), @@ -82,18 +82,20 @@ async fn handle_add_task( let task = Task::new( task_id, - vec![request.leader, request.helper], + Vec::from([request.leader, request.helper]), + request.query_type, vdaf, role, - vec![verify_key], - request.max_batch_lifetime, + Vec::from([verify_key]), + request.max_batch_query_count, + Time::from_seconds_since_epoch(request.task_expiration), request.min_batch_size, - min_batch_duration, + time_precision, // We can be strict about clock skew since this executable is only intended for use with // other aggregators running on the same host. Duration::from_seconds(1), collector_hpke_config, - vec![leader_authentication_token], + Vec::from([leader_authentication_token]), collector_authentication_tokens, [(hpke_config, private_key)], ) @@ -241,7 +243,7 @@ async fn main() -> anyhow::Result<()> { // Run the aggregation job creator. let pool = database_pool(&db_config, None).await?; let datastore_key = LessSafeKey::new(UnboundKey::new(&AES_128_GCM, &key_bytes).unwrap()); - let crypter = Crypter::new(vec![datastore_key]); + let crypter = Crypter::new(Vec::from([datastore_key])); let aggregation_job_creator = Arc::new(AggregationJobCreator::new( Datastore::new(pool, crypter, clock), clock, diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs index f1fd60dd0..3f51073ca 100644 --- a/interop_binaries/src/bin/janus_interop_client.rs +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -61,7 +61,7 @@ struct UploadRequest { measurement: Measurement, #[serde(default, rename = "nonceTime")] timestamp: Option, - min_batch_duration: u64, + time_precision: u64, } #[derive(Debug, Serialize)] @@ -83,25 +83,25 @@ where let task_id_bytes = base64::decode_config(request.task_id, URL_SAFE_NO_PAD) .context("invalid base64url content in \"taskId\"")?; let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; - let min_batch_duration = Duration::from_seconds(request.min_batch_duration); + let time_precision = Duration::from_seconds(request.time_precision); let client_parameters = ClientParameters::new( task_id, - vec![request.leader, request.helper], - min_batch_duration, + Vec::::from([request.leader, request.helper]), + time_precision, ); let leader_hpke_config = janus_client::aggregator_hpke_config( &client_parameters, - Role::Leader, - task_id, + &Role::Leader, + &task_id, http_client, ) .await .context("failed to fetch leader's HPKE configuration")?; let helper_hpke_config = janus_client::aggregator_hpke_config( &client_parameters, - Role::Helper, - task_id, + &Role::Helper, + &task_id, http_client, ) .await diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 1c13d087f..44ff50c66 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -1,7 +1,7 @@ use base64::URL_SAFE_NO_PAD; use janus_core::hpke::{generate_hpke_config_and_private_key, HpkePrivateKey}; use janus_messages::{HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, Role}; -use janus_server::task::{Task, VdafInstance}; +use janus_server::task::{QueryType, Task, VdafInstance}; use prio::codec::Encode; use rand::random; use serde::{de::Visitor, Deserialize, Serialize}; @@ -143,15 +143,17 @@ pub struct AggregatorAddTaskRequest { pub task_id: String, // in unpadded base64url pub leader: Url, pub helper: Url, + pub query_type: QueryType, pub vdaf: VdafObject, pub leader_authentication_token: String, #[serde(default)] pub collector_authentication_token: Option, pub aggregator_id: u8, pub verify_key: String, // in unpadded base64url - pub max_batch_lifetime: u64, + pub max_batch_query_count: u64, + pub task_expiration: u64, // in seconds since the epoch pub min_batch_size: u64, - pub min_batch_duration: u64, // in seconds + pub time_precision: u64, // in seconds pub collector_hpke_config: String, // in unpadded base64url } @@ -165,32 +167,34 @@ pub struct AddTaskResponse { impl From for AggregatorAddTaskRequest { fn from(task: Task) -> Self { Self { - task_id: base64::encode_config(task.id.as_ref(), URL_SAFE_NO_PAD), - leader: task.aggregator_url(Role::Leader).unwrap().clone(), - helper: task.aggregator_url(Role::Helper).unwrap().clone(), - vdaf: task.vdaf.clone().into(), + task_id: base64::encode_config(task.id().as_ref(), URL_SAFE_NO_PAD), + leader: task.aggregator_url(&Role::Leader).unwrap().clone(), + helper: task.aggregator_url(&Role::Helper).unwrap().clone(), + query_type: *task.query_type(), + vdaf: task.vdaf().clone().into(), leader_authentication_token: String::from_utf8( - task.aggregator_auth_tokens - .first() - .unwrap() - .as_bytes() - .to_vec(), + task.primary_aggregator_auth_token().as_bytes().to_vec(), ) .unwrap(), - collector_authentication_token: task - .collector_auth_tokens - .first() - .map(|t| String::from_utf8(t.as_bytes().to_vec()).unwrap()), - aggregator_id: task.role.index().unwrap().try_into().unwrap(), + collector_authentication_token: if task.role() == &Role::Leader { + Some( + String::from_utf8(task.primary_collector_auth_token().as_bytes().to_vec()) + .unwrap(), + ) + } else { + None + }, + aggregator_id: task.role().index().unwrap().try_into().unwrap(), verify_key: base64::encode_config( - task.vdaf_verify_keys().first().unwrap().as_bytes(), + task.vdaf_verify_keys().first().unwrap().as_ref(), URL_SAFE_NO_PAD, ), - max_batch_lifetime: task.max_batch_lifetime, - min_batch_size: task.min_batch_size, - min_batch_duration: task.min_batch_duration.as_seconds(), + max_batch_query_count: task.max_batch_query_count(), + task_expiration: task.task_expiration().as_seconds_since_epoch(), + min_batch_size: task.min_batch_size(), + time_precision: task.time_precision().as_seconds(), collector_hpke_config: base64::encode_config( - &task.collector_hpke_config.get_encoded(), + &task.collector_hpke_config().get_encoded(), URL_SAFE_NO_PAD, ), } diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 8f628a8a1..387c084f4 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -19,12 +19,13 @@ use std::time::Duration as StdDuration; use testcontainers::RunnableImage; const JSON_MEDIA_TYPE: &str = "application/json"; -const MIN_BATCH_DURATION: u64 = 3600; +const TIME_PRECISION: u64 = 3600; /// Take a VDAF description and a list of measurements, perform an entire aggregation using /// interoperation test binaries, and return the aggregate result. This follows the outline of /// the "Test Runner Operation" section in draft-dcook-ppm-dap-interop-test-design-01. async fn run( + query_type: serde_json::Value, vdaf_object: serde_json::Value, measurements: &[serde_json::Value], aggregation_parameter: &[u8], @@ -233,14 +234,16 @@ async fn run( "taskId": task_id_encoded, "leader": internal_leader_endpoint, "helper": internal_helper_endpoint, + "queryType": query_type, "vdaf": vdaf_object, "leaderAuthenticationToken": aggregator_auth_token, "collectorAuthenticationToken": collector_auth_token, "aggregatorId": 0, "verifyKey": verify_key_encoded, - "maxBatchLifetime": 1, + "maxBatchQueryCount": 1, + "taskExpiration": u64::MAX, "minBatchSize": 1, - "minBatchDuration": MIN_BATCH_DURATION, + "timePrecision": TIME_PRECISION, "collectorHpkeConfig": collector_hpke_config_encoded, })) .send() @@ -278,13 +281,15 @@ async fn run( "taskId": task_id_encoded, "leader": internal_leader_endpoint, "helper": internal_helper_endpoint, + "queryType": query_type, "vdaf": vdaf_object, "leaderAuthenticationToken": aggregator_auth_token, "aggregatorId": 1, "verifyKey": verify_key_encoded, - "maxBatchLifetime": 1, + "maxBatchQueryCount": 1, + "taskExpiration": u64::MAX, "minBatchSize": 1, - "minBatchDuration": MIN_BATCH_DURATION, + "timePrecision": TIME_PRECISION, "collectorHpkeConfig": collector_hpke_config_encoded, })) .send() @@ -315,12 +320,12 @@ async fn run( // determine what batch time to start the aggregation at. let start_timestamp = RealClock::default().now(); let batch_interval_start = start_timestamp - .to_batch_unit_interval_start(Duration::from_seconds(MIN_BATCH_DURATION)) + .to_batch_unit_interval_start(&Duration::from_seconds(TIME_PRECISION)) .unwrap() .as_seconds_since_epoch(); - // Span the aggregation over two minimum batch durations, just in case our - // measurements spilled over a batch boundary. - let batch_interval_duration = MIN_BATCH_DURATION * 2; + // Span the aggregation over two time precisions, just in case our measurements spilled over a + // batch boundary. + let batch_interval_duration = TIME_PRECISION * 2; // Send one or more /internal/test/upload requests to the client. for measurement in measurements { @@ -332,7 +337,7 @@ async fn run( "helper": internal_helper_endpoint, "vdaf": vdaf_object, "measurement": measurement, - "minBatchDuration": MIN_BATCH_DURATION, + "timePrecision": TIME_PRECISION, })) .send() .await @@ -452,6 +457,7 @@ async fn run( #[tokio::test] async fn e2e_prio3_count() { let result = run( + json!("TimeInterval"), json!({"type": "Prio3Aes128Count"}), &[ json!("0"), @@ -482,6 +488,7 @@ async fn e2e_prio3_count() { #[tokio::test] async fn e2e_prio3_sum() { let result = run( + json!("TimeInterval"), json!({"type": "Prio3Aes128Sum", "bits": "64"}), &[ json!("0"), @@ -501,6 +508,7 @@ async fn e2e_prio3_sum() { #[tokio::test] async fn e2e_prio3_histogram() { let result = run( + json!("TimeInterval"), json!({ "type": "Prio3Aes128Histogram", "buckets": ["0", "1", "10", "100", "1000", "10000", "100000"], @@ -526,6 +534,7 @@ async fn e2e_prio3_histogram() { #[tokio::test] async fn e2e_prio3_count_vec() { let result = run( + json!("TimeInterval"), json!({"type": "Prio3Aes128CountVec", "length": "4"}), &[ json!(["0", "0", "0", "1"]), diff --git a/janus_client/src/lib.rs b/janus_client/src/lib.rs index 06d48d481..d78f4ca81 100644 --- a/janus_client/src/lib.rs +++ b/janus_client/src/lib.rs @@ -59,24 +59,20 @@ pub struct ClientParameters { /// entry is the leader's. #[derivative(Debug(format_with = "fmt_vector_of_urls"))] aggregator_endpoints: Vec, - /// The minimum batch duration of the task. This value is shared by all - /// parties in the protocol, and is used to compute report timestamps. - min_batch_duration: Duration, + /// The time precision of the task. This value is shared by all parties in the protocol, and is + /// used to compute report timestamps. + time_precision: Duration, /// Parameters to use when retrying HTTP requests. http_request_retry_parameters: ExponentialBackoff, } impl ClientParameters { /// Creates a new set of client task parameters. - pub fn new( - task_id: TaskId, - aggregator_endpoints: Vec, - min_batch_duration: Duration, - ) -> Self { + pub fn new(task_id: TaskId, aggregator_endpoints: Vec, time_precision: Duration) -> Self { Self::new_with_backoff( task_id, aggregator_endpoints, - min_batch_duration, + time_precision, http_request_exponential_backoff(), ) } @@ -85,7 +81,7 @@ impl ClientParameters { pub fn new_with_backoff( task_id: TaskId, mut aggregator_endpoints: Vec, - min_batch_duration: Duration, + time_precision: Duration, http_request_retry_parameters: ExponentialBackoff, ) -> Self { // Ensure provided aggregator endpoints end with a slash, as we will be joining additional @@ -98,14 +94,14 @@ impl ClientParameters { Self { task_id, aggregator_endpoints, - min_batch_duration, + time_precision, http_request_retry_parameters, } } /// The URL relative to which the API endpoints for the aggregator may be /// found, if the role is an aggregator, or an error otherwise. - fn aggregator_endpoint(&self, role: Role) -> Result<&Url, Error> { + fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> { Ok(&self.aggregator_endpoints[role .index() .ok_or(Error::InvalidParameter("role is not an aggregator"))?]) @@ -113,14 +109,14 @@ impl ClientParameters { /// URL from which the HPKE configuration for the server filling `role` may /// be fetched per draft-gpew-priv-ppm §4.3.1 - fn hpke_config_endpoint(&self, role: Role) -> Result { + fn hpke_config_endpoint(&self, role: &Role) -> Result { Ok(self.aggregator_endpoint(role)?.join("hpke_config")?) } /// URL to which reports may be uploaded by clients per draft-gpew-priv-ppm /// §4.3.2 fn upload_endpoint(&self) -> Result { - Ok(self.aggregator_endpoint(Role::Leader)?.join("upload")?) + Ok(self.aggregator_endpoint(&Role::Leader)?.join("upload")?) } } @@ -137,12 +133,12 @@ fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { #[tracing::instrument(err)] pub async fn aggregator_hpke_config( client_parameters: &ClientParameters, - aggregator_role: Role, - task_id: TaskId, + aggregator_role: &Role, + task_id: &TaskId, http_client: &reqwest::Client, ) -> Result { let mut request_url = client_parameters.hpke_config_endpoint(aggregator_role)?; - request_url.set_query(Some(&format!("task_id={}", task_id))); + request_url.set_query(Some(&format!("task_id={task_id}"))); let hpke_config_response = retry_http_request( client_parameters.http_request_retry_parameters.clone(), || async { http_client.get(request_url.clone()).send().await }, @@ -215,10 +211,8 @@ where let time = self .clock .now() - .to_batch_unit_interval_start(self.parameters.min_batch_duration) - .map_err(|_| { - Error::InvalidParameter("couldn't round time down to min_batch_duration") - })?; + .to_batch_unit_interval_start(&self.parameters.time_precision) + .map_err(|_| Error::InvalidParameter("couldn't round time down to time_precision"))?; let report_metadata = ReportMetadata::new( random(), time, @@ -226,21 +220,21 @@ where ); let public_share = public_share.get_encoded(); let associated_data = associated_data_for_report_share( - self.parameters.task_id, + &self.parameters.task_id, &report_metadata, &public_share, ); let encrypted_input_shares: Vec = [ - (&self.leader_hpke_config, Role::Leader), - (&self.helper_hpke_config, Role::Helper), + (&self.leader_hpke_config, &Role::Leader), + (&self.helper_hpke_config, &Role::Helper), ] .into_iter() .zip(input_shares) .map(|((hpke_config, receiver_role), input_share)| { Ok(hpke::seal( hpke_config, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, receiver_role), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role), &input_share.get_encoded(), &associated_data, )?) @@ -290,17 +284,18 @@ where #[cfg(test)] mod tests { - use super::*; + use crate::{default_http_client, Client, ClientParameters, Error}; use assert_matches::assert_matches; - use http::StatusCode; + use http::{header::CONTENT_TYPE, StatusCode}; use janus_core::{ hpke::test_util::generate_test_hpke_config_and_private_key, retries::test_http_request_exponential_backoff, test_util::install_test_trace_subscriber, time::MockClock, }; - use janus_messages::Time; + use janus_messages::{Duration, Report, Time}; use mockito::mock; - use prio::vdaf::prio3::Prio3; + use prio::vdaf::{self, prio3::Prio3}; + use rand::random; use url::Url; fn setup_client(vdaf_client: V) -> Client @@ -427,10 +422,11 @@ mod tests { } #[tokio::test] - async fn upload_bad_min_batch_duration() { + async fn upload_bad_time_precision() { install_test_trace_subscriber(); - let client_parameters = ClientParameters::new(random(), vec![], Duration::from_seconds(0)); + let client_parameters = + ClientParameters::new(random(), Vec::new(), Duration::from_seconds(0)); let client = Client::new( client_parameters, Prio3::new_aes128_count(2).unwrap(), @@ -449,7 +445,7 @@ mod tests { let vdaf = Prio3::new_aes128_count(2).unwrap(); let mut client = setup_client(vdaf); - client.parameters.min_batch_duration = Duration::from_seconds(100); + client.parameters.time_precision = Duration::from_seconds(100); client.clock = MockClock::new(Time::from_seconds_since_epoch(101)); assert_eq!( client.prepare_report(&1).unwrap().metadata().time(), diff --git a/janus_collector/src/lib.rs b/janus_collector/src/lib.rs index fef14743c..f620309ec 100644 --- a/janus_collector/src/lib.rs +++ b/janus_collector/src/lib.rs @@ -429,12 +429,12 @@ where let aggregate_shares_bytes = collect_response .encrypted_aggregate_shares() .iter() - .zip([Role::Leader, Role::Helper]) + .zip(&[Role::Leader, Role::Helper]) .map(|(encrypted_aggregate_share, role)| { hpke::open( &self.parameters.hpke_config, &self.parameters.hpke_private_key, - &HpkeApplicationInfo::new(hpke::Label::AggregateShare, role, Role::Collector), + &HpkeApplicationInfo::new(&hpke::Label::AggregateShare, role, &Role::Collector), encrypted_aggregate_share, &associated_data, ) @@ -556,21 +556,37 @@ pub mod test_util { #[cfg(test)] mod tests { - use super::*; + use crate::{ + default_http_client, CollectJob, Collection, Collector, CollectorParameters, Error, + PollResult, + }; use assert_matches::assert_matches; use chrono::{TimeZone, Utc}; use janus_core::{ - hpke::{test_util::generate_test_hpke_config_and_private_key, Label}, + hpke::{ + self, associated_data_for_aggregate_share, + test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, + }, retries::test_http_request_exponential_backoff, + task::AuthenticationToken, test_util::{install_test_trace_subscriber, run_vdaf, VdafTranscript}, }; - use janus_messages::{Duration, HpkeCiphertext, PartialBatchSelector, Time}; + use janus_messages::{ + query_type::TimeInterval, CollectReq, CollectResp, Duration, HpkeCiphertext, Interval, + PartialBatchSelector, Role, Time, + }; use mockito::mock; use prio::{ + codec::Encode, field::Field64, - vdaf::{prio3::Prio3, AggregateShare}, + vdaf::{self, prio3::Prio3, AggregateShare}, }; use rand::random; + use reqwest::{ + header::{CONTENT_TYPE, LOCATION}, + StatusCode, Url, + }; + use retry_after::RetryAfter; fn setup_collector(vdaf_collector: V) -> Collector where @@ -609,22 +625,30 @@ mod tests { CollectResp::new( PartialBatchSelector::new_time_interval(), 1, - vec![ + Vec::::from([ hpke::seal( ¶meters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Leader, + &Role::Collector, + ), &>::from(&transcript.aggregate_shares[0]), &associated_data, ) .unwrap(), hpke::seal( ¶meters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, + ), &>::from(&transcript.aggregate_shares[1]), &associated_data, ) .unwrap(), - ], + ]), ) } @@ -963,8 +987,16 @@ mod tests { PartialBatchSelector::new_time_interval(), 1, Vec::from([ - HpkeCiphertext::new(*collector.parameters.hpke_config.id(), vec![], vec![]), - HpkeCiphertext::new(*collector.parameters.hpke_config.id(), vec![], vec![]), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), + HpkeCiphertext::new( + *collector.parameters.hpke_config.id(), + Vec::new(), + Vec::new(), + ), ]), ) .get_encoded(), @@ -987,14 +1019,22 @@ mod tests { Vec::from([ hpke::seal( &collector.parameters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Leader, + &Role::Collector, + ), b"bad", &associated_data, ) .unwrap(), hpke::seal( &collector.parameters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, + ), b"bad", &associated_data, ) @@ -1022,18 +1062,26 @@ mod tests { Vec::from([ hpke::seal( &collector.parameters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), - &>::from(&AggregateShare::from(vec![Field64::from(0)])), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Leader, + &Role::Collector, + ), + &>::from(&AggregateShare::from(Vec::from([Field64::from(0)]))), &associated_data, ) .unwrap(), hpke::seal( &collector.parameters.hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), - &>::from(&AggregateShare::from(vec![ + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, + ), + &>::from(&AggregateShare::from(Vec::from([ Field64::from(0), Field64::from(0), - ])), + ]))), &associated_data, ) .unwrap(), @@ -1108,7 +1156,7 @@ mod tests { .create(); assert_matches!( collector.poll_once(&job).await.unwrap(), - PollResult::NextAttempt(Some(RetryAfter::Delay(duration))) => assert_eq!(duration, StdDuration::from_secs(60)) + PollResult::NextAttempt(Some(RetryAfter::Delay(duration))) => assert_eq!(duration, std::time::Duration::from_secs(60)) ); mock_collect_poll_retry_after_60s.assert(); @@ -1137,7 +1185,7 @@ mod tests { collector .parameters .collect_poll_wait_parameters - .max_elapsed_time = Some(StdDuration::from_secs(3)); + .max_elapsed_time = Some(std::time::Duration::from_secs(3)); let collect_job_url = format!("{}/collect_job/1", mockito::server_url()); let batch_interval = Interval::new( @@ -1165,7 +1213,7 @@ mod tests { mock_collect_poll_retry_after_10s.assert(); let near_future = - Utc::now() + chrono::Duration::from_std(StdDuration::from_secs(1)).unwrap(); + Utc::now() + chrono::Duration::from_std(std::time::Duration::from_secs(1)).unwrap(); let near_future_formatted = near_future.format("%a, %d %b %Y %H:%M:%S GMT").to_string(); let mock_collect_poll_retry_after_near_future = mock("GET", "/collect_job/1") .with_status(202) @@ -1194,11 +1242,11 @@ mod tests { collector .parameters .collect_poll_wait_parameters - .max_elapsed_time = Some(StdDuration::from_millis(15)); + .max_elapsed_time = Some(std::time::Duration::from_millis(15)); collector .parameters .collect_poll_wait_parameters - .initial_interval = StdDuration::from_millis(10); + .initial_interval = std::time::Duration::from_millis(10); let mock_collect_poll_no_retry_after = mock("GET", "/collect_job/1") .with_status(202) .expect_at_least(1) diff --git a/janus_core/src/hpke.rs b/janus_core/src/hpke.rs index 2c282f990..ead546d1a 100644 --- a/janus_core/src/hpke.rs +++ b/janus_core/src/hpke.rs @@ -37,7 +37,7 @@ fn hpke_dispatch_config_from_hpke_config( /// Construct the HPKE associated data for sealing or opening data enciphered for a report or report /// share, per §4.3.2 and 4.4.1.3 of draft-ietf-ppm-dap-02 pub fn associated_data_for_report_share( - task_id: TaskId, + task_id: &TaskId, report_metadata: &ReportMetadata, public_share: &[u8], ) -> Vec { @@ -83,12 +83,12 @@ pub struct HpkeApplicationInfo(Vec); impl HpkeApplicationInfo { /// Construct HPKE application info from the provided label and participant roles. - pub fn new(label: Label, sender_role: Role, recipient_role: Role) -> Self { + pub fn new(label: &Label, sender_role: &Role, recipient_role: &Role) -> Self { Self( [ label.as_bytes(), - &[sender_role as u8], - &[recipient_role as u8], + &[*sender_role as u8], + &[*recipient_role as u8], ] .concat(), ) @@ -235,7 +235,7 @@ mod tests { fn exchange_message() { let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let message = b"a message that is secret"; let associated_data = b"message associated data"; @@ -257,7 +257,7 @@ mod tests { fn wrong_private_key() { let (hpke_config, _) = generate_test_hpke_config_and_private_key(); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let message = b"a message that is secret"; let associated_data = b"message associated data"; @@ -280,14 +280,14 @@ mod tests { fn wrong_application_info() { let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let message = b"a message that is secret"; let associated_data = b"message associated data"; let ciphertext = seal(&hpke_config, &application_info, message, associated_data).unwrap(); let wrong_application_info = - HpkeApplicationInfo::new(Label::AggregateShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Client, &Role::Leader); open( &hpke_config, &hpke_private_key, @@ -302,7 +302,7 @@ mod tests { fn wrong_associated_data() { let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let message = b"a message that is secret"; let associated_data = b"message associated data"; @@ -339,7 +339,7 @@ mod tests { ); let hpke_private_key = HpkePrivateKey::new(private_key); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let ciphertext = seal(&hpke_config, &application_info, MESSAGE, ASSOCIATED_DATA).unwrap(); let plaintext = open( diff --git a/janus_core/src/time.rs b/janus_core/src/time.rs index b90cd57be..0016d60d9 100644 --- a/janus_core/src/time.rs +++ b/janus_core/src/time.rs @@ -82,7 +82,7 @@ pub trait TimeExt: Sized { /// Compute the start of the batch interval containing this Time, given the batch unit duration. fn to_batch_unit_interval_start( &self, - min_batch_duration: Duration, + time_precision: &Duration, ) -> Result; } @@ -90,11 +90,11 @@ impl TimeExt for Time { /// Compute the start of the batch interval containing this Time, given the batch unit duration. fn to_batch_unit_interval_start( &self, - min_batch_duration: Duration, + time_precision: &Duration, ) -> Result { let rem = self .as_seconds_since_epoch() - .checked_rem(min_batch_duration.as_seconds()) + .checked_rem(time_precision.as_seconds()) .ok_or(janus_messages::Error::IllegalTimeArithmetic( "remainder would overflow/underflow", ))?; diff --git a/janus_messages/src/lib.rs b/janus_messages/src/lib.rs index 0449f8e19..f72d77b35 100644 --- a/janus_messages/src/lib.rs +++ b/janus_messages/src/lib.rs @@ -65,7 +65,7 @@ impl Display for Duration { } /// DAP protocol message representing an instant in time with a resolution of seconds. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Time(u64); impl Time { @@ -836,7 +836,7 @@ impl ReportMetadata { } /// Retrieve the report ID from this report metadata. - pub fn report_id(&self) -> &ReportId { + pub fn id(&self) -> &ReportId { &self.report_id } diff --git a/janus_server/src/aggregator.rs b/janus_server/src/aggregator.rs index 9211113cf..42e8bde18 100644 --- a/janus_server/src/aggregator.rs +++ b/janus_server/src/aggregator.rs @@ -8,7 +8,7 @@ pub mod aggregation_job_driver; use crate::{ aggregator::{ accumulator::Accumulator, - aggregate_share::{compute_aggregate_share, validate_batch_lifetime_for_collect}, + aggregate_share::{compute_aggregate_share, validate_batch_query_count_for_collect}, }, datastore::{ self, @@ -39,9 +39,9 @@ use janus_core::{ use janus_messages::{ query_type::TimeInterval, AggregateContinueReq, AggregateContinueResp, AggregateInitializeReq, AggregateInitializeResp, AggregateShareReq, AggregateShareResp, AggregationJobId, CollectReq, - CollectResp, HpkeConfig, HpkeConfigId, Interval, PartialBatchSelector, PrepareStep, - PrepareStepResult, Report, ReportId, ReportIdChecksum, ReportShare, ReportShareError, Role, - TaskId, Time, + CollectResp, HpkeCiphertext, HpkeConfig, HpkeConfigId, Interval, PartialBatchSelector, + PrepareStep, PrepareStepResult, Report, ReportId, ReportIdChecksum, ReportShare, + ReportShareError, Role, TaskId, Time, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter, Unit}, @@ -328,7 +328,7 @@ impl Aggregator { let report = Report::get_decoded(report_bytes)?; let task_aggregator = self.task_aggregator_for(report.task_id()).await?; - if task_aggregator.task.role != Role::Leader { + if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(*report.task_id())); } task_aggregator @@ -350,7 +350,7 @@ impl Aggregator { // This assumes that the task ID is at the start of the message content. let task_id = TaskId::decode(&mut Cursor::new(req_bytes))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Helper { + if task_aggregator.task.role() != &Role::Helper { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -382,7 +382,7 @@ impl Aggregator { // This assumes that the task ID is at the start of the message content. let task_id = TaskId::decode(&mut Cursor::new(req_bytes))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Helper { + if task_aggregator.task.role() != &Role::Helper { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -417,7 +417,7 @@ impl Aggregator { // This assumes that the task ID is at the start of the message content. let task_id = TaskId::decode(&mut Cursor::new(req_bytes))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Leader { + if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -455,7 +455,7 @@ impl Aggregator { .ok_or(Error::UnrecognizedCollectJob(collect_job_id))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Leader { + if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -488,7 +488,7 @@ impl Aggregator { .ok_or(Error::UnrecognizedCollectJob(collect_job_id))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Leader { + if task_aggregator.task.role() != &Role::Leader { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -521,7 +521,7 @@ impl Aggregator { // This assumes that the task ID is at the start of the message content. let task_id = TaskId::decode(&mut Cursor::new(req_bytes))?; let task_aggregator = self.task_aggregator_for(&task_id).await?; - if task_aggregator.task.role != Role::Helper { + if task_aggregator.task.role() != &Role::Helper { return Err(Error::UnrecognizedTask(task_id)); } if !auth_token @@ -588,7 +588,7 @@ impl TaskAggregator { /// Create a new aggregator. `report_recipient` is used to decrypt reports received by this /// aggregator. fn new(task: Task) -> Result { - let vdaf_ops = match &task.vdaf { + let vdaf_ops = match task.vdaf() { VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count) => { let vdaf = Prio3::new_aes128_count(2)?; let verify_key = task.primary_vdaf_verify_key()?; @@ -639,7 +639,7 @@ impl TaskAggregator { ))) } - _ => panic!("VDAF {:?} is not yet supported", task.vdaf), + _ => panic!("VDAF {:?} is not yet supported", task.vdaf()), }; Ok(Self { @@ -653,7 +653,7 @@ impl TaskAggregator { // config/key -- right now it's the one with the maximal config ID, but that will run into // trouble if we ever need to wrap-around, which we may since config IDs are effectively a u8. self.task - .hpke_keys + .hpke_keys() .iter() .max_by_key(|(&id, _)| id) .unwrap() @@ -687,7 +687,12 @@ impl TaskAggregator { req: AggregateInitializeReq, ) -> Result { self.vdaf_ops - .handle_aggregate_init(datastore, aggregate_step_failure_counter, &self.task, req) + .handle_aggregate_init( + datastore, + aggregate_step_failure_counter, + Arc::clone(&self.task), + req, + ) .await } @@ -698,7 +703,12 @@ impl TaskAggregator { req: AggregateContinueReq, ) -> Result { self.vdaf_ops - .handle_aggregate_continue(datastore, aggregate_step_failure_counter, &self.task, req) + .handle_aggregate_continue( + datastore, + aggregate_step_failure_counter, + Arc::clone(&self.task), + Arc::new(req), + ) .await } @@ -709,12 +719,12 @@ impl TaskAggregator { ) -> Result { let collect_job_id = self .vdaf_ops - .handle_collect(datastore, &self.task, &req) + .handle_collect(datastore, Arc::clone(&self.task), Arc::new(req)) .await?; Ok(self .task - .aggregator_url(Role::Leader)? + .aggregator_url(&Role::Leader)? .join("collect_jobs/")? .join(&collect_job_id.to_string())?) } @@ -725,7 +735,7 @@ impl TaskAggregator { collect_job_id: Uuid, ) -> Result>, Error> { self.vdaf_ops - .handle_get_collect_job(datastore, &self.task, collect_job_id) + .handle_get_collect_job(datastore, &self.task, Arc::new(collect_job_id)) .await } @@ -750,13 +760,13 @@ impl TaskAggregator { .validate_batch_interval(req.batch_selector().batch_interval()) { return Err(Error::BatchInvalid( - self.task.id, + *self.task.id(), *req.batch_selector().batch_interval(), )); } self.vdaf_ops - .handle_aggregate_share(datastore, Arc::clone(&self.task), req) + .handle_aggregate_share(datastore, Arc::clone(&self.task), Arc::new(req)) .await } } @@ -864,7 +874,7 @@ impl VdafOps { &self, datastore: &Datastore, aggregate_step_failure_counter: &Counter, - task: &Task, + task: Arc, req: AggregateInitializeReq, ) -> Result { match self { @@ -949,8 +959,8 @@ impl VdafOps { &self, datastore: &Datastore, aggregate_step_failure_counter: &Counter, - task: &Task, - req: AggregateContinueReq, + task: Arc, + req: Arc, ) -> Result { match self { VdafOps::Prio3Aes128Count(vdaf, _) => { @@ -1049,19 +1059,19 @@ impl VdafOps { // §4.2.2: verify that the report's HPKE config ID is known let (hpke_config, hpke_private_key) = task - .hpke_keys + .hpke_keys() .get(leader_report.config_id()) .ok_or_else(|| { Error::OutdatedHpkeConfig(*report.task_id(), *leader_report.config_id()) })?; - let report_deadline = clock.now().add(&task.tolerable_clock_skew)?; + let report_deadline = clock.now().add(task.tolerable_clock_skew())?; // §4.2.4: reject reports from too far in the future if report.metadata().time().is_after(&report_deadline) { return Err(Error::ReportTooEarly( *report.task_id(), - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), )); } @@ -1073,10 +1083,10 @@ impl VdafOps { if let Err(error) = hpke::open( hpke_config, hpke_private_key, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, task.role), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, task.role()), leader_report, &associated_data_for_report_share( - *report.task_id(), + report.task_id(), report.metadata(), report.public_share(), ), @@ -1091,7 +1101,7 @@ impl VdafOps { let report = report.clone(); Box::pin(async move { let (existing_client_report, conflicting_collect_jobs) = try_join!( - tx.get_client_report(report.task_id(), report.metadata().report_id()), + tx.get_client_report(report.task_id(), report.metadata().id()), tx.get_collect_jobs_including_time::( report.task_id(), report.metadata().time() @@ -1104,7 +1114,7 @@ impl VdafOps { return Err(datastore::Error::User( Error::ReportTooLate( *report.task_id(), - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), ) .into(), @@ -1117,7 +1127,7 @@ impl VdafOps { return Err(datastore::Error::User( Error::ReportTooLate( *report.task_id(), - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), ) .into(), @@ -1139,7 +1149,7 @@ impl VdafOps { datastore: &Datastore, vdaf: &A, aggregate_step_failure_counter: &Counter, - task: &Task, + task: Arc, verify_key: &VerifyKey, req: AggregateInitializeReq, ) -> Result @@ -1154,16 +1164,13 @@ impl VdafOps { A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, { - let task_id = task.id; - let min_batch_duration = task.min_batch_duration; - // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "unrecognizedMessage". (§4.4.4.1) let mut seen_report_ids = HashSet::with_capacity(req.report_shares().len()); for share in req.report_shares() { - if !seen_report_ids.insert(share.metadata().report_id()) { + if !seen_report_ids.insert(share.metadata().id()) { return Err(Error::UnrecognizedMessage( - Some(task_id), + Some(*task.id()), "aggregate request contains duplicate report IDs", )); } @@ -1183,7 +1190,7 @@ impl VdafOps { let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; for report_share in req.report_shares() { let hpke_key = task - .hpke_keys + .hpke_keys() .get(report_share.encrypted_input_share().config_id()) .ok_or_else(|| { info!( @@ -1203,17 +1210,17 @@ impl VdafOps { hpke::open( hpke_config, hpke_private_key, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Helper), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), report_share.encrypted_input_share(), &associated_data_for_report_share( - task_id, + task.id(), report_share.metadata(), report_share.public_share(), ), ) .map_err(|error| { info!( - ?task_id, + task_id = %task.id(), metadata = ?report_share.metadata(), %error, "Couldn't decrypt helper's report share" @@ -1235,14 +1242,14 @@ impl VdafOps { let input_share = plaintext.and_then(|plaintext| { A::InputShare::get_decoded_with_param(&(vdaf, Role::Helper.index().unwrap()), &plaintext) .map_err(|error| { - info!(?task_id, metadata = ?report_share.metadata(), %error, "Couldn't decode helper's input share"); + info!(task_id = %task.id(), metadata = ?report_share.metadata(), %error, "Couldn't decode helper's input share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "input_share_decode_failure")]); ReportShareError::VdafPrepError }) }); let public_share = A::PublicShare::get_decoded_with_param(&vdaf, report_share.public_share()).map_err(|error|{ - info!(?task_id, metadata = ?report_share.metadata(), %error, "Couldn't decode public share"); + info!(task_id = %task.id(), metadata = ?report_share.metadata(), %error, "Couldn't decode public share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "public_share_decode_failure")]); ReportShareError::VdafPrepError }); @@ -1258,12 +1265,12 @@ impl VdafOps { verify_key.as_bytes(), Role::Helper.index().unwrap(), &agg_param, - &report_share.metadata().report_id().get_encoded(), + &report_share.metadata().id().get_encoded(), &public_share, &input_share, ) .map_err(|error| { - info!(?task_id, report_id = %report_share.metadata().report_id(), %error, "Couldn't prepare_init report share"); + info!(task_id = %task.id(), report_id = %report_share.metadata().id(), %error, "Couldn't prepare_init report share"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_init_failure")]); ReportShareError::VdafPrepError }) @@ -1290,7 +1297,7 @@ impl VdafOps { // Store data to datastore. let req = Arc::new(req); let aggregation_job = Arc::new(AggregationJob::::new( - task_id, + *task.id(), *req.job_id(), agg_param, if saw_continue { @@ -1302,7 +1309,8 @@ impl VdafOps { let report_share_data = Arc::new(report_share_data); let prep_steps = datastore .run_tx(|tx| { - let (req, aggregation_job, report_share_data) = ( + let (task, req, aggregation_job, report_share_data) = ( + Arc::clone(&task), Arc::clone(&req), Arc::clone(&aggregation_job), Arc::clone(&report_share_data), @@ -1313,9 +1321,8 @@ impl VdafOps { tx.put_aggregation_job(&aggregation_job).await?; let mut accumulator = Accumulator::::new( - task_id, - min_batch_duration, - aggregation_job.aggregation_parameter(), + Arc::clone(&task), + aggregation_job.aggregation_parameter().clone(), ); let mut prep_steps = Vec::new(); @@ -1324,36 +1331,36 @@ impl VdafOps { // isn't for a batch interval that has already started collection. let (report_share_exists, conflicting_aggregate_share_jobs) = try_join!( tx.check_report_share_exists( - &task_id, - share_data.report_share.metadata().report_id() + task.id(), + share_data.report_share.metadata().id() ), tx.get_aggregate_share_jobs_including_time::( - &task_id, + task.id(), share_data.report_share.metadata().time() ), )?; if report_share_exists { prep_steps.push(PrepareStep::new( - *share_data.report_share.metadata().report_id(), + *share_data.report_share.metadata().id(), PrepareStepResult::Failed(ReportShareError::ReportReplayed), )); continue; } if !conflicting_aggregate_share_jobs.is_empty() { prep_steps.push(PrepareStep::new( - *share_data.report_share.metadata().report_id(), + *share_data.report_share.metadata().id(), PrepareStepResult::Failed(ReportShareError::BatchCollected), )); continue; } // Write client report & report aggregation. - tx.put_report_share(&task_id, &share_data.report_share) + tx.put_report_share(task.id(), &share_data.report_share) .await?; tx.put_report_aggregation(&ReportAggregation::::new( - task_id, + *task.id(), *req.job_id(), - *share_data.report_share.metadata().report_id(), + *share_data.report_share.metadata().id(), *share_data.report_share.metadata().time(), ord as i64, share_data.agg_state.clone(), @@ -1366,12 +1373,12 @@ impl VdafOps { accumulator.update( output_share, share_data.report_share.metadata().time(), - share_data.report_share.metadata().report_id(), + share_data.report_share.metadata().id(), )?; } prep_steps.push(PrepareStep::new( - *share_data.report_share.metadata().report_id(), + *share_data.report_share.metadata().id(), share_data.prep_result.clone(), )); } @@ -1390,8 +1397,8 @@ impl VdafOps { datastore: &Datastore, vdaf: Arc, aggregate_step_failure_counter: &Counter, - task: &Task, - req: AggregateContinueReq, + task: Arc, + req: Arc, ) -> Result where A: 'static + Send + Sync, @@ -1405,35 +1412,31 @@ impl VdafOps { A::OutputShare: Send + Sync + for<'a> TryFrom<&'a [u8]>, for<'a> &'a A::OutputShare: Into>, { - let task_id = task.id; - let min_batch_duration = task.min_batch_duration; - let req = Arc::new(req); - // TODO(#224): don't hold DB transaction open while computing VDAF updates? // TODO(#224): don't do O(n) network round-trips (where n is the number of prepare steps) Ok(datastore .run_tx(|tx| { - let (vdaf, req, aggregate_step_failure_counter) = - (Arc::clone(&vdaf), Arc::clone(&req), aggregate_step_failure_counter.clone()); + let (vdaf, aggregate_step_failure_counter, task, req) = + (Arc::clone(&vdaf), aggregate_step_failure_counter.clone(), Arc::clone(&task), Arc::clone(&req)); Box::pin(async move { // Read existing state. let (aggregation_job, report_aggregations) = try_join!( - tx.get_aggregation_job::(&task_id, req.job_id()), + tx.get_aggregation_job::(task.id(), req.job_id()), tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Helper, - &task_id, + task.id(), req.job_id(), ), )?; - let aggregation_job = aggregation_job.ok_or_else(|| datastore::Error::User(Error::UnrecognizedAggregationJob(task_id, *req.job_id()).into()))?; + let aggregation_job = aggregation_job.ok_or_else(|| datastore::Error::User(Error::UnrecognizedAggregationJob(*task.id(), *req.job_id()).into()))?; // Handle each transition in the request. let mut report_aggregations = report_aggregations.into_iter(); let (mut saw_continue, mut saw_finish) = (false, false); let mut response_prep_steps = Vec::new(); - let mut accumulator = Accumulator::::new(task_id, min_batch_duration, aggregation_job.aggregation_parameter()); + let mut accumulator = Accumulator::::new(Arc::clone(&task), aggregation_job.aggregation_parameter().clone()); for prep_step in req.prepare_steps().iter() { // Match preparation step received from leader to stored report aggregation, @@ -1441,7 +1444,7 @@ impl VdafOps { let report_aggregation = loop { let report_agg = report_aggregations.next().ok_or_else(|| { datastore::Error::User(Error::UnrecognizedMessage( - Some(task_id), + Some(*task.id()), "leader sent unexpected, duplicate, or out-of-order prepare steps", ).into()) })?; @@ -1459,7 +1462,7 @@ impl VdafOps { // Make sure this report isn't in an interval that has already started // collection. - let conflicting_aggregate_share_jobs = tx.get_aggregate_share_jobs_including_time::(&task_id, report_aggregation.time()).await?; + let conflicting_aggregate_share_jobs = tx.get_aggregate_share_jobs_including_time::(task.id(), report_aggregation.time()).await?; if !conflicting_aggregate_share_jobs.is_empty() { response_prep_steps.push(PrepareStep::new( *prep_step.report_id(), @@ -1475,7 +1478,7 @@ impl VdafOps { _ => { return Err(datastore::Error::User( Error::UnrecognizedMessage( - Some(task_id), + Some(*task.id()), "leader sent prepare step for non-WAITING report aggregation", ).into() )); @@ -1493,7 +1496,7 @@ impl VdafOps { _ => { return Err(datastore::Error::User( Error::UnrecognizedMessage( - Some(task_id), + Some(*task.id()), "leader sent non-Continued prepare step", ).into() )); @@ -1522,7 +1525,7 @@ impl VdafOps { } Err(error) => { - info!(?task_id, job_id = %req.job_id(), report_id = %prep_step.report_id(), %error, "Prepare step failed"); + info!(task_id = %task.id(), job_id = %req.job_id(), report_id = %prep_step.report_id(), %error, "Prepare step failed"); aggregate_step_failure_counter.add(&Context::current(), 1, &[KeyValue::new("type", "prepare_step_failure")]); response_prep_steps.push(PrepareStep::new( *prep_step.report_id(), @@ -1569,8 +1572,8 @@ impl VdafOps { async fn handle_collect( &self, datastore: &Datastore, - task: &Task, - collect_req: &CollectReq, + task: Arc, + collect_req: Arc>, ) -> Result { match self { VdafOps::Prio3Aes128Count(_, _) => { @@ -1622,8 +1625,8 @@ impl VdafOps { #[tracing::instrument(skip(datastore), err)] async fn handle_collect_generic, C: Clock>( datastore: &Datastore, - task: &Task, - req: &CollectReq, + task: Arc, + req: Arc>, ) -> Result where A::AggregationParam: Send + Sync, @@ -1634,7 +1637,10 @@ impl VdafOps { // Check that the batch interval is valid for the task // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6.1.1 if !task.validate_batch_interval(req.query().batch_interval()) { - return Err(Error::BatchInvalid(task.id, *req.query().batch_interval())); + return Err(Error::BatchInvalid( + *task.id(), + *req.query().batch_interval(), + )); } Ok(datastore @@ -1647,7 +1653,7 @@ impl VdafOps { if let Some(collect_job_id) = tx .get_collect_job_id::( - &task.id, + task.id(), req.query().batch_interval(), &aggregation_param, ) @@ -1658,7 +1664,7 @@ impl VdafOps { } debug!(collect_request = ?req, "Cache miss, creating new collect job UUID"); - validate_batch_lifetime_for_collect::( + validate_batch_query_count_for_collect::( tx, &task, *req.query().batch_interval(), @@ -1686,7 +1692,7 @@ impl VdafOps { &self, datastore: &Datastore, task: &Task, - collect_job_id: Uuid, + collect_job_id: Arc, ) -> Result>, Error> { match self { VdafOps::Prio3Aes128Count(_, _) => { @@ -1738,7 +1744,7 @@ impl VdafOps { async fn handle_get_collect_job_generic, C: Clock>( datastore: &Datastore, task: &Task, - collect_job_id: Uuid, + collect_job_id: Arc, ) -> Result>, Error> where A::AggregationParam: Send + Sync, @@ -1746,15 +1752,15 @@ impl VdafOps { Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, { - let task_id = task.id; let collect_job = datastore - .run_tx(move |tx| { + .run_tx(|tx| { + let collect_job_id = Arc::clone(&collect_job_id); Box::pin(async move { tx.get_collect_job::(&collect_job_id) .await? .ok_or_else(|| { datastore::Error::User( - Error::UnrecognizedCollectJob(collect_job_id).into(), + Error::UnrecognizedCollectJob(*collect_job_id).into(), ) }) }) @@ -1763,7 +1769,7 @@ impl VdafOps { match collect_job.state() { CollectJobState::Start => { - debug!(?collect_job_id, ?task_id, "Collect job has not run yet"); + debug!(%collect_job_id, task_id = %task.id(), "Collect job has not run yet"); Ok(None) } @@ -1783,8 +1789,8 @@ impl VdafOps { // been long enough since the encrypted helper share was cached -- tricky thing is // deciding what "long enough" is. debug!( - ?collect_job_id, - ?task_id, + %collect_job_id, + task_id = %task.id(), "Serving cached collect job response" ); let associated_data = associated_data_for_aggregate_share::( @@ -1792,8 +1798,12 @@ impl VdafOps { collect_job.batch_interval(), ); let encrypted_leader_aggregate_share = hpke::seal( - &task.collector_hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), + task.collector_hpke_config(), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Leader, + &Role::Collector, + ), &>::from(leader_aggregate_share), &associated_data, )?; @@ -1801,24 +1811,24 @@ impl VdafOps { Ok(Some(CollectResp::new( PartialBatchSelector::new_time_interval(), *report_count, - vec![ + Vec::::from([ encrypted_leader_aggregate_share, encrypted_helper_aggregate_share.clone(), - ], + ]), ))) } CollectJobState::Abandoned => { // TODO(#248): decide how to respond for abandoned collect jobs. warn!( - ?collect_job_id, - ?task_id, + %collect_job_id, + task_id = %task.id(), "Attempting to collect abandoned collect job" ); Ok(None) } - CollectJobState::Deleted => Err(Error::DeletedCollectJob(collect_job_id)), + CollectJobState::Deleted => Err(Error::DeletedCollectJob(*collect_job_id)), } } @@ -1914,7 +1924,7 @@ impl VdafOps { &self, datastore: &Datastore, task: Arc, - aggregate_share_req: AggregateShareReq, + aggregate_share_req: Arc>, ) -> Result { match self { VdafOps::Prio3Aes128Count(_, _) => { @@ -1966,7 +1976,7 @@ impl VdafOps { async fn handle_aggregate_share_generic, C: Clock>( datastore: &Datastore, task: Arc, - aggregate_share_req: AggregateShareReq, + aggregate_share_req: Arc>, ) -> Result where A::AggregationParam: Send + Sync, @@ -1974,7 +1984,6 @@ impl VdafOps { Vec: for<'a> From<&'a A::AggregateShare>, for<'a> >::Error: std::fmt::Display, { - let aggregate_share_req = Arc::new(aggregate_share_req); let aggregate_share_job = datastore .run_tx(|tx| { let (task, aggregate_share_req) = @@ -2007,11 +2016,11 @@ impl VdafOps { )?; let (batch_unit_aggregations, _) = try_join!( tx.get_batch_unit_aggregations_for_task_in_interval::( - &task.id, + task.id(), aggregate_share_req.batch_selector().batch_interval(), &aggregation_param, ), - validate_batch_lifetime_for_collect::( + validate_batch_query_count_for_collect::( tx, &task, *aggregate_share_req.batch_selector().batch_interval(), @@ -2023,10 +2032,11 @@ impl VdafOps { .await .map_err(|e| datastore::Error::User(e.into()))?; - // Now that we are satisfied that the request is serviceable, we consume batch lifetime by - // recording the aggregate share request parameters and the result. + // Now that we are satisfied that the request is serviceable, we consume + // a query by recording the aggregate share request parameters and the + // result. let aggregate_share_job = AggregateShareJob::::new( - task.id, + *task.id(), *aggregate_share_req.batch_selector().batch_interval(), aggregation_param, helper_aggregate_share, @@ -2065,8 +2075,8 @@ impl VdafOps { // config valid when the current AggregateShareReq was made, and not whatever was valid at // the time the aggregate share was first computed. let encrypted_aggregate_share = hpke::seal( - &task.collector_hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + task.collector_hpke_config(), + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), &>::from(aggregate_share_job.helper_aggregate_share()), &associated_data_for_aggregate_share::( aggregate_share_req.task_id(), @@ -2766,44 +2776,73 @@ async fn post_to_helper( #[cfg(test)] mod tests { - use super::*; use crate::{ + aggregator::{ + aggregator_filter, error_handler, post_to_helper, Aggregator, DapProblemType, + DapProblemTypeParseError, Error, + }, datastore::{ - models::BatchUnitAggregation, + models::{ + AggregateShareJob, AggregationJob, AggregationJobState, BatchUnitAggregation, + CollectJob, CollectJobState, ReportAggregation, ReportAggregationState, + }, test_util::{ephemeral_datastore, DbHandle}, + Datastore, }, messages::{DurationExt, TimeExt}, - task::{test_util::generate_auth_token, VdafInstance}, + task::{ + test_util::{generate_auth_token, TaskBuilder}, + QueryType, Task, VerifyKey, PRIO3_AES128_VERIFY_KEY_LENGTH, + }, }; use assert_matches::assert_matches; - use http::Method; + use http::{ + header::{CACHE_CONTROL, CONTENT_TYPE, LOCATION}, + Method, StatusCode, + }; use hyper::body; use janus_core::{ hpke::associated_data_for_report_share, hpke::{ - associated_data_for_aggregate_share, - test_util::generate_test_hpke_config_and_private_key, HpkePrivateKey, Label, + self, associated_data_for_aggregate_share, + test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, + HpkePrivateKey, Label, }, report_id::ReportIdChecksumExt, + task::{AuthenticationToken, VdafInstance}, test_util::{dummy_vdaf, install_test_trace_subscriber, run_vdaf}, - time::{MockClock, RealClock, TimeExt as CoreTimeExt}, + time::{Clock, MockClock, RealClock, TimeExt as _}, }; use janus_messages::{ - BatchSelector, Duration, HpkeCiphertext, HpkeConfig, Query, ReportId, ReportMetadata, - TaskId, Time, + query_type::TimeInterval, AggregateContinueReq, AggregateContinueResp, + AggregateInitializeReq, AggregateInitializeResp, AggregateShareReq, AggregateShareResp, + BatchSelector, CollectReq, CollectResp, Duration, HpkeCiphertext, HpkeConfig, HpkeConfigId, + Interval, PartialBatchSelector, PrepareStep, PrepareStepResult, Query, Report, ReportId, + ReportIdChecksum, ReportMetadata, ReportShare, ReportShareError, Role, TaskId, Time, }; use mockito::mock; use opentelemetry::global::meter; use prio::{ - codec::Decode, + codec::{Decode, Encode}, field::Field64, - vdaf::{prio3::Prio3Aes128Count, AggregateShare, Aggregator as _}, + vdaf::{ + self, + prio3::{Prio3, Prio3Aes128Count}, + AggregateShare, Aggregator as _, PrepareTransition, + }, }; use rand::random; + use reqwest::Client; use serde_json::json; - use std::{collections::HashMap, io::Cursor}; + use std::{collections::HashMap, io::Cursor, sync::Arc}; + use url::Url; use uuid::Uuid; - use warp::{cors::CorsForbidden, reply::Reply, Rejection}; + use warp::{ + cors::CorsForbidden, + filters::BoxedFilter, + reply::{Reply, Response}, + Filter, Rejection, + }; const DUMMY_VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; @@ -2811,19 +2850,19 @@ mod tests { async fn hpke_config() { install_test_trace_subscriber(); - let task_id = random(); - let unknown_task_id: TaskId = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); + let unknown_task_id: TaskId = random(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; datastore.put_task(&task).await.unwrap(); - let want_hpke_key = current_hpke_key(&task.hpke_keys).clone(); + let want_hpke_key = current_hpke_key(task.hpke_keys()).clone(); let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); @@ -2879,7 +2918,7 @@ mod tests { // Recognized task ID provided let response = warp::test::request() - .path(&format!("/hpke_config?task_id={task_id}")) + .path(&format!("/hpke_config?task_id={}", task.id())) .method("GET") .filter(&filter) .await @@ -2901,7 +2940,7 @@ mod tests { assert_eq!(hpke_config, want_hpke_key.0); let application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); let message = b"this is a message"; let associated_data = b"some associated data"; @@ -2922,12 +2961,12 @@ mod tests { async fn hpke_config_cors_headers() { install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -2938,7 +2977,7 @@ mod tests { // Check for appropriate CORS headers in response to a preflight request. let response = warp::test::request() .method("OPTIONS") - .path(&format!("/hpke_config?task_id={}", task_id)) + .path(&format!("/hpke_config?task_id={}", task.id())) .header("origin", "https://example.com/") .header("access-control-request-method", "GET") .filter(&filter) @@ -2957,7 +2996,7 @@ mod tests { // Check for appropriate CORS headers with a simple GET request. let response = warp::test::request() .method("GET") - .path(&format!("/hpke_config?task_id={}", task_id)) + .path(&format!("/hpke_config?task_id={}", task.id())) .header("origin", "https://example.com/") .filter(&filter) .await @@ -2980,37 +3019,37 @@ mod tests { ) -> Report { datastore.put_task(task).await.unwrap(); - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); let report_metadata = ReportMetadata::new( random(), - clock.now().sub(&task.tolerable_clock_skew).unwrap(), - vec![], + clock.now().sub(task.tolerable_clock_skew()).unwrap(), + Vec::new(), ); let public_share = b"public share".to_vec(); let message = b"this is a message"; let associated_data = - associated_data_for_report_share(task.id, &report_metadata, &public_share); + associated_data_for_report_share(task.id(), &report_metadata, &public_share); let leader_ciphertext = hpke::seal( &hpke_key.0, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), message, &associated_data, ) .unwrap(); let helper_ciphertext = hpke::seal( &hpke_key.0, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), message, &associated_data, ) .unwrap(); Report::new( - task.id, + *task.id(), report_metadata, public_share, - vec![leader_ciphertext, helper_ciphertext], + Vec::from([leader_ciphertext, helper_ciphertext]), ) } @@ -3035,11 +3074,12 @@ mod tests { async fn upload_filter() { install_test_trace_subscriber(); - let task = Task::new_dummy( - random(), - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -3148,7 +3188,7 @@ mod tests { let bad_report = Report::new( *report.task_id(), ReportMetadata::new( - *report.metadata().report_id(), + *report.metadata().id(), bad_report_time, report.metadata().extensions().to_vec(), ), @@ -3211,7 +3251,7 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ), @@ -3239,12 +3279,12 @@ mod tests { async fn upload_filter_helper() { install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -3274,7 +3314,7 @@ mod tests { "title": "An endpoint received a message with an unknown task ID.", "detail": "An endpoint received a message with an unknown task ID.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!( @@ -3296,17 +3336,18 @@ mod tests { Arc>, DbHandle, ) { - let task = Task::new_dummy( - random(), - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); let report = setup_report(&task, &datastore, &clock).await; - let aggregator = Aggregator::new(datastore.clone(), clock, meter("janus_server")); + let aggregator = Aggregator::new(Arc::clone(&datastore), clock, meter("janus_server")); (aggregator, task, report, datastore, db_handle) } @@ -3324,7 +3365,7 @@ mod tests { let got_report = datastore .run_tx(|tx| { - let (task_id, report_id) = (*report.task_id(), *report.metadata().report_id()); + let (task_id, report_id) = (*report.task_id(), *report.metadata().id()); Box::pin(async move { tx.get_client_report(&task_id, &report_id).await }) }) .await @@ -3335,7 +3376,7 @@ mod tests { // TODO(#34): change this error type. assert_matches!(aggregator.handle_upload(&report.get_encoded()).await, Err(Error::ReportTooLate(task_id, stale_report_id, stale_time)) => { assert_eq!(&task_id, report.task_id()); - assert_eq!(report.metadata().report_id(), &stale_report_id); + assert_eq!(report.metadata().id(), &stale_report_id); assert_eq!(report.metadata().time(), &stale_time); }); } @@ -3367,7 +3408,7 @@ mod tests { let unused_hpke_config_id = (0..) .map(HpkeConfigId::from) - .find(|id| !task.hpke_keys.contains_key(id)) + .find(|id| !task.hpke_keys().contains_key(id)) .unwrap(); let report = Report::new( @@ -3395,14 +3436,14 @@ mod tests { fn reencrypt_report(report: Report, hpke_config: &HpkeConfig) -> Report { let message = b"this is a message"; let associated_data = associated_data_for_report_share( - *report.task_id(), + report.task_id(), report.metadata(), report.public_share(), ); let leader_ciphertext = hpke::seal( hpke_config, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader), message, &associated_data, ) @@ -3410,7 +3451,7 @@ mod tests { let helper_ciphertext = hpke::seal( hpke_config, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Helper), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), message, &associated_data, ) @@ -3435,18 +3476,18 @@ mod tests { Report::new( *report.task_id(), ReportMetadata::new( - *report.metadata().report_id(), + *report.metadata().id(), aggregator .clock .now() - .add(&task.tolerable_clock_skew) + .add(task.tolerable_clock_skew()) .unwrap(), report.metadata().extensions().to_vec(), ), report.public_share().to_vec(), report.encrypted_input_shares().to_vec(), ), - &task.hpke_keys.values().next().unwrap().0, + &task.hpke_keys().values().next().unwrap().0, ); aggregator .handle_upload(&report.get_encoded()) @@ -3455,7 +3496,7 @@ mod tests { let got_report = datastore .run_tx(|tx| { - let (task_id, report_id) = (*report.task_id(), *report.metadata().report_id()); + let (task_id, report_id) = (*report.task_id(), *report.metadata().id()); Box::pin(async move { tx.get_client_report(&task_id, &report_id).await }) }) .await @@ -3467,11 +3508,11 @@ mod tests { Report::new( *report.task_id(), ReportMetadata::new( - *report.metadata().report_id(), + *report.metadata().id(), aggregator .clock .now() - .add(&task.tolerable_clock_skew) + .add(task.tolerable_clock_skew()) .unwrap() .add(&Duration::from_seconds(1)) .unwrap(), @@ -3480,11 +3521,11 @@ mod tests { report.public_share().to_vec(), report.encrypted_input_shares().to_vec(), ), - &task.hpke_keys.values().next().unwrap().0, + &task.hpke_keys().values().next().unwrap().0, ); assert_matches!(aggregator.handle_upload(&report.get_encoded()).await, Err(Error::ReportTooEarly(task_id, report_id, time)) => { assert_eq!(&task_id, report.task_id()); - assert_eq!(report.metadata().report_id(), &report_id); + assert_eq!(report.metadata().id(), &report_id); assert_eq!(report.metadata().time(), &time); }); } @@ -3494,23 +3535,23 @@ mod tests { install_test_trace_subscriber(); let (aggregator, task, report, datastore, _db_handle) = setup_upload_test().await; - let task_id = task.id; // Insert a collect job for the batch interval including our report. let batch_interval = Interval::new( report .metadata() .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), - task.min_batch_duration, + *task.time_precision(), ) .unwrap(); datastore .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), batch_interval, (), @@ -3524,8 +3565,8 @@ mod tests { // Try to upload the report, verify that we get the expected error. assert_matches!(aggregator.handle_upload(&report.get_encoded()).await.unwrap_err(), Error::ReportTooLate(err_task_id, err_report_id, err_time) => { - assert_eq!(task_id, err_task_id); - assert_eq!(report.metadata().report_id(), &err_report_id); + assert_eq!(report.task_id(), &err_task_id); + assert_eq!(report.metadata().id(), &err_report_id); assert_eq!(report.metadata().time(), &err_time); }); } @@ -3534,19 +3575,19 @@ mod tests { async fn aggregate_leader() { install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; datastore.put_task(&task).await.unwrap(); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), Vec::new(), PartialBatchSelector::new_time_interval(), @@ -3584,7 +3625,7 @@ mod tests { "title": "An endpoint received a message with an unknown task ID.", "detail": "An endpoint received a message with an unknown task ID.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!( @@ -3618,19 +3659,19 @@ mod tests { async fn aggregate_wrong_agg_auth_token() { install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; datastore.put_task(&task).await.unwrap(); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), Vec::new(), PartialBatchSelector::new_time_interval(), @@ -3665,7 +3706,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, parts.status.as_u16()); @@ -3695,7 +3736,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, parts.status.as_u16()); @@ -3706,26 +3747,26 @@ mod tests { // Prepare datastore & request. install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let vdaf = Prio3::new_aes128_count(2).unwrap(); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -3733,12 +3774,12 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_0.report_id(), + report_metadata_0.id(), &0, ); let input_share = transcript.input_shares[1].clone(); let report_share_0 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_0, &hpke_key.0, &transcript.public_share, @@ -3750,7 +3791,7 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -3775,14 +3816,14 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); let mut input_share_bytes = input_share.get_encoded(); input_share_bytes.push(0); // can no longer be decoded. let aad = - associated_data_for_report_share(task_id, &report_metadata_2, &encoded_public_share); + associated_data_for_report_share(task.id(), &report_metadata_2, &encoded_public_share); let report_share_2 = generate_helper_report_share_for_plaintext( report_metadata_2, &hpke_key.0, @@ -3796,19 +3837,19 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); let wrong_hpke_config = loop { let hpke_config = generate_test_hpke_config_and_private_key().0; - if task.hpke_keys.contains_key(hpke_config.id()) { + if task.hpke_keys().contains_key(hpke_config.id()) { continue; } break hpke_config; }; let report_share_3 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_3, &wrong_hpke_config, &transcript.public_share, @@ -3820,7 +3861,7 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -3828,12 +3869,12 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_4.report_id(), + report_metadata_4.id(), &0, ); let input_share = transcript.input_shares[1].clone(); let report_share_4 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_4, &hpke_key.0, &transcript.public_share, @@ -3842,13 +3883,13 @@ mod tests { // report_share_5 falls into a batch unit that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( - task.min_batch_duration.as_seconds() / 2, + task.time_precision().as_seconds() / 2, )); let report_metadata_5 = ReportMetadata::new( random(), past_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -3856,12 +3897,12 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_5.report_id(), + report_metadata_5.id(), &0, ); let input_share = transcript.input_shares[1].clone(); let report_share_5 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_5, &hpke_key.0, &transcript.public_share, @@ -3874,11 +3915,11 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); - let aad = associated_data_for_report_share(task_id, &report_metadata_6, &public_share_6); + let aad = associated_data_for_report_share(task.id(), &report_metadata_6, &public_share_6); let report_share_6 = generate_helper_report_share_for_plaintext( report_metadata_6, &hpke_key.0, @@ -3892,13 +3933,13 @@ mod tests { let (task, report_share_4) = (task.clone(), report_share_4.clone()); Box::pin(async move { tx.put_task(&task).await?; - tx.put_report_share(&task_id, &report_share_4).await?; + tx.put_report_share(task.id(), &report_share_4).await?; tx.put_aggregate_share_job::( &AggregateShareJob::new( - task_id, + *task.id(), Interval::new( Time::from_seconds_since_epoch(0), - task.min_batch_duration, + *task.time_precision(), ) .unwrap(), (), @@ -3914,7 +3955,7 @@ mod tests { .unwrap(); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), Vec::new(), PartialBatchSelector::new_time_interval(), @@ -3960,67 +4001,46 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 7); let prepare_step_0 = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!( - prepare_step_0.report_id(), - report_share_0.metadata().report_id() - ); + assert_eq!(prepare_step_0.report_id(), report_share_0.metadata().id()); assert_matches!(prepare_step_0.result(), &PrepareStepResult::Continued(..)); let prepare_step_1 = aggregate_resp.prepare_steps().get(1).unwrap(); - assert_eq!( - prepare_step_1.report_id(), - report_share_1.metadata().report_id() - ); + assert_eq!(prepare_step_1.report_id(), report_share_1.metadata().id()); assert_matches!( prepare_step_1.result(), &PrepareStepResult::Failed(ReportShareError::HpkeDecryptError) ); let prepare_step_2 = aggregate_resp.prepare_steps().get(2).unwrap(); - assert_eq!( - prepare_step_2.report_id(), - report_share_2.metadata().report_id() - ); + assert_eq!(prepare_step_2.report_id(), report_share_2.metadata().id()); assert_matches!( prepare_step_2.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) ); let prepare_step_6 = aggregate_resp.prepare_steps().get(6).unwrap(); - assert_eq!( - prepare_step_6.report_id(), - report_share_6.metadata().report_id() - ); + assert_eq!(prepare_step_6.report_id(), report_share_6.metadata().id()); assert_matches!( prepare_step_6.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) ); let prepare_step_3 = aggregate_resp.prepare_steps().get(3).unwrap(); - assert_eq!( - prepare_step_3.report_id(), - report_share_3.metadata().report_id() - ); + assert_eq!(prepare_step_3.report_id(), report_share_3.metadata().id()); assert_matches!( prepare_step_3.result(), &PrepareStepResult::Failed(ReportShareError::HpkeUnknownConfigId) ); let prepare_step_4 = aggregate_resp.prepare_steps().get(4).unwrap(); - assert_eq!( - prepare_step_4.report_id(), - report_share_4.metadata().report_id() - ); + assert_eq!(prepare_step_4.report_id(), report_share_4.metadata().id()); assert_eq!( prepare_step_4.result(), &PrepareStepResult::Failed(ReportShareError::ReportReplayed) ); let prepare_step_5 = aggregate_resp.prepare_steps().get(5).unwrap(); - assert_eq!( - prepare_step_5.report_id(), - report_share_5.metadata().report_id() - ); + assert_eq!(prepare_step_5.report_id(), report_share_5.metadata().id()); assert_eq!( prepare_step_5.result(), &PrepareStepResult::Failed(ReportShareError::BatchCollected) @@ -4032,21 +4052,25 @@ mod tests { // Prepare datastore & request. install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::FakeFailsPrepInit, Role::Helper); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::FakeFailsPrepInit, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); datastore.put_task(&task).await.unwrap(); let report_share = generate_helper_report_share::( - task_id, + task.id(), &ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ), @@ -4055,7 +4079,7 @@ mod tests { &(), ); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), @@ -4093,10 +4117,7 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!( - prepare_step.report_id(), - report_share.metadata().report_id() - ); + assert_eq!(prepare_step.report_id(), report_share.metadata().id()); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -4108,21 +4129,25 @@ mod tests { // Prepare datastore & request. install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::FakeFailsPrepInit, Role::Helper); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::FakeFailsPrepInit, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); datastore.put_task(&task).await.unwrap(); let report_share = generate_helper_report_share::( - task_id, + task.id(), &ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ), @@ -4131,7 +4156,7 @@ mod tests { &(), ); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), @@ -4169,10 +4194,7 @@ mod tests { assert_eq!(aggregate_resp.prepare_steps().len(), 1); let prepare_step = aggregate_resp.prepare_steps().get(0).unwrap(); - assert_eq!( - prepare_step.report_id(), - report_share.metadata().report_id() - ); + assert_eq!(prepare_step.report_id(), report_share.metadata().id()); assert_matches!( prepare_step.result(), &PrepareStepResult::Failed(ReportShareError::VdafPrepError) @@ -4183,8 +4205,12 @@ mod tests { async fn aggregate_init_duplicated_report_id() { install_test_trace_subscriber(); - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::FakeFailsPrepInit, Role::Helper); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::FakeFailsPrepInit, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -4206,7 +4232,7 @@ mod tests { ); let request = AggregateInitializeReq::new( - task_id, + *task.id(), random(), dummy_vdaf::AggregationParam(0).get_encoded(), PartialBatchSelector::new_time_interval(), @@ -4244,7 +4270,7 @@ mod tests { "title": "The message type for a response was incorrect or the payload was malformed.", "detail": "The message type for a response was incorrect or the payload was malformed.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, parts.status.as_u16()); @@ -4255,13 +4281,13 @@ mod tests { // Prepare datastore & request. install_test_trace_subscriber(); - let task_id = random(); let aggregation_job_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); @@ -4269,14 +4295,14 @@ mod tests { let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4284,7 +4310,7 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata_0.report_id(), + report_metadata_0.id(), &0, ); let prep_state_0 = assert_matches!( @@ -4299,7 +4325,7 @@ mod tests { ); let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_0, &hpke_key.0, &transcript_0.public_share, @@ -4311,7 +4337,7 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4319,12 +4345,12 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata_1.report_id(), + report_metadata_1.id(), &0, ); let prep_state_1 = assert_matches!(&transcript_1.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let report_share_1 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_1, &hpke_key.0, &transcript_1.public_share, @@ -4333,13 +4359,13 @@ mod tests { // report_share_2 falls into a batch unit that has already been collected. let past_clock = MockClock::new(Time::from_seconds_since_epoch( - task.min_batch_duration.as_seconds() / 2, + task.time_precision().as_seconds() / 2, )); let report_metadata_2 = ReportMetadata::new( random(), past_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4347,13 +4373,13 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata_2.report_id(), + report_metadata_2.id(), &0, ); let prep_state_2 = assert_matches!(&transcript_2.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let prep_msg_2 = transcript_2.prepare_messages[0].clone(); let report_share_2 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_2, &hpke_key.0, &transcript_2.public_share, @@ -4382,15 +4408,15 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; - tx.put_report_share(&task_id, &report_share_0).await?; - tx.put_report_share(&task_id, &report_share_1).await?; - tx.put_report_share(&task_id, &report_share_2).await?; + tx.put_report_share(task.id(), &report_share_0).await?; + tx.put_report_share(task.id(), &report_share_1).await?; + tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -4399,9 +4425,9 @@ mod tests { tx.put_report_aggregation::( &ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_0.report_id(), + *report_metadata_0.id(), *report_metadata_0.time(), 0, ReportAggregationState::Waiting(prep_state_0, None), @@ -4410,9 +4436,9 @@ mod tests { .await?; tx.put_report_aggregation::( &ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_1.report_id(), + *report_metadata_1.id(), *report_metadata_1.time(), 1, ReportAggregationState::Waiting(prep_state_1, None), @@ -4421,9 +4447,9 @@ mod tests { .await?; tx.put_report_aggregation::( &ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_2.report_id(), + *report_metadata_2.id(), *report_metadata_2.time(), 2, ReportAggregationState::Waiting(prep_state_2, None), @@ -4433,10 +4459,10 @@ mod tests { tx.put_aggregate_share_job::( &AggregateShareJob::new( - task_id, + *task.id(), Interval::new( Time::from_seconds_since_epoch(0), - task.min_batch_duration, + *task.time_precision(), ) .unwrap(), (), @@ -4452,15 +4478,15 @@ mod tests { .unwrap(); let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([ PrepareStep::new( - *report_metadata_0.report_id(), + *report_metadata_0.id(), PrepareStepResult::Continued(prep_msg_0.get_encoded()), ), PrepareStep::new( - *report_metadata_2.report_id(), + *report_metadata_2.id(), PrepareStepResult::Continued(prep_msg_2.get_encoded()), ), ]), @@ -4494,9 +4520,9 @@ mod tests { assert_eq!( aggregate_resp, AggregateContinueResp::new(Vec::from([ - PrepareStep::new(*report_metadata_0.report_id(), PrepareStepResult::Finished), + PrepareStep::new(*report_metadata_0.id(), PrepareStepResult::Finished), PrepareStep::new( - *report_metadata_2.report_id(), + *report_metadata_2.id(), PrepareStepResult::Failed(ReportShareError::BatchCollected), ) ])) @@ -4505,12 +4531,11 @@ mod tests { // Validate datastore. let (aggregation_job, report_aggregations) = datastore .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await?; @@ -4518,7 +4543,7 @@ mod tests { .get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Helper, - &task_id, + task.id(), &aggregation_job_id, ) .await?; @@ -4531,7 +4556,7 @@ mod tests { assert_eq!( aggregation_job, Some(AggregationJob::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::Finished, @@ -4541,25 +4566,25 @@ mod tests { report_aggregations, Vec::from([ ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_0.report_id(), + *report_metadata_0.id(), *report_metadata_0.time(), 0, ReportAggregationState::Finished(out_share_0.clone()), ), ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_1.report_id(), + *report_metadata_1.id(), *report_metadata_1.time(), 1, ReportAggregationState::Failed(ReportShareError::ReportDropped), ), ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_2.report_id(), + *report_metadata_2.id(), *report_metadata_2.time(), 2, ReportAggregationState::Failed(ReportShareError::BatchCollected), @@ -4572,35 +4597,35 @@ mod tests { async fn aggregate_continue_accumulate_batch_unit_aggregation() { install_test_trace_subscriber(); - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Helper, + ) + .build(); let aggregation_job_id_0 = random(); let aggregation_job_id_1 = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Helper, - ); let (datastore, _db_handle) = ephemeral_datastore(MockClock::default()).await; let datastore = Arc::new(datastore); let first_batch_unit_interval_clock = MockClock::default(); let second_batch_unit_interval_clock = MockClock::new( first_batch_unit_interval_clock .now() - .add(&task.min_batch_duration) + .add(task.time_precision()) .unwrap(), ); let vdaf = Prio3::new_aes128_count(2).unwrap(); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); - let hpke_key = current_hpke_key(&task.hpke_keys); + let hpke_key = current_hpke_key(task.hpke_keys()); // report_share_0 is a "happy path" report. let report_metadata_0 = ReportMetadata::new( random(), first_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4608,14 +4633,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_0.report_id(), + report_metadata_0.id(), &0, ); let prep_state_0 = assert_matches!(&transcript_0.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_0 = assert_matches!(&transcript_0.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_0 = transcript_0.prepare_messages[0].clone(); let report_share_0 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_0, &hpke_key.0, &transcript_0.public_share, @@ -4628,7 +4653,7 @@ mod tests { random(), first_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4636,14 +4661,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_1.report_id(), + report_metadata_1.id(), &0, ); let prep_state_1 = assert_matches!(&transcript_1.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_1 = assert_matches!(&transcript_1.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_1 = transcript_1.prepare_messages[0].clone(); let report_share_1 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_1, &hpke_key.0, &transcript_1.public_share, @@ -4655,7 +4680,7 @@ mod tests { random(), second_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4663,14 +4688,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_2.report_id(), + report_metadata_2.id(), &0, ); let prep_state_2 = assert_matches!(&transcript_2.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_2 = assert_matches!(&transcript_2.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_2 = transcript_2.prepare_messages[0].clone(); let report_share_2 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_2, &hpke_key.0, &transcript_2.public_share, @@ -4699,15 +4724,15 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; - tx.put_report_share(&task_id, &report_share_0).await?; - tx.put_report_share(&task_id, &report_share_1).await?; - tx.put_report_share(&task_id, &report_share_2).await?; + tx.put_report_share(task.id(), &report_share_0).await?; + tx.put_report_share(task.id(), &report_share_1).await?; + tx.put_report_share(task.id(), &report_share_2).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_0, (), AggregationJobState::InProgress, @@ -4718,9 +4743,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_0, - *report_metadata_0.report_id(), + *report_metadata_0.id(), *report_metadata_0.time(), 0, ReportAggregationState::Waiting(prep_state_0, None), @@ -4730,9 +4755,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_0, - *report_metadata_1.report_id(), + *report_metadata_1.id(), *report_metadata_1.time(), 1, ReportAggregationState::Waiting(prep_state_1, None), @@ -4742,9 +4767,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_0, - *report_metadata_2.report_id(), + *report_metadata_2.id(), *report_metadata_2.time(), 2, ReportAggregationState::Waiting(prep_state_2, None), @@ -4758,19 +4783,19 @@ mod tests { .unwrap(); let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id_0, Vec::from([ PrepareStep::new( - *report_metadata_0.report_id(), + *report_metadata_0.id(), PrepareStepResult::Continued(prep_msg_0.get_encoded()), ), PrepareStep::new( - *report_metadata_1.report_id(), + *report_metadata_1.id(), PrepareStepResult::Continued(prep_msg_1.get_encoded()), ), PrepareStep::new( - *report_metadata_2.report_id(), + *report_metadata_2.id(), PrepareStepResult::Continued(prep_msg_2.get_encoded()), ), ]), @@ -4801,17 +4826,17 @@ mod tests { let batch_unit_aggregations = datastore .run_tx(|tx| { - let report_metadata_0 = report_metadata_0.clone(); + let (task, report_metadata_0) = (task.clone(), report_metadata_0.clone()); Box::pin(async move { tx.get_batch_unit_aggregations_for_task_in_interval::( - &task_id, + task.id(), &Interval::new( report_metadata_0 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), // Make interval big enough to capture both batch unit aggregations - Duration::from_seconds(task.min_batch_duration.as_seconds() * 2), + Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), &(), @@ -4825,17 +4850,17 @@ mod tests { let aggregate_share = vdaf .aggregate(&(), [out_share_0.clone(), out_share_1.clone()]) .unwrap(); - let checksum = ReportIdChecksum::for_report_id(report_metadata_0.report_id()) - .updated_with(report_metadata_1.report_id()); + let checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) + .updated_with(report_metadata_1.id()); assert_eq!( batch_unit_aggregations, Vec::from([ BatchUnitAggregation::::new( - task_id, + *task.id(), report_metadata_0 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), (), aggregate_share, @@ -4843,15 +4868,15 @@ mod tests { checksum, ), BatchUnitAggregation::::new( - task_id, + *task.id(), report_metadata_2 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), (), AggregateShare::from(out_share_2.clone()), 1, - ReportIdChecksum::for_report_id(report_metadata_2.report_id()), + ReportIdChecksum::for_report_id(report_metadata_2.id()), ), ]) ); @@ -4863,7 +4888,7 @@ mod tests { random(), first_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4871,14 +4896,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_3.report_id(), + report_metadata_3.id(), &0, ); let prep_state_3 = assert_matches!(&transcript_3.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_3 = assert_matches!(&transcript_3.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_3 = transcript_3.prepare_messages[0].clone(); let report_share_3 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_3, &hpke_key.0, &transcript_3.public_share, @@ -4890,7 +4915,7 @@ mod tests { random(), second_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4898,14 +4923,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_4.report_id(), + report_metadata_4.id(), &0, ); let prep_state_4 = assert_matches!(&transcript_4.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_4 = assert_matches!(&transcript_4.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_4 = transcript_4.prepare_messages[0].clone(); let report_share_4 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_4, &hpke_key.0, &transcript_4.public_share, @@ -4917,7 +4942,7 @@ mod tests { random(), second_batch_unit_interval_clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -4925,14 +4950,14 @@ mod tests { &vdaf, verify_key.as_bytes(), &(), - report_metadata_5.report_id(), + report_metadata_5.id(), &0, ); let prep_state_5 = assert_matches!(&transcript_5.prepare_transitions[1][0], PrepareTransition::::Continue(prep_state, _) => prep_state.clone()); let out_share_5 = assert_matches!(&transcript_5.prepare_transitions[1][1], PrepareTransition::::Finish(out_share) => out_share.clone()); let prep_msg_5 = transcript_5.prepare_messages[0].clone(); let report_share_5 = generate_helper_report_share::( - task_id, + task.id(), &report_metadata_5, &hpke_key.0, &transcript_5.public_share, @@ -4941,6 +4966,7 @@ mod tests { datastore .run_tx(|tx| { + let task = task.clone(); let (report_share_3, report_share_4, report_share_5) = ( report_share_3.clone(), report_share_4.clone(), @@ -4958,15 +4984,15 @@ mod tests { ); Box::pin(async move { - tx.put_report_share(&task_id, &report_share_3).await?; - tx.put_report_share(&task_id, &report_share_4).await?; - tx.put_report_share(&task_id, &report_share_5).await?; + tx.put_report_share(task.id(), &report_share_3).await?; + tx.put_report_share(task.id(), &report_share_4).await?; + tx.put_report_share(task.id(), &report_share_5).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_1, (), AggregationJobState::InProgress, @@ -4977,9 +5003,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_1, - *report_metadata_3.report_id(), + *report_metadata_3.id(), *report_metadata_3.time(), 3, ReportAggregationState::Waiting(prep_state_3, None), @@ -4989,9 +5015,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_1, - *report_metadata_4.report_id(), + *report_metadata_4.id(), *report_metadata_4.time(), 4, ReportAggregationState::Waiting(prep_state_4, None), @@ -5001,9 +5027,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id_1, - *report_metadata_5.report_id(), + *report_metadata_5.id(), *report_metadata_5.time(), 5, ReportAggregationState::Waiting(prep_state_5, None), @@ -5017,19 +5043,19 @@ mod tests { .unwrap(); let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id_1, Vec::from([ PrepareStep::new( - *report_metadata_3.report_id(), + *report_metadata_3.id(), PrepareStepResult::Continued(prep_msg_3.get_encoded()), ), PrepareStep::new( - *report_metadata_4.report_id(), + *report_metadata_4.id(), PrepareStepResult::Continued(prep_msg_4.get_encoded()), ), PrepareStep::new( - *report_metadata_5.report_id(), + *report_metadata_5.id(), PrepareStepResult::Continued(prep_msg_5.get_encoded()), ), ]), @@ -5059,17 +5085,17 @@ mod tests { let batch_unit_aggregations = datastore .run_tx(|tx| { - let report_metadata_0 = report_metadata_0.clone(); + let (task, report_metadata_0) = (task.clone(), report_metadata_0.clone()); Box::pin(async move { tx.get_batch_unit_aggregations_for_task_in_interval::( - &task_id, + task.id(), &Interval::new( report_metadata_0 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), // Make interval big enough to capture both batch unit aggregations - Duration::from_seconds(task.min_batch_duration.as_seconds() * 2), + Duration::from_seconds(task.time_precision().as_seconds() * 2), ) .unwrap(), &(), @@ -5083,25 +5109,25 @@ mod tests { let first_aggregate_share = vdaf .aggregate(&(), [out_share_0, out_share_1, out_share_3]) .unwrap(); - let first_checksum = ReportIdChecksum::for_report_id(report_metadata_0.report_id()) - .updated_with(report_metadata_1.report_id()) - .updated_with(report_metadata_3.report_id()); + let first_checksum = ReportIdChecksum::for_report_id(report_metadata_0.id()) + .updated_with(report_metadata_1.id()) + .updated_with(report_metadata_3.id()); let second_aggregate_share = vdaf .aggregate(&(), [out_share_2, out_share_4, out_share_5]) .unwrap(); - let second_checksum = ReportIdChecksum::for_report_id(report_metadata_2.report_id()) - .updated_with(report_metadata_4.report_id()) - .updated_with(report_metadata_5.report_id()); + let second_checksum = ReportIdChecksum::for_report_id(report_metadata_2.id()) + .updated_with(report_metadata_4.id()) + .updated_with(report_metadata_5.id()); assert_eq!( batch_unit_aggregations, Vec::from([ BatchUnitAggregation::::new( - task_id, + *task.id(), report_metadata_0 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), (), first_aggregate_share, @@ -5109,10 +5135,10 @@ mod tests { first_checksum, ), BatchUnitAggregation::::new( - task_id, + *task.id(), report_metadata_2 .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), (), second_aggregate_share, @@ -5129,14 +5155,18 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), Vec::new(), ); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); @@ -5148,7 +5178,7 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata.clone(), Vec::from("Public Share"), @@ -5165,7 +5195,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::InProgress, @@ -5175,9 +5205,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Waiting((), None), @@ -5190,10 +5220,10 @@ mod tests { // Make request. let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([PrepareStep::new( - *report_metadata.report_id(), + *report_metadata.id(), PrepareStepResult::Finished, )]), ); @@ -5227,7 +5257,7 @@ mod tests { "title": "The message type for a response was incorrect or the payload was malformed.", "detail": "The message type for a response was incorrect or the payload was malformed.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5238,14 +5268,18 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::FakeFailsPrepStep, + Role::Helper, + ) + .build(); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), - vec![], + Vec::new(), ); - let task = Task::new_dummy(task_id, VdafInstance::FakeFailsPrepStep, Role::Helper); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; let datastore = Arc::new(datastore); @@ -5258,7 +5292,7 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata.clone(), Vec::from("public share"), @@ -5274,7 +5308,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::InProgress, @@ -5284,9 +5318,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Waiting((), None), @@ -5299,10 +5333,10 @@ mod tests { // Make request. let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([PrepareStep::new( - *report_metadata.report_id(), + *report_metadata.id(), PrepareStepResult::Continued(Vec::new()), )]), ); @@ -5335,7 +5369,7 @@ mod tests { assert_eq!( aggregate_resp, AggregateContinueResp::new(Vec::from([PrepareStep::new( - *report_metadata.report_id(), + *report_metadata.id(), PrepareStepResult::Failed(ReportShareError::VdafPrepError), )]),) ); @@ -5343,11 +5377,11 @@ mod tests { // Check datastore state. let (aggregation_job, report_aggregation) = datastore .run_tx(|tx| { - let report_metadata = report_metadata.clone(); + let (task, report_metadata) = (task.clone(), report_metadata.clone()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await?; @@ -5355,9 +5389,9 @@ mod tests { .get_report_aggregation( &dummy_vdaf::Vdaf::default(), &Role::Helper, - &task_id, + task.id(), &aggregation_job_id, - report_metadata.report_id(), + report_metadata.id(), ) .await?; Ok((aggregation_job, report_aggregation)) @@ -5369,7 +5403,7 @@ mod tests { assert_eq!( aggregation_job, Some(AggregationJob::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::Finished, @@ -5378,9 +5412,9 @@ mod tests { assert_eq!( report_aggregation, Some(ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Failed(ReportShareError::VdafPrepError), @@ -5394,14 +5428,18 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), - vec![], + Vec::new(), ); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5413,7 +5451,7 @@ mod tests { Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata.clone(), Vec::from("PUBLIC"), @@ -5429,7 +5467,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::InProgress, @@ -5439,9 +5477,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Waiting((), None), @@ -5454,7 +5492,7 @@ mod tests { // Make request. let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([PrepareStep::new( ReportId::from( @@ -5493,7 +5531,7 @@ mod tests { "title": "The message type for a response was incorrect or the payload was malformed.", "detail": "The message type for a response was incorrect or the payload was malformed.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5504,20 +5542,24 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let aggregation_job_id = random(); let report_metadata_0 = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), - vec![], + Vec::new(), ); let report_metadata_1 = ReportMetadata::new( ReportId::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]), Time::from_seconds_since_epoch(54321), - vec![], + Vec::new(), ); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5534,7 +5576,7 @@ mod tests { tx.put_task(&task).await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata_0.clone(), Vec::from("public"), @@ -5547,7 +5589,7 @@ mod tests { ) .await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata_1.clone(), Vec::from("public"), @@ -5564,7 +5606,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::InProgress, @@ -5575,9 +5617,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_0.report_id(), + *report_metadata_0.id(), *report_metadata_0.time(), 0, ReportAggregationState::Waiting((), None), @@ -5587,9 +5629,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata_1.report_id(), + *report_metadata_1.id(), *report_metadata_1.time(), 1, ReportAggregationState::Waiting((), None), @@ -5602,16 +5644,16 @@ mod tests { // Make request. let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([ // Report IDs are in opposite order to what was stored in the datastore. PrepareStep::new( - *report_metadata_1.report_id(), + *report_metadata_1.id(), PrepareStepResult::Continued(Vec::new()), ), PrepareStep::new( - *report_metadata_0.report_id(), + *report_metadata_0.id(), PrepareStepResult::Continued(Vec::new()), ), ]), @@ -5646,7 +5688,7 @@ mod tests { "title": "The message type for a response was incorrect or the payload was malformed.", "detail": "The message type for a response was incorrect or the payload was malformed.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5657,15 +5699,19 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(54321), - vec![], + Vec::new(), ); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5673,11 +5719,10 @@ mod tests { datastore .run_tx(|tx| { let (task, report_metadata) = (task.clone(), report_metadata.clone()); - Box::pin(async move { tx.put_task(&task).await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( report_metadata.clone(), Vec::from("public share"), @@ -5693,7 +5738,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, dummy_vdaf::AggregationParam(0), AggregationJobState::InProgress, @@ -5703,9 +5748,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Invalid, @@ -5718,7 +5763,7 @@ mod tests { // Make request. let request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([PrepareStep::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), @@ -5755,7 +5800,7 @@ mod tests { "title": "The message type for a response was incorrect or the payload was malformed.", "detail": "The message type for a response was incorrect or the payload was malformed.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5765,9 +5810,12 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5776,9 +5824,9 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(), + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(), ), Vec::new(), ); @@ -5806,7 +5854,7 @@ mod tests { "title": "An endpoint received a message with an unknown task ID.", "detail": "An endpoint received a message with an unknown task ID.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5816,9 +5864,12 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5827,12 +5878,12 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), // Collect request will be rejected because batch interval is too small - Duration::from_seconds(task.min_batch_duration.as_seconds() - 1), + Duration::from_seconds(task.time_precision().as_seconds() - 1), ) .unwrap(), ), @@ -5865,7 +5916,7 @@ mod tests { "title": "The batch implied by the query is invalid.", "detail": "The batch implied by the query is invalid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5877,9 +5928,12 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5888,11 +5942,11 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), - Duration::from_seconds(task.min_batch_duration.as_seconds()), + Duration::from_seconds(task.time_precision().as_seconds()), ) .unwrap(), ), @@ -5926,7 +5980,7 @@ mod tests { "title": "The number of reports included in the batch is invalid.", "detail": "The number of reports included in the batch is invalid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -5936,14 +5990,14 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); + ) + .build(); let batch_interval = - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(); + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -5954,7 +6008,7 @@ mod tests { let filter = aggregator_filter(Arc::clone(&datastore), clock).unwrap(); let req = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval(batch_interval), Vec::new(), ); @@ -5982,7 +6036,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6013,7 +6067,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6040,7 +6094,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6051,19 +6105,14 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let mut task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); - task.aggregator_endpoints = Vec::from([ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]); - task.max_batch_lifetime = 1; + ) + .build(); let batch_interval = - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(); + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6074,7 +6123,7 @@ mod tests { let filter = aggregator_filter(Arc::clone(&datastore), clock).unwrap(); let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval(batch_interval), Vec::new(), ); @@ -6124,7 +6173,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6153,7 +6202,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6178,7 +6227,7 @@ mod tests { "title": "The request's authorization is not valid.", "detail": "The request's authorization is not valid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); assert_eq!(want_status, response.status()); @@ -6189,25 +6238,20 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let mut task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - ); - task.aggregator_endpoints = vec![ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]; - task.max_batch_lifetime = 1; - let batch_interval = - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(); let (collector_hpke_config, collector_hpke_recipient) = generate_test_hpke_config_and_private_key(); - task.collector_hpke_config = collector_hpke_config; + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_collector_hpke_config(collector_hpke_config) + .build(); + let batch_interval = + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(); - let leader_aggregate_share = AggregateShare::from(vec![Field64::from(64)]); - let helper_aggregate_share = AggregateShare::from(vec![Field64::from(32)]); + let leader_aggregate_share = AggregateShare::from(Vec::from([Field64::from(64)])); + let helper_aggregate_share = AggregateShare::from(Vec::from([Field64::from(32)])); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6218,7 +6262,7 @@ mod tests { let filter = aggregator_filter(Arc::clone(&datastore), clock).unwrap(); let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval(batch_interval), Vec::new(), ); @@ -6263,20 +6307,20 @@ mod tests { // Update the collect job with the aggregate shares. Collect job should now be complete. datastore .run_tx(|tx| { - let collector_hpke_config = task.collector_hpke_config.clone(); + let task = task.clone(); let helper_aggregate_share_bytes: Vec = (&helper_aggregate_share).into(); let leader_aggregate_share = leader_aggregate_share.clone(); Box::pin(async move { let encrypted_helper_aggregate_share = hpke::seal( - &collector_hpke_config, + task.collector_hpke_config(), &HpkeApplicationInfo::new( - Label::AggregateShare, - Role::Helper, - Role::Collector, + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, ), &helper_aggregate_share_bytes, &associated_data_for_aggregate_share::( - &task.id, + task.id(), &batch_interval, ), ) @@ -6329,11 +6373,11 @@ mod tests { assert_eq!(collect_resp.encrypted_aggregate_shares().len(), 2); let decrypted_leader_aggregate_share = hpke::open( - &task.collector_hpke_config, + task.collector_hpke_config(), &collector_hpke_recipient, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Leader, &Role::Collector), &collect_resp.encrypted_aggregate_shares()[0], - &associated_data_for_aggregate_share::(&task_id, &batch_interval), + &associated_data_for_aggregate_share::(task.id(), &batch_interval), ) .unwrap(); assert_eq!( @@ -6342,11 +6386,11 @@ mod tests { ); let decrypted_helper_aggregate_share = hpke::open( - &task.collector_hpke_config, + task.collector_hpke_config(), &collector_hpke_recipient, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + &HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Helper, &Role::Collector), &collect_resp.encrypted_aggregate_shares()[1], - &associated_data_for_aggregate_share::(&task_id, &batch_interval), + &associated_data_for_aggregate_share::(task.id(), &batch_interval), ) .unwrap(); assert_eq!( @@ -6381,16 +6425,18 @@ mod tests { async fn collect_request_batch_queried_too_many_times() { install_test_trace_subscriber(); - let task_id = random(); - let mut task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); - task.max_batch_lifetime = 1; + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let (datastore, _db_handle) = ephemeral_datastore(MockClock::default()).await; datastore .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.put_task(&task).await?; @@ -6398,7 +6444,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task.id, + *task.id(), Time::from_seconds_since_epoch(0), dummy_vdaf::AggregationParam(0), dummy_vdaf::AggregateShare(0), @@ -6413,11 +6459,11 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), MockClock::default()).unwrap(); - // Sending this request will consume the lifetime for [0, min_batch_duration). + // Sending this request will consume a query for [0, time_precision). let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(), + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(), ), dummy_vdaf::AggregationParam(0).get_encoded(), ); @@ -6438,11 +6484,11 @@ mod tests { assert_eq!(response.status(), StatusCode::SEE_OTHER); - // This request will not be allowed due to the batch lifetime already being consumed. + // This request will not be allowed due to the query count already being consumed. let invalid_request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(), + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(), ), dummy_vdaf::AggregationParam(1).get_encoded(), ); @@ -6472,7 +6518,7 @@ mod tests { "title": "The batch described by the query has been queried too many times.", "detail": "The batch described by the query has been queried too many times.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -6481,16 +6527,18 @@ mod tests { async fn collect_request_batch_overlap() { install_test_trace_subscriber(); - let task_id = random(); - let mut task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); - task.max_batch_lifetime = 1; + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let (datastore, _db_handle) = ephemeral_datastore(MockClock::default()).await; datastore .run_tx(|tx| { let task = task.clone(); - Box::pin(async move { tx.put_task(&task).await?; @@ -6498,7 +6546,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task.id, + *task.id(), Time::from_seconds_since_epoch(0), dummy_vdaf::AggregationParam(0), dummy_vdaf::AggregateShare(0), @@ -6513,14 +6561,14 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), MockClock::default()).unwrap(); - // Sending this request will consume the lifetime for [0, 2 * min_batch_duration). + // Sending this request will consume a query for [0, 2 * time_precision). let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), Duration::from_microseconds( - 2 * task.min_batch_duration.as_microseconds().unwrap(), + 2 * task.time_precision().as_microseconds().unwrap(), ), ) .unwrap(), @@ -6546,13 +6594,13 @@ mod tests { // This request will not be allowed due to overlapping with the previous request. let invalid_request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0) - .add(&task.min_batch_duration) + .add(task.time_precision()) .unwrap(), - task.min_batch_duration, + *task.time_precision(), ) .unwrap(), ), @@ -6584,7 +6632,7 @@ mod tests { "title": "The queried batch overlaps with a previously queried batch.", "detail": "The queried batch overlaps with a previously queried batch.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -6594,19 +6642,14 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let mut task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); - task.aggregator_endpoints = vec![ - "https://leader.endpoint".parse().unwrap(), - "https://helper.endpoint".parse().unwrap(), - ]; - task.max_batch_lifetime = 1; + ) + .build(); let batch_interval = - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(); + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6632,7 +6675,7 @@ mod tests { // Create a collect job let request = CollectReq::new( - task_id, + *task.id(), Query::new_time_interval(batch_interval), Vec::new(), ); @@ -6696,8 +6739,12 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6706,9 +6753,9 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); let request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(), + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(), ), Vec::new(), 0, @@ -6741,7 +6788,7 @@ mod tests { "title": "An endpoint received a message with an unknown task ID.", "detail": "An endpoint received a message with an unknown task ID.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -6751,8 +6798,12 @@ mod tests { install_test_trace_subscriber(); // Prepare parameters. - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6761,12 +6812,12 @@ mod tests { let filter = aggregator_filter(Arc::new(datastore), clock).unwrap(); let request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), // Collect request will be rejected because batch interval is too small - Duration::from_seconds(task.min_batch_duration.as_seconds() - 1), + Duration::from_seconds(task.time_precision().as_seconds() - 1), ) .unwrap(), ), @@ -6801,7 +6852,7 @@ mod tests { "title": "The batch implied by the query is invalid.", "detail": "The batch implied by the query is invalid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -6810,15 +6861,18 @@ mod tests { async fn aggregate_share_request() { install_test_trace_subscriber(); - let task_id = random(); let (collector_hpke_config, collector_hpke_recipient) = generate_test_hpke_config_and_private_key(); - - let mut task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Helper); - task.max_batch_lifetime = 1; - task.min_batch_duration = Duration::from_seconds(500); - task.min_batch_size = 10; - task.collector_hpke_config = collector_hpke_config.clone(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Helper, + ) + .with_max_batch_query_count(1) + .with_time_precision(Duration::from_seconds(500)) + .with_min_batch_size(10) + .with_collector_hpke_config(collector_hpke_config.clone()) + .build(); let clock = MockClock::default(); let (datastore, _db_handle) = ephemeral_datastore(clock.clone()).await; @@ -6830,9 +6884,9 @@ mod tests { // There are no batch unit_aggregations in the datastore yet let request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( - Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(), + Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()).unwrap(), ), dummy_vdaf::AggregationParam(0).get_encoded(), 0, @@ -6865,13 +6919,14 @@ mod tests { "title": "The number of reports included in the batch is invalid.", "detail": "The number of reports included in the batch is invalid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); // Put some batch unit aggregations in the DB. datastore .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { for aggregation_param in [ dummy_vdaf::AggregationParam(0), @@ -6881,7 +6936,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(500), aggregation_param, dummy_vdaf::AggregateShare(64), @@ -6894,7 +6949,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(1500), aggregation_param, dummy_vdaf::AggregateShare(128), @@ -6907,7 +6962,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(2000), aggregation_param, dummy_vdaf::AggregateShare(256), @@ -6920,7 +6975,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(2500), aggregation_param, dummy_vdaf::AggregateShare(512), @@ -6938,7 +6993,7 @@ mod tests { // Specified interval includes too few reports. let request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -6976,7 +7031,7 @@ mod tests { "title": "The number of reports included in the batch is invalid.", "detail": "The number of reports included in the batch is invalid.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); @@ -6984,7 +7039,7 @@ mod tests { for misaligned_request in [ // Interval is big enough, but checksum doesn't match. AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -6998,7 +7053,7 @@ mod tests { ), // Interval is big enough, but report count doesn't match. AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(2000), @@ -7037,7 +7092,7 @@ mod tests { "title": "Leader and helper disagree on reports aggregated in a batch.", "detail": "Leader and helper disagree on reports aggregated in a batch.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -7048,7 +7103,7 @@ mod tests { ( "first and second batch units", AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -7065,7 +7120,7 @@ mod tests { ( "third and fourth batch units", AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(2000), @@ -7082,7 +7137,7 @@ mod tests { ), ] { // Request the aggregate share multiple times. If the request parameters don't change, - // then there is no batch lifetime violation and all requests should succeed. + // then there is no query count violation and all requests should succeed. for iteration in 0..3 { let (parts, body) = warp::test::request() .method("POST") @@ -7119,7 +7174,11 @@ mod tests { let aggregate_share = hpke::open( &collector_hpke_config, &collector_hpke_recipient, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, + ), aggregate_share_resp.encrypted_aggregate_share(), &associated_data_for_aggregate_share::( request.task_id(), @@ -7142,7 +7201,7 @@ mod tests { // Requests for collection intervals that overlap with but are not identical to previous // collection intervals fail. let all_batch_unit_request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -7178,16 +7237,16 @@ mod tests { "title": "The queried batch overlaps with a previously queried batch.", "detail": "The queried batch overlaps with a previously queried batch.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }), ); - // Previous sequence of aggregate share requests should have consumed the batch lifetime for - // all the batch units. Further requests for any batch units will cause batch lifetime + // Previous sequence of aggregate share requests should have consumed the available queries + // for all the batch units. Further requests for any batch units will cause query count // violations. - for batch_lifetime_violation_request in [ + for query_count_violation_request in [ AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(0), @@ -7200,7 +7259,7 @@ mod tests { ReportIdChecksum::get_decoded(&[3 ^ 2; 32]).unwrap(), ), AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval( Interval::new( Time::from_seconds_since_epoch(2000), @@ -7221,7 +7280,7 @@ mod tests { task.primary_aggregator_auth_token().as_bytes(), ) .header(CONTENT_TYPE, AggregateShareReq::::MEDIA_TYPE) - .body(batch_lifetime_violation_request.get_encoded()) + .body(query_count_violation_request.get_encoded()) .filter(&filter) .await .unwrap() @@ -7237,7 +7296,7 @@ mod tests { "title": "The batch described by the query has been queried too many times.", "detail": "The batch described by the query has been queried too many times.", "instance": "..", - "taskid": format!("{}", task_id), + "taskid": format!("{}", task.id()), }) ); } @@ -7253,7 +7312,7 @@ mod tests { } fn generate_helper_report_share( - task_id: TaskId, + task_id: &TaskId, report_metadata: &ReportMetadata, cfg: &HpkeConfig, public_share: &V::PublicShare, @@ -7286,7 +7345,7 @@ mod tests { encoded_public_share, hpke::seal( cfg, - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Helper), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Helper), plaintext, associated_data, ) diff --git a/janus_server/src/aggregator/accumulator.rs b/janus_server/src/aggregator/accumulator.rs index dd2421528..740dd1950 100644 --- a/janus_server/src/aggregator/accumulator.rs +++ b/janus_server/src/aggregator/accumulator.rs @@ -1,15 +1,18 @@ //! In-memory accumulation of output shares. use super::Error; -use crate::datastore::{self, models::BatchUnitAggregation, Transaction}; +use crate::{ + datastore::{self, models::BatchUnitAggregation, Transaction}, + task::Task, +}; use derivative::Derivative; use janus_core::{ report_id::ReportIdChecksumExt, time::{Clock, TimeExt}, }; -use janus_messages::{Duration, Interval, ReportId, ReportIdChecksum, TaskId, Time}; +use janus_messages::{Interval, ReportId, ReportIdChecksum, Time}; use prio::vdaf::{self, Aggregatable}; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use tracing::debug; #[derive(Derivative)] @@ -47,28 +50,22 @@ pub(super) struct Accumulator> where for<'a> &'a A::AggregateShare: Into>, { - task_id: TaskId, - min_batch_duration: Duration, + task: Arc, #[derivative(Debug = "ignore")] aggregation_param: A::AggregationParam, accumulations: HashMap>, } -impl> Accumulator +impl<'t, const L: usize, A: vdaf::Aggregator> Accumulator where for<'a> &'a A::AggregateShare: Into>, for<'a> >::Error: std::fmt::Display, { /// Create a new accumulator - pub(super) fn new( - task_id: TaskId, - min_batch_duration: Duration, - aggregation_param: &A::AggregationParam, - ) -> Self { + pub(super) fn new(task: Arc, aggregation_param: A::AggregationParam) -> Self { Self { - task_id, - min_batch_duration, - aggregation_param: aggregation_param.clone(), + task, + aggregation_param, accumulations: HashMap::new(), } } @@ -81,7 +78,7 @@ where report_id: &ReportId, ) -> Result<(), datastore::Error> { let key = report_time - .to_batch_unit_interval_start(self.min_batch_duration) + .to_batch_unit_interval_start(self.task.time_precision()) .map_err(|e| datastore::Error::User(e.into()))?; if let Some(accumulation) = self.accumulations.get_mut(&key) { accumulation @@ -110,11 +107,11 @@ where tx: &Transaction<'_, C>, ) -> Result<(), datastore::Error> { for (unit_interval_start, accumulation) in &self.accumulations { - let unit_interval = Interval::new(*unit_interval_start, self.min_batch_duration)?; + let unit_interval = Interval::new(*unit_interval_start, *self.task.time_precision())?; let batch_unit_aggregations = tx .get_batch_unit_aggregations_for_task_in_interval::( - &self.task_id, + self.task.id(), &unit_interval, &self.aggregation_param, ) @@ -124,7 +121,7 @@ where return Err(datastore::Error::DbState(format!( "found {} batch unit aggregation rows for task {}, interval {unit_interval}", batch_unit_aggregations.len(), - self.task_id, + self.task.id(), ))); } @@ -145,7 +142,7 @@ where "inserting new batch_unit_aggregation row", ); tx.put_batch_unit_aggregation(&BatchUnitAggregation::::new( - self.task_id, + *self.task.id(), *unit_interval.start(), self.aggregation_param.clone(), accumulation.aggregate_share.clone(), diff --git a/janus_server/src/aggregator/aggregate_share.rs b/janus_server/src/aggregator/aggregate_share.rs index b386172c6..6cb2f502c 100644 --- a/janus_server/src/aggregator/aggregate_share.rs +++ b/janus_server/src/aggregator/aggregate_share.rs @@ -229,7 +229,7 @@ impl CollectJobDriver { let batch_unit_aggregations = tx .get_batch_unit_aggregations_for_task_in_interval::( - &task.id, + task.id(), collect_job.batch_interval(), collect_job.aggregation_parameter(), ) @@ -255,7 +255,7 @@ impl CollectJobDriver { // Send an aggregate share request to the helper. let req = AggregateShareReq::new( - task.id, + *task.id(), BatchSelector::new_time_interval(*collect_job.batch_interval()), collect_job.aggregation_parameter().get_encoded(), report_count, @@ -264,7 +264,8 @@ impl CollectJobDriver { let resp_bytes = post_to_helper( &self.http_client, - task.aggregator_url(Role::Helper)?.join("aggregate_share")?, + task.aggregator_url(&Role::Helper)? + .join("aggregate_share")?, AggregateShareReq::::MEDIA_TYPE, req, task.primary_aggregator_auth_token(), @@ -481,7 +482,7 @@ impl CollectJobDriver { /// Computes the aggregate share over the provided batch unit aggregations. /// The assumption is that all aggregation jobs contributing to those batch unit aggregations have -/// been driven to completion, and that the batch lifetime requirements have been validated for the +/// been driven to completion, and that the query count requirements have been validated for the /// included batch units. #[tracing::instrument(err)] pub(crate) async fn compute_aggregate_share>( @@ -534,14 +535,14 @@ where // Only happens if there were no batch unit aggregations, which would get caught by the // min_batch_size check below, but we have to unwrap the option. - let total_aggregate_share = - total_aggregate_share.ok_or(Error::InvalidBatchSize(task.id, total_report_count))?; + let total_aggregate_share = total_aggregate_share + .ok_or_else(|| Error::InvalidBatchSize(*task.id(), total_report_count))?; // Refuse to service time-interval aggregate share requests if there are too few reports // included. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6.1.1 - if total_report_count < task.min_batch_size { - return Err(Error::InvalidBatchSize(task.id, total_report_count)); + if total_report_count < task.min_batch_size() { + return Err(Error::InvalidBatchSize(*task.id(), total_report_count)); } // TODO(#468): This should check against the task's max batch size for fixed size queries @@ -551,10 +552,10 @@ where /// Check whether this collect interval has been included in enough collect jobs (for `task.role` == /// [`Role::Leader`]) or aggregate share jobs (for `task.role` == [`Role::Helper`]) to violate the -/// task's maximum batch lifetime, and that this collect interval does not partially overlap with +/// task's maximum batch query count, and that this collect interval does not partially overlap with /// an already-observed collect interval. // TODO(#468): This only handles time-interval queries -pub(crate) async fn validate_batch_lifetime_for_collect< +pub(crate) async fn validate_batch_query_count_for_collect< const L: usize, C: Clock, A: vdaf::Aggregator, @@ -568,23 +569,24 @@ where for<'a> &'a A::AggregateShare: Into>, { // Check how many rows in the relevant table have an intersecting batch interval. - // Each such row consumes one unit of batch lifetime (§4.6). - let intersecting_intervals: Vec<_> = match task.role { + // Each such row consumes one unit of query count. + // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6 + let intersecting_intervals: Vec<_> = match task.role() { Role::Leader => tx - .get_collect_jobs_jobs_intersecting_interval::(&task.id, &collect_interval) + .get_collect_jobs_jobs_intersecting_interval::(task.id(), &collect_interval) .await? .into_iter() .map(|job| *job.batch_interval()) .collect(), Role::Helper => tx - .get_aggregate_share_jobs_intersecting_interval::(&task.id, &collect_interval) + .get_aggregate_share_jobs_intersecting_interval::(task.id(), &collect_interval) .await? .into_iter() .map(|job| *job.batch_interval()) .collect(), - _ => panic!("Unexpected task role {:?}", task.role), + _ => panic!("Unexpected task role {:?}", task.role()), }; // Check that all intersecting collect intervals are equal to this collect interval. @@ -594,31 +596,31 @@ where .any(|interval| interval != &collect_interval) { return Err(datastore::Error::User( - Error::BatchOverlap(task.id, collect_interval).into(), + Error::BatchOverlap(*task.id(), collect_interval).into(), )); } // Check that the batch query count is being consumed appropriately. // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6 - let max_batch_lifetime: usize = task.max_batch_lifetime.try_into()?; - if intersecting_intervals.len() == max_batch_lifetime { + let max_batch_query_count: usize = task.max_batch_query_count().try_into()?; + if intersecting_intervals.len() == max_batch_query_count { debug!( - task_id = ?task.id, ?collect_interval, - "Refusing aggregate share request because batch lifetime has been consumed" + task_id = %task.id(), ?collect_interval, + "Refusing aggregate share request because query count has been consumed" ); return Err(datastore::Error::User( - Error::BatchQueriedTooManyTimes(task.id, intersecting_intervals.len() as u64).into(), + Error::BatchQueriedTooManyTimes(*task.id(), intersecting_intervals.len() as u64).into(), )); } - if intersecting_intervals.len() > max_batch_lifetime { + if intersecting_intervals.len() > max_batch_query_count { error!( - task_id = ?task.id, ?collect_interval, - "Batch lifetime has been consumed more times than task allows" + task_id = %task.id(), ?collect_interval, + "query count has been consumed more times than task allows" ); // We return an internal error since this should be impossible. return Err(datastore::Error::User( - Error::Internal("batch lifetime overconsumed".to_string()).into(), + Error::Internal("query count overconsumed".to_string()).into(), )); } Ok(()) @@ -626,41 +628,45 @@ where #[cfg(test)] mod tests { - use super::*; use crate::{ - aggregator::DapProblemType, + aggregator::{aggregate_share::CollectJobDriver, DapProblemType, Error}, binary_utils::job_driver::JobDriver, datastore::{ models::{ - AggregationJob, AggregationJobState, CollectJob, CollectJobState, - ReportAggregation, ReportAggregationState, + AcquiredCollectJob, AggregationJob, AggregationJobState, BatchUnitAggregation, + CollectJob, CollectJobState, Lease, ReportAggregation, ReportAggregationState, }, test_util::ephemeral_datastore, + Datastore, }, messages::TimeExt, - task::VdafInstance, + task::{test_util::TaskBuilder, QueryType, VdafInstance}, }; use assert_matches::assert_matches; use http::{header::CONTENT_TYPE, StatusCode}; use janus_core::{ test_util::{ - dummy_vdaf::{AggregateShare, AggregationParam, OutputShare}, + dummy_vdaf::{self, AggregateShare, AggregationParam, OutputShare}, install_test_trace_subscriber, runtime::TestRuntimeManager, }, - time::{MockClock, TimeExt as CoreTimeExt}, + time::{Clock, MockClock, TimeExt as CoreTimeExt}, Runtime, }; use janus_messages::{ - Duration, HpkeCiphertext, HpkeConfigId, Interval, Report, ReportMetadata, Role, + query_type::TimeInterval, AggregateShareReq, AggregateShareResp, BatchSelector, Duration, + HpkeCiphertext, HpkeConfigId, Interval, Report, ReportIdChecksum, ReportMetadata, Role, }; use mockito::mock; use opentelemetry::global::meter; + use prio::codec::{Decode, Encode}; use rand::random; - use std::str; + use std::{str, sync::Arc}; use url::Url; use uuid::Uuid; + use super::DUMMY_VERIFY_KEY_LENGTH; + async fn setup_collect_job_test_case( clock: MockClock, datastore: Arc>, @@ -669,19 +675,19 @@ mod tests { Option>, CollectJob, ) { - let task_id = random(); - let mut task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); - task.aggregator_endpoints = vec![ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&mockito::server_url()).unwrap(), - ]; - task.min_batch_duration = Duration::from_seconds(500); - task.min_batch_size = 10; + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) + .with_aggregator_endpoints(Vec::from([ + Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter + Url::parse(&mockito::server_url()).unwrap(), + ])) + .with_time_precision(Duration::from_seconds(500)) + .with_min_batch_size(10) + .build(); let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); let aggregation_param = AggregationParam(0); let collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), batch_interval, aggregation_param, @@ -690,9 +696,7 @@ mod tests { let lease = datastore .run_tx(|tx| { - let clock = clock.clone(); - let task = task.clone(); - let collect_job = collect_job.clone(); + let (clock, task, collect_job) = (clock.clone(), task.clone(), collect_job.clone()); Box::pin(async move { tx.put_task(&task).await?; @@ -704,7 +708,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, aggregation_param, AggregationJobState::Finished, @@ -715,12 +719,12 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); tx.put_client_report(&Report::new( - task_id, + *task.id(), report_metadata.clone(), Vec::new(), Vec::new(), @@ -731,9 +735,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Finished(OutputShare()), @@ -744,7 +748,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), clock.now(), aggregation_param, AggregateShare(0), @@ -756,7 +760,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), clock.now().add(&Duration::from_seconds(1000)).unwrap(), aggregation_param, AggregateShare(0), @@ -770,7 +774,7 @@ mod tests { .acquire_incomplete_collect_jobs(&Duration::from_seconds(100), 1) .await? .remove(0); - assert_eq!(&task_id, lease.leased().task_id()); + assert_eq!(task.id(), lease.leased().task_id()); assert_eq!( collect_job.collect_job_id(), lease.leased().collect_job_id() @@ -794,28 +798,27 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let task_id = random(); - let mut task = Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader); - task.aggregator_endpoints = vec![ - Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter - Url::parse(&mockito::server_url()).unwrap(), - ]; - task.min_batch_duration = Duration::from_seconds(500); - task.min_batch_size = 10; + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) + .with_aggregator_endpoints(Vec::from([ + Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter + Url::parse(&mockito::server_url()).unwrap(), + ])) + .with_time_precision(Duration::from_seconds(500)) + .with_min_batch_size(10) + .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); let aggregation_param = AggregationParam(0); let (collect_job_id, lease) = ds .run_tx(|tx| { - let clock = clock.clone(); - let task = task.clone(); + let (clock, task) = (clock.clone(), task.clone()); Box::pin(async move { tx.put_task(&task).await?; let collect_job_id = Uuid::new_v4(); tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), collect_job_id, batch_interval, aggregation_param, @@ -828,7 +831,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, aggregation_param, AggregationJobState::Finished, @@ -839,12 +842,12 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); tx.put_client_report(&Report::new( - task_id, + *task.id(), report_metadata.clone(), Vec::new(), Vec::new(), @@ -855,9 +858,9 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, - *report_metadata.report_id(), + *report_metadata.id(), *report_metadata.time(), 0, ReportAggregationState::Finished(OutputShare()), @@ -870,7 +873,7 @@ mod tests { .remove(0), ); - assert_eq!(&task_id, lease.leased().task_id()); + assert_eq!(task.id(), lease.leased().task_id()); assert_eq!(&collect_job_id, lease.leased().collect_job_id()); Ok((collect_job_id, lease)) }) @@ -889,18 +892,18 @@ mod tests { .await .unwrap_err(); assert_matches!(error, Error::InvalidBatchSize(error_task_id, 0) => { - assert_eq!(task_id, error_task_id) + assert_eq!(task.id(), &error_task_id) }); // Put some batch unit aggregations in the DB ds.run_tx(|tx| { - let clock = clock.clone(); + let (clock, task) = (clock.clone(), task.clone()); Box::pin(async move { tx.put_batch_unit_aggregation(&BatchUnitAggregation::< DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), clock.now(), aggregation_param, AggregateShare(0), @@ -913,7 +916,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), clock.now().add(&Duration::from_seconds(1000)).unwrap(), aggregation_param, AggregateShare(0), @@ -929,7 +932,7 @@ mod tests { .unwrap(); let leader_request = AggregateShareReq::new( - task_id, + *task.id(), BatchSelector::new_time_interval(batch_interval), aggregation_param.get_encoded(), 10, diff --git a/janus_server/src/aggregator/aggregation_job_creator.rs b/janus_server/src/aggregator/aggregation_job_creator.rs index 254c06a82..14ed62b26 100644 --- a/janus_server/src/aggregator/aggregation_job_creator.rs +++ b/janus_server/src/aggregator/aggregation_job_creator.rs @@ -122,8 +122,8 @@ impl AggregationJobCreator { let tasks = match tasks { Ok(tasks) => tasks .into_iter() - .filter_map(|task| match task.role { - Role::Leader => Some((task.id, task)), + .filter_map(|task| match task.role() { + Role::Leader => Some((*task.id(), task)), _ => None, }) .collect::>(), @@ -155,14 +155,14 @@ impl AggregationJobCreator { if job_creation_task_shutdown_handles.contains_key(&task_id) { continue; } - info!(?task_id, "Starting job creation worker"); + info!(%task_id, "Starting job creation worker"); let (tx, rx) = oneshot::channel(); job_creation_task_shutdown_handles.insert(task_id, tx); tokio::task::spawn({ let (this, job_creation_time_histogram) = (Arc::clone(&self), job_creation_time_histogram.clone()); async move { - this.run_for_task(rx, job_creation_time_histogram, task) + this.run_for_task(rx, job_creation_time_histogram, Arc::new(task)) .await } }); @@ -181,9 +181,9 @@ impl AggregationJobCreator { &self, mut shutdown: Receiver<()>, job_creation_time_histogram: Histogram, - task: Task, + task: Arc, ) { - debug!(task_id = ?task.id, "Job creation worker started"); + debug!(task_id = %task.id(), "Job creation worker started"); let first_tick_instant = Instant::now() + Duration::from_secs( thread_rng().gen_range(0..self.aggregation_job_creation_interval.as_secs()), @@ -194,17 +194,17 @@ impl AggregationJobCreator { loop { select! { _ = aggregation_job_creation_ticker.tick() => { - info!(task_id = ?task.id, "Creating aggregation jobs for task"); + info!(task_id = %task.id(), "Creating aggregation jobs for task"); let (start, mut status) = (Instant::now(), "success"); - if let Err(error) = self.create_aggregation_jobs_for_task(&task).await { - error!(task_id = ?task.id, %error, "Couldn't create aggregation jobs for task"); + if let Err(error) = self.create_aggregation_jobs_for_task(Arc::clone(&task)).await { + error!(task_id = %task.id(), %error, "Couldn't create aggregation jobs for task"); status = "error"; } job_creation_time_histogram.record(&Context::current(), start.elapsed().as_secs_f64(), &[KeyValue::new("status", status)]); } _ = &mut shutdown => { - debug!(task_id = ?task.id, "Job creation worker stopped"); + debug!(task_id = %task.id(), "Job creation worker stopped"); return; } } @@ -212,8 +212,8 @@ impl AggregationJobCreator { } #[tracing::instrument(skip(self), err)] - async fn create_aggregation_jobs_for_task(&self, task: &Task) -> anyhow::Result<()> { - match task.vdaf { + async fn create_aggregation_jobs_for_task(&self, task: Arc) -> anyhow::Result<()> { + match task.vdaf() { VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count) => { self.create_aggregation_jobs_for_task_no_param::(task) .await @@ -237,8 +237,8 @@ impl AggregationJobCreator { } _ => { - error!(vdaf = ?task.vdaf, "VDAF is not yet supported"); - panic!("VDAF {:?} is not yet supported", task.vdaf); + error!(vdaf = ?task.vdaf(), "VDAF is not yet supported"); + panic!("VDAF {:?} is not yet supported", task.vdaf()); } } } @@ -249,7 +249,7 @@ impl AggregationJobCreator { A: vdaf::Aggregator, >( &self, - task: &Task, + task: Arc, ) -> anyhow::Result<()> where for<'a> &'a A::AggregateShare: Into>, @@ -258,12 +258,10 @@ impl AggregationJobCreator { A::OutputShare: Send + Sync, for<'a> &'a A::OutputShare: Into>, { - let task_id = task.id; - let min_batch_duration = task.min_batch_duration; let current_batch_unit_start = self .clock .now() - .to_batch_unit_interval_start(min_batch_duration)?; + .to_batch_unit_interval_start(task.time_precision())?; let min_aggregation_job_size = self.min_aggregation_job_size; let max_aggregation_job_size = self.max_aggregation_job_size; @@ -271,14 +269,15 @@ impl AggregationJobCreator { Ok(self .datastore .run_tx(|tx| { + let task = Arc::clone(&task); Box::pin(async move { // Find some unaggregated client reports, and group them by their batch unit. let report_ids_by_batch_unit = tx - .get_unaggregated_client_report_ids_for_task(&task_id) + .get_unaggregated_client_report_ids_for_task(task.id()) .await? .into_iter() .map(|(report_id, time)| { - time.to_batch_unit_interval_start(min_batch_duration) + time.to_batch_unit_interval_start(task.time_precision()) .map(|rounded_time| (rounded_time, (report_id, time))) .map_err(datastore::Error::from) }) @@ -300,13 +299,13 @@ impl AggregationJobCreator { let aggregation_job_id = random(); debug!( - ?task_id, - ?aggregation_job_id, + task_id = %task.id(), + %aggregation_job_id, report_count = agg_job_reports.len(), "Creating aggregation job" ); agg_jobs.push(AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -314,7 +313,7 @@ impl AggregationJobCreator { for (ord, (report_id, time)) in agg_job_reports.iter().enumerate() { report_aggs.push(ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, *report_id, *time, @@ -353,7 +352,7 @@ impl AggregationJobCreator { #[tracing::instrument(skip(self), err)] async fn create_aggregation_jobs_for_task_with_param( &self, - task: &Task, + task: Arc, ) -> anyhow::Result<()> where A: vdaf::Aggregator + VdafHasAggregationParameter, @@ -364,22 +363,21 @@ impl AggregationJobCreator { for<'a> &'a A::OutputShare: Into>, A::AggregationParam: Send + Sync + Eq + Hash, { - let task_id = task.id; - let min_batch_duration = task.min_batch_duration; let max_aggregation_job_size = self.max_aggregation_job_size; self.datastore .run_tx(|tx| { + let task = Arc::clone(&task); Box::pin(async move { // Find some client reports that are covered by a collect request, // but haven't been aggregated yet, and group them by their batch unit. let result_vec = tx - .get_unaggregated_client_report_ids_by_collect_for_task::(&task_id) + .get_unaggregated_client_report_ids_by_collect_for_task::(task.id()) .await? .into_iter() .map(|(report_id, report_time, aggregation_param)| { report_time - .to_batch_unit_interval_start(min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .map(|rounded_time| { ((rounded_time, aggregation_param), (report_id, report_time)) }) @@ -398,13 +396,13 @@ impl AggregationJobCreator { { let aggregation_job_id = random(); debug!( - ?task_id, - ?aggregation_job_id, + task_id = %task.id(), + %aggregation_job_id, report_count = agg_job_reports.len(), "Creating aggregation job" ); agg_jobs.push(AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, aggregation_param.clone(), AggregationJobState::InProgress, @@ -412,7 +410,7 @@ impl AggregationJobCreator { for (ord, (report_id, time)) in agg_job_reports.iter().enumerate() { report_aggs.push(ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, *report_id, *time, @@ -457,7 +455,7 @@ mod tests { }, messages::test_util::new_dummy_report, messages::TimeExt, - task::{Task, PRIO3_AES128_VERIFY_KEY_LENGTH}, + task::{test_util::TaskBuilder, QueryType, PRIO3_AES128_VERIFY_KEY_LENGTH}, }; use futures::{future::try_join_all, TryFutureExt}; use janus_core::{ @@ -476,7 +474,6 @@ mod tests { Aggregator, Vdaf, }, }; - use rand::random; use std::{ collections::{HashMap, HashSet}, iter, @@ -504,22 +501,21 @@ mod tests { // even if the main test loops on calling yield_now(). let report_time = Time::from_seconds_since_epoch(0); - - let leader_task_id = random(); - let leader_task = Task::new_dummy( - leader_task_id, + let leader_task = TaskBuilder::new( + QueryType::TimeInterval, VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ); - let leader_report = new_dummy_report(leader_task_id, report_time); + ) + .build(); + let leader_report = new_dummy_report(*leader_task.id(), report_time); - let helper_task_id = random(); - let helper_task = Task::new_dummy( - helper_task_id, + let helper_task = TaskBuilder::new( + QueryType::TimeInterval, VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); - let helper_report = new_dummy_report(helper_task_id, report_time); + ) + .build(); + let helper_report = new_dummy_report(*helper_task.id(), report_time); ds.run_tx(|tx| { let (leader_task, helper_task) = (leader_task.clone(), helper_task.clone()); @@ -558,16 +554,17 @@ mod tests { job_creator .datastore .run_tx(|tx| { + let (leader_task, helper_task) = (leader_task.clone(), helper_task.clone()); Box::pin(async move { let leader_agg_jobs = read_aggregate_jobs_for_task_prio3_count::< HashSet<_>, _, - >(tx, leader_task_id) + >(tx, leader_task.id()) .await; let helper_agg_jobs = read_aggregate_jobs_for_task_prio3_count::< HashSet<_>, _, - >(tx, helper_task_id) + >(tx, helper_task.id()) .await; Ok((leader_agg_jobs, helper_agg_jobs)) }) @@ -581,7 +578,7 @@ mod tests { report_times_and_ids, HashSet::from([( *leader_report.metadata().time(), - *leader_report.metadata().report_id() + *leader_report.metadata().id() )]) ); } @@ -604,35 +601,41 @@ mod tests { assert!(MAX_AGGREGATION_JOB_SIZE < usize::MAX); // we can add 1 safely } - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); + let task = Arc::new( + TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(), + ); let current_batch_unit = clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(); // In the current batch unit, create MIN_AGGREGATION_JOB_SIZE reports. We expect an // aggregation job to be created containing these reports. let report_time = clock.now(); let cur_batch_unit_reports: Vec = - iter::repeat_with(|| new_dummy_report(task_id, report_time)) + iter::repeat_with(|| new_dummy_report(*task.id(), report_time)) .take(MIN_AGGREGATION_JOB_SIZE) .collect(); // In a previous "small" batch unit, create fewer than MIN_AGGREGATION_JOB_SIZE reports. // Since the minimum aggregation job size applies only to the current batch window, we // expect an aggregation job to be created for these reports. - let report_time = report_time.sub(&task.min_batch_duration).unwrap(); + let report_time = report_time.sub(task.time_precision()).unwrap(); let small_batch_unit_reports: Vec = - iter::repeat_with(|| new_dummy_report(task_id, report_time)) + iter::repeat_with(|| new_dummy_report(*task.id(), report_time)) .take(MIN_AGGREGATION_JOB_SIZE - 1) .collect(); // In a (separate) previous "big" batch unit, create more than MAX_AGGREGATION_JOB_SIZE // reports. We expect these reports will be split into more than one aggregation job. - let report_time = report_time.sub(&task.min_batch_duration).unwrap(); + let report_time = report_time.sub(task.time_precision()).unwrap(); let big_batch_unit_reports: Vec = - iter::repeat_with(|| new_dummy_report(task_id, report_time)) + iter::repeat_with(|| new_dummy_report(*task.id(), report_time)) .take(MAX_AGGREGATION_JOB_SIZE + 1) .collect(); @@ -640,7 +643,7 @@ mod tests { .iter() .chain(&small_batch_unit_reports) .chain(&big_batch_unit_reports) - .map(|report| *report.metadata().report_id()) + .map(|report| *report.metadata().id()) .collect(); ds.run_tx(|tx| { @@ -675,7 +678,7 @@ mod tests { max_aggregation_job_size: MAX_AGGREGATION_JOB_SIZE, }; job_creator - .create_aggregation_jobs_for_task(&task) + .create_aggregation_jobs_for_task(Arc::clone(&task)) .await .unwrap(); @@ -683,8 +686,9 @@ mod tests { let agg_jobs = job_creator .datastore .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - Ok(read_aggregate_jobs_for_task_prio3_count::, _>(tx, task_id).await) + Ok(read_aggregate_jobs_for_task_prio3_count::, _>(tx, task.id()).await) }) }) .await @@ -695,7 +699,7 @@ mod tests { let batch_units: HashSet::OutputShare: TryFrom<&'a [u8]>, { try_join_all( - tx.get_aggregation_jobs_for_task_id::(&task_id) + tx.get_aggregation_jobs_for_task_id::(task_id) .await .unwrap() .into_iter() .map(|agg_job| async { - let agg_job_id = *agg_job.aggregation_job_id(); + let agg_job_id = *agg_job.id(); tx.get_report_aggregations_for_aggregation_job( - &vdaf, + vdaf, &Role::Leader, - &task_id, + task_id, &agg_job_id, ) .map_ok(move |report_aggs| { ( - *agg_job.aggregation_job_id(), + *agg_job.id(), report_aggs .into_iter() .map(|ra| (*ra.time(), *ra.report_id())) diff --git a/janus_server/src/aggregator/aggregation_job_driver.rs b/janus_server/src/aggregator/aggregation_job_driver.rs index 3bc156425..730567c3d 100644 --- a/janus_server/src/aggregator/aggregation_job_driver.rs +++ b/janus_server/src/aggregator/aggregation_job_driver.rs @@ -214,7 +214,7 @@ impl AggregationJobDriver { .map_err(|err| datastore::Error::User(err.into()))?; Ok(( - task, + Arc::new(task), aggregation_job, report_aggregations, client_reports, @@ -253,7 +253,7 @@ impl AggregationJobDriver { datastore: &Datastore, vdaf: &A, lease: Arc>, - task: Task, + task: Arc, aggregation_job: AggregationJob, report_aggregations: Vec>, client_reports: Vec, @@ -288,7 +288,7 @@ impl AggregationJobDriver { assert_eq!(report_aggregation.task_id(), client_report.task_id()); assert_eq!( report_aggregation.report_id(), - client_report.metadata().report_id() + client_report.metadata().id() ); } @@ -337,7 +337,7 @@ impl AggregationJobDriver { // Decrypt leader input share & transform into our first transition. let (hpke_config, hpke_private_key) = match task - .hpke_keys + .hpke_keys() .get(leader_encrypted_input_share.config_id()) { Some((hpke_config, hpke_private_key)) => (hpke_config, hpke_private_key), @@ -355,9 +355,12 @@ impl AggregationJobDriver { } }; let hpke_application_info = - HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); - let associated_data = - associated_data_for_report_share(task.id, report.metadata(), report.public_share()); + HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader); + let associated_data = associated_data_for_report_share( + task.id(), + report.metadata(), + report.public_share(), + ); let leader_input_share_bytes = match hpke::open( hpke_config, hpke_private_key, @@ -421,7 +424,7 @@ impl AggregationJobDriver { verify_key.as_bytes(), Role::Leader.index().unwrap(), aggregation_job.aggregation_parameter(), - &report.metadata().report_id().get_encoded(), + &report.metadata().id().get_encoded(), &public_share, &leader_input_share, ) { @@ -455,8 +458,8 @@ impl AggregationJobDriver { // TODO(#235): abandon work immediately on "terminal" failures from helper, or other // unexepected cases such as unknown/unexpected content type. let req = AggregateInitializeReq::new( - task.id, - *aggregation_job.aggregation_job_id(), + *task.id(), + *aggregation_job.id(), aggregation_job.aggregation_parameter().get_encoded(), PartialBatchSelector::new_time_interval(), report_shares, @@ -464,7 +467,7 @@ impl AggregationJobDriver { let resp_bytes = post_to_helper( &self.http_client, - task.aggregator_url(Role::Helper)?.join("aggregate")?, + task.aggregator_url(&Role::Helper)?.join("aggregate")?, AggregateInitializeReq::::MEDIA_TYPE, req, task.primary_aggregator_auth_token(), @@ -495,7 +498,7 @@ impl AggregationJobDriver { datastore: &Datastore, vdaf: &A, lease: Arc>, - task: Task, + task: Arc, aggregation_job: AggregationJob, report_aggregations: Vec>, ) -> Result<()> @@ -556,15 +559,11 @@ impl AggregationJobDriver { // Construct request, send it to the helper, and process the response. // TODO(#235): abandon work immediately on "terminal" failures from helper, or other // unexepected cases such as unknown/unexpected content type. - let req = AggregateContinueReq::new( - task.id, - *aggregation_job.aggregation_job_id(), - prepare_steps, - ); + let req = AggregateContinueReq::new(*task.id(), *aggregation_job.id(), prepare_steps); let resp_bytes = post_to_helper( &self.http_client, - task.aggregator_url(Role::Helper)?.join("aggregate")?, + task.aggregator_url(&Role::Helper)?.join("aggregate")?, AggregateContinueReq::MEDIA_TYPE, req, task.primary_aggregator_auth_token(), @@ -592,7 +591,7 @@ impl AggregationJobDriver { datastore: &Datastore, vdaf: &A, lease: Arc>, - task: Task, + task: Arc, aggregation_job: AggregationJob, stepped_aggregations: &[SteppedAggregation], mut report_aggregations_to_write: Vec>, @@ -616,9 +615,8 @@ impl AggregationJobDriver { )); } let mut accumulator = Accumulator::::new( - task.id, - task.min_batch_duration, - aggregation_job.aggregation_parameter(), + Arc::clone(&task), + aggregation_job.aggregation_parameter().clone(), ); for (stepped_aggregation, helper_prep_step) in stepped_aggregations.iter().zip(prep_steps) { let (report_aggregation, leader_transition) = ( @@ -912,7 +910,7 @@ mod tests { }, test_util::ephemeral_datastore, }, - task::{Task, VerifyKey, PRIO3_AES128_VERIFY_KEY_LENGTH}, + task::{test_util::TaskBuilder, QueryType, VerifyKey, PRIO3_AES128_VERIFY_KEY_LENGTH}, }; use assert_matches::assert_matches; use http::{header::CONTENT_TYPE, StatusCode}; @@ -961,20 +959,22 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); - - let task_id = random(); - let mut task = - Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); - task.aggregator_endpoints = vec![ + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_aggregator_endpoints(Vec::from([ Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), - ]; + ])) + .build(); let report_metadata = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -985,15 +985,15 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata.report_id(), + report_metadata.id(), &0, ); let agg_auth_token = task.primary_aggregator_auth_token().clone(); - let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; + let (leader_hpke_config, _) = task.hpke_keys().iter().next().unwrap().1; let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( - task_id, + task.id(), &report_metadata, &[leader_hpke_config, &helper_hpke_config], &transcript.public_share, @@ -1012,7 +1012,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -1022,9 +1022,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Start, @@ -1044,7 +1044,7 @@ mod tests { AggregateInitializeReq::::MEDIA_TYPE, AggregateInitializeResp::MEDIA_TYPE, AggregateInitializeResp::new(Vec::from([PrepareStep::new( - *report.metadata().report_id(), + *report.metadata().id(), PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), )])) .get_encoded(), @@ -1053,7 +1053,7 @@ mod tests { AggregateContinueReq::MEDIA_TYPE, AggregateContinueResp::MEDIA_TYPE, AggregateContinueResp::new(Vec::from([PrepareStep::new( - *report.metadata().report_id(), + *report.metadata().id(), PrepareStepResult::Finished, )])) .get_encoded(), @@ -1111,7 +1111,7 @@ mod tests { let want_aggregation_job = AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::Finished, @@ -1121,9 +1121,9 @@ mod tests { PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); let want_report_aggregation = ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Finished(leader_output_share), @@ -1131,12 +1131,12 @@ mod tests { let (got_aggregation_job, got_report_aggregation) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let report_id = *report.metadata().report_id(); + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await? @@ -1145,7 +1145,7 @@ mod tests { .get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, &report_id, ) @@ -1170,19 +1170,22 @@ mod tests { let ds = Arc::new(ds); let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); - let task_id = random(); - let mut task = - Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); - task.aggregator_endpoints = vec![ + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_aggregator_endpoints(Vec::from([ Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), - ]; + ])) + .build(); let report_metadata = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -1193,15 +1196,15 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata.report_id(), + report_metadata.id(), &0, ); let agg_auth_token = task.primary_aggregator_auth_token(); - let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; + let (leader_hpke_config, _) = task.hpke_keys().iter().next().unwrap().1; let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( - task_id, + task.id(), &report_metadata, &[leader_hpke_config, &helper_hpke_config], &transcript.public_share, @@ -1220,7 +1223,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -1230,9 +1233,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Start, @@ -1247,7 +1250,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(lease.leased().task_id(), &task_id); + assert_eq!(lease.leased().task_id(), task.id()); assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); // Setup: prepare mocked HTTP response. (first an error response, then a success) @@ -1255,7 +1258,7 @@ mod tests { // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregateInitializeReq::new( - task_id, + *task.id(), aggregation_job_id, ().get_encoded(), PartialBatchSelector::new_time_interval(), @@ -1273,7 +1276,7 @@ mod tests { &transcript.prepare_transitions[Role::Helper.index().unwrap()][0], PrepareTransition::Continue(_, prep_share) => prep_share); let helper_response = AggregateInitializeResp::new(Vec::from([PrepareStep::new( - *report.metadata().report_id(), + *report.metadata().id(), PrepareStepResult::Continued(helper_vdaf_msg.get_encoded()), )])); let mocked_aggregate_failure = mock("POST", "/aggregate") @@ -1321,7 +1324,7 @@ mod tests { let want_aggregation_job = AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -1332,9 +1335,9 @@ mod tests { let prep_msg = transcript.prepare_messages[0].clone(); let want_report_aggregation = ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), @@ -1342,12 +1345,12 @@ mod tests { let (got_aggregation_job, got_report_aggregation) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let report_id = *report.metadata().report_id(); + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await? @@ -1356,7 +1359,7 @@ mod tests { .get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, &report_id, ) @@ -1382,18 +1385,21 @@ mod tests { let ds = Arc::new(ds); let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); - let task_id = random(); - let mut task = - Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); - task.aggregator_endpoints = vec![ + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_aggregator_endpoints(Vec::from([ Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), - ]; + ])) + .build(); let report_metadata = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -1404,15 +1410,15 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata.report_id(), + report_metadata.id(), &0, ); let agg_auth_token = task.primary_aggregator_auth_token(); - let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; + let (leader_hpke_config, _) = task.hpke_keys().iter().next().unwrap().1; let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( - task_id, + task.id(), &report_metadata, &[leader_hpke_config, &helper_hpke_config], &transcript.public_share, @@ -1445,7 +1451,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -1455,9 +1461,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Waiting(leader_prep_state, Some(prep_msg)), @@ -1472,7 +1478,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(lease.leased().task_id(), &task_id); + assert_eq!(lease.leased().task_id(), task.id()); assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); // Setup: prepare mocked HTTP responses. (first an error response, then a success) @@ -1480,15 +1486,15 @@ mod tests { // It would be nicer to retrieve the request bytes from the mock, then do our own parsing & // verification -- but mockito does not expose this functionality at time of writing.) let leader_request = AggregateContinueReq::new( - task_id, + *task.id(), aggregation_job_id, Vec::from([PrepareStep::new( - *report.metadata().report_id(), + *report.metadata().id(), PrepareStepResult::Continued(prep_msg.get_encoded()), )]), ); let helper_response = AggregateContinueResp::new(Vec::from([PrepareStep::new( - *report.metadata().report_id(), + *report.metadata().id(), PrepareStepResult::Finished, )])); let mocked_aggregate_failure = mock("POST", "/aggregate") @@ -1533,7 +1539,7 @@ mod tests { let want_aggregation_job = AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::Finished, @@ -1543,9 +1549,9 @@ mod tests { PrepareTransition::Finish(leader_output_share) => leader_output_share.clone()); let want_report_aggregation = ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Finished(leader_output_share), @@ -1553,28 +1559,27 @@ mod tests { let batch_interval_start = report .metadata() .time() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(); let want_batch_unit_aggregations = Vec::from([BatchUnitAggregation::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), batch_interval_start, (), leader_aggregate_share, 1, - ReportIdChecksum::for_report_id(report.metadata().report_id()), + ReportIdChecksum::for_report_id(report.metadata().id()), )]); let (got_aggregation_job, got_report_aggregation, got_batch_unit_aggregations) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let report_metadata = report.metadata().clone(); + let (vdaf, task, report_metadata) = (Arc::clone(&vdaf), task.clone(), report.metadata().clone()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await? @@ -1583,18 +1588,18 @@ mod tests { .get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, - report_metadata.report_id(), + report_metadata.id(), ) .await? .unwrap(); let batch_unit_aggregations = tx .get_batch_unit_aggregations_for_task_in_interval::( - &task_id, + task.id(), &Interval::new( - report_metadata.time().to_batch_unit_interval_start(task.min_batch_duration).unwrap(), - task.min_batch_duration).unwrap(), + report_metadata.time().to_batch_unit_interval_start(task.time_precision()).unwrap(), + *task.time_precision()).unwrap(), &()) .await .unwrap(); @@ -1618,18 +1623,21 @@ mod tests { let ds = Arc::new(ds); let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); - let task_id = random(); - let mut task = - Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); - task.aggregator_endpoints = vec![ + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_aggregator_endpoints(Vec::from([ Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), - ]; + ])) + .build(); let report_metadata = ReportMetadata::new( random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); @@ -1640,14 +1648,14 @@ mod tests { vdaf.as_ref(), verify_key.as_bytes(), &(), - report_metadata.report_id(), + report_metadata.id(), &0, ); - let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; + let (leader_hpke_config, _) = task.hpke_keys().iter().next().unwrap().1; let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( - task_id, + task.id(), &report_metadata, &[leader_hpke_config, &helper_hpke_config], &transcript.public_share, @@ -1657,16 +1665,16 @@ mod tests { let aggregation_job = AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, ); let report_aggregation = ReportAggregation::::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Start, @@ -1694,7 +1702,7 @@ mod tests { }) .await .unwrap(); - assert_eq!(lease.leased().task_id(), &task_id); + assert_eq!(lease.leased().task_id(), task.id()); assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); // Run: create an aggregation job driver & cancel the aggregation job. @@ -1714,12 +1722,12 @@ mod tests { let (got_aggregation_job, got_report_aggregation, got_leases) = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); - let report_id = *report.metadata().report_id(); + let (vdaf, task, report_id) = + (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx .get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await? @@ -1728,7 +1736,7 @@ mod tests { .get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, &report_id, ) @@ -1750,7 +1758,7 @@ mod tests { /// Returns a report with the given task ID & metadata values and encrypted input shares /// corresponding to the given HPKE configs & input shares. fn generate_report( - task_id: TaskId, + task_id: &TaskId, report_metadata: &ReportMetadata, hpke_configs: &[&HpkeConfig], public_share: &P, @@ -1766,7 +1774,7 @@ mod tests { .map(|role| { hpke::seal( hpke_configs.get(role.index().unwrap()).unwrap(), - &HpkeApplicationInfo::new(Label::InputShare, Role::Client, role), + &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &role), &input_shares .get(role.index().unwrap()) .unwrap() @@ -1778,7 +1786,7 @@ mod tests { .unwrap(); Report::new( - task_id, + *task_id, report_metadata.clone(), public_share, encrypted_input_shares, @@ -1793,19 +1801,22 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let ds = Arc::new(ds); - let task_id = random(); - let mut task = - Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); - task.aggregator_endpoints = vec![ + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_aggregator_endpoints(Vec::from([ Url::parse("http://irrelevant").unwrap(), // leader URL doesn't matter Url::parse(&mockito::server_url()).unwrap(), - ]; + ])) + .build(); let agg_auth_token = task.primary_aggregator_auth_token(); let aggregation_job_id = random(); let verify_key: VerifyKey = task.primary_vdaf_verify_key().unwrap(); - let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; + let (leader_hpke_config, _) = task.hpke_keys().iter().next().unwrap().1; let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let vdaf = Prio3::new_aes128_count(2).unwrap(); @@ -1813,19 +1824,13 @@ mod tests { random(), clock .now() - .to_batch_unit_interval_start(task.min_batch_duration) + .to_batch_unit_interval_start(task.time_precision()) .unwrap(), Vec::new(), ); - let transcript = run_vdaf( - &vdaf, - verify_key.as_bytes(), - &(), - report_metadata.report_id(), - &0, - ); + let transcript = run_vdaf(&vdaf, verify_key.as_bytes(), &(), report_metadata.id(), &0); let report = generate_report( - task_id, + task.id(), &report_metadata, &[leader_hpke_config, &helper_hpke_config], &transcript.public_share, @@ -1847,7 +1852,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -1858,9 +1863,9 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, - *report.metadata().report_id(), + *report.metadata().id(), *report.metadata().time(), 0, ReportAggregationState::Start, @@ -1946,9 +1951,10 @@ mod tests { // Confirm in the database that the job was abandoned. let aggregation_job_after = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { tx.get_aggregation_job::( - &task_id, + task.id(), &aggregation_job_id, ) .await @@ -1960,7 +1966,7 @@ mod tests { assert_eq!( aggregation_job_after, AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::Abandoned, diff --git a/janus_server/src/bin/janus_cli.rs b/janus_server/src/bin/janus_cli.rs index 305a27152..cdcde1f5f 100644 --- a/janus_server/src/bin/janus_cli.rs +++ b/janus_server/src/bin/janus_cli.rs @@ -176,7 +176,7 @@ async fn provision_tasks(datastore: &Datastore, tasks_file: &Path) for task in tasks.iter() { // We attempt to delete the task, but ignore "task not found" errors since // the task not existing is an OK outcome too. - match tx.delete_task(&task.id).await { + match tx.delete_task(task.id()).await { Ok(_) | Err(datastore::Error::MutationTargetNotFound) => (), err => err?, } @@ -359,9 +359,8 @@ mod tests { }, config::CommonConfig, datastore::test_util::{ephemeral_datastore, ephemeral_db_handle}, - task::Task, + task::{test_util::TaskBuilder, QueryType}, }; - use rand::random; use ring::aead::{UnboundKey, AES_128_GCM}; use std::{ collections::HashMap, @@ -390,7 +389,7 @@ mod tests { .unwrap(); let expected_datastore_keys = - vec!["datastore-key-1".to_string(), "datastore-key-2".to_string()]; + Vec::from(["datastore-key-1".to_string(), "datastore-key-2".to_string()]); // Keys provided at command line, not present in k8s let mut binary_options = CommonBinaryOptions::default(); @@ -461,16 +460,18 @@ mod tests { #[tokio::test] async fn provision_tasks() { let tasks = Vec::from([ - Task::new_dummy( - random(), + TaskBuilder::new( + QueryType::TimeInterval, VdafInstance::Prio3Aes128Count.into(), Role::Leader, - ), - Task::new_dummy( - random(), + ) + .build(), + TaskBuilder::new( + QueryType::TimeInterval, VdafInstance::Prio3Aes128Sum { bits: 64 }.into(), Role::Helper, - ), + ) + .build(), ]); let (ds, _db_handle) = ephemeral_datastore(RealClock::default()).await; @@ -486,13 +487,13 @@ mod tests { super::provision_tasks(&ds, &tasks_path).await.unwrap(); // Verify that the expected tasks were written. - let want_tasks: HashMap<_, _> = tasks.into_iter().map(|task| (task.id, task)).collect(); + let want_tasks: HashMap<_, _> = tasks.into_iter().map(|task| (*task.id(), task)).collect(); let got_tasks = ds .run_tx(|tx| Box::pin(async move { tx.get_tasks().await })) .await .unwrap() .into_iter() - .map(|task| (task.id, task)) + .map(|task| (*task.id(), task)) .collect(); assert_eq!(want_tasks, got_tasks); } diff --git a/janus_server/src/datastore.rs b/janus_server/src/datastore.rs index 28d07f8db..873118a79 100644 --- a/janus_server/src/datastore.rs +++ b/janus_server/src/datastore.rs @@ -9,7 +9,7 @@ use self::models::{ use crate::aggregator::aggregation_job_creator::VdafHasAggregationParameter; use crate::{ messages::{IntervalExt, TimeExt}, - task::{self, Task, VdafInstance}, + task::{self, QueryType, Task, VdafInstance}, SecretBytes, }; use anyhow::anyhow; @@ -167,33 +167,40 @@ impl Transaction<'_, C> { /// Writes a task into the datastore. #[tracing::instrument(skip(self), err)] pub async fn put_task(&self, task: &Task) -> Result<(), Error> { - let endpoints: Vec<_> = task.aggregator_endpoints.iter().map(Url::as_str).collect(); + let endpoints: Vec<_> = task + .aggregator_endpoints() + .iter() + .map(Url::as_str) + .collect(); // Main task insert. let stmt = self .tx .prepare_cached( - "INSERT INTO tasks (task_id, aggregator_role, aggregator_endpoints, vdaf, - max_batch_lifetime, min_batch_size, min_batch_duration, tolerable_clock_skew, - collector_hpke_config) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", + "INSERT INTO tasks (task_id, aggregator_role, aggregator_endpoints, query_type, + vdaf, max_batch_query_count, task_expiration, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)", ) .await?; self.tx .execute( &stmt, &[ - /* task_id */ &task.id.as_ref(), - /* aggregator_role */ &AggregatorRole::from_role(task.role)?, + /* task_id */ &task.id().as_ref(), + /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, /* aggregator_endpoints */ &endpoints, - /* vdaf */ &Json(&task.vdaf), - /* max_batch_lifetime */ &i64::try_from(task.max_batch_lifetime)?, - /* min_batch_size */ &i64::try_from(task.min_batch_size)?, - /* min_batch_duration */ - &i64::try_from(task.min_batch_duration.as_seconds())?, + /* query_type */ &Json(task.query_type()), + /* vdaf */ &Json(task.vdaf()), + /* max_batch_query_count */ + &i64::try_from(task.max_batch_query_count())?, + /* task_expiration */ &task.task_expiration().as_naive_date_time(), + /* min_batch_size */ &i64::try_from(task.min_batch_size())?, + /* time_precision */ + &i64::try_from(task.time_precision().as_seconds())?, /* tolerable_clock_skew */ - &i64::try_from(task.tolerable_clock_skew.as_seconds())?, - /* collector_hpke_config */ &task.collector_hpke_config.get_encoded(), + &i64::try_from(task.tolerable_clock_skew().as_seconds())?, + /* collector_hpke_config */ &task.collector_hpke_config().get_encoded(), ], ) .await?; @@ -201,11 +208,11 @@ impl Transaction<'_, C> { // Aggregator auth tokens. let mut aggregator_auth_token_ords = Vec::new(); let mut aggregator_auth_tokens = Vec::new(); - for (ord, token) in task.aggregator_auth_tokens.iter().enumerate() { + for (ord, token) in task.aggregator_auth_tokens().iter().enumerate() { let ord = i64::try_from(ord)?; let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id.as_ref()); + row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); let encrypted_aggregator_auth_token = self.crypter.encrypt( @@ -224,7 +231,7 @@ impl Transaction<'_, C> { ) .await?; let aggregator_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id.as_ref(), + /* task_id */ &task.id().as_ref(), /* ords */ &aggregator_auth_token_ords, /* tokens */ &aggregator_auth_tokens, ]; @@ -233,11 +240,11 @@ impl Transaction<'_, C> { // Collector auth tokens. let mut collector_auth_token_ords = Vec::new(); let mut collector_auth_tokens = Vec::new(); - for (ord, token) in task.collector_auth_tokens.iter().enumerate() { + for (ord, token) in task.collector_auth_tokens().iter().enumerate() { let ord = i64::try_from(ord)?; let mut row_id = [0; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id.as_ref()); + row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); row_id[TaskId::LEN..].copy_from_slice(&ord.to_be_bytes()); let encrypted_collector_auth_token = self.crypter.encrypt( @@ -255,7 +262,7 @@ impl Transaction<'_, C> { SELECT (SELECT id FROM tasks WHERE task_id = $1), * FROM UNNEST($2::BIGINT[], $3::BYTEA[])" ).await?; let collector_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id.as_ref(), + /* task_id */ &task.id().as_ref(), /* ords */ &collector_auth_token_ords, /* tokens */ &collector_auth_tokens, ]; @@ -265,9 +272,9 @@ impl Transaction<'_, C> { let mut hpke_config_ids: Vec = Vec::new(); let mut hpke_configs: Vec> = Vec::new(); let mut hpke_private_keys: Vec> = Vec::new(); - for (hpke_config, hpke_private_key) in task.hpke_keys.values() { + for (hpke_config, hpke_private_key) in task.hpke_keys().values() { let mut row_id = [0u8; TaskId::LEN + size_of::()]; - row_id[..TaskId::LEN].copy_from_slice(task.id.as_ref()); + row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref()); row_id[TaskId::LEN..].copy_from_slice(&u8::from(*hpke_config.id()).to_be_bytes()); let encrypted_hpke_private_key = self.crypter.encrypt( @@ -289,7 +296,7 @@ impl Transaction<'_, C> { ) .await?; let hpke_configs_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id.as_ref(), + /* task_id */ &task.id().as_ref(), /* config_id */ &hpke_config_ids, /* configs */ &hpke_configs, /* private_keys */ &hpke_private_keys, @@ -301,9 +308,9 @@ impl Transaction<'_, C> { for vdaf_verify_key in task.vdaf_verify_keys() { let encrypted_vdaf_verify_key = self.crypter.encrypt( "task_vdaf_verify_keys", - task.id.as_ref(), + task.id().as_ref(), "vdaf_verify_key", - vdaf_verify_key.as_bytes(), + vdaf_verify_key.as_ref(), )?; vdaf_verify_keys.push(encrypted_vdaf_verify_key); } @@ -315,7 +322,7 @@ impl Transaction<'_, C> { ) .await?; let vdaf_verify_keys_params: &[&(dyn ToSql + Sync)] = &[ - /* task_id */ &task.id.as_ref(), + /* task_id */ &task.id().as_ref(), /* vdaf_verify_keys */ &vdaf_verify_keys, ]; let vdaf_verify_keys_future = self.tx.execute(&stmt, vdaf_verify_keys_params); @@ -396,8 +403,9 @@ impl Transaction<'_, C> { let stmt = self .tx .prepare_cached( - "SELECT aggregator_role, aggregator_endpoints, vdaf, max_batch_lifetime, - min_batch_size, min_batch_duration, tolerable_clock_skew, collector_hpke_config + "SELECT aggregator_role, aggregator_endpoints, query_type, vdaf, + max_batch_query_count, task_expiration, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config FROM tasks WHERE task_id = $1", ) .await?; @@ -472,9 +480,9 @@ impl Transaction<'_, C> { let stmt = self .tx .prepare_cached( - "SELECT task_id, aggregator_role, aggregator_endpoints, vdaf, - max_batch_lifetime, min_batch_size, min_batch_duration, - tolerable_clock_skew, collector_hpke_config + "SELECT task_id, aggregator_role, aggregator_endpoints, query_type, vdaf, + max_batch_query_count, task_expiration, min_batch_size, time_precision, + tolerable_clock_skew, collector_hpke_config FROM tasks", ) .await?; @@ -607,16 +615,17 @@ impl Transaction<'_, C> { ) -> Result { // Scalar task parameters. let aggregator_role: AggregatorRole = row.get("aggregator_role"); - let endpoints: Vec = row.get("aggregator_endpoints"); - let endpoints = endpoints + let endpoints = row + .get::<_, Vec>("aggregator_endpoints") .into_iter() .map(|endpoint| Ok(Url::parse(&endpoint)?)) .collect::>()?; + let query_type = row.try_get::<_, Json>("query_type")?.0; let vdaf = row.try_get::<_, Json>("vdaf")?.0; - let max_batch_lifetime = row.get_bigint_and_convert("max_batch_lifetime")?; + let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?; + let task_expiration = Time::from_naive_date_time(&row.get("task_expiration")); let min_batch_size = row.get_bigint_and_convert("min_batch_size")?; - let min_batch_duration = - Duration::from_seconds(row.get_bigint_and_convert("min_batch_duration")?); + let time_precision = Duration::from_seconds(row.get_bigint_and_convert("time_precision")?); let tolerable_clock_skew = Duration::from_seconds(row.get_bigint_and_convert("tolerable_clock_skew")?); let collector_hpke_config = HpkeConfig::get_decoded(row.get("collector_hpke_config"))?; @@ -693,12 +702,14 @@ impl Transaction<'_, C> { Ok(Task::new( *task_id, endpoints, + query_type, vdaf, aggregator_role.as_role(), vdaf_verify_keys, - max_batch_lifetime, + max_batch_query_count, + task_expiration, min_batch_size, - min_batch_duration, + time_precision, tolerable_clock_skew, collector_hpke_config, aggregator_auth_tokens, @@ -896,7 +907,7 @@ impl Transaction<'_, C> { &stmt, &[ /* task_id */ &report.task_id().get_encoded(), - /* report_id */ &report.metadata().report_id().as_ref(), + /* report_id */ &report.metadata().id().as_ref(), /* client_timestamp */ &report.metadata().time().as_naive_date_time(), /* extensions */ &encoded_extensions, /* public_share */ &report.public_share(), @@ -962,7 +973,7 @@ impl Transaction<'_, C> { &stmt, &[ /* task_id */ &task_id.get_encoded(), - /* report_id */ &report_share.metadata().report_id().as_ref(), + /* report_id */ &report_share.metadata().id().as_ref(), /* client_timestamp */ &report_share.metadata().time().as_naive_date_time(), ], @@ -1168,7 +1179,7 @@ impl Transaction<'_, C> { &stmt, &[ /* task_id */ &aggregation_job.task_id().as_ref(), - /* aggregation_job_id */ &aggregation_job.aggregation_job_id().as_ref(), + /* aggregation_job_id */ &aggregation_job.id().as_ref(), /* aggregation_param */ &aggregation_job.aggregation_parameter().get_encoded(), /* state */ &aggregation_job.state(), @@ -1204,7 +1215,7 @@ impl Transaction<'_, C> { /* state */ &aggregation_job.state(), /* task_id */ &aggregation_job.task_id().as_ref(), /* aggregation_job_id */ - &aggregation_job.aggregation_job_id().as_ref(), + &aggregation_job.id().as_ref(), ], ) .await?, @@ -2069,13 +2080,13 @@ ORDER BY id DESC let stmt = self .tx .prepare_cached( - "WITH tasks AS (SELECT id, min_batch_duration FROM tasks WHERE task_id = $1) + "WITH tasks AS (SELECT id, time_precision FROM tasks WHERE task_id = $1) SELECT unit_interval_start, aggregate_share, report_count, checksum FROM batch_unit_aggregations WHERE task_id = (SELECT id FROM tasks) AND unit_interval_start >= $2 - AND (unit_interval_start + (SELECT min_batch_duration FROM tasks) * interval '1 second') <= $3 + AND (unit_interval_start + (SELECT time_precision FROM tasks) * interval '1 second') <= $3 AND aggregation_param = $4", ) .await?; @@ -2676,7 +2687,7 @@ pub mod models { } /// Returns the aggregation job ID associated with this aggregation job. - pub fn aggregation_job_id(&self) -> &AggregationJobId { + pub fn id(&self) -> &AggregationJobId { &self.aggregation_job_id } @@ -3138,7 +3149,7 @@ pub mod models { /// BatchUnitAggregation corresponds to a row in the `batch_unit_aggregations` table and /// represents the possibly-ongoing aggregation of the set of input shares that fall within the - /// interval defined by `unit_interval_start` and the relevant task's `min_batch_duration`. + /// interval defined by `unit_interval_start` and the relevant task's `time_precision`. /// This is the finest-grained possible aggregate share we can emit for this task, hence "batch /// unit". The aggregate share constructed to service a collect or aggregate share request /// consists of one or more `BatchUnitAggregation`s merged together. @@ -3152,7 +3163,7 @@ pub mod models { task_id: TaskId, /// This is an aggregation over report shares whose timestamp falls within the interval /// starting at this time and of duration equal to the corresponding task's - /// `min_batch_duration`. `unit_interval_start` is aligned to `min_batch_duration`. + /// `time_precision`. `unit_interval_start` is a multiple of `time_precision`. unit_interval_start: Time, /// The VDAF aggregation parameter used to prepare and accumulate input shares. #[derivative(Debug = "ignore")] @@ -3892,33 +3903,41 @@ pub mod test_util { #[cfg(test)] mod tests { - use super::*; use crate::{ datastore::{ - models::{AggregationJobState, CollectJobState}, + models::{ + AcquiredAggregationJob, AggregateShareJob, AggregationJob, AggregationJobState, + BatchUnitAggregation, CollectJob, CollectJobState, Lease, ReportAggregation, + ReportAggregationState, SqlInterval, + }, test_util::{ephemeral_datastore, generate_aead_key}, + Crypter, Error, }, messages::{test_util::new_dummy_report, DurationExt, TimeExt}, - task::{VdafInstance, PRIO3_AES128_VERIFY_KEY_LENGTH}, + task::{test_util::TaskBuilder, QueryType, Task, PRIO3_AES128_VERIFY_KEY_LENGTH}, }; use assert_matches::assert_matches; use chrono::NaiveDate; use futures::future::try_join_all; use janus_core::{ hpke::{self, associated_data_for_aggregate_share, HpkeApplicationInfo, Label}, + task::VdafInstance, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, }, - time::{MockClock, TimeExt as CoreTimeExt}, + time::{Clock, MockClock, TimeExt as CoreTimeExt}, }; use janus_messages::{ - query_type::TimeInterval, Duration, ExtensionType, HpkeConfigId, Interval, - ReportShareError, Role, Time, + query_type::TimeInterval, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, + Interval, Report, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, + ReportShareError, Role, TaskId, Time, }; use prio::{ + codec::{Decode, Encode}, field::{Field128, Field64}, vdaf::{ + self, poplar1::{IdpfInput, Poplar1, ToyIdpf}, prg::PrgAes128, prio3::{Prio3, Prio3Aes128Count}, @@ -3931,6 +3950,9 @@ mod tests { iter, sync::Arc, }; + use uuid::Uuid; + + use super::{models::AcquiredCollectJob, Datastore}; const DUMMY_VERIFY_KEY_LENGTH: usize = dummy_vdaf::Vdaf::VERIFY_KEY_LENGTH; @@ -3939,72 +3961,52 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let values = [ - ( - random(), - janus_core::task::VdafInstance::Prio3Aes128Count, - Role::Leader, - ), + // Insert tasks, check that they can be retrieved by ID. + let mut want_tasks = HashMap::new(); + for (vdaf, role) in [ + (VdafInstance::Prio3Aes128Count, Role::Leader), ( - random(), - janus_core::task::VdafInstance::Prio3Aes128CountVec { length: 8 }, + VdafInstance::Prio3Aes128CountVec { length: 8 }, Role::Leader, ), ( - random(), - janus_core::task::VdafInstance::Prio3Aes128CountVec { length: 64 }, - Role::Helper, - ), - ( - random(), - janus_core::task::VdafInstance::Prio3Aes128Sum { bits: 64 }, + VdafInstance::Prio3Aes128CountVec { length: 64 }, Role::Helper, ), + (VdafInstance::Prio3Aes128Sum { bits: 64 }, Role::Helper), + (VdafInstance::Prio3Aes128Sum { bits: 32 }, Role::Helper), ( - random(), - janus_core::task::VdafInstance::Prio3Aes128Sum { bits: 32 }, - Role::Helper, - ), - ( - random(), - janus_core::task::VdafInstance::Prio3Aes128Histogram { - buckets: vec![0, 100, 200, 400], + VdafInstance::Prio3Aes128Histogram { + buckets: Vec::from([0, 100, 200, 400]), }, Role::Leader, ), ( - random(), - janus_core::task::VdafInstance::Prio3Aes128Histogram { - buckets: vec![0, 25, 50, 75, 100], + VdafInstance::Prio3Aes128Histogram { + buckets: Vec::from([0, 25, 50, 75, 100]), }, Role::Leader, ), - ( - random(), - janus_core::task::VdafInstance::Poplar1 { bits: 8 }, - Role::Helper, - ), - ( - random(), - janus_core::task::VdafInstance::Poplar1 { bits: 64 }, - Role::Helper, - ), - ]; - - // Insert tasks, check that they can be retrieved by ID. - let mut want_tasks = HashMap::new(); - for (task_id, vdaf, role) in values { - let task = Task::new_dummy(task_id, vdaf.into(), role); - want_tasks.insert(task_id, task.clone()); + (VdafInstance::Poplar1 { bits: 8 }, Role::Helper), + (VdafInstance::Poplar1 { bits: 64 }, Role::Helper), + ] { + let task = TaskBuilder::new(QueryType::TimeInterval, vdaf.into(), role).build(); + want_tasks.insert(*task.id(), task.clone()); let err = ds - .run_tx(|tx| Box::pin(async move { tx.delete_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.delete_task(task.id()).await }) + }) .await .unwrap_err(); assert_matches!(err, Error::MutationTargetNotFound); let retrieved_task = ds - .run_tx(|tx| Box::pin(async move { tx.get_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.get_task(task.id()).await }) + }) .await .unwrap(); assert_eq!(None, retrieved_task); @@ -4012,23 +4014,35 @@ mod tests { ds.put_task(&task).await.unwrap(); let retrieved_task = ds - .run_tx(|tx| Box::pin(async move { tx.get_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.get_task(task.id()).await }) + }) .await .unwrap(); assert_eq!(Some(&task), retrieved_task.as_ref()); - ds.run_tx(|tx| Box::pin(async move { tx.delete_task(&task_id).await })) - .await - .unwrap(); + ds.run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.delete_task(task.id()).await }) + }) + .await + .unwrap(); let retrieved_task = ds - .run_tx(|tx| Box::pin(async move { tx.get_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.get_task(task.id()).await }) + }) .await .unwrap(); assert_eq!(None, retrieved_task); let err = ds - .run_tx(|tx| Box::pin(async move { tx.delete_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.delete_task(task.id()).await }) + }) .await .unwrap_err(); assert_matches!(err, Error::MutationTargetNotFound); @@ -4039,10 +4053,13 @@ mod tests { ds.put_task(&task).await.unwrap(); let retrieved_task = ds - .run_tx(|tx| Box::pin(async move { tx.get_task(&task_id).await })) + .run_tx(|tx| { + let task = task.clone(); + Box::pin(async move { tx.get_task(task.id()).await }) + }) .await .unwrap(); - assert_eq!(Some(&task), retrieved_task.as_ref()); + assert_eq!(Some(task), retrieved_task); } let got_tasks: HashMap = ds @@ -4050,7 +4067,7 @@ mod tests { .await .unwrap() .into_iter() - .map(|task| (task.id, task)) + .map(|task| (*task.id(), task)) .collect(); assert_eq!(want_tasks, got_tasks); } @@ -4060,15 +4077,21 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let report = Report::new( - random(), + *task.id(), ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), Time::from_seconds_since_epoch(12345), - vec![ + Vec::from([ Extension::new(ExtensionType::Tbd, Vec::from("extension_data_0")), Extension::new(ExtensionType::Tbd, Vec::from("extension_data_1")), - ], + ]), ), Vec::from("public_share"), Vec::from([ @@ -4086,14 +4109,9 @@ mod tests { ); ds.run_tx(|tx| { - let report = report.clone(); + let (task, report) = (task.clone(), report.clone()); Box::pin(async move { - tx.put_task(&Task::new_dummy( - *report.task_id(), - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; tx.put_client_report(&report).await }) }) @@ -4102,9 +4120,11 @@ mod tests { let retrieved_report = ds .run_tx(|tx| { - let task_id = *report.task_id(); - let report_id = *report.metadata().report_id(); - Box::pin(async move { tx.get_client_report(&task_id, &report_id).await }) + let report = report.clone(); + Box::pin(async move { + tx.get_client_report(report.task_id(), report.metadata().id()) + .await + }) }) .await .unwrap(); @@ -4140,31 +4160,42 @@ mod tests { let when = MockClock::default() .now() - .to_batch_unit_interval_start(Duration::from_seconds(1000)) + .to_batch_unit_interval_start(&Duration::from_seconds(1000)) .unwrap(); - let task_id = random(); - let unrelated_task_id = random(); + + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); + let unrelated_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let first_unaggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), when, Vec::new()), Vec::new(), Vec::new(), ); let second_unaggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), when, Vec::new()), Vec::new(), Vec::new(), ); let aggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), when, Vec::new()), Vec::new(), Vec::new(), ); let unrelated_report = Report::new( - unrelated_task_id, + *unrelated_task.id(), ReportMetadata::new(random(), when, Vec::new()), Vec::new(), Vec::new(), @@ -4173,11 +4204,15 @@ mod tests { // Set up state. ds.run_tx(|tx| { let ( + task, + unrelated_task, first_unaggregated_report, second_unaggregated_report, aggregated_report, unrelated_report, ) = ( + task.clone(), + unrelated_task.clone(), first_unaggregated_report.clone(), second_unaggregated_report.clone(), aggregated_report.clone(), @@ -4185,18 +4220,8 @@ mod tests { ); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; - tx.put_task(&Task::new_dummy( - unrelated_task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; + tx.put_task(&unrelated_task).await?; tx.put_client_report(&first_unaggregated_report).await?; tx.put_client_report(&second_unaggregated_report).await?; @@ -4208,7 +4233,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -4217,9 +4242,9 @@ mod tests { tx .put_report_aggregation( &ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, - *aggregated_report.metadata().report_id(), + *aggregated_report.metadata().id(), *aggregated_report.metadata().time(), 0, ReportAggregationState::< @@ -4237,8 +4262,9 @@ mod tests { // Run query & verify results. let got_reports = HashSet::from_iter( ds.run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - tx.get_unaggregated_client_report_ids_for_task(&task_id) + tx.get_unaggregated_client_report_ids_for_task(task.id()) .await }) }) @@ -4250,11 +4276,11 @@ mod tests { got_reports, HashSet::from([ ( - *first_unaggregated_report.metadata().report_id(), + *first_unaggregated_report.metadata().id(), *first_unaggregated_report.metadata().time(), ), ( - *second_unaggregated_report.metadata().report_id(), + *second_unaggregated_report.metadata().id(), *second_unaggregated_report.metadata().time(), ), ]), @@ -4266,29 +4292,39 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let task_id = random(); - let unrelated_task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); + let unrelated_task = TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .build(); let first_unaggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), Time::from_seconds_since_epoch(12345), Vec::new()), Vec::new(), Vec::new(), ); let second_unaggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), Time::from_seconds_since_epoch(12346), Vec::new()), Vec::new(), Vec::new(), ); let aggregated_report = Report::new( - task_id, + *task.id(), ReportMetadata::new(random(), Time::from_seconds_since_epoch(12347), Vec::new()), Vec::new(), Vec::new(), ); let unrelated_report = Report::new( - unrelated_task_id, + *unrelated_task.id(), ReportMetadata::new(random(), Time::from_seconds_since_epoch(12348), Vec::new()), Vec::new(), Vec::new(), @@ -4297,11 +4333,15 @@ mod tests { // Set up state. ds.run_tx(|tx| { let ( + task, + unrelated_task, first_unaggregated_report, second_unaggregated_report, aggregated_report, unrelated_report, ) = ( + task.clone(), + unrelated_task.clone(), first_unaggregated_report.clone(), second_unaggregated_report.clone(), aggregated_report.clone(), @@ -4309,14 +4349,8 @@ mod tests { ); Box::pin(async move { - tx.put_task(&Task::new_dummy(task_id, VdafInstance::Fake, Role::Leader)) - .await?; - tx.put_task(&Task::new_dummy( - unrelated_task_id, - VdafInstance::Fake, - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; + tx.put_task(&unrelated_task).await?; tx.put_client_report(&first_unaggregated_report).await?; tx.put_client_report(&second_unaggregated_report).await?; @@ -4326,7 +4360,7 @@ mod tests { // There are no client reports submitted under this task, so we shouldn't see // this aggregation parameter at all. tx.put_collect_job(&CollectJob::new( - unrelated_task_id, + *unrelated_task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(0), @@ -4346,8 +4380,9 @@ mod tests { // collect requests. let got_reports = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - tx.get_unaggregated_client_report_ids_by_collect_for_task::(&task_id) + tx.get_unaggregated_client_report_ids_by_collect_for_task::(task.id()) .await }) }) @@ -4357,11 +4392,14 @@ mod tests { // Add collect jobs, and mark one report as having already been aggregated once. ds.run_tx(|tx| { - let aggregated_report_time = *aggregated_report.metadata().time(); - let aggregated_report_id = *aggregated_report.metadata().report_id(); + let (task, aggregated_report_id, aggregated_report_time) = ( + task.clone(), + *aggregated_report.metadata().id(), + *aggregated_report.metadata().time(), + ); Box::pin(async move { tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(0), @@ -4373,7 +4411,7 @@ mod tests { )) .await?; tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(0), @@ -4387,7 +4425,7 @@ mod tests { // No reports fall in this interval, so we shouldn't see it's aggregation // parameter at all. tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(8 * 3600), @@ -4402,7 +4440,7 @@ mod tests { let aggregation_job_id = random(); tx.put_aggregation_job( &AggregationJob::::new( - task_id, + *task.id(), aggregation_job_id, AggregationParam(0), AggregationJobState::InProgress, @@ -4413,7 +4451,7 @@ mod tests { DUMMY_VERIFY_KEY_LENGTH, dummy_vdaf::Vdaf, >::new( - task_id, + *task.id(), aggregation_job_id, aggregated_report_id, aggregated_report_time, @@ -4430,8 +4468,9 @@ mod tests { // and three with another. let mut got_reports = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - tx.get_unaggregated_client_report_ids_by_collect_for_task::(&task_id) + tx.get_unaggregated_client_report_ids_by_collect_for_task::(task.id()) .await }) }) @@ -4440,27 +4479,27 @@ mod tests { let mut expected_reports = Vec::from([ ( - *first_unaggregated_report.metadata().report_id(), + *first_unaggregated_report.metadata().id(), *first_unaggregated_report.metadata().time(), AggregationParam(0), ), ( - *first_unaggregated_report.metadata().report_id(), + *first_unaggregated_report.metadata().id(), *first_unaggregated_report.metadata().time(), AggregationParam(1), ), ( - *second_unaggregated_report.metadata().report_id(), + *second_unaggregated_report.metadata().id(), *second_unaggregated_report.metadata().time(), AggregationParam(0), ), ( - *second_unaggregated_report.metadata().report_id(), + *second_unaggregated_report.metadata().id(), *second_unaggregated_report.metadata().time(), AggregationParam(1), ), ( - *aggregated_report.metadata().report_id(), + *aggregated_report.metadata().id(), *aggregated_report.metadata().time(), AggregationParam(1), ), @@ -4472,9 +4511,10 @@ mod tests { // Add overlapping collect jobs with repeated aggregation parameters. Make sure we don't // repeat result tuples, which could lead to double counting in batch unit aggregations. ds.run_tx(|tx| { + let task = task.clone(); Box::pin(async move { tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(0), @@ -4486,7 +4526,7 @@ mod tests { )) .await?; tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), Interval::new( Time::from_seconds_since_epoch(0), @@ -4506,8 +4546,9 @@ mod tests { // Verify that we get the same result. let mut got_reports = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - tx.get_unaggregated_client_report_ids_by_collect_for_task::(&task_id) + tx.get_unaggregated_client_report_ids_by_collect_for_task::(task.id()) .await }) }) @@ -4522,7 +4563,12 @@ mod tests { install_test_trace_subscriber(); let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let report_share = ReportShare::new( ReportMetadata::new( ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), @@ -4542,18 +4588,13 @@ mod tests { let got_report_share_exists = ds .run_tx(|tx| { - let report_share = report_share.clone(); + let (task, report_share) = (task.clone(), report_share.clone()); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; let report_share_exists = tx - .check_report_share_exists(&task_id, report_share.metadata().report_id()) + .check_report_share_exists(task.id(), report_share.metadata().id()) .await?; - tx.put_report_share(&task_id, &report_share).await?; + tx.put_report_share(task.id(), &report_share).await?; Ok(report_share_exists) }) }) @@ -4563,9 +4604,9 @@ mod tests { let (got_report_share_exists, got_task_id, got_extensions, got_input_shares) = ds .run_tx(|tx| { - let report_share_metadata = report_share.metadata().clone(); + let (task, report_share_metadata) = (task.clone(), report_share.metadata().clone()); Box::pin(async move { - let report_share_exists = tx.check_report_share_exists(&task_id, report_share_metadata.report_id()).await?; + let report_share_exists = tx.check_report_share_exists(task.id(), report_share_metadata.id()).await?; let row = tx .tx .query_one( @@ -4573,7 +4614,7 @@ mod tests { FROM client_reports JOIN tasks ON tasks.id = client_reports.task_id WHERE report_id = $1 AND client_timestamp = $2", &[ - /* report_id */ &report_share_metadata.report_id().as_ref(), + /* report_id */ &report_share_metadata.id().as_ref(), /* client_timestamp */ &report_share_metadata.time().as_naive_date_time(), ], ) @@ -4591,7 +4632,7 @@ mod tests { .unwrap(); assert!(got_report_share_exists); - assert_eq!(task_id, got_task_id); + assert_eq!(task.id(), &got_task_id); assert!(got_extensions.is_none()); assert!(got_input_shares.is_none()); } @@ -4605,8 +4646,14 @@ mod tests { // better exercising the serialization/deserialization roundtrip of the aggregation_param. const PRG_SEED_SIZE: usize = 16; type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 64 }.into(), + Role::Leader, + ) + .build(); let aggregation_job = AggregationJob::::new( - random(), + *task.id(), random(), BTreeSet::from([ IdpfInput::new("abc".as_bytes(), 0).unwrap(), @@ -4616,14 +4663,9 @@ mod tests { ); ds.run_tx(|tx| { - let aggregation_job = aggregation_job.clone(); + let (task, aggregation_job) = (task.clone(), aggregation_job.clone()); Box::pin(async move { - tx.put_task(&Task::new_dummy( - *aggregation_job.task_id(), - janus_core::task::VdafInstance::Poplar1 { bits: 64 }.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; tx.put_aggregation_job(&aggregation_job).await }) }) @@ -4634,11 +4676,8 @@ mod tests { .run_tx(|tx| { let aggregation_job = aggregation_job.clone(); Box::pin(async move { - tx.get_aggregation_job( - aggregation_job.task_id(), - aggregation_job.aggregation_job_id(), - ) - .await + tx.get_aggregation_job(aggregation_job.task_id(), aggregation_job.id()) + .await }) }) .await @@ -4659,11 +4698,8 @@ mod tests { .run_tx(|tx| { let aggregation_job = aggregation_job.clone(); Box::pin(async move { - tx.get_aggregation_job( - aggregation_job.task_id(), - aggregation_job.aggregation_job_id(), - ) - .await + tx.get_aggregation_job(aggregation_job.task_id(), aggregation_job.id()) + .await }) }) .await @@ -4679,7 +4715,12 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; const AGGREGATION_JOB_COUNT: usize = 10; - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let mut aggregation_job_ids: Vec<_> = thread_rng() .sample_iter(Standard) .take(AGGREGATION_JOB_COUNT) @@ -4687,22 +4728,17 @@ mod tests { aggregation_job_ids.sort(); ds.run_tx(|tx| { - let aggregation_job_ids = aggregation_job_ids.clone(); + let (task, aggregation_job_ids) = (task.clone(), aggregation_job_ids.clone()); Box::pin(async move { // Write a few aggregation jobs we expect to be able to retrieve with // acquire_incomplete_aggregation_jobs(). - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; for aggregation_job_id in aggregation_job_ids { tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -4715,24 +4751,24 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, random(), (), AggregationJobState::Finished + *task.id(), random(), (), AggregationJobState::Finished )) .await?; // Write an aggregation job for a task that we are taking on the helper role for. // We don't want to retrieve this one, either. - let helper_task_id = random(); - tx.put_task(&Task::new_dummy( - helper_task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let helper_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - )) - .await?; + ) + .build(); + tx.put_task(&helper_task).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - helper_task_id, + *helper_task.id(), random(), (), AggregationJobState::InProgress, @@ -4784,9 +4820,9 @@ mod tests { .map(|&agg_job_id| { ( AcquiredAggregationJob::new( - task_id, + *task.id(), agg_job_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + VdafInstance::Prio3Aes128Count.into(), ), want_expiry_time, ) @@ -4866,9 +4902,9 @@ mod tests { .map(|&job_id| { ( AcquiredAggregationJob::new( - task_id, + *task.id(), job_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + VdafInstance::Prio3Aes128Count.into(), ), want_expiry_time, ) @@ -4983,9 +5019,14 @@ mod tests { // better exercising the serialization/deserialization roundtrip of the aggregation_param. const PRG_SEED_SIZE: usize = 16; type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 64 }.into(), + Role::Leader, + ) + .build(); let first_aggregation_job = AggregationJob::::new( - task_id, + *task.id(), random(), BTreeSet::from([ IdpfInput::new("abc".as_bytes(), 0).unwrap(), @@ -4994,7 +5035,7 @@ mod tests { AggregationJobState::InProgress, ); let second_aggregation_job = AggregationJob::::new( - task_id, + *task.id(), random(), BTreeSet::from([ IdpfInput::new("ghi".as_bytes(), 2).unwrap(), @@ -5004,31 +5045,27 @@ mod tests { ); ds.run_tx(|tx| { - let (first_aggregation_job, second_aggregation_job) = ( + let (task, first_aggregation_job, second_aggregation_job) = ( + task.clone(), first_aggregation_job.clone(), second_aggregation_job.clone(), ); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Poplar1 { bits: 64 }.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; tx.put_aggregation_job(&first_aggregation_job).await?; tx.put_aggregation_job(&second_aggregation_job).await?; // Also write an unrelated aggregation job with a different task ID to check that it // is not returned. - let unrelated_task_id = random(); - tx.put_task(&Task::new_dummy( - unrelated_task_id, - janus_core::task::VdafInstance::Poplar1 { bits: 64 }.into(), + let unrelated_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Poplar1 { bits: 64 }.into(), Role::Leader, - )) - .await?; + ) + .build(); + tx.put_task(&unrelated_task).await?; tx.put_aggregation_job(&AggregationJob::::new( - unrelated_task_id, + *unrelated_task.id(), random(), BTreeSet::from([ IdpfInput::new("foo".as_bytes(), 10).unwrap(), @@ -5043,15 +5080,16 @@ mod tests { .unwrap(); // Run. - let mut want_agg_jobs = vec![first_aggregation_job, second_aggregation_job]; - want_agg_jobs.sort_by_key(|agg_job| *agg_job.aggregation_job_id()); + let mut want_agg_jobs = Vec::from([first_aggregation_job, second_aggregation_job]); + want_agg_jobs.sort_by_key(|agg_job| *agg_job.id()); let mut got_agg_jobs = ds .run_tx(|tx| { - Box::pin(async move { tx.get_aggregation_jobs_for_task_id(&task_id).await }) + let task = task.clone(); + Box::pin(async move { tx.get_aggregation_jobs_for_task_id(task.id()).await }) }) .await .unwrap(); - got_agg_jobs.sort_by_key(|agg_job| *agg_job.aggregation_job_id()); + got_agg_jobs.sort_by_key(|agg_job| *agg_job.id()); // Verify. assert_eq!(want_agg_jobs, got_agg_jobs); @@ -5073,36 +5111,36 @@ mod tests { ReportAggregationState::Failed(ReportShareError::VdafPrepError), ReportAggregationState::Invalid, ] - .iter() + .into_iter() .enumerate() { - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let aggregation_job_id = random(); let time = Time::from_seconds_since_epoch(12345); let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); let report_aggregation = ds .run_tx(|tx| { - let state = state.clone(); + let (task, state) = (task.clone(), state.clone()); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, )) .await?; tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( ReportMetadata::new(report_id, time, Vec::new()), Vec::from("public_share"), @@ -5116,7 +5154,7 @@ mod tests { .await?; let report_aggregation = ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, report_id, time, @@ -5132,12 +5170,12 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, &report_id, ) @@ -5165,12 +5203,12 @@ mod tests { let got_report_aggregation = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); Box::pin(async move { tx.get_report_aggregation( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, &report_id, ) @@ -5236,27 +5274,29 @@ mod tests { let vdaf = Arc::new(Prio3::new_aes128_count(2).unwrap()); let (prep_state, prep_msg, output_share) = generate_vdaf_values(vdaf.as_ref(), (), 0); - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let aggregation_job_id = random(); let report_aggregations = ds .run_tx(|tx| { - let prep_msg = prep_msg.clone(); - let prep_state = prep_state.clone(); - let output_share = output_share.clone(); - + let (task, prep_msg, prep_state, output_share) = ( + task.clone(), + prep_msg.clone(), + prep_state.clone(), + output_share.clone(), + ); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; tx.put_aggregation_job(&AggregationJob::< PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), aggregation_job_id, (), AggregationJobState::InProgress, @@ -5282,7 +5322,7 @@ mod tests { let time = Time::from_seconds_since_epoch(12345); let report_id = ReportId::from((ord as u128).to_be_bytes()); tx.put_report_share( - &task_id, + task.id(), &ReportShare::new( ReportMetadata::new(report_id, time, Vec::new()), Vec::from("public_share"), @@ -5296,7 +5336,7 @@ mod tests { .await?; let report_aggregation = ReportAggregation::new( - task_id, + *task.id(), aggregation_job_id, report_id, time, @@ -5314,12 +5354,12 @@ mod tests { let got_report_aggregations = ds .run_tx(|tx| { - let vdaf = Arc::clone(&vdaf); + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); Box::pin(async move { tx.get_report_aggregations_for_aggregation_job( vdaf.as_ref(), &Role::Leader, - &task_id, + task.id(), &aggregation_job_id, ) .await @@ -5332,7 +5372,7 @@ mod tests { #[tokio::test] async fn crypter() { - let crypter = Crypter::new(vec![generate_aead_key(), generate_aead_key()]); + let crypter = Crypter::new(Vec::from([generate_aead_key(), generate_aead_key()])); let bad_key = generate_aead_key(); const TABLE: &str = "some_table"; @@ -5374,7 +5414,12 @@ mod tests { async fn lookup_collect_job() { install_test_trace_subscriber(); - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -5390,23 +5435,18 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; ds.run_tx(|tx| { - Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await - }) + let task = task.clone(); + Box::pin(async move { tx.put_task(&task).await }) }) .await .unwrap(); let collect_job_id = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { tx.get_collect_job_id::( - &task_id, + task.id(), &batch_interval, &(), ) @@ -5419,9 +5459,10 @@ mod tests { let collect_job_id = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { let collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), batch_interval, (), @@ -5436,9 +5477,10 @@ mod tests { let same_collect_job_id = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { tx.get_collect_job_id::( - &task_id, + task.id(), &batch_interval, &(), ) @@ -5455,11 +5497,12 @@ mod tests { // Check that we can find the collect job by timestamp. let (collect_jobs_by_time, collect_jobs_by_interval) = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { let collect_jobs_by_time = tx.get_collect_jobs_including_time:: - (&task_id, ×tamp).await?; + (task.id(), ×tamp).await?; let collect_jobs_by_interval = tx.get_collect_jobs_jobs_intersecting_interval:: - (&task_id, &interval).await?; + (task.id(), &interval).await?; Ok((collect_jobs_by_time, collect_jobs_by_interval)) }) }) @@ -5470,7 +5513,7 @@ mod tests { PRIO3_AES128_VERIFY_KEY_LENGTH, Prio3Aes128Count, >::new( - task_id, + *task.id(), collect_job_id, batch_interval, (), @@ -5501,10 +5544,11 @@ mod tests { .unwrap(); let different_collect_job_id = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { let collect_job_id = Uuid::new_v4(); tx.put_collect_job(&CollectJob::new( - task_id, + *task.id(), collect_job_id, different_batch_interval, (), @@ -5538,11 +5582,12 @@ mod tests { // Check that we can find both collect jobs by timestamp. let (mut collect_jobs_by_time, mut collect_jobs_by_interval) = ds .run_tx(|tx| { + let task = task.clone(); Box::pin(async move { let collect_jobs_by_time = tx.get_collect_jobs_including_time:: - (&task_id, ×tamp).await?; + (task.id(), ×tamp).await?; let collect_jobs_by_interval = tx.get_collect_jobs_jobs_intersecting_interval:: - (&task_id, &interval).await?; + (task.id(), &interval).await?; Ok((collect_jobs_by_time, collect_jobs_by_interval)) }) }) @@ -5553,14 +5598,14 @@ mod tests { let mut want_collect_jobs = Vec::from([ CollectJob::::new( - task_id, + *task.id(), collect_job_id, batch_interval, (), CollectJobState::Start, ), CollectJob::::new( - task_id, + *task.id(), different_collect_job_id, different_batch_interval, (), @@ -5577,8 +5622,18 @@ mod tests { async fn get_collect_job_task_id() { install_test_trace_subscriber(); - let first_task_id = random(); - let second_task_id = random(); + let first_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); + let second_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -5588,26 +5643,14 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; ds.run_tx(|tx| { + let (first_task, second_task) = (first_task.clone(), second_task.clone()); Box::pin(async move { - tx.put_task(&Task::new_dummy( - first_task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await - .unwrap(); - - tx.put_task(&Task::new_dummy( - second_task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await - .unwrap(); + tx.put_task(&first_task).await.unwrap(); + tx.put_task(&second_task).await.unwrap(); let first_collect_job_id = Uuid::new_v4(); tx.put_collect_job(&CollectJob::new( - first_task_id, + *first_task.id(), first_collect_job_id, batch_interval, (), @@ -5618,7 +5661,7 @@ mod tests { let second_collect_job_id = Uuid::new_v4(); tx.put_collect_job(&CollectJob::new( - second_task_id, + *second_task.id(), second_collect_job_id, batch_interval, (), @@ -5628,16 +5671,18 @@ mod tests { .unwrap(); assert_eq!( - Some(first_task_id), + Some(first_task.id()), tx.get_collect_job_task_id(&first_collect_job_id) .await .unwrap() + .as_ref() ); assert_eq!( - Some(second_task_id), + Some(second_task.id()), tx.get_collect_job_task_id(&second_collect_job_id) .await .unwrap() + .as_ref() ); assert_eq!( None, @@ -5655,7 +5700,12 @@ mod tests { async fn get_collect_job() { install_test_trace_subscriber(); - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let first_batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -5670,16 +5720,12 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; ds.run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - ); tx.put_task(&task).await.unwrap(); let first_collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), first_batch_interval, (), @@ -5688,7 +5734,7 @@ mod tests { tx.put_collect_job(&first_collect_job).await.unwrap(); let second_collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), second_batch_interval, (), @@ -5714,14 +5760,18 @@ mod tests { .unwrap(); assert_eq!(second_collect_job, second_collect_job_again); - let leader_aggregate_share = AggregateShare::from(vec![Field64::from(1)]); + let leader_aggregate_share = AggregateShare::from(Vec::from([Field64::from(1)])); let encrypted_helper_aggregate_share = hpke::seal( - &task.collector_hpke_config, - &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + task.collector_hpke_config(), + &HpkeApplicationInfo::new( + &Label::AggregateShare, + &Role::Helper, + &Role::Collector, + ), &[0, 1, 2, 3, 4, 5], &associated_data_for_aggregate_share::( - &task.id, + task.id(), &first_batch_interval, ), ) @@ -5762,7 +5812,12 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; - let task_id = random(); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); let abandoned_batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -5775,16 +5830,12 @@ mod tests { .unwrap(); ds.run_tx(|tx| { + let task = task.clone(); Box::pin(async move { - tx.put_task(&Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&task).await?; let abandoned_collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), abandoned_batch_interval, (), @@ -5793,7 +5844,7 @@ mod tests { tx.put_collect_job(&abandoned_collect_job).await?; let deleted_collect_job = CollectJob::new( - task_id, + *task.id(), Uuid::new_v4(), deleted_batch_interval, (), @@ -5895,8 +5946,16 @@ mod tests { let mut test_case = test_case.clone(); Box::pin(async move { for task_id in &test_case.task_ids { - tx.put_task(&Task::new_dummy(*task_id, VdafInstance::Fake, Role::Leader)) - .await?; + tx.put_task( + &TaskBuilder::new( + QueryType::TimeInterval, + crate::task::VdafInstance::Fake, + Role::Leader, + ) + .with_id(*task_id) + .build(), + ) + .await?; } for report in &test_case.reports { @@ -5923,8 +5982,8 @@ mod tests { report_count: 1, encrypted_helper_aggregate_share: HpkeCiphertext::new( HpkeConfigId::from(0), - vec![], - vec![], + Vec::new(), + Vec::new(), ), leader_aggregate_share: dummy_vdaf::AggregateShare(0), }, @@ -5973,7 +6032,7 @@ mod tests { AcquiredCollectJob::new( c.task_id, c.collect_job_id.unwrap(), - VdafInstance::Fake, + crate::task::VdafInstance::Fake, ), clock.now().add(&Duration::from_seconds(100)).unwrap(), ) @@ -6014,7 +6073,7 @@ mod tests { >::new( task_id, aggregation_job_id, - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, // Doesn't matter what state the report aggregation is in @@ -6036,7 +6095,7 @@ mod tests { let collect_job_leases = run_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, report_aggregations, @@ -6151,10 +6210,10 @@ mod tests { run_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id, other_task_id], - reports: vec![], + task_ids: Vec::from([task_id, other_task_id]), + reports: Vec::new(), aggregation_jobs, - report_aggregations: vec![], + report_aggregations: Vec::new(), collect_job_test_cases, }, ) @@ -6168,7 +6227,7 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let task_id = random(); - let reports = vec![new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]; + let reports = Vec::from([new_dummy_report(task_id, Time::from_seconds_since_epoch(0))]); let aggregation_jobs = Vec::from([AggregationJob::< DUMMY_VERIFY_KEY_LENGTH, @@ -6197,10 +6256,10 @@ mod tests { run_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, - report_aggregations: vec![], + report_aggregations: Vec::new(), collect_job_test_cases, }, ) @@ -6214,12 +6273,12 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(clock.clone()).await; let task_id = random(); - let reports = vec![new_dummy_report( + let reports = Vec::from([new_dummy_report( task_id, // Report associated with the aggregation job is outside the collect job's batch // interval Time::from_seconds_since_epoch(200), - )]; + )]); let aggregation_job_id = random(); let aggregation_jobs = Vec::from([AggregationJob::< DUMMY_VERIFY_KEY_LENGTH, @@ -6236,7 +6295,7 @@ mod tests { >::new( task_id, aggregation_job_id, - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, // Shouldn't matter what state the report aggregation is in @@ -6258,7 +6317,7 @@ mod tests { run_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, report_aggregations, @@ -6293,7 +6352,7 @@ mod tests { >::new( task_id, aggregation_job_id, - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6359,7 +6418,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[0], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6367,7 +6426,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[1], - *reports[1].metadata().report_id(), + *reports[1].metadata().id(), *reports[1].metadata().time(), 0, ReportAggregationState::Start, @@ -6390,7 +6449,7 @@ mod tests { run_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, report_aggregations, @@ -6427,7 +6486,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[0], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6435,7 +6494,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[1], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6472,7 +6531,7 @@ mod tests { let test_case = setup_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, report_aggregations, @@ -6514,7 +6573,7 @@ mod tests { AcquiredCollectJob::new( c.task_id, c.collect_job_id.unwrap(), - VdafInstance::Fake, + crate::task::VdafInstance::Fake, ), clock.now().add(&Duration::from_seconds(100)).unwrap(), ) @@ -6564,7 +6623,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[0], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6572,7 +6631,7 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[1], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, @@ -6580,14 +6639,14 @@ mod tests { ReportAggregation::::new( task_id, aggregation_job_ids[2], - *reports[0].metadata().report_id(), + *reports[0].metadata().id(), *reports[0].metadata().time(), 0, ReportAggregationState::Start, ), ]); - let collect_job_test_cases = vec![ + let collect_job_test_cases = Vec::from([ CollectJobTestCase { should_be_acquired: true, task_id, @@ -6624,12 +6683,12 @@ mod tests { collect_job_id: None, state: CollectJobTestCaseState::Deleted, }, - ]; + ]); setup_collect_job_acquire_test_case( &ds, CollectJobAcquireTestCase { - task_ids: vec![task_id], + task_ids: Vec::from([task_id]), reports, aggregation_jobs, report_aggregations, @@ -6660,9 +6719,20 @@ mod tests { const PRG_SEED_SIZE: usize = 16; type ToyPoplar1 = Poplar1, PrgAes128, PRG_SEED_SIZE>; - let task_id = random(); - let other_task_id = random(); - let aggregate_share = AggregateShare::from(vec![Field64::from(17)]); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .with_time_precision(Duration::from_seconds(100)) + .build(); + let other_task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); + let aggregate_share = AggregateShare::from(Vec::from([Field64::from(17)])); let aggregation_param = BTreeSet::from([ IdpfInput::new("abc".as_bytes(), 0).unwrap(), IdpfInput::new("def".as_bytes(), 1).unwrap(), @@ -6671,27 +6741,19 @@ mod tests { let (ds, _db_handle) = ephemeral_datastore(MockClock::default()).await; ds.run_tx(|tx| { - let (aggregate_share, aggregation_param) = - (aggregate_share.clone(), aggregation_param.clone()); + let (task, other_task, aggregate_share, aggregation_param) = ( + task.clone(), + other_task.clone(), + aggregate_share.clone(), + aggregation_param.clone(), + ); Box::pin(async move { - let mut task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - ); - task.min_batch_duration = Duration::from_seconds(100); tx.put_task(&task).await?; - - tx.put_task(&Task::new_dummy( - other_task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), - Role::Leader, - )) - .await?; + tx.put_task(&other_task).await?; let first_batch_unit_aggregation = BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(100), aggregation_param.clone(), aggregate_share.clone(), @@ -6701,7 +6763,7 @@ mod tests { let second_batch_unit_aggregation = BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(150), aggregation_param.clone(), aggregate_share.clone(), @@ -6711,7 +6773,7 @@ mod tests { let third_batch_unit_aggregation = BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(200), aggregation_param.clone(), aggregate_share.clone(), @@ -6722,7 +6784,7 @@ mod tests { // Start of this aggregation's interval is before the interval queried below. tx.put_batch_unit_aggregation( &BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(25), aggregation_param.clone(), aggregate_share.clone(), @@ -6744,7 +6806,7 @@ mod tests { // Aggregation parameter differs from the one queried below. tx.put_batch_unit_aggregation( &BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(100), BTreeSet::from([ IdpfInput::new("gh".as_bytes(), 2).unwrap(), @@ -6760,7 +6822,7 @@ mod tests { // End of this aggregation's interval is after the interval queried below. tx.put_batch_unit_aggregation( &BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(250), aggregation_param.clone(), aggregate_share.clone(), @@ -6773,7 +6835,7 @@ mod tests { // Start of this aggregation's interval is after the interval queried below. tx.put_batch_unit_aggregation( &BatchUnitAggregation::::new( - task_id, + *task.id(), Time::from_seconds_since_epoch(400), aggregation_param.clone(), aggregate_share.clone(), @@ -6786,7 +6848,7 @@ mod tests { // Task ID differs from that queried below. tx.put_batch_unit_aggregation( &BatchUnitAggregation::::new( - other_task_id, + *other_task.id(), Time::from_seconds_since_epoch(200), aggregation_param.clone(), aggregate_share.clone(), @@ -6798,7 +6860,7 @@ mod tests { let batch_unit_aggregations = tx .get_batch_unit_aggregations_for_task_in_interval::( - &task_id, + task.id(), &Interval::new( Time::from_seconds_since_epoch(50), Duration::from_seconds(250), @@ -6831,7 +6893,7 @@ mod tests { *first_batch_unit_aggregation.task_id(), *first_batch_unit_aggregation.unit_interval_start(), first_batch_unit_aggregation.aggregation_parameter().clone(), - AggregateShare::from(vec![Field64::from(25)]), + AggregateShare::from(Vec::from([Field64::from(25)])), 1, ReportIdChecksum::get_decoded(&[1; 32]).unwrap(), ); @@ -6841,7 +6903,7 @@ mod tests { let batch_unit_aggregations = tx .get_batch_unit_aggregations_for_task_in_interval::( - &task_id, + task.id(), &Interval::new( Time::from_seconds_since_epoch(50), Duration::from_seconds(250), @@ -6884,15 +6946,14 @@ mod tests { ds.run_tx(|tx| { Box::pin(async move { - let task_id = random(); - let task = Task::new_dummy( - task_id, - janus_core::task::VdafInstance::Prio3Aes128Count.into(), + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), Role::Helper, - ); + ).build(); tx.put_task(&task).await?; - let aggregate_share = AggregateShare::from(vec![Field64::from(17)]); + let aggregate_share = AggregateShare::from(Vec::from([Field64::from(17)])); let batch_interval = Interval::new( Time::from_seconds_since_epoch(100), Duration::from_seconds(100), @@ -6907,7 +6968,7 @@ mod tests { let checksum = ReportIdChecksum::get_decoded(&[1; 32]).unwrap(); let aggregate_share_job = AggregateShareJob::new( - task_id, + *task.id(), batch_interval, (), aggregate_share.clone(), @@ -6923,7 +6984,7 @@ mod tests { let aggregate_share_job_again = tx .get_aggregate_share_job::( - &task_id, + task.id(), &batch_interval, &().get_encoded(), ) @@ -6935,7 +6996,7 @@ mod tests { assert!(tx .get_aggregate_share_job::( - &task_id, + task.id(), &other_batch_interval, &().get_encoded(), ) @@ -6946,12 +7007,12 @@ mod tests { let want_aggregate_share_jobs = Vec::from([aggregate_share_job]); let got_aggregate_share_jobs = tx.get_aggregate_share_jobs_including_time::( - &task_id, &Time::from_seconds_since_epoch(150)).await?; + task.id(), &Time::from_seconds_since_epoch(150)).await?; assert_eq!(got_aggregate_share_jobs, want_aggregate_share_jobs); let got_aggregate_share_jobs = tx.get_aggregate_share_jobs_intersecting_interval:: ( - &task_id, + task.id(), &Interval::new( Time::from_seconds_since_epoch(145), Duration::from_seconds(10)) diff --git a/janus_server/src/lib.rs b/janus_server/src/lib.rs index 0d798ac4b..b7b7e2764 100644 --- a/janus_server/src/lib.rs +++ b/janus_server/src/lib.rs @@ -15,11 +15,13 @@ pub mod trace; pub struct SecretBytes(Vec); impl SecretBytes { - pub fn new(buf: Vec) -> SecretBytes { - SecretBytes(buf) + pub fn new(buf: Vec) -> Self { + Self(buf) } +} - pub fn as_bytes(&self) -> &[u8] { +impl AsRef<[u8]> for SecretBytes { + fn as_ref(&self) -> &[u8] { &self.0 } } diff --git a/janus_server/src/task.rs b/janus_server/src/task.rs index ecbe7691b..0160f7269 100644 --- a/janus_server/src/task.rs +++ b/janus_server/src/task.rs @@ -1,4 +1,4 @@ -//! Shared parameters for a PPM task. +//! Shared parameters for a DAP task. use crate::SecretBytes; use base64::URL_SAFE_NO_PAD; @@ -9,7 +9,7 @@ use janus_core::{ }; use janus_messages::{ Duration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, Interval, - Role, TaskId, + Role, TaskId, Time, }; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; use std::{ @@ -30,6 +30,20 @@ pub enum Error { AggregatorVerifyKeySize, } +/// Identifiers for query types used by a task, along with query-type specific configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum QueryType { + /// Time-interval: used to support a collection style based on fixed time intervals. + TimeInterval, + + /// Fixed-size: used to support collection of batches as quickly as possible, without aligning + /// to a fixed batch window. + FixedSize { + /// The maximum number of reports in a batch to allow it to be collected. + max_batch_size: u64, + }, +} + /// Identifiers for VDAFs supported by this aggregator, corresponding to /// definitions in [draft-irtf-cfrg-vdaf-03][1] and implementations in /// [`prio::vdaf::prio3`]. @@ -170,43 +184,47 @@ impl TryFrom<&SecretBytes> for VerifyKey { } } -/// The parameters for a PPM task, corresponding to draft-gpew-priv-ppm §4.2. +/// The parameters for a DAP task, corresponding to draft-gpew-priv-ppm §4.2. #[derive(Clone, Derivative, PartialEq, Eq)] #[derivative(Debug)] pub struct Task { /// Unique identifier for the task. - pub id: TaskId, + task_id: TaskId, /// URLs relative to which aggregator API endpoints are found. The first /// entry is the leader's. #[derivative(Debug(format_with = "fmt_vector_of_urls"))] - pub aggregator_endpoints: Vec, + aggregator_endpoints: Vec, + /// The query type this task uses to generate batches. + query_type: QueryType, /// The VDAF this task executes. - pub vdaf: VdafInstance, + vdaf: VdafInstance, /// The role performed by the aggregator. - pub role: Role, + role: Role, /// Secret verification keys shared by the aggregators. #[derivative(Debug = "ignore")] vdaf_verify_keys: Vec, /// The maximum number of times a given batch may be collected. - pub max_batch_lifetime: u64, + max_batch_query_count: u64, + /// The time after which the task is considered invalid. + task_expiration: Time, /// The minimum number of reports in a batch to allow it to be collected. - pub min_batch_size: u64, - /// The minimum batch interval for a collect request. Batch intervals must - /// be multiples of this duration. - pub min_batch_duration: Duration, + min_batch_size: u64, + /// The duration to which clients should round their reported timestamps to. For time-interval + /// tasks, batch intervals must be multiples of this duration. + time_precision: Duration, /// How much clock skew to allow between client and aggregator. Reports from /// farther than this duration into the future will be rejected. - pub tolerable_clock_skew: Duration, + tolerable_clock_skew: Duration, /// HPKE configuration for the collector. - pub collector_hpke_config: HpkeConfig, + collector_hpke_config: HpkeConfig, /// Tokens used to authenticate messages sent to or received from the other aggregator. #[derivative(Debug = "ignore")] - pub aggregator_auth_tokens: Vec, + aggregator_auth_tokens: Vec, /// Tokens used to authenticate messages sent to or received from the collector. #[derivative(Debug = "ignore")] - pub collector_auth_tokens: Vec, + collector_auth_tokens: Vec, /// HPKE configurations & private keys used by this aggregator to decrypt client reports. - pub hpke_keys: HashMap, + hpke_keys: HashMap, } impl Task { @@ -214,19 +232,21 @@ impl Task { pub fn new>( task_id: TaskId, mut aggregator_endpoints: Vec, + query_type: QueryType, vdaf: VdafInstance, role: Role, vdaf_verify_keys: Vec, - max_batch_lifetime: u64, + max_batch_query_count: u64, + task_expiration: Time, min_batch_size: u64, - min_batch_duration: Duration, + time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: HpkeConfig, aggregator_auth_tokens: Vec, collector_auth_tokens: Vec, hpke_keys: I, ) -> Result { - // PPM currently only supports configurations of exactly two aggregators. + // DAP currently only supports configurations of exactly two aggregators. if aggregator_endpoints.len() != 2 { return Err(Error::InvalidParameter("aggregator_endpoints")); } @@ -253,43 +273,120 @@ impl Task { } // Compute hpke_configs mapping cfg.id -> (cfg, key). - let hpke_configs: HashMap = hpke_keys + let hpke_keys: HashMap = hpke_keys .into_iter() .map(|(cfg, key)| (*cfg.id(), (cfg, key))) .collect(); - if hpke_configs.is_empty() { + if hpke_keys.is_empty() { return Err(Error::InvalidParameter("hpke_configs")); } Ok(Self { - id: task_id, + task_id, aggregator_endpoints, + query_type, vdaf, role, vdaf_verify_keys, - max_batch_lifetime, + max_batch_query_count, + task_expiration, min_batch_size, - min_batch_duration, + time_precision, tolerable_clock_skew, collector_hpke_config, aggregator_auth_tokens, collector_auth_tokens, - hpke_keys: hpke_configs, + hpke_keys, }) } + /// Retrieves the task ID associated with this task. + pub fn id(&self) -> &TaskId { + &self.task_id + } + + /// Retrieves the aggregator endpoints associated with this task in natural order. + pub fn aggregator_endpoints(&self) -> &[Url] { + &self.aggregator_endpoints + } + + /// Retrieves the query type associated with this task. + pub fn query_type(&self) -> &QueryType { + &self.query_type + } + + /// Retrieves the VDAF associated with this task. + pub fn vdaf(&self) -> &VdafInstance { + &self.vdaf + } + + /// Retrieves the role associated with this task. + pub fn role(&self) -> &Role { + &self.role + } + + /// Retrieves the VDAF verification keys associated with this task. + pub fn vdaf_verify_keys(&self) -> &[SecretBytes] { + &self.vdaf_verify_keys + } + + /// Retrieves the max batch query count parameter associated with this task. + pub fn max_batch_query_count(&self) -> u64 { + self.max_batch_query_count + } + + /// Retrieves the task expiration associated with this task. + pub fn task_expiration(&self) -> &Time { + &self.task_expiration + } + + /// Retrieves the min batch size parameter associated with this task. + pub fn min_batch_size(&self) -> u64 { + self.min_batch_size + } + + /// Retrieves the time precision parameter associated with this task. + pub fn time_precision(&self) -> &Duration { + &self.time_precision + } + + /// Retrieves the tolerable clock skew parameter associated with this task. + pub fn tolerable_clock_skew(&self) -> &Duration { + &self.tolerable_clock_skew + } + + /// Retrieves the collector HPKE config associated with this task. + pub fn collector_hpke_config(&self) -> &HpkeConfig { + &self.collector_hpke_config + } + + /// Retrieves the aggregator authentication tokens associated with this task. + pub fn aggregator_auth_tokens(&self) -> &[AuthenticationToken] { + &self.aggregator_auth_tokens + } + + /// Retrieves the collector authentication tokens associated with this task. + pub fn collector_auth_tokens(&self) -> &[AuthenticationToken] { + &self.collector_auth_tokens + } + + /// Retrieves the HPKE keys in use associated with this task. + pub fn hpke_keys(&self) -> &HashMap { + &self.hpke_keys + } + /// Returns true if `batch_interval` is valid, per §4.6 of draft-gpew-priv-ppm. pub(crate) fn validate_batch_interval(&self, batch_interval: &Interval) -> bool { - // Batch interval should be greater than task's minimum batch duration - batch_interval.duration().as_seconds() >= self.min_batch_duration.as_seconds() - // Batch interval start must be a multiple of minimum batch duration - && batch_interval.start().as_seconds_since_epoch() % self.min_batch_duration.as_seconds() == 0 - // Batch interval duration must be a multiple of minimum batch duration - && batch_interval.duration().as_seconds() % self.min_batch_duration.as_seconds() == 0 + // Batch interval should be greater than task's time precision + batch_interval.duration().as_seconds() >= self.time_precision.as_seconds() + // Batch interval start must be a multiple of time precision + && batch_interval.start().as_seconds_since_epoch() % self.time_precision.as_seconds() == 0 + // Batch interval duration must be a multiple of time precision + && batch_interval.duration().as_seconds() % self.time_precision.as_seconds() == 0 } /// Returns the [`Url`] relative to which the server performing `role` serves its API. - pub fn aggregator_url(&self, role: Role) -> Result<&Url, Error> { + pub fn aggregator_url(&self, role: &Role) -> Result<&Url, Error> { let index = role.index().ok_or(Error::InvalidParameter(role.as_str()))?; Ok(&self.aggregator_endpoints[index]) } @@ -337,11 +434,6 @@ impl Task { let secret_bytes = self.vdaf_verify_keys.first().unwrap(); VerifyKey::try_from(secret_bytes).map_err(|_| Error::AggregatorVerifyKeySize) } - - /// Returns the secret VDAF verification keys for this task. - pub fn vdaf_verify_keys(&self) -> &[SecretBytes] { - &self.vdaf_verify_keys - } } fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { @@ -356,14 +448,16 @@ fn fmt_vector_of_urls(urls: &Vec, f: &mut Formatter<'_>) -> fmt::Result { /// Deserialize traits. #[derive(Serialize, Deserialize)] struct SerializedTask { - id: String, // in unpadded base64url + task_id: String, // in unpadded base64url aggregator_endpoints: Vec, + query_type: QueryType, vdaf: VdafInstance, role: Role, vdaf_verify_keys: Vec, // in unpadded base64url - max_batch_lifetime: u64, + max_batch_query_count: u64, + task_expiration: Time, min_batch_size: u64, - min_batch_duration: Duration, + time_precision: Duration, tolerable_clock_skew: Duration, collector_hpke_config: SerializedHpkeConfig, aggregator_auth_tokens: Vec, // in unpadded base64url @@ -373,11 +467,11 @@ struct SerializedTask { impl Serialize for Task { fn serialize(&self, serializer: S) -> Result { - let id = base64::encode_config(self.id.as_ref(), URL_SAFE_NO_PAD); + let task_id = base64::encode_config(self.task_id.as_ref(), URL_SAFE_NO_PAD); let vdaf_verify_keys: Vec<_> = self .vdaf_verify_keys .iter() - .map(|key| base64::encode_config(key.as_bytes(), URL_SAFE_NO_PAD)) + .map(|key| base64::encode_config(key.as_ref(), URL_SAFE_NO_PAD)) .collect(); let aggregator_auth_tokens = self .aggregator_auth_tokens @@ -396,14 +490,16 @@ impl Serialize for Task { .collect(); SerializedTask { - id, + task_id, aggregator_endpoints: self.aggregator_endpoints.clone(), + query_type: self.query_type, vdaf: self.vdaf.clone(), role: self.role, vdaf_verify_keys, - max_batch_lifetime: self.max_batch_lifetime, + max_batch_query_count: self.max_batch_query_count, + task_expiration: self.task_expiration, min_batch_size: self.min_batch_size, - min_batch_duration: self.min_batch_duration, + time_precision: self.time_precision, tolerable_clock_skew: self.tolerable_clock_skew, collector_hpke_config: self.collector_hpke_config.clone().into(), aggregator_auth_tokens, @@ -421,7 +517,7 @@ impl<'de> Deserialize<'de> for Task { // task_id let task_id_bytes: [u8; TaskId::LEN] = - base64::decode_config(serialized_task.id, URL_SAFE_NO_PAD) + base64::decode_config(serialized_task.task_id, URL_SAFE_NO_PAD) .map_err(D::Error::custom)? .try_into() .map_err(|_| D::Error::custom("task_id length incorrect"))?; @@ -476,12 +572,14 @@ impl<'de> Deserialize<'de> for Task { Task::new( task_id, serialized_task.aggregator_endpoints, + serialized_task.query_type, serialized_task.vdaf, serialized_task.role, vdaf_verify_keys, - serialized_task.max_batch_lifetime, + serialized_task.max_batch_query_count, + serialized_task.task_expiration, serialized_task.min_batch_size, - serialized_task.min_batch_duration, + serialized_task.time_precision, serialized_task.tolerable_clock_skew, collector_hpke_config, aggregator_auth_tokens, @@ -561,12 +659,14 @@ impl TryFrom for (HpkeConfig, HpkePrivateKey) { #[cfg(feature = "test-util")] pub mod test_util { use super::{ - AuthenticationToken, SecretBytes, Task, VdafInstance, PRIO3_AES128_VERIFY_KEY_LENGTH, + AuthenticationToken, QueryType, SecretBytes, Task, VdafInstance, + PRIO3_AES128_VERIFY_KEY_LENGTH, }; use crate::messages::DurationExt; use janus_core::hpke::test_util::generate_test_hpke_config_and_private_key; - use janus_messages::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId}; + use janus_messages::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId, Time}; use rand::{distributions::Standard, random, thread_rng, Rng}; + use url::Url; impl VdafInstance { /// Returns the expected length of a VDAF verification key for a VDAF of this type. @@ -584,11 +684,15 @@ pub mod test_util { } } - impl Task { - /// Create a dummy [`Task`] from the provided [`TaskId`], with - /// dummy values for the other fields. This is pub because it is needed for - /// integration tests. - pub fn new_dummy(task_id: TaskId, vdaf: VdafInstance, role: Role) -> Task { + /// TaskBuilder is a testing utility allowing tasks to be built based on a template. + #[derive(Clone)] + pub struct TaskBuilder(Task); + + impl TaskBuilder { + /// Create a [`TaskBuilder`] from the provided values, with arbitrary values for the other + /// task parameters. + pub fn new(query_type: QueryType, vdaf: VdafInstance, role: Role) -> Self { + let task_id = random(); let (aggregator_config_0, aggregator_private_key_0) = generate_test_hpke_config_and_private_key(); let (mut aggregator_config_1, aggregator_private_key_1) = @@ -614,28 +718,111 @@ pub mod test_util { Vec::new() }; - Task::new( - task_id, - Vec::from([ - "http://leader_endpoint".parse().unwrap(), - "http://helper_endpoint".parse().unwrap(), - ]), - vdaf, - role, - Vec::from([vdaf_verify_key]), - 1, - 0, - Duration::from_hours(8).unwrap(), - Duration::from_minutes(10).unwrap(), - generate_test_hpke_config_and_private_key().0, - Vec::from([generate_auth_token(), generate_auth_token()]), - collector_auth_tokens, - Vec::from([ - (aggregator_config_0, aggregator_private_key_0), - (aggregator_config_1, aggregator_private_key_1), - ]), + Self( + Task::new( + task_id, + Vec::from([ + "https://leader.endpoint".parse().unwrap(), + "https://helper.endpoint".parse().unwrap(), + ]), + query_type, + vdaf, + role, + Vec::from([vdaf_verify_key]), + 1, + Time::from_seconds_since_epoch(u64::MAX), + 0, + Duration::from_hours(8).unwrap(), + Duration::from_minutes(10).unwrap(), + generate_test_hpke_config_and_private_key().0, + Vec::from([generate_auth_token(), generate_auth_token()]), + collector_auth_tokens, + Vec::from([ + (aggregator_config_0, aggregator_private_key_0), + (aggregator_config_1, aggregator_private_key_1), + ]), + ) + .unwrap(), ) - .unwrap() + } + + /// Associates the eventual task with the given task ID. + pub fn with_id(self, task_id: TaskId) -> Self { + Self(Task { task_id, ..self.0 }) + } + + /// Associates the eventual task with the given aggregator endpoints. + pub fn with_aggregator_endpoints(self, aggregator_endpoints: Vec) -> Self { + Self(Task { + aggregator_endpoints, + ..self.0 + }) + } + + /// Retrieves the aggregator endpoints associated with this task builder. + pub fn aggregator_endpoints(&self) -> &[Url] { + self.0.aggregator_endpoints() + } + + /// Associates the eventual task with the given aggregator role. + pub fn with_role(self, role: Role) -> Self { + Self(Task { role, ..self.0 }) + } + + /// Associates the eventual task with the given VDAF verification keys. + pub fn with_vdaf_verify_keys(self, vdaf_verify_keys: Vec) -> Self { + Self(Task { + vdaf_verify_keys, + ..self.0 + }) + } + + /// Associates the eventual task with the given max batch query count parameter. + pub fn with_max_batch_query_count(self, max_batch_query_count: u64) -> Self { + Self(Task { + max_batch_query_count, + ..self.0 + }) + } + + /// Associates the eventual task with the given min batch size parameter. + pub fn with_min_batch_size(self, min_batch_size: u64) -> Self { + Self(Task { + min_batch_size, + ..self.0 + }) + } + + /// Associates the eventual task with the given time precision parameter. + pub fn with_time_precision(self, time_precision: Duration) -> Self { + Self(Task { + time_precision, + ..self.0 + }) + } + + /// Associates the eventual task with the given collector HPKE config. + pub fn with_collector_hpke_config(self, collector_hpke_config: HpkeConfig) -> Self { + Self(Task { + collector_hpke_config, + ..self.0 + }) + } + + /// Associates the eventual task with the given aggregator authentication tokens. + pub fn with_aggregator_auth_tokens( + self, + aggregator_auth_tokens: Vec, + ) -> Self { + Self(Task { + aggregator_auth_tokens, + ..self.0 + }) + } + + /// Consumes this task builder & produces a [`Task`] with the given specifications. + pub fn build(self) -> Task { + self.0 } } @@ -652,7 +839,11 @@ mod tests { use super::{ test_util::generate_auth_token, SecretBytes, Task, PRIO3_AES128_VERIFY_KEY_LENGTH, }; - use crate::{config::test_util::roundtrip_encoding, messages::DurationExt, task::VdafInstance}; + use crate::{ + config::test_util::roundtrip_encoding, + messages::DurationExt, + task::{test_util::TaskBuilder, QueryType, VdafInstance}, + }; use janus_core::hpke::test_util::generate_test_hpke_config_and_private_key; use janus_messages::{Duration, Interval, Role, Time}; use rand::random; @@ -660,9 +851,10 @@ mod tests { #[test] fn validate_batch_interval() { - let mut task = Task::new_dummy(random(), VdafInstance::Fake, Role::Leader); - let min_batch_duration_secs = 3600; - task.min_batch_duration = Duration::from_seconds(min_batch_duration_secs); + let time_precision_secs = 3600; + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) + .with_time_precision(Duration::from_seconds(time_precision_secs)) + .build(); struct TestCase { name: &'static str, @@ -670,12 +862,12 @@ mod tests { expected: bool, } - let test_cases = vec![ + for test_case in Vec::from([ TestCase { name: "same duration as minimum", input: Interval::new( - Time::from_seconds_since_epoch(min_batch_duration_secs), - Duration::from_seconds(min_batch_duration_secs), + Time::from_seconds_since_epoch(time_precision_secs), + Duration::from_seconds(time_precision_secs), ) .unwrap(), expected: true, @@ -683,8 +875,8 @@ mod tests { TestCase { name: "interval too short", input: Interval::new( - Time::from_seconds_since_epoch(min_batch_duration_secs), - Duration::from_seconds(min_batch_duration_secs - 1), + Time::from_seconds_since_epoch(time_precision_secs), + Duration::from_seconds(time_precision_secs - 1), ) .unwrap(), expected: false, @@ -692,8 +884,8 @@ mod tests { TestCase { name: "interval larger than minimum", input: Interval::new( - Time::from_seconds_since_epoch(min_batch_duration_secs), - Duration::from_seconds(min_batch_duration_secs * 2), + Time::from_seconds_since_epoch(time_precision_secs), + Duration::from_seconds(time_precision_secs * 2), ) .unwrap(), expected: true, @@ -701,8 +893,8 @@ mod tests { TestCase { name: "interval duration not aligned with minimum", input: Interval::new( - Time::from_seconds_since_epoch(min_batch_duration_secs), - Duration::from_seconds(min_batch_duration_secs + 1800), + Time::from_seconds_since_epoch(time_precision_secs), + Duration::from_seconds(time_precision_secs + 1800), ) .unwrap(), expected: false, @@ -711,14 +903,12 @@ mod tests { name: "interval start not aligned with minimum", input: Interval::new( Time::from_seconds_since_epoch(1800), - Duration::from_seconds(min_batch_duration_secs), + Duration::from_seconds(time_precision_secs), ) .unwrap(), expected: false, }, - ]; - - for test_case in test_cases { + ]) { assert_eq!( test_case.expected, task.validate_batch_interval(&test_case.input), @@ -767,7 +957,7 @@ mod tests { ); assert_tokens( &VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Histogram { - buckets: vec![0, 100, 200, 400], + buckets: Vec::from([0, 100, 200, 400]), }), &[ Token::StructVariant { @@ -823,11 +1013,14 @@ mod tests { #[test] fn task_serialization() { - roundtrip_encoding(Task::new_dummy( - random(), - VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), - Role::Leader, - )); + roundtrip_encoding( + TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), + Role::Leader, + ) + .build(), + ); } #[test] @@ -839,10 +1032,12 @@ mod tests { "http://leader_endpoint".parse().unwrap(), "http://helper_endpoint".parse().unwrap(), ]), + QueryType::TimeInterval, VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), Role::Leader, Vec::from([SecretBytes::new([0; PRIO3_AES128_VERIFY_KEY_LENGTH].into())]), 0, + Time::from_seconds_since_epoch(u64::MAX), 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), @@ -860,10 +1055,12 @@ mod tests { "http://leader_endpoint".parse().unwrap(), "http://helper_endpoint".parse().unwrap(), ]), + QueryType::TimeInterval, VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), Role::Leader, Vec::from([SecretBytes::new([0; PRIO3_AES128_VERIFY_KEY_LENGTH].into())]), 0, + Time::from_seconds_since_epoch(u64::MAX), 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), @@ -881,10 +1078,12 @@ mod tests { "http://leader_endpoint".parse().unwrap(), "http://helper_endpoint".parse().unwrap(), ]), + QueryType::TimeInterval, VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), Role::Helper, Vec::from([SecretBytes::new([0; PRIO3_AES128_VERIFY_KEY_LENGTH].into())]), 0, + Time::from_seconds_since_epoch(u64::MAX), 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), @@ -902,10 +1101,12 @@ mod tests { "http://leader_endpoint".parse().unwrap(), "http://helper_endpoint".parse().unwrap(), ]), + QueryType::TimeInterval, VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), Role::Helper, Vec::from([SecretBytes::new([0; PRIO3_AES128_VERIFY_KEY_LENGTH].into())]), 0, + Time::from_seconds_since_epoch(u64::MAX), 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), @@ -925,10 +1126,12 @@ mod tests { "http://leader_endpoint/foo/bar".parse().unwrap(), "http://helper_endpoint".parse().unwrap(), ]), + QueryType::TimeInterval, VdafInstance::Real(janus_core::task::VdafInstance::Prio3Aes128Count), Role::Leader, Vec::from([SecretBytes::new([0; PRIO3_AES128_VERIFY_KEY_LENGTH].into())]), 0, + Time::from_seconds_since_epoch(u64::MAX), 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), diff --git a/janus_server/tests/graceful_shutdown.rs b/janus_server/tests/graceful_shutdown.rs index 4a884f537..bf39406c3 100644 --- a/janus_server/tests/graceful_shutdown.rs +++ b/janus_server/tests/graceful_shutdown.rs @@ -5,8 +5,10 @@ use janus_core::{task::VdafInstance, test_util::install_test_trace_subscriber, time::RealClock}; use janus_messages::Role; -use janus_server::{datastore::test_util::ephemeral_datastore, task::Task}; -use rand::random; +use janus_server::{ + datastore::test_util::ephemeral_datastore, + task::{test_util::TaskBuilder, QueryType}, +}; use reqwest::Url; use serde_yaml::Mapping; use std::{ @@ -120,8 +122,12 @@ async fn graceful_shutdown(binary: &Path, mut config: Mapping) { format!("{}", health_check_listen_address).into(), ); - let task_id = random(); - let task = Task::new_dummy(task_id, VdafInstance::Prio3Aes128Count.into(), Role::Leader); + let task = TaskBuilder::new( + QueryType::TimeInterval, + VdafInstance::Prio3Aes128Count.into(), + Role::Leader, + ) + .build(); datastore.put_task(&task).await.unwrap(); // Save the above configuration to a temporary file, so that we can pass