Skip to content
135 changes: 135 additions & 0 deletions qdp/qdp-core/examples/distributed_multigpu_q34_probe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Instant;

use qdp_core::gpu::LocalCollectiveCommunicator;
use qdp_core::{
DistributedExecutionContext, DistributionMode, MahoutError, PlacementRequest, Precision,
QdpEngine, ShardPolicy,
};

fn gib(bytes: usize) -> f64 {
bytes as f64 / (1024.0 * 1024.0 * 1024.0)
}

fn parse_device_ids() -> Result<Vec<usize>, MahoutError> {
let raw = std::env::var("GPU_IDS").unwrap_or_else(|_| "0,1,2,3,4,5".to_string());
let mut ids = Vec::new();
for piece in raw.split(',') {
let trimmed = piece.trim();
if trimmed.is_empty() {
continue;
}
ids.push(trimmed.parse::<usize>().map_err(|err| {
MahoutError::InvalidInput(format!("Invalid GPU ID '{trimmed}': {err}"))
})?);
}

if ids.is_empty() {
return Err(MahoutError::InvalidInput(
"GPU_IDS must contain at least one CUDA device ID".to_string(),
));
}

Ok(ids)
}

fn main() -> Result<(), MahoutError> {
let num_qubits = std::env::var("QUBITS")
.ok()
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(34);
let host_len = std::env::var("HOST_LEN")
.ok()
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(1);
let precision = match std::env::var("PRECISION").ok().as_deref() {
Some("f64") | Some("float64") => Precision::Float64,
_ => Precision::Float32,
};
let shard_policy = match std::env::var("SHARD_POLICY").ok().as_deref() {
Some("equal") => ShardPolicy::Equal,
_ => ShardPolicy::BalancedUneven,
};
let device_ids = parse_device_ids()?;
let request =
PlacementRequest::new(num_qubits, DistributionMode::ShardedCapacity, shard_policy);
let host_data = vec![1.0f64; host_len];

println!(
"Starting distributed amplitude probe: qubits={}, host_len={}, gpus={:?}, precision={:?}, shard_policy={:?}, collectives=in-process",
num_qubits, host_len, device_ids, precision, shard_policy
);

let collectives = LocalCollectiveCommunicator;
let execution = DistributedExecutionContext::single_process(device_ids.clone(), &collectives)?;

let prepare_start = Instant::now();
let prepared = QdpEngine::prepare_distributed_amplitude_on(
&execution,
&host_data,
num_qubits,
precision,
Some(request.clone()),
)?;
let prepare_elapsed = prepare_start.elapsed();

println!(
"Prepared in {:.3}s; global_len={}; shards={}; max_local_len={}; estimated_max_shard_gib={:.2}; gather_device={:?}",
prepare_elapsed.as_secs_f64(),
prepared.plan.global_len,
prepared.layout.num_shards(),
prepared.plan.max_local_len(),
gib(prepared.plan.estimated_max_shard_bytes(precision)?),
prepared.layout.recommended_gather_device_id()
);

for shard in &prepared.layout.shards {
let shard_bytes = match precision {
Precision::Float32 => shard.local_len * 8,
Precision::Float64 => shard.local_len * 16,
};
println!(
" shard {} -> cuda:{} range=[{}, {}) local_len={} (~{:.2} GiB)",
shard.shard_id,
shard.device_id,
shard.start_idx,
shard.end_idx,
shard.local_len,
gib(shard_bytes)
);
}

let encode_start = Instant::now();
let state = QdpEngine::encode_distributed_amplitude_to_shards_on(
&execution,
&host_data,
num_qubits,
precision,
Some(request),
)?;
let encode_elapsed = encode_start.elapsed();

println!(
"Encoded in {:.3}s; state_shards={}; placement={:?}",
encode_elapsed.as_secs_f64(),
state.num_shards(),
state.recommended_placement_device_ids()
);

Ok(())
}
44 changes: 44 additions & 0 deletions qdp/qdp-core/src/gpu/communicator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::error::{MahoutError, Result};

/// Abstracts cross-shard collective operations.
///
/// The current implementation executes collectives inside one process. A future
/// MPI-backed implementation can provide the same interface while mapping the
/// partial contributions to rank-local shards and performing a real all-reduce.
pub trait CollectiveCommunicator: Send + Sync {
/// Sum one set of per-shard partial contributions into one global scalar.
fn all_reduce_sum_f64(&self, values: &[f64]) -> Result<f64>;
}

/// In-process collective implementation for the current single-process
/// distributed path.
#[derive(Default, Debug, Clone, Copy)]
pub struct LocalCollectiveCommunicator;

impl CollectiveCommunicator for LocalCollectiveCommunicator {
fn all_reduce_sum_f64(&self, values: &[f64]) -> Result<f64> {
if values.is_empty() {
return Err(MahoutError::InvalidInput(
"Collective reduction requires at least one partial contribution".to_string(),
));
}

Ok(values.iter().copied().sum())
}
}
8 changes: 7 additions & 1 deletion qdp/qdp-core/src/gpu/cuda_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ unsafe extern "C" {
ptr: *const c_void,
) -> i32;

pub(crate) fn cudaGetDevice(device: *mut i32) -> i32;
pub(crate) fn cudaSetDevice(device: i32) -> i32;
pub(crate) fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> i32;
pub(crate) fn cudaDeviceCanAccessPeer(
can_access_peer: *mut i32,
device: i32,
peer_device: i32,
) -> i32;

