diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 8eb411a55..88b17318b 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -52,25 +52,21 @@ jobs: run: cargo doc janus_server_docker: runs-on: ubuntu-latest + env: + DOCKER_BUILDKIT: 1 steps: - uses: actions/checkout@v3 - - env: - DOCKER_BUILDKIT: 1 - run: docker build --tag janus_server . - - env: - DOCKER_BUILDKIT: 1 - run: docker build --tag janus_aggregation_job_creator --build-arg BINARY=aggregation_job_creator . - - env: - DOCKER_BUILDKIT: 1 - run: docker build --tag janus_aggregation_job_driver --build-arg BINARY=aggregation_job_driver . - - env: - DOCKER_BUILDKIT: 1 - run: docker build --tag janus_collect_job_driver --build-arg BINARY=collect_job_driver . - - env: - DOCKER_BUILDKIT: 1 - run: docker build --tag janus_cli --build-arg BINARY=janus_cli . + - run: docker build --tag janus_server . + - run: docker build --tag janus_aggregation_job_creator --build-arg BINARY=aggregation_job_creator . + - run: docker build --tag janus_aggregation_job_driver --build-arg BINARY=aggregation_job_driver . + - run: docker build --tag janus_collect_job_driver --build-arg BINARY=collect_job_driver . + - run: docker build --tag janus_cli --build-arg BINARY=janus_cli . - run: docker run --rm janus_server --help - run: docker run --rm janus_aggregation_job_creator --help - run: docker run --rm janus_aggregation_job_driver --help - run: docker run --rm janus_collect_job_driver --help - run: docker run --rm janus_cli --help + + - run: docker build --tag janus_interop_client --build-arg BINARY=janus_interop_client -f Dockerfile.interop . + - run: docker build --tag janus_interop_aggregator -f Dockerfile.interop_aggregator . + - run: docker build --tag janus_interop_collector --build-arg BINARY=janus_interop_collector -f Dockerfile.interop . diff --git a/.github/workflows/push-docker-images-release.yml b/.github/workflows/push-docker-images-release.yml index c66180dd9..7e74b8e7d 100644 --- a/.github/workflows/push-docker-images-release.yml +++ b/.github/workflows/push-docker-images-release.yml @@ -72,3 +72,24 @@ jobs: . - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_cli:latest - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_cli:${{ steps.get_version.outputs.VERSION }} + + - run: |- + docker build --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_client:latest \ + --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_client:${{ steps.get_version.outputs.VERSION }} \ + --build-arg BINARY=janus_interop_client \ + -f Dockerfile.interop . + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_client:latest + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_client:${{ steps.get_version.outputs.VERSION }} + - run: |- + docker build --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_aggregator:latest \ + --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_aggregator:${{ steps.get_version.outputs.VERSION }} \ + -f Dockerfile.interop_aggregator . + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_aggregator:latest + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_aggregator:${{ steps.get_version.outputs.VERSION }} + - run: |- + docker build --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_collector:latest \ + --tag us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_collector:${{ steps.get_version.outputs.VERSION }} \ + --build-arg BINARY=janus_interop_collector \ + -f Dockerfile.interop . + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_collector:latest + - run: docker push us-west2-docker.pkg.dev/janus-artifacts/janus/janus_interop_collector:${{ steps.get_version.outputs.VERSION }} diff --git a/Cargo.lock b/Cargo.lock index 5ffd0288d..e92d07a05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1569,6 +1569,34 @@ version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e85a1509a128c855368e135cffcde7eac17d8e1083f41e2b98c58bc1a5074be" +[[package]] +name = "interop_binaries" +version = "0.1.0" +dependencies = [ + "anyhow", + "base64", + "clap 3.2.16", + "janus_client", + "janus_core", + "janus_server", + "lazy_static", + "opentelemetry", + "portpicker", + "prio", + "rand", + "reqwest", + "ring", + "serde", + "serde_json", + "testcontainers", + "tokio", + "tracing", + "tracing-log", + "tracing-subscriber", + "url", + "warp", +] + [[package]] name = "ipnet" version = "2.4.0" diff --git a/Cargo.toml b/Cargo.toml index 2a9e9255f..391d39d53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] -members = ["janus_core", "janus_client", "janus_server", "monolithic_integration_test"] +members = ["interop_binaries", "janus_core", "janus_client", "janus_server", "monolithic_integration_test"] resolver = "2" diff --git a/Dockerfile b/Dockerfile index 43f50e89e..cfc443f81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,6 +9,7 @@ COPY janus_core /src/janus_core COPY janus_client /src/janus_client COPY janus_server /src/janus_server COPY monolithic_integration_test /src/monolithic_integration_test +COPY interop_binaries /src/interop_binaries COPY db/schema.sql /src/db/schema.sql RUN --mount=type=cache,target=/usr/local/cargo/registry --mount=type=cache,target=/src/target cargo build --release --bin $BINARY --features=prometheus && cp /src/target/release/$BINARY /$BINARY diff --git a/Dockerfile.interop b/Dockerfile.interop new file mode 100644 index 000000000..229e02603 --- /dev/null +++ b/Dockerfile.interop @@ -0,0 +1,26 @@ +FROM rust:1.62.1-alpine as builder +ARG BINARY +RUN apk add libc-dev + +WORKDIR /src +COPY Cargo.toml /src/Cargo.toml +COPY Cargo.lock /src/Cargo.lock +COPY janus_core /src/janus_core +COPY janus_client /src/janus_client +COPY janus_server /src/janus_server +COPY monolithic_integration_test /src/monolithic_integration_test +COPY interop_binaries /src/interop_binaries +COPY db/schema.sql /src/db/schema.sql +RUN --mount=type=cache,target=/usr/local/cargo/registry --mount=type=cache,target=/src/target cargo build --release --bin $BINARY && cp /src/target/release/$BINARY /$BINARY + +FROM alpine:3.16.1 +ARG BINARY +RUN mkdir /logs +COPY --from=builder /src/db/schema.sql /db/schema.sql +COPY --from=builder /$BINARY /$BINARY +EXPOSE 8080 +# Store the build argument in an environment variable so we can reference it +# from the ENTRYPOINT at runtime. +ENV BINARY=$BINARY +ENV RUST_LOG=info +ENTRYPOINT ["/bin/sh", "-c", "exec /$BINARY \"$@\" >/logs/stdout.log 2>/logs/stderr.log"] diff --git a/Dockerfile.interop_aggregator b/Dockerfile.interop_aggregator new file mode 100644 index 000000000..a053d704a --- /dev/null +++ b/Dockerfile.interop_aggregator @@ -0,0 +1,22 @@ +FROM rust:1.62.1-alpine as builder +RUN apk add libc-dev + +WORKDIR /src +COPY Cargo.toml /src/Cargo.toml +COPY Cargo.lock /src/Cargo.lock +COPY janus_core /src/janus_core +COPY janus_client /src/janus_client +COPY janus_server /src/janus_server +COPY monolithic_integration_test /src/monolithic_integration_test +COPY interop_binaries /src/interop_binaries +COPY db/schema.sql /src/db/schema.sql +RUN --mount=type=cache,target=/usr/local/cargo/registry --mount=type=cache,target=/src/target cargo build --release --bin janus_interop_aggregator && cp /src/target/release/janus_interop_aggregator /janus_interop_aggregator + +FROM postgres:14-bullseye +RUN mkdir /logs +RUN apt-get update && apt-get install -y supervisor +COPY interop_binaries/supervisord.conf /supervisord.conf +COPY --from=builder /src/db/schema.sql /db/schema.sql +COPY --from=builder /janus_interop_aggregator /janus_interop_aggregator +EXPOSE 8080 +ENTRYPOINT ["/usr/bin/supervisord", "-c", "/supervisord.conf"] diff --git a/interop_binaries/Cargo.toml b/interop_binaries/Cargo.toml new file mode 100644 index 000000000..ffcf6f24a --- /dev/null +++ b/interop_binaries/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "interop_binaries" +version = "0.1.0" +edition = "2021" +license = "MPL-2.0" +rust-version = "1.60" +publish = false + +[dependencies] +anyhow = "1" +base64 = "0.13.0" +clap = "3.2.16" +janus_client = { path = "../janus_client" } +janus_core = { path = "../janus_core" } +janus_server = { path = "../janus_server" } +opentelemetry = { version = "0.17", features = ["metrics"] } +prio = "0.8.2" +rand = "0.8" +reqwest = { version = "0.11.4", default-features = false, features = ["rustls-tls"] } +ring = "0.16.20" +serde = { version = "1.0.141", features = ["derive"] } +serde_json = "1.0.82" +tokio = { version = "^1.20", features = ["full", "tracing"] } +tracing = "0.1.36" +tracing-log = "0.1.3" +tracing-subscriber = { version = "0.3", features = ["std", "env-filter", "fmt"] } +url = { version = "2.2.2", features = ["serde"] } +warp = "^0.3" + +[dev-dependencies] +lazy_static = "1" +portpicker = "0.1" +reqwest = { version = "0.11.4", default-features = false, features = ["json"] } +testcontainers = "0.14.0" diff --git a/interop_binaries/src/bin/janus_interop_aggregator.rs b/interop_binaries/src/bin/janus_interop_aggregator.rs new file mode 100644 index 000000000..3d5a0ace5 --- /dev/null +++ b/interop_binaries/src/bin/janus_interop_aggregator.rs @@ -0,0 +1,309 @@ +use anyhow::Context; +use base64::URL_SAFE_NO_PAD; +use clap::{Arg, Command}; +use interop_binaries::{ + install_tracing_subscriber, + status::{ERROR, SUCCESS}, + HpkeConfigRegistry, VdafObject, +}; +use janus_core::{ + message::{Duration, HpkeConfig, Role, TaskId}, + time::RealClock, + TokioRuntime, +}; +use janus_server::{ + aggregator::{ + aggregate_share::CollectJobDriver, aggregation_job_creator::AggregationJobCreator, + aggregation_job_driver::AggregationJobDriver, + }, + binary_utils::{database_pool, job_driver::JobDriver}, + config::DbConfig, + datastore::{Crypter, Datastore}, + task::{AuthenticationToken, Task}, +}; +use opentelemetry::global::meter; +use prio::codec::Decode; +use rand::{thread_rng, Rng}; +use ring::aead::{LessSafeKey, UnboundKey, AES_128_GCM}; +use serde::{Deserialize, Serialize}; +use std::{ + net::{Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration as StdDuration, +}; +use tokio::sync::Mutex; +use url::Url; +use warp::{hyper::StatusCode, reply::Response, Filter, Reply}; + +#[derive(Debug, Serialize)] +struct EndpointResponse { + status: &'static str, + endpoint: &'static str, +} + +static ENDPOINT_RESPONSE: EndpointResponse = EndpointResponse { + status: "success", + endpoint: "/", +}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AddTaskRequest { + task_id: String, + leader: Url, + helper: Url, + vdaf: VdafObject, + leader_authentication_token: String, + #[serde(default)] + collector_authentication_token: Option, + aggregator_id: u8, + verify_key: String, + max_batch_lifetime: u64, + min_batch_size: u64, + min_batch_duration: u64, + collector_hpke_config: String, +} + +#[derive(Debug, Serialize)] +struct AddTaskResponse { + status: &'static str, + #[serde(default)] + error: Option, +} + +async fn handle_add_task( + datastore: &Datastore, + keyring: &Mutex, + request: AddTaskRequest, +) -> anyhow::Result<()> { + let task_id_bytes = base64::decode_config(request.task_id, base64::URL_SAFE_NO_PAD) + .context("invalid base64url content in \"taskId\"")?; + let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; + let vdaf: janus_core::task::VdafInstance = request.vdaf.into(); + let vdaf: janus_server::task::VdafInstance = vdaf.into(); + let leader_authentication_token = + AuthenticationToken::from(request.leader_authentication_token.into_bytes()); + let verify_key = base64::decode_config(request.verify_key, URL_SAFE_NO_PAD) + .context("invalid base64url content in \"verifyKey\"")?; + let min_batch_duration = Duration::from_seconds(request.min_batch_duration); + let collector_hpke_config_bytes = + base64::decode_config(request.collector_hpke_config, URL_SAFE_NO_PAD) + .context("invalid base64url content in \"collectorHpkeConfig\"")?; + let collector_hpke_config = HpkeConfig::get_decoded(&collector_hpke_config_bytes) + .context("could not parse collector HPKE configuration")?; + + let (role, collector_authentication_tokens) = match ( + request.aggregator_id, + request.collector_authentication_token, + ) { + (0, None) => { + return Err(anyhow::anyhow!("collector authentication is missing")); + } + (0, Some(collector_authentication_token)) => ( + Role::Leader, + vec![AuthenticationToken::from( + collector_authentication_token.into_bytes(), + )], + ), + (1, _) => (Role::Helper, Vec::new()), + _ => return Err(anyhow::anyhow!("invalid \"aggregator_id\" value")), + }; + + let (hpke_config, private_key) = keyring.lock().await.get_random_keypair(); + + let task = Task::new( + task_id, + vec![request.leader, request.helper], + vdaf, + role, + vec![verify_key], + request.max_batch_lifetime, + request.min_batch_size, + min_batch_duration, + // We can be strict about clock skew since this executable is only intended for use with + // other aggregators running on the same host. + Duration::from_seconds(1), + collector_hpke_config, + vec![leader_authentication_token], + collector_authentication_tokens, + [(hpke_config, private_key)], + ) + .context("error constructing task")?; + + datastore + .run_tx(move |tx| { + let task = task.clone(); + Box::pin(async move { tx.put_task(&task).await }) + }) + .await + .context("error adding task to database") +} + +fn make_filter( + datastore: Arc>, +) -> anyhow::Result + Clone> { + let keyring = Arc::new(Mutex::new(HpkeConfigRegistry::new())); + let clock = janus_core::time::RealClock::default(); + let dap_filter = janus_server::aggregator::aggregator_filter(Arc::clone(&datastore), clock)?; + + let endpoint_filter = warp::path!("endpoint_for_task").map(|| { + warp::reply::with_status(warp::reply::json(&ENDPOINT_RESPONSE), StatusCode::OK) + .into_response() + }); + let add_task_filter = + warp::path!("add_task") + .and(warp::body::json()) + .then(move |request: AddTaskRequest| { + let datastore = Arc::clone(&datastore); + let keyring = Arc::clone(&keyring); + async move { + let response = match handle_add_task(&datastore, &keyring, request).await { + Ok(()) => AddTaskResponse { + status: SUCCESS, + error: None, + }, + Err(e) => AddTaskResponse { + status: ERROR, + error: Some(format!("{:?}", e)), + }, + }; + warp::reply::with_status(warp::reply::json(&response), StatusCode::OK) + .into_response() + } + }); + + Ok(warp::path!("internal" / "test" / ..) + .and(warp::post()) + .and(endpoint_filter.or(add_task_filter).unify()) + .or(dap_filter.map(Reply::into_response)) + .unify()) +} + +fn app() -> clap::Command<'static> { + Command::new("Janus interoperation test aggregator") + .arg( + Arg::new("port") + .long("port") + .short('p') + .default_value("8080") + .help("Port number to listen on."), + ) + .arg( + Arg::new("postgres-url") + .long("postgres-url") + .default_value("postgres://postgres@127.0.0.1:5432/postgres") + .help("PostgreSQL database connection URL."), + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + install_tracing_subscriber()?; + let matches = app().get_matches(); + let http_port = matches.value_of_t::("port")?; + let postgres_url = matches.value_of_t::("postgres-url")?; + + // Make an ephemeral datastore key. + let mut key_bytes = [0u8; 16]; + thread_rng().fill(&mut key_bytes); + let datastore_key = LessSafeKey::new(UnboundKey::new(&AES_128_GCM, &key_bytes).unwrap()); + let crypter = Crypter::new(vec![datastore_key]); + + // Connect to database, apply schema, and set up datastore. + let db_config = DbConfig { + url: postgres_url, + connection_pool_timeouts_secs: 30, + }; + let pool = database_pool(&db_config, &None).await?; + let clock = janus_core::time::RealClock::default(); + let client = pool.get().await?; + client + .batch_execute(include_str!("../../../db/schema.sql")) + .await?; + // Return the database connection we used to deploy the schema back to the pool, so it can be + // reused. + drop(client); + let datastore = Arc::new(Datastore::new(pool, crypter, clock)); + + // Run an HTTP server with both the DAP aggregator endpoints and the interoperation test + // endpoints. + let filter = make_filter(Arc::clone(&datastore))?; + let server = warp::serve(filter); + let aggregator_future = server.bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, http_port))); + + // Run the aggregation job creator. + let pool = database_pool(&db_config, &None).await?; + let datastore_key = LessSafeKey::new(UnboundKey::new(&AES_128_GCM, &key_bytes).unwrap()); + let crypter = Crypter::new(vec![datastore_key]); + let aggregation_job_creator = Arc::new(AggregationJobCreator::new( + Datastore::new(pool, crypter, clock), + clock, + StdDuration::from_secs(5), + StdDuration::from_secs(1), + 1, + 100, + )); + let aggregation_job_creator_future = aggregation_job_creator.run(); + + // Run the aggregation job driver. + let aggregation_job_driver_meter = meter("aggregation_job_driver"); + let aggregation_job_driver = Arc::new(AggregationJobDriver::new( + reqwest::Client::new(), + &aggregation_job_driver_meter, + )); + let aggregation_job_driver = Arc::new(JobDriver::new( + clock, + TokioRuntime, + aggregation_job_driver_meter, + Duration::from_seconds(1), + Duration::from_seconds(5), + 10, + Duration::from_seconds(1), + aggregation_job_driver.make_incomplete_job_acquirer_callback( + Arc::clone(&datastore), + Duration::from_seconds(10), + ), + aggregation_job_driver.make_job_stepper_callback(Arc::clone(&datastore), 3), + )); + let aggregation_job_driver_future = aggregation_job_driver.run(); + + // Run the collect job driver. + let collect_job_driver_meter = meter("collect_job_driver"); + let collect_job_driver = Arc::new(CollectJobDriver::new( + reqwest::Client::new(), + &collect_job_driver_meter, + )); + let collect_job_driver = Arc::new(JobDriver::new( + clock, + TokioRuntime, + collect_job_driver_meter, + Duration::from_seconds(1), + Duration::from_seconds(5), + 10, + Duration::from_seconds(1), + collect_job_driver.make_incomplete_job_acquirer_callback( + Arc::clone(&datastore), + Duration::from_seconds(10), + ), + collect_job_driver.make_job_stepper_callback(Arc::clone(&datastore), 3), + )); + let collect_job_driver_future = collect_job_driver.run(); + + tokio::spawn(aggregation_job_creator_future); + tokio::spawn(aggregation_job_driver_future); + tokio::spawn(collect_job_driver_future); + + aggregator_future.await; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::app; + + #[test] + fn verify_clap_app() { + app().debug_assert(); + } +} diff --git a/interop_binaries/src/bin/janus_interop_client.rs b/interop_binaries/src/bin/janus_interop_client.rs new file mode 100644 index 000000000..fd540cf8f --- /dev/null +++ b/interop_binaries/src/bin/janus_interop_client.rs @@ -0,0 +1,192 @@ +use anyhow::Context; +use base64::URL_SAFE_NO_PAD; +use clap::{Arg, Command}; +use interop_binaries::{ + install_tracing_subscriber, + status::{ERROR, SUCCESS}, + VdafObject, +}; +use janus_client::ClientParameters; +use janus_core::{ + message::{Duration, Role, TaskId, Time}, + time::{MockClock, RealClock}, +}; +use prio::{ + codec::Decode, + vdaf::{prio3::Prio3, Vdaf}, +}; +use serde::{Deserialize, Serialize}; +use std::net::{Ipv4Addr, SocketAddr}; +use url::Url; +use warp::{hyper::StatusCode, reply::Response, Filter, Reply}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct UploadRequest { + task_id: String, + leader: Url, + helper: Url, + vdaf: VdafObject, + measurement: u64, + #[serde(default)] + nonce_time: Option, + min_batch_duration: u64, +} + +#[derive(Debug, Serialize)] +struct UploadResponse { + status: &'static str, + #[serde(default)] + error: Option, +} + +async fn handle_upload_generic( + http_client: &reqwest::Client, + vdaf_client: V, + request: UploadRequest, + measurement: V::Measurement, +) -> anyhow::Result<()> +where + for<'a> Vec: From<&'a ::AggregateShare>, +{ + let task_id_bytes = base64::decode_config(request.task_id, URL_SAFE_NO_PAD) + .context("invalid base64url content in \"taskId\"")?; + let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; + let min_batch_duration = Duration::from_seconds(request.min_batch_duration); + let client_parameters = ClientParameters::new( + task_id, + vec![request.leader, request.helper], + min_batch_duration, + ); + + let leader_hpke_config = janus_client::aggregator_hpke_config( + &client_parameters, + Role::Leader, + task_id, + http_client, + ) + .await + .context("failed to fetch leader's HPKE configuration")?; + let helper_hpke_config = janus_client::aggregator_hpke_config( + &client_parameters, + Role::Helper, + task_id, + http_client, + ) + .await + .context("failed to fetch helper's HPKE configuration")?; + + match request.nonce_time { + Some(nonce_time) => { + let clock = MockClock::new(Time::from_seconds_since_epoch(nonce_time)); + let client = janus_client::Client::new( + client_parameters, + vdaf_client, + clock, + http_client, + leader_hpke_config, + helper_hpke_config, + ); + client + .upload(&measurement) + .await + .context("report generation and upload failed") + } + None => { + let client = janus_client::Client::new( + client_parameters, + vdaf_client, + RealClock::default(), + http_client, + leader_hpke_config, + helper_hpke_config, + ); + client + .upload(&measurement) + .await + .context("report generation and upload failed") + } + } +} + +async fn handle_upload( + http_client: &reqwest::Client, + request: UploadRequest, +) -> anyhow::Result<()> { + let measurement = request.measurement; + match request.vdaf { + VdafObject::Prio3Aes128Count {} => { + let vdaf_client = + Prio3::new_aes128_count(2).context("failed to construct Prio3Aes128Count VDAF")?; + handle_upload_generic(http_client, vdaf_client, request, measurement).await?; + } + VdafObject::Prio3Aes128Sum { bits } => { + let vdaf_client = Prio3::new_aes128_sum(2, bits) + .context("failed to construct Prio3Aes128Sum VDAF")?; + handle_upload_generic(http_client, vdaf_client, request, measurement.into()).await?; + } + VdafObject::Prio3Aes128Histogram { ref buckets } => { + let vdaf_client = Prio3::new_aes128_histogram(2, buckets) + .context("failed to construct Prio3Aes128Histogram VDAF")?; + handle_upload_generic(http_client, vdaf_client, request, measurement.into()).await?; + } + } + Ok(()) +} + +fn make_filter() -> anyhow::Result + Clone> { + let http_client = janus_client::default_http_client()?; + Ok(warp::path!("internal" / "test" / "upload") + .and(warp::post()) + .and(warp::body::json()) + .then(move |request: UploadRequest| { + let http_client = http_client.clone(); + async move { + let response = match handle_upload(&http_client, request).await { + Ok(()) => UploadResponse { + status: SUCCESS, + error: None, + }, + Err(e) => UploadResponse { + status: ERROR, + error: Some(format!("{:?}", e)), + }, + }; + warp::reply::with_status(warp::reply::json(&response), StatusCode::OK) + .into_response() + } + })) +} + +fn app() -> clap::Command<'static> { + Command::new("Janus interoperation test client").arg( + Arg::new("port") + .long("port") + .short('p') + .default_value("8080") + .help("Port number to listen on."), + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + install_tracing_subscriber()?; + let matches = app().get_matches(); + let port = matches.value_of_t::("port")?; + let filter = make_filter()?; + let server = warp::serve(filter); + server + .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, port))) + .await; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::app; + + #[test] + fn verify_clap_app() { + app().debug_assert(); + } +} diff --git a/interop_binaries/src/bin/janus_interop_collector.rs b/interop_binaries/src/bin/janus_interop_collector.rs new file mode 100644 index 000000000..c9487016a --- /dev/null +++ b/interop_binaries/src/bin/janus_interop_collector.rs @@ -0,0 +1,490 @@ +use anyhow::Context; +use base64::URL_SAFE_NO_PAD; +use clap::{Arg, Command}; +use interop_binaries::{ + install_tracing_subscriber, + status::{COMPLETE, ERROR, IN_PROGRESS, SUCCESS}, + HpkeConfigRegistry, VdafObject, +}; +use janus_core::{ + hpke::{self, associated_data_for_aggregate_share, HpkeApplicationInfo, HpkePrivateKey, Label}, + message::{Duration, HpkeConfig, Interval, Role, TaskId, Time}, +}; +use janus_server::{ + message::{CollectReq, CollectResp}, + task::DAP_AUTH_HEADER, +}; +use prio::{ + codec::{Decode, Encode}, + field::{Field128, Field64}, + vdaf::{ + prio3::{Prio3, Prio3Aes128Count, Prio3Aes128Histogram, Prio3Aes128Sum}, + AggregateShare, Collector, Vdaf, + }, +}; +use reqwest::{ + header::{CONTENT_TYPE, LOCATION}, + Url, +}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{hash_map::Entry, HashMap}, + net::{Ipv4Addr, SocketAddr}, + sync::Arc, +}; +use tokio::sync::Mutex; +use warp::{hyper::StatusCode, reply::Response, Filter, Reply}; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AddTaskRequest { + task_id: String, + leader: Url, + vdaf: VdafObject, + collector_authentication_token: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct AddTaskResponse { + status: &'static str, + #[serde(default)] + error: Option, + collector_hpke_config: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct CollectStartRequest { + task_id: String, + agg_param: String, + batch_interval_start: u64, + batch_interval_duration: u64, +} + +#[derive(Debug, Serialize)] +struct CollectStartResponse { + status: &'static str, + #[serde(default)] + error: Option, + #[serde(default)] + handle: Option, +} + +#[derive(Debug, Deserialize)] +struct CollectPollRequest { + handle: String, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum AggregationResult { + Number(u64), + NumberArray(Vec), +} + +#[derive(Debug, Serialize)] +struct CollectPollResponse { + status: &'static str, + #[serde(default)] + error: Option, + #[serde(default)] + result: Option, +} + +struct TaskState { + private_key: HpkePrivateKey, + hpke_config: HpkeConfig, + leader_url: Url, + vdaf: VdafObject, + auth_token: String, +} + +/// A collect job handle. +#[derive(Clone, PartialEq, Eq, Hash)] +struct Handle(String); + +impl Handle { + fn generate() -> Handle { + let randomness = rand::random::<[u8; 32]>(); + Handle(base64::encode_config(randomness, URL_SAFE_NO_PAD)) + } +} + +struct CollectJobState { + task_id: TaskId, + url: Url, + batch_interval: Interval, + agg_param: Vec, +} + +async fn handle_add_task( + tasks: &Mutex>, + keyring: &Mutex, + request: AddTaskRequest, +) -> anyhow::Result { + let task_id_bytes = base64::decode_config(request.task_id, base64::URL_SAFE_NO_PAD) + .context("invalid base64url content in \"taskId\"")?; + let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; + + let mut tasks_guard = tasks.lock().await; + let entry = tasks_guard.entry(task_id); + if let Entry::Occupied(_) = &entry { + return Err(anyhow::anyhow!("cannot add a task with a duplicate ID")); + } + + let (hpke_config, private_key) = keyring.lock().await.get_random_keypair(); + + entry.or_insert(TaskState { + private_key, + hpke_config: hpke_config.clone(), + leader_url: request.leader, + vdaf: request.vdaf, + auth_token: request.collector_authentication_token, + }); + + Ok(hpke_config) +} + +async fn handle_collect_start( + http_client: &reqwest::Client, + tasks: &Mutex>, + collect_jobs: &Mutex>, + request: CollectStartRequest, +) -> anyhow::Result { + let task_id_bytes = base64::decode_config(request.task_id, URL_SAFE_NO_PAD) + .context("invalid base64url content in \"taskId\"")?; + let task_id = TaskId::get_decoded(&task_id_bytes).context("invalid length of TaskId")?; + let agg_param = base64::decode_config(request.agg_param, URL_SAFE_NO_PAD) + .context("invalid base64url content in \"aggParam\"")?; + let batch_interval = Interval::new( + Time::from_seconds_since_epoch(request.batch_interval_start), + Duration::from_seconds(request.batch_interval_duration), + ) + .context("invalid batch interval specification")?; + + let dap_collect_request = CollectReq { + task_id, + batch_interval, + agg_param: agg_param.clone(), + }; + + let tasks_guard = tasks.lock().await; + let task_state = tasks_guard + .get(&task_id) + .context("task was not added before being used in a collect request")?; + + let response = http_client + .post(task_state.leader_url.join("collect")?) + .header(CONTENT_TYPE, CollectReq::MEDIA_TYPE) + .header(DAP_AUTH_HEADER, &task_state.auth_token) + .body(dap_collect_request.get_encoded()) + .send() + .await + .context("error sending collect request to the leader")?; + let status = response.status(); + if status != StatusCode::SEE_OTHER { + return Err(anyhow::anyhow!(format!( + "collect request got status code {}", + status, + ))); + } + let collect_job_url = Url::parse( + response + .headers() + .get(LOCATION) + .context("response to collect request did not include a Location header")? + .to_str() + .context("collect response Location header contained invalid characters")?, + ) + .context("collect response Location header contained an invalid URL")?; + + let mut collect_jobs_guard = collect_jobs.lock().await; + let handle = loop { + let handle = Handle::generate(); + match collect_jobs_guard.entry(handle.clone()) { + Entry::Occupied(_) => continue, + entry @ Entry::Vacant(_) => { + entry.or_insert(CollectJobState { + task_id, + url: collect_job_url, + batch_interval, + agg_param, + }); + break handle; + } + } + }; + + Ok(handle) +} + +async fn handle_collect_poll( + http_client: &reqwest::Client, + tasks: &Mutex>, + collect_jobs: &Mutex>, + request: CollectPollRequest, +) -> anyhow::Result> { + let tasks_guard = tasks.lock().await; + let collect_jobs_guard = collect_jobs.lock().await; + let collect_job_state = collect_jobs_guard + .get(&Handle(request.handle)) + .context("did not recognize handle in collect_poll request")?; + let task_id = collect_job_state.task_id; + let task_state = tasks_guard + .get(&task_id) + .context("could not look up task information while polling")?; + + let response = http_client + .get(collect_job_state.url.clone()) + .header(DAP_AUTH_HEADER, &task_state.auth_token) + .send() + .await + .context("error fetching collect job from leader")?; + let status = response.status(); + if status == StatusCode::ACCEPTED { + return Ok(None); + } else if status != StatusCode::OK { + return Err(anyhow::anyhow!(format!( + "collect job fetch got status code {}", + status + ))); + } + + let dap_collect_response = CollectResp::get_decoded( + &response + .bytes() + .await + .context("error reading collect response")?, + ) + .context("could not decode collect response")?; + + if dap_collect_response.encrypted_agg_shares.len() != 2 { + return Err(anyhow::anyhow!( + "collect response does not have two ciphertexts" + )); + } + let associated_data = + associated_data_for_aggregate_share(task_id, collect_job_state.batch_interval); + let leader_aggregate_share_bytes = hpke::open( + &task_state.hpke_config, + &task_state.private_key, + &HpkeApplicationInfo::new(Label::AggregateShare, Role::Leader, Role::Collector), + &dap_collect_response.encrypted_agg_shares[0], + &associated_data, + ) + .context("could not decrypt aggregate share from the leader")?; + let helper_aggregate_share_bytes = hpke::open( + &task_state.hpke_config, + &task_state.private_key, + &HpkeApplicationInfo::new(Label::AggregateShare, Role::Helper, Role::Collector), + &dap_collect_response.encrypted_agg_shares[1], + &associated_data, + ) + .context("could not decrypt aggregate share from the helper")?; + + match task_state.vdaf { + VdafObject::Prio3Aes128Count {} => { + let leader_aggregate_share = + AggregateShare::::try_from(leader_aggregate_share_bytes.as_ref()) + .context("could not decode leader's aggregate share")?; + let helper_aggregate_share = + AggregateShare::::try_from(helper_aggregate_share_bytes.as_ref()) + .context("could not decode helper's aggregate share")?; + <::AggregationParam>::get_decoded( + &collect_job_state.agg_param, + ) + .context("could not decode aggregation parameter")?; + let vdaf = + Prio3::new_aes128_count(2).context("failed to construct Prio3Aes128Count VDAF")?; + let aggregate_result = vdaf + .unshard(&(), [leader_aggregate_share, helper_aggregate_share]) + .context("could not unshard aggregate result")?; + Ok(Some(AggregationResult::Number(aggregate_result))) + } + VdafObject::Prio3Aes128Sum { bits } => { + let leader_aggregate_share = + AggregateShare::::try_from(leader_aggregate_share_bytes.as_ref()) + .context("could not decode leader's aggregate share")?; + let helper_aggregate_share = + AggregateShare::::try_from(helper_aggregate_share_bytes.as_ref()) + .context("could not decode helper's aggregate share")?; + <::AggregationParam>::get_decoded(&collect_job_state.agg_param) + .context("could not decode aggregation parameter")?; + let vdaf = Prio3::new_aes128_sum(2, bits) + .context("failed to construct Prio3Aes128Sum VDAF")?; + let aggregate_result = vdaf + .unshard(&(), [leader_aggregate_share, helper_aggregate_share]) + .context("could not unshard aggregate result")?; + Ok(Some(AggregationResult::Number( + aggregate_result + .try_into() + .context("aggregate result was too large to represent natively in JSON")?, + ))) + } + VdafObject::Prio3Aes128Histogram { ref buckets } => { + let leader_aggregate_share = + AggregateShare::::try_from(leader_aggregate_share_bytes.as_ref()) + .context("could not decode leader's aggregate share")?; + let helper_aggregate_share = + AggregateShare::::try_from(helper_aggregate_share_bytes.as_ref()) + .context("could not decode helper's aggregate share")?; + <::AggregationParam>::get_decoded( + &collect_job_state.agg_param, + ) + .context("could not decode aggregation parameter")?; + let vdaf = Prio3::new_aes128_histogram(2, buckets) + .context("failed to construct Prio3Aes128Histogram VDAF")?; + let aggregate_result = vdaf + .unshard(&(), [leader_aggregate_share, helper_aggregate_share]) + .context("could not unshard aggregate result")?; + let converted = aggregate_result + .into_iter() + .map(|counter| { + u64::try_from(counter).context( + "entry in aggregate result was too large to represent natively in JSON", + ) + }) + .collect::, _>>()?; + Ok(Some(AggregationResult::NumberArray(converted))) + } + } +} + +fn make_filter() -> anyhow::Result + Clone> { + let http_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build()?; + let tasks: Arc>> = Arc::new(Mutex::new(HashMap::new())); + let collect_jobs: Arc>> = + Arc::new(Mutex::new(HashMap::new())); + let keyring = Arc::new(Mutex::new(HpkeConfigRegistry::new())); + + let add_task_filter = warp::path!("add_task").and(warp::body::json()).then({ + let tasks = Arc::clone(&tasks); + let keyring = Arc::clone(&keyring); + move |request: AddTaskRequest| { + let tasks = Arc::clone(&tasks); + let keyring = Arc::clone(&keyring); + async move { + let response = match handle_add_task(&tasks, &keyring, request).await { + Ok(collector_hpke_config) => AddTaskResponse { + status: SUCCESS, + error: None, + collector_hpke_config: Some(base64::encode_config( + collector_hpke_config.get_encoded(), + URL_SAFE_NO_PAD, + )), + }, + Err(e) => AddTaskResponse { + status: ERROR, + error: Some(format!("{:?}", e)), + collector_hpke_config: None, + }, + }; + warp::reply::with_status(warp::reply::json(&response), StatusCode::OK) + .into_response() + } + } + }); + let collect_start_filter = + warp::path!("collect_start").and(warp::body::json()).then({ + let http_client = http_client.clone(); + let tasks = Arc::clone(&tasks); + let collect_jobs = Arc::clone(&collect_jobs); + move |request: CollectStartRequest| { + let http_client = http_client.clone(); + let tasks = Arc::clone(&tasks); + let collect_jobs = Arc::clone(&collect_jobs); + async move { + let response = + match handle_collect_start(&http_client, &tasks, &collect_jobs, request) + .await + { + Ok(handle) => CollectStartResponse { + status: SUCCESS, + error: None, + handle: Some(handle.0), + }, + Err(e) => CollectStartResponse { + status: ERROR, + error: Some(format!("{:?}", e)), + handle: None, + }, + }; + warp::reply::with_status(warp::reply::json(&response), StatusCode::OK) + .into_response() + } + } + }); + let collect_poll_filter = warp::path!("collect_poll").and(warp::body::json()).then({ + move |request: CollectPollRequest| { + let http_client = http_client.clone(); + let tasks = Arc::clone(&tasks); + let collect_jobs = Arc::clone(&collect_jobs); + async move { + let response = + match handle_collect_poll(&http_client, &tasks, &collect_jobs, request).await { + Ok(Some(result)) => CollectPollResponse { + status: COMPLETE, + error: None, + result: Some(result), + }, + Ok(None) => CollectPollResponse { + status: IN_PROGRESS, + error: None, + result: None, + }, + Err(e) => CollectPollResponse { + status: ERROR, + error: Some(format!("{:?}", e)), + result: None, + }, + }; + warp::reply::with_status(warp::reply::json(&response), StatusCode::OK) + .into_response() + } + } + }); + + Ok(warp::path!("internal" / "test" / ..).and(warp::post()).and( + add_task_filter + .or(collect_start_filter) + .unify() + .or(collect_poll_filter) + .unify(), + )) +} + +fn app() -> clap::Command<'static> { + Command::new("Janus interoperation test collector").arg( + Arg::new("port") + .long("port") + .short('p') + .default_value("8080") + .help("Port number to listen on."), + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + install_tracing_subscriber()?; + let matches = app().get_matches(); + let port = matches.value_of_t::("port")?; + let filter = make_filter()?; + let server = warp::serve(filter); + server + .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, port))) + .await; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::app; + + #[test] + fn verify_clap_app() { + app().debug_assert(); + } +} diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs new file mode 100644 index 000000000..3d693a72a --- /dev/null +++ b/interop_binaries/src/lib.rs @@ -0,0 +1,90 @@ +use janus_core::{ + hpke::{generate_hpke_config_and_private_key, HpkePrivateKey}, + message::{HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId}, + task::VdafInstance, +}; +use rand::{thread_rng, Rng}; +use serde::Deserialize; +use std::collections::HashMap; +use tracing_log::LogTracer; +use tracing_subscriber::{prelude::*, EnvFilter, Registry}; + +pub mod status { + pub static SUCCESS: &str = "success"; + pub static ERROR: &str = "error"; + pub static COMPLETE: &str = "complete"; + pub static IN_PROGRESS: &str = "in progress"; +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum VdafObject { + Prio3Aes128Count {}, + Prio3Aes128Sum { bits: u32 }, + Prio3Aes128Histogram { buckets: Vec }, +} + +impl From for VdafInstance { + fn from(object: VdafObject) -> VdafInstance { + match object { + VdafObject::Prio3Aes128Count {} => VdafInstance::Prio3Aes128Count, + VdafObject::Prio3Aes128Sum { bits } => VdafInstance::Prio3Aes128Sum { bits }, + VdafObject::Prio3Aes128Histogram { buckets } => { + VdafInstance::Prio3Aes128Histogram { buckets } + } + } + } +} + +pub fn install_tracing_subscriber() -> anyhow::Result<()> { + let stdout_filter = EnvFilter::from_default_env(); + let layer = tracing_subscriber::fmt::layer() + .with_thread_ids(true) + .with_level(true) + .with_target(true) + .with_file(true) + .with_line_number(true) + .pretty(); + let subscriber = Registry::default().with(stdout_filter.and_then(layer)); + tracing::subscriber::set_global_default(subscriber)?; + + LogTracer::init()?; + + Ok(()) +} + +/// This registry lazily generates up to 256 HPKE key pairs, one with each possible +/// [`HpkeConfigId`]. +#[derive(Default)] +pub struct HpkeConfigRegistry { + keypairs: HashMap, +} + +impl HpkeConfigRegistry { + pub fn new() -> HpkeConfigRegistry { + Default::default() + } + + /// Get the keypair associated with a given ID. + pub fn fetch_keypair(&mut self, id: HpkeConfigId) -> (HpkeConfig, HpkePrivateKey) { + self.keypairs + .entry(id) + .or_insert_with(|| { + generate_hpke_config_and_private_key( + id, + // These algorithms should be broadly compatible with other DAP implementations, since they + // are required by section 6 of draft-ietf-ppm-dap-01. + HpkeKemId::X25519HkdfSha256, + HpkeKdfId::HkdfSha256, + HpkeAeadId::Aes128Gcm, + ) + }) + .clone() + } + + /// Choose a random [`HpkeConfigId`], and then get the keypair associated with that ID. + pub fn get_random_keypair(&mut self) -> (HpkeConfig, HpkePrivateKey) { + let id = HpkeConfigId::from(thread_rng().gen::()); + self.fetch_keypair(id) + } +} diff --git a/interop_binaries/supervisord.conf b/interop_binaries/supervisord.conf new file mode 100644 index 000000000..84cb8203d --- /dev/null +++ b/interop_binaries/supervisord.conf @@ -0,0 +1,15 @@ +[supervisord] +nodaemon=true +user=root + +[program:janus_interop_aggregator] +command=/janus_interop_aggregator +environment=RUST_LOG=info +stdout_logfile=/logs/aggregator_stdout.log +stderr_logfile=/logs/aggregator_stderr.log + +[program:postgres] +command=/usr/local/bin/docker-entrypoint.sh postgres +environment=POSTGRES_DB="postgres",POSTGRES_HOST_AUTH_METHOD="trust" +stdout_logfile=/logs/postgres_stdout.log +stderr_logfile=/logs/postgres_stderr.log diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs new file mode 100644 index 000000000..54f59ffef --- /dev/null +++ b/interop_binaries/tests/end_to_end.rs @@ -0,0 +1,592 @@ +use anyhow::Context; +use base64::URL_SAFE_NO_PAD; +use janus_core::{ + message::{Duration, TaskId}, + time::{Clock, RealClock}, +}; +use janus_server::task::PRIO3_AES128_VERIFY_KEY_LENGTH; +use lazy_static::lazy_static; +use portpicker::pick_unused_port; +use prio::codec::Encode; +use reqwest::{header::CONTENT_TYPE, StatusCode}; +use serde_json::{json, Value}; +use std::{ + collections::BTreeSet, + env, + io::{self, ErrorKind}, + net::{Ipv4Addr, SocketAddr}, + process::{Child, Command, Stdio}, + time::Duration as StdDuration, +}; +use testcontainers::{images::postgres::Postgres, RunnableImage}; +use tokio::{ + io::{AsyncBufReadExt, BufReader}, + net::TcpStream, + process::{ChildStderr, ChildStdout}, + time::sleep, +}; + +static JSON_MEDIA_TYPE: &str = "application/json"; +static MIN_BATCH_DURATION: u64 = 3600; + +lazy_static! { + static ref CONTAINER_CLIENT: testcontainers::clients::Cli = + testcontainers::clients::Cli::default(); +} + +/// Wait for a TCP server to begin listening on the given port. +async fn wait_for_tcp_server(port: u16) -> anyhow::Result<()> { + for _ in 0..100 { + if TcpStream::connect(SocketAddr::from((Ipv4Addr::LOCALHOST, port))) + .await + .is_ok() + { + return Ok(()); + } + sleep(StdDuration::from_millis(200)).await; + } + Err(anyhow::anyhow!( + "timed out waiting for a server to accept on port {}", + port, + )) +} + +/// RAII guard to ensure that child processes are cleaned up during test failures. +struct ChildProcessCleanupDropGuard(Child); + +impl Drop for ChildProcessCleanupDropGuard { + fn drop(&mut self) { + match self.0.kill() { + Ok(_) => {} + Err(e) if e.kind() == ErrorKind::InvalidInput => {} + Err(e) => panic!("failed to kill child process: {:?}", e), + } + } +} + +/// Pass output from a child process's stdout pipe to print!(), so that it can be captured and +/// stored by the test harness. +async fn forward_stdout(stdout: ChildStdout) -> io::Result<()> { + let mut reader = BufReader::new(stdout); + let mut line = String::new(); + loop { + line.clear(); + let count = reader.read_line(&mut line).await?; + if count == 0 { + return Ok(()); + } + print!("{}", line); + } +} + +/// Pass output from a child process's stderr pipe to eprint!(), so that it can be captured and +/// stored by the test harness. +async fn forward_stderr(stderr: ChildStderr) -> io::Result<()> { + let mut reader = BufReader::new(stderr); + let mut line = String::new(); + loop { + line.clear(); + let count = reader.read_line(&mut line).await?; + if count == 0 { + return Ok(()); + } + eprint!("{}", line); + } +} + +/// Take a VDAF description and a list of measurements, perform an entire aggregation using +/// interoperation test binaries, and return the aggregate result. This follows the outline of +/// section 4.7 of draft-dcook-ppm-dap-interop-test-design-00. +async fn run( + vdaf_object: serde_json::Value, + measurements: &[serde_json::Value], + aggregation_parameter: &[u8], +) -> anyhow::Result { + // Start up a database testcontainer for each aggregator directly, and don't set up the schema. + let leader_db_container = + CONTAINER_CLIENT.run(RunnableImage::from(Postgres::default()).with_tag("14-alpine")); + let leader_postgres_port = leader_db_container.get_host_port_ipv4(5432); + let helper_db_container = + CONTAINER_CLIENT.run(RunnableImage::from(Postgres::default()).with_tag("14-alpine")); + let helper_postgres_port = helper_db_container.get_host_port_ipv4(5432); + + // Pick four ports for HTTP servers. + let client_port = pick_unused_port().context("couldn't pick a port for the client")?; + let leader_port = pick_unused_port().context("couldn't pick a port for the leader")?; + let helper_port = pick_unused_port().context("couldn't pick a port for the helper")?; + let collector_port = pick_unused_port().context("couldn't pick a port for the collector")?; + assert_eq!( + BTreeSet::from([client_port, leader_port, helper_port, collector_port]).len(), + 4, + "Ports selected for HTTP servers were not unique", + ); + + // Create and start containers. (here, we just run the binaries instead) + // We use std::process instead of tokio::process so that we can kill the child processes from + // a Drop implementation. tokio::process::Child::kill() is async, and could not be called from + // there. + let mut client_command = Command::new(env!("CARGO_BIN_EXE_janus_interop_client")); + client_command.arg("--port").arg(format!("{}", client_port)); + let mut leader_command = Command::new(env!("CARGO_BIN_EXE_janus_interop_aggregator")); + leader_command.arg("--port").arg(format!("{}", leader_port)); + leader_command.arg("--postgres-url").arg(format!( + "postgres://postgres@127.0.0.1:{}/postgres", + leader_postgres_port + )); + let mut helper_command = Command::new(env!("CARGO_BIN_EXE_janus_interop_aggregator")); + helper_command.arg("--port").arg(format!("{}", helper_port)); + helper_command.arg("--postgres-url").arg(format!( + "postgres://postgres@127.0.0.1:{}/postgres", + helper_postgres_port + )); + let mut collector_command = Command::new(env!("CARGO_BIN_EXE_janus_interop_collector")); + collector_command + .arg("--port") + .arg(format!("{}", collector_port)); + let commands = [ + client_command, + leader_command, + helper_command, + collector_command, + ]; + let mut drop_guards = Vec::with_capacity(commands.len()); + for mut command in commands { + let mut drop_guard = ChildProcessCleanupDropGuard( + command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?, + ); + tokio::spawn(forward_stdout(ChildStdout::from_std( + drop_guard.0.stdout.take().unwrap(), + )?)); + tokio::spawn(forward_stderr(ChildStderr::from_std( + drop_guard.0.stderr.take().unwrap(), + )?)); + drop_guards.push(drop_guard); + } + + // Try opening a TCP connection to each container's port, and retry until it succeeds. + for port in [client_port, leader_port, helper_port, collector_port] { + wait_for_tcp_server(port).await?; + } + + // Generate a random TaskId, random authentication tokens, and a VDAF verification key. + let task_id = TaskId::random(); + let aggregator_auth_token = base64::encode_config(rand::random::<[u8; 16]>(), URL_SAFE_NO_PAD); + let collector_auth_token = base64::encode_config(rand::random::<[u8; 16]>(), URL_SAFE_NO_PAD); + let verify_key = rand::random::<[u8; PRIO3_AES128_VERIFY_KEY_LENGTH]>(); + + let task_id_encoded = base64::encode_config(&task_id.get_encoded(), URL_SAFE_NO_PAD); + let verify_key_encoded = base64::encode_config(&verify_key, URL_SAFE_NO_PAD); + let leader_endpoint = format!("http://127.0.0.1:{}/", leader_port); + let helper_endpoint = format!("http://127.0.0.1:{}/", helper_port); + + let http_client = reqwest::Client::new(); + + // Send a /internal/test/endpoint_for_task request to the leader. + let leader_endpoint_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/endpoint_for_task", + leader_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "aggregatorId": 0, + "hostnameAndPort": format!("127.0.0.1:{}", leader_port), + })) + .send() + .await?; + assert_eq!(leader_endpoint_response.status(), StatusCode::OK); + assert_eq!( + leader_endpoint_response + .headers() + .get(CONTENT_TYPE) + .unwrap(), + JSON_MEDIA_TYPE, + ); + let leader_endpoint_response_body = leader_endpoint_response.json::().await?; + let leader_endpoint_response_object = leader_endpoint_response_body + .as_object() + .context("endpoint_for_task response is not an object")?; + assert_eq!( + leader_endpoint_response_object + .get("status") + .context("endpoint_for_task response is missing \"status\"")?, + "success", + "error: {:?}", + leader_endpoint_response_object.get("error"), + ); + assert_eq!( + leader_endpoint_response_object + .get("endpoint") + .context("endpoint_for_task response is missing \"endpoint\"")?, + "/", + ); + + // Send a /internal/test/endpoint_for_task request to the helper. + let helper_endpoint_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/endpoint_for_task", + helper_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "aggregatorId": 1, + "hostnameAndPort": format!("127.0.0.1:{}", leader_port), + })) + .send() + .await?; + assert_eq!(helper_endpoint_response.status(), StatusCode::OK); + assert_eq!( + helper_endpoint_response + .headers() + .get(CONTENT_TYPE) + .unwrap(), + JSON_MEDIA_TYPE, + ); + let helper_endpoint_response_body = helper_endpoint_response.json::().await?; + let helper_endpoint_response_object = helper_endpoint_response_body + .as_object() + .context("endpoint_for_task response is not an object")?; + assert_eq!( + helper_endpoint_response_object + .get("status") + .context("endpoint_for_task response is missing \"status\"")?, + "success", + "error: {:?}", + helper_endpoint_response_object.get("error"), + ); + assert_eq!( + helper_endpoint_response_object + .get("endpoint") + .context("endpoint_for_task response is missing \"endpoint\"")?, + "/", + ); + + // Send a /internal/test/add_task request to the collector. + let collector_add_task_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/add_task", + collector_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "leader": leader_endpoint, + "vdaf": vdaf_object, + "collectorAuthenticationToken": collector_auth_token, + })) + .send() + .await?; + assert_eq!(collector_add_task_response.status(), StatusCode::OK); + assert_eq!( + collector_add_task_response + .headers() + .get(CONTENT_TYPE) + .unwrap(), + JSON_MEDIA_TYPE, + ); + let collector_add_task_response_body = collector_add_task_response.json::().await?; + let collector_add_task_response_object = collector_add_task_response_body + .as_object() + .context("collector add_task response is not an object")?; + assert_eq!( + collector_add_task_response_object + .get("status") + .context("collector add_task response is missing \"status\"")?, + "success", + "error: {:?}", + collector_add_task_response_object.get("error"), + ); + let collector_hpke_config_encoded = collector_add_task_response_object + .get("collectorHpkeConfig") + .context("collector add_task response is missing \"collectorHpkeConfig\"")? + .as_str() + .context("\"collectorHpkeConfig\" value is not a string")?; + + // Send a /internal/test/add_task request to the leader. + let leader_add_task_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/add_task", + leader_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "leader": leader_endpoint, + "helper": helper_endpoint, + "vdaf": vdaf_object, + "leaderAuthenticationToken": aggregator_auth_token, + "collectorAuthenticationToken": collector_auth_token, + "aggregatorId": 0, + "verifyKey": verify_key_encoded, + "maxBatchLifetime": 1, + "minBatchSize": 1, + "minBatchDuration": MIN_BATCH_DURATION, + "collectorHpkeConfig": collector_hpke_config_encoded, + })) + .send() + .await?; + assert_eq!(leader_add_task_response.status(), StatusCode::OK); + assert_eq!( + leader_add_task_response + .headers() + .get(CONTENT_TYPE) + .unwrap(), + JSON_MEDIA_TYPE, + ); + let leader_add_task_response_body = leader_add_task_response.json::().await?; + let leader_add_task_response_object = leader_add_task_response_body + .as_object() + .context("leader add_task response is not an object")?; + assert_eq!( + leader_add_task_response_object + .get("status") + .context("leader add_task response is missing \"status\"")?, + "success", + "error: {:?}", + leader_add_task_response_object.get("error"), + ); + + // Send a /internal/test/add_task request to the helper. + let helper_add_task_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/add_task", + helper_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "leader": leader_endpoint, + "helper": helper_endpoint, + "vdaf": vdaf_object, + "leaderAuthenticationToken": aggregator_auth_token, + "aggregatorId": 1, + "verifyKey": verify_key_encoded, + "maxBatchLifetime": 1, + "minBatchSize": 1, + "minBatchDuration": MIN_BATCH_DURATION, + "collectorHpkeConfig": collector_hpke_config_encoded, + })) + .send() + .await?; + assert_eq!(helper_add_task_response.status(), StatusCode::OK); + assert_eq!( + helper_add_task_response + .headers() + .get(CONTENT_TYPE) + .unwrap(), + JSON_MEDIA_TYPE, + ); + let helper_add_task_response_body = helper_add_task_response.json::().await?; + let helper_add_task_response_object = helper_add_task_response_body + .as_object() + .context("helper add_task response is not an object")?; + assert_eq!( + helper_add_task_response_object + .get("status") + .context("helper add_task response is missing \"status\"")?, + "success", + "error: {:?}", + helper_add_task_response_object.get("error"), + ); + + // Record the time before generating reports, and round it down to + // determine what batch time to start the aggregation at. + let start_timestamp = RealClock::default().now(); + let batch_interval_start = start_timestamp + .to_batch_unit_interval_start(Duration::from_seconds(MIN_BATCH_DURATION))? + .as_seconds_since_epoch(); + // Span the aggregation over two minimum batch durations, just in case our + // measurements spilled over a batch boundary. + let batch_interval_duration = MIN_BATCH_DURATION * 2; + + // Send one or more /internal/test/upload requests to the client. + for measurement in measurements { + let upload_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/upload", + client_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "leader": leader_endpoint, + "helper": helper_endpoint, + "vdaf": vdaf_object, + "measurement": measurement, + "minBatchDuration": MIN_BATCH_DURATION, + })) + .send() + .await?; + assert_eq!(upload_response.status(), StatusCode::OK); + assert_eq!( + upload_response.headers().get(CONTENT_TYPE).unwrap(), + JSON_MEDIA_TYPE, + ); + let upload_response_body = upload_response.json::().await?; + let upload_response_object = upload_response_body + .as_object() + .context("upload response is not an object")?; + assert_eq!( + upload_response_object + .get("status") + .context("upload response is missing \"status\"")?, + "success", + "error: {:?}", + upload_response_object.get("error"), + ); + } + + // Send a /internal/test/collect_start request to the collector. + let collect_start_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/collect_start", + collector_port, + )) + .json(&json!({ + "taskId": task_id_encoded, + "aggParam": base64::encode_config(aggregation_parameter, URL_SAFE_NO_PAD), + "batchIntervalStart": batch_interval_start, + "batchIntervalDuration": batch_interval_duration, + })) + .send() + .await?; + assert_eq!(collect_start_response.status(), StatusCode::OK); + assert_eq!( + collect_start_response.headers().get(CONTENT_TYPE).unwrap(), + JSON_MEDIA_TYPE, + ); + let collect_start_response_body = collect_start_response.json::().await?; + let collect_start_response_object = collect_start_response_body + .as_object() + .context("collect_start response is not an object")?; + assert_eq!( + collect_start_response_object + .get("status") + .context("collect_start response is missing \"status\"")?, + "success", + "error: {:?}", + collect_start_response_object.get("error"), + ); + let collect_job_handle = collect_start_response_object + .get("handle") + .context("collect_start response is missing \"handle\"")? + .as_str() + .context("\"handle\" value is not a string")?; + + // Send /internal/test/collect_poll requests to the collector, polling until it is completed. + for _ in 0..30 { + let collect_poll_response = http_client + .post(format!( + "http://127.0.0.1:{}/internal/test/collect_poll", + collector_port, + )) + .json(&json!({ + "handle": collect_job_handle, + })) + .send() + .await?; + assert_eq!(collect_poll_response.status(), StatusCode::OK); + assert_eq!( + collect_poll_response.headers().get(CONTENT_TYPE).unwrap(), + JSON_MEDIA_TYPE, + ); + let collect_poll_response_body = collect_poll_response.json::().await?; + let collect_poll_response_object = collect_poll_response_body + .as_object() + .context("collect_poll response is not an object")?; + let status = collect_poll_response_object + .get("status") + .context("collect_poll response is missing \"status\"")? + .as_str() + .context("\"status\" value is not a string")?; + if status == "in progress" { + tokio::time::sleep(StdDuration::from_millis(500)).await; + continue; + } + assert_eq!( + status, + "complete", + "error: {:?}", + collect_poll_response_object.get("error"), + ); + return collect_poll_response_object + .get("result") + .context("completed collect_poll response is missing \"result\"") + .cloned(); + } + + Err(anyhow::anyhow!("timed out fetching aggregation result")) +} + +#[tokio::test] +async fn e2e_prio3_count() { + let result = run( + json!({"type": "Prio3Aes128Count"}), + &[ + json!(0), + json!(1), + json!(1), + json!(0), + json!(1), + json!(0), + json!(1), + json!(0), + json!(1), + json!(1), + json!(0), + json!(1), + json!(0), + json!(1), + json!(0), + json!(0), + json!(0), + json!(0), + ], + b"", + ) + .await + .unwrap(); + assert_eq!(result, json!(8)); +} + +#[tokio::test] +async fn e2e_prio3_sum() { + let result = run( + json!({"type": "Prio3Aes128Sum", "bits": 64}), + &[ + json!(0), + json!(10), + json!(9), + json!(21), + json!(8), + json!(12), + json!(14), + ], + b"", + ) + .await + .unwrap(); + assert_eq!(result, json!(74)); +} + +#[tokio::test] +async fn e2e_prio3_histogram() { + let result = run( + json!({"type": "Prio3Aes128Histogram", "buckets": [0, 1, 10, 100, 1_000, 10_000, 100_000]}), + &[ + json!(1), + json!(4), + json!(16), + json!(64), + json!(256), + json!(1024), + json!(4096), + json!(16384), + json!(65536), + json!(262144), + ], + b"", + ) + .await + .unwrap(); + assert_eq!(result, json!([0, 1, 1, 2, 1, 2, 2, 1])); +} diff --git a/janus_client/src/lib.rs b/janus_client/src/lib.rs index 21af31d86..1d9b5db02 100644 --- a/janus_client/src/lib.rs +++ b/janus_client/src/lib.rs @@ -233,10 +233,10 @@ mod tests { use super::*; use assert_matches::assert_matches; use janus_core::{ - hpke::test_util::generate_hpke_config_and_private_key, + hpke::test_util::generate_test_hpke_config_and_private_key, message::{TaskId, Time}, test_util::install_test_trace_subscriber, - time::test_util::MockClock, + time::MockClock, }; use mockito::mock; use prio::vdaf::prio3::Prio3; @@ -256,8 +256,8 @@ mod tests { vdaf_client, MockClock::default(), &default_http_client().unwrap(), - generate_hpke_config_and_private_key().0, - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, ) } @@ -338,8 +338,8 @@ mod tests { Prio3::new_aes128_count(2).unwrap(), MockClock::default(), &default_http_client().unwrap(), - generate_hpke_config_and_private_key().0, - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, ); let result = client.upload(&1).await; assert_matches!(result, Err(Error::InvalidParameter(_))); diff --git a/janus_core/src/hpke.rs b/janus_core/src/hpke.rs index b163211d0..63c9d27dc 100644 --- a/janus_core/src/hpke.rs +++ b/janus_core/src/hpke.rs @@ -1,7 +1,10 @@ //! Encryption and decryption of messages using HPKE (RFC 9180). -use crate::message::{Extension, HpkeCiphertext, HpkeConfig, Interval, Nonce, Role, TaskId}; -use hpke_dispatch::HpkeError; +use crate::message::{ + Extension, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, + HpkePublicKey, Interval, Nonce, Role, TaskId, +}; +use hpke_dispatch::{HpkeError, Kem, Keypair}; use prio::codec::{encode_u16_items, Encode}; use std::str::FromStr; @@ -170,38 +173,52 @@ pub fn open( .map_err(Into::into) } +/// Generate a new HPKE keypair and return it as an HpkeConfig (public portion) and +/// HpkePrivateKey (private portion). +pub fn generate_hpke_config_and_private_key( + hpke_config_id: HpkeConfigId, + kem_id: HpkeKemId, + kdf_id: HpkeKdfId, + aead_id: HpkeAeadId, +) -> (HpkeConfig, HpkePrivateKey) { + let Keypair { + private_key, + public_key, + } = match kem_id { + HpkeKemId::X25519HkdfSha256 => Kem::X25519HkdfSha256.gen_keypair(), + HpkeKemId::P256HkdfSha256 => Kem::DhP256HkdfSha256.gen_keypair(), + }; + ( + HpkeConfig::new( + hpke_config_id, + kem_id, + kdf_id, + aead_id, + HpkePublicKey::new(public_key), + ), + HpkePrivateKey::new(private_key), + ) +} + #[cfg(feature = "test-util")] pub mod test_util { - use crate::{ - hpke::HpkePrivateKey, - message::{HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey}, - }; - use hpke_dispatch::{Kem, Keypair}; + use super::{generate_hpke_config_and_private_key, HpkePrivateKey}; + use crate::message::{HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId}; use rand::{thread_rng, Rng}; - /// Generate a new HPKE keypair and return it as an HpkeConfig (public portion) and - /// HpkePrivateKey (private portion). - pub fn generate_hpke_config_and_private_key() -> (HpkeConfig, HpkePrivateKey) { - let Keypair { - private_key, - public_key, - } = Kem::X25519HkdfSha256.gen_keypair(); - ( - HpkeConfig::new( - HpkeConfigId::from(thread_rng().gen::()), - HpkeKemId::X25519HkdfSha256, - HpkeKdfId::HkdfSha256, - HpkeAeadId::Aes128Gcm, - HpkePublicKey::new(public_key), - ), - HpkePrivateKey::new(private_key), + pub fn generate_test_hpke_config_and_private_key() -> (HpkeConfig, HpkePrivateKey) { + generate_hpke_config_and_private_key( + HpkeConfigId::from(thread_rng().gen::()), + HpkeKemId::X25519HkdfSha256, + HpkeKdfId::HkdfSha256, + HpkeAeadId::Aes128Gcm, ) } } #[cfg(test)] mod tests { - use super::{test_util::generate_hpke_config_and_private_key, HpkeApplicationInfo, Label}; + use super::{test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label}; use crate::{ hpke::{open, seal, HpkePrivateKey}, message::{ @@ -215,7 +232,7 @@ mod tests { #[test] fn exchange_message() { - let (hpke_config, hpke_private_key) = generate_hpke_config_and_private_key(); + let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); let message = b"a message that is secret"; @@ -237,7 +254,7 @@ mod tests { #[test] fn wrong_private_key() { - let (hpke_config, _) = generate_hpke_config_and_private_key(); + let (hpke_config, _) = generate_test_hpke_config_and_private_key(); let application_info = HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); let message = b"a message that is secret"; @@ -246,7 +263,8 @@ mod tests { let ciphertext = seal(&hpke_config, &application_info, message, associated_data).unwrap(); // Attempt to decrypt with different private key, and verify this fails. - let (wrong_hpke_config, wrong_hpke_private_key) = generate_hpke_config_and_private_key(); + let (wrong_hpke_config, wrong_hpke_private_key) = + generate_test_hpke_config_and_private_key(); open( &wrong_hpke_config, &wrong_hpke_private_key, @@ -259,7 +277,7 @@ mod tests { #[test] fn wrong_application_info() { - let (hpke_config, hpke_private_key) = generate_hpke_config_and_private_key(); + let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); let message = b"a message that is secret"; @@ -281,7 +299,7 @@ mod tests { #[test] fn wrong_associated_data() { - let (hpke_config, hpke_private_key) = generate_hpke_config_and_private_key(); + let (hpke_config, hpke_private_key) = generate_test_hpke_config_and_private_key(); let application_info = HpkeApplicationInfo::new(Label::InputShare, Role::Client, Role::Leader); let message = b"a message that is secret"; diff --git a/janus_core/src/time.rs b/janus_core/src/time.rs index 5fe7af761..c2ffdd6de 100644 --- a/janus_core/src/time.rs +++ b/janus_core/src/time.rs @@ -1,8 +1,11 @@ //! Utilities for timestamps and durations. -use crate::message::Time; +use crate::message::{Duration, Time}; use chrono::Utc; -use std::fmt::{Debug, Formatter}; +use std::{ + fmt::{Debug, Formatter}, + sync::{Arc, Mutex}, +}; /// A clock knows what time it currently is. pub trait Clock: 'static + Clone + Debug + Sync + Send { @@ -32,49 +35,40 @@ impl Debug for RealClock { } } -#[cfg(feature = "test-util")] -pub mod test_util { - use crate::{ - message::{Duration, Time}, - time::Clock, - }; - use std::sync::{Arc, Mutex}; - - /// A mock clock for use in testing. Clones are identical: all clones of a given MockClock will - /// be controlled by a controller retrieved from any of the clones. - #[derive(Clone, Debug)] - #[non_exhaustive] - pub struct MockClock { - /// The time that this clock will return from [`Self::now`]. - current_time: Arc>, - } +/// A mock clock for use in testing. Clones are identical: all clones of a given MockClock will +/// be controlled by a controller retrieved from any of the clones. +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct MockClock { + /// The time that this clock will return from [`Self::now`]. + current_time: Arc>, +} - impl MockClock { - pub fn new(when: Time) -> MockClock { - MockClock { - current_time: Arc::new(Mutex::new(when)), - } +impl MockClock { + pub fn new(when: Time) -> MockClock { + MockClock { + current_time: Arc::new(Mutex::new(when)), } + } - pub fn advance(&self, dur: Duration) { - let mut current_time = self.current_time.lock().unwrap(); - *current_time = current_time.add(dur).unwrap(); - } + pub fn advance(&self, dur: Duration) { + let mut current_time = self.current_time.lock().unwrap(); + *current_time = current_time.add(dur).unwrap(); } +} - impl Clock for MockClock { - fn now(&self) -> Time { - let current_time = self.current_time.lock().unwrap(); - *current_time - } +impl Clock for MockClock { + fn now(&self) -> Time { + let current_time = self.current_time.lock().unwrap(); + *current_time } +} - impl Default for MockClock { - fn default() -> Self { - Self { - // Sunday, September 9, 2001 1:46:40 AM UTC - current_time: Arc::new(Mutex::new(Time::from_seconds_since_epoch(1000000000))), - } +impl Default for MockClock { + fn default() -> Self { + Self { + // Sunday, September 9, 2001 1:46:40 AM UTC + current_time: Arc::new(Mutex::new(Time::from_seconds_since_epoch(1000000000))), } } } diff --git a/janus_server/src/aggregator.rs b/janus_server/src/aggregator.rs index 49b29f496..cb286a955 100644 --- a/janus_server/src/aggregator.rs +++ b/janus_server/src/aggregator.rs @@ -2293,15 +2293,15 @@ mod tests { use janus_core::{ hpke::associated_data_for_report_share, hpke::{ - associated_data_for_aggregate_share, test_util::generate_hpke_config_and_private_key, - HpkePrivateKey, Label, + associated_data_for_aggregate_share, + test_util::generate_test_hpke_config_and_private_key, HpkePrivateKey, Label, }, message::{Duration, HpkeCiphertext, HpkeConfig, TaskId, Time}, test_util::{ dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, run_vdaf, }, - time::test_util::MockClock, + time::MockClock, }; use opentelemetry::global::meter; use prio::{ @@ -3183,7 +3183,7 @@ mod tests { // report_share_3 has an unknown HPKE config ID. let nonce_3 = Nonce::generate(&clock, task.min_batch_duration).unwrap(); let wrong_hpke_config = loop { - let hpke_config = generate_hpke_config_and_private_key().0; + let hpke_config = generate_test_hpke_config_and_private_key().0; if task.hpke_keys.contains_key(&hpke_config.id()) { continue; } @@ -5309,7 +5309,7 @@ mod tests { let batch_interval = Interval::new(Time::from_seconds_since_epoch(0), task.min_batch_duration).unwrap(); let (collector_hpke_config, collector_hpke_recipient) = - generate_hpke_config_and_private_key(); + generate_test_hpke_config_and_private_key(); task.collector_hpke_config = collector_hpke_config; let leader_aggregate_share = AggregateShare::from(vec![Field64::from(64)]); @@ -5718,7 +5718,7 @@ mod tests { let task_id = TaskId::random(); let (collector_hpke_config, collector_hpke_recipient) = - generate_hpke_config_and_private_key(); + generate_test_hpke_config_and_private_key(); let mut task = new_dummy_task(task_id, VdafInstance::Fake, Role::Helper); task.max_batch_lifetime = 1; diff --git a/janus_server/src/aggregator/aggregate_share.rs b/janus_server/src/aggregator/aggregate_share.rs index 427c4cc3e..5b648f875 100644 --- a/janus_server/src/aggregator/aggregate_share.rs +++ b/janus_server/src/aggregator/aggregate_share.rs @@ -458,7 +458,7 @@ mod tests { install_test_trace_subscriber, runtime::TestRuntimeManager, }, - time::test_util::MockClock, + time::MockClock, Runtime, }; use mockito::mock; diff --git a/janus_server/src/aggregator/aggregation_job_creator.rs b/janus_server/src/aggregator/aggregation_job_creator.rs index 9af49729c..4d1276a4d 100644 --- a/janus_server/src/aggregator/aggregation_job_creator.rs +++ b/janus_server/src/aggregator/aggregation_job_creator.rs @@ -452,7 +452,7 @@ mod tests { dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, }, - time::{test_util::MockClock, Clock}, + time::{Clock, MockClock}, }; use prio::{ codec::ParameterizedDecode, diff --git a/janus_server/src/aggregator/aggregation_job_driver.rs b/janus_server/src/aggregator/aggregation_job_driver.rs index 8377b1bd9..9b68688e8 100644 --- a/janus_server/src/aggregator/aggregation_job_driver.rs +++ b/janus_server/src/aggregator/aggregation_job_driver.rs @@ -853,12 +853,12 @@ mod tests { use janus_core::{ hpke::{ self, associated_data_for_report_share, - test_util::generate_hpke_config_and_private_key, HpkeApplicationInfo, Label, + test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label, }, message::{Duration, HpkeConfig, Interval, Nonce, NonceChecksum, Report, Role, TaskId}, task::VdafInstance, test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntimeManager}, - time::test_util::MockClock, + time::MockClock, Runtime, }; use mockito::mock; @@ -908,7 +908,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token().clone(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; - let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); + let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( task_id, nonce, @@ -1105,7 +1105,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; - let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); + let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( task_id, nonce, @@ -1292,7 +1292,7 @@ mod tests { let agg_auth_token = task.primary_aggregator_auth_token(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; - let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); + let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( task_id, nonce, @@ -1506,7 +1506,7 @@ mod tests { let input_shares = run_vdaf(vdaf.as_ref(), &verify_key, &(), nonce, &0).input_shares; let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; - let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); + let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let report = generate_report( task_id, nonce, @@ -1656,7 +1656,7 @@ mod tests { let verify_key = task.vdaf_verify_keys[0].clone().try_into().unwrap(); let (leader_hpke_config, _) = task.hpke_keys.iter().next().unwrap().1; - let (helper_hpke_config, _) = generate_hpke_config_and_private_key(); + let (helper_hpke_config, _) = generate_test_hpke_config_and_private_key(); let vdaf = Prio3::new_aes128_count(2).unwrap(); let nonce = Nonce::generate(&clock, task.min_batch_duration).unwrap(); diff --git a/janus_server/src/binary_utils/job_driver.rs b/janus_server/src/binary_utils/job_driver.rs index e161bcb39..63cdc5e5c 100644 --- a/janus_server/src/binary_utils/job_driver.rs +++ b/janus_server/src/binary_utils/job_driver.rs @@ -253,7 +253,7 @@ mod tests { use janus_core::{ message::TaskId, test_util::{install_test_trace_subscriber, runtime::TestRuntimeManager}, - time::test_util::MockClock, + time::MockClock, }; use opentelemetry::global::meter; use tokio::sync::Mutex; diff --git a/janus_server/src/datastore.rs b/janus_server/src/datastore.rs index ba6f0c30c..2b40d7a03 100644 --- a/janus_server/src/datastore.rs +++ b/janus_server/src/datastore.rs @@ -3292,7 +3292,7 @@ mod tests { dummy_vdaf::{self, AggregationParam}, install_test_trace_subscriber, }, - time::test_util::MockClock, + time::MockClock, }; use prio::{ field::{Field128, Field64}, diff --git a/janus_server/src/task.rs b/janus_server/src/task.rs index 808029c33..13f433ad7 100644 --- a/janus_server/src/task.rs +++ b/janus_server/src/task.rs @@ -532,7 +532,7 @@ pub mod test_util { use super::{AuthenticationToken, Task, VdafInstance, PRIO3_AES128_VERIFY_KEY_LENGTH}; use janus_core::{ - hpke::test_util::generate_hpke_config_and_private_key, + hpke::test_util::generate_test_hpke_config_and_private_key, message::{Duration, HpkeConfig, HpkeConfigId, Role, TaskId}, }; use rand::{thread_rng, Rng}; @@ -558,9 +558,9 @@ pub mod test_util { /// integration tests. pub fn new_dummy_task(task_id: TaskId, vdaf: VdafInstance, role: Role) -> Task { let (aggregator_config_0, aggregator_private_key_0) = - generate_hpke_config_and_private_key(); + generate_test_hpke_config_and_private_key(); let (mut aggregator_config_1, aggregator_private_key_1) = - generate_hpke_config_and_private_key(); + generate_test_hpke_config_and_private_key(); aggregator_config_1 = HpkeConfig::new( HpkeConfigId::from(1), aggregator_config_1.kem_id(), @@ -592,7 +592,7 @@ pub mod test_util { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token(), generate_auth_token()]), collector_auth_tokens, Vec::from([ @@ -620,7 +620,7 @@ mod tests { }; use crate::{config::test_util::roundtrip_encoding, task::VdafInstance}; use janus_core::{ - hpke::test_util::generate_hpke_config_and_private_key, + hpke::test_util::generate_test_hpke_config_and_private_key, message::{Duration, Interval, Role, TaskId, Time}, }; use serde_test::{assert_tokens, Token}; @@ -800,10 +800,10 @@ mod tests { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token()]), Vec::new(), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); @@ -821,10 +821,10 @@ mod tests { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token()]), Vec::from([generate_auth_token()]), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -842,10 +842,10 @@ mod tests { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token()]), Vec::new(), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); @@ -863,10 +863,10 @@ mod tests { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token()]), Vec::from([generate_auth_token()]), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap_err(); } @@ -886,10 +886,10 @@ mod tests { 0, Duration::from_hours(8).unwrap(), Duration::from_minutes(10).unwrap(), - generate_hpke_config_and_private_key().0, + generate_test_hpke_config_and_private_key().0, Vec::from([generate_auth_token()]), Vec::from([generate_auth_token()]), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); diff --git a/monolithic_integration_test/tests/common/mod.rs b/monolithic_integration_test/tests/common/mod.rs index f6946ff45..80835c8c1 100644 --- a/monolithic_integration_test/tests/common/mod.rs +++ b/monolithic_integration_test/tests/common/mod.rs @@ -8,8 +8,9 @@ use itertools::Itertools; use janus_client::{Client, ClientParameters}; use janus_core::{ hpke::{ - self, associated_data_for_aggregate_share, test_util::generate_hpke_config_and_private_key, - HpkeApplicationInfo, HpkePrivateKey, Label, + self, associated_data_for_aggregate_share, + test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, HpkePrivateKey, + Label, }, message::{Duration, HpkeConfig, Interval, Role, TaskId}, task::VdafInstance, @@ -71,7 +72,7 @@ pub fn create_test_tasks( collector_hpke_config.clone(), aggregator_auth_tokens.clone(), Vec::from([generate_auth_token()]), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); let helper_task = Task::new( @@ -87,7 +88,7 @@ pub fn create_test_tasks( collector_hpke_config.clone(), aggregator_auth_tokens, Vec::new(), - Vec::from([generate_hpke_config_and_private_key()]), + Vec::from([generate_test_hpke_config_and_private_key()]), ) .unwrap(); diff --git a/monolithic_integration_test/tests/daphne.rs b/monolithic_integration_test/tests/daphne.rs index 3fa487a37..21f112d7c 100644 --- a/monolithic_integration_test/tests/daphne.rs +++ b/monolithic_integration_test/tests/daphne.rs @@ -1,6 +1,7 @@ use common::{create_test_tasks, pick_two_unused_ports, submit_measurements_and_verify_aggregate}; use janus_core::{ - hpke::test_util::generate_hpke_config_and_private_key, test_util::install_test_trace_subscriber, + hpke::test_util::generate_test_hpke_config_and_private_key, + test_util::install_test_trace_subscriber, }; use monolithic_integration_test::{daphne::Daphne, janus::Janus}; @@ -13,7 +14,8 @@ async fn daphne_janus() { // Start servers. let (daphne_port, janus_port) = pick_two_unused_ports(); - let (collector_hpke_config, collector_private_key) = generate_hpke_config_and_private_key(); + let (collector_hpke_config, collector_private_key) = + generate_test_hpke_config_and_private_key(); let (daphne_task, janus_task) = create_test_tasks(daphne_port, janus_port, &collector_hpke_config); @@ -36,7 +38,8 @@ async fn janus_daphne() { // Start servers. let (janus_port, daphne_port) = pick_two_unused_ports(); - let (collector_hpke_config, collector_private_key) = generate_hpke_config_and_private_key(); + let (collector_hpke_config, collector_private_key) = + generate_test_hpke_config_and_private_key(); let (janus_task, daphne_task) = create_test_tasks(janus_port, daphne_port, &collector_hpke_config); diff --git a/monolithic_integration_test/tests/janus.rs b/monolithic_integration_test/tests/janus.rs index 804256655..7c5b2f769 100644 --- a/monolithic_integration_test/tests/janus.rs +++ b/monolithic_integration_test/tests/janus.rs @@ -1,6 +1,7 @@ use common::{create_test_tasks, pick_two_unused_ports, submit_measurements_and_verify_aggregate}; use janus_core::{ - hpke::test_util::generate_hpke_config_and_private_key, test_util::install_test_trace_subscriber, + hpke::test_util::generate_test_hpke_config_and_private_key, + test_util::install_test_trace_subscriber, }; use monolithic_integration_test::janus::Janus; @@ -13,7 +14,8 @@ async fn janus_janus() { // Start servers. let (janus_leader_port, janus_helper_port) = pick_two_unused_ports(); - let (collector_hpke_config, collector_private_key) = generate_hpke_config_and_private_key(); + let (collector_hpke_config, collector_private_key) = + generate_test_hpke_config_and_private_key(); let (mut janus_leader_task, mut janus_helper_task) = create_test_tasks(janus_leader_port, janus_helper_port, &collector_hpke_config);