Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,19 @@ pub(crate) fn gen_define_handling<'ll>(
transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
let transfer_from: Vec<u64> =
transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
let valid_kernel_mappings = MappingFlags::LITERAL | MappingFlags::IMPLICIT;
// FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
let transfer_kernel = vec![MappingFlags::TARGET_PARAM.bits(); transfer_to.len()];
let transfer_kernel: Vec<u64> = transfer
.iter()
.map(|m| (m.intersection(valid_kernel_mappings) | MappingFlags::TARGET_PARAM).bits())
.collect();

let actual_sizes = sizes
.iter()
.map(|s| match s {
OffloadSize::Static(sz) => *sz,
OffloadSize::Dynamic => 0,
// NOTE(Sa4dUs): set `.offload_sizes` entry to 0 for sizes that we determine at runtime, just like clang
_ => 0,
})
.collect::<Vec<_>>();
let offload_sizes =
Expand Down Expand Up @@ -542,12 +547,20 @@ pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
}

fn get_runtime_size<'ll, 'tcx>(
_cx: &CodegenCx<'ll, 'tcx>,
_val: &'ll Value,
_meta: &OffloadMetadata,
builder: &mut Builder<'_, 'll, 'tcx>,
args: &[&'ll Value],
index: usize,
meta: &OffloadMetadata,
) -> &'ll Value {
// FIXME(Sa4dUs): handle dynamic-size data (e.g. slices)
bug!("offload does not support dynamic sizes yet");
match meta.payload_size {
OffloadSize::Slice { element_size } => {
let length_idx = index + 1;
let length = args[length_idx];
let length_i64 = builder.intcast(length, builder.cx.type_i64(), false);
builder.mul(length_i64, builder.cx.get_const_i64(element_size))
}
_ => bug!("unexpected offload size {:?}", meta.payload_size),
}
}

// For each kernel *call*, we now use some of our previous declared globals to move data to and from
Expand Down Expand Up @@ -588,7 +601,7 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;

let has_dynamic = metadata.iter().any(|m| matches!(m.payload_size, OffloadSize::Dynamic));
let has_dynamic = metadata.iter().any(|m| !matches!(m.payload_size, OffloadSize::Static(_)));

let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;
Expand Down Expand Up @@ -683,9 +696,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
builder.store(geps[i as usize], gep2, Align::EIGHT);