pub(crate) fn cudaMemcpyAsync(
dst: *mut c_void,
Expand All @@ -63,7 +70,6 @@ unsafe extern "C" {
kind: u32,
stream: *mut c_void,
) -> i32;

pub(crate) fn cudaEventCreateWithFlags(event: *mut *mut c_void, flags: u32) -> i32;
pub(crate) fn cudaEventRecord(event: *mut c_void, stream: *mut c_void) -> i32;
pub(crate) fn cudaEventDestroy(event: *mut c_void) -> i32;
Expand Down
168 changes: 168 additions & 0 deletions qdp/qdp-core/src/gpu/distributed/amplitude.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::error::{MahoutError, Result};
use crate::gpu::distributed::{
DistributionMode, PlacementPlan, PlacementPlanner, PlacementRequest, ShardPlacement,
ShardPolicy,
};
use crate::gpu::memory::Precision;
use crate::gpu::topology::DeviceMesh;

/// Shared planning math for amplitude-sharded state construction.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DistributedAmplitudePlan {
pub request: PlacementRequest,
pub placement: PlacementPlan,
pub num_qubits: usize,
pub global_len: usize,
pub num_devices: usize,
pub shard_bits: Option<usize>,
pub uniform_shard_len: Option<usize>,
}

/// Result of preparing a distributed amplitude encode without yet allocating
/// concrete shard buffers. This fixes the public API surface for later PRs that
/// will populate `state` with real device allocations.
#[derive(Clone)]
pub struct PreparedDistributedAmplitudeEncode {
pub mesh: DeviceMesh,
pub plan: DistributedAmplitudePlan,
pub inv_norm: f64,
pub layout: super::layout::DistributedStateLayout,
}

impl DistributedAmplitudePlan {
/// Validate one distributed amplitude request and derive the shard math used
/// by later layout and materialization steps.
pub fn for_request(mesh: &DeviceMesh, request: PlacementRequest) -> Result<Self> {
if request.num_qubits == 0 {
return Err(MahoutError::InvalidInput(
"Number of qubits must be at least 1 for distributed amplitude planning"
.to_string(),
));
}
if mesh.num_devices() == 0 {
return Err(MahoutError::InvalidInput(
"Distributed amplitude planning requires at least one device".to_string(),
));
}
if request.mode != DistributionMode::ShardedCapacity {
return Err(MahoutError::InvalidInput(format!(
"Distributed amplitude planning currently supports only {:?}, got {:?}",
DistributionMode::ShardedCapacity,
request.mode
)));
}

let num_devices = mesh.num_devices();
let placement = PlacementPlanner::plan(mesh, &request)?;
Self::validate_local_shard_shape(request.num_qubits, &placement)?;
let global_len = placement.global_len;
let num_qubits = request.num_qubits;
let (shard_bits, uniform_shard_len) = match request.shard_policy {
ShardPolicy::Equal => {
debug_assert!(num_devices.is_power_of_two());
let shard_bits = num_devices.trailing_zeros() as usize;
if shard_bits > request.num_qubits {
return Err(MahoutError::InvalidInput(format!(
"Cannot shard {} qubits across {} devices: shard bits {} exceed qubit count",
request.num_qubits, num_devices, shard_bits
)));
}
(Some(shard_bits), Some(placement.shard_len()?))
}
ShardPolicy::BalancedUneven => (None, None),
};

Ok(Self {
request,
placement,
num_qubits,
global_len,
num_devices,
shard_bits,
uniform_shard_len,
})
}

/// Logical half-open amplitude range covered by one shard ID.
pub fn shard_range(&self, shard_id: usize) -> Result<(usize, usize)> {
let placement = self.placement.placements.get(shard_id).ok_or_else(|| {
MahoutError::InvalidInput(format!(
"Shard ID {} out of range for {} devices",
shard_id, self.num_devices
))
})?;
Ok((placement.start_idx, placement.end_idx))
}

/// Largest local shard length across the current placement.
pub fn max_local_len(&self) -> usize {
self.placement
.placements
.iter()
.map(ShardPlacement::local_len)
.max()
.unwrap_or(0)
}

/// Estimated bytes required by the largest local shard at one target precision.
pub fn estimated_max_shard_bytes(&self, precision: Precision) -> Result<usize> {
estimated_amplitude_bytes(self.max_local_len(), precision)
}

fn validate_local_shard_shape(num_qubits: usize, placement: &PlacementPlan) -> Result<()> {
let required_local_len = placement
.placements
.iter()
.map(ShardPlacement::local_len)
.max()
.ok_or_else(|| {
MahoutError::InvalidInput(
"Placement plan must contain at least one shard".to_string(),
)
})?;

if required_local_len == 0 {
return Err(MahoutError::InvalidInput(format!(
"Distributed amplitude request for {} qubits produced an empty local shard",
num_qubits
)));
}

let _ = estimated_amplitude_bytes(required_local_len, Precision::Float32)?;
let _ = estimated_amplitude_bytes(required_local_len, Precision::Float64)?;

Ok(())
}
}

fn estimated_amplitude_bytes(local_len: usize, precision: Precision) -> Result<usize> {
let bytes_per_amplitude = match precision {
Precision::Float32 => 8usize,
Precision::Float64 => 16usize,
};

local_len
.checked_mul(bytes_per_amplitude)
.ok_or_else(|| {
MahoutError::InvalidInput(format!(
"Distributed amplitude shard byte estimate overflowed for local_len={} and precision={:?}",
local_len, precision
))
})
}
Loading
Loading