if matches!(metadata[i as usize].payload_size, OffloadSize::Dynamic) {
if !matches!(metadata[i as usize].payload_size, OffloadSize::Static(_)) {
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
let size_val = get_runtime_size(cx, args[i as usize], &metadata[i as usize]);
let size_val = get_runtime_size(builder, args, i as usize, &metadata[i as usize]);
builder.store(size_val, gep3, Align::EIGHT);
}
}
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1813,9 +1813,20 @@ fn codegen_offload<'ll, 'tcx>(
let sig = tcx.instantiate_bound_regions_with_erased(sig);
let inputs = sig.inputs();

let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();
let fn_abi = cx.fn_abi_of_instance(fn_target, ty::List::empty());

let types = inputs.iter().map(|ty| cx.layout_of(*ty).llvm_type(cx)).collect::<Vec<_>>();
let mut metadata = Vec::new();
let mut types = Vec::new();

for (i, arg_abi) in fn_abi.args.iter().enumerate() {
let ty = inputs[i];
let decomposed = OffloadMetadata::handle_abi(cx, tcx, ty, arg_abi);

for (meta, entry_ty) in decomposed {
metadata.push(meta);
types.push(bx.cx.layout_of(entry_ty).llvm_type(bx.cx));
}
}

let offload_globals_ref = cx.offload_globals.borrow();
let offload_globals = match offload_globals_ref.as_ref() {
Expand Down
34 changes: 32 additions & 2 deletions compiler/rustc_middle/src/ty/offload_meta.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
use bitflags::bitflags;
use rustc_abi::{BackendRepr, TyAbiInterface};
use rustc_target::callconv::ArgAbi;

use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};

#[derive(Debug, Copy, Clone)]
pub struct OffloadMetadata {
pub payload_size: OffloadSize,
pub mode: MappingFlags,
}

#[derive(Debug, Copy, Clone)]
pub enum OffloadSize {
Dynamic,
Static(u64),
Slice { element_size: u64 },
}

bitflags! {
/// Mirrors `OpenMPOffloadMappingFlags` from Clang/OpenMP.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(transparent)]
pub struct MappingFlags: u64 {
/// No flags.
Expand Down Expand Up @@ -62,11 +65,38 @@ impl OffloadMetadata {
mode: MappingFlags::from_ty(tcx, ty),
}
}

pub fn handle_abi<'tcx, C>(
cx: &C,
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
arg_abi: &ArgAbi<'tcx, Ty<'tcx>>,
) -> Vec<(Self, Ty<'tcx>)>
where
Ty<'tcx>: TyAbiInterface<'tcx, C>,
{
match arg_abi.layout.backend_repr {
BackendRepr::ScalarPair(_, _) => (0..2)
.map(|i| {
let ty = arg_abi.layout.field(cx, i).ty;
(OffloadMetadata::from_ty(tcx, ty), ty)
})
.collect(),
_ => vec![(OffloadMetadata::from_ty(tcx, ty), ty)],
}
}
}

// FIXME(Sa4dUs): implement a solid logic to determine the payload size
fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> OffloadSize {
match ty.kind() {
ty::Slice(elem_ty) => {
let layout = tcx.layout_of(PseudoCanonicalInput {
typing_env: TypingEnv::fully_monomorphized(),
value: *elem_ty,
});
OffloadSize::Slice { element_size: layout.unwrap().size.bytes() }
}
ty::RawPtr(inner, _) | ty::Ref(_, inner, _) => get_payload_size(tcx, *inner),
_ => OffloadSize::Static(
tcx.layout_of(PseudoCanonicalInput {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use rustc_middle::hir::nested_filter;
use rustc_middle::ty::adjustment::{Adjust, Adjustment, AutoBorrow, DerefAdjustKind};
use rustc_middle::ty::print::{FmtPrinter, PrettyPrinter, Print, Printer};
use rustc_middle::ty::{
self, GenericArg, GenericArgKind, GenericArgsRef, InferConst, IsSuggestable, Term, TermKind,
Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt, TypeckResults,
self, GenericArg, GenericArgKind, GenericArgsRef, GenericParamDefKind, InferConst,
IsSuggestable, Term, TermKind, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
TypeVisitableExt, TypeckResults,
};
use rustc_span::{BytePos, DUMMY_SP, Ident, Span, sym};
use tracing::{debug, instrument, warn};
Expand Down Expand Up @@ -592,15 +593,19 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
(true, parent.prefix.to_string(), parent.name)
});

let param = &generics.own_params[argument_index];
let param_name = param.name.to_string();

infer_subdiags.push(SourceKindSubdiag::GenericLabel {
span,
is_type,
param_name: generics.own_params[argument_index].name.to_string(),
param_name: param_name.clone(),
parent_exists,
parent_prefix,
parent_name,
});

let mut used_fallback = false;
let args = if self.tcx.get_diagnostic_item(sym::iterator_collect_fn)
== Some(generics_def_id)
{
Expand Down Expand Up @@ -634,9 +639,9 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
let mut p = fmt_printer(self, Namespace::TypeNS);
p.comma_sep(generic_args.iter().copied().map(|arg| {
if arg.is_suggestable(self.tcx, true) {
used_fallback = true;
return arg;
}

match arg.kind() {
GenericArgKind::Lifetime(_) => bug!("unexpected lifetime"),
GenericArgKind::Type(_) => self.next_ty_var(DUMMY_SP).into(),
Expand All @@ -648,11 +653,31 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
};

if !have_turbofish {
infer_subdiags.push(SourceKindSubdiag::GenericSuggestion {
span: insert_span,
arg_count: generic_args.len(),
args,
});
if generic_args.len() == 1 && used_fallback {
match param.kind {
GenericParamDefKind::Type { .. } => {
infer_subdiags.push(SourceKindSubdiag::GenericTypeSuggestion {
span: insert_span,
param: param_name,
});
}
GenericParamDefKind::Const { .. } => {
infer_subdiags.push(SourceKindSubdiag::ConstGenericSuggestion {
span: insert_span,
param: param_name,
});
}
GenericParamDefKind::Lifetime => {
bug!("unexpected lifetime")
}
}
} else {
infer_subdiags.push(SourceKindSubdiag::GenericSuggestion {
span: insert_span,
arg_count: generic_args.len(),
args,
});
}
}
}
InferSourceKind::FullyQualifiedMethodCall { receiver, successor, args, def_id } => {
Expand Down
Loading
Loading