diff --git a/compiler/rustc_attr_parsing/src/attributes/doc.rs b/compiler/rustc_attr_parsing/src/attributes/doc.rs index a5a9596a9e715..09d474a227fb7 100644 --- a/compiler/rustc_attr_parsing/src/attributes/doc.rs +++ b/compiler/rustc_attr_parsing/src/attributes/doc.rs @@ -15,7 +15,8 @@ use super::{AcceptMapping, AttributeParser}; use crate::context::{AcceptContext, FinalizeContext, Stage}; use crate::errors::{ DocAliasDuplicated, DocAutoCfgExpectsHideOrShow, DocAutoCfgHideShowExpectsList, - DocAutoCfgHideShowUnexpectedItem, DocUnknownInclude, IllFormedAttributeInput, + DocAutoCfgHideShowUnexpectedItem, DocAutoCfgWrongLiteral, DocUnknownAny, DocUnknownInclude, + DocUnknownPasses, DocUnknownPlugins, DocUnknownSpotlight, IllFormedAttributeInput, }; use crate::parser::{ArgParser, MetaItemOrLitParser, MetaItemParser, OwnedPathParser}; use crate::session_diagnostics::{ @@ -442,9 +443,9 @@ impl DocParser { ArgParser::NameValue(nv) => { let MetaItemLit { kind: LitKind::Bool(bool_value), span, .. } = nv.value_as_lit() else { - cx.emit_lint( + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocAutoCfgWrongLiteral, + move |dcx, level| DocAutoCfgWrongLiteral.into_diag(dcx, level), nv.value_span, ); return; @@ -613,10 +614,11 @@ impl DocParser { } } Some(sym::spotlight) => { - cx.emit_lint( + let span = path.span(); + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocUnknownSpotlight { span: path.span() }, - path.span(), + move |dcx, level| DocUnknownSpotlight { sugg_span: span }.into_diag(dcx, level), + span, ); } Some(sym::include) if let Some(nv) = args.name_value() => { @@ -640,32 +642,37 @@ impl DocParser { ); } Some(name @ (sym::passes | sym::no_default_passes)) => { - cx.emit_lint( + let span = path.span(); + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocUnknownPasses { name, span: path.span() }, - path.span(), + move |dcx, level| { + DocUnknownPasses { name, note_span: span }.into_diag(dcx, level) + }, + span, ); } Some(sym::plugins) => { - cx.emit_lint( + let span = path.span(); + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocUnknownPlugins { span: path.span() }, - path.span(), + move |dcx, level| DocUnknownPlugins { label_span: span }.into_diag(dcx, level), + span, ); } Some(name) => { - cx.emit_lint( + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocUnknownAny { name }, + move |dcx, level| DocUnknownAny { name }.into_diag(dcx, level), path.span(), ); } None => { let full_name = path.segments().map(|s| s.as_str()).intersperse("::").collect::(); - cx.emit_lint( + let name = Symbol::intern(&full_name); + cx.emit_dyn_lint( rustc_session::lint::builtin::INVALID_DOC_ATTRIBUTES, - AttributeLintKind::DocUnknownAny { name: Symbol::intern(&full_name) }, + move |dcx, level| DocUnknownAny { name }.into_diag(dcx, level), path.span(), ); } diff --git a/compiler/rustc_attr_parsing/src/errors.rs b/compiler/rustc_attr_parsing/src/errors.rs index 2ba50439c4e3c..ca57c25f25a0e 100644 --- a/compiler/rustc_attr_parsing/src/errors.rs +++ b/compiler/rustc_attr_parsing/src/errors.rs @@ -205,3 +205,50 @@ pub(crate) struct DocUnknownInclude { )] pub sugg: (Span, Applicability), } + +#[derive(Diagnostic)] +#[diag("unknown `doc` attribute `spotlight`")] +#[note("`doc(spotlight)` was renamed to `doc(notable_trait)`")] +#[note("`doc(spotlight)` is now a no-op")] +pub(crate) struct DocUnknownSpotlight { + #[suggestion( + "use `notable_trait` instead", + style = "short", + applicability = "machine-applicable", + code = "notable_trait" + )] + pub sugg_span: Span, +} + +#[derive(Diagnostic)] +#[diag("unknown `doc` attribute `{$name}`")] +#[note( + "`doc` attribute `{$name}` no longer functions; see issue #44136 " +)] +#[note("`doc({$name})` is now a no-op")] +pub(crate) struct DocUnknownPasses { + pub name: Symbol, + #[label("no longer functions")] + pub note_span: Span, +} + +#[derive(Diagnostic)] +#[diag("unknown `doc` attribute `plugins`")] +#[note( + "`doc` attribute `plugins` no longer functions; see issue #44136 and CVE-2018-1000622 " +)] +#[note("`doc(plugins)` is now a no-op")] +pub(crate) struct DocUnknownPlugins { + #[label("no longer functions")] + pub label_span: Span, +} + +#[derive(Diagnostic)] +#[diag("unknown `doc` attribute `{$name}`")] +pub(crate) struct DocUnknownAny { + pub name: Symbol, +} + +#[derive(Diagnostic)] +#[diag("expected boolean for `#[doc(auto_cfg = ...)]`")] +pub(crate) struct DocAutoCfgWrongLiteral; diff --git a/compiler/rustc_lint/src/early/diagnostics.rs b/compiler/rustc_lint/src/early/diagnostics.rs index a6a065040420e..4a0320cbaf80d 100644 --- a/compiler/rustc_lint/src/early/diagnostics.rs +++ b/compiler/rustc_lint/src/early/diagnostics.rs @@ -43,26 +43,6 @@ impl<'a> Diagnostic<'a, ()> for DecorateAttrLint<'_, '_, '_> { .into_diag(dcx, level) } - &AttributeLintKind::DocUnknownSpotlight { span } => { - lints::DocUnknownSpotlight { sugg_span: span }.into_diag(dcx, level) - } - - &AttributeLintKind::DocUnknownPasses { name, span } => { - lints::DocUnknownPasses { name, note_span: span }.into_diag(dcx, level) - } - - &AttributeLintKind::DocUnknownPlugins { span } => { - lints::DocUnknownPlugins { label_span: span }.into_diag(dcx, level) - } - - &AttributeLintKind::DocUnknownAny { name } => { - lints::DocUnknownAny { name }.into_diag(dcx, level) - } - - &AttributeLintKind::DocAutoCfgWrongLiteral => { - lints::DocAutoCfgWrongLiteral.into_diag(dcx, level) - } - &AttributeLintKind::DocTestTakesList => lints::DocTestTakesList.into_diag(dcx, level), &AttributeLintKind::DocTestUnknown { name } => { diff --git a/compiler/rustc_lint/src/lints.rs b/compiler/rustc_lint/src/lints.rs index 8912c8b03fbdd..19fabc51ae536 100644 --- a/compiler/rustc_lint/src/lints.rs +++ b/compiler/rustc_lint/src/lints.rs @@ -3303,53 +3303,6 @@ pub(crate) struct ExpectedNoArgs; )] pub(crate) struct ExpectedNameValue; -#[derive(Diagnostic)] -#[diag("unknown `doc` attribute `spotlight`")] -#[note("`doc(spotlight)` was renamed to `doc(notable_trait)`")] -#[note("`doc(spotlight)` is now a no-op")] -pub(crate) struct DocUnknownSpotlight { - #[suggestion( - "use `notable_trait` instead", - style = "short", - applicability = "machine-applicable", - code = "notable_trait" - )] - pub sugg_span: Span, -} - -#[derive(Diagnostic)] -#[diag("unknown `doc` attribute `{$name}`")] -#[note( - "`doc` attribute `{$name}` no longer functions; see issue #44136 " -)] -#[note("`doc({$name})` is now a no-op")] -pub(crate) struct DocUnknownPasses { - pub name: Symbol, - #[label("no longer functions")] - pub note_span: Span, -} - -#[derive(Diagnostic)] -#[diag("unknown `doc` attribute `plugins`")] -#[note( - "`doc` attribute `plugins` no longer functions; see issue #44136 and CVE-2018-1000622 " -)] -#[note("`doc(plugins)` is now a no-op")] -pub(crate) struct DocUnknownPlugins { - #[label("no longer functions")] - pub label_span: Span, -} - -#[derive(Diagnostic)] -#[diag("unknown `doc` attribute `{$name}`")] -pub(crate) struct DocUnknownAny { - pub name: Symbol, -} - -#[derive(Diagnostic)] -#[diag("expected boolean for `#[doc(auto_cfg = ...)]`")] -pub(crate) struct DocAutoCfgWrongLiteral; - #[derive(Diagnostic)] #[diag("`#[doc(test(...)]` takes a list of attributes")] pub(crate) struct DocTestTakesList; diff --git a/compiler/rustc_lint_defs/src/lib.rs b/compiler/rustc_lint_defs/src/lib.rs index 540e4afc50ff0..96c7dec3d8188 100644 --- a/compiler/rustc_lint_defs/src/lib.rs +++ b/compiler/rustc_lint_defs/src/lib.rs @@ -656,11 +656,6 @@ pub enum DeprecatedSinceKind { pub enum AttributeLintKind { UnexpectedCfgName((Symbol, Span), Option<(Symbol, Span)>), UnexpectedCfgValue((Symbol, Span), Option<(Symbol, Span)>), - DocUnknownSpotlight { span: Span }, - DocUnknownPasses { name: Symbol, span: Span }, - DocUnknownPlugins { span: Span }, - DocUnknownAny { name: Symbol }, - DocAutoCfgWrongLiteral, DocTestTakesList, DocTestUnknown { name: Symbol }, DocTestLiteral, diff --git a/compiler/rustc_parse/src/parser/expr.rs b/compiler/rustc_parse/src/parser/expr.rs index 437102d549e79..a0601cc71d032 100644 --- a/compiler/rustc_parse/src/parser/expr.rs +++ b/compiler/rustc_parse/src/parser/expr.rs @@ -3464,7 +3464,9 @@ impl<'a> Parser<'a> { } pub(crate) fn eat_metavar_guard(&mut self) -> Option> { - self.eat_metavar_seq(MetaVarKind::Guard, |this| this.parse_match_arm_guard()).flatten() + self.eat_metavar_seq(MetaVarKind::Guard, |this| { + this.expect_match_arm_guard(ForceCollect::Yes) + }) } fn parse_match_arm_guard(&mut self) -> PResult<'a, Option>> { diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index c255d546e3933..81ccfcc434547 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -2641,7 +2641,7 @@ pub fn build_session_options(early_dcx: &mut EarlyDiagCtxt, matches: &getopts::M if unstable_opts.retpoline_external_thunk { unstable_opts.retpoline = true; collected_options.target_modifiers.insert( - OptionsTargetModifiers::UnstableOptions(UnstableOptionsTargetModifiers::retpoline), + OptionsTargetModifiers::UnstableOptions(UnstableOptionsTargetModifiers::Retpoline), "true".to_string(), ); } diff --git a/compiler/rustc_session/src/lib.rs b/compiler/rustc_session/src/lib.rs index 04e12f1afce68..c5b7a5e8450da 100644 --- a/compiler/rustc_session/src/lib.rs +++ b/compiler/rustc_session/src/lib.rs @@ -5,6 +5,7 @@ #![feature(default_field_values)] #![feature(iter_intersperse)] #![feature(macro_derive)] +#![feature(macro_metavar_expr)] #![feature(rustc_attrs)] // To generate CodegenOptionsTargetModifiers and UnstableOptionsTargetModifiers enums // with macro_rules, it is necessary to use recursive mechanic ("Incremental TT Munchers"). diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 9580642ba72bd..ef3e061e78a5f 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -141,10 +141,10 @@ impl TargetModifier { assert!(other.is_none() || self.opt == other.unwrap().opt); match self.opt { OptionsTargetModifiers::UnstableOptions(unstable) => match unstable { - UnstableOptionsTargetModifiers::sanitizer => { + UnstableOptionsTargetModifiers::Sanitizer => { return target_modifier_consistency_check::sanitizer(self, other); } - UnstableOptionsTargetModifiers::sanitizer_cfi_normalize_integers => { + UnstableOptionsTargetModifiers::SanitizerCfiNormalizeIntegers => { return target_modifier_consistency_check::sanitizer_cfi_normalize_integers( sess, self, other, ); @@ -170,164 +170,52 @@ fn tmod_push_impl( } } -macro_rules! tmod_push { - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr) => { - if *$opt_expr != $init { - tmod_push_impl( - OptionsTargetModifiers::$struct_name($tmod_enum_name::$opt_name), - $tmod_vals, - $mods, - ); - } - }; -} - -macro_rules! gather_tmods { - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [SUBSTRUCT], [TARGET_MODIFIER]) => { - compile_error!("SUBSTRUCT can't be target modifier"); - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [UNTRACKED], [TARGET_MODIFIER]) => { - tmod_push!($struct_name, $tmod_enum_name, $opt_name, $opt_expr, $init, $mods, $tmod_vals) - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [TRACKED], [TARGET_MODIFIER]) => { - tmod_push!($struct_name, $tmod_enum_name, $opt_name, $opt_expr, $init, $mods, $tmod_vals) - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [TRACKED_NO_CRATE_HASH], [TARGET_MODIFIER]) => { - tmod_push!($struct_name, $tmod_enum_name, $opt_name, $opt_expr, $init, $mods, $tmod_vals) - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [SUBSTRUCT], [$(MITIGATION)?]) => { - $opt_expr.gather_target_modifiers($mods, $tmod_vals); - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [UNTRACKED], [$(MITIGATION)?]) => {{}}; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [TRACKED], [$(MITIGATION)?]) => {{}}; - ($struct_name:ident, $tmod_enum_name:ident, $opt_name:ident, $opt_expr:expr, $init:expr, $mods:expr, $tmod_vals:expr, - [TRACKED_NO_CRATE_HASH], [$(MITIGATION)?]) => {{}}; -} - -macro_rules! gather_tmods_top_level { - ($_opt_name:ident, $opt_expr:expr, $mods:expr, $tmod_vals:expr, [SUBSTRUCT $substruct_enum:ident]) => { - $opt_expr.gather_target_modifiers($mods, $tmod_vals); - }; - ($opt_name:ident, $opt_expr:expr, $mods:expr, $tmod_vals:expr, [$non_substruct:ident TARGET_MODIFIER]) => { - compile_error!("Top level option can't be target modifier"); - }; - ($opt_name:ident, $opt_expr:expr, $mods:expr, $tmod_vals:expr, [$non_substruct:ident $(MITIGATION)?]) => {}; -} - -/// Macro for generating OptionsTargetsModifiers top-level enum with impl. -/// Will generate something like: -/// ```rust,ignore (illustrative) -/// pub enum OptionsTargetModifiers { -/// CodegenOptions(CodegenOptionsTargetModifiers), -/// UnstableOptions(UnstableOptionsTargetModifiers), -/// } -/// impl OptionsTargetModifiers { -/// pub fn reparse(&self, user_value: &str) -> ExtendedTargetModifierInfo { -/// match self { -/// Self::CodegenOptions(v) => v.reparse(user_value), -/// Self::UnstableOptions(v) => v.reparse(user_value), -/// } -/// } -/// pub fn is_target_modifier(flag_name: &str) -> bool { -/// CodegenOptionsTargetModifiers::is_target_modifier(flag_name) || -/// UnstableOptionsTargetModifiers::is_target_modifier(flag_name) -/// } -/// } -/// ``` -macro_rules! top_level_tmod_enum { - ($( {$($optinfo:tt)*} ),* $(,)*) => { - top_level_tmod_enum! { @parse {}, (user_value){}; $($($optinfo)*|)* } - }; - // Termination +macro_rules! top_level_options { ( - @parse - {$($variant:tt($substruct_enum:tt))*}, - ($user_value:ident){$($pout:tt)*}; + $(#[$top_level_attr:meta])* + pub struct Options { + $( + $(#[$attr:meta])* + $opt:ident : $t:ty + [$dep_tracking_marker:ident] + $( { TARGET_MODIFIER: $tmod_variant:ident($tmod_enum:ident) } )? + , + )* + } ) => { - #[allow(non_camel_case_types)] #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone, Encodable, BlobDecodable)] pub enum OptionsTargetModifiers { - $($variant($substruct_enum)),* + $( + $( + $tmod_variant($tmod_enum), + )? + )* } + impl OptionsTargetModifiers { - #[allow(unused_variables)] - pub fn reparse(&self, $user_value: &str) -> ExtendedTargetModifierInfo { - #[allow(unreachable_patterns)] + pub fn reparse(&self, user_value: &str) -> ExtendedTargetModifierInfo { match self { - $($pout)* - _ => panic!("unknown target modifier option: {:?}", *self) + $( + $( + Self::$tmod_variant(v) => v.reparse(user_value), + )? + )* + #[allow(unreachable_patterns)] + _ => panic!("unknown target modifier option: {self:?}"), } } - pub fn is_target_modifier(flag_name: &str) -> bool { - $($substruct_enum::is_target_modifier(flag_name))||* - } - } - }; - // Adding SUBSTRUCT option group into $eout - ( - @parse {$($eout:tt)*}, ($puser_value:ident){$($pout:tt)*}; - [SUBSTRUCT $substruct_enum:ident $variant:ident] | - $($tail:tt)* - ) => { - top_level_tmod_enum! { - @parse - { - $($eout)* - $variant($substruct_enum) - }, - ($puser_value){ - $($pout)* - Self::$variant(v) => v.reparse($puser_value), - }; - $($tail)* - } - }; - // Skipping non-target-modifier and non-substruct - ( - @parse {$($eout:tt)*}, ($puser_value:ident){$($pout:tt)*}; - [$non_substruct:ident] | - $($tail:tt)* - ) => { - top_level_tmod_enum! { - @parse - { - $($eout)* - }, - ($puser_value){ - $($pout)* - }; - $($tail)* - } - }; -} -macro_rules! top_level_options { - ( - $(#[$top_level_attr:meta])* - pub struct Options { - $( - $(#[$attr:meta])* - $opt:ident : $t:ty [ - $dep_tracking_marker:ident - $( $tmod:ident $variant:ident )? - ], - )* - } - ) => { - top_level_tmod_enum!( - { + pub fn is_target_modifier(flag_name: &str) -> bool { $( - [$dep_tracking_marker $($tmod $variant),*] - )|* + $( + if $tmod_enum::is_target_modifier(flag_name) { + return true + } + )? + )* + false } - ); + } #[derive(Clone)] $(#[$top_level_attr])* @@ -375,13 +263,11 @@ macro_rules! top_level_options { pub fn gather_target_modifiers(&self) -> Vec { let mut mods = Vec::::new(); $( - gather_tmods_top_level!( - $opt, - &self.$opt, - &mut mods, - &self.target_modifiers, - [$dep_tracking_marker $($tmod),*] - ); + $( + // Only expand for flags that have `TARGET_MODIFIER`. + ${ignore($tmod_enum)} + self.$opt.gather_target_modifiers(&mut mods, &self.target_modifiers); + )? )* mods.sort_by(|a, b| a.opt.cmp(&b.opt)); mods @@ -451,9 +337,9 @@ top_level_options!( #[rustc_lint_opt_deny_field_access("should only be used via `Config::track_state`")] untracked_state_hash: Hash64 [TRACKED_NO_CRATE_HASH], - unstable_opts: UnstableOptions [SUBSTRUCT UnstableOptionsTargetModifiers UnstableOptions], + unstable_opts: UnstableOptions [SUBSTRUCT] { TARGET_MODIFIER: UnstableOptions(UnstableOptionsTargetModifiers) }, prints: Vec [UNTRACKED], - cg: CodegenOptions [SUBSTRUCT CodegenOptionsTargetModifiers CodegenOptions], + cg: CodegenOptions [SUBSTRUCT] { TARGET_MODIFIER: CodegenOptions(CodegenOptionsTargetModifiers) }, externs: Externs [UNTRACKED], crate_name: Option [TRACKED], /// Indicates how the compiler should treat unstable features. @@ -530,108 +416,6 @@ top_level_options!( } ); -macro_rules! mitigation_enum_opt { - ($opt:ident, MITIGATION) => { - Some(mitigation_coverage::DeniedPartialMitigationKind::$opt) - }; - ($opt:ident, $(TARGET_MODIFIER)?) => { - None - }; -} - -macro_rules! tmod_enum_opt { - ($struct_name:ident, $tmod_enum_name:ident, $opt:ident, TARGET_MODIFIER) => { - Some(OptionsTargetModifiers::$struct_name($tmod_enum_name::$opt)) - }; - ($struct_name:ident, $tmod_enum_name:ident, $opt:ident, $(MITIGATION)?) => { - None - }; -} - -macro_rules! tmod_enum { - ($tmod_enum_name:ident, $prefix:expr, $( {$($optinfo:tt)*} ),* $(,)*) => { - tmod_enum! { $tmod_enum_name, $prefix, @parse {}, (user_value){}; $($($optinfo)*|)* } - }; - // Termination - ( - $tmod_enum_name:ident, $prefix:expr, - @parse - {$($eout:tt)*}, - ($user_value:ident){$($pout:tt)*}; - ) => { - #[allow(non_camel_case_types)] - #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone, Encodable, BlobDecodable)] - pub enum $tmod_enum_name { - $($eout),* - } - impl $tmod_enum_name { - #[allow(unused_variables)] - pub fn reparse(&self, $user_value: &str) -> ExtendedTargetModifierInfo { - #[allow(unreachable_patterns)] - match self { - $($pout)* - _ => panic!("unknown target modifier option: {:?}", *self) - } - } - pub fn is_target_modifier(flag_name: &str) -> bool { - match flag_name.replace('-', "_").as_str() { - $(stringify!($eout) => true,)* - _ => false, - } - } - } - }; - // Adding target-modifier option into $eout - ( - $tmod_enum_name:ident, $prefix:expr, - @parse {$($eout:tt)*}, ($puser_value:ident){$($pout:tt)*}; - $opt:ident, $parse:ident, $t:ty, [TARGET_MODIFIER] | - $($tail:tt)* - ) => { - tmod_enum! { - $tmod_enum_name, $prefix, - @parse - { - $($eout)* - $opt - }, - ($puser_value){ - $($pout)* - Self::$opt => { - let mut parsed : $t = Default::default(); - let val = if $puser_value.is_empty() { None } else { Some($puser_value) }; - parse::$parse(&mut parsed, val); - ExtendedTargetModifierInfo { - prefix: $prefix.to_string(), - name: stringify!($opt).to_string().replace('_', "-"), - tech_value: format!("{:?}", parsed), - } - }, - }; - $($tail)* - } - }; - // Skipping non-target-modifier - ( - $tmod_enum_name:ident, $prefix:expr, - @parse {$($eout:tt)*}, ($puser_value:ident){$($pout:tt)*}; - $opt:ident, $parse:ident, $t:ty, [$(MITIGATION)?] | - $($tail:tt)* - ) => { - tmod_enum! { - $tmod_enum_name, $prefix, - @parse - { - $($eout)* - }, - ($puser_value){ - $($pout)* - }; - $($tail)* - } - }; -} - #[derive(Default)] pub struct CollectedOptions { pub target_modifiers: BTreeMap, @@ -684,7 +468,7 @@ macro_rules! setter_for { macro_rules! options { ( $struct_name:ident, - $tmod_enum_name:ident, + $tmod_enum:ident, $stat:ident, $optmod:ident, $prefix:expr, @@ -695,8 +479,11 @@ macro_rules! options { $opt:ident : $t:ty = ( $init:expr, $parse:ident, - [$dep_tracking_marker:ident $( $modifier_kind:ident )?], - $desc:expr + [$dep_tracking_marker:ident] + $( { TARGET_MODIFIER: $tmod_variant:ident } )? + $( { MITIGATION: $mitigation_variant:ident } )? + , + $desc:literal $(, removed: $removed:ident )? ), )* @@ -710,15 +497,49 @@ macro_rules! options { )* } - tmod_enum!( - $tmod_enum_name, - $prefix, - { - $( - $opt, $parse, $t, [$($modifier_kind),*] - )|* + #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone, Encodable, BlobDecodable)] + pub enum $tmod_enum { + $( + $( $tmod_variant, )? + )* + } + + impl $tmod_enum { + pub fn reparse(&self, _user_value: &str) -> ExtendedTargetModifierInfo { + match self { + $( + $( + Self::$tmod_variant => { + let mut parsed: $t = Default::default(); + let val = if _user_value.is_empty() { None } else { Some(_user_value) }; + parse::$parse(&mut parsed, val); + ExtendedTargetModifierInfo { + prefix: $prefix.to_string(), + name: stringify!($opt).to_string().replace('_', "-"), + tech_value: format!("{:?}", parsed), + } + } + )? + )* + + #[allow(unreachable_patterns)] + _ => panic!("unknown target modifier option: {:?}", *self) + } } - ); + + pub fn is_target_modifier(flag_name: &str) -> bool { + match flag_name.replace('-', "_").as_str() { + $( + $( + // Only expand for flags that have `TARGET_MODIFIER`. + ${ignore($tmod_variant)} + stringify!($opt) => true, + )? + )* + _ => false, + } + } + } impl Default for $struct_name { fn default() -> $struct_name { @@ -770,17 +591,15 @@ macro_rules! options { _tmod_vals: &BTreeMap, ) { $( - gather_tmods!( - $struct_name, - $tmod_enum_name, - $opt, - &self.$opt, - $init, - _mods, - _tmod_vals, - [$dep_tracking_marker], - [$($modifier_kind),*] - ); + $( + if self.$opt != $init { + tmod_push_impl( + OptionsTargetModifiers::$struct_name($tmod_enum::$tmod_variant), + _tmod_vals, + _mods, + ); + } + )? )* } } @@ -793,8 +612,12 @@ macro_rules! options { type_desc: desc::$parse, desc: $desc, removed: None $( .or(Some(RemovedOption::$removed)) )?, - tmod: tmod_enum_opt!($struct_name, $tmod_enum_name, $opt, $($modifier_kind),*), - mitigation: mitigation_enum_opt!($opt, $($modifier_kind),*), + tmod: None $( .or(Some( + OptionsTargetModifiers::$struct_name($tmod_enum::$tmod_variant) + )))?, + mitigation: None $( .or(Some( + mitigation_coverage::DeniedPartialMitigationKind::$mitigation_variant + )))?, }, )* ]; @@ -2230,7 +2053,7 @@ options! { collapse_macro_debuginfo: CollapseMacroDebuginfo = (CollapseMacroDebuginfo::Unspecified, parse_collapse_macro_debuginfo, [TRACKED], "set option to collapse debuginfo for macros"), - control_flow_guard: CFGuard = (CFGuard::Disabled, parse_cfguard, [TRACKED MITIGATION], + control_flow_guard: CFGuard = (CFGuard::Disabled, parse_cfguard, [TRACKED] { MITIGATION: ControlFlowGuard }, "use Windows Control Flow Guard (default: no)"), debug_assertions: Option = (None, parse_opt_bool, [TRACKED], "explicitly enable the `cfg(debug_assertions)` directive"), @@ -2409,7 +2232,7 @@ options! { (default: no)"), box_noalias: bool = (true, parse_bool, [TRACKED], "emit noalias metadata for box (default: yes)"), - branch_protection: Option = (None, parse_branch_protection, [TRACKED TARGET_MODIFIER], + branch_protection: Option = (None, parse_branch_protection, [TRACKED] { TARGET_MODIFIER: BranchProtection }, "set options for branch target identification and pointer authentication on AArch64"), build_sdylib_interface: bool = (false, parse_bool, [UNTRACKED], "whether the stable interface is being built"), @@ -2510,7 +2333,7 @@ options! { fewer_names: Option = (None, parse_opt_bool, [TRACKED], "reduce memory use by retaining fewer names within compilation artifacts (LLVM-IR) \ (default: no)"), - fixed_x18: bool = (false, parse_bool, [TRACKED TARGET_MODIFIER], + fixed_x18: bool = (false, parse_bool, [TRACKED] { TARGET_MODIFIER: FixedX18 }, "make the x18 register reserved on AArch64 (default: no)"), flatten_format_args: bool = (true, parse_bool, [TRACKED], "flatten nested format_args!() and literals into a simplified format_args!() call \ @@ -2554,7 +2377,7 @@ options! { - hashes of green query instances - hash collisions of query keys - hash collisions when creating dep-nodes"), - indirect_branch_cs_prefix: bool = (false, parse_bool, [TRACKED TARGET_MODIFIER], + indirect_branch_cs_prefix: bool = (false, parse_bool, [TRACKED] { TARGET_MODIFIER: IndirectBranchCsPrefix }, "add `cs` prefix to `call` and `jmp` to indirect thunks (default: no)"), inline_llvm: bool = (true, parse_bool, [TRACKED], "enable LLVM inlining (default: yes)"), @@ -2749,10 +2572,10 @@ options! { "enable queries of the dependency graph for regression testing (default: no)"), randomize_layout: bool = (false, parse_bool, [TRACKED], "randomize the layout of types (default: no)"), - reg_struct_return: bool = (false, parse_bool, [TRACKED TARGET_MODIFIER], + reg_struct_return: bool = (false, parse_bool, [TRACKED] { TARGET_MODIFIER: RegStructReturn }, "On x86-32 targets, it overrides the default ABI to return small structs in registers. It is UNSOUND to link together crates that use different values for this flag!"), - regparm: Option = (None, parse_opt_number, [TRACKED TARGET_MODIFIER], + regparm: Option = (None, parse_opt_number, [TRACKED] { TARGET_MODIFIER: Regparm }, "On x86-32 targets, setting this to N causes the compiler to pass N arguments \ in registers EAX, EDX, and ECX instead of on the stack for\ \"C\", \"cdecl\", and \"stdcall\" fn.\ @@ -2764,19 +2587,19 @@ options! { remark_dir: Option = (None, parse_opt_pathbuf, [UNTRACKED], "directory into which to write optimization remarks (if not specified, they will be \ written to standard error output)"), - retpoline: bool = (false, parse_bool, [TRACKED TARGET_MODIFIER], + retpoline: bool = (false, parse_bool, [TRACKED] { TARGET_MODIFIER: Retpoline }, "enables retpoline-indirect-branches and retpoline-indirect-calls target features (default: no)"), - retpoline_external_thunk: bool = (false, parse_bool, [TRACKED TARGET_MODIFIER], + retpoline_external_thunk: bool = (false, parse_bool, [TRACKED] { TARGET_MODIFIER: RetpolineExternalThunk }, "enables retpoline-external-thunk, retpoline-indirect-branches and retpoline-indirect-calls \ target features (default: no)"), #[rustc_lint_opt_deny_field_access("use `Session::sanitizers()` instead of this field")] - sanitizer: SanitizerSet = (SanitizerSet::empty(), parse_sanitizers, [TRACKED TARGET_MODIFIER], + sanitizer: SanitizerSet = (SanitizerSet::empty(), parse_sanitizers, [TRACKED] { TARGET_MODIFIER: Sanitizer }, "use a sanitizer"), sanitizer_cfi_canonical_jump_tables: Option = (Some(true), parse_opt_bool, [TRACKED], "enable canonical jump tables (default: yes)"), sanitizer_cfi_generalize_pointers: Option = (None, parse_opt_bool, [TRACKED], "enable generalizing pointer types (default: no)"), - sanitizer_cfi_normalize_integers: Option = (None, parse_opt_bool, [TRACKED TARGET_MODIFIER], + sanitizer_cfi_normalize_integers: Option = (None, parse_opt_bool, [TRACKED] { TARGET_MODIFIER: SanitizerCfiNormalizeIntegers }, "enable normalizing integer types (default: no)"), sanitizer_dataflow_abilist: Vec = (Vec::new(), parse_comma_list, [TRACKED], "additional ABI list files that control how shadow parameters are passed (comma separated)"), @@ -2836,7 +2659,7 @@ written to standard error output)"), src_hash_algorithm: Option = (None, parse_src_file_hash, [TRACKED], "hash algorithm of source files in debug info (`md5`, `sha1`, or `sha256`)"), #[rustc_lint_opt_deny_field_access("use `Session::stack_protector` instead of this field")] - stack_protector: StackProtector = (StackProtector::None, parse_stack_protector, [TRACKED MITIGATION], + stack_protector: StackProtector = (StackProtector::None, parse_stack_protector, [TRACKED] { MITIGATION: StackProtector }, "control stack smash protection strategy (`rustc --print stack-protector-strategies` for details)"), staticlib_allow_rdylib_deps: bool = (false, parse_bool, [TRACKED], "allow staticlibs to have rust dylib dependencies"), diff --git a/compiler/rustc_session/src/options/mitigation_coverage.rs b/compiler/rustc_session/src/options/mitigation_coverage.rs index f396392cd2638..dbe989100d567 100644 --- a/compiler/rustc_session/src/options/mitigation_coverage.rs +++ b/compiler/rustc_session/src/options/mitigation_coverage.rs @@ -133,7 +133,6 @@ macro_rules! intersperse { macro_rules! denied_partial_mitigations { ([$self:ident] enum $kind:ident {$(($name:ident, $text:expr, $since:ident, $code:expr)),*}) => { - #[allow(non_camel_case_types)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Encodable, BlobDecodable)] pub enum DeniedPartialMitigationKind { $($name),* @@ -204,8 +203,8 @@ denied_partial_mitigations! { enum DeniedPartialMitigationKind { // The mitigation name should match the option name in rustc_session::options, // to allow for resetting the mitigation - (stack_protector, "stack-protector", EditionFuture, self.stack_protector()), - (control_flow_guard, "control-flow-guard", EditionFuture, self.opts.cg.control_flow_guard == CFGuard::Checks) + (StackProtector, "stack-protector", EditionFuture, self.stack_protector()), + (ControlFlowGuard, "control-flow-guard", EditionFuture, self.opts.cg.control_flow_guard == CFGuard::Checks) } } diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs index 7cead434bdad3..bf4c4287c433a 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs @@ -2748,6 +2748,11 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { obligation.param_env, obligation.cause.code(), ); + self.suggest_borrow_for_unsized_closure_return( + obligation.cause.body_id, + err, + obligation.predicate, + ); self.suggest_unsized_bound_if_applicable(err, obligation); if let Some(span) = err.span.primary_span() && let Some(mut diag) = diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs index 66eaa49cbd5d4..9e45023144def 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/suggestions.rs @@ -2196,6 +2196,60 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { false } + pub(super) fn suggest_borrow_for_unsized_closure_return( + &self, + body_id: LocalDefId, + err: &mut Diag<'_, G>, + predicate: ty::Predicate<'tcx>, + ) { + let Some(pred) = predicate.as_trait_clause() else { + return; + }; + if !self.tcx.is_lang_item(pred.def_id(), LangItem::Sized) { + return; + } + + let Some(span) = err.span.primary_span() else { + return; + }; + let Some(node_body_id) = self.tcx.hir_node_by_def_id(body_id).body_id() else { + return; + }; + let body = self.tcx.hir_body(node_body_id); + let mut expr_finder = FindExprBySpan::new(span, self.tcx); + expr_finder.visit_expr(body.value); + let Some(expr) = expr_finder.result else { + return; + }; + + let closure = match expr.kind { + hir::ExprKind::Call(_, args) => args.iter().find_map(|arg| match arg.kind { + hir::ExprKind::Closure(closure) => Some(closure), + _ => None, + }), + hir::ExprKind::MethodCall(_, _, args, _) => { + args.iter().find_map(|arg| match arg.kind { + hir::ExprKind::Closure(closure) => Some(closure), + _ => None, + }) + } + _ => None, + }; + let Some(closure) = closure else { + return; + }; + if !matches!(closure.fn_decl.output, hir::FnRetTy::DefaultReturn(_)) { + return; + } + + err.span_suggestion_verbose( + self.tcx.hir_body(closure.body).value.span.shrink_to_lo(), + "consider borrowing the value", + "&", + Applicability::MaybeIncorrect, + ); + } + pub(super) fn return_type_span(&self, obligation: &PredicateObligation<'tcx>) -> Option { let hir::Node::Item(hir::Item { kind: hir::ItemKind::Fn { sig, .. }, .. }) = self.tcx.hir_node_by_def_id(obligation.cause.body_id) diff --git a/src/tools/miri/Cargo.lock b/src/tools/miri/Cargo.lock index 630a4b5e3e0f9..25005693117ea 100644 --- a/src/tools/miri/Cargo.lock +++ b/src/tools/miri/Cargo.lock @@ -1162,9 +1162,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.9.2" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166" dependencies = [ "rand_chacha", "rand_core", diff --git a/src/tools/miri/cargo-miri/src/phases.rs b/src/tools/miri/cargo-miri/src/phases.rs index 567d51c0b31e7..8a4d09b8348bf 100644 --- a/src/tools/miri/cargo-miri/src/phases.rs +++ b/src/tools/miri/cargo-miri/src/phases.rs @@ -87,7 +87,9 @@ pub fn phase_cargo_miri(mut args: impl Iterator) { println!("`cargo miri {verb}` supports the same flags as `cargo {verb}`:\n"); let mut cmd = cargo(); cmd.arg(verb); - cmd.arg("--help"); + // Forward all arguments (some of them can influence the help output, e.g. + // the nextest verb). + cmd.args(args); exec(cmd); } _ => { diff --git a/src/tools/miri/genmc-sys/build.rs b/src/tools/miri/genmc-sys/build.rs index 04a8e6854fc08..4a5cc585139e3 100644 --- a/src/tools/miri/genmc-sys/build.rs +++ b/src/tools/miri/genmc-sys/build.rs @@ -28,7 +28,7 @@ mod downloading { /// The GenMC repository the we get our commit from. pub(crate) const GENMC_GITHUB_URL: &str = "https://github.com/MPI-SWS/genmc.git"; /// The GenMC commit we depend on. It must be available on the specified GenMC repository. - pub(crate) const GENMC_COMMIT: &str = "22d3d0b44dedb4e8e1aae3330e546465e4664529"; + pub(crate) const GENMC_COMMIT: &str = "29b03a66402c4453fc77901ef3be90bb55707cd4"; /// Ensure that a local GenMC repo is present and set to the correct commit. /// Return the path of the GenMC repo clone. @@ -159,6 +159,7 @@ fn compile_cpp_dependencies(genmc_path: &Path) { .out_dir(genmc_build_dir) .profile(GENMC_CMAKE_PROFILE) .define("BUILD_LLI", "OFF") + .define("EMIT_NA_LABELS", "OFF") .define("GENMC_DEBUG", if enable_genmc_debug { "ON" } else { "OFF" }); // The actual compilation happens here: @@ -172,7 +173,7 @@ fn compile_cpp_dependencies(genmc_path: &Path) { // Part 2: // Compile the cxx_bridge (the link between the Rust and C++ code). - let genmc_include_dir = genmc_install_dir.join("include").join("genmc"); + let genmc_include_dir = genmc_install_dir.join("include"); // These are all the C++ files we need to compile, which needs to be updated if more C++ files are added to Miri. // We use absolute paths since relative paths can confuse IDEs when attempting to go-to-source on a path in a compiler error. @@ -181,10 +182,6 @@ fn compile_cpp_dependencies(genmc_path: &Path) { .map(|file| std::path::absolute(cpp_files_base_path.join(file)).unwrap()); let mut bridge = cxx_build::bridge("src/lib.rs"); - // FIXME(genmc,cmake): Remove once the GenMC debug setting is available in the config.h file. - if enable_genmc_debug { - bridge.define("ENABLE_GENMC_DEBUG", None); - } bridge .opt_level(2) .debug(true) // Same settings that GenMC uses (default for cmake `RelWithDebInfo`) diff --git a/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp b/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp index 8110c8d24c593..7c89c630c396f 100644 --- a/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp +++ b/src/tools/miri/genmc-sys/cpp/include/MiriInterface.hpp @@ -5,17 +5,17 @@ #include "rust/cxx.h" // GenMC generated headers: -#include "config.h" +#include "genmc/config.h" // Miri `genmc-sys/src_cpp` headers: #include "ResultHandling.hpp" // GenMC headers: -#include "ExecutionGraph/EventLabel.hpp" -#include "Support/MemOrdering.hpp" -#include "Support/RMWOps.hpp" -#include "Verification/Config.hpp" -#include "Verification/GenMCDriver.hpp" +#include "genmc/Execution/EventLabel.hpp" +#include "genmc/Support/MemOrdering.hpp" +#include "genmc/Support/RMWOps.hpp" +#include "genmc/Verification/Config.hpp" +#include "genmc/Verification/GenMCDriver.hpp" // C++ headers: #include @@ -36,6 +36,7 @@ struct StoreResult; struct ReadModifyWriteResult; struct CompareExchangeResult; struct MutexLockResult; +struct MallocResult; // GenMC uses `int` for its thread IDs. using ThreadId = int; @@ -86,13 +87,15 @@ struct MiriGenmcShim : private GenMCDriver { /**** Memory access handling ****/ - [[nodiscard]] LoadResult handle_load( + [[nodiscard]] LoadResult handle_atomic_load( ThreadId thread_id, uint64_t address, uint64_t size, MemOrdering ord, GenmcScalar old_val ); + [[nodiscard]] LoadResult + handle_non_atomic_load(ThreadId thread_id, uint64_t address, uint64_t size); [[nodiscard]] ReadModifyWriteResult handle_read_modify_write( ThreadId thread_id, uint64_t address, @@ -113,7 +116,7 @@ struct MiriGenmcShim : private GenMCDriver { MemOrdering fail_load_ordering, bool can_fail_spuriously ); - [[nodiscard]] StoreResult handle_store( + [[nodiscard]] StoreResult handle_atomic_store( ThreadId thread_id, uint64_t address, uint64_t size, @@ -121,12 +124,14 @@ struct MiriGenmcShim : private GenMCDriver { GenmcScalar old_val, MemOrdering ord ); + [[nodiscard]] StoreResult + handle_non_atomic_store(ThreadId thread_id, uint64_t address, uint64_t size); void handle_fence(ThreadId thread_id, MemOrdering ord); /**** Memory (de)allocation ****/ - auto handle_malloc(ThreadId thread_id, uint64_t size, uint64_t alignment) -> uint64_t; + auto handle_malloc(ThreadId thread_id, uint64_t size, uint64_t alignment) -> MallocResult; /** Returns null on success, or an error string if an error occurs. */ auto handle_free(ThreadId thread_id, uint64_t address) -> std::unique_ptr; @@ -203,33 +208,15 @@ struct MiriGenmcShim : private GenMCDriver { auto get_estimation_results() const -> EstimationResult; private: - /** Increment the event index in the given thread by 1 and return the new event. */ - [[nodiscard]] inline auto inc_pos(ThreadId tid) -> Event { + /** Returns the current event for a given thread. */ + inline auto curr_pos(ThreadId tid) -> Event { ERROR_ON(tid >= threads_action_.size(), "ThreadId out of bounds"); - return ++threads_action_[tid].event; + return threads_action_[tid].event; } - /** Decrement the event index in the given thread by 1 and return the new event. */ - inline auto dec_pos(ThreadId tid) -> Event { + /** Increment the event index in the given thread by `count`. */ + inline void inc_pos(ThreadId tid, unsigned int count) { ERROR_ON(tid >= threads_action_.size(), "ThreadId out of bounds"); - return --threads_action_[tid].event; - } - - /** - * Helper function for loads that need to reset the event counter when no value is returned. - * Same syntax as `GenMCDriver::handleLoad`, but this takes a thread id instead of an Event. - * Automatically calls `inc_pos` and `dec_pos` where needed for the given thread. - */ - template - auto handle_load_reset_if_none(ThreadId tid, std::optional old_val, Ts&&... params) - -> HandleResult { - const auto pos = inc_pos(tid); - const auto ret = - GenMCDriver::handleLoad(nullptr, pos, old_val, std::forward(params)...); - // If we didn't get a value, we have to reset the index of the current thread. - if (!std::holds_alternative(ret)) { - dec_pos(tid); - } - return ret; + threads_action_[tid].event.index += count; } /** @@ -293,40 +280,55 @@ inline std::optional try_to_sval(GenmcScalar scalar) { namespace LoadResultExt { inline LoadResult no_value() { return LoadResult { + .invalid = false, .error = std::unique_ptr(nullptr), - .has_value = false, .read_value = GenmcScalarExt::uninit(), }; } inline LoadResult from_value(SVal read_value) { - return LoadResult { .error = std::unique_ptr(nullptr), - .has_value = true, + return LoadResult { .invalid = false, + .error = std::unique_ptr(nullptr), .read_value = GenmcScalarExt::from_sval(read_value) }; } inline LoadResult from_error(std::unique_ptr error) { - return LoadResult { .error = std::move(error), - .has_value = false, + return LoadResult { .invalid = false, + .error = std::move(error), .read_value = GenmcScalarExt::uninit() }; } + +inline LoadResult from_invalid() { + return LoadResult { .invalid = true, .error = nullptr, .read_value = GenmcScalarExt::uninit() }; +} + } // namespace LoadResultExt namespace StoreResultExt { inline StoreResult ok(bool is_coherence_order_maximal_write) { - return StoreResult { /* error: */ std::unique_ptr(nullptr), - is_coherence_order_maximal_write }; + return StoreResult { .invalid = false, + .error = std::unique_ptr(nullptr), + .is_coherence_order_maximal_write = is_coherence_order_maximal_write }; } inline StoreResult from_error(std::unique_ptr error) { - return StoreResult { .error = std::move(error), .is_coherence_order_maximal_write = false }; + return StoreResult { .invalid = false, + .error = std::move(error), + .is_coherence_order_maximal_write = false }; +} + +inline StoreResult from_invalid() { + return StoreResult { .invalid = true, + .error = nullptr, + .is_coherence_order_maximal_write = false }; } } // namespace StoreResultExt namespace ReadModifyWriteResultExt { inline ReadModifyWriteResult ok(SVal old_value, SVal new_value, bool is_coherence_order_maximal_write) { - return ReadModifyWriteResult { .error = std::unique_ptr(nullptr), + return ReadModifyWriteResult { .invalid = false, + .error = std::unique_ptr(nullptr), .old_value = GenmcScalarExt::from_sval(old_value), .new_value = GenmcScalarExt::from_sval(new_value), .is_coherence_order_maximal_write = @@ -334,7 +336,16 @@ ok(SVal old_value, SVal new_value, bool is_coherence_order_maximal_write) { } inline ReadModifyWriteResult from_error(std::unique_ptr error) { - return ReadModifyWriteResult { .error = std::move(error), + return ReadModifyWriteResult { .invalid = false, + .error = std::move(error), + .old_value = GenmcScalarExt::uninit(), + .new_value = GenmcScalarExt::uninit(), + .is_coherence_order_maximal_write = false }; +} + +inline ReadModifyWriteResult from_invalid() { + return ReadModifyWriteResult { .invalid = true, + .error = nullptr, .old_value = GenmcScalarExt::uninit(), .new_value = GenmcScalarExt::uninit(), .is_coherence_order_maximal_write = false }; @@ -343,7 +354,8 @@ inline ReadModifyWriteResult from_error(std::unique_ptr error) { namespace CompareExchangeResultExt { inline CompareExchangeResult success(SVal old_value, bool is_coherence_order_maximal_write) { - return CompareExchangeResult { .error = nullptr, + return CompareExchangeResult { .invalid = false, + .error = nullptr, .old_value = GenmcScalarExt::from_sval(old_value), .is_success = true, .is_coherence_order_maximal_write = @@ -351,14 +363,24 @@ inline CompareExchangeResult success(SVal old_value, bool is_coherence_order_max } inline CompareExchangeResult failure(SVal old_value) { - return CompareExchangeResult { .error = nullptr, + return CompareExchangeResult { .invalid = false, + .error = nullptr, .old_value = GenmcScalarExt::from_sval(old_value), .is_success = false, .is_coherence_order_maximal_write = false }; } inline CompareExchangeResult from_error(std::unique_ptr error) { - return CompareExchangeResult { .error = std::move(error), + return CompareExchangeResult { .invalid = false, + .error = std::move(error), + .old_value = GenmcScalarExt::uninit(), + .is_success = false, + .is_coherence_order_maximal_write = false }; +} + +inline CompareExchangeResult from_invalid() { + return CompareExchangeResult { .invalid = true, + .error = nullptr, .old_value = GenmcScalarExt::uninit(), .is_success = false, .is_coherence_order_maximal_write = false }; @@ -367,20 +389,42 @@ inline CompareExchangeResult from_error(std::unique_ptr error) { namespace MutexLockResultExt { inline MutexLockResult ok(bool is_lock_acquired) { - return MutexLockResult { /* error: */ nullptr, /* is_reset: */ false, is_lock_acquired }; + return MutexLockResult { .invalid = false, + .error = nullptr, + .is_reset = false, + .is_lock_acquired = is_lock_acquired }; } inline MutexLockResult reset() { - return MutexLockResult { /* error: */ nullptr, - /* is_reset: */ true, - /* is_lock_acquired: */ false }; + return MutexLockResult { .invalid = false, + .error = nullptr, + .is_reset = true, + .is_lock_acquired = false }; } inline MutexLockResult from_error(std::unique_ptr error) { - return MutexLockResult { /* error: */ std::move(error), - /* is_reset: */ false, - /* is_lock_acquired: */ false }; + return MutexLockResult { .invalid = false, + .error = std::move(error), + .is_reset = false, + .is_lock_acquired = false }; +} + +inline MutexLockResult from_invalid() { + return MutexLockResult { .invalid = true, + .error = nullptr, + .is_reset = false, + .is_lock_acquired = false }; } } // namespace MutexLockResultExt +namespace MallocResultExt { +inline MallocResult ok(SVal addr) { + return MallocResult { .error = nullptr, .address = addr.get() }; +} + +inline MallocResult from_error(std::unique_ptr error) { + return MallocResult { .error = std::move(error), .address = 0UL }; +} +} // namespace MallocResultExt + #endif /* GENMC_MIRI_INTERFACE_HPP */ diff --git a/src/tools/miri/genmc-sys/cpp/include/ResultHandling.hpp b/src/tools/miri/genmc-sys/cpp/include/ResultHandling.hpp index cb5f49c179b05..6df3a3af84fb4 100644 --- a/src/tools/miri/genmc-sys/cpp/include/ResultHandling.hpp +++ b/src/tools/miri/genmc-sys/cpp/include/ResultHandling.hpp @@ -5,7 +5,7 @@ #include "rust/cxx.h" // GenMC headers: -#include "Verification/VerificationError.hpp" +#include "genmc/Verification/VerificationError.hpp" #include #include diff --git a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Exploration.cpp b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Exploration.cpp index d5a3833e2e837..1c5186bfff1f0 100644 --- a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Exploration.cpp +++ b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Exploration.cpp @@ -7,22 +7,22 @@ #include "genmc-sys/src/lib.rs.h" // GenMC headers: -#include "ADT/value_ptr.hpp" -#include "ExecutionGraph/EventLabel.hpp" -#include "ExecutionGraph/LoadAnnotation.hpp" -#include "Runtime/InterpreterEnumAPI.hpp" -#include "Static/ModuleID.hpp" -#include "Support/ASize.hpp" -#include "Support/Error.hpp" -#include "Support/Logger.hpp" -#include "Support/MemAccess.hpp" -#include "Support/RMWOps.hpp" -#include "Support/SAddr.hpp" -#include "Support/SVal.hpp" -#include "Support/ThreadInfo.hpp" -#include "Support/Verbosity.hpp" -#include "Verification/GenMCDriver.hpp" -#include "Verification/MemoryModel.hpp" +#include "genmc/ADT/value_ptr.hpp" +#include "genmc/Execution/EventLabel.hpp" +#include "genmc/Execution/LoadAnnotation.hpp" +#include "genmc/Support/ASize.hpp" +#include "genmc/Support/ActionEnums.hpp" +#include "genmc/Support/Error.hpp" +#include "genmc/Support/Logger.hpp" +#include "genmc/Support/MemAccess.hpp" +#include "genmc/Support/ModuleVarID.hpp" +#include "genmc/Support/RMWOps.hpp" +#include "genmc/Support/SAddr.hpp" +#include "genmc/Support/SVal.hpp" +#include "genmc/Support/ThreadInfo.hpp" +#include "genmc/Support/Verbosity.hpp" +#include "genmc/Verification/GenMCDriver.hpp" +#include "genmc/Verification/MemoryModel.hpp" // C++ headers: #include @@ -47,13 +47,13 @@ auto MiriGenmcShim::schedule_next( [](auto&& arg) { using T = std::decay_t; if constexpr (std::is_same_v) - return SchedulingResult { ExecutionState::Ok, static_cast(arg) }; + return SchedulingResult { ExecutionStatus::Ok, static_cast(arg) }; else if constexpr (std::is_same_v) - return SchedulingResult { ExecutionState::Blocked, 0 }; + return SchedulingResult { ExecutionStatus::Blocked, 0 }; else if constexpr (std::is_same_v) - return SchedulingResult { ExecutionState::Error, 0 }; + return SchedulingResult { ExecutionStatus::Error, 0 }; else if constexpr (std::is_same_v) - return SchedulingResult { ExecutionState::Finished, 0 }; + return SchedulingResult { ExecutionStatus::Finished, 0 }; else static_assert(false, "non-exhaustive visitor!"); }, @@ -75,39 +75,66 @@ auto MiriGenmcShim::handle_execution_end() -> std::unique_ptr { /**** Blocking instructions ****/ void MiriGenmcShim::handle_assume_block(ThreadId thread_id, AssumeType assume_type) { - BUG_ON(getExec().getGraph().isThreadBlocked(thread_id)); - GenMCDriver::handleAssume(nullptr, inc_pos(thread_id), assume_type); + auto ret = GenMCDriver::handleAssume(nullptr, curr_pos(thread_id), assume_type); + inc_pos(thread_id, ret.count); } /**** Memory access handling ****/ -[[nodiscard]] auto MiriGenmcShim::handle_load( +[[nodiscard]] auto MiriGenmcShim::handle_atomic_load( ThreadId thread_id, uint64_t address, uint64_t size, MemOrdering ord, GenmcScalar old_val ) -> LoadResult { - // `type` is only used for printing. - const auto type = AType::Unsigned; - const auto ret = handle_load_reset_if_none( - thread_id, + const auto ret = GenMCDriver::handleRead( + nullptr, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), ord, SAddr(address), ASize(size), - type + nullptr, + std::nullopt, + EventDeps() ); - - if (const auto* err = std::get_if(&ret)) + inc_pos(thread_id, ret.count); + if (const auto* err = std::get_if(&ret.result)) return LoadResultExt::from_error(format_error(*err)); - const auto* ret_val = std::get_if(&ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. - ERROR_ON(!ret_val, "Unimplemented: load returned unexpected result."); + if (std::holds_alternative(ret.result)) + return LoadResultExt::from_invalid(); + const auto* ret_val = std::get_if(&ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return value. + ERROR_ON(!ret_val, "Unimplemented: atomic load returned unexpected result."); return LoadResultExt::from_value(*ret_val); } -[[nodiscard]] auto MiriGenmcShim::handle_store( +[[nodiscard]] auto +MiriGenmcShim::handle_non_atomic_load(ThreadId thread_id, uint64_t address, uint64_t size) + -> LoadResult { + const auto ret = GenMCDriver::handleNALoad( + nullptr, + curr_pos(thread_id), + SAddr(address), + ASize(size), + EventDeps() + ); + inc_pos(thread_id, ret.count); + + if (const auto* err = std::get_if(&ret.result)) + return LoadResultExt::from_error(format_error(*err)); + if (std::holds_alternative(ret.result)) + return LoadResultExt::from_invalid(); + // FIXME(genmc): handle `HandleResult::Reset` return value. + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: non-atomic load returned unexpected result." + ); + return LoadResultExt::no_value(); +} + +[[nodiscard]] auto MiriGenmcShim::handle_atomic_store( ThreadId thread_id, uint64_t address, uint64_t size, @@ -115,31 +142,57 @@ void MiriGenmcShim::handle_assume_block(ThreadId thread_id, AssumeType assume_ty GenmcScalar old_val, MemOrdering ord ) -> StoreResult { - const auto pos = inc_pos(thread_id); - const auto ret = GenMCDriver::handleStore( + const auto ret = GenMCDriver::handleWrite( nullptr, - pos, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), ord, SAddr(address), ASize(size), - /* type */ AType::Unsigned, // `type` is only used for printing. GenmcScalarExt::to_sval(value), + WriteAttr(), EventDeps() ); - if (const auto* err = std::get_if(&ret)) + inc_pos(thread_id, ret.count); + if (const auto* err = std::get_if(&ret.result)) return StoreResultExt::from_error(format_error(*err)); + if (std::holds_alternative(ret.result)) + return StoreResultExt::from_invalid(); - const auto* is_co_max = std::get_if(&ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. - ERROR_ON(!is_co_max, "Unimplemented: Store returned unexpected result."); + const auto* is_co_max = std::get_if(&ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return value. + ERROR_ON(!is_co_max, "Unimplemented: atomic store returned unexpected result."); return StoreResultExt::ok(*is_co_max); } +[[nodiscard]] auto +MiriGenmcShim::handle_non_atomic_store(ThreadId thread_id, uint64_t address, uint64_t size) + -> StoreResult { + const auto ret = GenMCDriver::handleNAStore( + nullptr, + curr_pos(thread_id), + SAddr(address), + ASize(size), + EventDeps() + ); + inc_pos(thread_id, ret.count); + + if (const auto* err = std::get_if(&ret.result)) + return StoreResultExt::from_error(format_error(*err)); + if (std::holds_alternative(ret.result)) + return StoreResultExt::from_invalid(); + // FIXME(genmc): handle `HandleResult::Reset` return value. + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: non-atomic store returned unexpected result." + ); + return StoreResultExt::ok(true); +} + void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) { - const auto pos = inc_pos(thread_id); - GenMCDriver::handleFence(nullptr, pos, ord, EventDeps()); + auto ret = GenMCDriver::handleFence(nullptr, curr_pos(thread_id), ord, EventDeps()); + inc_pos(thread_id, ret.count); } [[nodiscard]] auto MiriGenmcShim::handle_read_modify_write( @@ -155,45 +208,52 @@ void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) { // into a load and a store component. This means we can have for example `AcqRel` loads and // stores, but this is intended for RMW operations. - // Somewhat confusingly, the GenMC term for RMW read/write labels is - // `FaiRead` and `FaiWrite`. - const auto load_ret = handle_load_reset_if_none( - thread_id, + const auto load_ret = GenMCDriver::handleFaiRead( + nullptr, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), ordering, SAddr(address), ASize(size), - AType::Unsigned, // The type is only used for printing. rmw_op, GenmcScalarExt::to_sval(rhs_value), + WriteAttr(), + nullptr, + std::nullopt, EventDeps() ); - if (const auto* err = std::get_if(&load_ret)) + inc_pos(thread_id, load_ret.count); + if (const auto* err = std::get_if(&load_ret.result)) return ReadModifyWriteResultExt::from_error(format_error(*err)); + if (std::holds_alternative(load_ret.result)) + return ReadModifyWriteResultExt::from_invalid(); - const auto* ret_val = std::get_if(&load_ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. + const auto* ret_val = std::get_if(&load_ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return values. ERROR_ON(!ret_val, "Unimplemented: read-modify-write returned unexpected result."); const auto read_old_val = *ret_val; const auto new_value = executeRMWBinOp(read_old_val, GenmcScalarExt::to_sval(rhs_value), size, rmw_op); - const auto storePos = inc_pos(thread_id); - const auto store_ret = GenMCDriver::handleStore( + const auto store_ret = GenMCDriver::handleFaiWrite( nullptr, - storePos, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), ordering, SAddr(address), ASize(size), - AType::Unsigned, // The type is only used for printing. - new_value + new_value, + WriteAttr(), + EventDeps() ); - if (const auto* err = std::get_if(&store_ret)) + inc_pos(thread_id, store_ret.count); + if (const auto* err = std::get_if(&store_ret.result)) return ReadModifyWriteResultExt::from_error(format_error(*err)); + if (std::holds_alternative(store_ret.result)) + return ReadModifyWriteResultExt::from_invalid(); - const auto* is_co_max = std::get_if(&store_ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. + const auto* is_co_max = std::get_if(&store_ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return values. ERROR_ON(!is_co_max, "Unimplemented: RMW store returned unexpected result."); return ReadModifyWriteResultExt::ok( /* old_value: */ read_old_val, @@ -222,20 +282,28 @@ void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) { auto expectedVal = GenmcScalarExt::to_sval(expected_value); auto new_val = GenmcScalarExt::to_sval(new_value); - const auto load_ret = handle_load_reset_if_none( - thread_id, + const auto load_ret = GenMCDriver::handleCasRead( + nullptr, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), success_ordering, SAddr(address), ASize(size), - AType::Unsigned, // The type is only used for printing. expectedVal, - new_val + new_val, + WriteAttr(), + nullptr, + std::nullopt, + EventDeps() ); - if (const auto* err = std::get_if(&load_ret)) + inc_pos(thread_id, load_ret.count); + if (const auto* err = std::get_if(&load_ret.result)) return CompareExchangeResultExt::from_error(format_error(*err)); - const auto* ret_val = std::get_if(&load_ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. + if (std::holds_alternative(load_ret.result)) + return CompareExchangeResultExt::from_invalid(); + + const auto* ret_val = std::get_if(&load_ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return values. ERROR_ON(nullptr == ret_val, "Unimplemented: load returned unexpected result."); const auto read_old_val = *ret_val; if (read_old_val != expectedVal) @@ -243,21 +311,25 @@ void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) { // FIXME(GenMC): Add support for modelling spurious failures. - const auto storePos = inc_pos(thread_id); - const auto store_ret = GenMCDriver::handleStore( + const auto store_ret = GenMCDriver::handleCasWrite( nullptr, - storePos, + curr_pos(thread_id), GenmcScalarExt::try_to_sval(old_val), success_ordering, SAddr(address), ASize(size), - AType::Unsigned, // The type is only used for printing. - new_val + new_val, + WriteAttr(), + EventDeps() ); - if (const auto* err = std::get_if(&store_ret)) + inc_pos(thread_id, store_ret.count); + if (const auto* err = std::get_if(&store_ret.result)) return CompareExchangeResultExt::from_error(format_error(*err)); - const auto* is_co_max = std::get_if(&store_ret); - // FIXME(genmc): handle `HandleResult::{Invalid, Reset}` return values. + if (std::holds_alternative(store_ret.result)) + return CompareExchangeResultExt::from_invalid(); + + const auto* is_co_max = std::get_if(&store_ret.result); + // FIXME(genmc): handle `HandleResult::Reset` return values. ERROR_ON(!is_co_max, "Unimplemented: compare-exchange store returned unexpected result."); return CompareExchangeResultExt::success(read_old_val, *is_co_max); } @@ -265,33 +337,45 @@ void MiriGenmcShim::handle_fence(ThreadId thread_id, MemOrdering ord) { /**** Memory (de)allocation ****/ auto MiriGenmcShim::handle_malloc(ThreadId thread_id, uint64_t size, uint64_t alignment) - -> uint64_t { - const auto pos = inc_pos(thread_id); - + -> MallocResult { // These are only used for printing and features Miri-GenMC doesn't support (yet). const auto storage_duration = StorageDuration::SD_Heap; // Volatile, as opposed to "persistent" (i.e., non-volatile memory that persists over reboots) const auto storage_type = StorageType::ST_Volatile; const auto address_space = AddressSpace::AS_User; - const SVal ret_val = GenMCDriver::handleMalloc( + const auto ret = GenMCDriver::handleMalloc( nullptr, - pos, + curr_pos(thread_id), size, alignment, storage_duration, storage_type, address_space, + nullptr, + "", EventDeps() ); - return ret_val.get(); + inc_pos(thread_id, ret.count); + if (const auto* err = std::get_if(&ret.result)) + return MallocResultExt::from_error(format_error(*err)); + const auto* addr = std::get_if(&ret.result); + ERROR_ON(!addr, "Unimplemented: malloc returned unexpected result."); + return MallocResultExt::ok(*addr); } auto MiriGenmcShim::handle_free(ThreadId thread_id, uint64_t address) -> std::unique_ptr { - auto pos = inc_pos(thread_id); - auto ret = GenMCDriver::handleFree(nullptr, pos, SAddr(address), EventDeps()); - return ret.has_value() ? format_error(*ret) : nullptr; + auto ret = GenMCDriver::handleFree(nullptr, curr_pos(thread_id), SAddr(address), EventDeps()); + inc_pos(thread_id, ret.count); + if (const auto* err = std::get_if(&ret.result)) + return format_error(*err); + + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: free returned unexpected result." + ); + return nullptr; } /**** Estimation mode result ****/ @@ -325,12 +409,12 @@ auto MiriGenmcShim::handle_mutex_lock(ThreadId thread_id, uint64_t address, uint const auto annot = std::move(Annotation( AssumeType::Spinloop, Annotation::ExprVP( - NeExpr::create( + NeExpr::create( // `RegisterExpr` marks the value of the current expression, i.e., the loaded value. // The `id` is ignored by GenMC; it is only used by the LLI frontend to substitute // other variables from previous expressions that may be used here. - RegisterExpr::create(size_bits, /* id */ 0), - ConcreteExpr::create(size_bits, MutexState::LOCKED) + RegisterExpr::create(size_bits, /* id */ 0), + ConcreteExpr::create(size_bits, MutexState::LOCKED) ) .release() ) @@ -340,26 +424,34 @@ auto MiriGenmcShim::handle_mutex_lock(ThreadId thread_id, uint64_t address, uint // access, if there previously was a non-atomic initializing access. We set the initial state of // a mutex to be "unlocked". const auto old_val = MutexState::UNLOCKED; - const auto load_ret = handle_load_reset_if_none( - thread_id, + const auto load_ret = GenMCDriver::handleLockCasRead( + nullptr, + curr_pos(thread_id), old_val, address, size, annot, EventDeps() ); - if (const auto* err = std::get_if(&load_ret)) + inc_pos(thread_id, load_ret.count); + if (const auto* err = std::get_if(&load_ret.result)) return MutexLockResultExt::from_error(format_error(*err)); + if (std::holds_alternative(load_ret.result)) + return MutexLockResultExt::from_invalid(); // If we get a `Reset`, GenMC decided that this lock operation should not yet run, since it // would not acquire the mutex. Like the handling of the case further down where we read a `1` // ("Mutex already locked"), Miri should call the handle function again once the current thread // is scheduled by GenMC the next time. - if (std::holds_alternative(load_ret)) + if (std::holds_alternative(load_ret.result)) return MutexLockResultExt::reset(); - const auto* ret_val = std::get_if(&load_ret); + const auto* ret_val = std::get_if(&load_ret.result); ERROR_ON(!ret_val, "Unimplemented: mutex lock returned unexpected result."); - ERROR_ON(!MutexState::isValid(*ret_val), "Mutex read value was neither 0 nor 1"); + ERROR_ON( + !MutexState::isValid(*ret_val), + "Mutex read value was neither 0 nor 1 ({})", + std::to_string(ret_val->get()) + ); if (*ret_val == MutexState::LOCKED) { // We did not acquire the mutex, so we tell GenMC to block the thread until we can acquire // it. GenMC determines this based on the annotation we pass with the load further up in @@ -368,69 +460,72 @@ auto MiriGenmcShim::handle_mutex_lock(ThreadId thread_id, uint64_t address, uint return MutexLockResultExt::ok(false); } - const auto store_ret = GenMCDriver::handleStore( - nullptr, - inc_pos(thread_id), - old_val, - address, - size, - EventDeps() - ); - if (const auto* err = std::get_if(&store_ret)) + const auto store_ret = + GenMCDriver::handleLockCasWrite(nullptr, curr_pos(thread_id), address, size, EventDeps()); + inc_pos(thread_id, store_ret.count); + if (const auto* err = std::get_if(&store_ret.result)) return MutexLockResultExt::from_error(format_error(*err)); + if (std::holds_alternative(store_ret.result)) + return MutexLockResultExt::from_invalid(); // We don't update Miri's memory for this operation so we don't need to know if the store // was the co-maximal store, but we still check that we at least get a boolean as the result // of the store. - const auto* is_co_max = std::get_if(&store_ret); + const auto* is_co_max = std::get_if(&store_ret.result); ERROR_ON(!is_co_max, "Unimplemented: mutex_try_lock store returned unexpected result."); return MutexLockResultExt::ok(true); } auto MiriGenmcShim::handle_mutex_try_lock(ThreadId thread_id, uint64_t address, uint64_t size) -> MutexLockResult { - auto& currPos = threads_action_[thread_id].event; // As usual, we need to tell GenMC which value was stored at this location before this atomic // access, if there previously was a non-atomic initializing access. We set the initial state of // a mutex to be "unlocked". const auto old_val = MutexState::UNLOCKED; - const auto load_ret = GenMCDriver::handleLoad( + const auto load_ret = GenMCDriver::handleTrylockCasRead( nullptr, - ++currPos, + curr_pos(thread_id), old_val, SAddr(address), - ASize(size) + ASize(size), + std::nullopt, + EventDeps() ); - if (const auto* err = std::get_if(&load_ret)) + inc_pos(thread_id, load_ret.count); + if (const auto* err = std::get_if(&load_ret.result)) return MutexLockResultExt::from_error(format_error(*err)); - const auto* ret_val = std::get_if(&load_ret); + if (std::holds_alternative(load_ret.result)) + return MutexLockResultExt::from_invalid(); + const auto* ret_val = std::get_if(&load_ret.result); ERROR_ON(!ret_val, "Unimplemented: mutex trylock load returned unexpected result."); ERROR_ON(!MutexState::isValid(*ret_val), "Mutex read value was neither 0 nor 1"); if (*ret_val == MutexState::LOCKED) return MutexLockResultExt::ok(false); /* Lock already held. */ - const auto store_ret = GenMCDriver::handleStore( + const auto store_ret = GenMCDriver::handleTrylockCasWrite( nullptr, - ++currPos, - old_val, + curr_pos(thread_id), SAddr(address), - ASize(size) + ASize(size), + EventDeps() ); - if (const auto* err = std::get_if(&store_ret)) + inc_pos(thread_id, store_ret.count); + if (const auto* err = std::get_if(&store_ret.result)) return MutexLockResultExt::from_error(format_error(*err)); + if (std::holds_alternative(store_ret.result)) + return MutexLockResultExt::from_invalid(); // We don't update Miri's memory for this operation so we don't need to know if the store was // co-maximal, but we still check that we get a boolean result. - const auto* is_co_max = std::get_if(&store_ret); + const auto* is_co_max = std::get_if(&store_ret.result); ERROR_ON(!is_co_max, "Unimplemented: store part of mutex try_lock returned unexpected result."); return MutexLockResultExt::ok(true); } auto MiriGenmcShim::handle_mutex_unlock(ThreadId thread_id, uint64_t address, uint64_t size) -> StoreResult { - const auto pos = inc_pos(thread_id); - const auto ret = GenMCDriver::handleStore( + const auto ret = GenMCDriver::handleUnlockWrite( nullptr, - pos, + curr_pos(thread_id), // As usual, we need to tell GenMC which value was stored at this location before this // atomic access, if there previously was a non-atomic initializing access. We set the // initial state of a mutex to be "unlocked". @@ -438,13 +533,16 @@ auto MiriGenmcShim::handle_mutex_unlock(ThreadId thread_id, uint64_t address, ui MemOrdering::Release, SAddr(address), ASize(size), - AType::Signed, /* store_value */ MutexState::UNLOCKED, + WriteAttr(), EventDeps() ); - if (const auto* err = std::get_if(&ret)) + inc_pos(thread_id, ret.count); + if (const auto* err = std::get_if(&ret.result)) return StoreResultExt::from_error(format_error(*err)); - const auto* is_co_max = std::get_if(&ret); + if (std::holds_alternative(ret.result)) + return StoreResultExt::from_invalid(); + const auto* is_co_max = std::get_if(&ret.result); ERROR_ON(!is_co_max, "Unimplemented: store part of mutex unlock returned unexpected result."); return StoreResultExt::ok(*is_co_max); } @@ -452,40 +550,62 @@ auto MiriGenmcShim::handle_mutex_unlock(ThreadId thread_id, uint64_t address, ui /** Thread creation/joining */ void MiriGenmcShim::handle_thread_create(ThreadId thread_id, ThreadId parent_id) { - // NOTE: The threadCreate event happens in the parent: - const auto pos = inc_pos(parent_id); // FIXME(genmc): for supporting symmetry reduction, these will need to be properly set: const unsigned fun_id = 0; const SVal arg = SVal(0); const ThreadInfo child_info = ThreadInfo { thread_id, parent_id, fun_id, arg, "unknown thread" }; - const auto child_tid = GenMCDriver::handleThreadCreate(nullptr, pos, child_info, EventDeps()); + // NOTE: The threadCreate event happens in the parent: + const auto ret = + GenMCDriver::handleThreadCreate(nullptr, curr_pos(parent_id), child_info, EventDeps()); + inc_pos(parent_id, ret.count); + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: unexpected return value for thread create" + ); + auto child_tid = std::get(ret.result); + // Sanity check the thread id, which is the index in the `threads_action_` array. - BUG_ON(child_tid != thread_id || child_tid <= 0 || child_tid != threads_action_.size()); + VERIFY(child_tid == thread_id && child_tid > 0 && child_tid == threads_action_.size()); threads_action_.push_back(Action(ActionKind::Load, Event(child_tid, 0))); } void MiriGenmcShim::handle_thread_join(ThreadId thread_id, ThreadId child_id) { // The thread join event happens in the parent. - const auto pos = inc_pos(thread_id); - - const auto ret = GenMCDriver::handleThreadJoin(nullptr, pos, child_id, EventDeps()); - // If the join failed, decrease the event index again: - if (!std::holds_alternative(ret)) { - dec_pos(thread_id); - } - // FIXME(genmc): handle `HandleResult::{Invalid, Reset, VerificationError}` return values. + const auto ret = + GenMCDriver::handleThreadJoin(nullptr, curr_pos(thread_id), child_id, EventDeps()); + inc_pos(thread_id, ret.count); + // FIXME(genmc): handle `HandleResult::{Invalid, VerificationError}` return values. + ERROR_ON( + !std::holds_alternative(ret.result) && !std::holds_alternative(ret.result), + "Unimplemented: unexpected return value for thread join" + ); + // FIXME(genmc): Here Reset{} is silently accepted. Double-check why that is. + // The reason is likely that, although GenMC wants to re-run the join instruction, + // when GenMC deems that the join has executed, it will also deem it successful, + // i.e., the return value is guaranteed to be 0 (or at least we assume that). + // In this case, it doesn't matter that we don't re-run the instruction, since + // Miri sets the correct return value, and GenMC will only schedule this thread + // when it knows the child has terminated. // NOTE: Thread return value is ignored, since Miri doesn't need it. } void MiriGenmcShim::handle_thread_finish(ThreadId thread_id, uint64_t ret_val) { - const auto pos = inc_pos(thread_id); - GenMCDriver::handleThreadFinish(nullptr, pos, SVal(ret_val)); + auto ret = GenMCDriver::handleThreadFinish(nullptr, curr_pos(thread_id), SVal(ret_val)); + inc_pos(thread_id, ret.count); + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: unexpected return value for thread finish" + ); } void MiriGenmcShim::handle_thread_kill(ThreadId thread_id) { - const auto pos = inc_pos(thread_id); - GenMCDriver::handleThreadKill(nullptr, pos); + auto ret = GenMCDriver::handleThreadKill(nullptr, curr_pos(thread_id)); + inc_pos(thread_id, ret.count); + ERROR_ON( + !std::holds_alternative(ret.result), + "Unimplemented: unexpected return value for thread kill" + ); } diff --git a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp index 20c827221a92a..a0c7ca6cd9fa1 100644 --- a/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp +++ b/src/tools/miri/genmc-sys/cpp/src/MiriInterface/Setup.cpp @@ -7,9 +7,8 @@ #include "genmc-sys/src/lib.rs.h" // GenMC headers: -#include "Support/Error.hpp" -#include "Support/Verbosity.hpp" -#include "Verification/InterpreterCallbacks.hpp" +#include "genmc/Support/Error.hpp" +#include "genmc/Support/Verbosity.hpp" // C++ headers: #include @@ -113,6 +112,13 @@ static auto to_genmc_verbosity_level(const LogLevel log_level) -> VerbosityLevel // value written by the skipped thread. conf->replayCompletedThreads = true; + // Initialization checking is done by Miri; GenMC's checks are incorrect for Rust. + conf->disableInitializationChecks = true; + + // Don't check static-address validity as it's incompatible with Miri's + // dynamic discovery of static variables. + conf->disableStaticValidityChecks = true; + // FIXME(genmc): implement symmetry reduction. ERROR_ON( params.do_symmetry_reduction, @@ -160,45 +166,5 @@ static auto to_genmc_verbosity_level(const LogLevel log_level) -> VerbosityLevel // Create the actual driver and Miri-GenMC communication shim. auto driver = std::make_unique(std::move(conf), mode); - - // FIXME(genmc,HACK): Until a proper solution is implemented in GenMC, these callbacks will - // allow Miri to return information about global allocations and override uninitialized memory - // checks for non-atomic loads (Miri handles those without GenMC, so the error would be wrong). - auto interpreter_callbacks = InterpreterCallbacks { - // Miri already ensures that memory accesses are valid, so this check doesn't matter. - // We check that the address is static, but skip checking if it is part of an actual - // allocation. - .isStaticallyAllocated = [](SAddr addr) { return addr.isStatic(); }, - // FIXME(genmc,error reporting): Once a proper a proper API for passing such information is - // implemented in GenMC, Miri should use it to improve the produced error messages. - .getStaticName = [](SAddr addr) { return "[UNKNOWN STATIC]"; }, - // This function is called to get the initial value stored at the given address. - // - // From a Miri perspective, this API doesn't work very well: most memory starts out - // "uninitialized"; - // only statics have an initial value. And their initial value is just a sequence of bytes, - // but GenMC expect this to be already split into separate atomic variables. So we return a - // dummy value. - // This value should never be visible to the interpreted program. - // GenMC does not understand uninitialized memory the same way Miri does, which may cause - // this function to be called. The returned value can be visible to Miri or the user: - // - Printing the execution graph may contain this value in place of uninitialized values. - // FIXME(genmc): NOTE: printing the execution graph is not yet implemented. - // - Non-atomic loads may return this value, but Miri ignores values of non-atomic loads. - // - Atomic loads will *not* see this value once mixed atomic-non-atomic support is added. - // Currently, atomic loads can see this value, unless initialized by an *atomic* store. - // FIXME(genmc): update this comment once mixed atomic-non-atomic support is added. - // - // FIXME(genmc): implement proper support for uninitialized memory in GenMC. - // Ideally, the initial value getter would return an `optional`, since the memory - // location may be uninitialized. - .initValGetter = [](const AAccess& a) { return SVal(0xDEAD); }, - // Miri serves non-atomic loads from its own memory and these GenMC checks are wrong in that - // case. This should no longer be required with proper mixed-size access support. - .skipUninitLoadChecks = [](const MemAccessLabel* access_label - ) { return access_label->getOrdering() == MemOrdering::NotAtomic; }, - }; - driver->setInterpCallbacks(std::move(interpreter_callbacks)); - return driver; } diff --git a/src/tools/miri/genmc-sys/src/lib.rs b/src/tools/miri/genmc-sys/src/lib.rs index 26de80f295d31..84250007ed684 100644 --- a/src/tools/miri/genmc-sys/src/lib.rs +++ b/src/tools/miri/genmc-sys/src/lib.rs @@ -140,15 +140,15 @@ mod ffi { /// Log errors, warnings and tips. Tip, /// Debug print considered revisits. - /// Downgraded to `Tip` if `GENMC_DEBUG` is not enabled. + /// Downgraded to `Tip` if `ENABLE_GENMC_DEBUG` is not enabled. Debug1Revisits, /// Print the execution graph after every memory access. /// Also includes the previous debug log level. - /// Downgraded to `Tip` if `GENMC_DEBUG` is not enabled. + /// Downgraded to `Tip` if `ENABLE_GENMC_DEBUG` is not enabled. Debug2MemoryAccesses, /// Print reads-from values considered by GenMC. /// Also includes the previous debug log level. - /// Downgraded to `Tip` if `GENMC_DEBUG` is not enabled. + /// Downgraded to `Tip` if `ENABLE_GENMC_DEBUG` is not enabled. Debug3ReadsFrom, } @@ -182,7 +182,7 @@ mod ffi { #[must_use] #[derive(Debug, Clone, Copy)] - enum ExecutionState { + enum ExecutionStatus { Ok, Error, Blocked, @@ -192,7 +192,7 @@ mod ffi { #[must_use] #[derive(Debug)] struct SchedulingResult { - exec_state: ExecutionState, + exec_status: ExecutionStatus, next_thread: i32, } @@ -212,10 +212,10 @@ mod ffi { #[must_use] #[derive(Debug)] struct LoadResult { + /// If `true`, exploration should be dropped, **and all other fields are invalid**. + invalid: bool, /// If not null, contains the error encountered during the handling of the load. error: UniquePtr, - /// Indicates whether a value was read or not. - has_value: bool, /// The value that was read. Should not be used if `has_value` is `false`. read_value: GenmcScalar, } @@ -223,6 +223,8 @@ mod ffi { #[must_use] #[derive(Debug)] struct StoreResult { + /// If `true`, exploration should be dropped, **and all other fields are invalid**. + invalid: bool, /// If not null, contains the error encountered during the handling of the store. error: UniquePtr, /// `true` if the write should also be reflected in Miri's memory representation. @@ -232,6 +234,8 @@ mod ffi { #[must_use] #[derive(Debug)] struct ReadModifyWriteResult { + /// If `true`, exploration should be dropped, **and all other fields are invalid**. + invalid: bool, /// If there was an error, it will be stored in `error`, otherwise it is `None`. error: UniquePtr, /// The value that was read by the RMW operation as the left operand. @@ -245,6 +249,8 @@ mod ffi { #[must_use] #[derive(Debug)] struct CompareExchangeResult { + /// If `true`, exploration should be dropped, **and all other fields are invalid**. + invalid: bool, /// If there was an error, it will be stored in `error`, otherwise it is `None`. error: UniquePtr, /// The value that was read by the compare-exchange. @@ -258,6 +264,8 @@ mod ffi { #[must_use] #[derive(Debug)] struct MutexLockResult { + /// If `true`, exploration should be dropped, **and all other fields are invalid**. + invalid: bool, /// If there was an error, it will be stored in `error`, otherwise it is `None`. error: UniquePtr, /// If true, GenMC determined that we should retry the mutex lock operation once the thread attempting to lock is scheduled again. @@ -266,6 +274,15 @@ mod ffi { is_lock_acquired: bool, } + #[must_use] + #[derive(Debug)] + struct MallocResult { + /// If not null, contains the error encountered during the handling of malloc. + error: UniquePtr, + /// The allocated address. + address: u64, + } + /**** These are GenMC types that we have to copy-paste here since cxx does not support "importing" externally defined C++ types. ****/ @@ -385,7 +402,7 @@ mod ffi { /***** Functions for handling events encountered during program execution. *****/ /**** Memory access handling ****/ - fn handle_load( + fn handle_atomic_load( self: Pin<&mut MiriGenmcShim>, thread_id: i32, address: u64, @@ -393,6 +410,12 @@ mod ffi { memory_ordering: MemOrdering, old_value: GenmcScalar, ) -> LoadResult; + fn handle_non_atomic_load( + self: Pin<&mut MiriGenmcShim>, + thread_id: i32, + address: u64, + size: u64, + ) -> LoadResult; fn handle_read_modify_write( self: Pin<&mut MiriGenmcShim>, thread_id: i32, @@ -415,7 +438,7 @@ mod ffi { fail_load_ordering: MemOrdering, can_fail_spuriously: bool, ) -> CompareExchangeResult; - fn handle_store( + fn handle_atomic_store( self: Pin<&mut MiriGenmcShim>, thread_id: i32, address: u64, @@ -424,6 +447,12 @@ mod ffi { old_value: GenmcScalar, memory_ordering: MemOrdering, ) -> StoreResult; + fn handle_non_atomic_store( + self: Pin<&mut MiriGenmcShim>, + thread_id: i32, + address: u64, + size: u64, + ) -> StoreResult; fn handle_fence( self: Pin<&mut MiriGenmcShim>, thread_id: i32, @@ -436,7 +465,7 @@ mod ffi { thread_id: i32, size: u64, alignment: u64, - ) -> u64; + ) -> MallocResult; /// Returns true if an error was found. fn handle_free( self: Pin<&mut MiriGenmcShim>, diff --git a/src/tools/miri/rust-version b/src/tools/miri/rust-version index 38f153f78d026..e9fc6c4cd023e 100644 --- a/src/tools/miri/rust-version +++ b/src/tools/miri/rust-version @@ -1 +1 @@ -4c4205163abcbd08948b3efab796c543ba1ea687 +e22c616e4e87914135c1db261a03e0437255335e diff --git a/src/tools/miri/src/alloc_addresses/mod.rs b/src/tools/miri/src/alloc_addresses/mod.rs index 579c9e1165d43..bcaa97c4a5c5a 100644 --- a/src/tools/miri/src/alloc_addresses/mod.rs +++ b/src/tools/miri/src/alloc_addresses/mod.rs @@ -170,7 +170,9 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> { { let fn_sig = this.tcx.instantiate_bound_regions_with_erased( this.tcx - .fn_sig(instance.def_id()).instantiate(*this.tcx, instance.args).skip_norm_wip(), + .fn_sig(instance.def_id()) + .instantiate(*this.tcx, instance.args) + .skip_norm_wip(), ); let fn_ptr = crate::shims::native_lib::build_libffi_closure(this, fn_sig)?; diff --git a/src/tools/miri/src/concurrency/genmc/helper.rs b/src/tools/miri/src/concurrency/genmc/helper.rs index f539e783fd3f5..e2b4c261c4fec 100644 --- a/src/tools/miri/src/concurrency/genmc/helper.rs +++ b/src/tools/miri/src/concurrency/genmc/helper.rs @@ -4,7 +4,6 @@ use rustc_const_eval::interpret::{InterpResult, interp_ok}; use rustc_middle::mir; use rustc_middle::mir::interpret; use rustc_middle::ty::ScalarInt; -use tracing::debug; use super::GenmcScalar; use crate::alloc_addresses::EvalContextExt as _; @@ -13,33 +12,6 @@ use crate::*; /// Maximum size memory access in bytes that GenMC supports. pub(super) const MAX_ACCESS_SIZE: u64 = 8; - -/// This function is used to split up a large memory access into aligned, non-overlapping chunks of a limited size. -/// Returns an iterator over the chunks, yielding `(base address, size)` of each chunk, ordered by address. -pub fn split_access(address: Size, size: Size) -> impl Iterator { - let start_address = address.bytes(); - let end_address = start_address + size.bytes(); - - let start_address_aligned = start_address.next_multiple_of(MAX_ACCESS_SIZE); - let end_address_aligned = (end_address / MAX_ACCESS_SIZE) * MAX_ACCESS_SIZE; // prev_multiple_of - - debug!( - "GenMC: splitting NA memory access into {MAX_ACCESS_SIZE} byte chunks: {}B + {} * {MAX_ACCESS_SIZE}B + {}B = {size:?}", - start_address_aligned - start_address, - (end_address_aligned - start_address_aligned) / MAX_ACCESS_SIZE, - end_address - end_address_aligned, - ); - - // FIXME(genmc): could make remaining accesses powers-of-2, instead of 1 byte. - let start_chunks = (start_address..start_address_aligned).map(|address| (address, 1)); - let aligned_chunks = (start_address_aligned..end_address_aligned) - .step_by(MAX_ACCESS_SIZE.try_into().unwrap()) - .map(|address| (address, MAX_ACCESS_SIZE)); - let end_chunks = (end_address_aligned..end_address).map(|address| (address, 1)); - - start_chunks.chain(aligned_chunks).chain(end_chunks) -} - /// Inverse function to `scalar_to_genmc_scalar`. /// /// Convert a Miri `Scalar` to a `GenmcScalar`. diff --git a/src/tools/miri/src/concurrency/genmc/mod.rs b/src/tools/miri/src/concurrency/genmc/mod.rs index 092fc7294d15d..b2a4b2dd85383 100644 --- a/src/tools/miri/src/concurrency/genmc/mod.rs +++ b/src/tools/miri/src/concurrency/genmc/mod.rs @@ -19,7 +19,6 @@ use self::helper::{ }; use self::run::GenmcMode; use self::thread_id_map::ThreadIdMap; -use crate::concurrency::genmc::helper::split_access; use crate::diagnostics::SpanDedupDiagnostic; use crate::intrinsics::AtomicRmwOp; use crate::*; @@ -267,8 +266,13 @@ impl GenmcCtx { } else { GenmcScalar::UNINIT }; - let read_value = - self.handle_load(&ecx.machine, address, size, ordering.to_genmc(), genmc_old_value)?; + let read_value = self.handle_atomic_load( + &ecx.machine, + address, + size, + ordering.to_genmc(), + genmc_old_value, + )?; genmc_scalar_to_scalar(ecx, self, read_value, size) } @@ -292,7 +296,7 @@ impl GenmcCtx { } else { GenmcScalar::UNINIT }; - self.handle_store( + self.handle_atomic_store( &ecx.machine, address, size, @@ -447,6 +451,9 @@ impl GenmcCtx { can_fail_spuriously, ); + if cas_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } if let Some(error) = cas_result.error.as_ref() { // FIXME(genmc): error handling throw_ub_format!("{}", error.to_string_lossy()); @@ -488,32 +495,7 @@ impl GenmcCtx { return interp_ok(()); } - let handle_load = |address, size| { - // NOTE: Values loaded non-atomically are still handled by Miri, so we discard whatever we get from GenMC - let _read_value = self.handle_load( - machine, - address, - size, - MemOrdering::NotAtomic, - // This value is used to update the co-maximal store event to the same location. - // We don't need to update that store, since if it is ever read by any atomic loads, the value will be updated then. - // We use uninit for lack of a better value, since we don't know whether the location we currently load from is initialized or not. - GenmcScalar::UNINIT, - )?; - interp_ok(()) - }; - - // This load is small enough so GenMC can handle it. - if size.bytes() <= MAX_ACCESS_SIZE { - return handle_load(address, size); - } - - // This load is too big to be a single GenMC access, we have to split it. - // FIXME(genmc): This will misbehave if there are non-64bit-atomics in there. - // Needs proper support on the GenMC side for large and mixed atomic accesses. - for (address, size) in split_access(address, size) { - handle_load(Size::from_bytes(address), Size::from_bytes(size))?; - } + self.handle_non_atomic_load(machine, address, size)?; interp_ok(()) } @@ -540,40 +522,7 @@ impl GenmcCtx { return interp_ok(()); } - let handle_store = |address, size| { - // We always write the the stored values to Miri's memory, whether GenMC says the write is co-maximal or not. - // The GenMC scheduler ensures that replaying an execution happens in porf-respecting order (po := program order, rf: reads-from order). - // This means that for any non-atomic read Miri performs, the corresponding write has already been replayed. - let _is_co_max_write = self.handle_store( - machine, - address, - size, - // We don't know the value that this store will write, but GenMC expects that we give it an actual value. - // Unfortunately, there are situations where this value can actually become visible - // to the program: when there is an atomic load reading from a non-atomic store. - // FIXME(genmc): update once mixed atomic-non-atomic support is added. Afterwards, this value should never be readable. - GenmcScalar::from_u64(0xDEADBEEF), - // This value is used to update the co-maximal store event to the same location. - // This old value cannot be read anymore by any future loads, since we are doing another non-atomic store to the same location. - // Any future load will either see the store we are adding now, or we have a data race (there can only be one possible non-atomic value to read from at any time). - // We use uninit for lack of a better value, since we don't know whether the location we currently write to is initialized or not. - GenmcScalar::UNINIT, - MemOrdering::NotAtomic, - )?; - interp_ok(()) - }; - - // This store is small enough so GenMC can handle it. - if size.bytes() <= MAX_ACCESS_SIZE { - return handle_store(address, size); - } - - // This store is too big to be a single GenMC access, we have to split it. - // FIXME(genmc): This will misbehave if there are non-64bit-atomics in there. - // Needs proper support on the GenMC side for large and mixed atomic accesses. - for (address, size) in split_access(address, size) { - handle_store(Size::from_bytes(address), Size::from_bytes(size))?; - } + self.handle_non_atomic_store(machine, address, size)?; interp_ok(()) } @@ -599,14 +548,15 @@ impl GenmcCtx { } // GenMC doesn't support ZSTs, so we set the minimum size to 1 byte let genmc_size = size.bytes().max(1); - let chosen_address = self.handle.borrow_mut().pin_mut().handle_malloc( + let malloc_result = self.handle.borrow_mut().pin_mut().handle_malloc( self.active_thread_genmc_tid(machine), genmc_size, alignment.bytes(), ); - if chosen_address == 0 { + if let Some(_error) = malloc_result.error.as_ref() { throw_exhaust!(AddressSpaceFull); } + let chosen_address = malloc_result.address; // Non-global addresses should not be in the global address space. assert_eq!(0, chosen_address & GENMC_GLOBAL_ADDRESSES_MASK); @@ -735,9 +685,9 @@ impl GenmcCtx { } impl GenmcCtx { - /// Inform GenMC about a load (atomic or non-atomic). + /// Inform GenMC about an atomic load. /// Returns the value that GenMC wants this load to read. - fn handle_load<'tcx>( + fn handle_atomic_load<'tcx>( &self, machine: &MiriMachine<'tcx>, address: Size, @@ -758,7 +708,7 @@ impl GenmcCtx { "GenMC: load, address: {addr} == {addr:#x}, size: {size:?}, ordering: {memory_ordering:?}, old_value: {genmc_old_value:x?}", addr = address.bytes() ); - let load_result = self.handle.borrow_mut().pin_mut().handle_load( + let load_result = self.handle.borrow_mut().pin_mut().handle_atomic_load( self.active_thread_genmc_tid(machine), address.bytes(), size.bytes(), @@ -766,23 +716,51 @@ impl GenmcCtx { genmc_old_value, ); + if load_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } if let Some(error) = load_result.error.as_ref() { // FIXME(genmc): error handling throw_ub_format!("{}", error.to_string_lossy()); } - if !load_result.has_value { - // FIXME(GenMC): Implementing certain GenMC optimizations will lead to this. - unimplemented!("GenMC: load returned no value."); - } - debug!("GenMC: load returned value: {:?}", load_result.read_value); interp_ok(load_result.read_value) } - /// Inform GenMC about a store (atomic or non-atomic). + /// Inform GenMC about a non-atomic load. + fn handle_non_atomic_load<'tcx>( + &self, + machine: &MiriMachine<'tcx>, + address: Size, + size: Size, + ) -> InterpResult<'tcx> { + assert!(size.bytes() != 0); + debug!( + "GenMC: NA load, address: {addr} == {addr:#x}, size: {size:?}", + addr = address.bytes() + ); + let load_result = self.handle.borrow_mut().pin_mut().handle_non_atomic_load( + self.active_thread_genmc_tid(machine), + address.bytes(), + size.bytes(), + ); + + if load_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } + if let Some(error) = load_result.error.as_ref() { + // FIXME(genmc): error handling + throw_ub_format!("{}", error.to_string_lossy()); + } + // `load_result.read_value` is just a dummy for non-atomic loads. And anyway Miri doesn't + // give us a chance to change the value here, it'll always use the one from its memory. + interp_ok(()) + } + + /// Inform GenMC about an atomic store. /// Returns true if the store is co-maximal, i.e., it should be written to Miri's memory too. - fn handle_store<'tcx>( + fn handle_atomic_store<'tcx>( &self, machine: &MiriMachine<'tcx>, address: Size, @@ -804,7 +782,7 @@ impl GenmcCtx { "GenMC: store, address: {addr} = {addr:#x}, size: {size:?}, ordering {memory_ordering:?}, value: {genmc_value:?}", addr = address.bytes() ); - let store_result = self.handle.borrow_mut().pin_mut().handle_store( + let store_result = self.handle.borrow_mut().pin_mut().handle_atomic_store( self.active_thread_genmc_tid(machine), address.bytes(), size.bytes(), @@ -813,6 +791,9 @@ impl GenmcCtx { memory_ordering, ); + if store_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } if let Some(error) = store_result.error.as_ref() { // FIXME(genmc): error handling throw_ub_format!("{}", error.to_string_lossy()); @@ -821,6 +802,36 @@ impl GenmcCtx { interp_ok(store_result.is_coherence_order_maximal_write) } + /// Inform GenMC about a non-atomic store. + fn handle_non_atomic_store<'tcx>( + &self, + machine: &MiriMachine<'tcx>, + address: Size, + size: Size, + ) -> InterpResult<'tcx> { + assert!(size.bytes() != 0); + debug!( + "GenMC: NA store, address: {addr} = {addr:#x}, size: {size:?}", + addr = address.bytes() + ); + let store_result = self.handle.borrow_mut().pin_mut().handle_non_atomic_store( + self.active_thread_genmc_tid(machine), + address.bytes(), + size.bytes(), + ); + + if store_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } + if let Some(error) = store_result.error.as_ref() { + // FIXME(genmc): error handling + throw_ub_format!("{}", error.to_string_lossy()); + } + // Miri will always write non-atomic stores to memory. Make sure GenMC agrees with that. + assert!(store_result.is_coherence_order_maximal_write); + interp_ok(()) + } + /// Inform GenMC about an atomic read-modify-write operation. /// This includes atomic swap (also often called "exchange"), but does *not* /// include compare-exchange (see `RMWBinOp` for full list of operations). @@ -859,6 +870,9 @@ impl GenmcCtx { genmc_old_value, ); + if rmw_result.invalid { + throw_machine_stop!(TerminationInfo::GenmcSkip); + } if let Some(error) = rmw_result.error.as_ref() { // FIXME(genmc): error handling throw_ub_format!("{}", error.to_string_lossy()); diff --git a/src/tools/miri/src/concurrency/genmc/scheduling.rs b/src/tools/miri/src/concurrency/genmc/scheduling.rs index 54e87c05818de..f2e9d9204c92a 100644 --- a/src/tools/miri/src/concurrency/genmc/scheduling.rs +++ b/src/tools/miri/src/concurrency/genmc/scheduling.rs @@ -1,4 +1,4 @@ -use genmc_sys::{ActionKind, ExecutionState}; +use genmc_sys::{ActionKind, ExecutionStatus}; use rustc_data_structures::either::Either; use rustc_middle::mir::TerminatorKind; use rustc_middle::ty::{self, Ty}; @@ -117,9 +117,9 @@ impl GenmcCtx { let result = self.handle.borrow_mut().pin_mut().schedule_next(genmc_tid, atomic_kind); // Depending on the exec_state, we either schedule the given thread, or we are finished with this execution. - match result.exec_state { - ExecutionState::Ok => interp_ok(Some(thread_infos.get_miri_tid(result.next_thread))), - ExecutionState::Blocked => { + match result.exec_status { + ExecutionStatus::Ok => interp_ok(Some(thread_infos.get_miri_tid(result.next_thread))), + ExecutionStatus::Blocked => { // This execution doesn't need further exploration. We treat this as "success, no // leak check needed", which makes it a NOP in the big outer loop. throw_machine_stop!(TerminationInfo::Exit { @@ -127,7 +127,7 @@ impl GenmcCtx { leak_check: false, }); } - ExecutionState::Finished => { + ExecutionStatus::Finished => { let exit_status = self.exec_state.exit_status.get().expect( "If the execution is finished, we should have a return value from the program.", ); @@ -136,7 +136,7 @@ impl GenmcCtx { leak_check: matches!(exit_status.exit_type, super::ExitType::MainThreadFinish), }); } - ExecutionState::Error => { + ExecutionStatus::Error => { // GenMC found an error in one of the `handle_*` functions, but didn't return the detected error from the function immediately. // This is still an bug in the user program, so we print the error string. panic!( diff --git a/src/tools/miri/src/data_structures/mono_hash_map.rs b/src/tools/miri/src/data_structures/mono_hash_map.rs index 220233f8ff5f0..63edfdac9d605 100644 --- a/src/tools/miri/src/data_structures/mono_hash_map.rs +++ b/src/tools/miri/src/data_structures/mono_hash_map.rs @@ -96,10 +96,7 @@ impl AllocMap for MonoHashMap { /// Read-only lookup (avoid read-acquiring the RefCell). fn get(&self, k: K) -> Option<&V> { - let val: *const V = match self.0.borrow().get(&k) { - Some(v) => &**v, - None => return None, - }; + let val: *const V = &**self.0.borrow().get(&k)?; // This is safe because `val` points into a `Box`, that we know will not move and // will also not be dropped as long as the shared reference `self` is live. unsafe { Some(&*val) } diff --git a/src/tools/miri/src/diagnostics.rs b/src/tools/miri/src/diagnostics.rs index 9d93edcaa3445..14beafc6a34ba 100644 --- a/src/tools/miri/src/diagnostics.rs +++ b/src/tools/miri/src/diagnostics.rs @@ -18,7 +18,7 @@ pub enum TerminationInfo { leak_check: bool, }, Abort(String), - /// Miri was interrupted by a Ctrl+C from the user + /// Miri was interrupted by a Ctrl+C from the user. Interrupted, UnsupportedInIsolation(String), StackedBorrowsUb { @@ -32,6 +32,8 @@ pub enum TerminationInfo { history: tree_diagnostics::HistoryData, }, Int2PtrWithStrictProvenance, + /// GenMC determined that the execution should stop. + GenmcSkip, /// All threads are blocked. GlobalDeadlock, /// Some thread discovered a deadlock condition (e.g. in a mutex with reentrancy checking). @@ -81,6 +83,7 @@ impl fmt::Display for TerminationInfo { TreeBorrowsUb { title, .. } => write!(f, "{title}"), GlobalDeadlock => write!(f, "the evaluated program deadlocked"), LocalDeadlock => write!(f, "a thread deadlocked"), + GenmcSkip => write!(f, "GenMC wants to skip this execution"), MultipleSymbolDefinitions { link_name, .. } => write!(f, "multiple definitions of symbol `{link_name}`"), SymbolShimClashing { link_name, .. } => @@ -240,6 +243,10 @@ pub fn report_result<'tcx>( Some("unsupported operation"), StackedBorrowsUb { .. } | TreeBorrowsUb { .. } | DataRace { .. } => Some("Undefined Behavior"), + GenmcSkip => { + assert!(ecx.machine.data_race.as_genmc_ref().is_some()); + return Some((0, false)); + } LocalDeadlock => { labels.push(format!("thread got stuck here")); None diff --git a/src/tools/miri/src/provenance_gc.rs b/src/tools/miri/src/provenance_gc.rs index f2c750a7577f4..02353411eb944 100644 --- a/src/tools/miri/src/provenance_gc.rs +++ b/src/tools/miri/src/provenance_gc.rs @@ -21,6 +21,10 @@ macro_rules! no_provenance { } no_provenance!(i8 i16 i32 i64 isize u8 u16 u32 u64 usize bool ThreadId); +impl VisitProvenance for &'static str { + fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {} +} + impl VisitProvenance for Option { fn visit_provenance(&self, visit: &mut VisitWith<'_>) { if let Some(x) = self { diff --git a/src/tools/miri/src/shims/unix/fd.rs b/src/tools/miri/src/shims/unix/fd.rs index 065f040cd3e1d..5e5f7d6bc3b84 100644 --- a/src/tools/miri/src/shims/unix/fd.rs +++ b/src/tools/miri/src/shims/unix/fd.rs @@ -62,6 +62,20 @@ pub trait UnixFileDescription: FileDescription { throw_unsup_format!("cannot flock {}", self.name()); } + /// Modifies device parameters. + /// `op` is the device-dependent operation code. It's either a `c_long` or `c_int`, depending on + /// the target and whether it uses glibc or musl. + /// `arg` is the optional third argument which exists depending on the operation code. It's either + /// an integer or a pointer. + fn ioctl<'tcx>( + &self, + _op: Scalar, + _arg: Option<&OpTy<'tcx>>, + _ecx: &mut MiriInterpCx<'tcx>, + ) -> InterpResult<'tcx, i32> { + throw_unsup_format!("cannot use ioctl on {}", self.name()); + } + /// Return which epoll events are currently active. fn epoll_active_events<'tcx>(&self) -> InterpResult<'tcx, EpollEvents> { throw_unsup_format!("{}: epoll does not support this file description", self.name()); @@ -129,6 +143,39 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { interp_ok(Scalar::from_i32(this.try_unwrap_io_result(result)?)) } + fn ioctl( + &mut self, + fd: &OpTy<'tcx>, + op: &OpTy<'tcx>, + varargs: &[OpTy<'tcx>], + ) -> InterpResult<'tcx, Scalar> { + let this = self.eval_context_mut(); + + let fd = this.read_scalar(fd)?.to_i32()?; + let op = this.read_scalar(op)?; + // There is at most one relevant variadic argument. + // It exists depending on the device and the opcode and thus we can't + // use `check_min_vararg_count` here. + let arg = varargs.first(); + + let Some(fd) = this.machine.fds.get(fd) else { + return this.set_last_error_and_return_i32(LibcError("EBADF")); + }; + + // Handle common opcodes. + let fioclex = this.eval_libc("FIOCLEX"); + let fionclex = this.eval_libc("FIONCLEX"); + if op == fioclex || op == fionclex { + // Since we don't support `exec`, those are NOPs. + return interp_ok(Scalar::from_i32(0)); + } + + // Since some ioctl operations use the return value as an output parameter, we cannot strictly use the convention of + // zero indicating success and -1 indicating an error. + let return_value = fd.as_unix(this).ioctl(op, arg, this)?; + interp_ok(Scalar::from_i32(return_value)) + } + fn fcntl( &mut self, fd_num: &OpTy<'tcx>, diff --git a/src/tools/miri/src/shims/unix/foreign_items.rs b/src/tools/miri/src/shims/unix/foreign_items.rs index 2b366b699065d..ba12985e86fcc 100644 --- a/src/tools/miri/src/shims/unix/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/foreign_items.rs @@ -307,6 +307,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let result = this.flock(fd, op)?; this.write_scalar(result, dest)?; } + "ioctl" => { + let ([fd, op], varargs) = + this.check_shim_sig_variadic_lenient(abi, CanonAbi::C, link_name, args)?; + let result = this.ioctl(fd, op, varargs)?; + this.write_scalar(result, dest)?; + } // File and file system access "open" => { @@ -658,8 +664,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { abi, args, )?; - let result = this.getpeername(socket, address, address_len)?; - this.write_scalar(result, dest)?; + this.getpeername(socket, address, address_len, dest)?; } // Time diff --git a/src/tools/miri/src/shims/unix/macos/foreign_items.rs b/src/tools/miri/src/shims/unix/macos/foreign_items.rs index c9e9c30ac2c7d..9ca487eac9ae9 100644 --- a/src/tools/miri/src/shims/unix/macos/foreign_items.rs +++ b/src/tools/miri/src/shims/unix/macos/foreign_items.rs @@ -80,12 +80,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let result = this.realpath(path, resolved_path)?; this.write_scalar(result, dest)?; } - "ioctl" => { - let ([fd_num, cmd], varargs) = - this.check_shim_sig_variadic_lenient(abi, CanonAbi::C, link_name, args)?; - let result = this.ioctl(fd_num, cmd, varargs)?; - this.write_scalar(result, dest)?; - } // Environment related shims "_NSGetEnviron" => { @@ -341,30 +335,4 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { interp_ok(EmulateItemResult::NeedsReturn) } - - fn ioctl( - &mut self, - fd_num: &OpTy<'tcx>, - cmd: &OpTy<'tcx>, - _varargs: &[OpTy<'tcx>], - ) -> InterpResult<'tcx, Scalar> { - let this = self.eval_context_mut(); - - let fioclex = this.eval_libc_u64("FIOCLEX"); - - let fd_num = this.read_scalar(fd_num)?.to_i32()?; - let cmd = this.read_scalar(cmd)?.to_u64()?; - - if cmd == fioclex { - // Since we don't support `exec`, this is a NOP. However, we want to - // return EBADF if the FD is invalid. - if this.machine.fds.is_fd_num(fd_num) { - interp_ok(Scalar::from_i32(0)) - } else { - this.set_last_error_and_return_i32(LibcError("EBADF")) - } - } else { - throw_unsup_format!("ioctl: unsupported command {cmd:#x}"); - } - } } diff --git a/src/tools/miri/src/shims/unix/socket.rs b/src/tools/miri/src/shims/unix/socket.rs index 41a510cfe9b85..9d7d5a32f127b 100644 --- a/src/tools/miri/src/shims/unix/socket.rs +++ b/src/tools/miri/src/shims/unix/socket.rs @@ -1,6 +1,7 @@ use std::cell::{Cell, RefCell}; use std::io::Read; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::time::Duration; use std::{io, iter}; use mio::Interest; @@ -12,8 +13,10 @@ use rustc_const_eval::interpret::{InterpResult, interp_ok}; use rustc_middle::throw_unsup_format; use rustc_target::spec::Os; +use crate::concurrency::blocking_io::InterestReceiver; use crate::shims::files::{EvalContextExt as _, FdId, FileDescription, FileDescriptionRef}; -use crate::{OpTy, Scalar, *}; +use crate::shims::unix::UnixFileDescription; +use crate::*; #[derive(Debug, PartialEq)] enum SocketFamily { @@ -23,22 +26,6 @@ enum SocketFamily { IPv6, } -enum SocketIoError { - /// The socket is not yet ready. Either EINPROGRESS or ENOTCONNECTED occurred. - NotReady, - /// Any other kind of I/O error. - Other(io::Error), -} - -impl From for SocketIoError { - fn from(value: io::Error) -> Self { - match value.kind() { - io::ErrorKind::InProgress | io::ErrorKind::NotConnected => Self::NotReady, - _ => Self::Other(value), - } - } -} - #[derive(Debug)] enum SocketState { /// No syscall after `socket` has been made. @@ -61,59 +48,6 @@ enum SocketState { Connected(TcpStream), } -impl SocketState { - /// If the socket is currently in [`SocketState::Connecting`], try to ensure - /// that the connection is established by first checking that [`TcpStream::take_error`] - /// doesn't return an error and then by checking that [`TcpStream::peer_addr`] - /// returns the address of the connected peer. - /// - /// If the connection is established or the socket is in any other state, - /// [`Ok`] is returned. - /// - /// **Important**: On Windows hosts this function can only be used to ensure a socket is connected - /// _after_ a [`Interest::WRITABLE`] event was received. - pub fn try_set_connected(&mut self) -> Result<(), SocketIoError> { - // Further explanation of the limitation on Windows hosts: - // Windows treats sockets which are connecting as connected until either the connection timeout hits - // or an error occurs. Thus, the [`TcpStream::peer_addr`] method returns [`Ok`] with the provided peer - // address even when the connection might not yet be established. - - let SocketState::Connecting(stream) = self else { return Ok(()) }; - - if let Ok(Some(e)) = stream.take_error() { - // There was an error whilst connecting. - let e = SocketIoError::from(e); - // We won't get EINPROGRESS or ENOTCONNECTED here - // so we need to reset the state. - assert!(matches!(e, SocketIoError::Other(_))); - // Go back to initial state as the only way of getting into the - // `Connecting` state is from the `Initial` state. - *self = SocketState::Initial; - return Err(e); - } - - if let Err(e) = stream.peer_addr() { - let e = SocketIoError::from(e); - if let SocketIoError::Other(_) = &e { - // All other errors are fatal for a socket and thus the state needs to be reset. - *self = SocketState::Initial; - } - return Err(e); - }; - - // We just read the peer address without an error so we can be - // sure that the connection is established. - - // Temporarily use dummy state to take ownership of the stream. - let SocketState::Connecting(stream) = std::mem::replace(self, SocketState::Initial) else { - // At the start of the function we ensured that we're currently connecting. - unreachable!() - }; - *self = SocketState::Connected(stream); - Ok(()) - } -} - #[derive(Debug)] struct Socket { /// Family of the socket, used to ensure socket only binds/connects to address of @@ -151,17 +85,40 @@ impl FileDescription for Socket { ) -> InterpResult<'tcx> { assert!(communicate_allowed, "cannot have `Socket` with isolation enabled!"); - if !matches!(&*self.state.borrow(), SocketState::Connected(_)) { - // We can only receive from connected sockets. For all other - // states we return a not connected error. - return finish.call(ecx, Err(LibcError("ENOTCONN"))); - } + let socket = self; + + ecx.ensure_connected( + socket.clone(), + !socket.is_non_block.get(), + "read", + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + ptr: Pointer, + len: usize, + finish: DynMachineCallback<'tcx, Result>, + } |this, result: Result<(), ()>| { + if result.is_err() { + return finish.call(this, Err(LibcError("ENOTCONN"))) + } - // Since `read` is the same as `recv` with no flags, we just treat - // the `read` as a `recv` here. - ecx.block_for_recv(self, ptr, len, /* should_peek */ false, finish); + // Since `read` is the same as `recv` with no flags, we just treat + // the `read` as a `recv` here. - interp_ok(()) + if socket.is_non_block.get() { + // We have a non-blocking socket and thus don't want to block until + // we can read. + let result = this.try_non_block_recv(&socket, ptr, len, /* should_peek */ false)?; + finish.call(this, result) + } else { + // The socket is in blocking mode and thus the read call should block + // until we can read some bytes from the socket. + this.block_for_recv(socket, ptr, len, /* should_peek */ false, finish); + interp_ok(()) + } + } + ), + ) } fn write<'tcx>( @@ -174,17 +131,40 @@ impl FileDescription for Socket { ) -> InterpResult<'tcx> { assert!(communicate_allowed, "cannot have `Socket` with isolation enabled!"); - if !matches!(&*self.state.borrow(), SocketState::Connected(_)) { - // We can only send with connected sockets. For all other - // states we return a not connected error. - return finish.call(ecx, Err(LibcError("ENOTCONN"))); - } + let socket = self; + + ecx.ensure_connected( + socket.clone(), + !socket.is_non_block.get(), + "write", + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + ptr: Pointer, + len: usize, + finish: DynMachineCallback<'tcx, Result> + } |this, result: Result<(), ()>| { + if result.is_err() { + return finish.call(this, Err(LibcError("ENOTCONN"))) + } - // Since `write` is the same as `send` with no flags, we just treat - // the `write` as a `send` here. - ecx.block_for_send(self, ptr, len, finish); + // Since `write` is the same as `send` with no flags, we just treat + // the `write` as a `send` here. - interp_ok(()) + if socket.is_non_block.get() { + // We have a non-blocking socket and thus don't want to block until + // we can write. + let result = this.try_non_block_send(&socket, ptr, len)?; + return finish.call(this, result) + } else { + // The socket is in blocking mode and thus the write call should block + // until we can write some bytes into the socket. + this.block_for_send(socket, ptr, len, finish); + interp_ok(()) + } + } + ), + ) } fn short_fd_operations(&self) -> bool { @@ -192,6 +172,10 @@ impl FileDescription for Socket { true } + fn as_unix<'tcx>(&self, _ecx: &MiriInterpCx<'tcx>) -> &dyn UnixFileDescription { + self + } + fn get_flags<'tcx>(&self, ecx: &mut MiriInterpCx<'tcx>) -> InterpResult<'tcx, Scalar> { let mut flags = ecx.eval_libc_i32("O_RDWR"); @@ -204,10 +188,64 @@ impl FileDescription for Socket { fn set_flags<'tcx>( &self, - mut _flag: i32, - _ecx: &mut MiriInterpCx<'tcx>, + mut flag: i32, + ecx: &mut MiriInterpCx<'tcx>, ) -> InterpResult<'tcx, Scalar> { - throw_unsup_format!("fcntl: socket flags aren't supported") + let o_nonblock = ecx.eval_libc_i32("O_NONBLOCK"); + + // O_NONBLOCK flag can be set / unset by user. + if flag & o_nonblock == o_nonblock { + self.is_non_block.set(true); + flag &= !o_nonblock; + } else { + self.is_non_block.set(false); + } + + // Throw error if there is any unsupported flag. + if flag != 0 { + throw_unsup_format!("fcntl: only O_NONBLOCK is supported for sockets") + } + + interp_ok(Scalar::from_i32(0)) + } +} + +impl UnixFileDescription for Socket { + fn ioctl<'tcx>( + &self, + op: Scalar, + arg: Option<&OpTy<'tcx>>, + ecx: &mut MiriInterpCx<'tcx>, + ) -> InterpResult<'tcx, i32> { + assert!(ecx.machine.communicate(), "cannot have `Socket` with isolation enabled!"); + + let fionbio = ecx.eval_libc("FIONBIO"); + + if op == fionbio { + // On these OSes, Rust uses the ioctl, so we trust that it is reasonable and controls + // the same internal flag as fcntl. + if !matches!(ecx.tcx.sess.target.os, Os::Linux | Os::Android | Os::MacOs | Os::FreeBsd) + { + // FIONBIO cannot be used to change the blocking mode of a socket on solarish targets: + // + // Since there might be more targets which do weird things with this option, we use + // an allowlist instead of just denying solarish targets. + throw_unsup_format!( + "ioctl: setting FIONBIO on sockets is unsupported on target {}", + ecx.tcx.sess.target.os + ); + } + + let Some(value_ptr) = arg else { + throw_ub_format!("ioctl: setting FIONBIO on sockets requires a third argument"); + }; + let value = ecx.deref_pointer_as(value_ptr, ecx.machine.layouts.i32)?; + let non_block = ecx.read_scalar(&value)?.to_i32()? != 0; + self.is_non_block.set(non_block); + return interp_ok(0); + } + + throw_unsup_format!("ioctl: unsupported operation {op:#x} on socket"); } } @@ -469,19 +507,35 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } if socket.is_non_block.get() { - throw_unsup_format!("accept4: non-blocking accept is unsupported") + // We have a non-blocking socket and thus don't want to block until + // we can accept an incoming connection. + match this.try_non_block_accept( + &socket, + address_ptr, + address_len_ptr, + is_client_sock_nonblock, + )? { + Ok(sockfd) => { + // We need to create the scalar using the destination size since + // `syscall(SYS_accept4, ...)` returns a long which doesn't match + // the int returned from the `accept`/`accept4` syscalls. + // See . + this.write_scalar(Scalar::from_int(sockfd, dest.layout.size), dest) + } + Err(e) => this.set_last_error_and_return(e, dest), + } + } else { + // The socket is in blocking mode and thus the accept call should block + // until an incoming connection is ready. + this.block_for_accept( + socket, + address_ptr, + address_len_ptr, + is_client_sock_nonblock, + dest.clone(), + ); + interp_ok(()) } - - // The socket is in blocking mode and thus the accept call should block - // until an incoming connection is ready. - this.block_for_accept( - address_ptr, - address_len_ptr, - is_client_sock_nonblock, - socket, - dest.clone(), - ); - interp_ok(()) } fn connect( @@ -530,22 +584,44 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // Mio returns a potentially unconnected stream. // We can be ensured that the connection is established when // [`TcpStream::take_err`] and [`TcpStream::peer_addr`] both - // don't return errors. - // For non-blocking sockets we need to check that for every - // [`Interest::WRITEABLE`] event on the stream. + // don't return an error after receiving an [`Interest::WRITEABLE`] + // event on the stream. match TcpStream::connect(address) { Ok(stream) => *socket.state.borrow_mut() = SocketState::Connecting(stream), Err(e) => return this.set_last_error_and_return(e, dest), }; if socket.is_non_block.get() { - throw_unsup_format!("connect: non-blocking connect is unsupported"); - } + // We have a non-blocking socket and thus don't want to block until + // the connection is established. - // The socket is in blocking mode and thus the connect call should block - // until the connection with the server is established. - this.block_for_connect(socket, dest.clone()); - interp_ok(()) + // Since the [`TcpStream::connect`] function of mio hides the EINPROGRESS + // we just always return EINPROGRESS and check whether the connection succeeded + // once we want to use the connected socket. + this.set_last_error_and_return(LibcError("EINPROGRESS"), dest) + } else { + // The socket is in blocking mode and thus the connect call should block + // until the connection with the server is established. + + let dest = dest.clone(); + + this.ensure_connected( + socket, + /* should_wait */ true, + "connect", + callback!( + @capture<'tcx> { + dest: MPlaceTy<'tcx> + } |this, result: Result<(), ()>| { + if result.is_err() { + this.set_last_error_and_return(LibcError("ENOTCONN"), &dest) + } else { + this.write_scalar(Scalar::from_i32(0), &dest) + } + } + ), + ) + } } fn send( @@ -576,12 +652,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { return this.set_last_error_and_return(LibcError("ENOTSOCK"), dest); }; - if !matches!(&*socket.state.borrow(), SocketState::Connected(_)) { - // We can only send with connected sockets. For all other - // states we return a not connected error. - return this.set_last_error_and_return(LibcError("ENOTCONN"), dest); - } - // Non-deterministically decide to further reduce the length, simulating a partial send. // We avoid reducing the write size to 0: the docs seem to be entirely fine with that, // but the standard library is not (https://github.com/rust-lang/rust/issues/145959). @@ -594,50 +664,86 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { length }; + let mut is_op_non_block = false; + // Interpret the flag. Every flag we recognize is "subtracted" from `flags`, so // if there is anything left at the end, that's an unsupported flag. if matches!( this.tcx.sess.target.os, Os::Linux | Os::Android | Os::FreeBsd | Os::Solaris | Os::Illumos ) { - // MSG_NOSIGNAL only exists on Linux, Android, FreeBSD, + // MSG_NOSIGNAL and MSG_DONTWAIT only exist on Linux, Android, FreeBSD, // Solaris, and Illumos targets. let msg_nosignal = this.eval_libc_i32("MSG_NOSIGNAL"); + let msg_dontwait = this.eval_libc_i32("MSG_DONTWAIT"); if flags & msg_nosignal == msg_nosignal { // This is only needed to ensure that no EPIPE signal is sent when // trying to send into a stream which is no longer connected. // Since we don't support signals, we can ignore this. flags &= !msg_nosignal; } + if flags & msg_dontwait == msg_dontwait { + flags &= !msg_dontwait; + is_op_non_block = true; + } } if flags != 0 { throw_unsup_format!( - "send: flag {flags:#x} is unsupported, only MSG_NOSIGNAL is allowed", + "send: flag {flags:#x} is unsupported, only MSG_NOSIGNAL and MSG_DONTWAIT are allowed", ); } + // If either the operation or the socket is non-blocking, we don't want + // to wait until the connection is established. + let should_wait = !is_op_non_block && !socket.is_non_block.get(); let dest = dest.clone(); - this.block_for_send( - socket, - buffer_ptr, - length, - callback!(@capture<'tcx> { - dest: MPlaceTy<'tcx> - } |this, result: Result| { - match result { - Ok(read_size) => { - let read_size: u64 = read_size.try_into().unwrap(); - let ssize_layout = this.libc_ty_layout("ssize_t"); - this.write_scalar(Scalar::from_int(read_size, ssize_layout.size), &dest) + this.ensure_connected( + socket.clone(), + should_wait, + "send", + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + flags: i32, + buffer_ptr: Pointer, + length: usize, + is_op_non_block: bool, + dest: MPlaceTy<'tcx>, + } |this, result: Result<(), ()>| { + if result.is_err() { + return this.set_last_error_and_return(LibcError("ENOTCONN"), &dest) } - Err(e) => this.set_last_error_and_return(e, &dest) - } - }), - ); - interp_ok(()) + if is_op_non_block || socket.is_non_block.get() { + // We have a non-blocking operation or a non-blocking socket and + // thus don't want to block until we can send. + match this.try_non_block_send(&socket, buffer_ptr, length)? { + Ok(size) => this.write_scalar(Scalar::from_target_isize(size.try_into().unwrap(), this), &dest), + Err(e) => this.set_last_error_and_return(e, &dest), + } + } else { + // The socket is in blocking mode and thus the send call should block + // until we can send some bytes into the socket. + this.block_for_send( + socket, + buffer_ptr, + length, + callback!(@capture<'tcx> { + dest: MPlaceTy<'tcx> + } |this, result: Result| { + match result { + Ok(size) => this.write_scalar(Scalar::from_target_isize(size.try_into().unwrap(), this), &dest), + Err(e) => this.set_last_error_and_return(e, &dest) + } + }), + ); + interp_ok(()) + } + } + ), + ) } fn recv( @@ -668,12 +774,6 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { return this.set_last_error_and_return(LibcError("ENOTSOCK"), dest); }; - if !matches!(&*socket.state.borrow(), SocketState::Connected(_)) { - // We can only receive from connected sockets. For all other - // states we return a not connected error. - return this.set_last_error_and_return(LibcError("ENOTCONN"), dest); - } - // Non-deterministically decide to further reduce the length, simulating a partial receive. // We don't simulate partial receives for lengths < 2 because the man page states that a // return value of zero can only be returned in some special cases: @@ -690,6 +790,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }; let mut should_peek = false; + let mut is_op_non_block = false; // Interpret the flag. Every flag we recognize is "subtracted" from `flags`, so // if there is anything left at the end, that's an unsupported flag. @@ -710,35 +811,77 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } } + if matches!( + this.tcx.sess.target.os, + Os::Linux | Os::Android | Os::FreeBsd | Os::Solaris | Os::Illumos + ) { + // MSG_DONTWAIT only exists on Linux, Android, FreeBSD, + // Solaris, and Illumos targets. + let msg_dontwait = this.eval_libc_i32("MSG_DONTWAIT"); + if flags & msg_dontwait == msg_dontwait { + flags &= !msg_dontwait; + is_op_non_block = true; + } + } + if flags != 0 { throw_unsup_format!( - "recv: flag {flags:#x} is unsupported, only MSG_PEEK \ + "recv: flag {flags:#x} is unsupported, only MSG_PEEK, MSG_DONTWAIT \ and MSG_CMSG_CLOEXEC are allowed", ); } + // If either the operation or the socket is non-blocking, we don't want + // to wait until the connection is established. + let should_wait = !is_op_non_block && !socket.is_non_block.get(); let dest = dest.clone(); - this.block_for_recv( - socket, - buffer_ptr, - length, - should_peek, - callback!(@capture<'tcx> { - dest: MPlaceTy<'tcx> - } |this, result: Result| { - match result { - Ok(read_size) => { - let read_size: u64 = read_size.try_into().unwrap(); - let ssize_layout = this.libc_ty_layout("ssize_t"); - this.write_scalar(Scalar::from_int(read_size, ssize_layout.size), &dest) + this.ensure_connected( + socket.clone(), + should_wait, + "recv", + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + buffer_ptr: Pointer, + length: usize, + should_peek: bool, + is_op_non_block: bool, + dest: MPlaceTy<'tcx>, + } |this, result: Result<(), ()>| { + if result.is_err() { + return this.set_last_error_and_return(LibcError("ENOTCONN"), &dest) } - Err(e) => this.set_last_error_and_return(e, &dest) - } - }), - ); - interp_ok(()) + if is_op_non_block || socket.is_non_block.get() { + // We have a non-blocking operation or a non-blocking socket and + // thus don't want to block until we can receive. + match this.try_non_block_recv(&socket, buffer_ptr, length, should_peek)? { + Ok(size) => this.write_scalar(Scalar::from_target_isize(size.try_into().unwrap(), this), &dest), + Err(e) => this.set_last_error_and_return(e, &dest), + } + } else { + // The socket is in blocking mode and thus the receive call should block + // until we can receive some bytes from the socket. + this.block_for_recv( + socket, + buffer_ptr, + length, + should_peek, + callback!(@capture<'tcx> { + dest: MPlaceTy<'tcx> + } |this, result: Result| { + match result { + Ok(size) => this.write_scalar(Scalar::from_target_isize(size.try_into().unwrap(), this), &dest), + Err(e) => this.set_last_error_and_return(e, &dest) + } + }), + ); + interp_ok(()) + } + } + ), + ) } fn setsockopt( @@ -871,7 +1014,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { socket: &OpTy<'tcx>, address: &OpTy<'tcx>, address_len: &OpTy<'tcx>, - ) -> InterpResult<'tcx, Scalar> { + // Location where the output scalar is written to. + dest: &MPlaceTy<'tcx>, + ) -> InterpResult<'tcx> { let this = self.eval_context_mut(); let socket = this.read_scalar(socket)?.to_i32()?; @@ -880,32 +1025,56 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // Get the file handle let Some(fd) = this.machine.fds.get(socket) else { - return this.set_last_error_and_return_i32(LibcError("EBADF")); + return this.set_last_error_and_return(LibcError("EBADF"), dest); }; let Some(socket) = fd.downcast::() else { // Man page specifies to return ENOTSOCK if `fd` is not a socket. - return this.set_last_error_and_return_i32(LibcError("ENOTSOCK")); + return this.set_last_error_and_return(LibcError("ENOTSOCK"), dest); }; assert!(this.machine.communicate(), "cannot have `Socket` with isolation enabled!"); - let state = socket.state.borrow(); + let dest = dest.clone(); - let SocketState::Connected(stream) = &*state else { - // We can only read the peer address of connected sockets. - return this.set_last_error_and_return_i32(LibcError("ENOTCONN")); - }; + // It's only safe to call [`TcpStream::peer_addr`] after the socket is connected since + // UNIX targets should return ENOTCONN when the connection is not yet established. + this.ensure_connected( + socket.clone(), + /* should_wait */ false, + "getpeername", + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + address_ptr: Pointer, + address_len_ptr: Pointer, + dest: MPlaceTy<'tcx>, + } |this, result: Result<(), ()>| { + if result.is_err() { + return this.set_last_error_and_return(LibcError("ENOTCONN"), &dest) + }; - let address = match stream.peer_addr() { - Ok(address) => address, - Err(e) => return this.set_last_error_and_return_i32(e), - }; + let SocketState::Connected(stream) = &*socket.state.borrow() else { + unreachable!() + }; - match this.write_socket_address(&address, address_ptr, address_len_ptr, "getpeername")? { - Ok(_) => interp_ok(Scalar::from_i32(0)), - Err(e) => this.set_last_error_and_return_i32(e), - } + let address = match stream.peer_addr() { + Ok(address) => address, + Err(e) => return this.set_last_error_and_return(e, &dest), + }; + + match this.write_socket_address( + &address, + address_ptr, + address_len_ptr, + "getpeername", + )? { + Ok(_) => this.write_scalar(Scalar::from_i32(0), &dest), + Err(e) => this.set_last_error_and_return(e, &dest), + } + } + ), + ) } } @@ -1182,12 +1351,15 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { /// Block the thread until there's an incoming connection or an error occurred. /// /// This recursively calls itself should the operation still block for some reason. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Listening`]. fn block_for_accept( &mut self, + socket: FileDescriptionRef, address_ptr: Pointer, address_len_ptr: Pointer, is_client_sock_nonblock: bool, - socket: FileDescriptionRef, dest: MPlaceTy<'tcx>, ) { let this = self.eval_context_mut(); @@ -1204,89 +1376,83 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } |this, kind: UnblockKind| { assert_eq!(kind, UnblockKind::Ready); - let state = socket.state.borrow(); - - let SocketState::Listening(listener) = &*state else { - // We checked that the socket is in listening state before blocking - // and since there is no outgoing transition from that state this - // should be unreachable. - unreachable!() - }; - - let (stream, addr) = match listener.accept() { - Ok(peer) => peer, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - // We need to block the thread again as it would still block. - drop(state); - this.block_for_accept(address_ptr, address_len_ptr, is_client_sock_nonblock, socket, dest); - return interp_ok(()) + match this.try_non_block_accept(&socket, address_ptr, address_len_ptr, is_client_sock_nonblock)? { + Ok(sockfd) => { + // We need to create the scalar using the destination size since + // `syscall(SYS_accept4, ...)` returns a long which doesn't match + // the int returned from the `accept`/`accept4` syscalls. + // See . + this.write_scalar(Scalar::from_int(sockfd, dest.layout.size), &dest) }, - Err(e) => return this.set_last_error_and_return(e, &dest), - }; - - let family = match addr { - SocketAddr::V4(_) => SocketFamily::IPv4, - SocketAddr::V6(_) => SocketFamily::IPv6, - }; - - if address_ptr != Pointer::null() { - // We only attempt a write if the address pointer is not a null pointer. - // If the address pointer is a null pointer the user isn't interested in the - // address and we don't need to write anything. - if let Err(e) = this.write_socket_address(&addr, address_ptr, address_len_ptr, "accept4")? { - return this.set_last_error_and_return(e, &dest); - }; + Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::WouldBlock => { + // We need to block the thread again as it would still block. + this.block_for_accept(socket, address_ptr, address_len_ptr, is_client_sock_nonblock, dest); + interp_ok(()) + } + Err(e) => this.set_last_error_and_return(e, &dest), } - - let fd = this.machine.fds.new_ref(Socket { - family, - state: RefCell::new(SocketState::Connected(stream)), - is_non_block: Cell::new(is_client_sock_nonblock), - }); - let sockfd = this.machine.fds.insert(fd); - // We need to create the scalar using the destination size since - // `syscall(SYS_accept4, ...)` returns a long which doesn't match - // the int returned from the `accept`/`accept4` syscalls. - // See . - this.write_scalar(Scalar::from_int(sockfd, dest.layout.size), &dest) }), ); } - /// Block the thread until the stream is connected or an error occurred. - fn block_for_connect(&mut self, socket: FileDescriptionRef, dest: MPlaceTy<'tcx>) { + /// Attempt to accept an incoming connection on the listening socket in a + /// non-blocking manner. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Listening`]. + fn try_non_block_accept( + &mut self, + socket: &FileDescriptionRef, + address_ptr: Pointer, + address_len_ptr: Pointer, + is_client_sock_nonblock: bool, + ) -> InterpResult<'tcx, Result> { let this = self.eval_context_mut(); - this.block_thread_for_io( - socket.clone(), - Interest::WRITABLE, - None, - callback!(@capture<'tcx> { - socket: FileDescriptionRef, - dest: MPlaceTy<'tcx>, - } |this, kind: UnblockKind| { - assert_eq!(kind, UnblockKind::Ready); - let mut state = socket.state.borrow_mut(); + let state = socket.state.borrow(); + let SocketState::Listening(listener) = &*state else { + panic!( + "try_non_block_accept must only be called when socket is in `SocketState::Listening`" + ) + }; + + let (stream, addr) = match listener.accept() { + Ok(peer) => peer, + Err(e) => return interp_ok(Err(IoError::HostError(e))), + }; - // We received a "writable" event so `try_set_connected` is safe to call. - match state.try_set_connected() { - Ok(_) => this.write_scalar(Scalar::from_i32(0), &dest), - Err(SocketIoError::NotReady) => { - // We need to block the thread again as the connection is still not yet ready. - drop(state); - this.block_for_connect(socket, dest); - return interp_ok(()) - }, - Err(SocketIoError::Other(e)) => return this.set_last_error_and_return(e, &dest) - } - }), - ); + let family = match addr { + SocketAddr::V4(_) => SocketFamily::IPv4, + SocketAddr::V6(_) => SocketFamily::IPv6, + }; + + if address_ptr != Pointer::null() { + // We only attempt a write if the address pointer is not a null pointer. + // If the address pointer is a null pointer the user isn't interested in the + // address and we don't need to write anything. + if let Err(e) = + this.write_socket_address(&addr, address_ptr, address_len_ptr, "accept4")? + { + return interp_ok(Err(e)); + }; + } + + let fd = this.machine.fds.new_ref(Socket { + family, + state: RefCell::new(SocketState::Connected(stream)), + is_non_block: Cell::new(is_client_sock_nonblock), + }); + let sockfd = this.machine.fds.insert(fd); + interp_ok(Ok(sockfd)) } /// Block the thread until we can send bytes into the connected socket /// or an error occurred. /// /// This recursively calls itself should the operation still block for some reason. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Connected`]. fn block_for_send( &mut self, socket: FileDescriptionRef, @@ -1307,18 +1473,8 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } |this, kind: UnblockKind| { assert_eq!(kind, UnblockKind::Ready); - let mut state = socket.state.borrow_mut(); - let SocketState::Connected(stream) = &mut*state else { - // We ensured that the socket is connected before blocking. - unreachable!() - }; - - // This is a *non-blocking* write. - let result = this.write_to_host(stream, length, buffer_ptr)?; - match result { + match this.try_non_block_send(&socket, buffer_ptr, length)? { Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::WouldBlock => { - // We need to block the thread again as it would still block. - drop(state); this.block_for_send(socket, buffer_ptr, length, finish); interp_ok(()) }, @@ -1328,10 +1484,41 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { ); } + /// Attempt to send bytes into the connected socket in a non-blocking manner. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Connected`]. + fn try_non_block_send( + &mut self, + socket: &FileDescriptionRef, + buffer_ptr: Pointer, + length: usize, + ) -> InterpResult<'tcx, Result> { + let this = self.eval_context_mut(); + + let SocketState::Connected(stream) = &mut *socket.state.borrow_mut() else { + panic!("try_non_block_send must only be called when the socket is connected") + }; + + // This is a *non-blocking* write. + let result = this.write_to_host(stream, length, buffer_ptr)?; + match result { + Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::NotConnected => { + // On Windows hosts, `send` can return WSAENOTCONN where EAGAIN or EWOULDBLOCK + // would be returned on UNIX-like systems. We thus remap this error to an EWOULDBLOCK. + interp_ok(Err(IoError::HostError(io::ErrorKind::WouldBlock.into()))) + } + result => interp_ok(result), + } + } + /// Block the thread until we can receive bytes from the connected socket /// or an error occurred. /// /// This recursively calls itself should the operation still block for some reason. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Connected`]. fn block_for_recv( &mut self, socket: FileDescriptionRef, @@ -1354,24 +1541,9 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { } |this, kind: UnblockKind| { assert_eq!(kind, UnblockKind::Ready); - let mut state = socket.state.borrow_mut(); - let SocketState::Connected(stream) = &mut*state else { - // We ensured that the socket is connected before blocking. - unreachable!() - }; - - // This is a *non-blocking* read/peek. - let result = this.read_from_host(|buf| { - if should_peek { - stream.peek(buf) - } else { - stream.read(buf) - } - }, length, buffer_ptr)?; - match result { + match this.try_non_block_recv(&socket, buffer_ptr, length, should_peek)? { Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::WouldBlock => { // We need to block the thread again as it would still block. - drop(state); this.block_for_recv(socket, buffer_ptr, length, should_peek, finish); interp_ok(()) }, @@ -1380,6 +1552,178 @@ trait EvalContextPrivExt<'tcx>: crate::MiriInterpCxExt<'tcx> { }), ); } + + /// Attempt to receive bytes from the connected socket in a non-blocking manner. + /// + /// **Note**: This function is only safe to call when having previously ensured + /// that the socket is in [`SocketState::Connected`]. + fn try_non_block_recv( + &mut self, + socket: &FileDescriptionRef, + buffer_ptr: Pointer, + length: usize, + should_peek: bool, + ) -> InterpResult<'tcx, Result> { + let this = self.eval_context_mut(); + + let SocketState::Connected(stream) = &mut *socket.state.borrow_mut() else { + panic!("try_non_block_recv must only be called when the socket is connected") + }; + + // This is a *non-blocking* read/peek. + let result = this.read_from_host( + |buf| { + if should_peek { stream.peek(buf) } else { stream.read(buf) } + }, + length, + buffer_ptr, + )?; + match result { + Err(IoError::HostError(e)) if e.kind() == io::ErrorKind::NotConnected => { + // On Windows hosts, `recv` can return WSAENOTCONN where EAGAIN or EWOULDBLOCK + // would be returned on UNIX-like systems. We thus remap this error to an EWOULDBLOCK. + interp_ok(Err(IoError::HostError(io::ErrorKind::WouldBlock.into()))) + } + result => interp_ok(result), + } + } + + // Execute the provided callback function when the socket is either in + // [`SocketState::Connected`] or an error occurred. + /// If the socket is currently neither in the [`SocketState::Connecting`] nor + /// the [`SocketState::Connecting`] state, an ENOTCONN error is returned. + /// When the callback function is called with `Ok(_)`, then we're guaranteed + /// that the socket is in the [`SocketState::Connected`] state. + /// + /// This function can optionally also block until either an error occurred or + /// the socket reached the [`SocketState::Connected`] state. + fn ensure_connected( + &mut self, + socket: FileDescriptionRef, + should_wait: bool, + foreign_name: &'static str, + action: DynMachineCallback<'tcx, Result<(), ()>>, + ) -> InterpResult<'tcx> { + let this = self.eval_context_mut(); + + let state = socket.state.borrow(); + match &*state { + SocketState::Connecting(_) => { /* fall-through to below */ } + SocketState::Connected(_) => { + drop(state); + return action.call(this, Ok(())); + } + _ => { + drop(state); + return action.call(this, Err(())); + } + }; + + drop(state); + + // We're currently connecting. Since the underlying mio socket is non-blocking, + // the only way to determine whether we are done connecting is by polling. + // If we should wait until the connection is established, the timeout is `None`. + // Otherwise, we use a zero duration timeout, i.e. we return immediately + // (but we still go through the scheduler once -- which is fine). + let timeout = if should_wait { + None + } else { + Some((TimeoutClock::Monotonic, TimeoutAnchor::Absolute, Duration::ZERO)) + }; + + this.block_thread_for_io( + socket.clone(), + Interest::WRITABLE, + timeout, + callback!( + @capture<'tcx> { + socket: FileDescriptionRef, + should_wait: bool, + foreign_name: &'static str, + action: DynMachineCallback<'tcx, Result<(), ()>>, + } |this, kind: UnblockKind| { + if UnblockKind::TimedOut == kind { + // We can only time out when `should_wait` is false. + // This then means that the socket is not yet connected. + assert!(!should_wait); + this.machine.blocking_io.deregister(socket.id(), InterestReceiver::UnblockThread(this.active_thread())); + return action.call(this, Err(())) + } + + // The thread woke up because it's ready, indicating a writeable or error event. + + let mut state = socket.state.borrow_mut(); + let stream = match &*state { + SocketState::Connecting(stream) => stream, + SocketState::Connected(_) => { + drop(state); + // This can happen because we blocked the thread: + // maybe another thread "upgraded" the connection in the meantime. + return action.call(this, Ok(())) + }, + _ => { + drop(state); + // We ensured that we only block when we're currently connecting. + // Since this thread just got rescheduled, it could be that another + // thread realized that the connection failed and we're thus in + // an "invalid state". + return action.call(this, Err(())) + } + }; + + // Manually check whether there were any errors since calling `connect`. + if let Ok(Some(_)) = stream.take_error() { + // There was an error during connecting and thus we + // return ENOTCONN. It's the program's responsibility + // to read SO_ERROR itself. + // + // Go back to initial state since the only way of getting into the + // `Connecting` state is from the `Initial` state and at this point + // we know that the connection won't be established anymore. + // + // FIXME: We're currently just dropping the error information. Eventually + // we'll have to store it so that it can be recovered by the user. + *state = SocketState::Initial; + drop(state); + return action.call(this, Err(())) + } + + // There was no error during connecting. We still need to ensure that + // the wakeup wasn't spurious. We do this by attempting to read the + // peer address of the socket (following the advice given by mio): + // + + match stream.peer_addr() { + Ok(_) => { /* fall-through to below */}, + Err(e) if matches!(e.kind(), io::ErrorKind::NotConnected | io::ErrorKind::InProgress) => { + // We received a spurious wakeup from the OS. This should be considered an OS bug: + // + panic!("{foreign_name}: received writable event from OS but socket is not yet connected") + }, + Err(_) => { + // For all other errors the socket is connected. Since we're not interested in the + // peer address and only want to know whether the socket is connected, we can ignore + // the error and continue. + } + } + + // The connection is established. + + // Temporarily use dummy state to take ownership of the stream. + let SocketState::Connecting(stream) = std::mem::replace(&mut*state, SocketState::Initial) else { + // At the start of the function we ensured that we're currently connecting. + unreachable!() + }; + *state = SocketState::Connected(stream); + drop(state); + action.call(this, Ok(())) + } + ), + ); + + interp_ok(()) + } } impl VisitProvenance for FileDescriptionRef { diff --git a/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_double_free.rs b/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_double_free.rs index c18675931719f..c1a1b1523d11f 100644 --- a/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_double_free.rs +++ b/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_double_free.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that we can detect a double-free bug across two threads, which only shows up if the second thread reads an atomic pointer at a very specific moment. // GenMC can detect this error consistently, without having to run the buggy code with multiple RNG seeds or in a loop. diff --git a/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_invalid_provenance.rs b/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_invalid_provenance.rs index 87223e990bde9..06384c4308c99 100644 --- a/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_invalid_provenance.rs +++ b/src/tools/miri/tests/genmc/fail/atomics/atomic_ptr_invalid_provenance.rs @@ -1,5 +1,4 @@ //@revisions: send make -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows // Test that we can distinguish two pointers with the same address, but different provenance, after they are sent to GenMC and back. // We have two variants, one where we send such a pointer to GenMC, and one where we make it on the GenMC side. diff --git a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.dealloc.stderr b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.dealloc.stderr index 7534eaf8f37ec..aa51e1213d0c9 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.dealloc.stderr +++ b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.dealloc.stderr @@ -1,5 +1,5 @@ Running GenMC Verification... -error: Undefined Behavior: Attempt to access freed memory +error: Undefined Behavior: Attempt to access non-allocated memory --> tests/genmc/fail/data_race/atomic_ptr_alloc_race.rs:LL:CC | LL | dealloc(b as *mut u8, Layout::new::()); diff --git a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.rs b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.rs index e453c16b157d2..660060b2dd404 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.rs +++ b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_alloc_race.rs @@ -1,5 +1,5 @@ //@revisions: write dealloc -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-ignore-leaks +//@compile-flags: -Zmiri-ignore-leaks // Test that we can detect data races between an allocation and an unsynchronized action in another thread. // We have two variants, an alloc-dealloc race and an alloc-write race. diff --git a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_dealloc_write_race.rs b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_dealloc_write_race.rs index 10e0d8d854c25..2e5ae05f5c8d5 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_dealloc_write_race.rs +++ b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_dealloc_write_race.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that use-after-free bugs involving atomic pointers are detected in GenMC mode. #![no_main] diff --git a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_write_dealloc_race.rs b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_write_dealloc_race.rs index e2d3057a5b0df..768f5e8b40fd9 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_write_dealloc_race.rs +++ b/src/tools/miri/tests/genmc/fail/data_race/atomic_ptr_write_dealloc_race.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that use-after-free bugs involving atomic pointers are detected in GenMC mode. // Compared to `atomic_ptr_dealloc_write_race.rs`, this variant checks that the data race is still detected, even if the write happens before the free. // diff --git a/src/tools/miri/tests/genmc/fail/data_race/mpu2_rels_rlx.rs b/src/tools/miri/tests/genmc/fail/data_race/mpu2_rels_rlx.rs index 32954a643b34b..c02895628a004 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/mpu2_rels_rlx.rs +++ b/src/tools/miri/tests/genmc/fail/data_race/mpu2_rels_rlx.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test `wrong/racy/MPU2+rels+rlx`. // Test if Miri with GenMC can detect the data race on `X`. // The data race only occurs if thread 1 finishes, then threads 3 and 4 run, then thread 2. diff --git a/src/tools/miri/tests/genmc/fail/data_race/weak_orderings.rs b/src/tools/miri/tests/genmc/fail/data_race/weak_orderings.rs index 1568a302f85ad..2f2725e3742f1 100644 --- a/src/tools/miri/tests/genmc/fail/data_race/weak_orderings.rs +++ b/src/tools/miri/tests/genmc/fail/data_race/weak_orderings.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@revisions: rlx_rlx rlx_acq rel_rlx // Translated from GenMC's test `wrong/racy/MP+rel+rlx`, `MP+rlx+acq` and `MP+rlx+rlx`. diff --git a/src/tools/miri/tests/genmc/fail/loom/buggy_inc.rs b/src/tools/miri/tests/genmc/fail/loom/buggy_inc.rs index 2e614e6a360ba..f205582771c7e 100644 --- a/src/tools/miri/tests/genmc/fail/loom/buggy_inc.rs +++ b/src/tools/miri/tests/genmc/fail/loom/buggy_inc.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: Copyright (c) 2019 Carl Lerche diff --git a/src/tools/miri/tests/genmc/fail/loom/store_buffering.non_genmc.stderr b/src/tools/miri/tests/genmc/fail/loom/store_buffering.non_genmc.stderr deleted file mode 100644 index 487ab21f28b32..0000000000000 --- a/src/tools/miri/tests/genmc/fail/loom/store_buffering.non_genmc.stderr +++ /dev/null @@ -1,12 +0,0 @@ -error: abnormal termination: the program aborted execution - --> tests/genmc/fail/loom/store_buffering.rs:LL:CC - | -LL | std::process::abort(); - | ^^^^^^^^^^^^^^^^^^^^^ abnormal termination occurred here - | - = note: this is on thread `main` - -note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace - -error: aborting due to 1 previous error - diff --git a/src/tools/miri/tests/genmc/fail/loom/store_buffering.rs b/src/tools/miri/tests/genmc/fail/loom/store_buffering.rs index fc522dd013fa1..7560644e1df94 100644 --- a/src/tools/miri/tests/genmc/fail/loom/store_buffering.rs +++ b/src/tools/miri/tests/genmc/fail/loom/store_buffering.rs @@ -1,15 +1,9 @@ -//@ revisions: non_genmc genmc -//@[genmc] compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // SPDX-License-Identifier: MIT // SPDX-FileCopyrightText: Copyright (c) 2019 Carl Lerche // This is the test `store_buffering` from `loom/test/litmus.rs`, adapted for Miri-GenMC. // https://github.com/tokio-rs/loom/blob/dbf32b04bae821c64be44405a0bb72ca08741558/tests/litmus.rs -// This test shows the comparison between running Miri with or without GenMC. -// Without GenMC, Miri requires multiple iterations of the loop to detect the error. - #![no_main] #[path = "../../../utils/genmc.rs"] @@ -23,30 +17,27 @@ use crate::genmc::*; #[unsafe(no_mangle)] fn miri_start(_argc: isize, _argv: *const *const u8) -> isize { // For normal Miri, we need multiple repetitions, but GenMC should find the bug with only 1. - const REPS: usize = if cfg!(non_genmc) { 128 } else { 1 }; - for _ in 0..REPS { - // New atomics every iterations, so they don't influence each other. - let x = AtomicUsize::new(0); - let y = AtomicUsize::new(0); - - let mut a: usize = 1234; - let mut b: usize = 1234; - unsafe { - let ids = [ - spawn_pthread_closure(|| { - x.store(1, Relaxed); - a = y.load(Relaxed) - }), - spawn_pthread_closure(|| { - y.store(1, Relaxed); - b = x.load(Relaxed) - }), - ]; - join_pthreads(ids); - } - if (a, b) == (0, 0) { - std::process::abort(); //~ ERROR: abnormal termination - } + + let x = AtomicUsize::new(0); + let y = AtomicUsize::new(0); + + let mut a: usize = 1234; + let mut b: usize = 1234; + unsafe { + let ids = [ + spawn_pthread_closure(|| { + x.store(1, Relaxed); + a = y.load(Relaxed) + }), + spawn_pthread_closure(|| { + y.store(1, Relaxed); + b = x.load(Relaxed) + }), + ]; + join_pthreads(ids); + } + if (a, b) == (0, 0) { + std::process::abort(); //~ ERROR: abnormal termination } 0 diff --git a/src/tools/miri/tests/genmc/fail/loom/store_buffering.genmc.stderr b/src/tools/miri/tests/genmc/fail/loom/store_buffering.stderr similarity index 78% rename from src/tools/miri/tests/genmc/fail/loom/store_buffering.genmc.stderr rename to src/tools/miri/tests/genmc/fail/loom/store_buffering.stderr index 176ab6a573c87..3273c23ea39eb 100644 --- a/src/tools/miri/tests/genmc/fail/loom/store_buffering.genmc.stderr +++ b/src/tools/miri/tests/genmc/fail/loom/store_buffering.stderr @@ -2,8 +2,8 @@ Running GenMC Verification... error: abnormal termination: the program aborted execution --> tests/genmc/fail/loom/store_buffering.rs:LL:CC | -LL | std::process::abort(); - | ^^^^^^^^^^^^^^^^^^^^^ abnormal termination occurred here +LL | std::process::abort(); + | ^^^^^^^^^^^^^^^^^^^^^ abnormal termination occurred here | = note: this is on thread `main` diff --git a/src/tools/miri/tests/genmc/fail/shims/exit.rs b/src/tools/miri/tests/genmc/fail/shims/exit.rs index 4138f4e785bbf..8c0cc9b3b1c7c 100644 --- a/src/tools/miri/tests/genmc/fail/shims/exit.rs +++ b/src/tools/miri/tests/genmc/fail/shims/exit.rs @@ -1,5 +1,3 @@ -//@ compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - fn main() { std::thread::spawn(|| { unsafe { std::hint::unreachable_unchecked() }; //~ERROR: entering unreachable code diff --git a/src/tools/miri/tests/genmc/fail/shims/mutex_diff_thread_unlock.rs b/src/tools/miri/tests/genmc/fail/shims/mutex_diff_thread_unlock.rs index d2da722f1c02f..e499a26a9d765 100644 --- a/src/tools/miri/tests/genmc/fail/shims/mutex_diff_thread_unlock.rs +++ b/src/tools/miri/tests/genmc/fail/shims/mutex_diff_thread_unlock.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@error-in-other-file: Undefined Behavior // Test that GenMC throws an error if a `std::sync::Mutex` is unlocked from a different thread than the one that locked it. diff --git a/src/tools/miri/tests/genmc/fail/shims/mutex_double_unlock.rs b/src/tools/miri/tests/genmc/fail/shims/mutex_double_unlock.rs index 3daff38efbfdd..d1801fd0ee69c 100644 --- a/src/tools/miri/tests/genmc/fail/shims/mutex_double_unlock.rs +++ b/src/tools/miri/tests/genmc/fail/shims/mutex_double_unlock.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@error-in-other-file: Undefined Behavior // Test that GenMC can detect a double unlock of a mutex. diff --git a/src/tools/miri/tests/genmc/fail/simple/2w2w_weak.rs b/src/tools/miri/tests/genmc/fail/simple/2w2w_weak.rs index baf3584966ec8..9a2f5c78dac4a 100644 --- a/src/tools/miri/tests/genmc/fail/simple/2w2w_weak.rs +++ b/src/tools/miri/tests/genmc/fail/simple/2w2w_weak.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@revisions: sc3_rel1 release4 relaxed4 // The pass tests "2w2w_3sc_1rel.rs", "2w2w_4rel" and "2w2w_4sc" and the fail test "2w2w_weak.rs" are related. diff --git a/src/tools/miri/tests/genmc/fail/simple/alloc_large.rs b/src/tools/miri/tests/genmc/fail/simple/alloc_large.rs index 27d92bf66d424..da0a902ef4a49 100644 --- a/src/tools/miri/tests/genmc/fail/simple/alloc_large.rs +++ b/src/tools/miri/tests/genmc/fail/simple/alloc_large.rs @@ -1,5 +1,4 @@ //@revisions: single multiple -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@error-in-other-file: resource exhaustion // Ensure that we emit a proper error if GenMC fails to fulfill an allocation. diff --git a/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_ops.rs b/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_ops.rs index aa40e193dbfd0..75ea56a4277f6 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_ops.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_ops.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test several operations on atomic pointers. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_roundtrip.rs b/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_roundtrip.rs index d846a55cbc31a..1a3f17319b290 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_roundtrip.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/atomic_ptr_roundtrip.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that we can send pointers with any alignment to GenMC and back, even across threads. // After a round-trip, the pointers should still work properly (no missing provenance). diff --git a/src/tools/miri/tests/genmc/pass/atomics/cas_failure_ord_racy_key_init.rs b/src/tools/miri/tests/genmc/pass/atomics/cas_failure_ord_racy_key_init.rs index e17a988cb3736..cf20482092ede 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/cas_failure_ord_racy_key_init.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/cas_failure_ord_racy_key_init.rs @@ -1,4 +1,4 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-ignore-leaks +//@compile-flags: -Zmiri-ignore-leaks // Adapted from: `impl LazyKey`, `fn lazy_init`: rust/library/std/src/sys/thread_local/key/racy.rs // Two threads race to initialize a key, which is just an index into an array in this test. diff --git a/src/tools/miri/tests/genmc/pass/atomics/cas_simple.rs b/src/tools/miri/tests/genmc/pass/atomics/cas_simple.rs index c19c81995d1bc..5ef1dd2352d68 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/cas_simple.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/cas_simple.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test the basic functionality of compare_exchange. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/atomics/mixed_atomic_non_atomic.rs b/src/tools/miri/tests/genmc/pass/atomics/mixed_atomic_non_atomic.rs index 7601b354b1c00..9e809e93af2c9 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/mixed_atomic_non_atomic.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/mixed_atomic_non_atomic.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that we can read the value of a non-atomic store atomically and an of an atomic value non-atomically. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/atomics/read_initial_value.rs b/src/tools/miri/tests/genmc/pass/atomics/read_initial_value.rs index 18e039fdd0dfe..55ae510b92c11 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/read_initial_value.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/read_initial_value.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Test that we can read the initial value of global, heap and stack allocations in GenMC mode. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/atomics/rmw_ops.rs b/src/tools/miri/tests/genmc/pass/atomics/rmw_ops.rs index 411207b79b7e6..9ebcdd9dfae86 100644 --- a/src/tools/miri/tests/genmc/pass/atomics/rmw_ops.rs +++ b/src/tools/miri/tests/genmc/pass/atomics/rmw_ops.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // This test check for correct handling of atomic read-modify-write operations for all integer sizes. // Atomic max and min should return the previous value, and store the result in the atomic. // Atomic addition and subtraction should have wrapping semantics. diff --git a/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.rs b/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.rs new file mode 100644 index 0000000000000..b38c855ae0ef1 --- /dev/null +++ b/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.rs @@ -0,0 +1,22 @@ +//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows + +// Tests mixed-size non-atomic accesses. + +#![no_main] + +use std::sync::atomic::*; + +#[unsafe(no_mangle)] +fn miri_start(_argc: isize, _argv: *const *const u8) -> isize { + let mut data = 0u64; + // Treat this like an array of two AtomicI32. + let atomics = unsafe { &*(&raw mut data as *mut u64 as *mut [AtomicI32; 2]) }; + + atomics[0].load(Ordering::SeqCst); + atomics[1].store(-1, Ordering::SeqCst); + atomics[0].store(-1, Ordering::Relaxed); + + assert_eq!(data, u64::MAX); + + 0 +} diff --git a/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.stderr b/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.stderr new file mode 100644 index 0000000000000..7867be2dbe8ed --- /dev/null +++ b/src/tools/miri/tests/genmc/pass/atomics/u64_as_u32_array.stderr @@ -0,0 +1,2 @@ +Running GenMC Verification... +Verification complete with 1 executions. No errors found. diff --git a/src/tools/miri/tests/genmc/pass/data-structures/ms_queue_dynamic.rs b/src/tools/miri/tests/genmc/pass/data-structures/ms_queue_dynamic.rs index 934fc977366dc..e760966696d15 100644 --- a/src/tools/miri/tests/genmc/pass/data-structures/ms_queue_dynamic.rs +++ b/src/tools/miri/tests/genmc/pass/data-structures/ms_queue_dynamic.rs @@ -1,5 +1,5 @@ //@ revisions: default_R1W1 default_R1W2 spinloop_assume_R1W1 spinloop_assume_R1W2 -//@compile-flags: -Zmiri-ignore-leaks -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-verbose +//@compile-flags: -Zmiri-ignore-leaks -Zmiri-genmc-verbose //@normalize-stderr-test: "Verification took .*s" -> "Verification took [TIME]s" // This test is a translations of the GenMC test `ms-queue-dynamic`, but with all code related to GenMC's hazard pointer API removed. diff --git a/src/tools/miri/tests/genmc/pass/data-structures/treiber_stack_dynamic.rs b/src/tools/miri/tests/genmc/pass/data-structures/treiber_stack_dynamic.rs index 8bdd2a371f51c..2e64faea7f098 100644 --- a/src/tools/miri/tests/genmc/pass/data-structures/treiber_stack_dynamic.rs +++ b/src/tools/miri/tests/genmc/pass/data-structures/treiber_stack_dynamic.rs @@ -1,5 +1,5 @@ //@ revisions: default_R1W1 default_R1W2 default_R1W3 spinloop_assume_R1W1 spinloop_assume_R1W2 spinloop_assume_R1W3 -//@compile-flags: -Zmiri-ignore-leaks -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-verbose +//@compile-flags: -Zmiri-ignore-leaks -Zmiri-genmc-verbose //@normalize-stderr-test: "Verification took .*s" -> "Verification took [TIME]s" // This test is a translations of the GenMC test `treiber-stack-dynamic`, but with all code related to GenMC's hazard pointer API removed. diff --git a/src/tools/miri/tests/genmc/pass/litmus/2cowr.rs b/src/tools/miri/tests/genmc/pass/litmus/2cowr.rs index d9b582bb4362d..c14a300781e63 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/2cowr.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/2cowr.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "2CoWR". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/2w2w_2sc_scf.rs b/src/tools/miri/tests/genmc/pass/litmus/2w2w_2sc_scf.rs index 3b3fca02285d6..a61a1e0c3164d 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/2w2w_2sc_scf.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/2w2w_2sc_scf.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "2+2W+2sc+scf". // It tests correct handling of SeqCst fences combined with relaxed accesses. diff --git a/src/tools/miri/tests/genmc/pass/litmus/2w2w_3sc_1rel.rs b/src/tools/miri/tests/genmc/pass/litmus/2w2w_3sc_1rel.rs index 22fe9524c37f5..d587c42de18be 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/2w2w_3sc_1rel.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/2w2w_3sc_1rel.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@revisions: release1 release2 // Translated from GenMC's test "2+2W+3sc+rel1" and "2+2W+3sc+rel2" (two variants that swap which store is `Release`). diff --git a/src/tools/miri/tests/genmc/pass/litmus/2w2w_4rel.rs b/src/tools/miri/tests/genmc/pass/litmus/2w2w_4rel.rs index f47f5a11c5c96..b0919dc6caacf 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/2w2w_4rel.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/2w2w_4rel.rs @@ -1,5 +1,4 @@ //@revisions: weak sc -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@[sc]compile-flags: -Zmiri-disable-weak-memory-emulation // Translated from GenMC's test "2+2W". diff --git a/src/tools/miri/tests/genmc/pass/litmus/2w2w_4sc.rs b/src/tools/miri/tests/genmc/pass/litmus/2w2w_4sc.rs index c5711ba04fce2..3c61adea643cb 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/2w2w_4sc.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/2w2w_4sc.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "2+2W+4c". // // The pass tests "2w2w_3sc_1rel.rs", "2w2w_4rel" and "2w2w_4sc" and the fail test "2w2w_weak.rs" are related. diff --git a/src/tools/miri/tests/genmc/pass/litmus/IRIW-acq-sc.rs b/src/tools/miri/tests/genmc/pass/litmus/IRIW-acq-sc.rs index 6d2dfd4f273e2..9a7c1c9ff5944 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/IRIW-acq-sc.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/IRIW-acq-sc.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/IRIW-acq-sc" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/IRIWish.rs b/src/tools/miri/tests/genmc/pass/litmus/IRIWish.rs index 6f1d37962d10a..b9f1125233edf 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/IRIWish.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/IRIWish.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/IRIWish" test. // This test prints the values read by the different threads to check that we get all the values we expect. diff --git a/src/tools/miri/tests/genmc/pass/litmus/IRIWish.stderr b/src/tools/miri/tests/genmc/pass/litmus/IRIWish.stderr index 7ea2dd5085136..27a04677868ba 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/IRIWish.stderr +++ b/src/tools/miri/tests/genmc/pass/litmus/IRIWish.stderr @@ -1,25 +1,4 @@ Running GenMC Verification... -[1, 1, 1, 1, 1] -[1, 1, 1, 0, 1] -[1, 1, 1, 0, 0] -[1, 1, 0, 1, 1] -[1, 1, 0, 0, 1] -[1, 1, 0, 0, 0] -[1, 0, 1, 1, 1] -[1, 0, 1, 0, 1] -[1, 0, 1, 0, 0] -[1, 0, 0, 1, 1] -[1, 0, 0, 0, 1] -[1, 0, 0, 0, 0] -[0, 1, 0, 0, 1] -[0, 1, 0, 0, 0] -[0, 1, 0, 0, 1] -[0, 1, 0, 0, 0] -[0, 1, 0, 0, 1] -[0, 1, 0, 0, 0] -[0, 1, 0, 0, 1] -[0, 1, 0, 0, 0] -[0, 0, 0, 0, 1] [0, 0, 0, 0, 0] [0, 0, 0, 0, 1] [0, 0, 0, 0, 0] @@ -27,4 +6,25 @@ Running GenMC Verification... [0, 0, 0, 0, 0] [0, 0, 0, 0, 1] [0, 0, 0, 0, 0] +[0, 0, 0, 0, 1] +[0, 1, 0, 0, 0] +[0, 1, 0, 0, 1] +[0, 1, 0, 0, 0] +[0, 1, 0, 0, 1] +[0, 1, 0, 0, 0] +[0, 1, 0, 0, 1] +[0, 1, 0, 0, 0] +[0, 1, 0, 0, 1] +[1, 0, 0, 0, 0] +[1, 0, 0, 0, 1] +[1, 0, 0, 1, 1] +[1, 0, 1, 0, 0] +[1, 0, 1, 0, 1] +[1, 0, 1, 1, 1] +[1, 1, 0, 0, 0] +[1, 1, 0, 0, 1] +[1, 1, 0, 1, 1] +[1, 1, 1, 0, 0] +[1, 1, 1, 0, 1] +[1, 1, 1, 1, 1] Verification complete with 28 executions. No errors found. diff --git a/src/tools/miri/tests/genmc/pass/litmus/LB.rs b/src/tools/miri/tests/genmc/pass/litmus/LB.rs index 107121ef4e3c7..4cc1209326fff 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/LB.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/LB.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/LB" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/LB_incMPs.rs b/src/tools/miri/tests/genmc/pass/litmus/LB_incMPs.rs index e43d92fc6c55a..eea6bf62cc27a 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/LB_incMPs.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/LB_incMPs.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/LB+incMPs" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/MP.rs b/src/tools/miri/tests/genmc/pass/litmus/MP.rs index 5f9d1b01c37b0..a6ec6c2b29cc6 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MP.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/MP.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/MP" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.rs b/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.rs index 6f812bf8a8ac9..b065c57698435 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/MPU2+rels+acqf" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.stderr b/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.stderr index 29b59ce3bc1a3..ee111c2ce8b6e 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.stderr +++ b/src/tools/miri/tests/genmc/pass/litmus/MPU2_rels_acqf.stderr @@ -1,38 +1,38 @@ Running GenMC Verification... -X=1, Y=2, a=Err(1), b=Ok(1), c=2 -X=1, Y=2, a=Err(1), b=Ok(1), c=1 -X=1, Y=2, a=Err(1), b=Ok(1), c=0 -X=1, Y=2, a=Err(1), b=Ok(1), c=0 -X=2, Y=3, a=Ok(2), b=Ok(1), c=3 -X=1, Y=3, a=Ok(2), b=Ok(1), c=3 -X=1, Y=3, a=Ok(2), b=Ok(1), c=2 -X=1, Y=3, a=Ok(2), b=Ok(1), c=1 -X=1, Y=3, a=Ok(2), b=Ok(1), c=0 -X=1, Y=3, a=Ok(2), b=Ok(1), c=0 -X=1, Y=1, a=Err(1), b=Err(0), c=1 -X=1, Y=1, a=Err(1), b=Err(0), c=0 -X=1, Y=1, a=Err(1), b=Err(0), c=0 -X=1, Y=1, a=Err(1), b=Err(0), c=1 -X=1, Y=1, a=Err(1), b=Err(0), c=0 -X=1, Y=1, a=Err(1), b=Err(0), c=0 -X=1, Y=2, a=Err(0), b=Ok(1), c=2 -X=1, Y=2, a=Err(0), b=Ok(1), c=1 -X=1, Y=2, a=Err(0), b=Ok(1), c=0 -X=1, Y=2, a=Err(0), b=Ok(1), c=0 -X=1, Y=1, a=Err(0), b=Err(0), c=1 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=1 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=0 -X=1, Y=2, a=Err(0), b=Ok(1), c=2 -X=1, Y=2, a=Err(0), b=Ok(1), c=1 +X=1, Y=1, a=Err(0), b=Err(0), c=1 X=1, Y=2, a=Err(0), b=Ok(1), c=0 X=1, Y=2, a=Err(0), b=Ok(1), c=0 -X=1, Y=1, a=Err(0), b=Err(0), c=1 +X=1, Y=2, a=Err(0), b=Ok(1), c=1 +X=1, Y=2, a=Err(0), b=Ok(1), c=2 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=1 X=1, Y=1, a=Err(0), b=Err(0), c=0 X=1, Y=1, a=Err(0), b=Err(0), c=0 +X=1, Y=1, a=Err(0), b=Err(0), c=1 +X=1, Y=2, a=Err(0), b=Ok(1), c=0 +X=1, Y=2, a=Err(0), b=Ok(1), c=0 +X=1, Y=2, a=Err(0), b=Ok(1), c=1 +X=1, Y=2, a=Err(0), b=Ok(1), c=2 +X=1, Y=1, a=Err(1), b=Err(0), c=0 +X=1, Y=1, a=Err(1), b=Err(0), c=0 +X=1, Y=1, a=Err(1), b=Err(0), c=1 +X=1, Y=1, a=Err(1), b=Err(0), c=0 +X=1, Y=1, a=Err(1), b=Err(0), c=0 +X=1, Y=1, a=Err(1), b=Err(0), c=1 +X=1, Y=2, a=Err(1), b=Ok(1), c=0 +X=1, Y=2, a=Err(1), b=Ok(1), c=0 +X=1, Y=2, a=Err(1), b=Ok(1), c=1 +X=1, Y=2, a=Err(1), b=Ok(1), c=2 +X=1, Y=3, a=Ok(2), b=Ok(1), c=0 +X=1, Y=3, a=Ok(2), b=Ok(1), c=0 +X=1, Y=3, a=Ok(2), b=Ok(1), c=1 +X=1, Y=3, a=Ok(2), b=Ok(1), c=2 +X=2, Y=3, a=Ok(2), b=Ok(1), c=3 +X=1, Y=3, a=Ok(2), b=Ok(1), c=3 Verification complete with 36 executions. No errors found. diff --git a/src/tools/miri/tests/genmc/pass/litmus/MPU_rels_acq.rs b/src/tools/miri/tests/genmc/pass/litmus/MPU_rels_acq.rs index 4f20b2cf9def2..5347114342183 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MPU_rels_acq.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/MPU_rels_acq.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/MPU+rels+acq" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/MP_incMPs.rs b/src/tools/miri/tests/genmc/pass/litmus/MP_incMPs.rs index a08b7de27d13c..143fc6352f476 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MP_incMPs.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/MP_incMPs.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/MP+incMP" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/MP_rels_acqf.rs b/src/tools/miri/tests/genmc/pass/litmus/MP_rels_acqf.rs index 19065d3308f8c..3ad6916d6bf94 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/MP_rels_acqf.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/MP_rels_acqf.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/MP+rels+acqf" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/SB.rs b/src/tools/miri/tests/genmc/pass/litmus/SB.rs index 74d45c22a2953..2d017f5f30e60 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/SB.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/SB.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/SB" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/SB_2sc_scf.rs b/src/tools/miri/tests/genmc/pass/litmus/SB_2sc_scf.rs index ffc44de1bc7cf..ff6597a38004c 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/SB_2sc_scf.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/SB_2sc_scf.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/SB+2sc+scf" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.rs b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.rs index cbbaa82d6fb5a..9f59cbc293c10 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.rs @@ -1,5 +1,4 @@ //@revisions: weak sc -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@[sc]compile-flags: -Zmiri-disable-weak-memory-emulation // Translated from GenMC's "litmus/Z6.U" test. diff --git a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.sc.stderr b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.sc.stderr index c8fbb8951a386..b2f15c2b1d3f8 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.sc.stderr +++ b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.sc.stderr @@ -1,20 +1,20 @@ Running GenMC Verification... +a=1, b=1, X=1, Y=3 +a=1, b=0, X=1, Y=1 +a=1, b=1, X=1, Y=1 +a=1, b=1, X=1, Y=3 +a=3, b=1, X=1, Y=3 +a=1, b=0, X=1, Y=1 +a=1, b=1, X=1, Y=1 +a=3, b=0, X=1, Y=1 +a=3, b=1, X=1, Y=1 a=2, b=1, X=1, Y=3 a=4, b=1, X=1, Y=4 a=3, b=1, X=1, Y=3 -a=2, b=1, X=1, Y=2 a=2, b=0, X=1, Y=2 -a=1, b=1, X=1, Y=1 -a=1, b=0, X=1, Y=1 -a=4, b=1, X=1, Y=1 +a=2, b=1, X=1, Y=2 a=4, b=0, X=1, Y=1 -a=1, b=1, X=1, Y=3 -a=3, b=1, X=1, Y=3 -a=1, b=1, X=1, Y=1 +a=4, b=1, X=1, Y=1 a=1, b=0, X=1, Y=1 -a=3, b=1, X=1, Y=1 -a=3, b=0, X=1, Y=1 -a=1, b=1, X=1, Y=3 a=1, b=1, X=1, Y=1 -a=1, b=0, X=1, Y=1 Verification complete with 18 executions. No errors found. diff --git a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.weak.stderr b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.weak.stderr index 72c59d33f77cf..d92a93ff02c21 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/Z6_U.weak.stderr +++ b/src/tools/miri/tests/genmc/pass/litmus/Z6_U.weak.stderr @@ -1,24 +1,24 @@ Running GenMC Verification... +a=1, b=0, X=1, Y=3 +a=1, b=1, X=1, Y=3 +a=1, b=0, X=1, Y=1 +a=1, b=1, X=1, Y=1 +a=1, b=0, X=1, Y=3 +a=1, b=1, X=1, Y=3 +a=3, b=0, X=1, Y=3 +a=3, b=1, X=1, Y=3 +a=1, b=0, X=1, Y=1 +a=1, b=1, X=1, Y=1 +a=3, b=0, X=1, Y=1 +a=3, b=1, X=1, Y=1 a=2, b=1, X=1, Y=3 -a=4, b=1, X=1, Y=4 a=4, b=0, X=1, Y=4 +a=4, b=1, X=1, Y=4 a=3, b=1, X=1, Y=3 -a=2, b=1, X=1, Y=2 a=2, b=0, X=1, Y=2 -a=1, b=1, X=1, Y=1 -a=1, b=0, X=1, Y=1 -a=4, b=1, X=1, Y=1 +a=2, b=1, X=1, Y=2 a=4, b=0, X=1, Y=1 -a=1, b=1, X=1, Y=3 -a=1, b=0, X=1, Y=3 -a=3, b=1, X=1, Y=3 -a=3, b=0, X=1, Y=3 -a=1, b=1, X=1, Y=1 +a=4, b=1, X=1, Y=1 a=1, b=0, X=1, Y=1 -a=3, b=1, X=1, Y=1 -a=3, b=0, X=1, Y=1 -a=1, b=1, X=1, Y=3 -a=1, b=0, X=1, Y=3 a=1, b=1, X=1, Y=1 -a=1, b=0, X=1, Y=1 Verification complete with 22 executions. No errors found. diff --git a/src/tools/miri/tests/genmc/pass/litmus/Z6_acq.rs b/src/tools/miri/tests/genmc/pass/litmus/Z6_acq.rs index b00f3a59ce672..8390db963d011 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/Z6_acq.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/Z6_acq.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/Z6+acq" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/atomicpo.rs b/src/tools/miri/tests/genmc/pass/litmus/atomicpo.rs index 75be89893dab8..ccbe91749b502 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/atomicpo.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/atomicpo.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/atomicpo". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/casdep.rs b/src/tools/miri/tests/genmc/pass/litmus/casdep.rs index 8b8f6e793c1f7..f2f15d1afd555 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/casdep.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/casdep.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/casdep". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/ccr.rs b/src/tools/miri/tests/genmc/pass/litmus/ccr.rs index 4537f3d6830ce..dc803aa44efa1 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/ccr.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/ccr.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/ccr". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/cii.rs b/src/tools/miri/tests/genmc/pass/litmus/cii.rs index 18f56860f9604..761e1cd021605 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/cii.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/cii.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/cii". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/corr.rs b/src/tools/miri/tests/genmc/pass/litmus/corr.rs index b586e2e0fa8a8..6baebeddac808 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/corr.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/corr.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoRR" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/corr0.rs b/src/tools/miri/tests/genmc/pass/litmus/corr0.rs index 856d566ca8bd3..285282726ace7 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/corr0.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/corr0.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoRR0" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/corr1.rs b/src/tools/miri/tests/genmc/pass/litmus/corr1.rs index ccd849802911e..03aaabcf0581b 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/corr1.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/corr1.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoRR1" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/corr2.rs b/src/tools/miri/tests/genmc/pass/litmus/corr2.rs index 36616bf36371f..e86cf78109e3d 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/corr2.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/corr2.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoRR2" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/corw.rs b/src/tools/miri/tests/genmc/pass/litmus/corw.rs index 9216a4f8368f6..053fb66fa42e8 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/corw.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/corw.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoRW" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/cowr.rs b/src/tools/miri/tests/genmc/pass/litmus/cowr.rs index 1c51f23a09c66..0bb9461a6f7a9 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/cowr.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/cowr.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "CoWR" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/cumul-release.rs b/src/tools/miri/tests/genmc/pass/litmus/cumul-release.rs index 4034f7634e870..66490ed51b464 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/cumul-release.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/cumul-release.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/cumul-release". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/default.rs b/src/tools/miri/tests/genmc/pass/litmus/default.rs index 0ab26dce419ad..6a4d1a9cefc70 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/default.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/default.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/default" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/detour.rs b/src/tools/miri/tests/genmc/pass/litmus/detour.rs index 85c456d5c54e5..d59b5efa286f8 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/detour.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/detour.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@revisions: join no_join // Translated from GenMC's "litmus/detour" test. diff --git a/src/tools/miri/tests/genmc/pass/litmus/fr_w_w_w_reads.rs b/src/tools/miri/tests/genmc/pass/litmus/fr_w_w_w_reads.rs index c8d3d409cf04d..04723fb928159 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/fr_w_w_w_reads.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/fr_w_w_w_reads.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "fr+w+w+w+reads" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/inc2w.rs b/src/tools/miri/tests/genmc/pass/litmus/inc2w.rs index eb84304a1986e..fee9eea2d3802 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/inc2w.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/inc2w.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's test "litmus/inc2w". #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/inc_inc_RR_W_RR.rs b/src/tools/miri/tests/genmc/pass/litmus/inc_inc_RR_W_RR.rs index 40ca486318598..185ea189a415f 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/inc_inc_RR_W_RR.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/inc_inc_RR_W_RR.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - #![no_main] #[path = "../../../utils/genmc.rs"] diff --git a/src/tools/miri/tests/genmc/pass/litmus/riwi.rs b/src/tools/miri/tests/genmc/pass/litmus/riwi.rs index 49564c8e4fe0a..8284ec0895387 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/riwi.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/riwi.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // Translated from GenMC's "litmus/riwi" test. #![no_main] diff --git a/src/tools/miri/tests/genmc/pass/litmus/viktor-relseq.rs b/src/tools/miri/tests/genmc/pass/litmus/viktor-relseq.rs index 3256c9f421193..822a93ca97390 100644 --- a/src/tools/miri/tests/genmc/pass/litmus/viktor-relseq.rs +++ b/src/tools/miri/tests/genmc/pass/litmus/viktor-relseq.rs @@ -1,4 +1,4 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-estimate +//@compile-flags: -Zmiri-genmc-estimate // Translated from GenMC's "litmus/viktor-relseq" test. // diff --git a/src/tools/miri/tests/genmc/pass/shims/mutex_deadlock.rs b/src/tools/miri/tests/genmc/pass/shims/mutex_deadlock.rs index df47fbfbc1676..e2337c2ed3cd0 100644 --- a/src/tools/miri/tests/genmc/pass/shims/mutex_deadlock.rs +++ b/src/tools/miri/tests/genmc/pass/shims/mutex_deadlock.rs @@ -1,4 +1,4 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-verbose +//@compile-flags: -Zmiri-genmc-verbose //@normalize-stderr-test: "Verification took .*s" -> "Verification took [TIME]s" // Test that we can detect a deadlock involving `std::sync::Mutex` in GenMC mode. diff --git a/src/tools/miri/tests/genmc/pass/shims/mutex_simple.rs b/src/tools/miri/tests/genmc/pass/shims/mutex_simple.rs index 1f8bc81d85eb5..19a955bfcf93e 100644 --- a/src/tools/miri/tests/genmc/pass/shims/mutex_simple.rs +++ b/src/tools/miri/tests/genmc/pass/shims/mutex_simple.rs @@ -1,11 +1,10 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-verbose +//@compile-flags: -Zmiri-genmc-verbose //@normalize-stderr-test: "Verification took .*s" -> "Verification took [TIME]s" // Test various features of the `std::sync::Mutex` API with GenMC. // Miri running with GenMC intercepts the Mutex functions `lock`, `try_lock` and `unlock`, instead of running their actual implementation. // This interception should not break any functionality. // -// FIXME(genmc): Once GenMC supports mixed size accesses, add stack/heap allocated Mutexes to the test. // FIXME(genmc): Once the actual implementation of mutexes can be used in GenMC mode and there is a setting to disable Mutex interception: Add test revision without interception. // // Miri provides annotations to GenMC for the condition required to unblock a thread blocked on a Mutex lock call. @@ -25,7 +24,6 @@ use crate::genmc::*; const REPS: u64 = 3; static LOCK: Mutex = Mutex::new(0); -static OTHER_LOCK: Mutex = Mutex::new(1234); #[unsafe(no_mangle)] fn miri_start(_argc: isize, _argv: *const *const u8) -> isize { @@ -35,7 +33,8 @@ fn miri_start(_argc: isize, _argv: *const *const u8) -> isize { fn main_() { // Two mutexes should not interfere, holding this guard does not affect the other mutex. - let other_guard = OTHER_LOCK.lock().unwrap(); + let other_lock = Mutex::new(1234); + let other_guard = other_lock.lock().unwrap(); let guard = LOCK.lock().unwrap(); // Trying to lock should fail if the mutex is already held. diff --git a/src/tools/miri/tests/genmc/pass/shims/spinloop_assume.rs b/src/tools/miri/tests/genmc/pass/shims/spinloop_assume.rs index cf19e92994421..5a4d05370cfcd 100644 --- a/src/tools/miri/tests/genmc/pass/shims/spinloop_assume.rs +++ b/src/tools/miri/tests/genmc/pass/shims/spinloop_assume.rs @@ -1,5 +1,5 @@ //@ revisions: bounded123 bounded321 replaced123 replaced321 -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows -Zmiri-genmc-verbose +//@compile-flags: -Zmiri-genmc-verbose //@normalize-stderr-test: "Verification took .*s" -> "Verification took [TIME]s" // This test uses GenMC assume statements to bound or replace spinloops. diff --git a/src/tools/miri/tests/genmc/pass/std/arc.rs b/src/tools/miri/tests/genmc/pass/std/arc.rs index addf6408c006f..52fb5f50d5087 100644 --- a/src/tools/miri/tests/genmc/pass/std/arc.rs +++ b/src/tools/miri/tests/genmc/pass/std/arc.rs @@ -1,4 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows //@revisions: check_count try_upgrade // Check that various operations on `std::sync::Arc` are handled properly in GenMC mode. diff --git a/src/tools/miri/tests/genmc/pass/std/empty_main.rs b/src/tools/miri/tests/genmc/pass/std/empty_main.rs index 2ffc3388fb36c..24a058c228529 100644 --- a/src/tools/miri/tests/genmc/pass/std/empty_main.rs +++ b/src/tools/miri/tests/genmc/pass/std/empty_main.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // A lot of code runs before main, which we should be able to handle in GenMC mode. fn main() {} diff --git a/src/tools/miri/tests/genmc/pass/std/spawn_std_threads.rs b/src/tools/miri/tests/genmc/pass/std/spawn_std_threads.rs index dadbee47b9860..9a8828fe7d25b 100644 --- a/src/tools/miri/tests/genmc/pass/std/spawn_std_threads.rs +++ b/src/tools/miri/tests/genmc/pass/std/spawn_std_threads.rs @@ -1,5 +1,3 @@ -//@compile-flags: -Zmiri-genmc -Zmiri-disable-stacked-borrows - // We should be able to spawn and join standard library threads in GenMC mode. // Since these threads do nothing, we should only explore 1 program execution. diff --git a/src/tools/miri/tests/genmc/pass/std/thread_locals.rs b/src/tools/miri/tests/genmc/pass/std/thread_locals.rs index d76975d2e92c2..21023f79641c5 100644 --- a/src/tools/miri/tests/genmc/pass/std/thread_locals.rs +++ b/src/tools/miri/tests/genmc/pass/std/thread_locals.rs @@ -1,4 +1,4 @@ -//@compile-flags: -Zmiri-ignore-leaks -Zmiri-genmc -Zmiri-disable-stacked-borrows +//@compile-flags: -Zmiri-ignore-leaks use std::alloc::{Layout, alloc}; use std::cell::Cell; diff --git a/src/tools/miri/tests/pass-dep/libc/libc-blocking-io-same-fd.rs b/src/tools/miri/tests/pass-dep/libc/libc-blocking-io-same-fd.rs index d4bae144f2213..01bda8a4f15f6 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-blocking-io-same-fd.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-blocking-io-same-fd.rs @@ -11,7 +11,7 @@ use libc_utils::*; // same fd at the same time. fn main() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; @@ -25,30 +25,43 @@ fn main() { let mut buffer = [22u8; 128]; let bytes_written = unsafe { - errno_result(net::send_all(peerfd, buffer.as_mut_ptr().cast(), buffer.len(), 0)) - .unwrap() + errno_result(libc_utils::write_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::NoRetry, + |buf, len| libc::send(peerfd, buf, len, 0), + )) + .unwrap() }; assert_eq!(bytes_written as usize, 128); }); - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); let reader_thread = thread::spawn(move || { let mut buffer = [0u8; 8]; - let bytes_read = unsafe { - errno_result(net::recv_all(client_sockfd, buffer.as_mut_ptr().cast(), buffer.len(), 0)) - .unwrap() + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::NoRetry, + |buf, count| libc::recv(client_sockfd, buf, count, 0), + )) + .unwrap() }; - assert_eq!(bytes_read, 8); assert_eq!(&buffer, &[22u8; 8]); }); let mut buffer = [0u8; 8]; - let bytes_read = unsafe { - errno_result(net::recv_all(client_sockfd, buffer.as_mut_ptr().cast(), buffer.len(), 0)) - .unwrap() + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::NoRetry, + |buf, count| libc::recv(client_sockfd, buf, count, 0), + )) + .unwrap() }; - assert_eq!(bytes_read, 8); assert_eq!(&buffer, &[22u8; 8]); reader_thread.join().unwrap(); diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs new file mode 100644 index 0000000000000..f43847733bc70 --- /dev/null +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket-no-blocking.rs @@ -0,0 +1,662 @@ +//@ignore-target: windows +//@compile-flags: -Zmiri-disable-isolation +//@revisions: windows_host unix_host +//@[unix_host] ignore-host: windows +//@[windows_host] only-host: windows + +#![feature(io_error_inprogress)] + +#[path = "../../utils/libc.rs"] +mod libc_utils; + +use std::io::ErrorKind; +use std::thread; +use std::time::Duration; + +use libc_utils::*; + +const TEST_BYTES: &[u8] = b"these are some test bytes!"; + +fn main() { + test_fcntl_nonblock_opt(); + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" + ))] + test_sock_nonblock_opt(); + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + test_ioctl_fionbio_op(); + + test_accept_nonblock(); + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" + ))] + test_accept4_sock_nonblock_opt(); + test_connect_nonblock(); + test_send_recv_nonblock(); + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" + ))] + test_send_recv_dontwait(); + test_write_read_nonblock(); + + test_getpeername_ipv4_nonblock(); + test_getpeername_ipv4_nonblock_no_peer(); +} + +/// Test that setting the O_NONBLOCK flag changes the blocking state of a socket. +fn test_fcntl_nonblock_opt() { + let sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change socket to be non-blocking. + errno_check(libc::fcntl(sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + // Ensure that socket is really non-blocking. + assert_eq!(flags & libc::O_NONBLOCK, libc::O_NONBLOCK); + + unsafe { + // Change socket back to be blocking. + errno_check(libc::fcntl(sockfd, libc::F_SETFL, 0)); + } + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + // Ensure that socket is really blocking. + assert_eq!(flags & libc::O_NONBLOCK, 0); +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" +))] +/// Test creating a non-blocking socket by using the SOCK_NONBLOCK option +/// for the `socket` syscall. +fn test_sock_nonblock_opt() { + let sockfd = unsafe { + errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0)) + .unwrap() + }; + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + // Ensure that socket is really non-blocking. + assert_eq!(flags & libc::O_NONBLOCK, libc::O_NONBLOCK); +} + +#[cfg(not(any(target_os = "solaris", target_os = "illumos")))] +/// Test changing the blocking state of a socket using the `ioctl(fd, FIONBIO, ...)` +/// syscall. +fn test_ioctl_fionbio_op() { + let sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change socket to be non-blocking. + let mut value = 1 as libc::c_int; + errno_check(libc::ioctl(sockfd, libc::FIONBIO, &mut value)); + } + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + // Ensure that socket is really non-blocking. + assert_eq!(flags & libc::O_NONBLOCK, libc::O_NONBLOCK); + + unsafe { + // Change socket back to be blocking. + let mut value = 0 as libc::c_int; + errno_check(libc::ioctl(sockfd, libc::FIONBIO, &mut value)); + } + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + // Ensure that socket is really blocking. + assert_eq!(flags & libc::O_NONBLOCK, 0); +} + +/// Test that nonblocking TCP server sockets return [`ErrorKind::WouldBlock`] when trying +/// to accept when no incoming connection exists. This also tests that nonblocking server sockets +/// are still able to accept incoming connections should they already exist before the `accept` or +/// `accept4` syscall is called. +fn test_accept_nonblock() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change server socket to be non-blocking. + errno_check(libc::fcntl(server_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // This should fail as we don't have an incoming connection for this address. + let err = net::accept_ipv4(server_sockfd).unwrap_err(); + // Assert that either EAGAIN or EWOULDBLOCK was returned. + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + // Instantly yield to main thread to ensure that the `connect` syscall + // was called before we call the `accept` on the server. + thread::sleep(Duration::from_millis(10)); + + net::accept_ipv4(server_sockfd).unwrap(); + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + server_thread.join().unwrap(); +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" +))] +/// Test that calling `accept4` with the SOCK_NONBLOCK flag produces +/// a non-blocking peer socket. +fn test_accept4_sock_nonblock_opt() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + let (peerfd, _) = net::sockname_ipv4(|storage, len| unsafe { + libc::accept4(server_sockfd, storage, len, libc::SOCK_NONBLOCK) + }) + .unwrap(); + + let flags = unsafe { errno_result(libc::fcntl(peerfd, libc::F_GETFL, 0)).unwrap() }; + + // Ensure that peer socket is non-blocking. + assert_eq!(flags & libc::O_NONBLOCK, libc::O_NONBLOCK); + + let mut buffer = [0u8; 8]; + // Reading from a socket should return EWOULDBLOCK when there is no + // data written into it. + let err = unsafe { + errno_result(libc::read(peerfd, buffer.as_mut_ptr().cast(), buffer.len())).unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock); + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + server_thread.join().unwrap(); +} + +/// Test that connecting to a server socket works when the client +/// socket is non-blocking before the `connect` call. +fn test_connect_nonblock() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + net::accept_ipv4(server_sockfd).unwrap(); + }); + + // Yield to server thread to ensure that it's currently accepting. + thread::sleep(Duration::from_millis(10)); + + // Non-blocking connects always "fail" with EINPROGRESS. + let err = net::connect_ipv4(client_sockfd, addr).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InProgress); + + loop { + let result = net::sockname_ipv4(|storage, len| unsafe { + libc::getpeername(client_sockfd, storage, len) + }); + match result { + Ok(_) => { + // The client is now connected. + break; + } + Err(err) if err.kind() == ErrorKind::NotConnected => { + // The client is still connecting. + thread::sleep(Duration::from_millis(10)); + } + Err(err) => panic!("unexpected error whilst ensuring connection: {err}"), + } + } + + server_thread.join().unwrap(); +} + +/// Test sending bytes into and receiving bytes from a connected stream without blocking. +fn test_send_recv_nonblock() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); + // `peerfd` is a blocking socket now. But that's okay, the client still does non-blocking + // reads/writes. + + // Yield back to client so that it starts receiving before we start sending. + thread::sleep(Duration::from_millis(10)); + + unsafe { + errno_result(libc_utils::write_all_generic( + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + libc_utils::NoRetry, + |buf, count| libc::send(peerfd, buf, count, 0), + )) + .unwrap() + }; + + // The buffer should contain `TEST_BYTES` at the beginning. + // This will block until the client sent us this data. + let mut buffer = [0; TEST_BYTES.len()]; + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::NoRetry, + |buf, count| libc::recv(peerfd, buf, count, 0), + )) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // We are connected and the server socket is not writing. + + let mut buffer = [0; TEST_BYTES.len()]; + // Receiving from a socket when the peer is not writing is + // not possible without blocking. + let err = unsafe { + errno_result(libc::recv(client_sockfd, buffer.as_mut_ptr().cast(), buffer.len(), 0)) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Try to receive bytes from the peer socket without blocking. + // Since the peer socket might do partial writes, we might need to + // sleep multiple times until we received everything. + + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::RetryAfter(Duration::from_millis(10)), + |buf, count| libc::recv(client_sockfd, buf, count, 0), + )) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + + // Test non-blocking writing. + + // Sending into the empty buffer should succeed without blocking. + unsafe { + errno_result(libc_utils::write_all_generic( + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + libc_utils::NoRetry, + |buf, count| libc::send(client_sockfd, buf, count, 0), + )) + .unwrap() + }; + + if !cfg!(windows_host) { + // Keep sending data until the buffer is full and we block. + // We cannot test this on Windows since there apparently the send buffer + // never fills up, at least for localhost connections. + + let fill_buf = [1u8; 5_000_000]; + // This fills the socket receive buffer and thus should start blocking. + let err = unsafe { + errno_result(libc_utils::write_all_generic( + fill_buf.as_ptr().cast(), + fill_buf.len(), + libc_utils::NoRetry, + |buf, count| libc::send(client_sockfd, buf, count, 0), + )) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock) + } + + server_thread.join().unwrap(); +} + +#[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "solaris", + target_os = "illumos" +))] +/// Test sending bytes into and receiving bytes from a connected stream without blocking. +/// Instead of using non-blocking sockets, we test whether it works with blocking sockets +/// when passing the `libc::MSG_DONTWAIT` flag to the send and receive calls. +fn test_send_recv_dontwait() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); + // Similar to above we use blocking operations on the server side. + + // Yield back to client so that it starts receiving before we start sending. + thread::sleep(Duration::from_millis(10)); + + unsafe { + errno_result(libc_utils::write_all_generic( + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + libc_utils::NoRetry, + |buf, count| libc::send(peerfd, buf, count, 0), + )) + .unwrap() + }; + + // The buffer should contain `TEST_BYTES` at the beginning. + // This will block until the client sent us this data. + let mut buffer = [0; TEST_BYTES.len()]; + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::NoRetry, + |buf, count| libc::recv(peerfd, buf, count, 0), + )) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + // We are connected and the server socket is not writing. + + let mut buffer = [0; TEST_BYTES.len()]; + // Receiving from a socket when the peer is not writing is + // not possible without blocking. + let err = unsafe { + errno_result(libc::recv( + client_sockfd, + buffer.as_mut_ptr().cast(), + buffer.len(), + libc::MSG_DONTWAIT, + )) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Try to receive bytes from the peer socket without blocking. + // Since the peer socket might do partial writes, we might need to + // sleep multiple times until we received everything. + + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::RetryAfter(Duration::from_millis(10)), + |buf, count| libc::recv(client_sockfd, buf, count, libc::MSG_DONTWAIT), + )) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + + // Test non-blocking writing. + + // Sending into the empty buffer should succeed without blocking. + unsafe { + errno_result(libc_utils::write_all_generic( + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + libc_utils::NoRetry, + |buf, count| libc::send(client_sockfd, buf, count, libc::MSG_DONTWAIT), + )) + .unwrap() + }; + + if !cfg!(windows_host) { + // Keep sending data until the buffer is full and we block. + // We cannot test this on Windows since there apparently the send buffer + // never fills up, at least for localhost connections. + + let fill_buf = [1u8; 5_000_000]; + // This fills the socket receive buffer and thus should start blocking. + let err = unsafe { + errno_result(libc_utils::write_all_generic( + fill_buf.as_ptr().cast(), + fill_buf.len(), + libc_utils::NoRetry, + |buf, count| libc::send(client_sockfd, buf, count, libc::MSG_DONTWAIT), + )) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock) + } + + server_thread.join().unwrap(); +} + +/// Test writing bytes into and reading bytes from a connected stream without blocking. +fn test_write_read_nonblock() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); + // Similar to above we use blocking operations on the server side. + + // Yield back to client so that it starts receiving before we start sending. + thread::sleep(Duration::from_millis(10)); + + let bytes_written = unsafe { + errno_result(libc_utils::write_all( + peerfd, + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + )) + .unwrap() + }; + assert_eq!(bytes_written as usize, TEST_BYTES.len()); + + // The buffer should contain `TEST_BYTES` at the beginning. + // This will block until the client sent us this data. + let mut buffer = [0; TEST_BYTES.len()]; + unsafe { + errno_result(libc_utils::read_all(peerfd, buffer.as_mut_ptr().cast(), buffer.len())) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + }); + + net::connect_ipv4(client_sockfd, addr).unwrap(); + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // We are connected and the server socket is not writing. + + let mut buffer = [0; TEST_BYTES.len()]; + // Reading from a socket when the peer is not writing is + // not possible without blocking. + let err = unsafe { + errno_result(libc::read( + client_sockfd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + )) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Try to read bytes from the peer socket without blocking. + // Since the peer socket might do partial writes, we might need to + // sleep multiple times until we read everything. + + unsafe { + errno_result(libc_utils::read_all_generic( + buffer.as_mut_ptr().cast(), + buffer.len(), + libc_utils::RetryAfter(Duration::from_millis(10)), + |buf, count| libc::read(client_sockfd, buf, count), + )) + .unwrap() + }; + assert_eq!(&buffer, TEST_BYTES); + + // Now we test non-blocking writing. + + // Writing into the empty buffer should succeed without blocking. + let bytes_written = unsafe { + errno_result(libc_utils::write_all( + client_sockfd, + TEST_BYTES.as_ptr().cast(), + TEST_BYTES.len(), + )) + .unwrap() + }; + assert_eq!(bytes_written as usize, TEST_BYTES.len()); + + if !cfg!(windows_host) { + // Keep sending data until the buffer is full and we block. + // We cannot test this on Windows since there apparently the send buffer + // never fills up, at least for localhost connections. + + let fill_buf = [1u8; 5_000_000]; + // This fills the socket receive buffer and thus should start blocking. + let err = unsafe { + errno_result(libc_utils::write_all_generic( + fill_buf.as_ptr().cast(), + fill_buf.len(), + libc_utils::NoRetry, + |buf, count| libc::write(client_sockfd, buf, count), + )) + .unwrap_err() + }; + assert_eq!(err.kind(), ErrorKind::WouldBlock) + } + + server_thread.join().unwrap(); +} + +/// Test that the `getpeername` syscall successfully returns the peer address +/// for a non-blocking IPv4 socket whose connection has been successfully +/// established before calling the syscall. +fn test_getpeername_ipv4_nonblock() { + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // Spawn the server thread. + let server_thread = thread::spawn(move || { + net::accept_ipv4(server_sockfd).unwrap(); + }); + + // Yield to server thread to ensure that it's currently accepting. + thread::sleep(Duration::from_millis(10)); + + // Non-blocking connects always "fail" with EINPROGRESS. + let err = net::connect_ipv4(client_sockfd, addr).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InProgress); + + loop { + let peername_result = net::sockname_ipv4(|storage, len| unsafe { + libc::getpeername(client_sockfd, storage, len) + }); + + match peername_result { + Ok((_, peer_addr)) => { + assert_eq!(addr.sin_family, peer_addr.sin_family); + assert_eq!(addr.sin_port, peer_addr.sin_port); + assert_eq!(addr.sin_addr.s_addr, peer_addr.sin_addr.s_addr); + break; + } + Err(err) if err.kind() == ErrorKind::NotConnected => { + // Connection is not yet established; wait and retry later. + thread::sleep(Duration::from_millis(10)) + } + Err(err) => { + panic!("error whilst getting peername: {err}") + } + } + } + + server_thread.join().unwrap(); +} + +/// Test that the `getpeername` syscall returns ENOTCONN +/// for a non-blocking IPv4 socket which is stuck at +/// connecting to the remote address. +fn test_getpeername_ipv4_nonblock_no_peer() { + let client_sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + unsafe { + // Change client socket to be non-blocking. + errno_check(libc::fcntl(client_sockfd, libc::F_SETFL, libc::O_NONBLOCK)); + } + + // We cannot attempt to connect to a localhost address because + // it could be the case that a socket from another test is + // currently listening on `localhost:12321` because we bind to + // random ports everywhere. For `192.0.2.1` we know that nothing is + // listening because it's a blackhole address: + // + // The port `12321` is just a random non-zero port because Windows + // and Apple hosts return EADDRNOTAVAIL when attempting to connect to + // a zero port. + let addr = net::sock_addr_ipv4([192, 0, 2, 1], 12321); + + // Non-blocking connect should fail with EINPROGRESS. + let err = net::connect_ipv4(client_sockfd, addr).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InProgress); + + // Since we're never accepting the connection, the socket should never be + // successfully connected and thus we should be unable to read the peername. + let Err(err) = net::sockname_ipv4(|storage, len| unsafe { + libc::getpeername(client_sockfd, storage, len) + }) else { + unreachable!() + }; + assert_eq!(err.kind(), ErrorKind::NotConnected); +} diff --git a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs index 35830d9bf0622..64c1e8d4c3a6e 100644 --- a/src/tools/miri/tests/pass-dep/libc/libc-socket.rs +++ b/src/tools/miri/tests/pass-dep/libc/libc-socket.rs @@ -1,8 +1,6 @@ //@ignore-target: windows # No libc socket on Windows //@compile-flags: -Zmiri-disable-isolation -#![feature(io_error_inprogress)] - #[path = "../../utils/libc.rs"] mod libc_utils; #[path = "../../utils/mod.rs"] @@ -18,7 +16,7 @@ use utils::check_nondet; const TEST_BYTES: &[u8] = b"these are some test bytes!"; fn main() { - test_socket_close(); + test_create_close(); test_bind_ipv4(); test_bind_ipv4_reuseaddr(); test_set_reuseaddr_invalid_len(); @@ -50,11 +48,17 @@ fn main() { test_getpeername_ipv6(); } -fn test_socket_close() { - unsafe { - let sockfd = errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap(); - errno_check(libc::close(sockfd)); - } +/// Test creating a socket and then closing it afterwards. +fn test_create_close() { + let sockfd = + unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; + + let flags = unsafe { errno_result(libc::fcntl(sockfd, libc::F_GETFL, 0)).unwrap() }; + + // Ensure that socket is initially blocking. + assert_eq!(flags & libc::O_NONBLOCK, 0); + + unsafe { errno_check(libc::close(sockfd)) }; } fn test_bind_ipv4() { @@ -193,13 +197,18 @@ fn test_listen() { /// - Connecting when the server is already accepting /// - Accepting when there is already an incoming connection fn test_accept_connect() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; // Spawn the server thread. let server_thread = thread::spawn(move || { - net::accept_ipv4(server_sockfd).unwrap(); + let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); + + let flags = unsafe { errno_result(libc::fcntl(peerfd, libc::F_GETFL, 0)).unwrap() }; + + // Ensure that peer socket is blocking. + assert_eq!(flags & libc::O_NONBLOCK, 0); // Yield back to the client thread to test whether calling `connect` first also // works. @@ -213,7 +222,7 @@ fn test_accept_connect() { thread::sleep(Duration::from_millis(10)); // Test connecting to an already accepting server. - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); // Server thread should now be in its `sleep`. // Test connecting when there is no actively ongoing `accept`. @@ -221,7 +230,7 @@ fn test_accept_connect() { let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); server_thread.join().unwrap(); } @@ -231,7 +240,7 @@ fn test_accept_connect() { /// We especially want to test that the peeking doesn't remove the bytes from /// the queue. fn test_send_peek_recv() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; @@ -240,19 +249,18 @@ fn test_send_peek_recv() { let (peerfd, _) = net::accept_ipv4(server_sockfd).unwrap(); // Write the bytes into the stream. - let bytes_written = unsafe { - errno_result(libc_utils::net::send_all( - peerfd, + unsafe { + errno_result(libc_utils::write_all_generic( TEST_BYTES.as_ptr().cast(), TEST_BYTES.len(), - 0, + libc_utils::NoRetry, + |buf, count| libc::send(peerfd, buf, count, 0), )) .unwrap() }; - assert_eq!(bytes_written as usize, TEST_BYTES.len()); }); - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); let mut buffer = [0; TEST_BYTES.len()]; let bytes_read = unsafe { @@ -273,17 +281,15 @@ fn test_send_peek_recv() { // able to read the same bytes again into a new buffer. let mut buffer = [0; TEST_BYTES.len()]; - let bytes_read = unsafe { - errno_result(libc_utils::net::recv_all( - client_sockfd, + unsafe { + errno_result(libc_utils::read_all_generic( buffer.as_mut_ptr().cast(), buffer.len(), - 0, + libc_utils::NoRetry, + |buf, count| libc::recv(client_sockfd, buf, count, 0), )) .unwrap() }; - - assert_eq!(bytes_read as usize, TEST_BYTES.len()); assert_eq!(&buffer, TEST_BYTES); server_thread.join().unwrap(); @@ -291,7 +297,7 @@ fn test_send_peek_recv() { /// Test that we actually do partial sends and partial receives for sockets. fn test_partial_send_recv() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; @@ -313,7 +319,7 @@ fn test_partial_send_recv() { }); }); - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); // Ensure we sometimes do incomplete writes. check_nondet(|| { @@ -325,11 +331,10 @@ fn test_partial_send_recv() { let buffer = [0u8; 100_000]; // Write a lot of bytes into the socket such that we can test // incomplete reads. - let bytes_written = unsafe { + unsafe { errno_result(libc_utils::write_all(client_sockfd, buffer.as_ptr().cast(), buffer.len())) .unwrap() }; - assert_eq!(bytes_written as usize, buffer.len()); server_thread.join().unwrap(); } @@ -339,7 +344,7 @@ fn test_partial_send_recv() { /// We want to test this because `write` and `read` should be the same as /// `send` and `recv` with zero flags. fn test_write_read() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; @@ -359,15 +364,13 @@ fn test_write_read() { assert_eq!(bytes_written as usize, TEST_BYTES.len()); }); - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); let mut buffer = [0; TEST_BYTES.len()]; - let bytes_read = unsafe { + unsafe { errno_result(libc_utils::read_all(client_sockfd, buffer.as_mut_ptr().cast(), buffer.len())) .unwrap() }; - - assert_eq!(bytes_read as usize, TEST_BYTES.len()); assert_eq!(&buffer, TEST_BYTES); server_thread.join().unwrap(); @@ -484,14 +487,14 @@ fn test_getsockname_ipv6() { /// For a connected socket, the `getpeername` syscall should /// return the same address as the socket was connected to. fn test_getpeername_ipv4() { - let (server_sockfd, addr) = net::make_listener_ipv4(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv4().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0)).unwrap() }; // Spawn the server thread. let server_thread = thread::spawn(move || net::accept_ipv4(server_sockfd).unwrap()); - net::connect_ipv4(client_sockfd, addr); + net::connect_ipv4(client_sockfd, addr).unwrap(); let (_, peer_addr) = net::sockname_ipv4(|storage, len| unsafe { libc::getpeername(client_sockfd, storage, len) @@ -509,14 +512,14 @@ fn test_getpeername_ipv4() { /// For a connected socket, the `getpeername` syscall should /// return the same address as the socket was connected to. fn test_getpeername_ipv6() { - let (server_sockfd, addr) = net::make_listener_ipv6(0).unwrap(); + let (server_sockfd, addr) = net::make_listener_ipv6().unwrap(); let client_sockfd = unsafe { errno_result(libc::socket(libc::AF_INET6, libc::SOCK_STREAM, 0)).unwrap() }; // Spawn the server thread. let server_thread = thread::spawn(move || net::accept_ipv6(server_sockfd).unwrap()); - net::connect_ipv6(client_sockfd, addr); + net::connect_ipv6(client_sockfd, addr).unwrap(); let (_, peer_addr) = net::sockname_ipv6(|storage, len| unsafe { libc::getpeername(client_sockfd, storage, len) diff --git a/src/tools/miri/tests/pass/shims/socket-no-blocking.rs b/src/tools/miri/tests/pass/shims/socket-no-blocking.rs new file mode 100644 index 0000000000000..18106f113a171 --- /dev/null +++ b/src/tools/miri/tests/pass/shims/socket-no-blocking.rs @@ -0,0 +1,92 @@ +//@ignore-target: windows # No libc socket on Windows +//@compile-flags: -Zmiri-disable-isolation -Zmiri-fixed-schedule + +use std::io::{ErrorKind, Read, Write}; +use std::net::{TcpListener, TcpStream}; +use std::thread; + +const TEST_BYTES: &[u8] = b"these are some test bytes!"; + +fn main() { + test_accept_nonblock(); + test_send_recv_nonblock(); +} + +/// Test that nonblocking TCP server sockets return [`ErrorKind::WouldBlock`] when trying +/// to accept when no incoming connection exists. This also tests that nonblocking server sockets +/// are still able to accept incoming connections should they already exist before [`TcpListener::accept`] +/// is called. +fn test_accept_nonblock() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + // Make server non-blocking. + listener.set_nonblocking(true).unwrap(); + // Get local address with randomized port to know where + // we need to connect to. + let address = listener.local_addr().unwrap(); + + // Accepting when no incoming connecting exists should block. + let err = listener.accept().unwrap_err(); + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Start server thread. + let handle = thread::spawn(move || { + // Accepting when there is an existing incoming connection should + // succeed without blocking. + + let (_stream, _peer_addr) = listener.accept().unwrap(); + }); + + // The connect is blocking and thus we yield to the server thread. + let _stream = TcpStream::connect(address).unwrap(); + + handle.join().unwrap(); +} + +/// Test sending bytes into and receiving bytes from a connected stream without blocking. +fn test_send_recv_nonblock() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + // Get local address with randomized port to know where + // we need to connect to. + let address = listener.local_addr().unwrap(); + + // Start server thread. + let handle = thread::spawn(move || { + let (mut stream, _addr) = listener.accept().unwrap(); + + // Yield back to client thread to ensure that the first read + // is before we write anything into the socket. + thread::yield_now(); + + stream.write_all(TEST_BYTES).unwrap(); + }); + + // The connect is blocking and thus we yield to the server thread. + let mut stream = TcpStream::connect(address).unwrap(); + // Make client non-blocking. + stream.set_nonblocking(true).unwrap(); + let mut buffer = [0; TEST_BYTES.len()]; + // Reading when no data was written should return WouldBlock. + let err = stream.read_exact(&mut buffer).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::WouldBlock); + + // Try to read bytes from the peer socket without blocking. + // Since the peer socket might do partial writes, we might need to + // sleep multiple times until we read everything. + + let mut bytes_read = 0; + while bytes_read != TEST_BYTES.len() { + match stream.read(&mut buffer[bytes_read..]) { + Ok(read) => bytes_read += read, + Err(err) if err.kind() == ErrorKind::WouldBlock => { + // Not all data is written into the stream, yield to the server thread + // to write more into the stream. + thread::yield_now(); + } + Err(err) => panic!("unexpected error whilst reading: {err}"), + } + } + + assert_eq!(&buffer, TEST_BYTES); + + handle.join().unwrap(); +} diff --git a/src/tools/miri/tests/pass/tree_borrows/vec_unique.default.stderr b/src/tools/miri/tests/pass/tree_borrows/vec_unique.default.stderr deleted file mode 100644 index 254eba061f45d..0000000000000 --- a/src/tools/miri/tests/pass/tree_borrows/vec_unique.default.stderr +++ /dev/null @@ -1,6 +0,0 @@ -────────────────────────────────────────────────── -Warning: this tree is indicative only. Some tags may have been hidden. -0.. 2 -| Act | └─┬── -| Res | └──── -────────────────────────────────────────────────── diff --git a/src/tools/miri/tests/ui.rs b/src/tools/miri/tests/ui.rs index 2a6151737d6c0..10064bc2bbb3f 100644 --- a/src/tools/miri/tests/ui.rs +++ b/src/tools/miri/tests/ui.rs @@ -212,6 +212,12 @@ fn run_tests( flag.push(native_lib.into_os_string()); config.program.args.push(flag); } + // For GenMC tests, add the relevant flags. + if path.starts_with("tests/genmc/") { + config.program.args.push("-Zmiri-genmc".into()); + // FIXME(genmc): remove this when GenMC and SB can be used together. + config.program.args.push("-Zmiri-disable-stacked-borrows".into()); + } eprintln!(" Compiler: {}", config.program.display()); ui_test::run_tests_generic( diff --git a/src/tools/miri/tests/utils/libc.rs b/src/tools/miri/tests/utils/libc.rs index 866b576edffd3..26797ee4c3cba 100644 --- a/src/tools/miri/tests/utils/libc.rs +++ b/src/tools/miri/tests/utils/libc.rs @@ -1,7 +1,19 @@ //! Utils that need libc. #![allow(dead_code)] -use std::{fmt, io}; +use std::{fmt, io, time}; + +pub enum Retry { + NoRetry, + RetryAfter(time::Duration), +} +pub use Retry::*; + +/// Return the last OS error. +pub fn errno() -> i32 { + // libc has no portable way to do this so we use std. + io::Error::last_os_error().raw_os_error().unwrap() +} /// Handles the usual libc function that returns `-1` to indicate an error. #[track_caller] @@ -19,16 +31,25 @@ pub fn errno_check + Ord + fmt::Debug>(ret: T) { assert_eq!(errno_result(ret).unwrap(), 0i8.into(), "wrong successful result"); } -pub unsafe fn read_all( - fd: libc::c_int, +/// Invoke the `read` function until `buf` is full. `retry` contols the behavior on EAGAIN. +pub unsafe fn read_all_generic( buf: *mut libc::c_void, count: libc::size_t, + retry: Retry, + read: impl Fn(*mut libc::c_void, libc::size_t) -> libc::ssize_t, ) -> libc::ssize_t { assert!(count > 0); let mut read_so_far = 0; while read_so_far < count { - let res = libc::read(fd, buf.add(read_so_far), count - read_so_far); + let res = read(buf.add(read_so_far), count - read_so_far); if res < 0 { + if let RetryAfter(duration) = retry { + if errno() == libc::EAGAIN { + // Emulate blocking behavior by sleeping a bit and then trying again. + std::thread::sleep(duration); + continue; + } + } return res; } if res == 0 { @@ -40,6 +61,15 @@ pub unsafe fn read_all( return read_so_far as libc::ssize_t; } +/// Read from `fd` until `buf` is full. Abort on first error. +pub unsafe fn read_all( + fd: libc::c_int, + buf: *mut libc::c_void, + count: libc::size_t, +) -> libc::ssize_t { + read_all_generic(buf, count, NoRetry, |buf, count| libc::read(fd, buf, count)) +} + /// Try to fill the given slice by reading from `fd`. Panic if that many bytes could not be read. #[track_caller] pub fn read_all_into_slice(fd: libc::c_int, buf: &mut [u8]) -> io::Result<()> { @@ -75,16 +105,25 @@ pub fn read_until_eof_into_slice( Ok(buf.split_at_mut(res as usize)) } -pub unsafe fn write_all( - fd: libc::c_int, +/// Invoke the `write` function until `buf` is full. `retry` controls the behavior on EAGAIN. +pub unsafe fn write_all_generic( buf: *const libc::c_void, count: libc::size_t, + retry: Retry, + write: impl Fn(*const libc::c_void, libc::size_t) -> libc::ssize_t, ) -> libc::ssize_t { assert!(count > 0); let mut written_so_far = 0; while written_so_far < count { - let res = libc::write(fd, buf.add(written_so_far), count - written_so_far); + let res = write(buf.add(written_so_far), count - written_so_far); if res < 0 { + if let RetryAfter(duration) = retry { + if errno() == libc::EAGAIN { + // Emulate blocking behavior by sleeping a bit and then trying again. + std::thread::sleep(duration); + continue; + } + } return res; } // Apparently a return value of 0 is just a short write, nothing special (unlike reads). @@ -93,6 +132,15 @@ pub unsafe fn write_all( return written_so_far as libc::ssize_t; } +/// Write to `fd` until `buf` is fully written. Abort on first error. +pub unsafe fn write_all( + fd: libc::c_int, + buf: *const libc::c_void, + count: libc::size_t, +) -> libc::ssize_t { + write_all_generic(buf, count, NoRetry, |buf, count| libc::write(fd, buf, count)) +} + /// Write the entire `buf` to `fd`. Panic if not all bytes could be written. #[track_caller] pub fn write_all_from_slice(fd: libc::c_int, buf: &[u8]) -> io::Result<()> { @@ -159,9 +207,7 @@ pub mod epoll { } pub mod net { - use std::io; - - use super::{errno_check, errno_result}; + use super::*; /// IPv4 localhost address bytes pub const IPV4_LOCALHOST: [u8; 4] = [127, 0, 0, 1]; @@ -211,11 +257,8 @@ pub mod net { /// Create an IPv4 TCP socket which listens on a random port at the localhost address. /// Returns the socket file descriptor and the actual socket address the socket is listening on. - pub fn make_listener_ipv4( - options: libc::c_int, - ) -> io::Result<(libc::c_int, libc::sockaddr_in)> { - let sockfd = - unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM | options, 0))? }; + pub fn make_listener_ipv4() -> io::Result<(libc::c_int, libc::sockaddr_in)> { + let sockfd = unsafe { errno_result(libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0))? }; // Turn address into socket address with a random free port. let addr = sock_addr_ipv4(IPV4_LOCALHOST, 0); unsafe { @@ -239,11 +282,8 @@ pub mod net { /// Create an IPv6 TCP socket which listens on a random port at the localhost address. /// Returns the socket file descriptor and the actual socket address the socket is listening on. - pub fn make_listener_ipv6( - options: libc::c_int, - ) -> io::Result<(libc::c_int, libc::sockaddr_in6)> { - let sockfd = - unsafe { errno_result(libc::socket(libc::AF_INET6, libc::SOCK_STREAM | options, 0))? }; + pub fn make_listener_ipv6() -> io::Result<(libc::c_int, libc::sockaddr_in6)> { + let sockfd = unsafe { errno_result(libc::socket(libc::AF_INET6, libc::SOCK_STREAM, 0))? }; // Turn address into socket address with a random free port. let addr = sock_addr_ipv6(IPV6_LOCALHOST, 0); unsafe { @@ -276,25 +316,27 @@ pub mod net { } /// Connect the socket to the specified IPv4 address. - pub fn connect_ipv4(sockfd: libc::c_int, addr: libc::sockaddr_in) { + pub fn connect_ipv4(sockfd: libc::c_int, addr: libc::sockaddr_in) -> io::Result<()> { unsafe { - errno_check(libc::connect( + errno_result(libc::connect( sockfd, (&addr as *const libc::sockaddr_in).cast(), size_of::() as libc::socklen_t, - )); + ))?; } + Ok(()) } /// Connect the socket to the specified IPv6 address. - pub fn connect_ipv6(sockfd: libc::c_int, addr: libc::sockaddr_in6) { + pub fn connect_ipv6(sockfd: libc::c_int, addr: libc::sockaddr_in6) -> io::Result<()> { unsafe { - errno_check(libc::connect( + errno_result(libc::connect( sockfd, (&addr as *const libc::sockaddr_in6).cast(), size_of::() as libc::socklen_t, - )); + ))?; } + Ok(()) } /// Set a socket option. It's the caller's responsibility to ensure that `T` is @@ -384,45 +426,4 @@ pub mod net { Ok((value, address)) } - - pub unsafe fn recv_all( - fd: libc::c_int, - buf: *mut libc::c_void, - count: libc::size_t, - flags: libc::c_int, - ) -> libc::ssize_t { - assert!(count > 0); - let mut read_so_far = 0; - while read_so_far < count { - let res = libc::recv(fd, buf.add(read_so_far), count - read_so_far, flags); - if res < 0 { - return res; - } - if res == 0 { - // EOF - break; - } - read_so_far += res as libc::size_t; - } - return read_so_far as libc::ssize_t; - } - - pub unsafe fn send_all( - fd: libc::c_int, - buf: *const libc::c_void, - count: libc::size_t, - flags: libc::c_int, - ) -> libc::ssize_t { - assert!(count > 0); - let mut written_so_far = 0; - while written_so_far < count { - let res = libc::send(fd, buf.add(written_so_far), count - written_so_far, flags); - if res < 0 { - return res; - } - // Apparently a return value of 0 is just a short write, nothing special (unlike reads). - written_so_far += res as libc::size_t; - } - return written_so_far as libc::ssize_t; - } } diff --git a/src/tools/rust-analyzer/.github/workflows/rustdoc.yaml b/src/tools/rust-analyzer/.github/workflows/rustdoc.yaml index 03fd083175017..c5588a29f6761 100644 --- a/src/tools/rust-analyzer/.github/workflows/rustdoc.yaml +++ b/src/tools/rust-analyzer/.github/workflows/rustdoc.yaml @@ -3,6 +3,8 @@ on: push: branches: - master + pull_request: + merge_group: env: CARGO_INCREMENTAL: 0 @@ -28,6 +30,7 @@ jobs: run: cargo doc --all --no-deps --document-private-items - name: Deploy Docs + if: github.event_name == 'push' && github.repository == 'rust-lang/rust-analyzer' && github.ref == 'refs/heads/master' uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/src/tools/rust-analyzer/Cargo.lock b/src/tools/rust-analyzer/Cargo.lock index da530b3a9304d..e6575c28c1dd0 100644 --- a/src/tools/rust-analyzer/Cargo.lock +++ b/src/tools/rust-analyzer/Cargo.lock @@ -458,9 +458,9 @@ dependencies = [ [[package]] name = "derive-where" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f" +checksum = "d08b3a0bcc0d079199cd476b2cae8435016ec11d1c0986c6901c5ac223041534" dependencies = [ "proc-macro2", "quote", @@ -786,6 +786,7 @@ dependencies = [ "hir-ty", "intern", "itertools 0.14.0", + "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "ra-ap-rustc_type_ir", "rustc-hash 2.1.1", "serde_json", @@ -1104,7 +1105,6 @@ dependencies = [ "nohash-hasher", "parser", "profile", - "query-group-macro", "rayon", "rustc-hash 2.1.1", "salsa", @@ -1975,9 +1975,9 @@ dependencies = [ [[package]] name = "protobuf" -version = "3.7.1" +version = "3.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3a7c64d9bf75b1b8d981124c14c179074e8caa7dfe7b6a12e6222ddcd0c8f72" +checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4" dependencies = [ "once_cell", "protobuf-support", @@ -1986,9 +1986,9 @@ dependencies = [ [[package]] name = "protobuf-support" -version = "3.7.1" +version = "3.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b088fd20b938a875ea00843b6faf48579462630015c3788d397ad6a786663252" +checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6" dependencies = [ "thiserror 1.0.69", ] @@ -2048,9 +2048,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ra-ap-rustc_abi" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d49dbe5d570793b3c3227972a6ac85fc3e830f09b32c3cb3b68cfceebad3b0a" +checksum = "4b917ab47d7036977be4c984321af3e0de089229404d68ea9a286f50aa464697" dependencies = [ "bitflags 2.9.4", "ra-ap-rustc_hashes", @@ -2060,33 +2060,33 @@ dependencies = [ [[package]] name = "ra-ap-rustc_ast_ir" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd0956db62c264a899d15667993cbbd2e8f0b02108712217e2579c61ac30b94b" +checksum = "021d80bea67458b8c90cc25bfdca6f911ea818a41905e370c1f310cced1dd07e" [[package]] name = "ra-ap-rustc_hashes" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7df512084c24f4c96c8cc9a59cbd264301efbc8913d3759b065398024af316c9" +checksum = "8bb89395306ecfc980d252f77a4038d8b8bb578a25c856b545cbeeb3fde8358e" dependencies = [ "rustc-stable-hash", ] [[package]] name = "ra-ap-rustc_index" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca3a49a928d38ba7927605e5909b6abe77d09ff359e4695c070c3f91d69cc8a" +checksum = "84219d028a1954c4340ddde11adffe93eb83e476e942718fe926f4d99637cbbe" dependencies = [ "ra-ap-rustc_index_macros", ] [[package]] name = "ra-ap-rustc_index_macros" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4463e908a62c64c2a65c1966c2f4995d0e1f8b7dfc85a8b8de2562edf3d89070" +checksum = "3908fdfa258c663d8ee407e6b4a205b0880e323b533c0df7edceafbd54a02fb6" dependencies = [ "proc-macro2", "quote", @@ -2095,20 +2095,20 @@ dependencies = [ [[package]] name = "ra-ap-rustc_lexer" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228e01e1b237adb4bd8793487e1c37019c1e526a8f93716d99602301be267056" +checksum = "34b50f19d5856b8e2b36150e89b53a6102ab096e8044e1f55fd6fef977b10d85" dependencies = [ "memchr", + "unicode-ident", "unicode-properties", - "unicode-xid", ] [[package]] name = "ra-ap-rustc_next_trait_solver" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10d6f91143011d474bb844d268b0784c6a4c6db57743558b83f5ad34511627f1" +checksum = "76f83dcc451bcee8a99e284a583d5b3d82db5a200107a256a40ef132c4988f1b" dependencies = [ "derive-where", "ra-ap-rustc_index", @@ -2119,19 +2119,19 @@ dependencies = [ [[package]] name = "ra-ap-rustc_parse_format" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37fa8effbc436c0ddd9d7b1421aa3cccf8b94566c841c4e4aa3e09063b8f423f" +checksum = "f31236bdc6cbcae8af42d0b2db2fa8d812a8715b90a2ba5afb1132b37a4d0bbc" dependencies = [ "ra-ap-rustc_lexer", - "rustc-literal-escaper 0.0.5", + "rustc-literal-escaper 0.0.7", ] [[package]] name = "ra-ap-rustc_pattern_analysis" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "883c843fc27847ad03b8e772dd4a2d2728af4333a6d6821a22dfcfe7136dff3e" +checksum = "3fc4edac740e896fba4b3b4d9c423083e3eac49947732561ddfb2377e1f57829" dependencies = [ "ra-ap-rustc_index", "rustc-hash 2.1.1", @@ -2142,15 +2142,16 @@ dependencies = [ [[package]] name = "ra-ap-rustc_type_ir" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a86e33c46b2b261a173b23f207461a514812a8b2d2d7935bbc685f733eacce10" +checksum = "8efa119afc1bcadd821b27aa94332abf79c48ac0a972cb78932f63004ba4cdd9" dependencies = [ "arrayvec", "bitflags 2.9.4", "derive-where", "ena", "indexmap", + "ra-ap-rustc_abi", "ra-ap-rustc_ast_ir", "ra-ap-rustc_index", "ra-ap-rustc_type_ir_macros", @@ -2162,9 +2163,9 @@ dependencies = [ [[package]] name = "ra-ap-rustc_type_ir_macros" -version = "0.143.0" +version = "0.160.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15034c2fcaa5cf302aea6db20eda0f71fffeb0b372d6073cc50f940e974a2a47" +checksum = "e6b1dc03abfabc7179393c282892c73a3f0e4bbd5f0b6c87ff42c2b142e68f57" dependencies = [ "proc-macro2", "quote", @@ -2174,9 +2175,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.2" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166" dependencies = [ "rand_chacha", "rand_core", @@ -2385,9 +2386,9 @@ checksum = "ab03008eb631b703dd16978282ae36c73282e7922fe101a4bd072a40ecea7b8b" [[package]] name = "rustc-literal-escaper" -version = "0.0.5" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ee29da77c5a54f42697493cd4c9b9f31b74df666a6c04dfc4fde77abe0438b" +checksum = "8be87abb9e40db7466e0681dc8ecd9dcfd40360cb10b4c8fe24a7c4c3669b198" [[package]] name = "rustc-stable-hash" @@ -2453,9 +2454,9 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "salsa" -version = "0.25.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e2aa2fca57727371eeafc975acc8e6f4c52f8166a78035543f6ee1c74c2dcc" +checksum = "f77debccd43ba198e9cee23efd7f10330ff445e46a98a2b107fed9094a1ee676" dependencies = [ "boxcar", "crossbeam-queue", @@ -2478,15 +2479,15 @@ dependencies = [ [[package]] name = "salsa-macro-rules" -version = "0.25.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfc2a1e7bf06964105515451d728f2422dedc3a112383324a00b191a5c397a3" +checksum = "ea07adbf42d91cc076b7daf3b38bc8168c19eb362c665964118a89bc55ef19a5" [[package]] name = "salsa-macros" -version = "0.25.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d844c1aa34946da46af683b5c27ec1088a3d9d84a2b837a108223fd830220e1" +checksum = "d16d4d8b66451b9c75ddf740b7fc8399bc7b8ba33e854a5d7526d18708f67b05" dependencies = [ "proc-macro2", "quote", @@ -2505,9 +2506,9 @@ dependencies = [ [[package]] name = "scip" -version = "0.5.2" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb2b449a5e4660ce817676a0871cd1b4e2ff1023e33a1ac046670fa594b543a2" +checksum = "f4d0e81b39f590b2edbe2369760641511898ed34062aed2e18e6d05eead3d6b7" dependencies = [ "protobuf", ] @@ -2823,9 +2824,9 @@ checksum = "f18aa187839b2bdb1ad2fa35ead8c4c2976b64e4363c386d45ac0f7ee85c9233" [[package]] name = "thin-vec" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "144f754d318415ac792f9d69fc87abbbfc043ce2ef041c60f16ad828f638717d" +checksum = "259cdf8ed4e4aca6f1e9d011e10bd53f524a2d0637d7b28450f6c64ac298c4c6" [[package]] name = "thiserror" @@ -3139,21 +3140,15 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-ident" -version = "1.0.19" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-properties" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" - -[[package]] -name = "unicode-xid" -version = "0.2.6" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" [[package]] name = "url" diff --git a/src/tools/rust-analyzer/Cargo.toml b/src/tools/rust-analyzer/Cargo.toml index 3b3929df0dfbd..b8dedc6c50a01 100644 --- a/src/tools/rust-analyzer/Cargo.toml +++ b/src/tools/rust-analyzer/Cargo.toml @@ -86,14 +86,14 @@ vfs-notify = { path = "./crates/vfs-notify", version = "0.0.0" } vfs = { path = "./crates/vfs", version = "0.0.0" } edition = { path = "./crates/edition", version = "0.0.0" } -ra-ap-rustc_lexer = { version = "0.143", default-features = false } -ra-ap-rustc_parse_format = { version = "0.143", default-features = false } -ra-ap-rustc_index = { version = "0.143", default-features = false } -ra-ap-rustc_abi = { version = "0.143", default-features = false } -ra-ap-rustc_pattern_analysis = { version = "0.143", default-features = false } -ra-ap-rustc_ast_ir = { version = "0.143", default-features = false } -ra-ap-rustc_type_ir = { version = "0.143", default-features = false } -ra-ap-rustc_next_trait_solver = { version = "0.143", default-features = false } +ra-ap-rustc_lexer = { version = "0.160", default-features = false } +ra-ap-rustc_parse_format = { version = "0.160", default-features = false } +ra-ap-rustc_index = { version = "0.160", default-features = false } +ra-ap-rustc_abi = { version = "0.160", default-features = false } +ra-ap-rustc_pattern_analysis = { version = "0.160", default-features = false } +ra-ap-rustc_ast_ir = { version = "0.160", default-features = false } +ra-ap-rustc_type_ir = { version = "0.160", default-features = false } +ra-ap-rustc_next_trait_solver = { version = "0.160", default-features = false } # local crates that aren't published to crates.io. These should not have versions. @@ -127,7 +127,7 @@ object = { version = "0.36.7", default-features = false, features = [ "macho", "pe", ] } -postcard = {version = "1.1.3", features = ["alloc"]} +postcard = { version = "1.1.3", features = ["alloc"] } process-wrap = { version = "8.2.1", features = ["std"] } pulldown-cmark-to-cmark = "10.0.4" pulldown-cmark = { version = "0.9.6", default-features = false } @@ -135,13 +135,13 @@ rayon = "1.10.0" rowan = "=0.15.18" # Ideally we'd not enable the macros feature but unfortunately the `tracked` attribute does not work # on impls without it -salsa = { version = "0.25.2", default-features = false, features = [ +salsa = { version = "0.26", default-features = false, features = [ "rayon", "salsa_unstable", "macros", "inventory", ] } -salsa-macros = "0.25.2" +salsa-macros = "0.26" semver = "1.0.26" serde = { version = "1.0.219" } serde_derive = { version = "1.0.219" } @@ -170,7 +170,7 @@ tracing-subscriber = { version = "0.3.20", default-features = false, features = triomphe = { version = "0.1.14", default-features = false, features = ["std"] } url = "2.5.4" xshell = "0.2.7" -thin-vec = "0.2.14" +thin-vec = "0.2.16" petgraph = { version = "0.8.2", default-features = false } # We need to freeze the version of the crate, as the raw-api feature is considered unstable @@ -186,7 +186,10 @@ hashbrown = { version = "0.14.*", features = [ elided_lifetimes_in_paths = "warn" explicit_outlives_requirements = "warn" unsafe_op_in_unsafe_fn = "warn" -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(bootstrap)', "cfg(no_salsa_async_drops)"] } +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(bootstrap)', + "cfg(no_salsa_async_drops)", +] } unused_extern_crates = "warn" unused_lifetimes = "warn" unreachable_pub = "warn" diff --git a/src/tools/rust-analyzer/bench_data/glorious_old_parser b/src/tools/rust-analyzer/bench_data/glorious_old_parser index 8136daa8329fd..5022514924687 100644 --- a/src/tools/rust-analyzer/bench_data/glorious_old_parser +++ b/src/tools/rust-analyzer/bench_data/glorious_old_parser @@ -1,3 +1,4 @@ +//- minicore: fn use crate::ast::{AngleBracketedArgs, ParenthesizedArgs, AttrStyle, BareFnTy}; use crate::ast::{GenericBound, TraitBoundModifier}; use crate::ast::Unsafety; diff --git a/src/tools/rust-analyzer/crates/hir-def/src/attrs.rs b/src/tools/rust-analyzer/crates/hir-def/src/attrs.rs index b560d08492ff9..5cf5a9b6be847 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/attrs.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/attrs.rs @@ -159,7 +159,7 @@ fn match_attr_flags(attr_flags: &mut AttrFlags, attr: ast::Meta) -> ControlFlow< None => match &*first_segment { "deprecated" => attr_flags.insert(AttrFlags::IS_DEPRECATED), "doc" => extract_doc_tt_attr(attr_flags, tt), - "repr" => attr_flags.insert(AttrFlags::HAS_REPR), + "repr" | "rustc_scalable_vector" => attr_flags.insert(AttrFlags::HAS_REPR), "target_feature" => attr_flags.insert(AttrFlags::HAS_TARGET_FEATURE), "proc_macro_derive" | "rustc_builtin_macro" => { attr_flags.insert(AttrFlags::IS_DERIVE_OR_BUILTIN_MACRO) @@ -200,7 +200,7 @@ fn match_attr_flags(attr_flags: &mut AttrFlags, attr: ast::Meta) -> ControlFlow< let segment4 = segment4.and_then(|it| it.segment()?.name_ref()); segment1.text() == "test" && segment3.is_none_or(|it| it.text() == "prelude") - && segment4.is_none_or(|it| it.text() == "core") + && segment4.is_none_or(|it| matches!(&*it.text(), "core" | "std")) }); if is_test { attr_flags.insert(AttrFlags::IS_TEST); @@ -217,6 +217,7 @@ fn match_attr_flags(attr_flags: &mut AttrFlags, attr: ast::Meta) -> ControlFlow< "rustc_allow_incoherent_impl" => { attr_flags.insert(AttrFlags::RUSTC_ALLOW_INCOHERENT_IMPL) } + "rustc_scalable_vector" => attr_flags.insert(AttrFlags::HAS_REPR), "fundamental" => attr_flags.insert(AttrFlags::FUNDAMENTAL), "no_std" => attr_flags.insert(AttrFlags::IS_NO_STD), "may_dangle" => attr_flags.insert(AttrFlags::MAY_DANGLE), @@ -257,6 +258,9 @@ fn match_attr_flags(attr_flags: &mut AttrFlags, attr: ast::Meta) -> ControlFlow< Some(second_segment) => match &*first_segment { "rust_analyzer" => match &*second_segment { "skip" => attr_flags.insert(AttrFlags::RUST_ANALYZER_SKIP), + "prefer_underscore_import" => { + attr_flags.insert(AttrFlags::PREFER_UNDERSCORE_IMPORT) + } _ => {} }, _ => {} @@ -329,6 +333,8 @@ bitflags::bitflags! { const MACRO_STYLE_BRACES = 1 << 46; const MACRO_STYLE_BRACKETS = 1 << 47; const MACRO_STYLE_PARENTHESES = 1 << 48; + + const PREFER_UNDERSCORE_IMPORT = 1 << 49; } } @@ -724,14 +730,40 @@ impl AttrFlags { fn repr(db: &dyn DefDatabase, owner: AdtId) -> Option { let mut result = None; collect_attrs::(db, owner.into(), |attr| { - if let ast::Meta::TokenTreeMeta(attr) = attr - && attr.path().is1("repr") + let mut current = None; + if let ast::Meta::TokenTreeMeta(attr) = &attr + && let Some(path) = attr.path() && let Some(tt) = attr.token_tree() - && let Some(repr) = parse_repr_tt(&tt) { + if path.is1("repr") + && let Some(repr) = parse_repr_tt(&tt) + { + current = Some(repr); + } else if path.is1("rustc_scalable_vector") + && let mut tt = TokenTreeChildren::new(&tt) + && let Some(NodeOrToken::Token(scalable)) = tt.next() + && let Some(scalable) = ast::IntNumber::cast(scalable) + && let Ok(scalable) = scalable.value() + && let Ok(scalable) = scalable.try_into() + { + current = Some(ReprOptions { + scalable: Some(rustc_abi::ScalableElt::ElementCount(scalable)), + ..ReprOptions::default() + }); + } + } else if let ast::Meta::PathMeta(attr) = &attr + && attr.path().is1("rustc_scalable_vector") + { + current = Some(ReprOptions { + scalable: Some(rustc_abi::ScalableElt::Container), + ..ReprOptions::default() + }); + } + + if let Some(current) = current { match &mut result { - Some(existing) => merge_repr(existing, repr), - None => result = Some(repr), + Some(existing) => merge_repr(existing, current), + None => result = Some(current), } } ControlFlow::Continue(()) @@ -1076,10 +1108,45 @@ impl AttrFlags { }) } } + + pub fn unstable_feature(self, db: &dyn DefDatabase, owner: AttrDefId) -> Option { + if !self.contains(AttrFlags::IS_UNSTABLE) { + return None; + } + + return unstable_feature(db, owner); + + #[salsa::tracked] + fn unstable_feature(db: &dyn DefDatabase, owner: AttrDefId) -> Option { + collect_attrs(db, owner, |attr| { + if let ast::Meta::TokenTreeMeta(attr) = attr + && let path = attr.path() + && path.is1("unstable") + && let Some(tt) = attr.token_tree() + { + let mut tt = TokenTreeChildren::new(&tt); + // Technically the `feature = "..."` always comes first, but it's not a requirement. + while let Some(token) = tt.next() { + if let NodeOrToken::Token(token) = token + && token.text() == "feature" + && let Some(NodeOrToken::Token(eq)) = tt.next() + && eq.kind() == T![=] + && let Some(NodeOrToken::Token(feature)) = tt.next() + && let Some(feature) = ast::String::cast(feature) + && let Ok(feature) = feature.value() + { + return ControlFlow::Break(Symbol::intern(&feature)); + } + } + } + ControlFlow::Continue(()) + }) + } + } } fn merge_repr(this: &mut ReprOptions, other: ReprOptions) { - let ReprOptions { int, align, pack, flags, field_shuffle_seed: _ } = this; + let ReprOptions { int, align, pack, flags, scalable, field_shuffle_seed: _ } = this; flags.insert(other.flags); *align = (*align).max(other.align); *pack = match (*pack, other.pack) { @@ -1089,6 +1156,9 @@ fn merge_repr(this: &mut ReprOptions, other: ReprOptions) { if other.int.is_some() { *int = other.int; } + if other.scalable.is_some() { + *scalable = other.scalable; + } } fn parse_repr_tt(tt: &ast::TokenTree) -> Option { diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store.rs index 62a17168b18eb..497ed7d37f417 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store.rs @@ -642,9 +642,7 @@ impl ExpressionStore { self.walk_exprs_in_pat(*pat, &mut f); f(*expr); } - Expr::Block { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Async { statements, tail, .. } => { + Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => { for stmt in statements.iter() { match stmt { Statement::Let { initializer, else_branch, pat, .. } => { @@ -677,6 +675,9 @@ impl ExpressionStore { f(*expr); arms.iter().for_each(|arm| { f(arm.expr); + if let Some(guard) = arm.guard { + f(guard); + } self.walk_exprs_in_pat(arm.pat, &mut f); }); } @@ -777,9 +778,7 @@ impl ExpressionStore { Expr::Let { expr, .. } => { f(*expr); } - Expr::Block { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Async { statements, tail, .. } => { + Expr::Block { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } => { for stmt in statements.iter() { match stmt { Statement::Let { initializer, else_branch, .. } => { @@ -923,6 +922,20 @@ impl ExpressionStore { None => const { &Arena::new() }.iter(), } } + + /// The coroutine associated with a coroutine closure. + #[inline] + pub fn coroutine_for_closure(coroutine_closure: ExprId) -> ExprId { + // We keep the async closure exactly one expr before. + ExprId::from_raw(la_arena::RawIdx::from_u32(coroutine_closure.into_raw().into_u32() - 1)) + } + + /// The opposite of [`Self::coroutine_for_closure()`]. + #[inline] + pub fn closure_for_coroutine(coroutine: ExprId) -> ExprId { + // We keep the async closure exactly one expr before. + ExprId::from_raw(la_arena::RawIdx::from_u32(coroutine.into_raw().into_u32() + 1)) + } } impl Index for ExpressionStore { diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/body.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/body.rs index 0c8320369f66d..6be3e49a70def 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/body.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/body.rs @@ -133,7 +133,7 @@ impl Body { expr: ExprId, edition: Edition, ) -> String { - pretty::print_expr_hir(db, self, owner, expr, edition) + pretty::print_expr_hir(db, self, owner.into(), expr, edition) } pub fn pretty_print_pat( @@ -144,7 +144,7 @@ impl Body { oneline: bool, edition: Edition, ) -> String { - pretty::print_pat_hir(db, self, owner, pat, oneline, edition) + pretty::print_pat_hir(db, self, owner.into(), pat, oneline, edition) } } diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/lower.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/lower.rs index 7fe91a3d02dba..04437a59ac815 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/lower.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/lower.rs @@ -46,8 +46,9 @@ use crate::{ }, hir::{ Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind, - Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, OffsetOf, Pat, PatId, - RecordFieldPat, RecordLitField, RecordSpread, Statement, generics::GenericParams, + CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability, + OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement, + generics::GenericParams, }, item_scope::BuiltinShadowMode, item_tree::FieldsShape, @@ -944,12 +945,19 @@ impl<'db> ExprCollector<'db> { }) } - /// An `async fn` needs to capture all parameters in the generated `async` block, even if they have - /// non-captured patterns such as wildcards (to ensure consistent drop order). - fn lower_async_fn(&mut self, params: &mut Vec, body: ExprId) -> ExprId { + /// Lowers a desugared coroutine body after moving all of the arguments + /// into the body. This is to make sure that the future actually owns the + /// arguments that are passed to the function, and to ensure things like + /// drop order are stable. + fn lower_async_block_with_moved_arguments( + &mut self, + params: &mut [PatId], + body: ExprId, + coroutine_source: CoroutineSource, + ) -> ExprId { let mut statements = Vec::new(); for param in params { - let name = match self.store.pats[*param] { + let (name, hygiene) = match self.store.pats[*param] { Pat::Bind { id, .. } if matches!( self.store.bindings[id].mode, @@ -961,14 +969,16 @@ impl<'db> ExprCollector<'db> { } Pat::Bind { id, .. } => { // If this is a `ref` binding, we can't leave it as is but we can at least reuse the name, for better display. - self.store.bindings[id].name.clone() + (self.store.bindings[id].name.clone(), self.store.bindings[id].hygiene) } - _ => self.generate_new_name(), + _ => (self.generate_new_name(), HygieneId::ROOT), }; - let binding_id = - self.alloc_binding(name.clone(), BindingAnnotation::Mutable, HygieneId::ROOT); + let binding_id = self.alloc_binding(name.clone(), BindingAnnotation::Mutable, hygiene); let pat_id = self.alloc_pat_desugared(Pat::Bind { id: binding_id, subpat: None }); let expr = self.alloc_expr_desugared(Expr::Path(name.into())); + if !hygiene.is_root() { + self.store.ident_hygiene.insert(expr.into(), hygiene); + } statements.push(Statement::Let { pat: *param, type_ref: None, @@ -978,23 +988,54 @@ impl<'db> ExprCollector<'db> { *param = pat_id; } - self.alloc_expr_desugared(Expr::Async { - id: None, - statements: statements.into_boxed_slice(), - tail: Some(body), - }) + let async_ = self.async_block( + coroutine_source, + // The default capture mode here is by-ref. Later on during upvar analysis, + // we will force the captured arguments to by-move, but for async closures, + // we want to make sure that we avoid unnecessarily moving captures, or else + // all async closures would default to `FnOnce` as their calling mode. + CaptureBy::Ref, + None, + statements.into_boxed_slice(), + Some(body), + ); + // It's important that this comes last, see the lowering of async closures for why. + self.alloc_expr_desugared(async_) + } + + fn async_block( + &mut self, + source: CoroutineSource, + capture_by: CaptureBy, + id: Option, + statements: Box<[Statement]>, + tail: Option, + ) -> Expr { + let block = self.alloc_expr_desugared(Expr::Block { label: None, id, statements, tail }); + Expr::Closure { + args: Box::default(), + arg_types: Box::default(), + ret_type: None, + body: block, + closure_kind: ClosureKind::AsyncBlock { source }, + capture_by, + } } fn collect( &mut self, - params: &mut Vec, + params: &mut [PatId], expr: Option, awaitable: Awaitable, ) -> ExprId { self.awaitable_context.replace(awaitable); self.with_label_rib(RibKind::Closure, |this| { let body = this.collect_expr_opt(expr); - if awaitable == Awaitable::Yes { this.lower_async_fn(params, body) } else { body } + if awaitable == Awaitable::Yes { + this.lower_async_block_with_moved_arguments(params, body, CoroutineSource::Fn) + } else { + body + } }) } @@ -1126,7 +1167,7 @@ impl<'db> ExprCollector<'db> { self.desugar_try_block(e, result_type) } Some(ast::BlockModifier::Unsafe(_)) => { - self.collect_block_(e, |id, statements, tail| Expr::Unsafe { + self.collect_block_(e, |_, id, statements, tail| Expr::Unsafe { id, statements, tail, @@ -1136,7 +1177,7 @@ impl<'db> ExprCollector<'db> { let label_hygiene = self.hygiene_id_for(label.syntax().text_range()); let label_id = self.collect_label(label); self.with_labeled_rib(label_id, label_hygiene, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Block { + this.collect_block_(e, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -1145,12 +1186,18 @@ impl<'db> ExprCollector<'db> { }) } Some(ast::BlockModifier::Async(_)) => { + let capture_by = + if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; self.with_label_rib(RibKind::Closure, |this| { this.with_awaitable_block(Awaitable::Yes, |this| { - this.collect_block_(e, |id, statements, tail| Expr::Async { - id, - statements, - tail, + this.collect_block_(e, |this, id, statements, tail| { + this.async_block( + CoroutineSource::Block, + capture_by, + id, + statements, + tail, + ) }) }) }) @@ -1378,9 +1425,11 @@ impl<'db> ExprCollector<'db> { } } ast::Expr::ClosureExpr(e) => self.with_label_rib(RibKind::Closure, |this| { - this.with_binding_owner(|this| { + this.with_binding_owner_and_return(|this| { let mut args = Vec::new(); let mut arg_types = Vec::new(); + // For coroutine closures, the body, aka. the coroutine is the bindings owner, and not the closure. + let mut body_is_bindings_owner = false; if let Some(pl) = e.param_list() { let num_params = pl.params().count(); args.reserve_exact(num_params); @@ -1406,7 +1455,7 @@ impl<'db> ExprCollector<'db> { } else { Awaitable::No("non-async closure") }; - let body = this + let mut body = this .with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body())); let closure_kind = if this.is_lowering_coroutine { @@ -1417,7 +1466,16 @@ impl<'db> ExprCollector<'db> { }; ClosureKind::Coroutine(movability) } else if e.async_token().is_some() { - ClosureKind::Async + // It's important that this expr is allocated immediately before the closure. + // We rely on it for `coroutine_for_closure()`. + body = this.lower_async_block_with_moved_arguments( + &mut args, + body, + CoroutineSource::Closure, + ); + body_is_bindings_owner = true; + + ClosureKind::AsyncClosure } else { ClosureKind::Closure }; @@ -1425,7 +1483,7 @@ impl<'db> ExprCollector<'db> { if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref }; this.is_lowering_coroutine = prev_is_lowering_coroutine; this.current_try_block = prev_try_block; - this.alloc_expr( + let closure = this.alloc_expr( Expr::Closure { args: args.into(), arg_types: arg_types.into(), @@ -1435,7 +1493,9 @@ impl<'db> ExprCollector<'db> { capture_by, }, syntax_ptr, - ) + ); + + (if body_is_bindings_owner { body } else { closure }, closure) }) }), ast::Expr::BinExpr(e) => { @@ -1737,13 +1797,24 @@ impl<'db> ExprCollector<'db> { } } - fn with_binding_owner(&mut self, create_expr: impl FnOnce(&mut Self) -> ExprId) -> ExprId { + /// The callback should return two exprs: the first is the bindings owner, the second is the expr to return. + fn with_binding_owner_and_return( + &mut self, + create_expr: impl FnOnce(&mut Self) -> (ExprId, ExprId), + ) -> ExprId { let prev_unowned_bindings_len = self.unowned_bindings.len(); - let expr_id = create_expr(self); + let (bindings_owner, expr_to_return) = create_expr(self); for binding in self.unowned_bindings.drain(prev_unowned_bindings_len..) { - self.store.binding_owners.insert(binding, expr_id); + self.store.binding_owners.insert(binding, bindings_owner); } - expr_id + expr_to_return + } + + fn with_binding_owner(&mut self, create_expr: impl FnOnce(&mut Self) -> ExprId) -> ExprId { + self.with_binding_owner_and_return(move |this| { + let expr = create_expr(this); + (expr, expr) + }) } /// Desugar `try { ; }` into `': { ; ::std::ops::Try::from_output() }`, @@ -1762,7 +1833,7 @@ impl<'db> ExprCollector<'db> { let ptr = AstPtr::new(&e).upcast(); let (btail, expr_id) = self.with_labeled_rib(label, HygieneId::ROOT, |this| { let mut btail = None; - let block = this.collect_block_(e, |id, statements, tail| { + let block = this.collect_block_(e, |_, id, statements, tail| { btail = tail; Expr::Block { id, statements, tail, label: Some(label) } }); @@ -2220,7 +2291,7 @@ impl<'db> ExprCollector<'db> { } fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { - self.collect_block_(block, |id, statements, tail| Expr::Block { + self.collect_block_(block, |_, id, statements, tail| Expr::Block { id, statements, tail, @@ -2231,7 +2302,7 @@ impl<'db> ExprCollector<'db> { fn collect_block_( &mut self, block: ast::BlockExpr, - mk_block: impl FnOnce(Option, Box<[Statement]>, Option) -> Expr, + mk_block: impl FnOnce(&mut Self, Option, Box<[Statement]>, Option) -> Expr, ) -> ExprId { let block_id = self.expander.ast_id_map().ast_id_for_block(&block).map(|file_local_id| { let ast_id = self.expander.in_file(file_local_id); @@ -2266,8 +2337,8 @@ impl<'db> ExprCollector<'db> { }); let syntax_node_ptr = AstPtr::new(&block.into()); - let expr_id = self - .alloc_expr(mk_block(block_id, statements.into_boxed_slice(), tail), syntax_node_ptr); + let expr = mk_block(self, block_id, statements.into_boxed_slice(), tail); + let expr_id = self.alloc_expr(expr, syntax_node_ptr); self.def_map = prev_def_map; self.module = prev_local_module; @@ -2693,17 +2764,17 @@ impl<'db> ExprCollector<'db> { for (rib_idx, rib) in self.label_ribs.iter().enumerate().rev() { match &rib.kind { - RibKind::Normal(label_name, id, label_hygiene) => { - if *label_name == name && *label_hygiene == hygiene_id { - return if self.is_label_valid_from_rib(rib_idx) { - Ok(Some(*id)) - } else { - Err(ExpressionStoreDiagnostics::UnreachableLabel { - name, - node: self.expander.in_file(AstPtr::new(&lifetime)), - }) - }; - } + RibKind::Normal(label_name, id, label_hygiene) + if *label_name == name && *label_hygiene == hygiene_id => + { + return if self.is_label_valid_from_rib(rib_idx) { + Ok(Some(*id)) + } else { + Err(ExpressionStoreDiagnostics::UnreachableLabel { + name, + node: self.expander.in_file(AstPtr::new(&lifetime)), + }) + }; } RibKind::MacroDef(macro_id) => { if let Some((parent_ctx, label_macro_id)) = hygiene_info diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/pretty.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/pretty.rs index 9c9c4db3b208c..70ea54c734cad 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/pretty.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/pretty.rs @@ -9,6 +9,7 @@ use std::{ use hir_expand::{Lookup, mod_path::PathKind}; use itertools::Itertools; use span::Edition; +use stdx::never; use syntax::ast::{HasName, RangeOp}; use crate::{ @@ -400,7 +401,7 @@ fn print_generic_params(db: &dyn DefDatabase, generic_params: &GenericParams, p: pub fn print_expr_hir( db: &dyn DefDatabase, store: &ExpressionStore, - _owner: DefWithBodyId, + _owner: ExpressionStoreOwnerId, expr: ExprId, edition: Edition, ) -> String { @@ -419,7 +420,7 @@ pub fn print_expr_hir( pub fn print_pat_hir( db: &dyn DefDatabase, store: &ExpressionStore, - _owner: DefWithBodyId, + _owner: ExpressionStoreOwnerId, pat: PatId, oneline: bool, edition: Edition, @@ -760,14 +761,31 @@ impl Printer<'_> { w!(self, "]"); } Expr::Closure { args, arg_types, ret_type, body, closure_kind, capture_by } => { + let mut body = *body; + let mut print_pipes = true; match closure_kind { ClosureKind::Coroutine(Movability::Static) => { w!(self, "static "); } - ClosureKind::Async => { + ClosureKind::AsyncClosure => { + if let Expr::Closure { + body: inner_body, + closure_kind: ClosureKind::AsyncBlock { .. }, + .. + } = self.store[body] + { + body = inner_body; + } else { + never!("async closure should always have an async block body"); + } + w!(self, "async "); } - _ => (), + ClosureKind::AsyncBlock { .. } => { + w!(self, "async "); + print_pipes = false; + } + ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (), } match capture_by { CaptureBy::Value => { @@ -775,24 +793,26 @@ impl Printer<'_> { } CaptureBy::Ref => (), } - w!(self, "|"); - for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() { - if i != 0 { - w!(self, ", "); + if print_pipes { + w!(self, "|"); + for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() { + if i != 0 { + w!(self, ", "); + } + self.print_pat(*pat); + if let Some(ty) = ty { + w!(self, ": "); + self.print_type_ref(*ty); + } } - self.print_pat(*pat); - if let Some(ty) = ty { - w!(self, ": "); - self.print_type_ref(*ty); + w!(self, "|"); + if let Some(ret_ty) = ret_type { + w!(self, " -> "); + self.print_type_ref(*ret_ty); } + self.whitespace(); } - w!(self, "|"); - if let Some(ret_ty) = ret_type { - w!(self, " -> "); - self.print_type_ref(*ret_ty); - } - self.whitespace(); - self.print_expr(*body); + self.print_expr(body); } Expr::Tuple { exprs } => { w!(self, "("); @@ -832,9 +852,6 @@ impl Printer<'_> { Expr::Unsafe { id: _, statements, tail } => { self.print_block(Some("unsafe "), statements, tail); } - Expr::Async { id: _, statements, tail } => { - self.print_block(Some("async "), statements, tail); - } Expr::Const(id) => { w!(self, "const {{ /* {id:?} */ }}"); } diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/scope.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/scope.rs index 9738ac5c44c99..c6ba0241b7140 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/scope.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/scope.rs @@ -324,7 +324,7 @@ fn compute_expr_scopes( let mut scope = scopes.root_scope(); compute_expr_scopes(scopes, *id, &mut scope); } - Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } => { + Expr::Unsafe { id, statements, tail } => { let mut scope = scopes.new_block_scope(*scope, *id, None); // Overwrite the old scope for the block expr, so that every block scope can be found // via the block itself (important for blocks that only contain items, no expressions). diff --git a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/tests/body.rs b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/tests/body.rs index 4e5f2ca89327e..db12775df95f6 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/expr_store/tests/body.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/expr_store/tests/body.rs @@ -652,12 +652,12 @@ fn async_fn_weird_param_patterns() { async fn main(&self, param1: i32, ref mut param2: i32, _: i32, param4 @ _: i32, 123: i32) {} "#, expect![[r#" - fn main(self, param1, mut param2, mut 0, param4 @ _, mut 1) async { - let ref mut param2 = param2; - let _ = 0; - let 123 = 1; - {} - }"#]], + fn main(self, param1, mut param2, mut 0, param4 @ _, mut 1) async { + let ref mut param2 = param2; + let _ = 0; + let 123 = 1; + {} + }"#]], ) } diff --git a/src/tools/rust-analyzer/crates/hir-def/src/hir.rs b/src/tools/rust-analyzer/crates/hir-def/src/hir.rs index 7781a8fe54ee0..9e51d0eac98a1 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/hir.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/hir.rs @@ -21,7 +21,7 @@ use std::fmt; use hir_expand::{MacroDefId, name::Name}; use intern::Symbol; use la_arena::Idx; -use rustc_apfloat::ieee::{Half as f16, Quad as f128}; +use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use syntax::ast; use type_ref::TypeRefId; @@ -94,19 +94,19 @@ impl FloatTypeWrapper { Self(sym) } - pub fn to_f128(&self) -> f128 { + pub fn to_f128(&self) -> Quad { self.0.as_str().parse().unwrap_or_default() } - pub fn to_f64(&self) -> f64 { + pub fn to_f64(&self) -> Double { self.0.as_str().parse().unwrap_or_default() } - pub fn to_f32(&self) -> f32 { + pub fn to_f32(&self) -> Single { self.0.as_str().parse().unwrap_or_default() } - pub fn to_f16(&self) -> f16 { + pub fn to_f16(&self) -> Half { self.0.as_str().parse().unwrap_or_default() } } @@ -214,11 +214,6 @@ pub enum Expr { tail: Option, label: Option, }, - Async { - id: Option, - statements: Box<[Statement]>, - tail: Option, - }, Const(ExprId), // FIXME: Fold this into Block with an unsafe flag? Unsafe { @@ -339,7 +334,6 @@ impl Expr { | Expr::Block { .. } | Expr::Unsafe { .. } | Expr::Const(_) - | Expr::Async { .. } | Expr::If { .. } | Expr::Literal(_) | Expr::Loop { .. } @@ -534,7 +528,25 @@ pub enum InlineAsmRegOrRegClass { pub enum ClosureKind { Closure, Coroutine(Movability), - Async, + AsyncBlock { source: CoroutineSource }, + AsyncClosure, +} + +/// In the case of a coroutine created as part of an async/gen construct, +/// which kind of async/gen construct caused it to be created? +/// +/// This helps error messages but is also used to drive coercions in +/// type-checking (see #60424). +#[derive(Clone, PartialEq, Eq, Hash, Debug, Copy)] +pub enum CoroutineSource { + /// An explicit `async`/`gen` block written by the user. + Block, + + /// An explicit `async`/`gen` closure written by the user. + Closure, + + /// The `async`/`gen` block generated as the body of an async/gen function. + Fn, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/tools/rust-analyzer/crates/hir-def/src/item_scope.rs b/src/tools/rust-analyzer/crates/hir-def/src/item_scope.rs index b11a8bcd9097d..fe7d1806857b1 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/item_scope.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/item_scope.rs @@ -680,20 +680,19 @@ impl ItemScope { changed = true; } Entry::Occupied(mut entry) - if !matches!(import, Some(ImportOrExternCrate::Glob(..))) => + if !matches!(import, Some(ImportOrExternCrate::Glob(..))) + && glob_imports.values.remove(&lookup) => { - if glob_imports.values.remove(&lookup) { - cov_mark::hit!(import_shadowed); - - let import = import.and_then(ImportOrExternCrate::import_or_glob); - let prev = std::mem::replace(&mut fld.import, import); - if let Some(import) = import { - self.use_imports_values - .insert(import, prev.map_or(ImportOrDef::Def(fld.def), Into::into)); - } - entry.insert(fld); - changed = true; + cov_mark::hit!(import_shadowed); + + let import = import.and_then(ImportOrExternCrate::import_or_glob); + let prev = std::mem::replace(&mut fld.import, import); + if let Some(import) = import { + self.use_imports_values + .insert(import, prev.map_or(ImportOrDef::Def(fld.def), Into::into)); } + entry.insert(fld); + changed = true; } _ => {} } @@ -720,20 +719,19 @@ impl ItemScope { changed = true; } Entry::Occupied(mut entry) - if !matches!(import, Some(ImportOrExternCrate::Glob(..))) => + if !matches!(import, Some(ImportOrExternCrate::Glob(..))) + && glob_imports.macros.remove(&lookup) => { - if glob_imports.macros.remove(&lookup) { - cov_mark::hit!(import_shadowed); - let prev = std::mem::replace(&mut fld.import, import); - if let Some(import) = import { - self.use_imports_macros.insert( - import, - prev.map_or_else(|| ImportOrDef::Def(fld.def.into()), Into::into), - ); - } - entry.insert(fld); - changed = true; + cov_mark::hit!(import_shadowed); + let prev = std::mem::replace(&mut fld.import, import); + if let Some(import) = import { + self.use_imports_macros.insert( + import, + prev.map_or_else(|| ImportOrDef::Def(fld.def.into()), Into::into), + ); } + entry.insert(fld); + changed = true; } _ => {} } diff --git a/src/tools/rust-analyzer/crates/hir-def/src/lang_item.rs b/src/tools/rust-analyzer/crates/hir-def/src/lang_item.rs index fef92c89b145a..37d70b1e33a91 100644 --- a/src/tools/rust-analyzer/crates/hir-def/src/lang_item.rs +++ b/src/tools/rust-analyzer/crates/hir-def/src/lang_item.rs @@ -306,6 +306,7 @@ language_item_table! { LangItems => /// Trait injected by `#[derive(Eq)]`, (i.e. "Total EQ"; no, I will not apologize). StructuralTeq, sym::structural_teq, TraitId; Copy, sym::copy, TraitId; + UseCloned, sym::use_cloned, TraitId; Clone, sym::clone, TraitId; TrivialClone, sym::trivial_clone, TraitId; Sync, sym::sync, TraitId; @@ -324,6 +325,7 @@ language_item_table! { LangItems => Drop, sym::drop, TraitId; Destruct, sym::destruct, TraitId; + BikeshedGuaranteedNoDrop,sym::bikeshed_guaranteed_no_drop, TraitId; CoerceUnsized, sym::coerce_unsized, TraitId; DispatchFromDyn, sym::dispatch_from_dyn, TraitId; @@ -373,6 +375,8 @@ language_item_table! { LangItems => AsyncFn, sym::async_fn, TraitId; AsyncFnMut, sym::async_fn_mut, TraitId; AsyncFnOnce, sym::async_fn_once, TraitId; + AsyncFnKindHelper, sym::async_fn_kind_helper,TraitId; + AsyncFnKindUpvars, sym::async_fn_kind_upvars,TypeAliasId; CallRefFuture, sym::call_ref_future, TypeAliasId; CallOnceFuture, sym::call_once_future, TypeAliasId; @@ -489,6 +493,8 @@ language_item_table! { LangItems => IntoIterIntoIter, sym::into_iter, FunctionId; IteratorNext, sym::next, FunctionId; Iterator, sym::iterator, TraitId; + FusedIterator, sym::fused_iterator, TraitId; + AsyncIterator, sym::async_iterator, TraitId; PinNewUnchecked, sym::new_unchecked, FunctionId; @@ -509,6 +515,10 @@ language_item_table! { LangItems => CStr, sym::CStr, StructId; Ordering, sym::Ordering, EnumId; + Field, sym::field, TraitId; + FieldBase, sym::field_base, TypeAliasId; + FieldType, sym::field_type, TypeAliasId; + @non_lang_core_traits: core::default, Default; core::fmt, Debug; diff --git a/src/tools/rust-analyzer/crates/hir-expand/src/attrs.rs b/src/tools/rust-analyzer/crates/hir-expand/src/attrs.rs index 49baecb90cd50..d1f962f7ffd38 100644 --- a/src/tools/rust-analyzer/crates/hir-expand/src/attrs.rs +++ b/src/tools/rust-analyzer/crates/hir-expand/src/attrs.rs @@ -92,7 +92,7 @@ impl AstKeyValueMetaExt for ast::KeyValueMeta { } /// The callback is passed the attribute and the outermost `ast::Attr`. -/// Note that one node may map to multiple [`Meta`]s due to `cfg_attr`. +/// Note that one node may map to multiple [`ast::Meta`]s due to `cfg_attr`. /// /// `unsafe(attr)` are passed the inner attribute for now. #[inline] diff --git a/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs b/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs index f208203c931b7..8f513a2bcf666 100644 --- a/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs +++ b/src/tools/rust-analyzer/crates/hir-expand/src/builtin/derive_macro.rs @@ -22,8 +22,9 @@ use crate::{ use syntax::{ ast::{ self, AstNode, FieldList, HasAttrs, HasGenericArgs, HasGenericParams, HasModuleItem, - HasName, HasTypeBounds, edit_in_place::GenericParamsOwnerEdit, make, + HasName, HasTypeBounds, make, }, + syntax_editor::{GetOrCreateWhereClause, SyntaxEditor}, ted, }; @@ -1150,11 +1151,9 @@ fn coerce_pointee_expand( const ADDED_PARAM: &str = "__S"; - let where_clause = strukt.get_or_create_where_clause(); + let mut new_predicates: Vec = Vec::new(); { - let mut new_predicates = Vec::new(); - // # Rewrite generic parameter bounds // For each bound `U: ..` in `struct`, make a new bound with `__S` in place of `#[pointee]` // Example: @@ -1196,16 +1195,13 @@ fn coerce_pointee_expand( } else { make::name_ref(¶m_name.text()) }; - new_predicates.push( - make::where_pred( - Either::Right(make::ty_path(make::path_from_segments( - [make::path_segment(new_bounds_target)], - false, - ))), - new_bounds, - ) - .clone_for_update(), - ); + new_predicates.push(make::where_pred( + Either::Right(make::ty_path(make::path_from_segments( + [make::path_segment(new_bounds_target)], + false, + ))), + new_bounds, + )); } } @@ -1235,7 +1231,7 @@ fn coerce_pointee_expand( // // We should also write a few new `where` bounds from `#[pointee] T` to `__S` // as well as any bound that indirectly involves the `#[pointee] T` type. - for predicate in where_clause.predicates() { + for predicate in strukt.where_clause().into_iter().flat_map(|wc| wc.predicates()) { let predicate = predicate.clone_subtree().clone_for_update(); let Some(pred_target) = predicate.ty() else { continue }; @@ -1269,42 +1265,41 @@ fn coerce_pointee_expand( ); } } - - for new_predicate in new_predicates { - where_clause.add_predicate(new_predicate); - } } { // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location // // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. - where_clause.add_predicate( - make::where_pred( - Either::Right(make::ty_path(make::path_from_segments( - [make::path_segment(make::name_ref(&pointee_param_name.text()))], - false, - ))), - [make::type_bound(make::ty_path(make::path_from_segments( - [ - make::path_segment(make::name_ref("core")), - make::path_segment(make::name_ref("marker")), - make::generic_ty_path_segment( - make::name_ref("Unsize"), - [make::type_arg(make::ty_path(make::path_from_segments( - [make::path_segment(make::name_ref(ADDED_PARAM))], - false, - ))) - .into()], - ), - ], - true, - )))], - ) - .clone_for_update(), - ); + new_predicates.push(make::where_pred( + Either::Right(make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(&pointee_param_name.text()))], + false, + ))), + [make::type_bound(make::ty_path(make::path_from_segments( + [ + make::path_segment(make::name_ref("core")), + make::path_segment(make::name_ref("marker")), + make::generic_ty_path_segment( + make::name_ref("Unsize"), + [make::type_arg(make::ty_path(make::path_from_segments( + [make::path_segment(make::name_ref(ADDED_PARAM))], + false, + ))) + .into()], + ), + ], + true, + )))], + )); } + let (editor, strukt) = SyntaxEditor::with_ast_node(strukt); + strukt.get_or_create_where_clause(&editor, new_predicates.into_iter()); + let edit = editor.finish(); + let strukt = ast::Struct::cast(edit.new_root().clone()).unwrap(); + let adt = ast::Adt::Struct(strukt.clone()); + let self_for_traits = { // Replace the `#[pointee]` with `__S`. let mut type_param_idx = 0; diff --git a/src/tools/rust-analyzer/crates/hir-ty/Cargo.toml b/src/tools/rust-analyzer/crates/hir-ty/Cargo.toml index 238d1b08ae4fe..18426f3095b11 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/Cargo.toml +++ b/src/tools/rust-analyzer/crates/hir-ty/Cargo.toml @@ -55,7 +55,7 @@ hir-expand.workspace = true base-db.workspace = true syntax.workspace = true span.workspace = true -thin-vec = "0.2.14" +thin-vec = "0.2.16" [dev-dependencies] expect-test = "1.5.1" diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/builtin_derive.rs b/src/tools/rust-analyzer/crates/hir-ty/src/builtin_derive.rs index eb3922f4b6233..6a9b1671e7be1 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/builtin_derive.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/builtin_derive.rs @@ -20,8 +20,8 @@ use crate::{ GenericPredicates, db::HirDatabase, next_solver::{ - Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, ParamEnv, StoredEarlyBinder, - StoredTy, TraitRef, Ty, TyKind, fold::fold_tys, generics::Generics, + AliasTy, Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, ParamEnv, + StoredEarlyBinder, StoredTy, TraitRef, Ty, TyKind, fold::fold_tys, generics::Generics, }, }; @@ -342,7 +342,7 @@ fn extend_assoc_type_bounds<'db>( type Result = (); fn visit_ty(&mut self, t: Ty<'db>) -> Self::Result { - if let TyKind::Alias(AliasTyKind::Projection, _) = t.kind() { + if let TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { .. }, .. }) = t.kind() { self.assoc_type_bounds.push( TraitRef::new_from_args( self.interner, @@ -546,49 +546,49 @@ struct WithGenerics<'a, T: Trait, const N: usize>(&'a [T; N], T::Assoc); Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Debug, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Debug, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Debug, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Clone, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Clone, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Clone, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Copy, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Copy, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Copy, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: PartialEq<[#1]>, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): PartialEq<[Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. })]>, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): PartialEq<[Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. })]>, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Eq, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Eq, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Eq, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: PartialOrd<[#1]>, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): PartialOrd<[Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. })]>, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): PartialOrd<[Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. })]>, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Ord, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Ord, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Ord, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Trait, polarity:Positive), bound_vars: [] }) Clause(Binder { value: ConstArgHasType(#2, usize), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Sized, polarity:Positive), bound_vars: [] }) Clause(Binder { value: TraitPredicate(#1: Hash, polarity:Positive), bound_vars: [] }) - Clause(Binder { value: TraitPredicate(Alias(Projection, AliasTy { args: [#1], def_id: TypeAliasId("Assoc"), .. }): Hash, polarity:Positive), bound_vars: [] }) + Clause(Binder { value: TraitPredicate(Alias(AliasTy { args: [#1], kind: Projection { def_id: TypeAliasId("Assoc") }, .. }): Hash, polarity:Positive), bound_vars: [] }) "#]], ); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/consteval.rs b/src/tools/rust-analyzer/crates/hir-ty/src/consteval.rs index 928396c63aaf7..80e7e05d76fed 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/consteval.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/consteval.rs @@ -8,28 +8,30 @@ use hir_def::{ ConstId, EnumVariantId, ExpressionStoreOwnerId, GeneralConstId, GenericDefId, HasModule, StaticId, attrs::AttrFlags, - builtin_type::{BuiltinInt, BuiltinType, BuiltinUint}, expr_store::{Body, ExpressionStore}, hir::{Expr, ExprId, Literal}, }; use hir_expand::Lookup; +use rustc_abi::Size; +use rustc_apfloat::Float; use rustc_type_ir::inherent::IntoKind; +use stdx::never; use triomphe::Arc; use crate::{ - LifetimeElisionKind, MemoryMap, ParamEnvAndCrate, TyLoweringContext, + LifetimeElisionKind, ParamEnvAndCrate, TyLoweringContext, db::HirDatabase, display::DisplayTarget, infer::InferenceContext, - mir::{MirEvalError, MirLowerError}, + mir::{MirEvalError, MirLowerError, pad16}, next_solver::{ - Const, ConstBytes, ConstKind, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs, - StoredConst, StoredGenericArgs, Ty, ValueConst, + Allocation, Const, ConstKind, Consts, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs, + ScalarInt, StoredAllocation, StoredGenericArgs, Ty, TyKind, ValTreeKind, default_types, }, traits::StoredParamEnvAndCrate, }; -use super::mir::{interpret_mir, lower_body_to_mir, pad16}; +use super::mir::{interpret_mir, lower_body_to_mir}; pub fn unknown_const<'db>(_ty: Ty<'db>) -> Const<'db> { Const::new(DbInterner::conjure(), rustc_type_ir::ConstKind::Error(ErrorGuaranteed)) @@ -84,140 +86,87 @@ pub fn intern_const_ref<'a>( db: &'a dyn HirDatabase, value: &Literal, ty: Ty<'a>, - _krate: Crate, + krate: Crate, ) -> Const<'a> { let interner = DbInterner::new_no_crate(db); - let kind = match value { - &Literal::Uint(i, builtin_ty) - if builtin_ty.is_none() || ty.as_builtin() == builtin_ty.map(BuiltinType::Uint) => - { - let memory = match ty.as_builtin() { - Some(BuiltinType::Uint(builtin_uint)) => match builtin_uint { - BuiltinUint::U8 => Box::new([i as u8]) as Box<[u8]>, - BuiltinUint::U16 => Box::new((i as u16).to_le_bytes()), - BuiltinUint::U32 => Box::new((i as u32).to_le_bytes()), - BuiltinUint::U64 => Box::new((i as u64).to_le_bytes()), - BuiltinUint::U128 => Box::new((i).to_le_bytes()), - BuiltinUint::Usize => Box::new((i as usize).to_le_bytes()), - }, - _ => return Const::new(interner, rustc_type_ir::ConstKind::Error(ErrorGuaranteed)), - }; - rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory, memory_map: MemoryMap::default() }, - )) + let Ok(data_layout) = db.target_data_layout(krate) else { + return Const::error(interner); + }; + let valtree = match (ty.kind(), value) { + (TyKind::Uint(uint), Literal::Uint(value, _)) => { + let size = uint.bit_width().map(Size::from_bits).unwrap_or(data_layout.pointer_size()); + let scalar = ScalarInt::try_from_uint(*value, size).unwrap(); + ValTreeKind::Leaf(scalar) } - &Literal::Int(i, None) - if ty - .as_builtin() - .is_some_and(|builtin_ty| matches!(builtin_ty, BuiltinType::Uint(_))) => - { - let memory = match ty.as_builtin() { - Some(BuiltinType::Uint(builtin_uint)) => match builtin_uint { - BuiltinUint::U8 => Box::new([i as u8]) as Box<[u8]>, - BuiltinUint::U16 => Box::new((i as u16).to_le_bytes()), - BuiltinUint::U32 => Box::new((i as u32).to_le_bytes()), - BuiltinUint::U64 => Box::new((i as u64).to_le_bytes()), - BuiltinUint::U128 => Box::new((i as u128).to_le_bytes()), - BuiltinUint::Usize => Box::new((i as usize).to_le_bytes()), - }, - _ => return Const::new(interner, rustc_type_ir::ConstKind::Error(ErrorGuaranteed)), - }; - rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory, memory_map: MemoryMap::default() }, - )) + (TyKind::Uint(uint), Literal::Int(value, _)) => { + // `Literal::Int` is the default, so we also need to account for the type being uint. + let size = uint.bit_width().map(Size::from_bits).unwrap_or(data_layout.pointer_size()); + let scalar = ScalarInt::try_from_uint(*value as u128, size).unwrap(); + ValTreeKind::Leaf(scalar) } - &Literal::Int(i, builtin_ty) - if builtin_ty.is_none() || ty.as_builtin() == builtin_ty.map(BuiltinType::Int) => - { - let memory = match ty.as_builtin() { - Some(BuiltinType::Int(builtin_int)) => match builtin_int { - BuiltinInt::I8 => Box::new([i as u8]) as Box<[u8]>, - BuiltinInt::I16 => Box::new((i as i16).to_le_bytes()), - BuiltinInt::I32 => Box::new((i as i32).to_le_bytes()), - BuiltinInt::I64 => Box::new((i as i64).to_le_bytes()), - BuiltinInt::I128 => Box::new((i).to_le_bytes()), - BuiltinInt::Isize => Box::new((i as isize).to_le_bytes()), - }, - _ => return Const::new(interner, rustc_type_ir::ConstKind::Error(ErrorGuaranteed)), - }; - rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory, memory_map: MemoryMap::default() }, - )) + (TyKind::Int(int), Literal::Int(value, _)) => { + let size = int.bit_width().map(Size::from_bits).unwrap_or(data_layout.pointer_size()); + let scalar = ScalarInt::try_from_int(*value, size).unwrap(); + ValTreeKind::Leaf(scalar) } - Literal::Float(float_type_wrapper, builtin_float) - if builtin_float.is_none() - || ty.as_builtin() == builtin_float.map(BuiltinType::Float) => - { - let memory = match ty.as_builtin().unwrap() { - BuiltinType::Float(builtin_float) => match builtin_float { - // FIXME: - hir_def::builtin_type::BuiltinFloat::F16 => Box::new([0u8; 2]) as Box<[u8]>, - hir_def::builtin_type::BuiltinFloat::F32 => { - Box::new(float_type_wrapper.to_f32().to_le_bytes()) - } - hir_def::builtin_type::BuiltinFloat::F64 => { - Box::new(float_type_wrapper.to_f64().to_le_bytes()) - } - // FIXME: - hir_def::builtin_type::BuiltinFloat::F128 => Box::new([0; 16]), - }, - _ => unreachable!(), + (TyKind::Bool, Literal::Bool(value)) => ValTreeKind::Leaf(ScalarInt::from(*value)), + (TyKind::Char, Literal::Char(value)) => ValTreeKind::Leaf(ScalarInt::from(*value)), + (TyKind::Float(float), Literal::Float(value, _)) => { + let size = Size::from_bits(float.bit_width()); + let value = match float { + rustc_ast_ir::FloatTy::F16 => value.to_f16().to_bits(), + rustc_ast_ir::FloatTy::F32 => value.to_f32().to_bits(), + rustc_ast_ir::FloatTy::F64 => value.to_f64().to_bits(), + rustc_ast_ir::FloatTy::F128 => value.to_f128().to_bits(), }; - rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory, memory_map: MemoryMap::default() }, + let scalar = ScalarInt::try_from_uint(value, size).unwrap(); + ValTreeKind::Leaf(scalar) + } + (_, Literal::String(value)) => { + let u8_values = &interner.default_types().consts.u8_values; + ValTreeKind::Branch(Consts::new_from_iter( + interner, + value.as_str().as_bytes().iter().map(|&byte| u8_values[usize::from(byte)]), )) } - Literal::Bool(b) if ty.is_bool() => rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory: Box::new([*b as u8]), memory_map: MemoryMap::default() }, - )), - Literal::Char(c) if ty.is_char() => rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { - memory: (*c as u32).to_le_bytes().into(), - memory_map: MemoryMap::default(), - }, - )), - Literal::String(symbol) if ty.is_str() => rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { - memory: symbol.as_str().as_bytes().into(), - memory_map: MemoryMap::default(), - }, - )), - Literal::ByteString(items) if ty.as_slice().is_some_and(|ty| ty.is_u8()) => { - rustc_type_ir::ConstKind::Value(ValueConst::new( - ty, - ConstBytes { memory: items.clone(), memory_map: MemoryMap::default() }, + (_, Literal::ByteString(value)) => { + let u8_values = &interner.default_types().consts.u8_values; + ValTreeKind::Branch(Consts::new_from_iter( + interner, + value.iter().map(|&byte| u8_values[usize::from(byte)]), )) } - // FIXME - Literal::CString(_items) => rustc_type_ir::ConstKind::Error(ErrorGuaranteed), - _ => rustc_type_ir::ConstKind::Error(ErrorGuaranteed), + (_, Literal::CString(_)) => { + // FIXME: + return Const::error(interner); + } + _ => { + never!("mismatching type for literal"); + return Const::error(interner); + } }; - Const::new(interner, kind) + Const::new_valtree(interner, ty, valtree) } /// Interns a possibly-unknown target usize pub fn usize_const<'db>(db: &'db dyn HirDatabase, value: Option, krate: Crate) -> Const<'db> { - intern_const_ref( - db, - &match value { - Some(value) => Literal::Uint(value, Some(BuiltinUint::Usize)), - None => { - return Const::new( - DbInterner::new_no_crate(db), - rustc_type_ir::ConstKind::Error(ErrorGuaranteed), - ); - } - }, - Ty::new_uint(DbInterner::new_no_crate(db), rustc_type_ir::UintTy::Usize), - krate, - ) + let interner = DbInterner::new_no_crate(db); + let value = match value { + Some(value) => value, + None => { + return Const::error(interner); + } + }; + let Ok(data_layout) = db.target_data_layout(krate) else { + return Const::error(interner); + }; + let usize_ty = interner.default_types().types.usize; + let scalar = ScalarInt::try_from_uint(value, data_layout.pointer_size()).unwrap(); + Const::new_valtree(interner, usize_ty, ValTreeKind::Leaf(scalar)) +} + +pub fn allocation_as_usize(ec: Allocation<'_>) -> u128 { + u128::from_le_bytes(pad16(&ec.memory, false)) } pub fn try_const_usize<'db>(db: &'db dyn HirDatabase, c: Const<'db>) -> Option { @@ -230,20 +179,30 @@ pub fn try_const_usize<'db>(db: &'db dyn HirDatabase, c: Const<'db>) -> Option { let subst = unevaluated_const.args; let ec = db.const_eval(id, subst, None).ok()?; - try_const_usize(db, ec) + Some(allocation_as_usize(ec)) } GeneralConstId::StaticId(id) => { let ec = db.const_eval_static(id).ok()?; - try_const_usize(db, ec) + Some(allocation_as_usize(ec)) } GeneralConstId::AnonConstId(_) => None, }, - ConstKind::Value(val) => Some(u128::from_le_bytes(pad16(&val.value.inner().memory, false))), + ConstKind::Value(val) => { + if val.ty == default_types(db).types.usize { + Some(val.value.inner().to_leaf().to_uint_unchecked()) + } else { + None + } + } ConstKind::Error(_) => None, ConstKind::Expr(_) => None, } } +pub fn allocation_as_isize(ec: Allocation<'_>) -> i128 { + i128::from_le_bytes(pad16(&ec.memory, true)) +} + pub fn try_const_isize<'db>(db: &'db dyn HirDatabase, c: &Const<'db>) -> Option { match (*c).kind() { ConstKind::Param(_) => None, @@ -254,15 +213,21 @@ pub fn try_const_isize<'db>(db: &'db dyn HirDatabase, c: &Const<'db>) -> Option< GeneralConstId::ConstId(id) => { let subst = unevaluated_const.args; let ec = db.const_eval(id, subst, None).ok()?; - try_const_isize(db, &ec) + Some(allocation_as_isize(ec)) } GeneralConstId::StaticId(id) => { let ec = db.const_eval_static(id).ok()?; - try_const_isize(db, &ec) + Some(allocation_as_isize(ec)) } GeneralConstId::AnonConstId(_) => None, }, - ConstKind::Value(val) => Some(i128::from_le_bytes(pad16(&val.value.inner().memory, true))), + ConstKind::Value(val) => { + if val.ty == default_types(db).types.isize { + Some(val.value.inner().to_leaf().to_int_unchecked()) + } else { + None + } + } ConstKind::Error(_) => None, ConstKind::Expr(_) => None, } @@ -299,11 +264,7 @@ pub(crate) fn const_eval_discriminant_variant( .store(), )?; let c = interpret_mir(db, mir_body, false, None)?.0?; - let c = if is_signed { - try_const_isize(db, &c).unwrap() - } else { - try_const_usize(db, c).unwrap() as i128 - }; + let c = if is_signed { allocation_as_isize(c) } else { allocation_as_usize(c) as i128 }; Ok(c) } @@ -341,7 +302,11 @@ pub(crate) fn eval_to_const<'db>(expr: ExprId, ctx: &mut InferenceContext<'_, 'd lower_body_to_mir(ctx.db, body_owner, Body::of(ctx.db, body_owner), &infer, expr) && let Ok((Ok(result), _)) = interpret_mir(ctx.db, Arc::new(mir_body), true, None) { - return result; + return Const::new_from_allocation( + ctx.interner(), + &result, + ParamEnvAndCrate { param_env: ctx.table.param_env, krate: ctx.resolver.krate() }, + ); } Const::error(ctx.interner()) } @@ -359,7 +324,7 @@ pub(crate) fn const_eval<'db>( def: ConstId, subst: GenericArgs<'db>, trait_env: Option>, -) -> Result, ConstEvalError> { +) -> Result, ConstEvalError> { return match const_eval_query(db, def, subst.store(), trait_env.map(|env| env.store())) { Ok(konst) => Ok(konst.as_ref()), Err(err) => Err(err.clone()), @@ -371,7 +336,7 @@ pub(crate) fn const_eval<'db>( def: ConstId, subst: StoredGenericArgs, trait_env: Option, - ) -> Result { + ) -> Result { let body = db.monomorphized_mir_body( def.into(), subst, @@ -392,7 +357,7 @@ pub(crate) fn const_eval<'db>( _: ConstId, _: StoredGenericArgs, _: Option, - ) -> Result { + ) -> Result { Err(ConstEvalError::MirLowerError(MirLowerError::Loop)) } } @@ -400,7 +365,7 @@ pub(crate) fn const_eval<'db>( pub(crate) fn const_eval_static<'db>( db: &'db dyn HirDatabase, def: StaticId, -) -> Result, ConstEvalError> { +) -> Result, ConstEvalError> { return match const_eval_static_query(db, def) { Ok(konst) => Ok(konst.as_ref()), Err(err) => Err(err.clone()), @@ -410,7 +375,7 @@ pub(crate) fn const_eval_static<'db>( pub(crate) fn const_eval_static_query<'db>( db: &'db dyn HirDatabase, def: StaticId, - ) -> Result { + ) -> Result { let interner = DbInterner::new_no_crate(db); let body = db.monomorphized_mir_body( def.into(), @@ -430,7 +395,7 @@ pub(crate) fn const_eval_static<'db>( _: &dyn HirDatabase, _: salsa::Id, _: StaticId, - ) -> Result { + ) -> Result { Err(ConstEvalError::MirLowerError(MirLowerError::Loop)) } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/consteval/tests.rs b/src/tools/rust-analyzer/crates/hir-ty/src/consteval/tests.rs index aee27dcfdef9e..723fa0fc687a3 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/consteval/tests.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/consteval/tests.rs @@ -5,17 +5,16 @@ use rustc_apfloat::{ Float, ieee::{Half as f16, Quad as f128}, }; -use rustc_type_ir::inherent::IntoKind; use test_fixture::WithFixture; use test_utils::skip_slow_tests; use crate::{ MemoryMap, - consteval::try_const_usize, + consteval::allocation_as_usize, db::HirDatabase, display::DisplayTarget, mir::pad16, - next_solver::{Const, ConstBytes, ConstKind, DbInterner, GenericArgs}, + next_solver::{Allocation, DbInterner, GenericArgs}, setup_tracing, test_db::TestDB, }; @@ -45,7 +44,11 @@ fn check_fail( crate::attach_db(&db, || match eval_goal(&db, file_id) { Ok(_) => panic!("Expected fail, but it succeeded"), Err(e) => { - assert!(error(simplify(e.clone())), "Actual error was: {}", pretty_print_err(e, &db)) + assert!( + error(simplify(e.clone())), + "Actual error was: {}\n{e:?}", + pretty_print_err(e.clone(), &db) + ) } }) } @@ -94,13 +97,7 @@ fn check_answer( panic!("Error in evaluating goal: {err}"); } }; - match r.kind() { - ConstKind::Value(value) => { - let ConstBytes { memory, memory_map } = value.value.inner(); - check(memory, memory_map); - } - _ => panic!("Expected number but found {r:?}"), - } + check(&r.memory, &r.memory_map); }); } @@ -121,7 +118,7 @@ fn pretty_print_err(e: ConstEvalError, db: &TestDB) -> String { err } -fn eval_goal(db: &TestDB, file_id: EditionedFileId) -> Result, ConstEvalError> { +fn eval_goal(db: &TestDB, file_id: EditionedFileId) -> Result, ConstEvalError> { let _tracing = setup_tracing(); let interner = DbInterner::new_no_crate(db); let module_id = db.module_for_file(file_id.file_id(db)); @@ -1795,14 +1792,14 @@ const GOAL: i32 = { fn closure_capture_unsized_type() { check_number( r#" - //- minicore: fn, copy, slice, index, coerce_unsized + //- minicore: fn, copy, slice, index, coerce_unsized, sized fn f(x: &::Ty) -> &::Ty { let c = || &*x; c() } trait A { - type Ty; + type Ty: ?Sized; } impl A for i32 { @@ -1813,7 +1810,7 @@ fn closure_capture_unsized_type() { let k: &[u8] = &[1, 2, 3]; let k = f::(k); k[0] + k[1] + k[2] - } + }; "#, 6, ); @@ -2524,7 +2521,7 @@ fn enums() { ); crate::attach_db(&db, || { let r = eval_goal(&db, file_id).unwrap(); - assert_eq!(try_const_usize(&db, r), Some(1)); + assert_eq!(allocation_as_usize(r), 1); }) } @@ -2537,7 +2534,15 @@ fn const_loop() { const F2: i32 = 2 * F1; const GOAL: i32 = F3; "#, - |e| e == ConstEvalError::MirLowerError(MirLowerError::Loop), + |e| { + if let ConstEvalError::MirEvalError(MirEvalError::ConstEvalError(_, inner)) = e + && let ConstEvalError::MirLowerError(MirLowerError::Loop) = *inner + { + true + } else { + false + } + }, ); } @@ -2940,6 +2945,14 @@ fn recursive_adt() { TAG_TREE }; "#, - |e| matches!(e, ConstEvalError::MirLowerError(MirLowerError::Loop)), + |e| { + if let ConstEvalError::MirEvalError(MirEvalError::ConstEvalError(_, inner)) = e + && let ConstEvalError::MirLowerError(MirLowerError::Loop) = *inner + { + true + } else { + false + } + }, ); } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/db.rs b/src/tools/rust-analyzer/crates/hir-ty/src/db.rs index a0fb75397a235..3bf2d9a6a60b4 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/db.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/db.rs @@ -7,7 +7,7 @@ use hir_def::{ AdtId, BuiltinDeriveImplId, CallableDefId, ConstId, ConstParamId, DefWithBodyId, EnumVariantId, ExpressionStoreOwnerId, FunctionId, GenericDefId, ImplId, LifetimeParamId, LocalFieldId, StaticId, TraitId, TypeAliasId, VariantId, builtin_derive::BuiltinDeriveImplMethod, - db::DefDatabase, hir::ExprId, layout::TargetDataLayout, + db::DefDatabase, expr_store::ExpressionStore, hir::ExprId, layout::TargetDataLayout, }; use la_arena::ArenaMap; use salsa::plumbing::AsId; @@ -21,8 +21,8 @@ use crate::{ lower::{Diagnostics, GenericDefaults}, mir::{BorrowckResult, MirBody, MirLowerError}, next_solver::{ - Const, EarlyBinder, GenericArgs, ParamEnv, PolyFnSig, StoredEarlyBinder, StoredGenericArgs, - StoredTy, TraitRef, Ty, VariancesOf, + Allocation, EarlyBinder, GenericArgs, ParamEnv, PolyFnSig, StoredEarlyBinder, + StoredGenericArgs, StoredTy, TraitRef, Ty, VariancesOf, }, traits::{ParamEnvAndCrate, StoredParamEnvAndCrate}, }; @@ -68,11 +68,11 @@ pub trait HirDatabase: DefDatabase + std::fmt::Debug { def: ConstId, subst: GenericArgs<'db>, trait_env: Option>, - ) -> Result, ConstEvalError>; + ) -> Result, ConstEvalError>; #[salsa::invoke(crate::consteval::const_eval_static)] #[salsa::transparent] - fn const_eval_static<'db>(&'db self, def: StaticId) -> Result, ConstEvalError>; + fn const_eval_static<'db>(&'db self, def: StaticId) -> Result, ConstEvalError>; #[salsa::invoke(crate::consteval::const_eval_discriminant_variant)] #[salsa::cycle(cycle_result = crate::consteval::const_eval_discriminant_cycle_result)] @@ -200,12 +200,6 @@ pub trait HirDatabase: DefDatabase + std::fmt::Debug { #[salsa::interned] fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId; - #[salsa::interned] - fn intern_closure(&self, id: InternedClosure) -> InternedClosureId; - - #[salsa::interned] - fn intern_coroutine(&self, id: InternedCoroutine) -> InternedCoroutineId; - #[salsa::invoke(crate::variance::variances_of)] #[salsa::transparent] fn variances_of<'db>(&'db self, def: GenericDefId) -> VariancesOf<'db>; @@ -238,17 +232,87 @@ pub struct InternedOpaqueTyId { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct InternedClosure(pub ExpressionStoreOwnerId, pub ExprId); -#[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] pub struct InternedClosureId { pub loc: InternedClosure, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct InternedCoroutine(pub ExpressionStoreOwnerId, pub ExprId); +impl InternedClosureId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Closure, + .. + } + ), + "expected a closure, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } +} -#[salsa_macros::interned(no_lifetime, debug, revisions = usize::MAX)] +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] #[derive(PartialOrd, Ord)] pub struct InternedCoroutineId { - pub loc: InternedCoroutine, + pub loc: InternedClosure, +} + +impl InternedCoroutineId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::Coroutine(_) + | hir_def::hir::ClosureKind::AsyncBlock { .. }, + .. + } + ), + "expected a coroutine, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } +} + +#[salsa_macros::interned(constructor = new_impl, no_lifetime, debug, revisions = usize::MAX)] +#[derive(PartialOrd, Ord)] +pub struct InternedCoroutineClosureId { + pub loc: InternedClosure, +} + +impl InternedCoroutineClosureId { + #[inline] + pub fn new(db: &dyn HirDatabase, loc: InternedClosure) -> Self { + if cfg!(debug_assertions) { + let store = ExpressionStore::of(db, loc.0); + let expr = &store[loc.1]; + assert!( + matches!( + expr, + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncClosure, + .. + } + ), + "expected a coroutine closure, found {expr:?}" + ); + } + + Self::new_impl(db, loc) + } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/expr.rs b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/expr.rs index 33d9dd538dd38..068118c7053d8 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/expr.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/expr.rs @@ -146,7 +146,7 @@ impl<'db> ExprValidator<'db> { Expr::If { .. } => { self.check_for_unnecessary_else(id, expr); } - Expr::Block { .. } | Expr::Async { .. } | Expr::Unsafe { .. } => { + Expr::Block { .. } | Expr::Unsafe { .. } => { self.validate_block(expr); } _ => {} @@ -238,8 +238,7 @@ impl<'db> ExprValidator<'db> { if (pat_ty == scrut_ty || scrut_ty .as_reference() - .map(|(match_expr_ty, ..)| match_expr_ty == pat_ty) - .unwrap_or(false)) + .is_none_or(|(match_expr_ty, ..)| match_expr_ty == pat_ty)) && types_of_subpatterns_do_match(arm.pat, self.body, self.infer) { // If we had a NotUsefulMatchArm diagnostic, we could @@ -325,10 +324,7 @@ impl<'db> ExprValidator<'db> { } fn validate_block(&mut self, expr: &Expr) { - let (Expr::Block { statements, .. } - | Expr::Async { statements, .. } - | Expr::Unsafe { statements, .. }) = expr - else { + let (Expr::Block { statements, .. } | Expr::Unsafe { statements, .. }) = expr else { return; }; let pattern_arena = Arena::new(); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/match_check/pat_util.rs b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/match_check/pat_util.rs index c6a26cdd1d0f8..0b39692e46a09 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/match_check/pat_util.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/match_check/pat_util.rs @@ -38,6 +38,12 @@ pub(crate) trait EnumerateAndAdjustIterator { } impl EnumerateAndAdjustIterator for T { + /// When there is a list of items with a gap of an unknown length inside, and another list + /// of item it should be zipped against, this operates on the list with the gap and returns, + /// for each item, the index it should match in the other list. + /// + /// When compiling Rust, such situation often occurs for tuple structs/tuples with a rest pattern + /// that should be matched against the fields. fn enumerate_and_adjust( self, expected_len: usize, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs index 09c648139c458..ee33f7d1585e3 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/diagnostics/unsafe_check.rs @@ -406,7 +406,7 @@ impl<'db> UnsafeVisitor<'db> { }); return; } - Expr::Block { statements, .. } | Expr::Async { statements, .. } => { + Expr::Block { statements, .. } => { self.walk_pats_top( statements.iter().filter_map(|statement| match statement { &Statement::Let { pat, .. } => Some(pat), diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/display.rs b/src/tools/rust-analyzer/crates/hir-ty/src/display.rs index 0c4e34db7db0e..e4a8def4425a5 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/display.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/display.rs @@ -39,8 +39,7 @@ use rustc_apfloat::{ use rustc_ast_ir::FloatTy; use rustc_hash::FxHashSet; use rustc_type_ir::{ - AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, CoroutineClosureArgsParts, RegionKind, - Upcast, + AliasTyKind, BoundVarIndexKind, CoroutineArgsParts, RegionKind, Upcast, inherent::{AdtDef, GenericArgs as _, IntoKind, Term as _, Ty as _, Tys as _}, }; use smallvec::SmallVec; @@ -48,16 +47,17 @@ use span::Edition; use stdx::never; use crate::{ - CallableDefId, FnAbi, ImplTraitId, InferenceResult, MemoryMap, ParamEnvAndCrate, consteval, - db::{HirDatabase, InternedClosure, InternedCoroutine}, + CallableDefId, FnAbi, ImplTraitId, MemoryMap, ParamEnvAndCrate, consteval, + db::{HirDatabase, InternedClosure}, generics::generics, layout::Layout, lower::GenericPredicates, mir::pad16, next_solver::{ - AliasTy, Clause, ClauseKind, Const, ConstKind, DbInterner, ExistentialPredicate, FnSig, - GenericArg, GenericArgKind, GenericArgs, ParamEnv, PolyFnSig, Region, SolverDefId, - StoredEarlyBinder, StoredTy, Term, TermKind, TraitRef, Ty, TyKind, TypingMode, + AliasTy, Allocation, Clause, ClauseKind, Const, ConstKind, DbInterner, + ExistentialPredicate, FnSig, GenericArg, GenericArgKind, GenericArgs, ParamEnv, PolyFnSig, + Region, SolverDefId, StoredEarlyBinder, StoredTy, Term, TermKind, TraitRef, Ty, TyKind, + TypingMode, ValTree, abi::Safety, infer::{DbInternerInferExt, traits::ObligationCause}, }, @@ -647,7 +647,7 @@ fn write_projection<'db>( ClauseKind::TypeOutlives(t) => t.0, _ => return false, }; - let TyKind::Alias(AliasTyKind::Projection, a) = ty.kind() else { + let TyKind::Alias(a) = ty.kind() else { return false; }; a == *alias @@ -658,7 +658,7 @@ fn write_projection<'db>( write_bounds_like_dyn_trait_with_prefix( f, "impl", - Either::Left(Ty::new_alias(f.interner, AliasTyKind::Projection, *alias)), + Either::Left(Ty::new_alias(f.interner, *alias)), &bounds, SizedByDefault::NotSized, needs_parens_if_multi, @@ -674,7 +674,7 @@ fn write_projection<'db>( write!( f, ">::{}", - TypeAliasSignature::of(f.db, alias.def_id.expect_type_alias()) + TypeAliasSignature::of(f.db, alias.kind.def_id().expect_type_alias()) .name .display(f.db, f.edition()) )?; @@ -692,6 +692,12 @@ impl<'db> HirDisplay<'db> for GenericArg<'db> { } } +impl<'db> HirDisplay<'db> for Allocation<'db> { + fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { + render_const_scalar(f, &self.memory, &self.memory_map, self.ty) + } +} + impl<'db> HirDisplay<'db> for Const<'db> { fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result { match self.kind() { @@ -711,12 +717,7 @@ impl<'db> HirDisplay<'db> for Const<'db> { f.end_location_link(); Ok(()) } - ConstKind::Value(const_bytes) => render_const_scalar( - f, - &const_bytes.value.inner().memory, - &const_bytes.value.inner().memory_map, - const_bytes.ty, - ), + ConstKind::Value(value) => render_const_scalar_from_valtree(f, value.ty, value.value), ConstKind::Unevaluated(unev) => { let c = unev.def.0; write!(f, "{}", c.name(f.db))?; @@ -1006,7 +1007,7 @@ fn render_const_scalar_inner<'db>( TyKind::Pat(_, _) => f.write_str(""), TyKind::Error(..) | TyKind::Placeholder(_) - | TyKind::Alias(_, _) + | TyKind::Alias(..) | TyKind::Param(_) | TyKind::Bound(_, _) | TyKind::Infer(_) => f.write_str(""), @@ -1015,6 +1016,151 @@ fn render_const_scalar_inner<'db>( } } +fn render_const_scalar_from_valtree<'db>( + f: &mut HirFormatter<'_, 'db>, + ty: Ty<'db>, + valtree: ValTree<'db>, +) -> Result { + let param_env = ParamEnv::empty(); + let infcx = f.interner.infer_ctxt().build(TypingMode::PostAnalysis); + let ty = infcx.at(&ObligationCause::new(), param_env).deeply_normalize(ty).unwrap_or(ty); + render_const_scalar_from_valtree_inner(f, ty, valtree, param_env) +} + +fn render_const_scalar_from_valtree_inner<'db>( + f: &mut HirFormatter<'_, 'db>, + ty: Ty<'db>, + valtree: ValTree<'db>, + _param_env: ParamEnv<'db>, +) -> Result { + use TyKind; + match ty.kind() { + TyKind::Bool => write!(f, "{}", valtree.inner().to_leaf().try_to_bool().unwrap()), + TyKind::Char => { + let it = valtree.inner().to_leaf().to_u32(); + let Ok(c) = char::try_from(it) else { + return f.write_str(""); + }; + write!(f, "{c:?}") + } + TyKind::Int(_) => { + let it = valtree.inner().to_leaf().to_int_unchecked(); + write!(f, "{it}") + } + TyKind::Uint(_) => { + let it = valtree.inner().to_leaf().to_uint_unchecked(); + write!(f, "{it}") + } + TyKind::Float(fl) => match fl { + FloatTy::F16 => { + // FIXME(#17451): Replace with builtins once they are stabilised. + let it = f16::from_bits(valtree.inner().to_leaf().to_u16() as u128); + let s = it.to_string(); + if s.strip_prefix('-').unwrap_or(&s).chars().all(|c| c.is_ascii_digit()) { + // Match Rust debug formatting + write!(f, "{s}.0") + } else { + write!(f, "{s}") + } + } + FloatTy::F32 => { + let it = f32::from_bits(valtree.inner().to_leaf().to_u32()); + write!(f, "{it:?}") + } + FloatTy::F64 => { + let it = f64::from_bits(valtree.inner().to_leaf().to_u64()); + write!(f, "{it:?}") + } + FloatTy::F128 => { + // FIXME(#17451): Replace with builtins once they are stabilised. + let it = f128::from_bits(valtree.inner().to_leaf().to_u128()); + let s = it.to_string(); + if s.strip_prefix('-').unwrap_or(&s).chars().all(|c| c.is_ascii_digit()) { + // Match Rust debug formatting + write!(f, "{s}.0") + } else { + write!(f, "{s}") + } + } + }, + TyKind::Ref(_, inner_ty, _) => { + render_const_scalar_from_valtree_inner(f, inner_ty, valtree, _param_env) + } + TyKind::Str => { + let bytes = valtree + .inner() + .to_branch() + .iter() + .map(|konst| match konst.kind() { + ConstKind::Value(value) => Some(value.value.inner().to_leaf().to_u8()), + _ => None, + }) + .collect::>>(); + let Some(bytes) = bytes else { return f.write_str("") }; + let s = std::str::from_utf8(&bytes).unwrap_or(""); + write!(f, "{s:?}") + } + TyKind::Slice(inner_ty) | TyKind::Array(inner_ty, _) => { + let mut first = true; + write!(f, "[")?; + for item in valtree.inner().to_branch() { + if !first { + write!(f, ", ")?; + } else { + first = false; + } + let ConstKind::Value(value) = item.kind() else { + return f.write_str(""); + }; + render_const_scalar_from_valtree_inner(f, inner_ty, value.value, _param_env)?; + } + write!(f, "]") + } + TyKind::Tuple(tys) => { + let mut first = true; + write!(f, "(")?; + for (inner_ty, item) in std::iter::zip(tys, valtree.inner().to_branch()) { + if !first { + write!(f, ", ")?; + } else { + first = false; + } + let ConstKind::Value(value) = item.kind() else { + return f.write_str(""); + }; + render_const_scalar_from_valtree_inner(f, inner_ty, value.value, _param_env)?; + } + write!(f, ")") + } + TyKind::Adt(..) => { + // FIXME: ADTs, requires `adt_const_params`. + f.write_str("") + } + TyKind::FnDef(..) => ty.hir_fmt(f), + TyKind::FnPtr(_, _) | TyKind::RawPtr(_, _) => { + let it = valtree.inner().to_leaf().to_uint_unchecked(); + write!(f, "{it:#X} as ")?; + ty.hir_fmt(f) + } + TyKind::Never => f.write_str("!"), + TyKind::Closure(_, _) => f.write_str(""), + TyKind::Coroutine(_, _) => f.write_str(""), + TyKind::CoroutineWitness(_, _) => f.write_str(""), + TyKind::CoroutineClosure(_, _) => f.write_str(""), + TyKind::UnsafeBinder(_) => f.write_str(""), + // The below arms are unreachable, since const eval will bail out before here. + TyKind::Foreign(_) => f.write_str(""), + TyKind::Pat(_, _) => f.write_str(""), + TyKind::Error(..) + | TyKind::Placeholder(_) + | TyKind::Alias(..) + | TyKind::Param(_) + | TyKind::Bound(_, _) + | TyKind::Infer(_) => f.write_str(""), + TyKind::Dynamic(_, _) => f.write_str(""), + } +} + fn render_variant_after_name<'db>( data: &VariantFields, f: &mut HirFormatter<'_, 'db>, @@ -1281,7 +1427,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { hir_fmt_generics(f, parameters.as_slice(), Some(def.def_id().0.into()), None)?; } - TyKind::Alias(AliasTyKind::Projection, alias_ty) => { + TyKind::Alias(alias_ty @ AliasTy { kind: AliasTyKind::Projection { .. }, .. }) => { write_projection(f, &alias_ty, trait_bounds_need_parens)? } TyKind::Foreign(alias) => { @@ -1290,8 +1436,8 @@ impl<'db> HirDisplay<'db> for Ty<'db> { write!(f, "{}", type_alias.name.display(f.db, f.edition()))?; f.end_location_link(); } - TyKind::Alias(AliasTyKind::Opaque, alias_ty) => { - let opaque_ty_id = match alias_ty.def_id { + TyKind::Alias(alias_ty @ AliasTy { kind: AliasTyKind::Opaque { def_id }, .. }) => { + let opaque_ty_id = match def_id { SolverDefId::InternedOpaqueTyId(id) => id, _ => unreachable!(), }; @@ -1349,9 +1495,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } let sig = interner.signature_unclosure(substs.as_closure().sig(), Safety::Safe); let sig = sig.skip_binder(); - let InternedClosure(owner, _) = db.lookup_intern_closure(id); - let infer = InferenceResult::of(db, owner); - let (_, kind) = infer.closure_info(id); + let kind = substs.as_closure().kind(); match f.closure_style { ClosureStyle::ImplFn => write!(f, "impl {kind:?}(")?, ClosureStyle::RANotation => write!(f, "|")?, @@ -1403,26 +1547,16 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } _ => (), } - let CoroutineClosureArgsParts { closure_kind_ty, signature_parts_ty, .. } = - args.split_coroutine_closure_args(); - let kind = closure_kind_ty.to_opt_closure_kind().unwrap(); + let kind = args.as_coroutine_closure().kind(); let kind = match kind { rustc_type_ir::ClosureKind::Fn => "AsyncFn", rustc_type_ir::ClosureKind::FnMut => "AsyncFnMut", rustc_type_ir::ClosureKind::FnOnce => "AsyncFnOnce", }; - let TyKind::FnPtr(coroutine_sig, _) = signature_parts_ty.kind() else { - unreachable!("invalid coroutine closure signature"); - }; + let coroutine_sig = args.as_coroutine_closure().coroutine_closure_sig(); let coroutine_sig = coroutine_sig.skip_binder(); - let coroutine_inputs = coroutine_sig.inputs(); - let TyKind::Tuple(coroutine_inputs) = coroutine_inputs[1].kind() else { - unreachable!("invalid coroutine closure signature"); - }; - let TyKind::Tuple(coroutine_output) = coroutine_sig.output().kind() else { - unreachable!("invalid coroutine closure signature"); - }; - let coroutine_output = coroutine_output.as_slice()[1]; + let coroutine_inputs = coroutine_sig.tupled_inputs_ty.tuple_fields(); + let coroutine_output = coroutine_sig.return_ty; match f.closure_style { ClosureStyle::ImplFn => write!(f, "impl {kind}(")?, ClosureStyle::RANotation => write!(f, "async |")?, @@ -1536,17 +1670,16 @@ impl<'db> HirDisplay<'db> for Ty<'db> { } TyKind::Infer(..) => write!(f, "_")?, TyKind::Coroutine(coroutine_id, subst) => { - let InternedCoroutine(owner, expr_id) = coroutine_id.0.loc(db); + let InternedClosure(owner, expr_id) = coroutine_id.0.loc(db); let CoroutineArgsParts { resume_ty, yield_ty, return_ty, .. } = subst.split_coroutine_args(); let body = ExpressionStore::of(db, owner); let expr = &body[expr_id]; match expr { hir_def::hir::Expr::Closure { - closure_kind: hir_def::hir::ClosureKind::Async, + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. }, .. - } - | hir_def::hir::Expr::Async { .. } => { + } => { let future_trait = f.lang_items().Future; let output = future_trait.and_then(|t| { t.trait_items(db) @@ -1597,7 +1730,7 @@ impl<'db> HirDisplay<'db> for Ty<'db> { TyKind::CoroutineWitness(..) => write!(f, "{{coroutine witness}}")?, TyKind::Pat(_, _) => write!(f, "{{pat}}")?, TyKind::UnsafeBinder(_) => write!(f, "{{unsafe binder}}")?, - TyKind::Alias(_, _) => write!(f, "{{alias}}")?, + TyKind::Alias(..) => write!(f, "{{alias}}")?, } Ok(()) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/drop.rs b/src/tools/rust-analyzer/crates/hir-ty/src/drop.rs index ddc4e4ce85ef7..0d25d7dbd1d13 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/drop.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/drop.rs @@ -1,15 +1,16 @@ //! Utilities for computing drop info about types. use hir_def::{ - AdtId, + AdtId, ImplId, signatures::{StructFlags, StructSignature}, }; use rustc_hash::FxHashSet; -use rustc_type_ir::inherent::{AdtDef, IntoKind}; +use rustc_type_ir::inherent::{AdtDef, GenericArgs as _, IntoKind}; use stdx::never; use crate::{ - InferenceResult, consteval, + consteval, + db::HirDatabase, method_resolution::TraitImpls, next_solver::{ DbInterner, ParamEnv, SimplifiedType, Ty, TyKind, @@ -18,24 +19,23 @@ use crate::{ }, }; -fn has_destructor(interner: DbInterner<'_>, adt: AdtId) -> bool { - let db = interner.db; +#[salsa::tracked] +pub fn destructor(db: &dyn HirDatabase, adt: AdtId) -> Option { let module = match adt { AdtId::EnumId(id) => db.lookup_intern_enum(id).container, AdtId::StructId(id) => db.lookup_intern_struct(id).container, AdtId::UnionId(id) => db.lookup_intern_union(id).container, }; - let Some(drop_trait) = interner.lang_items().Drop else { - return false; - }; + let interner = DbInterner::new_with(db, module.krate(db)); + let drop_trait = interner.lang_items().Drop?; let impls = match module.block(db) { Some(block) => match TraitImpls::for_block(db, block) { Some(it) => &**it, - None => return false, + None => return None, }, None => TraitImpls::for_crate(db, module.krate(db)), }; - !impls.for_trait_and_self_ty(drop_trait, &SimplifiedType::Adt(adt.into())).0.is_empty() + impls.for_trait_and_self_ty(drop_trait, &SimplifiedType::Adt(adt.into())).0.first().copied() } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -71,7 +71,7 @@ fn has_drop_glue_impl<'db>( match ty.kind() { TyKind::Adt(adt_def, subst) => { let adt_id = adt_def.def_id().0; - if has_destructor(infcx.interner, adt_id) { + if adt_def.destructor(infcx.interner).is_some() { return DropGlue::HasDropGlue; } match adt_id { @@ -132,21 +132,17 @@ fn has_drop_glue_impl<'db>( has_drop_glue_impl(infcx, ty, env, visited) } TyKind::Slice(ty) => has_drop_glue_impl(infcx, ty, env, visited), - TyKind::Closure(closure_id, subst) => { - let owner = db.lookup_intern_closure(closure_id.0).0; - let infer = InferenceResult::of(db, owner); - let (captures, _) = infer.closure_info(closure_id.0); - let env = db.trait_environment(owner); - captures - .iter() - .map(|capture| has_drop_glue_impl(infcx, capture.ty(db, subst), env, visited)) - .max() - .unwrap_or(DropGlue::None) + TyKind::Closure(_, args) => { + has_drop_glue_impl(infcx, args.as_closure().tupled_upvars_ty(), env, visited) } - // FIXME: Handle coroutines. - TyKind::Coroutine(..) | TyKind::CoroutineWitness(..) | TyKind::CoroutineClosure(..) => { - DropGlue::None + TyKind::Coroutine(_, args) => { + has_drop_glue_impl(infcx, args.as_coroutine().tupled_upvars_ty(), env, visited) + } + TyKind::CoroutineClosure(_, args) => { + has_drop_glue_impl(infcx, args.as_coroutine_closure().tupled_upvars_ty(), env, visited) } + // FIXME: Coroutine witness. + TyKind::CoroutineWitness(..) => DropGlue::None, TyKind::Ref(..) | TyKind::RawPtr(..) | TyKind::FnDef(..) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/dyn_compatibility.rs b/src/tools/rust-analyzer/crates/hir-ty/src/dyn_compatibility.rs index e70918f8e1125..ba63343d49351 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/dyn_compatibility.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/dyn_compatibility.rs @@ -21,8 +21,9 @@ use crate::{ db::{HirDatabase, InternedOpaqueTyId}, lower::{GenericPredicates, associated_ty_item_bounds}, next_solver::{ - Binder, Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, Goal, ParamEnv, ParamTy, - SolverDefId, TraitPredicate, TraitRef, Ty, TypingMode, infer::DbInternerInferExt, mk_param, + AliasTy, Binder, Clause, Clauses, DbInterner, EarlyBinder, GenericArgs, Goal, ParamEnv, + ParamTy, SolverDefId, TraitPredicate, TraitRef, Ty, TypingMode, infer::DbInternerInferExt, + mk_param, }, traits::next_trait_solve_in_ctxt, }; @@ -239,30 +240,30 @@ fn contains_illegal_self_type_reference<'db, T: rustc_type_ir::TypeVisitable ControlFlow::Break(()), rustc_type_ir::TyKind::Param(_) => ControlFlow::Continue(()), - rustc_type_ir::TyKind::Alias(AliasTyKind::Projection, proj) => { - match self.allow_self_projection { - AllowSelfProjection::Yes => { - let trait_ = proj.trait_def_id(interner); - let trait_ = match trait_ { - SolverDefId::TraitId(id) => id, - _ => unreachable!(), - }; - if self.super_traits.is_none() { - self.super_traits = Some( - elaborate::supertrait_def_ids(interner, self.trait_.into()) - .map(|super_trait| super_trait.0) - .collect(), - ) - } - if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) { - ControlFlow::Continue(()) - } else { - ty.super_visit_with(self) - } + rustc_type_ir::TyKind::Alias( + proj @ AliasTy { kind: AliasTyKind::Projection { .. }, .. }, + ) => match self.allow_self_projection { + AllowSelfProjection::Yes => { + let trait_ = proj.trait_def_id(interner); + let trait_ = match trait_ { + SolverDefId::TraitId(id) => id, + _ => unreachable!(), + }; + if self.super_traits.is_none() { + self.super_traits = Some( + elaborate::supertrait_def_ids(interner, self.trait_.into()) + .map(|super_trait| super_trait.0) + .collect(), + ) + } + if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) { + ControlFlow::Continue(()) + } else { + ty.super_visit_with(self) } - AllowSelfProjection::No => ty.super_visit_with(self), } - } + AllowSelfProjection::No => ty.super_visit_with(self), + }, _ => ty.super_visit_with(self), } } @@ -503,8 +504,12 @@ fn contains_illegal_impl_trait_in_trait<'db>( &mut self, ty: as rustc_type_ir::Interner>::Ty, ) -> Self::Result { - if let rustc_type_ir::TyKind::Alias(AliasTyKind::Opaque, op) = ty.kind() { - let id = match op.def_id { + if let rustc_type_ir::TyKind::Alias(AliasTy { + kind: AliasTyKind::Opaque { def_id }, + .. + }) = ty.kind() + { + let id = match def_id { SolverDefId::InternedOpaqueTyId(id) => id, _ => unreachable!(), }; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs index bd897113bf0e5..339ce7933af13 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer.rs @@ -14,6 +14,7 @@ //! the `ena` crate, which is extracted from rustc. mod autoderef; +mod callee; pub(crate) mod cast; pub(crate) mod closure; mod coerce; @@ -28,9 +29,9 @@ mod path; mod place_op; pub(crate) mod unify; -use std::{cell::OnceCell, convert::identity, iter}; +use std::{cell::OnceCell, convert::identity, fmt, iter, ops::Deref}; -use base_db::Crate; +use base_db::{Crate, FxIndexMap}; use either::Either; use hir_def::{ AdtId, AssocItemId, ConstId, ConstParamId, DefWithBodyId, ExpressionStoreOwnerId, FieldId, @@ -54,15 +55,22 @@ use rustc_type_ir::{ AliasTyKind, TypeFoldable, inherent::{AdtDef, IntoKind, Ty as _}, }; +use smallvec::SmallVec; use span::Edition; use stdx::never; use thin_vec::ThinVec; use crate::{ ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, TargetFeatures, + closure_analysis::PlaceBase, collect_type_inference_vars, - db::{HirDatabase, InternedClosureId, InternedOpaqueTyId}, + db::{HirDatabase, InternedOpaqueTyId}, infer::{ + callee::DeferredCallResolution, + closure::analysis::{ + BorrowKind, + expr_use_visitor::{FakeReadCause, Place}, + }, coerce::{CoerceMany, DynamicCoerceMany}, diagnostics::{Diagnostics, InferenceTyLoweringContext as TyLoweringContext}, expr::ExprIsRead, @@ -71,14 +79,12 @@ use crate::{ ImplTraitIdx, ImplTraitLoweringMode, LifetimeElisionKind, diagnostics::TyLoweringDiagnostic, }, method_resolution::{CandidateId, MethodResolutionUnstableFeatures}, - mir::MirSpan, next_solver::{ AliasTy, Const, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs, Region, StoredGenericArgs, StoredTy, StoredTys, Ty, TyKind, Tys, abi::Safety, infer::{InferCtxt, ObligationInspector, traits::ObligationCause}, }, - traits::FnTrait, utils::TargetFeatureIsSafeInTarget, }; @@ -91,7 +97,6 @@ pub use coerce::could_coerce; pub use unify::{could_unify, could_unify_deeply}; use cast::{CastCheck, CastError}; -pub(crate) use closure::analysis::{CaptureKind, CapturedItem, CapturedItemWithoutTy}; /// The entry point of type inference. fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> InferenceResult { @@ -266,7 +271,10 @@ fn infer_finalize(mut ctx: InferenceContext<'_, '_>) -> InferenceResult { ctx.table.select_obligations_where_possible(); - ctx.infer_closures(); + // Closure and coroutine analysis may run after fallback + // because they don't constrain other type variables. + ctx.closure_analyze(); + assert!(ctx.deferred_call_resolutions.is_empty()); ctx.table.select_obligations_where_possible(); @@ -498,7 +506,7 @@ pub enum Adjust { /// The target type is `U` in both cases, with the region and mutability /// being those shared by both the receiver and the returned reference. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct OverloadedDeref(pub Option); +pub struct OverloadedDeref(pub Mutability); #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum AutoBorrowMutability { @@ -535,15 +543,6 @@ pub enum AutoBorrow { RawPtr(Mutability), } -impl AutoBorrow { - fn mutability(self) -> Mutability { - match self { - AutoBorrow::Ref(mutbl) => mutbl.into(), - AutoBorrow::RawPtr(mutbl) => mutbl, - } - } -} - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum PointerCast { /// Go from a fn-item type to a fn-pointer type. @@ -637,11 +636,226 @@ pub struct InferenceResult { /// the first `rest` has implicit `ref` binding mode, but the second `rest` binding mode is `move`. pub(crate) binding_modes: ArenaMap, - pub(crate) closure_info: FxHashMap, FnTrait)>, - // FIXME: remove this field - pub mutated_bindings_in_closure: FxHashSet, - pub(crate) coercion_casts: FxHashSet, + + pub closures_data: FxHashMap, +} + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +pub struct ClosureData { + /// Tracks the minimum captures required for a closure; + /// see `MinCaptureInformationMap` for more details. + pub min_captures: RootVariableMinCaptureList, + + /// Tracks the fake reads required for a closure and the reason for the fake read. + /// When performing pattern matching for closures, there are times we don't end up + /// reading places that are mentioned in a closure (because of _ patterns). However, + /// to ensure the places are initialized, we introduce fake reads. + /// Consider these two examples: + /// ```ignore (discriminant matching with only wildcard arm) + /// let x: u8; + /// let c = || match x { _ => () }; + /// ``` + /// In this example, we don't need to actually read/borrow `x` in `c`, and so we don't + /// want to capture it. However, we do still want an error here, because `x` should have + /// to be initialized at the point where c is created. Therefore, we add a "fake read" + /// instead. + /// ```ignore (destructured assignments) + /// let c = || { + /// let (t1, t2) = t; + /// } + /// ``` + /// In the second example, we capture the disjoint fields of `t` (`t.0` & `t.1`), but + /// we never capture `t`. This becomes an issue when we build MIR as we require + /// information on `t` in order to create place `t.0` and `t.1`. We can solve this + /// issue by fake reading `t`. + pub fake_reads: Box<[(Place, FakeReadCause, SmallVec<[CaptureSourceStack; 2]>)]>, +} + +/// Part of `MinCaptureInformationMap`; Maps a root variable to the list of `CapturedPlace`. +/// Used to track the minimum set of `Place`s that need to be captured to support all +/// Places captured by the closure starting at a given root variable. +/// +/// This provides a convenient and quick way of checking if a variable being used within +/// a closure is a capture of a local variable. +pub(crate) type RootVariableMinCaptureList = FxIndexMap; + +/// Part of `MinCaptureInformationMap`; List of `CapturePlace`s. +pub(crate) type MinCaptureList = Vec; + +/// A composite describing a `Place` that is captured by a closure. +#[derive(Eq, PartialEq, Clone, Debug, Hash)] +pub struct CapturedPlace { + /// The `Place` that is captured. + pub place: Place, + + /// `CaptureKind` and expression(s) that resulted in such capture of `place`. + pub info: CaptureInfo, + + /// Represents if `place` can be mutated or not. + pub mutability: Mutability, +} + +impl CapturedPlace { + pub fn is_by_ref(&self) -> bool { + match self.info.capture_kind { + UpvarCapture::ByValue | UpvarCapture::ByUse => false, + UpvarCapture::ByRef(..) => true, + } + } + + pub fn captured_local(&self) -> BindingId { + match self.place.base { + PlaceBase::Upvar { var_id: local, .. } | PlaceBase::Local(local) => local, + PlaceBase::Rvalue | PlaceBase::StaticItem => { + unreachable!("only locals can be captured") + } + } + } + + /// The type of the capture stored in the closure, which is different from the type of the captured place + /// if we capture by reference. + pub fn captured_ty<'db>(&self, db: &'db dyn HirDatabase) -> Ty<'db> { + let place_ty = self.place.ty(); + let make_ref = |mutbl| { + let interner = DbInterner::new_no_crate(db); + let region = Region::new_erased(interner); + Ty::new_ref(interner, region, place_ty, mutbl) + }; + match self.info.capture_kind { + UpvarCapture::ByUse | UpvarCapture::ByValue => place_ty, + UpvarCapture::ByRef(kind) => make_ref(kind.to_mutbl_lossy()), + } + } +} + +#[derive(Clone)] +pub struct CaptureSourceStack(CaptureSourceStackRepr); + +#[derive(Clone)] +enum CaptureSourceStackRepr { + One(ExprOrPatId), + Two([ExprOrPatId; 2]), + Many(ThinVec), +} + +impl PartialEq for CaptureSourceStack { + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} + +impl Eq for CaptureSourceStack {} + +impl std::hash::Hash for CaptureSourceStack { + fn hash(&self, state: &mut H) { + (**self).hash(state); + } +} + +const _: () = assert!(size_of::() == 16); + +impl Deref for CaptureSourceStack { + type Target = [ExprOrPatId]; + + #[inline] + fn deref(&self) -> &Self::Target { + match &self.0 { + CaptureSourceStackRepr::One(it) => std::slice::from_ref(it), + CaptureSourceStackRepr::Two(it) => it, + CaptureSourceStackRepr::Many(it) => it, + } + } +} + +impl fmt::Debug for CaptureSourceStack { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("CaptureSourceStack").field(&&**self).finish() + } +} + +impl CaptureSourceStack { + #[inline] + pub fn len(&self) -> usize { + match &self.0 { + CaptureSourceStackRepr::One(_) => 1, + CaptureSourceStackRepr::Two(_) => 2, + CaptureSourceStackRepr::Many(it) => it.len(), + } + } + + #[inline] + pub(crate) fn from_single(id: ExprOrPatId) -> Self { + Self(CaptureSourceStackRepr::One(id)) + } + + #[inline] + pub fn final_source(&self) -> ExprOrPatId { + *self.last().expect("should always have a final source") + } + + pub fn push(&mut self, new_id: ExprOrPatId) { + match &mut self.0 { + CaptureSourceStackRepr::One(old_id) => { + self.0 = CaptureSourceStackRepr::Two([*old_id, new_id]) + } + CaptureSourceStackRepr::Two([old_id1, old_id2]) => { + self.0 = CaptureSourceStackRepr::Many(ThinVec::from([*old_id1, *old_id2, new_id])); + } + CaptureSourceStackRepr::Many(old_ids) => old_ids.push(new_id), + } + } + + pub fn truncate(&mut self, new_len: usize) { + debug_assert!(new_len > 0); + match &mut self.0 { + CaptureSourceStackRepr::One(_) => {} + CaptureSourceStackRepr::Two([first, _]) => { + if new_len == 1 { + self.0 = CaptureSourceStackRepr::One(*first) + } + } + CaptureSourceStackRepr::Many(ids) => ids.truncate(new_len), + } + } + + pub fn shrink_to_fit(&mut self) { + match &mut self.0 { + CaptureSourceStackRepr::One(_) | CaptureSourceStackRepr::Two(_) => {} + CaptureSourceStackRepr::Many(ids) => match **ids { + [one] => self.0 = CaptureSourceStackRepr::One(one), + [first, second] => self.0 = CaptureSourceStackRepr::Two([first, second]), + _ => ids.shrink_to_fit(), + }, + } + } +} + +/// Part of `MinCaptureInformationMap`; describes the capture kind (&, &mut, move) +/// for a particular capture as well as identifying the part of the source code +/// that triggered this capture to occur. +#[derive(Eq, PartialEq, Clone, Debug, Hash)] +pub struct CaptureInfo { + pub sources: SmallVec<[CaptureSourceStack; 2]>, + + /// Capture mode that was selected + pub capture_kind: UpvarCapture, +} + +/// Information describing the capture of an upvar. This is computed +/// during `typeck`, specifically by `regionck`. +#[derive(Eq, PartialEq, Clone, Debug, Copy, Hash)] +pub enum UpvarCapture { + /// Upvar is captured by value. This is always true when the + /// closure is labeled `move`, but can also be true in other cases + /// depending on inference. + ByValue, + + /// Upvar is captured by use. This is true when the closure is labeled `use`. + ByUse, + + /// Upvar is captured by reference. + ByRef(BorrowKind), } #[salsa::tracked] @@ -699,9 +913,8 @@ impl InferenceResult { pat_adjustments: Default::default(), binding_modes: Default::default(), expr_adjustments: Default::default(), - closure_info: Default::default(), - mutated_bindings_in_closure: Default::default(), coercion_casts: Default::default(), + closures_data: Default::default(), } } @@ -771,9 +984,6 @@ impl InferenceResult { pub fn type_of_type_placeholder<'db>(&self, type_ref: TypeRefId) -> Option> { self.type_of_type_placeholder.get(&type_ref).map(|ty| ty.as_ref()) } - pub fn closure_info(&self, closure: InternedClosureId) -> &(Vec, FnTrait) { - self.closure_info.get(&closure).unwrap() - } pub fn type_of_expr_or_pat<'db>(&self, id: ExprOrPatId) -> Option> { match id { ExprOrPatId::ExprId(id) => self.type_of_expr.get(id).map(|it| it.as_ref()), @@ -870,6 +1080,26 @@ impl InferenceResult { pub fn binding_ty<'db>(&self, id: BindingId) -> Ty<'db> { self.type_of_binding.get(id).map_or(self.error_ty.as_ref(), |it| it.as_ref()) } + + /// This does not deduplicate, which means you'll get the types once per capture. + pub fn closure_captures_tys<'db>(&self, closure: ExprId) -> impl Iterator> { + self.closures_data[&closure] + .min_captures + .values() + .flat_map(|captures| captures.iter().map(|capture| capture.place.ty())) + } + + /// Like [`Self::closure_captures_tys()`], but using [`CapturedPlace::captured_ty()`]. + pub fn closure_captures_captured_tys<'db>( + &self, + db: &'db dyn HirDatabase, + closure: ExprId, + ) -> impl Iterator> { + self.closures_data[&closure] + .min_captures + .values() + .flat_map(|captures| captures.iter().map(|capture| capture.captured_ty(db))) + } } /// The inference context contains all information needed during type inference. @@ -913,19 +1143,8 @@ pub(crate) struct InferenceContext<'body, 'db> { deferred_cast_checks: Vec>, - // fields related to closure capture - current_captures: Vec, - /// A stack that has an entry for each projection in the current capture. - /// - /// For example, in `a.b.c`, we capture the spans of `a`, `a.b`, and `a.b.c`. - /// We do that because sometimes we truncate projections (when a closure captures - /// both `a.b` and `a.b.c`), and we want to provide accurate spans in this case. - current_capture_span_stack: Vec, - current_closure: Option, - /// Stores the list of closure ids that need to be analyzed before this closure. See the - /// comment on `InferenceContext::sort_closures` - closure_dependencies: FxHashMap>, - deferred_closures: FxHashMap, Ty<'db>, Vec>, ExprId)>>, + /// The key is an expression defining a closure or a coroutine closure. + deferred_call_resolutions: FxHashMap>>, diagnostics: Diagnostics, } @@ -1017,13 +1236,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> { diverges: Diverges::Maybe, breakables: Vec::new(), deferred_cast_checks: Vec::new(), - current_captures: Vec::new(), - current_capture_span_stack: Vec::new(), - current_closure: None, - deferred_closures: FxHashMap::default(), - closure_dependencies: FxHashMap::default(), inside_assignment: false, diagnostics: Diagnostics::default(), + deferred_call_resolutions: FxHashMap::default(), } } @@ -1082,7 +1297,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { // there is no problem in it being `pub(crate)`, remove this comment. fn resolve_all(self) -> InferenceResult { let InferenceContext { - mut table, mut result, tuple_field_accesses_rev, diagnostics, .. + mut table, + mut result, + tuple_field_accesses_rev, + diagnostics, + types, + .. } = self; let mut diagnostics = diagnostics.finish(); // Destructure every single field so whenever new fields are added to `InferenceResult` we @@ -1098,16 +1318,12 @@ impl<'body, 'db> InferenceContext<'body, 'db> { type_of_type_placeholder, type_of_opaque, type_mismatches, + closures_data, has_errors, error_ty: _, pat_adjustments, binding_modes: _, expr_adjustments, - // Types in `closure_info` have already been `resolve_completely()`'d during - // `InferenceContext::infer_closures()` (in `HirPlace::ty()` specifically), so no need - // to resolve them here. - closure_info: _, - mutated_bindings_in_closure: _, tuple_field_access_types: _, coercion_casts: _, diagnostics: _, @@ -1194,6 +1410,38 @@ impl<'body, 'db> InferenceContext<'body, 'db> { *has_errors = *has_errors || adjustment.as_ref().references_non_lt_error(); } pat_adjustments.shrink_to_fit(); + for closure_data in closures_data.values_mut() { + let ClosureData { min_captures, fake_reads } = closure_data; + let dummy_place = || Place { + base_ty: types.types.error.store(), + base: closure::analysis::expr_use_visitor::PlaceBase::Rvalue, + projections: Vec::new(), + }; + + for (place, _, sources) in fake_reads { + *place = table.resolve_completely(std::mem::replace(place, dummy_place())); + place.projections.shrink_to_fit(); + for source in &mut *sources { + source.shrink_to_fit(); + } + sources.shrink_to_fit(); + } + + for min_capture in min_captures.values_mut() { + for captured in &mut *min_capture { + let CapturedPlace { place, info, mutability: _ } = captured; + *place = table.resolve_completely(std::mem::replace(place, dummy_place())); + let CaptureInfo { sources, capture_kind: _ } = info; + for source in &mut *sources { + source.shrink_to_fit(); + } + sources.shrink_to_fit(); + } + min_capture.shrink_to_fit(); + } + min_captures.shrink_to_fit(); + } + closures_data.shrink_to_fit(); result.tuple_field_access_types = tuple_field_accesses_rev .into_iter() .map(|subst| table.resolve_completely(subst).store()) @@ -1387,6 +1635,21 @@ impl<'body, 'db> InferenceContext<'body, 'db> { self.diagnostics.push(diagnostic); } + fn record_deferred_call_resolution( + &mut self, + closure_def_id: ExprId, + r: DeferredCallResolution<'db>, + ) { + self.deferred_call_resolutions.entry(closure_def_id).or_default().push(r); + } + + fn remove_deferred_call_resolutions( + &mut self, + closure_def_id: ExprId, + ) -> Vec> { + self.deferred_call_resolutions.remove(&closure_def_id).unwrap_or_default() + } + fn with_ty_lowering( &mut self, store: &ExpressionStore, @@ -1440,9 +1703,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> { }); if placeholder_ids.len() == type_variables.len() { - for (placeholder_id, type_variable) in - placeholder_ids.into_iter().zip(type_variables.into_iter()) - { + for (placeholder_id, type_variable) in placeholder_ids.into_iter().zip(type_variables) { self.write_type_placeholder_ty(placeholder_id, type_variable); } } @@ -1648,6 +1909,23 @@ impl<'body, 'db> InferenceContext<'body, 'db> { result.unwrap_or(self.types.types.error) } + pub(crate) fn type_must_be_known_at_this_point( + &self, + _id: ExprOrPatId, + _ty: Ty<'db>, + ) -> Ty<'db> { + // FIXME: Emit an diagnostic. + self.types.types.error + } + + pub(crate) fn require_type_is_sized(&mut self, ty: Ty<'db>) { + if !ty.references_non_lt_error() + && let Some(sized_trait) = self.lang_items.Sized + { + self.table.register_bound(ty, sized_trait, ObligationCause::new()); + } + } + fn expr_ty(&self, expr: ExprId) -> Ty<'db> { self.result.expr_ty(expr) } @@ -1674,10 +1952,9 @@ impl<'body, 'db> InferenceContext<'body, 'db> { Some(res_assoc_ty) => { let alias = Ty::new_alias( self.interner(), - AliasTyKind::Projection, AliasTy::new( self.interner(), - res_assoc_ty.into(), + AliasTyKind::Projection { def_id: res_assoc_ty.into() }, iter::once(inner_ty.into()).chain(params.iter().copied()), ), ); @@ -1728,8 +2005,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> { let args = self.infcx().fill_rest_fresh_args(assoc_type.into(), trait_ref.args); let alias = Ty::new_alias( self.interner(), - AliasTyKind::Projection, - AliasTy::new_from_args(self.interner(), assoc_type.into(), args), + AliasTy::new_from_args( + self.interner(), + AliasTyKind::Projection { def_id: assoc_type.into() }, + args, + ), ); ty = self.table.try_structurally_resolve_type(alias); segments = segments.skip(1); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/autoderef.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/autoderef.rs index d748c89e67759..a6c7b2dbb9c38 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/autoderef.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/autoderef.rs @@ -36,7 +36,7 @@ impl<'db, Ctx: AutoderefCtx<'db>> GeneralAutoderef<'db, Ctx> { .iter() .map(|&(_source, kind)| { if let AutoderefKind::Overloaded = kind { - Some(OverloadedDeref(Some(Mutability::Not))) + Some(OverloadedDeref(Mutability::Not)) } else { None } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/callee.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/callee.rs new file mode 100644 index 0000000000000..3d478912a3db2 --- /dev/null +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/callee.rs @@ -0,0 +1,543 @@ +//! Inference of calls. + +use std::iter; + +use intern::sym; +use tracing::debug; + +use hir_def::{CallableDefId, hir::ExprId, signatures::FunctionSignature}; +use rustc_type_ir::{ + InferTy, Interner, + inherent::{GenericArgs as _, IntoKind, Ty as _}, +}; + +use crate::{ + Adjust, Adjustment, AutoBorrow, FnAbi, + autoderef::{GeneralAutoderef, InferenceContextAutoderef}, + infer::{ + AllowTwoPhase, AutoBorrowMutability, Expectation, InferenceContext, InferenceDiagnostic, + expr::{ExprIsRead, TupleArgumentsFlag}, + }, + method_resolution::{MethodCallee, TreatNotYetDefinedOpaques}, + next_solver::{ + FnSig, Ty, TyKind, + infer::{BoundRegionConversionTime, traits::ObligationCause}, + }, +}; + +#[derive(Debug)] +enum CallStep<'db> { + Builtin(Ty<'db>), + DeferredClosure(ExprId, FnSig<'db>), + /// Call overloading when callee implements one of the Fn* traits. + Overloaded(MethodCallee<'db>), +} + +impl<'db> InferenceContext<'_, 'db> { + pub(crate) fn infer_call( + &mut self, + call_expr: ExprId, + callee_expr: ExprId, + arg_exprs: &[ExprId], + expected: &Expectation<'db>, + ) -> Ty<'db> { + let original_callee_ty = self.infer_expr_no_expect(callee_expr, ExprIsRead::Yes); + + let expr_ty = self.table.try_structurally_resolve_type(original_callee_ty); + + let mut autoderef = GeneralAutoderef::new_from_inference_context(self, expr_ty); + let mut result = None; + while result.is_none() && autoderef.next().is_some() { + result = + Self::try_overloaded_call_step(call_expr, callee_expr, arg_exprs, &mut autoderef); + } + + // FIXME: rustc does some ABI checks here, but the ABI mapping is in rustc_target and we don't have access to that crate. + + let obligations = autoderef.take_obligations(); + self.table.register_predicates(obligations); + + let output = match result { + None => { + // Check all of the arg expressions, but with no expectations + // since we don't have a signature to compare them to. + for &arg in arg_exprs { + self.infer_expr_no_expect(arg, ExprIsRead::Yes); + } + + self.push_diagnostic(InferenceDiagnostic::ExpectedFunction { + call_expr, + found: original_callee_ty.store(), + }); + + self.types.types.error + } + + Some(CallStep::Builtin(callee_ty)) => { + self.confirm_builtin_call(call_expr, callee_ty, arg_exprs, expected) + } + + Some(CallStep::DeferredClosure(_def_id, fn_sig)) => { + self.confirm_deferred_closure_call(call_expr, arg_exprs, expected, fn_sig) + } + + Some(CallStep::Overloaded(method_callee)) => { + self.confirm_overloaded_call(call_expr, arg_exprs, expected, method_callee) + } + }; + + // we must check that return type of called functions is WF: + self.table.register_wf_obligation(output.into(), ObligationCause::new()); + + output + } + + fn try_overloaded_call_step( + call_expr: ExprId, + callee_expr: ExprId, + arg_exprs: &[ExprId], + autoderef: &mut InferenceContextAutoderef<'_, '_, 'db>, + ) -> Option> { + let final_ty = autoderef.final_ty(); + let adjusted_ty = autoderef.ctx().table.try_structurally_resolve_type(final_ty); + + // If the callee is a function pointer or a closure, then we're all set. + match adjusted_ty.kind() { + TyKind::FnDef(..) | TyKind::FnPtr(..) => { + let adjust_steps = autoderef.adjust_steps_as_infer_ok(); + let adjustments = + autoderef.ctx().table.register_infer_ok(adjust_steps).into_boxed_slice(); + autoderef.ctx().write_expr_adj(callee_expr, adjustments); + return Some(CallStep::Builtin(adjusted_ty)); + } + + // Check whether this is a call to a closure where we + // haven't yet decided on whether the closure is fn vs + // fnmut vs fnonce. If so, we have to defer further processing. + TyKind::Closure(def_id, args) + if autoderef.ctx().infcx().closure_kind(adjusted_ty).is_none() => + { + let closure_sig = args.as_closure().sig(); + let closure_sig = autoderef.ctx().infcx().instantiate_binder_with_fresh_vars( + BoundRegionConversionTime::FnCall, + closure_sig, + ); + let adjust_steps = autoderef.adjust_steps_as_infer_ok(); + let adjustments = autoderef.ctx().table.register_infer_ok(adjust_steps); + let def_id = def_id.0.loc(autoderef.ctx().db).1; + autoderef.ctx().record_deferred_call_resolution( + def_id, + DeferredCallResolution { + call_expr, + callee_expr, + closure_ty: adjusted_ty, + adjustments, + fn_sig: closure_sig, + }, + ); + return Some(CallStep::DeferredClosure(def_id, closure_sig)); + } + + // When calling a `CoroutineClosure` that is local to the body, we will + // not know what its `closure_kind` is yet. Instead, just fill in the + // signature with an infer var for the `tupled_upvars_ty` of the coroutine, + // and record a deferred call resolution which will constrain that var + // as part of `AsyncFn*` trait confirmation. + TyKind::CoroutineClosure(def_id, args) + if autoderef.ctx().infcx().closure_kind(adjusted_ty).is_none() => + { + let closure_args = args.as_coroutine_closure(); + let coroutine_closure_sig = + autoderef.ctx().infcx().instantiate_binder_with_fresh_vars( + BoundRegionConversionTime::FnCall, + closure_args.coroutine_closure_sig(), + ); + let tupled_upvars_ty = autoderef.ctx().table.next_ty_var(); + // We may actually receive a coroutine back whose kind is different + // from the closure that this dispatched from. This is because when + // we have no captures, we automatically implement `FnOnce`. This + // impl forces the closure kind to `FnOnce` i.e. `u8`. + let kind_ty = autoderef.ctx().table.next_ty_var(); + let interner = autoderef.ctx().interner(); + let call_sig = interner.mk_fn_sig( + [coroutine_closure_sig.tupled_inputs_ty], + coroutine_closure_sig.to_coroutine( + interner, + closure_args.parent_args(), + kind_ty, + interner.coroutine_for_closure(def_id), + tupled_upvars_ty, + ), + coroutine_closure_sig.c_variadic, + coroutine_closure_sig.safety, + coroutine_closure_sig.abi, + ); + let adjust_steps = autoderef.adjust_steps_as_infer_ok(); + let adjustments = autoderef.ctx().table.register_infer_ok(adjust_steps); + let def_id = def_id.0.loc(autoderef.ctx().db).1; + autoderef.ctx().record_deferred_call_resolution( + def_id, + DeferredCallResolution { + call_expr, + callee_expr, + closure_ty: adjusted_ty, + adjustments, + fn_sig: call_sig, + }, + ); + return Some(CallStep::DeferredClosure(def_id, call_sig)); + } + + // Hack: we know that there are traits implementing Fn for &F + // where F:Fn and so forth. In the particular case of types + // like `f: &mut FnMut()`, if there is a call `f()`, we would + // normally translate to `FnMut::call_mut(&mut f, ())`, but + // that winds up potentially requiring the user to mark their + // variable as `mut` which feels unnecessary and unexpected. + // + // fn foo(f: &mut impl FnMut()) { f() } + // ^ without this hack `f` would have to be declared as mutable + // + // The simplest fix by far is to just ignore this case and deref again, + // so we wind up with `FnMut::call_mut(&mut *f, ())`. + TyKind::Ref(..) if autoderef.step_count() == 0 => { + return None; + } + + TyKind::Infer(InferTy::TyVar(vid)) + // If we end up with an inference variable which is not the hidden type of + // an opaque, emit an error. + if !autoderef.ctx().infcx().has_opaques_with_sub_unified_hidden_type(vid) => { + autoderef + .ctx() + .type_must_be_known_at_this_point(callee_expr.into(), adjusted_ty); + return None; + } + + TyKind::Error(_) => { + return None; + } + + _ => {} + } + + // Now, we look for the implementation of a Fn trait on the object's type. + // We first do it with the explicit instruction to look for an impl of + // `Fn`, with the tuple `Tuple` having an arity corresponding + // to the number of call parameters. + // If that fails (or_else branch), we try again without specifying the + // shape of the tuple (hence the None). This allows to detect an Fn trait + // is implemented, and use this information for diagnostic. + autoderef + .ctx() + .try_overloaded_call_traits(adjusted_ty, Some(arg_exprs)) + .or_else(|| autoderef.ctx().try_overloaded_call_traits(adjusted_ty, None)) + .map(|(autoref, method)| { + let adjustments = autoderef.adjust_steps_as_infer_ok(); + let mut adjustments = autoderef.ctx().table.register_infer_ok(adjustments); + adjustments.extend(autoref); + autoderef.ctx().write_expr_adj(callee_expr, adjustments.into_boxed_slice()); + CallStep::Overloaded(method) + }) + } + + fn try_overloaded_call_traits( + &mut self, + adjusted_ty: Ty<'db>, + opt_arg_exprs: Option<&[ExprId]>, + ) -> Option<(Option, MethodCallee<'db>)> { + // HACK(async_closures): For async closures, prefer `AsyncFn*` + // over `Fn*`, since all async closures implement `FnOnce`, but + // choosing that over `AsyncFn`/`AsyncFnMut` would be more restrictive. + // For other callables, just prefer `Fn*` for perf reasons. + // + // The order of trait choices here is not that big of a deal, + // since it just guides inference (and our choice of autoref). + // Though in the future, I'd like typeck to choose: + // `Fn > AsyncFn > FnMut > AsyncFnMut > FnOnce > AsyncFnOnce` + // ...or *ideally*, we just have `LendingFn`/`LendingFnMut`, which + // would naturally unify these two trait hierarchies in the most + // general way. + let call_trait_choices = if self.shallow_resolve(adjusted_ty).is_coroutine_closure() { + [ + (self.lang_items.AsyncFn, sym::async_call, true), + (self.lang_items.AsyncFnMut, sym::async_call_mut, true), + (self.lang_items.AsyncFnOnce, sym::async_call_once, false), + (self.lang_items.Fn, sym::call, true), + (self.lang_items.FnMut, sym::call_mut, true), + (self.lang_items.FnOnce, sym::call_once, false), + ] + } else { + [ + (self.lang_items.Fn, sym::call, true), + (self.lang_items.FnMut, sym::call_mut, true), + (self.lang_items.FnOnce, sym::call_once, false), + (self.lang_items.AsyncFn, sym::async_call, true), + (self.lang_items.AsyncFnMut, sym::async_call_mut, true), + (self.lang_items.AsyncFnOnce, sym::async_call_once, false), + ] + }; + + // Try the options that are least restrictive on the caller first. + for (opt_trait_def_id, method_name, borrow) in call_trait_choices { + let Some(trait_def_id) = opt_trait_def_id else { + continue; + }; + + let opt_input_type = opt_arg_exprs.map(|arg_exprs| { + Ty::new_tup_from_iter( + self.interner(), + arg_exprs.iter().map(|_| self.table.next_ty_var()), + ) + }); + + // We use `TreatNotYetDefinedOpaques::AsRigid` here so that if the `adjusted_ty` + // is `Box` we choose `FnOnce` instead of `Fn`. + // + // We try all the different call traits in order and choose the first + // one which may apply. So if we treat opaques as inference variables + // `Box: Fn` is considered ambiguous and chosen. + if let Some(ok) = self.table.lookup_method_for_operator( + ObligationCause::new(), + method_name, + trait_def_id, + adjusted_ty, + opt_input_type, + TreatNotYetDefinedOpaques::AsRigid, + ) { + let method = self.table.register_infer_ok(ok); + let mut autoref = None; + if borrow { + // Check for &self vs &mut self in the method signature. Since this is either + // the Fn or FnMut trait, it should be one of those. + let TyKind::Ref(_, _, mutbl) = method.sig.inputs_and_output.inputs()[0].kind() + else { + panic!("Expected `FnMut`/`Fn` to take receiver by-ref/by-mut") + }; + + // For initial two-phase borrow + // deployment, conservatively omit + // overloaded function call ops. + let mutbl = AutoBorrowMutability::new(mutbl, AllowTwoPhase::No); + + autoref = Some(Adjustment { + kind: Adjust::Borrow(AutoBorrow::Ref(mutbl)), + target: method.sig.inputs_and_output.inputs()[0].store(), + }); + } + + return Some((autoref, method)); + } + } + + None + } + + /// Returns the argument indices to skip. + fn check_legacy_const_generics( + &mut self, + callee: Option, + args: &[ExprId], + ) -> Box<[u32]> { + let func = match callee { + Some(CallableDefId::FunctionId(func)) => func, + _ => return Default::default(), + }; + + let data = FunctionSignature::of(self.db, func); + let Some(legacy_const_generics_indices) = data.legacy_const_generics_indices(self.db, func) + else { + return Default::default(); + }; + let mut legacy_const_generics_indices = Box::<[u32]>::from(legacy_const_generics_indices); + + // only use legacy const generics if the param count matches with them + if data.params.len() + legacy_const_generics_indices.len() != args.len() { + if args.len() <= data.params.len() { + return Default::default(); + } else { + // there are more parameters than there should be without legacy + // const params; use them + legacy_const_generics_indices.sort_unstable(); + return legacy_const_generics_indices; + } + } + + // check legacy const parameters + for arg_idx in legacy_const_generics_indices.iter().copied() { + if arg_idx >= args.len() as u32 { + continue; + } + let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly + self.infer_expr(args[arg_idx as usize], &expected, ExprIsRead::Yes); + // FIXME: evaluate and unify with the const + } + legacy_const_generics_indices.sort_unstable(); + legacy_const_generics_indices + } + + fn confirm_builtin_call( + &mut self, + call_expr: ExprId, + callee_ty: Ty<'db>, + arg_exprs: &[ExprId], + expected: &Expectation<'db>, + ) -> Ty<'db> { + let (fn_sig, def_id) = match callee_ty.kind() { + TyKind::FnDef(def_id, args) => { + let fn_sig = + self.db.callable_item_signature(def_id.0).instantiate(self.interner(), args); + (fn_sig, Some(def_id.0)) + } + + // FIXME(const_trait_impl): these arms should error because we can't enforce them + TyKind::FnPtr(sig_tys, hdr) => (sig_tys.with(hdr), None), + + _ => unreachable!(), + }; + + // Replace any late-bound regions that appear in the function + // signature with region variables. We also have to + // renormalize the associated types at this point, since they + // previously appeared within a `Binder<>` and hence would not + // have been normalized before. + let fn_sig = self + .infcx() + .instantiate_binder_with_fresh_vars(BoundRegionConversionTime::FnCall, fn_sig); + + let indices_to_skip = self.check_legacy_const_generics(def_id, arg_exprs); + self.check_call_arguments( + call_expr, + fn_sig.inputs(), + fn_sig.output(), + expected, + arg_exprs, + &indices_to_skip, + fn_sig.c_variadic, + TupleArgumentsFlag::DontTupleArguments, + ); + + if fn_sig.abi == FnAbi::RustCall + && let Some(ty) = fn_sig.inputs().last().copied() + && let Some(tuple_trait) = self.lang_items.Tuple + { + self.table.register_bound(ty, tuple_trait, ObligationCause::new()); + self.require_type_is_sized(ty); + } + + fn_sig.output() + } + + fn confirm_deferred_closure_call( + &mut self, + call_expr: ExprId, + arg_exprs: &[ExprId], + expected: &Expectation<'db>, + fn_sig: FnSig<'db>, + ) -> Ty<'db> { + // `fn_sig` is the *signature* of the closure being called. We + // don't know the full details yet (`Fn` vs `FnMut` etc), but we + // do know the types expected for each argument and the return + // type. + self.check_call_arguments( + call_expr, + fn_sig.inputs(), + fn_sig.output(), + expected, + arg_exprs, + &[], + fn_sig.c_variadic, + TupleArgumentsFlag::TupleArguments, + ); + + fn_sig.output() + } + + fn confirm_overloaded_call( + &mut self, + call_expr: ExprId, + arg_exprs: &[ExprId], + expected: &Expectation<'db>, + method: MethodCallee<'db>, + ) -> Ty<'db> { + self.check_call_arguments( + call_expr, + &method.sig.inputs()[1..], + method.sig.output(), + expected, + arg_exprs, + &[], + method.sig.c_variadic, + TupleArgumentsFlag::TupleArguments, + ); + + self.write_method_resolution(call_expr, method.def_id, method.args); + + method.sig.output() + } +} + +#[derive(Debug, Clone)] +pub(crate) struct DeferredCallResolution<'db> { + call_expr: ExprId, + callee_expr: ExprId, + closure_ty: Ty<'db>, + adjustments: Vec, + fn_sig: FnSig<'db>, +} + +impl<'a, 'db> DeferredCallResolution<'db> { + pub(crate) fn resolve(self, ctx: &mut InferenceContext<'a, 'db>) { + debug!("DeferredCallResolution::resolve() {:?}", self); + + // we should not be invoked until the closure kind has been + // determined by upvar inference + assert!(ctx.infcx().closure_kind(self.closure_ty).is_some()); + + // We may now know enough to figure out fn vs fnmut etc. + match ctx.try_overloaded_call_traits(self.closure_ty, None) { + Some((autoref, method_callee)) => { + // One problem is that when we get here, we are going + // to have a newly instantiated function signature + // from the call trait. This has to be reconciled with + // the older function signature we had before. In + // principle we *should* be able to fn_sigs(), but we + // can't because of the annoying need for a TypeTrace. + // (This always bites me, should find a way to + // refactor it.) + let method_sig = method_callee.sig; + + debug!("attempt_resolution: method_callee={:?}", method_callee); + + for (method_arg_ty, self_arg_ty) in + iter::zip(method_sig.inputs().iter().skip(1), self.fn_sig.inputs()) + { + _ = ctx.demand_eqtype(self.call_expr.into(), *self_arg_ty, *method_arg_ty); + } + + _ = ctx.demand_eqtype( + self.call_expr.into(), + method_sig.output(), + self.fn_sig.output(), + ); + + let mut adjustments = self.adjustments; + adjustments.extend(autoref); + ctx.write_expr_adj(self.callee_expr, adjustments.into_boxed_slice()); + + ctx.write_method_resolution( + self.call_expr, + method_callee.def_id, + method_callee.args, + ); + } + None => { + assert!( + ctx.lang_items.FnOnce.is_none(), + "Expected to find a suitable `Fn`/`FnMut`/`FnOnce` implementation for `{:?}`", + self.closure_ty + ) + } + } + } +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs index b868f0234209e..2207bc37e8be7 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure.rs @@ -6,7 +6,7 @@ use std::{iter, mem, ops::ControlFlow}; use hir_def::{ TraitId, - hir::{ClosureKind, ExprId, PatId}, + hir::{ClosureKind, CoroutineSource, ExprId, PatId}, type_ref::TypeRefId, }; use rustc_type_ir::{ @@ -19,7 +19,7 @@ use tracing::debug; use crate::{ FnAbi, - db::{InternedClosure, InternedCoroutine}, + db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId}, infer::{BreakableKind, Diverges, coerce::CoerceMany}, next_solver::{ AliasTy, Binder, ClauseKind, DbInterner, ErrorGuaranteed, FnSig, GenericArgs, PolyFnSig, @@ -30,7 +30,6 @@ use crate::{ traits::{ObligationCause, PredicateObligations}, }, }, - traits::FnTrait, }; use super::{Expectation, InferenceContext}; @@ -54,68 +53,40 @@ impl<'db> InferenceContext<'_, 'db> { ret_type: Option, arg_types: &[Option], closure_kind: ClosureKind, - tgt_expr: ExprId, + closure_expr: ExprId, expected: &Expectation<'db>, ) -> Ty<'db> { assert_eq!(args.len(), arg_types.len()); let interner = self.interner(); + // It's always helpful for inference if we know the kind of + // closure sooner rather than later, so first examine the expected + // type, and see if can glean a closure kind from there. let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { - Some(expected_ty) => self.deduce_closure_signature(expected_ty, closure_kind), + Some(ty) => { + let ty = self.table.try_structurally_resolve_type(ty); + self.deduce_closure_signature(ty, closure_kind) + } None => (None, None), }; - let ClosureSignatures { bound_sig, liberated_sig } = + let ClosureSignatures { bound_sig, mut liberated_sig } = self.sig_of_closure(arg_types, ret_type, expected_sig); - let body_ret_ty = bound_sig.output().skip_binder(); - let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); - // FIXME: Make this an infer var and infer it later. - let tupled_upvars_ty = self.types.types.unit; - let (id, ty, resume_yield_tys) = match closure_kind { - ClosureKind::Coroutine(_) => { - let yield_ty = self.table.next_ty_var(); - let resume_ty = - liberated_sig.inputs().first().copied().unwrap_or(self.types.types.unit); + debug!(?bound_sig, ?liberated_sig); - // FIXME: Infer the upvars later. - let parts = CoroutineArgsParts { - parent_args: parent_args.as_slice(), - kind_ty: self.types.types.unit, - resume_ty, - yield_ty, - return_ty: body_ret_ty, - tupled_upvars_ty, - }; + let parent_args = GenericArgs::identity_for_item(interner, self.generic_def.into()); - let coroutine_id = - self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); - let coroutine_ty = Ty::new_coroutine( - interner, - coroutine_id, - CoroutineArgs::new(interner, parts).args, - ); + let tupled_upvars_ty = self.table.next_ty_var(); - (None, coroutine_ty, Some((resume_ty, yield_ty))) - } + // FIXME: We could probably actually just unify this further -- + // instead of having a `FnSig` and a `Option`, + // we can have a `ClosureSignature { Coroutine { .. }, Closure { .. } }`, + // similar to how `ty::GenSig` is a distinct data structure. + let (closure_ty, resume_yield_tys) = match closure_kind { ClosureKind::Closure => { - let closure_id = self.db.intern_closure(InternedClosure(self.owner, tgt_expr)); - match expected_kind { - Some(kind) => { - self.result.closure_info.insert( - closure_id, - ( - Vec::new(), - match kind { - rustc_type_ir::ClosureKind::Fn => FnTrait::Fn, - rustc_type_ir::ClosureKind::FnMut => FnTrait::FnMut, - rustc_type_ir::ClosureKind::FnOnce => FnTrait::FnOnce, - }, - ), - ); - } - None => {} - }; + // Tuple up the arguments and insert the resulting function type into + // the `closures` table. let sig = bound_sig.map_bound(|sig| { interner.mk_fn_sig( [Ty::new_tup(interner, sig.inputs())], @@ -125,49 +96,91 @@ impl<'db> InferenceContext<'_, 'db> { sig.abi, ) }); - let sig_ty = Ty::new_fn_ptr(interner, sig); - // FIXME: Infer the kind later if needed. - let parts = ClosureArgsParts { - parent_args: parent_args.as_slice(), - closure_kind_ty: Ty::from_closure_kind( - interner, - expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), - ), - closure_sig_as_fn_ptr_ty: sig_ty, - tupled_upvars_ty, + + debug!(?sig, ?expected_kind); + + let closure_kind_ty = match expected_kind { + Some(kind) => Ty::from_closure_kind(interner, kind), + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + None => self.table.next_ty_var(), }; - let closure_ty = Ty::new_closure( + + let closure_args = ClosureArgs::new( interner, - closure_id.into(), - ClosureArgs::new(interner, parts).args, + ClosureArgsParts { + parent_args: parent_args.as_slice(), + closure_kind_ty, + closure_sig_as_fn_ptr_ty: Ty::new_fn_ptr(interner, sig), + tupled_upvars_ty, + }, ); - self.deferred_closures.entry(closure_id).or_default(); - self.add_current_closure_dependency(closure_id); - (Some(closure_id), closure_ty, None) + + let closure_id = + InternedClosureId::new(self.db, InternedClosure(self.owner, closure_expr)); + + (Ty::new_closure(interner, closure_id.into(), closure_args.args), None) } - ClosureKind::Async => { - // async closures always return the type ascribed after the `->` (if present), - // and yield `()`. - let bound_return_ty = bound_sig.skip_binder().output(); - let bound_yield_ty = self.types.types.unit; - // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. - let resume_ty = self.types.types.unit; + ClosureKind::Coroutine(_) | ClosureKind::AsyncBlock { .. } => { + let yield_ty = match closure_kind { + ClosureKind::Coroutine(_) => self.table.next_ty_var(), + ClosureKind::AsyncBlock { .. } => self.types.types.unit, + _ => unreachable!(), + }; - // FIXME: Infer the kind later if needed. - let closure_kind_ty = Ty::from_closure_kind( - interner, - expected_kind.unwrap_or(rustc_type_ir::ClosureKind::Fn), - ); + // Resume type defaults to `()` if the coroutine has no argument. + let resume_ty = + liberated_sig.inputs().first().copied().unwrap_or(self.types.types.unit); - // FIXME: Infer captures later. - // `for<'env> fn() -> ()`, for no captures. - let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + // Coroutines that come from coroutine closures have not yet determined + // their kind ty, so make a fresh infer var which will be constrained + // later during upvar analysis. Regular coroutines always have the kind + // ty of `().` + let kind_ty = match closure_kind { + ClosureKind::AsyncBlock { source: CoroutineSource::Closure } => { + self.table.next_ty_var() + } + _ => self.types.types.unit, + }; + + let coroutine_args = CoroutineArgs::new( interner, - Binder::bind_with_vars( - interner.mk_fn_sig_safe_rust_abi([], self.types.types.unit), - self.types.coroutine_captures_by_ref_bound_var_kinds, - ), + CoroutineArgsParts { + parent_args: parent_args.as_slice(), + kind_ty, + resume_ty, + yield_ty, + return_ty: liberated_sig.output(), + tupled_upvars_ty, + }, ); + + let coroutine_id = + InternedCoroutineId::new(self.db, InternedClosure(self.owner, closure_expr)); + + ( + Ty::new_coroutine(interner, coroutine_id.into(), coroutine_args.args), + Some((resume_ty, yield_ty)), + ) + } + ClosureKind::AsyncClosure => { + // async closures always return the type ascribed after the `->` (if present), + // and yield `()`. + let (bound_return_ty, bound_yield_ty) = + (bound_sig.skip_binder().output(), self.types.types.unit); + // Compute all of the variables that will be used to populate the coroutine. + let resume_ty = self.table.next_ty_var(); + + let closure_kind_ty = match expected_kind { + Some(kind) => Ty::from_closure_kind(interner, kind), + + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + None => self.table.next_ty_var(), + }; + + let coroutine_captures_by_ref_ty = self.table.next_ty_var(); + let closure_args = CoroutineClosureArgs::new( interner, CoroutineClosureArgsParts { @@ -177,7 +190,13 @@ impl<'db> InferenceContext<'_, 'db> { interner, bound_sig.map_bound(|sig| { interner.mk_fn_sig( - [resume_ty, Ty::new_tup(interner, sig.inputs())], + [ + resume_ty, + Ty::new_tup_from_iter( + interner, + sig.inputs().iter().copied(), + ), + ], Ty::new_tup(interner, &[bound_yield_ty, bound_return_ty]), sig.c_variadic, sig.safety, @@ -190,9 +209,55 @@ impl<'db> InferenceContext<'_, 'db> { }, ); - let coroutine_id = - self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); - (None, Ty::new_coroutine_closure(interner, coroutine_id, closure_args.args), None) + let coroutine_kind_ty = match expected_kind { + Some(kind) => Ty::from_coroutine_closure_kind(interner, kind), + + // Create a type variable (for now) to represent the closure kind. + // It will be unified during the upvar inference phase (`upvar.rs`) + None => self.table.next_ty_var(), + }; + + let coroutine_upvars_ty = self.table.next_ty_var(); + + let coroutine_closure_id = InternedCoroutineClosureId::new( + self.db, + InternedClosure(self.owner, closure_expr), + ); + + // We need to turn the liberated signature that we got from HIR, which + // looks something like `|Args...| -> T`, into a signature that is suitable + // for type checking the inner body of the closure, which always returns a + // coroutine. To do so, we use the `CoroutineClosureSignature` to compute + // the coroutine type, filling in the tupled_upvars_ty and kind_ty with infer + // vars which will get constrained during upvar analysis. + let coroutine_output_ty = closure_args + .coroutine_closure_sig() + .map_bound(|sig| { + sig.to_coroutine( + interner, + parent_args.as_slice(), + coroutine_kind_ty, + interner.coroutine_for_closure(coroutine_closure_id.into()), + coroutine_upvars_ty, + ) + }) + .skip_binder(); + liberated_sig = interner.mk_fn_sig( + liberated_sig.inputs().iter().copied(), + coroutine_output_ty, + liberated_sig.c_variadic, + liberated_sig.safety, + liberated_sig.abi, + ); + + ( + Ty::new_coroutine_closure( + interner, + coroutine_closure_id.into(), + closure_args.args, + ), + None, + ) } }; @@ -203,9 +268,9 @@ impl<'db> InferenceContext<'_, 'db> { // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); - let prev_closure = mem::replace(&mut self.current_closure, id); - let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty); - let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty)); + let prev_ret_ty = mem::replace(&mut self.return_ty, liberated_sig.output()); + let prev_ret_coercion = + self.return_coercion.replace(CoerceMany::new(liberated_sig.output())); let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { @@ -215,10 +280,9 @@ impl<'db> InferenceContext<'_, 'db> { self.diverges = prev_diverges; self.return_ty = prev_ret_ty; self.return_coercion = prev_ret_coercion; - self.current_closure = prev_closure; self.resume_yield_tys = prev_resume_yield_tys; - ty + closure_ty } fn fn_trait_kind_from_def_id(&self, trait_id: TraitId) -> Option { @@ -256,7 +320,7 @@ impl<'db> InferenceContext<'_, 'db> { closure_kind: ClosureKind, ) -> (Option>, Option) { match expected_ty.kind() { - TyKind::Alias(rustc_type_ir::Opaque, AliasTy { def_id, args, .. }) => self + TyKind::Alias(AliasTy { kind: rustc_type_ir::Opaque { def_id }, args, .. }) => self .deduce_closure_signature_from_predicates( expected_ty, closure_kind, @@ -287,7 +351,9 @@ impl<'db> InferenceContext<'_, 'db> { let expected_sig = sig_tys.with(hdr); (Some(expected_sig), Some(rustc_type_ir::ClosureKind::Fn)) } - ClosureKind::Coroutine(_) | ClosureKind::Async => (None, None), + ClosureKind::Coroutine(_) + | ClosureKind::AsyncClosure + | ClosureKind::AsyncBlock { .. } => (None, None), }, _ => (None, None), } @@ -400,7 +466,7 @@ impl<'db> InferenceContext<'_, 'db> { if let Some(trait_def_id) = trait_def_id { let found_kind = match closure_kind { ClosureKind::Closure => self.fn_trait_kind_from_def_id(trait_def_id), - ClosureKind::Async => self + ClosureKind::AsyncClosure => self .async_fn_trait_kind_from_def_id(trait_def_id) .or_else(|| self.fn_trait_kind_from_def_id(trait_def_id)), _ => None, @@ -446,13 +512,13 @@ impl<'db> InferenceContext<'_, 'db> { ClosureKind::Closure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection(projection) } - ClosureKind::Async if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { + ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.AsyncFnOnceOutput => { self.extract_sig_from_projection(projection) } // It's possible we've passed the closure to a (somewhat out-of-fashion) // `F: FnOnce() -> Fut, Fut: Future` style bound. Let's still // guide inference here, since it's beneficial for the user. - ClosureKind::Async if Some(def_id) == self.lang_items.FnOnceOutput => { + ClosureKind::AsyncClosure if Some(def_id) == self.lang_items.FnOnceOutput => { self.extract_sig_from_projection_and_future_bound(projection) } _ => None, @@ -561,8 +627,7 @@ impl<'db> InferenceContext<'_, 'db> { // that does not misuse a `FnSig` type, but that can be done separately. let return_ty = return_ty.unwrap_or_else(|| self.table.next_ty_var()); - let sig = - projection.rebind(self.interner().mk_fn_sig_safe_rust_abi(input_tys, return_ty)); + let sig = projection.rebind(self.interner().mk_fn_sig_safe_rust_abi(input_tys, return_ty)); Some(sig) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis.rs index ce0ccfe82f27c..668d7496cd1b2 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis.rs @@ -1,1310 +1,1676 @@ -//! Post-inference closure analysis: captures and closure kind. +//! ### Inferring borrow kinds for upvars +//! +//! Whenever there is a closure expression, we need to determine how each +//! upvar is used. We do this by initially assigning each upvar an +//! immutable "borrow kind" (see `BorrowKind` for details) and then +//! "escalating" the kind as needed. The borrow kind proceeds according to +//! the following lattice: +//! ```ignore (not-rust) +//! ty::ImmBorrow -> ty::UniqueImmBorrow -> ty::MutBorrow +//! ``` +//! So, for example, if we see an assignment `x = 5` to an upvar `x`, we +//! will promote its borrow kind to mutable borrow. If we see an `&mut x` +//! we'll do the same. Naturally, this applies not just to the upvar, but +//! to everything owned by `x`, so the result is the same for something +//! like `x.f = 5` and so on (presuming `x` is not a borrowed pointer to a +//! struct). These adjustments are performed in +//! `adjust_for_non_move_closure` (you can trace backwards through the code +//! from there). +//! +//! The fact that we are inferring borrow kinds as we go results in a +//! semi-hacky interaction with the way `ExprUseVisitor` is computing +//! `Place`s. In particular, it will query the current borrow kind as it +//! goes, and we'll return the *current* value, but this may get +//! adjusted later. Therefore, in this module, we generally ignore the +//! borrow kind (and derived mutabilities) that `ExprUseVisitor` returns +//! within `Place`s, since they may be inaccurate. (Another option +//! would be to use a unification scheme, where instead of returning a +//! concrete borrow kind like `ty::ImmBorrow`, we return a +//! `ty::InferBorrow(upvar_id)` or something like that, but this would +//! then mean that all later passes would have to check for these figments +//! and report an error, and it just seems like more mess in the end.) -use std::{cmp, mem}; +use std::{iter, mem}; -use base_db::Crate; use hir_def::{ - ExpressionStoreOwnerId, FieldId, HasModule, VariantId, - expr_store::{Body, ExpressionStore, path::Path}, + expr_store::ExpressionStore, hir::{ - Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, ExprOrPatId, Pat, PatId, - RecordSpread, Statement, UnaryOp, + BindingAnnotation, BindingId, CaptureBy, CoroutineSource, Expr, ExprId, ExprOrPatId, Pat, + PatId, Statement, }, - item_tree::FieldsShape, resolver::ValueNs, - signatures::VariantFields, }; +use macros::{TypeFoldable, TypeVisitable}; use rustc_ast_ir::Mutability; -use rustc_hash::{FxHashMap, FxHashSet}; -use rustc_type_ir::inherent::{GenericArgs as _, IntoKind, Ty as _}; +use rustc_hash::{FxBuildHasher, FxHashMap}; +use rustc_type_ir::{ + BoundVar, ClosureKind, TypeVisitableExt as _, + inherent::{AdtDef as _, GenericArgs as _, IntoKind as _, Ty as _}, +}; use smallvec::{SmallVec, smallvec}; -use stdx::{format_to, never}; -use syntax::utils::is_raw_identifier; +use span::Edition; +use tracing::{debug, instrument}; use crate::{ - Adjust, Adjustment, BindingMode, - db::{HirDatabase, InternedClosure, InternedClosureId}, - display::{DisplayTarget, HirDisplay as _}, - infer::InferenceContext, - mir::{BorrowKind, MirSpan, MutBorrowKind}, + FnAbi, + infer::{ + CaptureInfo, CaptureSourceStack, CapturedPlace, InferenceContext, UpvarCapture, + closure::analysis::expr_use_visitor::{ + self as euv, FakeReadCause, Place, PlaceBase, PlaceWithOrigin, Projection, + ProjectionKind, + }, + }, next_solver::{ - DbInterner, ErrorGuaranteed, GenericArgs, ParamEnv, StoredEarlyBinder, StoredTy, Ty, - TyKind, - infer::{InferCtxt, traits::ObligationCause}, - obligation_ctxt::ObligationCtxt, + Binder, BoundRegion, BoundRegionKind, DbInterner, GenericArgs, Region, Ty, TyKind, + abi::Safety, infer::traits::ObligationCause, normalize, }, - traits::FnTrait, + upvars::{Upvars, UpvarsRef}, }; -// The below functions handle capture and closure kind (Fn, FnMut, ..) +pub(crate) mod expr_use_visitor; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub(crate) enum HirPlaceProjection { - Deref, - Field(FieldId), - TupleField(u32), +#[derive(Debug, Copy, Clone, TypeVisitable, TypeFoldable)] +enum UpvarArgs<'db> { + Closure(GenericArgs<'db>), + Coroutine(GenericArgs<'db>), + CoroutineClosure(GenericArgs<'db>), } -impl HirPlaceProjection { - fn projected_ty<'db>( - self, - infcx: &InferCtxt<'db>, - env: ParamEnv<'db>, - mut base: Ty<'db>, - krate: Crate, - ) -> Ty<'db> { - let interner = infcx.interner; - let db = interner.db; - if base.is_ty_error() { - return Ty::new_error(interner, ErrorGuaranteed); - } - - if matches!(base.kind(), TyKind::Alias(..)) { - let mut ocx = ObligationCtxt::new(infcx); - match ocx.structurally_normalize_ty(&ObligationCause::dummy(), env, base) { - Ok(it) => base = it, - Err(_) => return Ty::new_error(interner, ErrorGuaranteed), - } - } +impl<'db> UpvarArgs<'db> { + #[inline] + fn tupled_upvars_ty(self) -> Ty<'db> { match self { - HirPlaceProjection::Deref => match base.kind() { - TyKind::RawPtr(inner, _) | TyKind::Ref(_, inner, _) => inner, - TyKind::Adt(adt_def, subst) if adt_def.is_box() => subst.type_at(0), - _ => { - never!( - "Overloaded deref on type {} is not a projection", - base.display(db, DisplayTarget::from_crate(db, krate)) - ); - Ty::new_error(interner, ErrorGuaranteed) - } - }, - HirPlaceProjection::Field(f) => match base.kind() { - TyKind::Adt(_, subst) => { - db.field_types(f.parent)[f.local_id].get().instantiate(interner, subst) - } - ty => { - never!("Only adt has field, found {:?}", ty); - Ty::new_error(interner, ErrorGuaranteed) - } - }, - HirPlaceProjection::TupleField(idx) => match base.kind() { - TyKind::Tuple(subst) => { - subst.as_slice().get(idx as usize).copied().unwrap_or_else(|| { - never!("Out of bound tuple field"); - Ty::new_error(interner, ErrorGuaranteed) - }) - } - ty => { - never!("Only tuple has tuple field: {:?}", ty); - Ty::new_error(interner, ErrorGuaranteed) - } - }, + UpvarArgs::Closure(args) => args.as_closure().tupled_upvars_ty(), + UpvarArgs::Coroutine(args) => args.as_coroutine().tupled_upvars_ty(), + UpvarArgs::CoroutineClosure(args) => args.as_coroutine_closure().tupled_upvars_ty(), } } } -#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)] -pub(crate) struct HirPlace { - pub(crate) local: BindingId, - pub(crate) projections: Vec, +#[derive(Eq, Clone, PartialEq, Debug, Copy, Hash)] +pub enum BorrowKind { + /// Data must be immutable and is aliasable. + Immutable, + + /// Data must be immutable but not aliasable. This kind of borrow + /// cannot currently be expressed by the user and is used only in + /// implicit closure bindings. It is needed when the closure + /// is borrowing or mutating a mutable referent, e.g.: + /// + /// ``` + /// let mut z = 3; + /// let x: &mut isize = &mut z; + /// let y = || *x += 5; + /// ``` + /// + /// If we were to try to translate this closure into a more explicit + /// form, we'd encounter an error with the code as written: + /// + /// ```compile_fail,E0594 + /// struct Env<'a> { x: &'a &'a mut isize } + /// let mut z = 3; + /// let x: &mut isize = &mut z; + /// let y = (&mut Env { x: &x }, fn_ptr); // Closure is pair of env and fn + /// fn fn_ptr(env: &mut Env) { **env.x += 5; } + /// ``` + /// + /// This is then illegal because you cannot mutate a `&mut` found + /// in an aliasable location. To solve, you'd have to translate with + /// an `&mut` borrow: + /// + /// ```compile_fail,E0596 + /// struct Env<'a> { x: &'a mut &'a mut isize } + /// let mut z = 3; + /// let x: &mut isize = &mut z; + /// let y = (&mut Env { x: &mut x }, fn_ptr); // changed from &x to &mut x + /// fn fn_ptr(env: &mut Env) { **env.x += 5; } + /// ``` + /// + /// Now the assignment to `**env.x` is legal, but creating a + /// mutable pointer to `x` is not because `x` is not mutable. We + /// could fix this by declaring `x` as `let mut x`. This is ok in + /// user code, if awkward, but extra weird for closures, since the + /// borrow is hidden. + /// + /// So we introduce a "unique imm" borrow -- the referent is + /// immutable, but not aliasable. This solves the problem. For + /// simplicity, we don't give users the way to express this + /// borrow, it's just used when translating closures. + /// + /// FIXME: Rename this to indicate the borrow is actually not immutable. + UniqueImmutable, + + /// Data is mutable and not aliasable. + Mutable, } -impl HirPlace { - fn ty<'db>(&self, ctx: &mut InferenceContext<'_, 'db>) -> Ty<'db> { - let krate = ctx.krate(); - let mut ty = ctx.table.resolve_completely(ctx.result.binding_ty(self.local)); - for p in &self.projections { - ty = p.projected_ty(ctx.infcx(), ctx.table.param_env, ty, krate); +impl BorrowKind { + pub fn from_hir_mutbl(m: hir_def::hir::type_ref::Mutability) -> BorrowKind { + match m { + hir_def::hir::type_ref::Mutability::Mut => BorrowKind::Mutable, + hir_def::hir::type_ref::Mutability::Shared => BorrowKind::Immutable, } - ty } - fn capture_kind_of_truncated_place( - &self, - mut current_capture: CaptureKind, - len: usize, - ) -> CaptureKind { - if let CaptureKind::ByRef(BorrowKind::Mut { - kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow, - }) = current_capture - && self.projections[len..].contains(&HirPlaceProjection::Deref) - { - current_capture = - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture }); + pub fn from_mutbl(m: Mutability) -> BorrowKind { + match m { + Mutability::Mut => BorrowKind::Mutable, + Mutability::Not => BorrowKind::Immutable, } - current_capture } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum CaptureKind { - ByRef(BorrowKind), - ByValue, -} + /// Returns a mutability `m` such that an `&m T` pointer could be used to obtain this borrow + /// kind. Because borrow kinds are richer than mutabilities, we sometimes have to pick a + /// mutability that is stronger than necessary so that it at least *would permit* the borrow in + /// question. + pub fn to_mutbl_lossy(self) -> Mutability { + match self { + BorrowKind::Mutable => Mutability::Mut, + BorrowKind::Immutable => Mutability::Not, -#[derive(Debug, Clone, PartialEq, Eq, salsa::Update)] -pub struct CapturedItem { - pub(crate) place: HirPlace, - pub(crate) kind: CaptureKind, - /// The inner vec is the stacks; the outer vec is for each capture reference. - /// - /// Even though we always report only the last span (i.e. the most inclusive span), - /// we need to keep them all, since when a closure occurs inside a closure, we - /// copy all captures of the inner closure to the outer closure, and then we may - /// truncate them, and we want the correct span to be reported. - span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, - pub(crate) ty: StoredEarlyBinder, + // We have no type corresponding to a unique imm borrow, so + // use `&mut`. It gives all the capabilities of a `&uniq` + // and hence is a safe "over approximation". + BorrowKind::UniqueImmutable => Mutability::Mut, + } + } } -impl CapturedItem { - pub fn local(&self) -> BindingId { - self.place.local - } +/// Describe the relationship between the paths of two places +/// eg: +/// - `foo` is ancestor of `foo.bar.baz` +/// - `foo.bar.baz` is an descendant of `foo.bar` +/// - `foo.bar` and `foo.baz` are divergent +enum PlaceAncestryRelation { + Ancestor, + Descendant, + SamePlace, + Divergent, +} - /// Returns whether this place has any field (aka. non-deref) projections. - pub fn has_field_projections(&self) -> bool { - self.place.projections.iter().any(|it| !matches!(it, HirPlaceProjection::Deref)) - } +/// Intermediate format to store a captured `Place` and associated `CaptureInfo` +/// during capture analysis. Information in this map feeds into the minimum capture +/// analysis pass. +type InferredCaptureInformation = Vec<(Place, CaptureInfo)>; - pub fn ty<'db>(&self, db: &'db dyn HirDatabase, subst: GenericArgs<'db>) -> Ty<'db> { - let interner = DbInterner::new_no_crate(db); - self.ty.get().instantiate(interner, subst.as_closure().parent_args()) - } +impl<'a, 'db> InferenceContext<'a, 'db> { + pub(crate) fn closure_analyze(&mut self) { + let upvars = crate::upvars::upvars_mentioned(self.db, self.owner) + .unwrap_or(const { &FxHashMap::with_hasher(FxBuildHasher) }); + for root_expr in self.store.expr_roots() { + self.analyze_closures_in_expr(root_expr, upvars); + } - pub fn kind(&self) -> CaptureKind { - self.kind + // it's our job to process these. + assert!(self.deferred_call_resolutions.is_empty()); } - pub fn spans(&self) -> SmallVec<[MirSpan; 3]> { - self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect() - } + fn analyze_closures_in_expr(&mut self, expr: ExprId, upvars: &'db FxHashMap) { + self.store.walk_child_exprs(expr, |expr| self.analyze_closures_in_expr(expr, upvars)); - /// Converts the place to a name that can be inserted into source code. - pub fn place_to_name(&self, owner: ExpressionStoreOwnerId, db: &dyn HirDatabase) -> String { - let krate = owner.krate(db); - let edition = krate.data(db).edition; - let mut result = match owner { - ExpressionStoreOwnerId::Signature(generic_def_id) => { - ExpressionStore::of(db, generic_def_id.into())[self.place.local] - .name - .display(db, edition) - .to_string() - } - ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) - [self.place.local] - .name - .display(db, edition) - .to_string(), - ExpressionStoreOwnerId::VariantFields(variant_id) => { - let fields = VariantFields::of(db, variant_id); - fields.store[self.place.local].name.display(db, edition).to_string() - } - }; - for proj in &self.place.projections { - match proj { - HirPlaceProjection::Deref => {} - HirPlaceProjection::Field(f) => { - let variant_data = f.parent.fields(db); - match variant_data.shape { - FieldsShape::Record => { - result.push('_'); - result.push_str(variant_data.fields()[f.local_id].name.as_str()) - } - FieldsShape::Tuple => { - let index = - variant_data.fields().iter().position(|it| it.0 == f.local_id); - if let Some(index) = index { - format_to!(result, "_{index}"); - } - } - FieldsShape::Unit => {} - } - } - HirPlaceProjection::TupleField(idx) => { - format_to!(result, "_{idx}") - } + match &self.store[expr] { + Expr::Closure { args, body, closure_kind, capture_by, .. } => { + self.analyze_closure( + expr, + args, + *body, + *capture_by, + *closure_kind, + upvars.get(&expr).map(|upvars| upvars.as_ref()).unwrap_or_default(), + ); } + _ => {} } - if is_raw_identifier(&result, owner.module(db).krate(db).data(db).edition) { - result.insert_str(0, "r#"); - } - result } - pub fn display_place_source_code( - &self, - owner: ExpressionStoreOwnerId, - db: &dyn HirDatabase, - ) -> String { - let krate = owner.krate(db); - let edition = krate.data(db).edition; - let mut result = match owner { - ExpressionStoreOwnerId::Signature(generic_def_id) => { - ExpressionStore::of(db, generic_def_id.into())[self.place.local] - .name - .display(db, edition) - .to_string() + /// Analysis starting point. + #[instrument(skip(self, body), level = "debug")] + fn analyze_closure( + &mut self, + closure_expr_id: ExprId, + params: &[PatId], + body: ExprId, + mut capture_clause: CaptureBy, + closure_kind: hir_def::hir::ClosureKind, + upvars: UpvarsRef<'db>, + ) { + // Extract the type of the closure. + let ty = self.expr_ty(closure_expr_id); + let (args, infer_kind) = match ty.kind() { + TyKind::Closure(_def_id, args) => { + (UpvarArgs::Closure(args), self.infcx().closure_kind(ty).is_none()) } - ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) - [self.place.local] - .name - .display(db, edition) - .to_string(), - ExpressionStoreOwnerId::VariantFields(variant_id) => { - let fields = VariantFields::of(db, variant_id); - fields.store[self.place.local].name.display(db, edition).to_string() + TyKind::CoroutineClosure(_def_id, args) => { + (UpvarArgs::CoroutineClosure(args), self.infcx().closure_kind(ty).is_none()) } - }; - for proj in &self.place.projections { - match proj { - // In source code autoderef kicks in. - HirPlaceProjection::Deref => {} - HirPlaceProjection::Field(f) => { - let variant_data = f.parent.fields(db); - match variant_data.shape { - FieldsShape::Record => format_to!( - result, - ".{}", - variant_data.fields()[f.local_id].name.display(db, edition) - ), - FieldsShape::Tuple => format_to!( - result, - ".{}", - variant_data - .fields() - .iter() - .position(|it| it.0 == f.local_id) - .unwrap_or_default() - ), - FieldsShape::Unit => {} - } - } - HirPlaceProjection::TupleField(idx) => { - format_to!(result, ".{idx}") - } + TyKind::Coroutine(_def_id, args) => (UpvarArgs::Coroutine(args), false), + TyKind::Error(_) => { + // #51714: skip analysis when we have already encountered type errors + return; } - } - let final_derefs_count = self - .place - .projections - .iter() - .rev() - .take_while(|proj| matches!(proj, HirPlaceProjection::Deref)) - .count(); - result.insert_str(0, &"*".repeat(final_derefs_count)); - result - } - - pub fn display_place(&self, owner: ExpressionStoreOwnerId, db: &dyn HirDatabase) -> String { - let krate = owner.krate(db); - let edition = krate.data(db).edition; - let mut result = match owner { - ExpressionStoreOwnerId::Signature(generic_def_id) => { - ExpressionStore::of(db, generic_def_id.into())[self.place.local] - .name - .display(db, edition) - .to_string() - } - ExpressionStoreOwnerId::Body(def_with_body_id) => Body::of(db, def_with_body_id) - [self.place.local] - .name - .display(db, edition) - .to_string(), - ExpressionStoreOwnerId::VariantFields(variant_id) => { - let fields = VariantFields::of(db, variant_id); - fields.store[self.place.local].name.display(db, edition).to_string() + _ => { + panic!("type of closure expr {:?} is not a closure {:?}", closure_expr_id, ty); } }; - let mut field_need_paren = false; - for proj in &self.place.projections { - match proj { - HirPlaceProjection::Deref => { - result = format!("*{result}"); - field_need_paren = true; - } - HirPlaceProjection::Field(f) => { - if field_need_paren { - result = format!("({result})"); - } - let variant_data = f.parent.fields(db); - let field = match variant_data.shape { - FieldsShape::Record => { - variant_data.fields()[f.local_id].name.as_str().to_owned() - } - FieldsShape::Tuple => variant_data - .fields() - .iter() - .position(|it| it.0 == f.local_id) - .unwrap_or_default() - .to_string(), - FieldsShape::Unit => "[missing field]".to_owned(), - }; - result = format!("{result}.{field}"); - field_need_paren = false; - } - HirPlaceProjection::TupleField(idx) => { - if field_need_paren { - result = format!("({result})"); - } - result = format!("{result}.{idx}"); - field_need_paren = false; - } + let args = self.infcx().resolve_vars_if_possible(args); + + let mut delegate = InferBorrowKind { + closure_def_id: closure_expr_id, + capture_information: Default::default(), + fake_reads: Default::default(), + }; + + let _ = euv::ExprUseVisitor::new(self, closure_expr_id, upvars, &mut delegate) + .consume_closure_body(params, body); + + // There are several curious situations with coroutine-closures where + // analysis is too aggressive with borrows when the coroutine-closure is + // marked `move`. Specifically: + // + // 1. If the coroutine-closure was inferred to be `FnOnce` during signature + // inference, then it's still possible that we try to borrow upvars from + // the coroutine-closure because they are not used by the coroutine body + // in a way that forces a move. See the test: + // `async-await/async-closures/force-move-due-to-inferred-kind.rs`. + // + // 2. If the coroutine-closure is forced to be `FnOnce` due to the way it + // uses its upvars (e.g. it consumes a non-copy value), but not *all* upvars + // would force the closure to `FnOnce`. + // See the test: `async-await/async-closures/force-move-due-to-actually-fnonce.rs`. + // + // This would lead to an impossible to satisfy situation, since `AsyncFnOnce` + // coroutine bodies can't borrow from their parent closure. To fix this, + // we force the inner coroutine to also be `move`. This only matters for + // coroutine-closures that are `move` since otherwise they themselves will + // be borrowing from the outer environment, so there's no self-borrows occurring. + if let UpvarArgs::Coroutine(..) = args + && let hir_def::hir::ClosureKind::AsyncBlock { source: CoroutineSource::Closure } = + closure_kind + && let parent_hir_id = ExpressionStore::closure_for_coroutine(closure_expr_id) + && let parent_ty = self.result.expr_ty(parent_hir_id) + && let Expr::Closure { capture_by: CaptureBy::Value, .. } = self.store[parent_hir_id] + { + // (1.) Closure signature inference forced this closure to `FnOnce`. + if let Some(ClosureKind::FnOnce) = self.infcx().closure_kind(parent_ty) { + capture_clause = CaptureBy::Value; + } + // (2.) The way that the closure uses its upvars means it's `FnOnce`. + else if self.coroutine_body_consumes_upvars(closure_expr_id, body, upvars) { + capture_clause = CaptureBy::Value; } } - result - } -} -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct CapturedItemWithoutTy { - pub(crate) place: HirPlace, - pub(crate) kind: CaptureKind, - /// The inner vec is the stacks; the outer vec is for each capture reference. - pub(crate) span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>, -} - -impl CapturedItemWithoutTy { - fn with_ty(self, ctx: &mut InferenceContext<'_, '_>) -> CapturedItem { - let ty = self.place.ty(ctx); - let ty = match &self.kind { - CaptureKind::ByValue => ty, - CaptureKind::ByRef(bk) => { - let m = match bk { - BorrowKind::Mut { .. } => Mutability::Mut, - _ => Mutability::Not, + // As noted in `lower_coroutine_body_with_moved_arguments`, we default the capture mode + // to `ByRef` for the `async {}` block internal to async fns/closure. This means + // that we would *not* be moving all of the parameters into the async block in all cases. + // For example, when one of the arguments is `Copy`, we turn a consuming use into a copy of + // a reference, so for `async fn x(t: i32) {}`, we'd only take a reference to `t`. + // + // We force all of these arguments to be captured by move before we do expr use analysis. + // + // FIXME(async_closures): This could be cleaned up. It's a bit janky that we're just + // moving all of the `LocalSource::AsyncFn` locals here. + if let hir_def::hir::ClosureKind::AsyncBlock { + source: CoroutineSource::Fn | CoroutineSource::Closure, + } = closure_kind + { + let Expr::Block { statements, .. } = &self.store[body] else { + panic!(); + }; + for stmt in statements { + let Statement::Let { pat, initializer: Some(init), .. } = *stmt else { + panic!(); + }; + let Pat::Bind { .. } = self.store[pat] else { + // Complex pattern, skip the non-upvar local. + continue; + }; + let Expr::Path(path) = &self.store[init] else { + panic!(); + }; + let update_guard = self.resolver.update_to_inner_scope(self.db, self.owner, init); + let Some(ValueNs::LocalBinding(local_id)) = + self.resolver.resolve_path_in_value_ns_fully( + self.db, + path, + self.store.expr_path_hygiene(init), + ) + else { + panic!(); }; - Ty::new_ref(ctx.interner(), ctx.types.regions.error, ty, m) + self.resolver.reset_to_guard(update_guard); + let place = self.place_for_root_variable(closure_expr_id, local_id); + delegate.capture_information.push(( + place, + CaptureInfo { + sources: smallvec![CaptureSourceStack::from_single(init.into())], + capture_kind: UpvarCapture::ByValue, + }, + )); } - }; - CapturedItem { - place: self.place, - kind: self.kind, - span_stacks: self.span_stacks, - ty: StoredEarlyBinder::bind(ty.store()), } - } -} -impl<'db> InferenceContext<'_, 'db> { - fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option { - let r = self.place_of_expr_without_adjust(tgt_expr)?; - let adjustments = - self.result.expr_adjustments.get(&tgt_expr).map(|it| &**it).unwrap_or_default(); - apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments) - } + debug!( + "For closure={:?}, capture_information={:#?}", + closure_expr_id, delegate.capture_information + ); - /// Pushes the span into `current_capture_span_stack`, *without clearing it first*. - fn path_place(&mut self, path: &Path, id: ExprOrPatId) -> Option { - if path.type_anchor().is_some() { - return None; - } - let hygiene = self.store.expr_or_pat_path_hygiene(id); - self.resolver.resolve_path_in_value_ns_fully(self.db, path, hygiene).and_then(|result| { - match result { - ValueNs::LocalBinding(binding) => { - let mir_span = match id { - ExprOrPatId::ExprId(id) => MirSpan::ExprId(id), - ExprOrPatId::PatId(id) => MirSpan::PatId(id), - }; - self.current_capture_span_stack.push(mir_span); - Some(HirPlace { local: binding, projections: Vec::new() }) - } - _ => None, - } - }) - } + let (capture_information, closure_kind, _origin) = self + .process_collected_capture_information(capture_clause, &delegate.capture_information); - /// Changes `current_capture_span_stack` to contain the stack of spans for this expr. - fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option { - self.current_capture_span_stack.clear(); - match &self.store[tgt_expr] { - Expr::Path(p) => { - let resolver_guard = - self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr); - let result = self.path_place(p, tgt_expr.into()); - self.resolver.reset_to_guard(resolver_guard); - return result; - } - Expr::Field { expr, name: _ } => { - let mut place = self.place_of_expr(*expr)?; - let field = self.result.field_resolution(tgt_expr)?; - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - place.projections.push(field.either(HirPlaceProjection::Field, |f| { - HirPlaceProjection::TupleField(f.index) - })); - return Some(place); - } - Expr::UnaryOp { expr, op: UnaryOp::Deref } => { - let is_builtin_deref = match self.expr_ty(*expr).kind() { - TyKind::Ref(..) | TyKind::RawPtr(..) => true, - TyKind::Adt(adt_def, _) if adt_def.is_box() => true, - _ => false, + self.compute_min_captures(closure_expr_id, capture_information); + + // We now fake capture information for all variables that are mentioned within the closure + // We do this after handling migrations so that min_captures computes before + if !enable_precise_capture(self.edition) { + let mut capture_information: InferredCaptureInformation = Default::default(); + + for var_hir_id in upvars.iter() { + let place = Place { + base_ty: self.result.binding_ty(var_hir_id).store(), + base: PlaceBase::Upvar { closure: closure_expr_id, var_id: var_hir_id }, + projections: Vec::new(), }; - if is_builtin_deref { - let mut place = self.place_of_expr(*expr)?; - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - place.projections.push(HirPlaceProjection::Deref); - return Some(place); - } - } - _ => (), - } - None - } - fn push_capture(&mut self, place: HirPlace, kind: CaptureKind) { - self.current_captures.push(CapturedItemWithoutTy { - place, - kind, - span_stacks: smallvec![self.current_capture_span_stack.iter().copied().collect()], - }); - } + debug!("seed place {:?}", place); - fn truncate_capture_spans(&self, capture: &mut CapturedItemWithoutTy, mut truncate_to: usize) { - // The first span is the identifier, and it must always remain. - truncate_to += 1; - for span_stack in &mut capture.span_stacks { - let mut remained = truncate_to; - let mut actual_truncate_to = 0; - for &span in &*span_stack { - actual_truncate_to += 1; - if !span.is_ref_span(self.store) { - remained -= 1; - if remained == 0 { - break; - } - } - } - if actual_truncate_to < span_stack.len() - && span_stack[actual_truncate_to].is_ref_span(self.store) - { - // Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect. - actual_truncate_to += 1; + let capture_kind = self.init_capture_kind_for_place(&place, capture_clause); + let fake_info = CaptureInfo { sources: SmallVec::new(), capture_kind }; + + capture_information.push((place, fake_info)); } - span_stack.truncate(actual_truncate_to); - } - } - fn ref_expr(&mut self, expr: ExprId, place: Option) { - if let Some(place) = place { - self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared)); + // This will update the min captures based on this new fake information. + self.compute_min_captures(closure_expr_id, capture_information); } - self.walk_expr(expr); - } - fn add_capture(&mut self, place: HirPlace, kind: CaptureKind) { - if self.is_upvar(&place) { - self.push_capture(place, kind); + if infer_kind { + // Unify the (as yet unbound) type variable in the closure + // args with the kind we inferred. + let closure_kind_ty = match args { + UpvarArgs::Closure(args) => args.as_closure().kind_ty(), + UpvarArgs::CoroutineClosure(args) => args.as_coroutine_closure().kind_ty(), + UpvarArgs::Coroutine(_) => unreachable!("coroutines don't have an inferred kind"), + }; + _ = self.demand_eqtype( + closure_expr_id.into(), + Ty::from_closure_kind(self.interner(), closure_kind), + closure_kind_ty, + ); } - } - fn mutate_path_pat(&mut self, path: &Path, id: PatId) { - if let Some(place) = self.path_place(path, id.into()) { - self.add_capture( - place, - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), + // For coroutine-closures, we additionally must compute the + // `coroutine_captures_by_ref_ty` type, which is used to generate the by-ref + // version of the coroutine-closure's output coroutine. + if let UpvarArgs::CoroutineClosure(args) = args + && !args.references_error() + { + let closure_env_region: Region<'_> = Region::new_bound( + self.interner(), + rustc_type_ir::INNERMOST, + BoundRegion { var: BoundVar::ZERO, kind: BoundRegionKind::ClosureEnv }, ); - self.current_capture_span_stack.pop(); // Remove the pattern span. - } - } - fn mutate_expr(&mut self, expr: ExprId, place: Option) { - if let Some(place) = place { - self.add_capture( - place, - CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), + let num_args = args + .as_coroutine_closure() + .coroutine_closure_sig() + .skip_binder() + .tupled_inputs_ty + .tuple_fields() + .len(); + + let tupled_upvars_ty_for_borrow = Ty::new_tup_from_iter( + self.interner(), + analyze_coroutine_closure_captures( + self.closure_min_captures_flattened(closure_expr_id), + self.closure_min_captures_flattened(ExpressionStore::coroutine_for_closure( + closure_expr_id, + )) + // Skip the captures that are just moving the closure's args + // into the coroutine. These are always by move, and we append + // those later in the `CoroutineClosureSignature` helper functions. + .skip(num_args), + |(_, parent_capture), (_, child_capture)| { + // This is subtle. See documentation on function. + let needs_ref = should_reborrow_from_env_of_parent_coroutine_closure( + parent_capture, + child_capture, + ); + + let upvar_ty = child_capture.place.ty(); + let capture = child_capture.info.capture_kind; + // Not all upvars are captured by ref, so use + // `apply_capture_kind_on_capture_ty` to ensure that we + // compute the right captured type. + apply_capture_kind_on_capture_ty( + self.interner(), + upvar_ty, + capture, + if needs_ref { closure_env_region } else { self.types.regions.erased }, + ) + }, + ), ); + let coroutine_captures_by_ref_ty = Ty::new_fn_ptr( + self.interner(), + Binder::bind_with_vars( + self.interner().mk_fn_sig( + [], + tupled_upvars_ty_for_borrow, + false, + Safety::Safe, + FnAbi::Rust, + ), + self.types.coroutine_captures_by_ref_bound_var_kinds, + ), + ); + _ = self.demand_eqtype( + closure_expr_id.into(), + args.as_coroutine_closure().coroutine_captures_by_ref_ty(), + coroutine_captures_by_ref_ty, + ); + + // Additionally, we can now constrain the coroutine's kind type. + // + // We only do this if `infer_kind`, because if we have constrained + // the kind from closure signature inference, the kind inferred + // for the inner coroutine may actually be more restrictive. + if infer_kind { + let TyKind::Coroutine(_, coroutine_args) = self.result.expr_ty(body).kind() else { + panic!(); + }; + _ = self.demand_eqtype( + closure_expr_id.into(), + coroutine_args.as_coroutine().kind_ty(), + Ty::from_coroutine_closure_kind(self.interner(), closure_kind), + ); + } } - self.walk_expr(expr); - } - fn consume_expr(&mut self, expr: ExprId) { - if let Some(place) = self.place_of_expr(expr) { - self.consume_place(place); + // Now that we've analyzed the closure, we know how each + // variable is borrowed, and we know what traits the closure + // implements (Fn vs FnMut etc). We now have some updates to do + // with that information. + // + // Note that no closure type C may have an upvar of type C + // (though it may reference itself via a trait object). This + // results from the desugaring of closures to a struct like + // `Foo<..., UV0...UVn>`. If one of those upvars referenced + // C, then the type would have infinite size (and the + // inference algorithm will reject it). + + // Equate the type variables for the upvars with the actual types. + let final_upvar_tys = self.final_upvar_tys(closure_expr_id); + debug!(?closure_expr_id, ?args, ?final_upvar_tys); + + // Build a tuple (U0..Un) of the final upvar types U0..Un + // and unify the upvar tuple type in the closure with it: + let final_tupled_upvars_type = Ty::new_tup(self.interner(), &final_upvar_tys); + self.demand_suptype(args.tupled_upvars_ty(), final_tupled_upvars_type); + + let fake_reads = delegate.fake_reads; + + self.result.closures_data.entry(closure_expr_id).or_default().fake_reads = + fake_reads.into_boxed_slice(); + + // If we are also inferred the closure kind here, + // process any deferred resolutions. + let deferred_call_resolutions = self.remove_deferred_call_resolutions(closure_expr_id); + for deferred_call_resolution in deferred_call_resolutions { + deferred_call_resolution.resolve(self); } - self.walk_expr(expr); } - fn consume_place(&mut self, place: HirPlace) { - if self.is_upvar(&place) { - let ty = place.ty(self); - let kind = if self.is_ty_copy(ty) { - CaptureKind::ByRef(BorrowKind::Shared) - } else { - CaptureKind::ByValue - }; - self.push_capture(place, kind); - } + /// Determines whether the body of the coroutine uses its upvars in a way that + /// consumes (i.e. moves) the value, which would force the coroutine to `FnOnce`. + /// In a more detailed comment above, we care whether this happens, since if + /// this happens, we want to force the coroutine to move all of the upvars it + /// would've borrowed from the parent coroutine-closure. + /// + /// This only really makes sense to be called on the child coroutine of a + /// coroutine-closure. + fn coroutine_body_consumes_upvars( + &mut self, + coroutine_def_id: ExprId, + body: ExprId, + upvars: UpvarsRef<'db>, + ) -> bool { + let mut delegate = InferBorrowKind { + closure_def_id: coroutine_def_id, + capture_information: Default::default(), + fake_reads: Default::default(), + }; + + let _ = euv::ExprUseVisitor::new(self, coroutine_def_id, upvars, &mut delegate) + .consume_expr(body); + + let (_, kind, _) = self + .process_collected_capture_information(CaptureBy::Ref, &delegate.capture_information); + + matches!(kind, ClosureKind::FnOnce) } - fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment]) { - if let Some((last, rest)) = adjustment.split_last() { - match &last.kind { - Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => { - self.walk_expr_with_adjust(tgt_expr, rest) - } - Adjust::Deref(Some(m)) => match m.0 { - Some(m) => { - self.ref_capture_with_adjusts(m, tgt_expr, rest); - } - None => unreachable!(), - }, - Adjust::Borrow(b) => { - self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest); - } - } - } else { - self.walk_expr_without_adjust(tgt_expr); - } + // Returns a list of `Ty`s for each upvar. + fn final_upvar_tys(&self, closure_id: ExprId) -> Vec> { + self.closure_min_captures_flattened(closure_id) + .map(|captured_place| { + let upvar_ty = captured_place.place.ty(); + let capture = captured_place.info.capture_kind; + + debug!(?captured_place.place, ?upvar_ty, ?capture, ?captured_place.mutability); + + apply_capture_kind_on_capture_ty( + self.interner(), + upvar_ty, + capture, + self.types.regions.erased, + ) + }) + .collect() } - fn ref_capture_with_adjusts(&mut self, m: Mutability, tgt_expr: ExprId, rest: &[Adjustment]) { - let capture_kind = match m { - Mutability::Mut => CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }), - Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared), - }; - if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) - && let Some(place) = - apply_adjusts_to_place(&mut self.current_capture_span_stack, place, rest) - { - self.add_capture(place, capture_kind); - } - self.walk_expr_with_adjust(tgt_expr, rest); + /// Adjusts the closure capture information to ensure that the operations aren't unsafe, + /// and that the path can be captured with required capture kind (depending on use in closure, + /// move closure etc.) + /// + /// Returns the set of adjusted information along with the inferred closure kind and span + /// associated with the closure kind inference. + /// + /// Note that we *always* infer a minimal kind, even if + /// we don't always *use* that in the final result (i.e., sometimes + /// we've taken the closure kind from the expectations instead, and + /// for coroutines we don't even implement the closure traits + /// really). + /// + /// If we inferred that the closure needs to be FnMut/FnOnce, last element of the returned tuple + /// contains a `Some()` with the `Place` that caused us to do so. + fn process_collected_capture_information( + &mut self, + capture_clause: CaptureBy, + capture_information: &InferredCaptureInformation, + ) -> (InferredCaptureInformation, ClosureKind, Option) { + let mut closure_kind = ClosureKind::LATTICE_BOTTOM; + let mut origin: Option = None; + + let processed = capture_information + .iter() + .cloned() + .map(|(place, mut capture_info)| { + // Apply rules for safety before inferring closure kind + let place = restrict_capture_precision(place, &mut capture_info); + + let place = truncate_capture_for_optimization(place, &mut capture_info); + + let updated = match capture_info.capture_kind { + UpvarCapture::ByValue => match closure_kind { + ClosureKind::Fn | ClosureKind::FnMut => { + (ClosureKind::FnOnce, Some(place.clone())) + } + // If closure is already FnOnce, don't update + ClosureKind::FnOnce => (closure_kind, origin.take()), + }, + + UpvarCapture::ByRef(BorrowKind::Mutable | BorrowKind::UniqueImmutable) => { + match closure_kind { + ClosureKind::Fn => (ClosureKind::FnMut, Some(place.clone())), + // Don't update the origin + ClosureKind::FnMut | ClosureKind::FnOnce => { + (closure_kind, origin.take()) + } + } + } + + _ => (closure_kind, origin.take()), + }; + + closure_kind = updated.0; + origin = updated.1; + + let place = match capture_clause { + CaptureBy::Value => adjust_for_move_closure(place, &mut capture_info), + CaptureBy::Ref => adjust_for_non_move_closure(place, &mut capture_info), + }; + + // This restriction needs to be applied after we have handled adjustments for `move` + // closures. We want to make sure any adjustment that might make us move the place into + // the closure gets handled. + let place = restrict_precision_for_drop_types(self, place, &mut capture_info); + + (place, capture_info) + }) + .collect(); + + (processed, closure_kind, origin) } - fn walk_expr(&mut self, tgt_expr: ExprId) { - if let Some(it) = self.result.expr_adjustments.get_mut(&tgt_expr) { - // FIXME: this take is completely unneeded, and just is here to make borrow checker - // happy. Remove it if you can. - let x_taken = mem::take(it); - self.walk_expr_with_adjust(tgt_expr, &x_taken); - *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken; - } else { - self.walk_expr_without_adjust(tgt_expr); + /// Analyzes the information collected by `InferBorrowKind` to compute the min number of + /// Places (and corresponding capture kind) that we need to keep track of to support all + /// the required captured paths. + /// + /// + /// Note: If this function is called multiple times for the same closure, it will update + /// the existing min_capture map that is stored in TypeckResults. + /// + /// Eg: + /// ``` + /// #[derive(Debug)] + /// struct Point { x: i32, y: i32 } + /// + /// let s = String::from("s"); // hir_id_s + /// let mut p = Point { x: 2, y: -2 }; // his_id_p + /// let c = || { + /// println!("{s:?}"); // L1 + /// p.x += 10; // L2 + /// println!("{}" , p.y); // L3 + /// println!("{p:?}"); // L4 + /// drop(s); // L5 + /// }; + /// ``` + /// and let hir_id_L1..5 be the expressions pointing to use of a captured variable on + /// the lines L1..5 respectively. + /// + /// InferBorrowKind results in a structure like this: + /// + /// ```ignore (illustrative) + /// { + /// Place(base: hir_id_s, projections: [], ....) -> { + /// capture_kind_expr: hir_id_L5, + /// path_expr_id: hir_id_L5, + /// capture_kind: ByValue + /// }, + /// Place(base: hir_id_p, projections: [Field(0, 0)], ...) -> { + /// capture_kind_expr: hir_id_L2, + /// path_expr_id: hir_id_L2, + /// capture_kind: ByValue + /// }, + /// Place(base: hir_id_p, projections: [Field(1, 0)], ...) -> { + /// capture_kind_expr: hir_id_L3, + /// path_expr_id: hir_id_L3, + /// capture_kind: ByValue + /// }, + /// Place(base: hir_id_p, projections: [], ...) -> { + /// capture_kind_expr: hir_id_L4, + /// path_expr_id: hir_id_L4, + /// capture_kind: ByValue + /// }, + /// } + /// ``` + /// + /// After the min capture analysis, we get: + /// ```ignore (illustrative) + /// { + /// hir_id_s -> [ + /// Place(base: hir_id_s, projections: [], ....) -> { + /// capture_kind_expr: hir_id_L5, + /// path_expr_id: hir_id_L5, + /// capture_kind: ByValue + /// }, + /// ], + /// hir_id_p -> [ + /// Place(base: hir_id_p, projections: [], ...) -> { + /// capture_kind_expr: hir_id_L2, + /// path_expr_id: hir_id_L4, + /// capture_kind: ByValue + /// }, + /// ], + /// } + /// ``` + #[instrument(level = "debug", skip(self))] + fn compute_min_captures( + &mut self, + closure_def_id: ExprId, + capture_information: InferredCaptureInformation, + ) { + if capture_information.is_empty() { + return; } - } - fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) { - match &self.store[tgt_expr] { - Expr::OffsetOf(_) => (), - Expr::InlineAsm(e) => e.operands.iter().for_each(|(_, op)| match op { - AsmOperand::In { expr, .. } - | AsmOperand::Out { expr: Some(expr), .. } - | AsmOperand::InOut { expr, .. } => self.walk_expr_without_adjust(*expr), - AsmOperand::SplitInOut { in_expr, out_expr, .. } => { - self.walk_expr_without_adjust(*in_expr); - if let Some(out_expr) = out_expr { - self.walk_expr_without_adjust(*out_expr); + let mut closure_data = + self.result.closures_data.remove(&closure_def_id).unwrap_or_default(); + let root_var_min_capture_list = &mut closure_data.min_captures; + let mut dedup_sources_scratch = FxHashMap::default(); + + for (mut place, capture_info) in capture_information.into_iter() { + let var_hir_id = match place.base { + PlaceBase::Upvar { var_id, .. } => var_id, + base => panic!("Expected upvar, found={:?}", base), + }; + + let Some(min_cap_list) = root_var_min_capture_list.get_mut(&var_hir_id) else { + let mutability = self.determine_capture_mutability(&place); + let min_cap_list = vec![CapturedPlace { place, info: capture_info, mutability }]; + root_var_min_capture_list.insert(var_hir_id, min_cap_list); + continue; + }; + + // Go through each entry in the current list of min_captures + // - if ancestor is found, update its capture kind to account for current place's + // capture information. + // + // - if descendant is found, remove it from the list, and update the current place's + // capture information to account for the descendant's capture kind. + // + // We can never be in a case where the list contains both an ancestor and a descendant + // Also there can only be ancestor but in case of descendants there might be + // multiple. + + let mut descendant_found = false; + let mut updated_capture_info = capture_info; + min_cap_list.retain(|possible_descendant| { + match determine_place_ancestry_relation(&place, &possible_descendant.place) { + // current place is ancestor of possible_descendant + PlaceAncestryRelation::Ancestor => { + descendant_found = true; + + let mut possible_descendant = possible_descendant.clone(); + + // Truncate the descendant (already in min_captures) to be same as the ancestor to handle any + // possible change in capture mode. + truncate_place_to_len_and_update_capture_kind( + &mut possible_descendant.place, + &mut possible_descendant.info, + place.projections.len(), + ); + + let backup_path_sources = determine_capture_sources( + &mut updated_capture_info, + &mut possible_descendant.info, + &mut dedup_sources_scratch, + ); + determine_capture_info( + &mut updated_capture_info, + &mut possible_descendant.info, + ); + + // we need to keep the ancestor's `path_expr_id` + updated_capture_info.sources = backup_path_sources; + false } + + _ => true, } - AsmOperand::Out { expr: None, .. } - | AsmOperand::Const(_) - | AsmOperand::Label(_) - | AsmOperand::Sym(_) => (), - }), - Expr::If { condition, then_branch, else_branch } => { - self.consume_expr(*condition); - self.consume_expr(*then_branch); - if let &Some(expr) = else_branch { - self.consume_expr(expr); - } - } - Expr::Async { statements, tail, .. } - | Expr::Unsafe { statements, tail, .. } - | Expr::Block { statements, tail, .. } => { - for s in statements.iter() { - match s { - Statement::Let { pat, type_ref: _, initializer, else_branch } => { - if let Some(else_branch) = else_branch { - self.consume_expr(*else_branch); - } - if let Some(initializer) = initializer { - if else_branch.is_some() { - self.consume_expr(*initializer); - } else { - self.walk_expr(*initializer); - } - if let Some(place) = self.place_of_expr(*initializer) { - self.consume_with_pat(place, *pat); - } - } + }); + + let mut ancestor_found = false; + if !descendant_found { + for possible_ancestor in min_cap_list.iter_mut() { + match determine_place_ancestry_relation(&place, &possible_ancestor.place) { + PlaceAncestryRelation::SamePlace => { + ancestor_found = true; + let backup_path_sources = determine_capture_sources( + &mut updated_capture_info, + &mut possible_ancestor.info, + &mut dedup_sources_scratch, + ); + determine_capture_info( + &mut possible_ancestor.info, + &mut updated_capture_info, + ); + possible_ancestor.info.sources = backup_path_sources; + + // Only one related place will be in the list. + break; } - Statement::Expr { expr, has_semi: _ } => { - self.consume_expr(*expr); + // current place is descendant of possible_ancestor + PlaceAncestryRelation::Descendant => { + ancestor_found = true; + + // Truncate the descendant (current place) to be same as the ancestor to handle any + // possible change in capture mode. + truncate_place_to_len_and_update_capture_kind( + &mut place, + &mut updated_capture_info, + possible_ancestor.place.projections.len(), + ); + + let backup_path_sources = determine_capture_sources( + &mut updated_capture_info, + &mut possible_ancestor.info, + &mut dedup_sources_scratch, + ); + determine_capture_info( + &mut possible_ancestor.info, + &mut updated_capture_info, + ); + + // we need to keep the ancestor's `sources` + possible_ancestor.info.sources = backup_path_sources; + + // Only one related place will be in the list. + break; } - Statement::Item(_) => (), + _ => {} } } - if let Some(tail) = tail { - self.consume_expr(*tail); - } - } - Expr::Call { callee, args } => { - self.consume_expr(*callee); - self.consume_exprs(args.iter().copied()); } - Expr::MethodCall { receiver, args, .. } => { - self.consume_expr(*receiver); - self.consume_exprs(args.iter().copied()); + + // Only need to insert when we don't have an ancestor in the existing min capture list + if !ancestor_found { + let mutability = self.determine_capture_mutability(&place); + let captured_place = + CapturedPlace { place, info: updated_capture_info, mutability }; + min_cap_list.push(captured_place); } - Expr::Match { expr, arms } => { - for arm in arms.iter() { - self.consume_expr(arm.expr); - if let Some(guard) = arm.guard { - self.consume_expr(guard); - } - } - self.walk_expr(*expr); - if let Some(discr_place) = self.place_of_expr(*expr) - && self.is_upvar(&discr_place) - { - let mut capture_mode = None; - for arm in arms.iter() { - self.walk_pat(&mut capture_mode, arm.pat); - } - if let Some(c) = capture_mode { - self.push_capture(discr_place, c); + } + + debug!( + "For closure={:?}, min_captures before sorting={:?}", + closure_def_id, root_var_min_capture_list + ); + + // Now that we have the minimized list of captures, sort the captures by field id. + // This causes the closure to capture the upvars in the same order as the fields are + // declared which is also the drop order. Thus, in situations where we capture all the + // fields of some type, the observable drop order will remain the same as it previously + // was even though we're dropping each capture individually. + // See https://github.com/rust-lang/project-rfc-2229/issues/42 and + // `tests/ui/closures/2229_closure_analysis/preserve_field_drop_order.rs`. + for (_, captures) in &mut *root_var_min_capture_list { + captures.sort_by(|capture1, capture2| { + fn is_field(p: &&Projection) -> bool { + match p.kind { + ProjectionKind::Field { .. } => true, + ProjectionKind::Deref | ProjectionKind::UnwrapUnsafeBinder => false, + p @ (ProjectionKind::Subslice | ProjectionKind::Index) => { + panic!("ProjectionKind {:?} was unexpected", p) + } } } - } - Expr::Break { expr, label: _ } - | Expr::Return { expr } - | Expr::Yield { expr } - | Expr::Yeet { expr } => { - if let &Some(expr) = expr { - self.consume_expr(expr); - } - } - &Expr::Become { expr } => { - self.consume_expr(expr); - } - Expr::RecordLit { fields, spread, .. } => { - if let RecordSpread::Expr(expr) = *spread { - self.consume_expr(expr); - } - self.consume_exprs(fields.iter().map(|it| it.expr)); - } - Expr::Field { expr, name: _ } => self.select_from_expr(*expr), - Expr::UnaryOp { expr, op: UnaryOp::Deref } => { - if self.result.method_resolution(tgt_expr).is_some() { - // Overloaded deref. - match self.expr_ty_after_adjustments(*expr).kind() { - TyKind::Ref(_, _, mutability) => { - let place = self.place_of_expr(*expr); - match mutability { - Mutability::Mut => self.mutate_expr(*expr, place), - Mutability::Not => self.ref_expr(*expr, place), + + // Need to sort only by Field projections, so filter away others. + // A previous implementation considered other projection types too + // but that caused ICE #118144 + let capture1_field_projections = capture1.place.projections.iter().filter(is_field); + let capture2_field_projections = capture2.place.projections.iter().filter(is_field); + + for (p1, p2) in capture1_field_projections.zip(capture2_field_projections) { + // We do not need to look at the `Projection.ty` fields here because at each + // step of the iteration, the projections will either be the same and therefore + // the types must be as well or the current projection will be different and + // we will return the result of comparing the field indexes. + match (p1.kind, p2.kind) { + ( + ProjectionKind::Field { field_idx: i1, .. }, + ProjectionKind::Field { field_idx: i2, .. }, + ) => { + // Compare only if paths are different. + // Otherwise continue to the next iteration + if i1 != i2 { + return i1.cmp(&i2); } } - // FIXME: Is this correct wrt. raw pointer derefs? - TyKind::RawPtr(..) => self.select_from_expr(*expr), - _ => never!("deref adjustments should include taking a mutable reference"), + // Given the filter above, this arm should never be hit + (l, r) => panic!("ProjectionKinds {:?} or {:?} were unexpected", l, r), } - } else { - self.select_from_expr(*expr); - } - } - Expr::Let { pat, expr } => { - self.walk_expr(*expr); - if let Some(place) = self.place_of_expr(*expr) { - self.consume_with_pat(place, *pat); - } - } - Expr::UnaryOp { expr, op: _ } - | Expr::Array(Array::Repeat { initializer: expr, repeat: _ }) - | Expr::Await { expr } - | Expr::Loop { body: expr, label: _ } - | Expr::Box { expr } - | Expr::Cast { expr, type_ref: _ } => { - self.consume_expr(*expr); - } - Expr::Ref { expr, rawness: _, mutability } => { - // We need to do this before we push the span so the order will be correct. - let place = self.place_of_expr(*expr); - self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr)); - match mutability { - hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr, place), - hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr, place), } - } - Expr::BinaryOp { lhs, rhs, op } => { - let Some(op) = op else { - return; - }; - if matches!(op, BinaryOp::Assignment { .. }) { - let place = self.place_of_expr(*lhs); - self.mutate_expr(*lhs, place); - self.consume_expr(*rhs); - return; - } - self.consume_expr(*lhs); - self.consume_expr(*rhs); - } - Expr::Range { lhs, rhs, range_type: _ } => { - if let &Some(expr) = lhs { - self.consume_expr(expr); - } - if let &Some(expr) = rhs { - self.consume_expr(expr); + + std::cmp::Ordering::Equal + }); + } + + debug!( + "For closure={:?}, min_captures after sorting={:#?}", + closure_def_id, root_var_min_capture_list + ); + self.result.closures_data.insert(closure_def_id, closure_data); + } + + fn normalize_capture_place(&self, place: Place) -> Place { + let mut place = self.infcx().resolve_vars_if_possible(place); + + // In the new solver, types in HIR `Place`s can contain unnormalized aliases, + // which can ICE later (e.g. when projecting fields for diagnostics). + let cause = ObligationCause::misc(); + let at = self.table.at(&cause); + match normalize::deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals( + at, + place.clone(), + vec![], + ) { + Ok((normalized, goals)) => { + if !goals.is_empty() { + // FIXME: Insert coroutine stalled predicates, this matters for MIR. + // let mut typeck_results = self.typeck_results.borrow_mut(); + // typeck_results.coroutine_stalled_predicates.extend( + // goals + // .into_iter() + // // FIXME: throwing away the param-env :( + // .map(|goal| (goal.predicate, self.misc(span))), + // ); } + normalized } - Expr::Index { base, index } => { - self.select_from_expr(*base); - self.consume_expr(*index); - } - Expr::Closure { .. } => { - let ty = self.expr_ty(tgt_expr); - let TyKind::Closure(id, _) = ty.kind() else { - never!("closure type is always closure"); - return; - }; - let (captures, _) = - self.result.closure_info.get(&id.0).expect( - "We sort closures, so we should always have data for inner closures", - ); - let mut cc = mem::take(&mut self.current_captures); - cc.extend(captures.iter().filter(|it| self.is_upvar(&it.place)).map(|it| { - CapturedItemWithoutTy { - place: it.place.clone(), - kind: it.kind, - span_stacks: it.span_stacks.clone(), - } - })); - self.current_captures = cc; - } - Expr::Array(Array::ElementList { elements: exprs }) | Expr::Tuple { exprs } => { - self.consume_exprs(exprs.iter().copied()) - } - &Expr::Assignment { target, value } => { - self.walk_expr(value); - let resolver_guard = - self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr); - match self.place_of_expr(value) { - Some(rhs_place) => { - self.inside_assignment = true; - self.consume_with_pat(rhs_place, target); - self.inside_assignment = false; - } - None => self.store.walk_pats(target, &mut |pat| match &self.store[pat] { - Pat::Path(path) => self.mutate_path_pat(path, pat), - &Pat::Expr(expr) => { - let place = self.place_of_expr(expr); - self.mutate_expr(expr, place); - } - _ => {} - }), + Err(_errors) => { + place.base_ty = self.types.types.error.store(); + for proj in &mut place.projections { + proj.ty = self.types.types.error.store(); } - self.resolver.reset_to_guard(resolver_guard); + place } + } + } + + fn closure_min_captures_flattened( + &self, + closure_expr_id: ExprId, + ) -> impl Iterator { + self.result + .closures_data + .get(&closure_expr_id) + .map(|closure_data| closure_data.min_captures.values().flatten()) + .into_iter() + .flatten() + } - Expr::Missing - | Expr::Continue { .. } - | Expr::Path(_) - | Expr::Literal(_) - | Expr::Const(_) - | Expr::Underscore => (), + fn init_capture_kind_for_place( + &self, + place: &Place, + capture_clause: CaptureBy, + ) -> UpvarCapture { + match capture_clause { + // In case of a move closure if the data is accessed through a reference we + // want to capture by ref to allow precise capture using reborrows. + // + // If the data will be moved out of this place, then the place will be truncated + // at the first Deref in `adjust_for_move_closure` and then moved into the closure. + // + // For example: + // + // struct Buffer<'a> { + // x: &'a String, + // y: Vec, + // } + // + // fn get<'a>(b: Buffer<'a>) -> impl Sized + 'a { + // let c = move || b.x; + // drop(b); + // c + // } + // + // Even though the closure is declared as move, when we are capturing borrowed data (in + // this case, *b.x) we prefer to capture by reference. + // Otherwise you'd get an error in 2021 immediately because you'd be trying to take + // ownership of the (borrowed) String or else you'd take ownership of b, as in 2018 and + // before, which is also an error. + CaptureBy::Value if !place.deref_tys().any(Ty::is_ref) => UpvarCapture::ByValue, + CaptureBy::Value | CaptureBy::Ref => UpvarCapture::ByRef(BorrowKind::Immutable), } } - fn walk_pat(&mut self, result: &mut Option, pat: PatId) { - let mut update_result = |ck: CaptureKind| match result { - Some(r) => { - *r = cmp::max(*r, ck); - } - None => *result = Some(ck), + fn place_for_root_variable(&self, closure_def_id: ExprId, var_hir_id: BindingId) -> Place { + let place = Place { + base_ty: self.result.binding_ty(var_hir_id).store(), + base: PlaceBase::Upvar { closure: closure_def_id, var_id: var_hir_id }, + projections: Default::default(), }; - self.walk_pat_inner( - pat, - &mut update_result, - BorrowKind::Mut { kind: MutBorrowKind::Default }, - ); + // Normalize eagerly when inserting into `capture_information`, so all downstream + // capture analysis can assume a normalized `Place`. + self.normalize_capture_place(place) } - fn walk_pat_inner( - &mut self, - p: PatId, - update_result: &mut impl FnMut(CaptureKind), - mut for_mut: BorrowKind, - ) { - match &self.store[p] { - Pat::Ref { .. } - | Pat::Box { .. } - | Pat::Missing - | Pat::Wild - | Pat::Tuple { .. } - | Pat::Expr(_) - | Pat::Or(_) => (), - Pat::TupleStruct { .. } | Pat::Record { .. } => { - if let Some(variant) = self.result.variant_resolution_for_pat(p) { - let adt = variant.adt_id(self.db); - let is_multivariant = match adt { - hir_def::AdtId::EnumId(e) => e.enum_variants(self.db).variants.len() != 1, - _ => false, - }; - if is_multivariant { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); - } - } - } - Pat::Slice { .. } - | Pat::ConstBlock(_) - | Pat::Path(_) - | Pat::Lit(_) - | Pat::Range { .. } => { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); + /// A captured place is mutable if + /// 1. Projections don't include a Deref of an immut-borrow, **and** + /// 2. PlaceBase is mut or projections include a Deref of a mut-borrow. + fn determine_capture_mutability(&mut self, place: &Place) -> Mutability { + let var_hir_id = match place.base { + PlaceBase::Upvar { var_id, .. } => var_id, + _ => unreachable!(), + }; + + let mut is_mutbl = if self.store[var_hir_id].mode == BindingAnnotation::Mutable { + Mutability::Mut + } else { + Mutability::Not + }; + + for pointer_ty in place.deref_tys() { + match self.table.structurally_resolve_type(pointer_ty).kind() { + // We don't capture derefs of raw ptrs + TyKind::RawPtr(_, _) => unreachable!(), + + // Dereferencing a mut-ref allows us to mut the Place if we don't deref + // an immut-ref after on top of this. + TyKind::Ref(.., Mutability::Mut) => is_mutbl = Mutability::Mut, + + // The place isn't mutable once we dereference an immutable reference. + TyKind::Ref(.., Mutability::Not) => return Mutability::Not, + + // Dereferencing a box doesn't change mutability + TyKind::Adt(def, ..) if def.is_box() => {} + + unexpected_ty => panic!("deref of unexpected pointer type {:?}", unexpected_ty), } - Pat::Bind { id, .. } => match self.result.binding_modes[p] { - crate::BindingMode::Move => { - if self.is_ty_copy(self.result.binding_ty(*id)) { - update_result(CaptureKind::ByRef(BorrowKind::Shared)); - } else { - update_result(CaptureKind::ByValue); - } + } + + is_mutbl + } +} + +/// Determines whether a child capture that is derived from a parent capture +/// should be borrowed with the lifetime of the parent coroutine-closure's env. +/// +/// There are two cases when this needs to happen: +/// +/// (1.) Are we borrowing data owned by the parent closure? We can determine if +/// that is the case by checking if the parent capture is by move, EXCEPT if we +/// apply a deref projection of an immutable reference, reborrows of immutable +/// references which aren't restricted to the LUB of the lifetimes of the deref +/// chain. This is why `&'short mut &'long T` can be reborrowed as `&'long T`. +/// +/// ```rust +/// let x = &1i32; // Let's call this lifetime `'1`. +/// let c = async move || { +/// println!("{:?}", *x); +/// // Even though the inner coroutine borrows by ref, we're only capturing `*x`, +/// // not `x`, so the inner closure is allowed to reborrow the data for `'1`. +/// }; +/// ``` +/// +/// (2.) If a coroutine is mutably borrowing from a parent capture, then that +/// mutable borrow cannot live for longer than either the parent *or* the borrow +/// that we have on the original upvar. Therefore we always need to borrow the +/// child capture with the lifetime of the parent coroutine-closure's env. +/// +/// ```rust +/// let mut x = 1i32; +/// let c = async || { +/// x = 1; +/// // The parent borrows `x` for some `&'1 mut i32`. +/// // However, when we call `c()`, we implicitly autoref for the signature of +/// // `AsyncFnMut::async_call_mut`. Let's call that lifetime `'call`. Since +/// // the maximum that `&'call mut &'1 mut i32` can be reborrowed is `&'call mut i32`, +/// // the inner coroutine should capture w/ the lifetime of the coroutine-closure. +/// }; +/// ``` +/// +/// If either of these cases apply, then we should capture the borrow with the +/// lifetime of the parent coroutine-closure's env. Luckily, if this function is +/// not correct, then the program is not unsound, since we still borrowck and validate +/// the choices made from this function -- the only side-effect is that the user +/// may receive unnecessary borrowck errors. +fn should_reborrow_from_env_of_parent_coroutine_closure( + parent_capture: &CapturedPlace, + child_capture: &CapturedPlace, +) -> bool { + // (1.) + (!parent_capture.is_by_ref() + // This is just inlined `place.deref_tys()` but truncated to just + // the child projections. Namely, look for a `&T` deref, since we + // can always extend `&'short mut &'long T` to `&'long T`. + && !child_capture + .place + .projections + .iter() + .enumerate() + .skip(parent_capture.place.projections.len()) + .any(|(idx, proj)| { + matches!(proj.kind, ProjectionKind::Deref) + && matches!( + child_capture.place.ty_before_projection(idx).kind(), + TyKind::Ref(.., Mutability::Not) + ) + })) + // (2.) + || matches!(child_capture.info.capture_kind, UpvarCapture::ByRef(BorrowKind::Mutable)) +} + +/// Truncate the capture so that the place being borrowed is in accordance with RFC 1240, +/// which states that it's unsafe to take a reference into a struct marked `repr(packed)`. +fn restrict_repr_packed_field_ref_capture( + mut place: Place, + capture_info: &mut CaptureInfo, +) -> Place { + let pos = place.projections.iter().enumerate().position(|(i, p)| { + let ty = place.ty_before_projection(i); + + // Return true for fields of packed structs. + match p.kind { + ProjectionKind::Field { .. } => match ty.kind() { + TyKind::Adt(def, _) if def.repr().packed() => { + // We stop here regardless of field alignment. Field alignment can change as + // types change, including the types of private fields in other crates, and that + // shouldn't affect how we compute our captures. + true } - crate::BindingMode::Ref(r) => match r { - Mutability::Mut => update_result(CaptureKind::ByRef(for_mut)), - Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)), - }, + + _ => false, }, + _ => false, } - if self.result.pat_adjustments.get(&p).is_some_and(|it| !it.is_empty()) { - for_mut = BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture }; - } - self.store.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut)); + }); + + if let Some(pos) = pos { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, pos); } - fn is_upvar(&self, place: &HirPlace) -> bool { - if let Some(c) = self.current_closure { - let InternedClosure(_, root) = self.db.lookup_intern_closure(c); - return self.store.is_binding_upvar(place.local, root); - } - false + place +} + +/// Returns a Ty that applies the specified capture kind on the provided capture Ty +fn apply_capture_kind_on_capture_ty<'db>( + interner: DbInterner<'db>, + ty: Ty<'db>, + capture_kind: UpvarCapture, + region: Region<'db>, +) -> Ty<'db> { + match capture_kind { + UpvarCapture::ByValue | UpvarCapture::ByUse => ty, + UpvarCapture::ByRef(kind) => Ty::new_ref(interner, region, ty, kind.to_mutbl_lossy()), } +} - fn is_ty_copy(&mut self, ty: Ty<'db>) -> bool { - if let TyKind::Closure(id, _) = ty.kind() { - // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We - // should probably let chalk know which closures are copy, but I don't know how doing it - // without creating query cycles. - return self - .result - .closure_info - .get(&id.0) - .map(|it| it.1 == FnTrait::Fn) - .unwrap_or(true); - } - let ty = self.table.resolve_completely(ty); - self.table.type_is_copy_modulo_regions(ty) +struct InferBorrowKind { + // The def-id of the closure whose kind and upvar accesses are being inferred. + closure_def_id: ExprId, + + /// For each Place that is captured by the closure, we track the minimal kind of + /// access we need (ref, ref mut, move, etc) and the expression that resulted in such access. + /// + /// Consider closure where s.str1 is captured via an ImmutableBorrow and + /// s.str2 via a MutableBorrow + /// + /// ```rust,no_run + /// struct SomeStruct { str1: String, str2: String }; + /// + /// // Assume that the HirId for the variable definition is `V1` + /// let mut s = SomeStruct { str1: format!("s1"), str2: format!("s2") }; + /// + /// let fix_s = |new_s2| { + /// // Assume that the HirId for the expression `s.str1` is `E1` + /// println!("Updating SomeStruct with str1={0}", s.str1); + /// // Assume that the HirId for the expression `*s.str2` is `E2` + /// s.str2 = new_s2; + /// }; + /// ``` + /// + /// For closure `fix_s`, (at a high level) the map contains + /// + /// ```ignore (illustrative) + /// Place { V1, [ProjectionKind::Field(Index=0, Variant=0)] } : CaptureKind { E1, ImmutableBorrow } + /// Place { V1, [ProjectionKind::Field(Index=1, Variant=0)] } : CaptureKind { E2, MutableBorrow } + /// ``` + capture_information: InferredCaptureInformation, + fake_reads: Vec<(Place, FakeReadCause, SmallVec<[CaptureSourceStack; 2]>)>, +} + +impl<'db> euv::Delegate<'db> for InferBorrowKind { + #[instrument(skip(self), level = "debug")] + fn fake_read( + &mut self, + place_with_id: PlaceWithOrigin, + cause: FakeReadCause, + ctx: &mut InferenceContext<'_, 'db>, + ) { + let PlaceBase::Upvar { .. } = place_with_id.place.base else { return }; + + // We need to restrict Fake Read precision to avoid fake reading unsafe code, + // such as deref of a raw pointer. + let dummy_capture_kind = UpvarCapture::ByRef(BorrowKind::Immutable); + let mut dummy_capture_info = + CaptureInfo { sources: SmallVec::new(), capture_kind: dummy_capture_kind }; + + let place = ctx.normalize_capture_place(place_with_id.place.clone()); + + let place = restrict_capture_precision(place, &mut dummy_capture_info); + + dummy_capture_info.capture_kind = dummy_capture_kind; + let place = restrict_repr_packed_field_ref_capture(place, &mut dummy_capture_info); + self.fake_reads.push((place, cause, place_with_id.origins)); } - fn select_from_expr(&mut self, expr: ExprId) { - self.walk_expr(expr); + #[instrument(skip(self), level = "debug")] + fn consume(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + let PlaceBase::Upvar { closure: upvar_closure, .. } = place_with_id.place.base else { + return; + }; + assert_eq!(self.closure_def_id, upvar_closure); + + let place = ctx.normalize_capture_place(place_with_id.place.clone()); + + self.capture_information.push(( + place, + CaptureInfo { sources: place_with_id.origins, capture_kind: UpvarCapture::ByValue }, + )); } - fn restrict_precision_for_unsafe(&mut self) { - // FIXME: Borrow checker problems without this. - let mut current_captures = std::mem::take(&mut self.current_captures); - for capture in &mut current_captures { - let mut ty = self.table.resolve_completely(self.result.binding_ty(capture.place.local)); - if ty.is_raw_ptr() || ty.is_union() { - capture.kind = CaptureKind::ByRef(BorrowKind::Shared); - self.truncate_capture_spans(capture, 0); - capture.place.projections.clear(); - continue; - } - for (i, p) in capture.place.projections.iter().enumerate() { - ty = p.projected_ty( - &self.table.infer_ctxt, - self.table.param_env, - ty, - self.owner.krate(self.db), - ); - if ty.is_raw_ptr() || ty.is_union() { - capture.kind = CaptureKind::ByRef(BorrowKind::Shared); - self.truncate_capture_spans(capture, i + 1); - capture.place.projections.truncate(i + 1); - break; - } - } - } - self.current_captures = current_captures; + #[instrument(skip(self), level = "debug")] + fn use_cloned(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + let PlaceBase::Upvar { closure: upvar_closure, .. } = place_with_id.place.base else { + return; + }; + assert_eq!(self.closure_def_id, upvar_closure); + + let place = ctx.normalize_capture_place(place_with_id.place.clone()); + + self.capture_information.push(( + place, + CaptureInfo { sources: place_with_id.origins, capture_kind: UpvarCapture::ByUse }, + )); } - fn adjust_for_move_closure(&mut self) { - // FIXME: Borrow checker won't allow without this. - let mut current_captures = std::mem::take(&mut self.current_captures); - for capture in &mut current_captures { - if let Some(first_deref) = - capture.place.projections.iter().position(|proj| *proj == HirPlaceProjection::Deref) - { - self.truncate_capture_spans(capture, first_deref); - capture.place.projections.truncate(first_deref); - } - capture.kind = CaptureKind::ByValue; + #[instrument(skip(self), level = "debug")] + fn borrow( + &mut self, + place_with_id: PlaceWithOrigin, + bk: BorrowKind, + ctx: &mut InferenceContext<'_, 'db>, + ) { + let PlaceBase::Upvar { closure: upvar_closure, .. } = place_with_id.place.base else { + return; + }; + assert_eq!(self.closure_def_id, upvar_closure); + + // The region here will get discarded/ignored + let capture_kind = UpvarCapture::ByRef(bk); + let mut capture_info = + CaptureInfo { sources: place_with_id.origins.iter().cloned().collect(), capture_kind }; + + let place = ctx.normalize_capture_place(place_with_id.place.clone()); + + // We only want repr packed restriction to be applied to reading references into a packed + // struct, and not when the data is being moved. Therefore we call this method here instead + // of in `restrict_capture_precision`. + let place = restrict_repr_packed_field_ref_capture(place, &mut capture_info); + + // Raw pointers don't inherit mutability + if place.deref_tys().any(Ty::is_raw_ptr) { + capture_info.capture_kind = UpvarCapture::ByRef(BorrowKind::Immutable); } - self.current_captures = current_captures; + + self.capture_information.push((place, capture_info)); } - fn minimize_captures(&mut self) { - self.current_captures.sort_unstable_by_key(|it| it.place.projections.len()); - let mut hash_map = FxHashMap::::default(); - let result = mem::take(&mut self.current_captures); - for mut item in result { - let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] }; - let mut it = item.place.projections.iter(); - let prev_index = loop { - if let Some(k) = hash_map.get(&lookup_place) { - break Some(*k); - } - match it.next() { - Some(it) => { - lookup_place.projections.push(*it); - } - None => break None, - } - }; - match prev_index { - Some(p) => { - let prev_projections_len = self.current_captures[p].place.projections.len(); - self.truncate_capture_spans(&mut item, prev_projections_len); - self.current_captures[p].span_stacks.extend(item.span_stacks); - let len = self.current_captures[p].place.projections.len(); - let kind_after_truncate = - item.place.capture_kind_of_truncated_place(item.kind, len); - self.current_captures[p].kind = - cmp::max(kind_after_truncate, self.current_captures[p].kind); - } - None => { - hash_map.insert(item.place.clone(), self.current_captures.len()); - self.current_captures.push(item); - } - } - } + #[instrument(skip(self), level = "debug")] + fn mutate(&mut self, assignee_place: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + self.borrow(assignee_place, BorrowKind::Mutable, ctx); } +} - fn consume_with_pat(&mut self, mut place: HirPlace, tgt_pat: PatId) { - let adjustments_count = - self.result.pat_adjustments.get(&tgt_pat).map(|it| it.len()).unwrap_or_default(); - place.projections.extend((0..adjustments_count).map(|_| HirPlaceProjection::Deref)); - self.current_capture_span_stack - .extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat))); - 'reset_span_stack: { - match &self.store[tgt_pat] { - Pat::Missing | Pat::Wild => (), - Pat::Tuple { args, ellipsis } => { - let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let field_count = match self.result.pat_ty(tgt_pat).kind() { - TyKind::Tuple(s) => s.len(), - _ => break 'reset_span_stack, - }; - let fields = 0..field_count; - let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); - for (&arg, i) in it { - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(HirPlaceProjection::TupleField(i as u32)); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - Pat::Or(pats) => { - for pat in pats.iter() { - self.consume_with_pat(place.clone(), *pat); - } - } - Pat::Record { args, .. } => { - let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { - break 'reset_span_stack; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place) - } - VariantId::StructId(s) => { - let vd = s.fields(self.db); - for field_pat in args.iter() { - let arg = field_pat.pat; - let Some(local_id) = vd.field(&field_pat.name) else { - continue; - }; - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(HirPlaceProjection::Field(FieldId { - parent: variant, - local_id, - })); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - } - } - Pat::Range { .. } | Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Lit(_) => { - self.consume_place(place) - } - Pat::Path(path) => { - if self.inside_assignment { - self.mutate_path_pat(path, tgt_pat); - } - self.consume_place(place); - } - &Pat::Bind { id, subpat: _ } => { - let mode = self.result.binding_modes[tgt_pat]; - let capture_kind = match mode { - BindingMode::Move => { - self.consume_place(place); - break 'reset_span_stack; - } - BindingMode::Ref(Mutability::Not) => BorrowKind::Shared, - BindingMode::Ref(Mutability::Mut) => { - BorrowKind::Mut { kind: MutBorrowKind::Default } - } - }; - self.current_capture_span_stack.push(MirSpan::BindingId(id)); - self.add_capture(place, CaptureKind::ByRef(capture_kind)); - self.current_capture_span_stack.pop(); - } - Pat::TupleStruct { path: _, args, ellipsis } => { - let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else { - break 'reset_span_stack; - }; - match variant { - VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { - self.consume_place(place) - } - VariantId::StructId(s) => { - let vd = s.fields(self.db); - let (al, ar) = - args.split_at(ellipsis.map_or(args.len(), |it| it as usize)); - let fields = vd.fields().iter(); - let it = al - .iter() - .zip(fields.clone()) - .chain(ar.iter().rev().zip(fields.rev())); - for (&arg, (i, _)) in it { - let mut p = place.clone(); - self.current_capture_span_stack.push(MirSpan::PatId(arg)); - p.projections.push(HirPlaceProjection::Field(FieldId { - parent: variant, - local_id: i, - })); - self.consume_with_pat(p, arg); - self.current_capture_span_stack.pop(); - } - } - } - } - Pat::Ref { pat, mutability: _ } => { - self.current_capture_span_stack.push(MirSpan::PatId(tgt_pat)); - place.projections.push(HirPlaceProjection::Deref); - self.consume_with_pat(place, *pat); - self.current_capture_span_stack.pop(); - } - Pat::Box { .. } => (), // not supported - &Pat::Expr(expr) => { - self.consume_place(place); - let pat_capture_span_stack = mem::take(&mut self.current_capture_span_stack); - let old_inside_assignment = mem::replace(&mut self.inside_assignment, false); - let lhs_place = self.place_of_expr(expr); - self.mutate_expr(expr, lhs_place); - self.inside_assignment = old_inside_assignment; - self.current_capture_span_stack = pat_capture_span_stack; +/// Rust doesn't permit moving fields out of a type that implements drop +#[instrument(skip(fcx), ret, level = "debug")] +fn restrict_precision_for_drop_types<'a, 'db>( + fcx: &mut InferenceContext<'a, 'db>, + mut place: Place, + capture_info: &mut CaptureInfo, +) -> Place { + let is_copy_type = fcx.infcx().type_is_copy_modulo_regions(fcx.table.param_env, place.ty()); + + if let (false, UpvarCapture::ByValue) = (is_copy_type, capture_info.capture_kind) { + for i in 0..place.projections.len() { + match place.ty_before_projection(i).kind() { + TyKind::Adt(def, _) if def.destructor(fcx.interner()).is_some() => { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, i); + break; } + _ => {} } } - self.current_capture_span_stack - .truncate(self.current_capture_span_stack.len() - adjustments_count); } - fn consume_exprs(&mut self, exprs: impl Iterator) { - for expr in exprs { - self.consume_expr(expr); - } + place +} + +/// Truncate `place` so that an `unsafe` block isn't required to capture it. +/// - No projections are applied to raw pointers, since these require unsafe blocks. We capture +/// them completely. +/// - No projections are applied on top of Union ADTs, since these require unsafe blocks. +fn restrict_precision_for_unsafe(mut place: Place, capture_info: &mut CaptureInfo) -> Place { + if place.base_ty.as_ref().is_raw_ptr() { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, 0); } - fn closure_kind(&self) -> FnTrait { - let mut r = FnTrait::Fn; - for it in &self.current_captures { - r = cmp::min( - r, - match &it.kind { - CaptureKind::ByRef(BorrowKind::Mut { .. }) => FnTrait::FnMut, - CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn, - CaptureKind::ByValue => FnTrait::FnOnce, - }, - ) - } - r + if place.base_ty.as_ref().is_union() { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, 0); } - fn analyze_closure(&mut self, closure: InternedClosureId) -> FnTrait { - let InternedClosure(_, root) = self.db.lookup_intern_closure(closure); - self.current_closure = Some(closure); - let Expr::Closure { body, capture_by, .. } = &self.store[root] else { - unreachable!("Closure expression id is always closure"); - }; - self.consume_expr(*body); - for item in &self.current_captures { - if matches!( - item.kind, - CaptureKind::ByRef(BorrowKind::Mut { - kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow - }) - ) && !item.place.projections.contains(&HirPlaceProjection::Deref) - { - // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in - // MIR. I didn't do that due duplicate diagnostics. - self.result.mutated_bindings_in_closure.insert(item.place.local); - } + for (i, proj) in place.projections.iter().enumerate() { + if proj.ty.as_ref().is_raw_ptr() { + // Don't apply any projections on top of a raw ptr. + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, i + 1); + break; } - self.restrict_precision_for_unsafe(); - // `closure_kind` should be done before adjust_for_move_closure - // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does. - // rustc also does diagnostics here if the latter is not a subtype of the former. - let closure_kind = self - .result - .closure_info - .get(&closure) - .map_or_else(|| self.closure_kind(), |info| info.1); - match capture_by { - CaptureBy::Value => self.adjust_for_move_closure(), - CaptureBy::Ref => (), + + if proj.ty.as_ref().is_union() { + // Don't capture precise fields of a union. + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, i + 1); + break; } - self.minimize_captures(); - self.strip_captures_ref_span(); - let result = mem::take(&mut self.current_captures); - let captures = result.into_iter().map(|it| it.with_ty(self)).collect::>(); - self.result.closure_info.insert(closure, (captures, closure_kind)); - closure_kind } - fn strip_captures_ref_span(&mut self) { - // FIXME: Borrow checker won't allow without this. - let mut captures = std::mem::take(&mut self.current_captures); - for capture in &mut captures { - if matches!(capture.kind, CaptureKind::ByValue) { - for span_stack in &mut capture.span_stacks { - if span_stack[span_stack.len() - 1].is_ref_span(self.store) { - span_stack.truncate(span_stack.len() - 1); - } - } + place +} + +/// Truncate projections so that the following rules are obeyed by the captured `place`: +/// - No Index projections are captured, since arrays are captured completely. +/// - No unsafe block is required to capture `place`. +/// +/// Returns the truncated place and updated capture mode. +#[instrument(ret, level = "debug")] +fn restrict_capture_precision(place: Place, capture_info: &mut CaptureInfo) -> Place { + let mut place = restrict_precision_for_unsafe(place, capture_info); + + if place.projections.is_empty() { + // Nothing to do here + return place; + } + + for (i, proj) in place.projections.iter().enumerate() { + match proj.kind { + ProjectionKind::Index | ProjectionKind::Subslice => { + // Arrays are completely captured, so we drop Index and Subslice projections + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, i); + return place; } + ProjectionKind::Deref => {} + ProjectionKind::Field { .. } => {} + ProjectionKind::UnwrapUnsafeBinder => {} } - self.current_captures = captures; } - pub(crate) fn infer_closures(&mut self) { - let deferred_closures = self.sort_closures(); - for (closure, exprs) in deferred_closures.into_iter().rev() { - self.current_captures = vec![]; - let kind = self.analyze_closure(closure); - - for (derefed_callee, callee_ty, params, expr) in exprs { - if let &Expr::Call { callee, .. } = &self.store[expr] { - let mut adjustments = - self.result.expr_adjustments.remove(&callee).unwrap_or_default().into_vec(); - self.write_fn_trait_method_resolution( - kind, - derefed_callee, - &mut adjustments, - callee_ty, - ¶ms, - expr, - ); - self.result.expr_adjustments.insert(callee, adjustments.into_boxed_slice()); - } + place +} + +/// Truncate deref of any reference. +#[instrument(ret, level = "debug")] +fn adjust_for_move_closure(mut place: Place, capture_info: &mut CaptureInfo) -> Place { + let first_deref = place.projections.iter().position(|proj| proj.kind == ProjectionKind::Deref); + + if let Some(idx) = first_deref { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, idx); + } + + capture_info.capture_kind = UpvarCapture::ByValue; + place +} + +/// Adjust closure capture just that if taking ownership of data, only move data +/// from enclosing stack frame. +#[instrument(ret, level = "debug")] +fn adjust_for_non_move_closure(mut place: Place, capture_info: &mut CaptureInfo) -> Place { + let contains_deref = + place.projections.iter().position(|proj| proj.kind == ProjectionKind::Deref); + + match capture_info.capture_kind { + UpvarCapture::ByValue | UpvarCapture::ByUse => { + if let Some(idx) = contains_deref { + truncate_place_to_len_and_update_capture_kind(&mut place, capture_info, idx); } } + + UpvarCapture::ByRef(..) => {} } - /// We want to analyze some closures before others, to have a correct analysis: - /// * We should analyze nested closures before the parent, since the parent should capture some of - /// the things that its children captures. - /// * If a closure calls another closure, we need to analyze the callee, to find out how we should - /// capture it (e.g. by move for FnOnce) - /// - /// These dependencies are collected in the main inference. We do a topological sort in this function. It - /// will consume the `deferred_closures` field and return its content in a sorted vector. - fn sort_closures( - &mut self, - ) -> Vec<(InternedClosureId, Vec<(Ty<'db>, Ty<'db>, Vec>, ExprId)>)> { - let mut deferred_closures = mem::take(&mut self.deferred_closures); - let mut dependents_count: FxHashMap = - deferred_closures.keys().map(|it| (*it, 0)).collect(); - for deps in self.closure_dependencies.values() { - for dep in deps { - *dependents_count.entry(*dep).or_default() += 1; - } + place +} + +/// At the end, `capture_info_a` will contain the selected info. +fn determine_capture_info(capture_info_a: &mut CaptureInfo, capture_info_b: &mut CaptureInfo) { + // If the capture kind is equivalent then, we don't need to escalate and can compare the + // expressions. + let eq_capture_kind = match (capture_info_a.capture_kind, capture_info_b.capture_kind) { + (UpvarCapture::ByValue, UpvarCapture::ByValue) => true, + (UpvarCapture::ByUse, UpvarCapture::ByUse) => true, + (UpvarCapture::ByRef(ref_a), UpvarCapture::ByRef(ref_b)) => ref_a == ref_b, + (UpvarCapture::ByValue, _) | (UpvarCapture::ByUse, _) | (UpvarCapture::ByRef(_), _) => { + false } - let mut queue: Vec<_> = - deferred_closures.keys().copied().filter(|&it| dependents_count[&it] == 0).collect(); - let mut result = vec![]; - while let Some(it) = queue.pop() { - if let Some(d) = deferred_closures.remove(&it) { - result.push((it, d)); + }; + + let swap = if eq_capture_kind { + false + } else { + // We select the CaptureKind which ranks higher based the following priority order: + // (ByUse | ByValue) > MutBorrow > UniqueImmBorrow > ImmBorrow + match (capture_info_a.capture_kind, capture_info_b.capture_kind) { + (UpvarCapture::ByUse, UpvarCapture::ByValue) + | (UpvarCapture::ByValue, UpvarCapture::ByUse) => { + panic!("Same capture can't be ByUse and ByValue at the same time") } - for &dep in self.closure_dependencies.get(&it).into_iter().flat_map(|it| it.iter()) { - let cnt = dependents_count.get_mut(&dep).unwrap(); - *cnt -= 1; - if *cnt == 0 { - queue.push(dep); + (UpvarCapture::ByValue, UpvarCapture::ByValue) + | (UpvarCapture::ByUse, UpvarCapture::ByUse) + | (UpvarCapture::ByValue | UpvarCapture::ByUse, UpvarCapture::ByRef(_)) => false, + (UpvarCapture::ByRef(_), UpvarCapture::ByValue | UpvarCapture::ByUse) => true, + (UpvarCapture::ByRef(ref_a), UpvarCapture::ByRef(ref_b)) => { + match (ref_a, ref_b) { + // Take LHS: + (BorrowKind::UniqueImmutable | BorrowKind::Mutable, BorrowKind::Immutable) + | (BorrowKind::Mutable, BorrowKind::UniqueImmutable) => false, + + // Take RHS: + (BorrowKind::Immutable, BorrowKind::UniqueImmutable | BorrowKind::Mutable) + | (BorrowKind::UniqueImmutable, BorrowKind::Mutable) => true, + + (BorrowKind::Immutable, BorrowKind::Immutable) + | (BorrowKind::UniqueImmutable, BorrowKind::UniqueImmutable) + | (BorrowKind::Mutable, BorrowKind::Mutable) => { + panic!("Expected unequal capture kinds"); + } } } } - assert!(deferred_closures.is_empty(), "we should have analyzed all closures"); - result + }; + + if swap { + mem::swap(capture_info_a, capture_info_b); } +} - pub(crate) fn add_current_closure_dependency(&mut self, dep: InternedClosureId) { - if let Some(c) = self.current_closure - && !dep_creates_cycle(&self.closure_dependencies, &mut FxHashSet::default(), c, dep) - { - self.closure_dependencies.entry(c).or_default().push(dep); - } +fn determine_capture_sources( + capture_info_a: &mut CaptureInfo, + capture_info_b: &mut CaptureInfo, + dedup_sources_scratch: &mut FxHashMap, +) -> SmallVec<[CaptureSourceStack; 2]> { + dedup_sources_scratch.clear(); + dedup_sources_scratch.extend( + mem::take(&mut capture_info_a.sources).into_iter().map(|it| (it.final_source(), it)), + ); + dedup_sources_scratch.extend( + mem::take(&mut capture_info_b.sources).into_iter().map(|it| (it.final_source(), it)), + ); - fn dep_creates_cycle( - closure_dependencies: &FxHashMap>, - visited: &mut FxHashSet, - from: InternedClosureId, - to: InternedClosureId, - ) -> bool { - if !visited.insert(from) { - return false; - } + let mut result = mem::take(&mut capture_info_a.sources); + result.clear(); + result.extend(dedup_sources_scratch.values().cloned()); + result +} - if from == to { - return true; - } +/// Truncates `place` to have up to `len` projections. +/// `curr_mode` is the current required capture kind for the place. +/// Returns the truncated `place` and the updated required capture kind. +/// +/// Note: Capture kind changes from `MutBorrow` to `UniqueImmBorrow` if the truncated part of the `place` +/// contained `Deref` of `&mut`. +fn truncate_place_to_len_and_update_capture_kind( + place: &mut Place, + info: &mut CaptureInfo, + len: usize, +) { + let is_mut_ref = |ty: Ty<'_>| matches!(ty.kind(), TyKind::Ref(.., Mutability::Mut)); - if let Some(deps) = closure_dependencies.get(&to) { - for dep in deps { - if dep_creates_cycle(closure_dependencies, visited, from, *dep) { - return true; - } + // If the truncated part of the place contains `Deref` of a `&mut` then convert MutBorrow -> + // UniqueImmBorrow + // Note that if the place contained Deref of a raw pointer it would've not been MutBorrow, so + // we don't need to worry about that case here. + match info.capture_kind { + UpvarCapture::ByRef(BorrowKind::Mutable) => { + for i in len..place.projections.len() { + if place.projections[i].kind == ProjectionKind::Deref + && is_mut_ref(place.ty_before_projection(i)) + { + info.capture_kind = UpvarCapture::ByRef(BorrowKind::UniqueImmutable); + break; } } + } - false + UpvarCapture::ByRef(..) => {} + UpvarCapture::ByValue | UpvarCapture::ByUse => {} + } + + // Now fix the sources, to point at the smaller place. + for source in &mut info.sources { + // +1 because the first place is the base. + source.truncate(len + 1); + } + + place.projections.truncate(len); +} + +/// Determines the Ancestry relationship of Place A relative to Place B +/// +/// `PlaceAncestryRelation::Ancestor` implies Place A is ancestor of Place B +/// `PlaceAncestryRelation::Descendant` implies Place A is descendant of Place B +/// `PlaceAncestryRelation::Divergent` implies neither of them is the ancestor of the other. +fn determine_place_ancestry_relation(place_a: &Place, place_b: &Place) -> PlaceAncestryRelation { + // If Place A and Place B don't start off from the same root variable, they are divergent. + if place_a.base != place_b.base { + return PlaceAncestryRelation::Divergent; + } + + // Assume of length of projections_a = n + let projections_a = &place_a.projections; + + // Assume of length of projections_b = m + let projections_b = &place_b.projections; + + let same_initial_projections = + iter::zip(projections_a, projections_b).all(|(proj_a, proj_b)| proj_a.kind == proj_b.kind); + + if same_initial_projections { + use std::cmp::Ordering; + + // First min(n, m) projections are the same + // Select Ancestor/Descendant + match projections_b.len().cmp(&projections_a.len()) { + Ordering::Greater => PlaceAncestryRelation::Ancestor, + Ordering::Equal => PlaceAncestryRelation::SamePlace, + Ordering::Less => PlaceAncestryRelation::Descendant, } + } else { + PlaceAncestryRelation::Divergent } } -/// Call this only when the last span in the stack isn't a split. -fn apply_adjusts_to_place( - current_capture_span_stack: &mut Vec, - mut r: HirPlace, - adjustments: &[Adjustment], -) -> Option { - let span = *current_capture_span_stack.last().expect("empty capture span stack"); - for adj in adjustments { - match &adj.kind { - Adjust::Deref(None) => { - current_capture_span_stack.push(span); - r.projections.push(HirPlaceProjection::Deref); - } - _ => return None, +/// Reduces the precision of the captured place when the precision doesn't yield any benefit from +/// borrow checking perspective, allowing us to save us on the size of the capture. +/// +/// +/// Fields that are read through a shared reference will always be read via a shared ref or a copy, +/// and therefore capturing precise paths yields no benefit. This optimization truncates the +/// rightmost deref of the capture if the deref is applied to a shared ref. +/// +/// Reason we only drop the last deref is because of the following edge case: +/// +/// ``` +/// # struct A { field_of_a: Box } +/// # struct B {} +/// # struct C<'a>(&'a i32); +/// struct MyStruct<'a> { +/// a: &'static A, +/// b: B, +/// c: C<'a>, +/// } +/// +/// fn foo<'a, 'b>(m: &'a MyStruct<'b>) -> impl FnMut() + 'static { +/// || drop(&*m.a.field_of_a) +/// // Here we really do want to capture `*m.a` because that outlives `'static` +/// +/// // If we capture `m`, then the closure no longer outlives `'static` +/// // it is constrained to `'a` +/// } +/// ``` +#[instrument(ret, level = "debug")] +fn truncate_capture_for_optimization(mut place: Place, info: &mut CaptureInfo) -> Place { + let is_shared_ref = |ty: Ty<'_>| matches!(ty.kind(), TyKind::Ref(.., Mutability::Not)); + + // Find the rightmost deref (if any). All the projections that come after this + // are fields or other "in-place pointer adjustments"; these refer therefore to + // data owned by whatever pointer is being dereferenced here. + let idx = place.projections.iter().rposition(|proj| ProjectionKind::Deref == proj.kind); + + match idx { + // If that pointer is a shared reference, then we don't need those fields. + Some(idx) if is_shared_ref(place.ty_before_projection(idx)) => { + truncate_place_to_len_and_update_capture_kind(&mut place, info, idx + 1) } + None | Some(_) => {} } - Some(r) + + place +} + +/// Precise capture is enabled if user is using Rust Edition 2021 or higher. +/// `span` is the span of the closure. +fn enable_precise_capture(edition: Edition) -> bool { + // FIXME: We should use the edition from the closure expr. + edition.at_least_2021() +} + +fn analyze_coroutine_closure_captures<'a, T>( + parent_captures: impl IntoIterator, + child_captures: impl IntoIterator, + mut for_each: impl FnMut((usize, &'a CapturedPlace), (usize, &'a CapturedPlace)) -> T, +) -> impl Iterator { + let mut result = SmallVec::<[_; 10]>::new(); + + let mut child_captures = child_captures.into_iter().enumerate().peekable(); + + // One parent capture may correspond to several child captures if we end up + // refining the set of captures via edition-2021 precise captures. We want to + // match up any number of child captures with one parent capture, so we keep + // peeking off this `Peekable` until the child doesn't match anymore. + for (parent_field_idx, parent_capture) in parent_captures.into_iter().enumerate() { + // Make sure we use every field at least once, b/c why are we capturing something + // if it's not used in the inner coroutine. + let mut field_used_at_least_once = false; + + // A parent matches a child if they share the same prefix of projections. + // The child may have more, if it is capturing sub-fields out of + // something that is captured by-move in the parent closure. + while child_captures.peek().is_some_and(|(_, child_capture)| { + child_prefix_matches_parent_projections(parent_capture, child_capture) + }) { + let (child_field_idx, child_capture) = child_captures.next().unwrap(); + // This analysis only makes sense if the parent capture is a + // prefix of the child capture. + assert!( + child_capture.place.projections.len() >= parent_capture.place.projections.len(), + "parent capture ({parent_capture:#?}) expected to be prefix of \ + child capture ({child_capture:#?})" + ); + + result.push(for_each( + (parent_field_idx, parent_capture), + (child_field_idx, child_capture), + )); + + field_used_at_least_once = true; + } + + // Make sure the field was used at least once. + assert!( + field_used_at_least_once, + "we captured {parent_capture:#?} but it was not used in the child coroutine?" + ); + } + assert_eq!(child_captures.next(), None, "leftover child captures?"); + + result.into_iter() +} + +fn child_prefix_matches_parent_projections( + parent_capture: &CapturedPlace, + child_capture: &CapturedPlace, +) -> bool { + let PlaceBase::Upvar { var_id: parent_base, .. } = parent_capture.place.base else { + panic!("expected capture to be an upvar"); + }; + let PlaceBase::Upvar { var_id: child_base, .. } = child_capture.place.base else { + panic!("expected capture to be an upvar"); + }; + + parent_base == child_base + && std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections) + .all(|(child, parent)| child.kind == parent.kind) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis/expr_use_visitor.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis/expr_use_visitor.rs new file mode 100644 index 0000000000000..099fa18168b2c --- /dev/null +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/closure/analysis/expr_use_visitor.rs @@ -0,0 +1,1705 @@ +//! A different sort of visitor for walking fn bodies. Unlike the +//! normal visitor, which just walks the entire body in one shot, the +//! `ExprUseVisitor` determines how expressions are being used. +//! +//! This is only used for upvar inference. + +use either::Either; +use hir_def::{ + AdtId, HasModule, VariantId, + attrs::AttrFlags, + hir::{ + Array, AsmOperand, BindingId, Expr, ExprId, ExprOrPatId, MatchArm, Pat, PatId, + RecordLitField, RecordSpread, Statement, + }, + resolver::ValueNs, +}; +use rustc_ast_ir::{try_visit, visit::VisitorResult}; +use rustc_type_ir::{ + FallibleTypeFolder, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor, + inherent::{AdtDef, IntoKind, Ty as _}, +}; +use smallvec::{SmallVec, smallvec}; +use syntax::ast::{BinaryOp, UnaryOp}; +use tracing::{debug, instrument}; + +use crate::{ + Adjust, Adjustment, AutoBorrow, BindingMode, + infer::{CaptureSourceStack, InferenceContext, UpvarCapture, closure::analysis::BorrowKind}, + method_resolution::CandidateId, + next_solver::{DbInterner, ErrorGuaranteed, StoredTy, Ty, TyKind}, + upvars::UpvarsRef, + utils::EnumerateAndAdjustIterator, +}; + +type Result = std::result::Result; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ProjectionKind { + /// A dereference of a pointer, reference or `Box` of the given type. + Deref, + + /// `B.F` where `B` is the base expression and `F` is + /// the field. The field is identified by which variant + /// it appears in along with a field index. The variant + /// is used for enums. + Field { field_idx: u32, variant_idx: u32 }, + + /// Some index like `B[x]`, where `B` is the base + /// expression. We don't preserve the index `x` because + /// we won't need it. + Index, + + /// A subslice covering a range of values like `B[x..y]`. + Subslice, + + /// `unwrap_binder!(expr)` + UnwrapUnsafeBinder, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum PlaceBase { + /// A temporary variable. + Rvalue, + /// A named `static` item. + StaticItem, + /// A named local variable. + Local(BindingId), + /// An upvar referenced by closure env. + Upvar { closure: ExprId, var_id: BindingId }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Projection { + /// Type after the projection is applied. + pub ty: StoredTy, + + /// Defines the kind of access made by the projection. + pub kind: ProjectionKind, +} + +/// A `Place` represents how a value is located in memory. This does not +/// always correspond to a syntactic place expression. For example, when +/// processing a pattern, a `Place` can be used to refer to the sub-value +/// currently being inspected. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Place { + /// The type of the `PlaceBase` + pub base_ty: StoredTy, + /// The "outermost" place that holds this value. + pub base: PlaceBase, + /// How this place is derived from the base place. + pub projections: Vec, +} + +impl<'db> TypeVisitable> for Place { + fn visit_with>>(&self, visitor: &mut V) -> V::Result { + let Self { base_ty, base: _, projections } = self; + try_visit!(base_ty.as_ref().visit_with(visitor)); + for proj in projections { + let Projection { ty, kind: _ } = proj; + try_visit!(ty.as_ref().visit_with(visitor)); + } + V::Result::output() + } +} + +impl<'db> TypeFoldable> for Place { + fn try_fold_with>>( + self, + folder: &mut F, + ) -> Result { + let Self { base_ty, base, projections } = self; + let base_ty = base_ty.as_ref().try_fold_with(folder)?.store(); + let projections = projections + .into_iter() + .map(|proj| { + let Projection { ty, kind } = proj; + let ty = ty.as_ref().try_fold_with(folder)?.store(); + Ok(Projection { ty, kind }) + }) + .collect::>()?; + Ok(Self { base_ty, base, projections }) + } + + fn fold_with>>(self, folder: &mut F) -> Self { + let Self { base_ty, base, projections } = self; + let base_ty = base_ty.as_ref().fold_with(folder).store(); + let projections = projections + .into_iter() + .map(|proj| { + let Projection { ty, kind } = proj; + let ty = ty.as_ref().fold_with(folder).store(); + Projection { ty, kind } + }) + .collect(); + Self { base_ty, base, projections } + } +} + +impl Place { + /// Returns an iterator of the types that have to be dereferenced to access + /// the `Place`. + /// + /// The types are in the reverse order that they are applied. So if + /// `x: &*const u32` and the `Place` is `**x`, then the types returned are + ///`*const u32` then `&*const u32`. + pub fn deref_tys<'db>(&self) -> impl Iterator> { + self.projections.iter().enumerate().rev().filter_map(move |(index, proj)| { + if ProjectionKind::Deref == proj.kind { + Some(self.ty_before_projection(index)) + } else { + None + } + }) + } + + /// Returns the type of this `Place` after all projections have been applied. + pub fn ty<'db>(&self) -> Ty<'db> { + self.projections.last().map_or(self.base_ty.as_ref(), |proj| proj.ty.as_ref()) + } + + /// Returns the type of this `Place` immediately before `projection_index`th projection + /// is applied. + pub fn ty_before_projection<'db>(&self, projection_index: usize) -> Ty<'db> { + assert!(projection_index < self.projections.len()); + if projection_index == 0 { + self.base_ty.as_ref() + } else { + self.projections[projection_index - 1].ty.as_ref() + } + } +} + +/// A `PlaceWithOrigin` represents how a value is located in memory. This does not +/// always correspond to a syntactic place expression. For example, when +/// processing a pattern, a `Place` can be used to refer to the sub-value +/// currently being inspected. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct PlaceWithOrigin { + /// `ExprId`s or `PatId`s of the expressions or patterns producing this value. + pub origins: SmallVec<[CaptureSourceStack; 2]>, + + /// Information about the `Place`. + pub place: Place, +} + +impl PlaceWithOrigin { + fn new_no_projections<'db>( + origin: impl Into, + base_ty: Ty<'db>, + base: PlaceBase, + ) -> PlaceWithOrigin { + Self::new( + smallvec![CaptureSourceStack::from_single(origin.into())], + base_ty, + base, + Vec::new(), + ) + } + + fn new<'db>( + origins: SmallVec<[CaptureSourceStack; 2]>, + base_ty: Ty<'db>, + base: PlaceBase, + projections: Vec, + ) -> PlaceWithOrigin { + debug_assert!(origins.iter().all(|origin| origin.len() == projections.len() + 1)); + PlaceWithOrigin { origins, place: Place { base_ty: base_ty.store(), base, projections } } + } + + fn push_projection(&mut self, projection: Projection, origin: ExprOrPatId) { + self.place.projections.push(projection); + for origin_stack in &mut self.origins { + origin_stack.push(origin); + } + } +} + +/// The `FakeReadCause` describes the type of pattern why a FakeRead statement exists. +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub enum FakeReadCause { + /// A fake read injected into a match guard to ensure that the discriminants + /// that are being matched on aren't modified while the match guard is being + /// evaluated. + /// + /// At the beginning of each match guard, a fake borrow is + /// inserted for each discriminant accessed in the entire `match` statement. + /// + /// Then, at the end of the match guard, a `FakeRead(ForMatchGuard)` is + /// inserted to keep the fake borrows alive until that point. + /// + /// This should ensure that you cannot change the variant for an enum while + /// you are in the midst of matching on it. + ForMatchGuard, + + /// Fake read of the scrutinee of a `match` or destructuring `let` + /// (i.e. `let` with non-trivial pattern). + /// + /// In `match x { ... }`, we generate a `FakeRead(ForMatchedPlace, x)` + /// and insert it into the `otherwise_block` (which is supposed to be + /// unreachable for irrefutable pattern-matches like `match` or `let`). + /// + /// This is necessary because `let x: !; match x {}` doesn't generate any + /// actual read of x, so we need to generate a `FakeRead` to check that it + /// is initialized. + /// + /// If the `FakeRead(ForMatchedPlace)` is being performed with a closure + /// that doesn't capture the required upvars, the `FakeRead` within the + /// closure is omitted entirely. + /// + /// To make sure that this is still sound, if a closure matches against + /// a Place starting with an Upvar, we hoist the `FakeRead` to the + /// definition point of the closure. + /// + /// If the `FakeRead` comes from being hoisted out of a closure like this, + /// we record the `ExprId` of the closure. Otherwise, the `Option` will be `None`. + // + // We can use LocalDefId here since fake read statements are removed + // before codegen in the `CleanupNonCodegenStatements` pass. + ForMatchedPlace(Option), + + /// A fake read injected into a match guard to ensure that the places + /// bound by the pattern are immutable for the duration of the match guard. + /// + /// Within a match guard, references are created for each place that the + /// pattern creates a binding for — this is known as the `RefWithinGuard` + /// version of the variables. To make sure that the references stay + /// alive until the end of the match guard, and properly prevent the + /// places in question from being modified, a `FakeRead(ForGuardBinding)` + /// is inserted at the end of the match guard. + /// + /// For details on how these references are created, see the extensive + /// documentation on `bind_matched_candidate_for_guard` in + /// `rustc_mir_build`. + ForGuardBinding, + + /// Officially, the semantics of + /// + /// `let pattern = ;` + /// + /// is that `` is evaluated into a temporary and then this temporary is + /// into the pattern. + /// + /// However, if we see the simple pattern `let var = `, we optimize this to + /// evaluate `` directly into the variable `var`. This is mostly unobservable, + /// but in some cases it can affect the borrow checker, as in #53695. + /// + /// Therefore, we insert a `FakeRead(ForLet)` immediately after each `let` + /// with a trivial pattern. + /// + /// FIXME: `ExprUseVisitor` has an entirely different opinion on what `FakeRead(ForLet)` + /// is supposed to mean. If it was accurate to what MIR lowering does, + /// would it even make sense to hoist these out of closures like + /// `ForMatchedPlace`? + ForLet(Option), + + /// Currently, index expressions overloaded through the `Index` trait + /// get lowered differently than index expressions with builtin semantics + /// for arrays and slices — the latter will emit code to perform + /// bound checks, and then return a MIR place that will only perform the + /// indexing "for real" when it gets incorporated into an instruction. + /// + /// This is observable in the fact that the following compiles: + /// + /// ``` + /// fn f(x: &mut [&mut [u32]], i: usize) { + /// x[i][x[i].len() - 1] += 1; + /// } + /// ``` + /// + /// However, we need to be careful to not let the user invalidate the + /// bound check with an expression like + /// + /// `(*x)[1][{ x = y; 4}]` + /// + /// Here, the first bounds check would be invalidated when we evaluate the + /// second index expression. To make sure that this doesn't happen, we + /// create a fake borrow of `x` and hold it while we evaluate the second + /// index. + /// + /// This borrow is kept alive by a `FakeRead(ForIndex)` at the end of its + /// scope. + ForIndex, +} + +/// This trait defines the callbacks you can expect to receive when +/// employing the ExprUseVisitor. +pub(crate) trait Delegate<'db> { + /// The value found at `place` is moved, depending + /// on `mode`. Where `diag_expr_id` is the id used for diagnostics for `place`. + /// + /// If the value is `Copy`, [`copy`][Self::copy] is called instead, which + /// by default falls back to [`borrow`][Self::borrow]. + /// + /// The parameter `diag_expr_id` indicates the HIR id that ought to be used for + /// diagnostics. Around pattern matching such as `let pat = expr`, the diagnostic + /// id will be the id of the expression `expr` but the place itself will have + /// the id of the binding in the pattern `pat`. + fn consume(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>); + + /// The value found at `place` is used, depending + /// on `mode`. Where `diag_expr_id` is the id used for diagnostics for `place`. + /// + /// Use of a `Copy` type in a ByUse context is considered a use + /// by `ImmBorrow` and `borrow` is called instead. This is because + /// a shared borrow is the "minimum access" that would be needed + /// to perform a copy. + /// + /// + /// The parameter `diag_expr_id` indicates the HIR id that ought to be used for + /// diagnostics. Around pattern matching such as `let pat = expr`, the diagnostic + /// id will be the id of the expression `expr` but the place itself will have + /// the id of the binding in the pattern `pat`. + fn use_cloned(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>); + + /// The value found at `place` is being borrowed with kind `bk`. + /// `diag_expr_id` is the id used for diagnostics (see `consume` for more details). + fn borrow( + &mut self, + place_with_id: PlaceWithOrigin, + bk: BorrowKind, + ctx: &mut InferenceContext<'_, 'db>, + ); + + /// The value found at `place` is being copied. + /// `diag_expr_id` is the id used for diagnostics (see `consume` for more details). + /// + /// If an implementation is not provided, use of a `Copy` type in a ByValue context is instead + /// considered a use by `ImmBorrow` and `borrow` is called instead. This is because a shared + /// borrow is the "minimum access" that would be needed to perform a copy. + fn copy(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + // In most cases, copying data from `x` is equivalent to doing `*&x`, so by default + // we treat a copy of `x` as a borrow of `x`. + self.borrow(place_with_id, BorrowKind::Immutable, ctx) + } + + /// The path at `assignee_place` is being assigned to. + /// `diag_expr_id` is the id used for diagnostics (see `consume` for more details). + fn mutate(&mut self, assignee_place: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>); + + /// The path at `binding_place` is a binding that is being initialized. + /// + /// This covers cases such as `let x = 42;` + fn bind(&mut self, binding_place: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + // Bindings can normally be treated as a regular assignment, so by default we + // forward this to the mutate callback. + self.mutate(binding_place, ctx) + } + + /// The `place` should be a fake read because of specified `cause`. + fn fake_read( + &mut self, + place_with_id: PlaceWithOrigin, + cause: FakeReadCause, + ctx: &mut InferenceContext<'_, 'db>, + ); +} + +impl<'db, D: Delegate<'db>> Delegate<'db> for &mut D { + fn consume(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + (**self).consume(place_with_id, ctx) + } + + fn use_cloned(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + (**self).use_cloned(place_with_id, ctx) + } + + fn borrow( + &mut self, + place_with_id: PlaceWithOrigin, + bk: BorrowKind, + ctx: &mut InferenceContext<'_, 'db>, + ) { + (**self).borrow(place_with_id, bk, ctx) + } + + fn copy(&mut self, place_with_id: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + (**self).copy(place_with_id, ctx) + } + + fn mutate(&mut self, assignee_place: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + (**self).mutate(assignee_place, ctx) + } + + fn bind(&mut self, binding_place: PlaceWithOrigin, ctx: &mut InferenceContext<'_, 'db>) { + (**self).bind(binding_place, ctx) + } + + fn fake_read( + &mut self, + place_with_id: PlaceWithOrigin, + cause: FakeReadCause, + ctx: &mut InferenceContext<'_, 'db>, + ) { + (**self).fake_read(place_with_id, cause, ctx) + } +} + +/// A visitor that reports how each expression is being used. +/// +/// See [module-level docs][self] and [`Delegate`] for details. +pub(crate) struct ExprUseVisitor<'a, 'b, 'db, D: Delegate<'db>> { + cx: &'a mut InferenceContext<'b, 'db>, + delegate: D, + closure_expr: ExprId, + upvars: UpvarsRef<'db>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PatWalkMode { + /// `let`, `match`. + Declaration, + /// Destructuring assignment. + Assignment, +} + +impl<'a, 'b, 'db, D: Delegate<'db>> ExprUseVisitor<'a, 'b, 'db, D> { + /// Creates the ExprUseVisitor, configuring it with the various options provided: + /// + /// - `delegate` -- who receives the callbacks + /// - `param_env` --- parameter environment for trait lookups (esp. pertaining to `Copy`) + /// - `typeck_results` --- typeck results for the code being analyzed + pub(crate) fn new( + cx: &'a mut InferenceContext<'b, 'db>, + closure_expr: ExprId, + upvars: UpvarsRef<'db>, + delegate: D, + ) -> Self { + ExprUseVisitor { delegate, closure_expr, upvars, cx } + } + + pub(crate) fn consume_closure_body(&mut self, params: &[PatId], body: ExprId) -> Result { + for ¶m in params { + let param_ty = self.pat_ty_adjusted(param)?; + debug!("consume_body: param_ty = {:?}", param_ty); + + let param_place = self.cat_rvalue(param.into(), param_ty); + + self.fake_read_scrutinee(param_place.clone(), false); + self.walk_pat(param_place, param, false, PatWalkMode::Declaration)?; + } + + self.consume_expr(body)?; + + Ok(()) + } + + #[instrument(skip(self), level = "debug")] + fn consume_or_copy(&mut self, place_with_id: PlaceWithOrigin) { + if self.cx.table.type_is_copy_modulo_regions(place_with_id.place.ty()) { + self.delegate.copy(place_with_id, self.cx); + } else { + self.delegate.consume(place_with_id, self.cx); + } + } + + #[instrument(skip(self), level = "debug")] + pub(crate) fn consume_clone_or_copy(&mut self, place_with_id: PlaceWithOrigin) { + // `x.use` will do one of the following + // * if it implements `Copy`, it will be a copy + // * if it implements `UseCloned`, it will be a call to `clone` + // * otherwise, it is a move + // + // we do a conservative approximation of this, treating it as a move unless we know that it implements copy or `UseCloned` + if self.cx.table.type_is_copy_modulo_regions(place_with_id.place.ty()) { + self.delegate.copy(place_with_id, self.cx); + } else if self.cx.table.type_is_use_cloned_modulo_regions(place_with_id.place.ty()) { + self.delegate.use_cloned(place_with_id, self.cx); + } else { + self.delegate.consume(place_with_id, self.cx); + } + } + + fn consume_exprs(&mut self, exprs: &[ExprId]) -> Result { + for &expr in exprs { + self.consume_expr(expr)?; + } + Ok(()) + } + + // FIXME: It's suspicious that this is public; clippy should probably use `walk_expr`. + #[instrument(skip(self), level = "debug")] + pub(crate) fn consume_expr(&mut self, expr: ExprId) -> Result { + let place_with_id = self.cat_expr(expr)?; + self.consume_or_copy(place_with_id); + self.walk_expr(expr)?; + Ok(()) + } + + fn mutate_expr(&mut self, expr: ExprId) -> Result { + let place_with_id = self.cat_expr(expr)?; + self.delegate.mutate(place_with_id, self.cx); + self.walk_expr(expr)?; + Ok(()) + } + + #[instrument(skip(self), level = "debug")] + fn borrow_expr(&mut self, expr: ExprId, bk: BorrowKind) -> Result { + let place_with_id = self.cat_expr(expr)?; + self.delegate.borrow(place_with_id, bk, self.cx); + self.walk_expr(expr)?; + Ok(()) + } + + #[instrument(skip(self), level = "debug")] + pub(crate) fn walk_expr(&mut self, expr: ExprId) -> Result { + self.walk_adjustment(expr)?; + + match self.cx.store[expr] { + Expr::Path(_) => {} + + Expr::UnaryOp { op: UnaryOp::Deref, expr: base } => { + // *base + self.walk_expr(base)?; + } + + Expr::Field { expr: base, .. } => { + // base.f + self.walk_expr(base)?; + } + + Expr::Index { base: lhs, index: rhs } => { + // lhs[rhs] + self.walk_expr(lhs)?; + self.consume_expr(rhs)?; + } + + Expr::Call { callee, ref args } => { + // callee(args) + self.consume_expr(callee)?; + self.consume_exprs(args)?; + } + + Expr::MethodCall { receiver, ref args, .. } => { + // callee.m(args) + self.consume_expr(receiver)?; + self.consume_exprs(args)?; + } + + Expr::RecordLit { ref fields, spread, .. } => { + self.walk_struct_expr(fields, spread)?; + } + + Expr::Tuple { ref exprs } => { + self.consume_exprs(exprs)?; + } + + Expr::If { + condition: cond_expr, + then_branch: then_expr, + else_branch: opt_else_expr, + } => { + self.consume_expr(cond_expr)?; + self.consume_expr(then_expr)?; + if let Some(else_expr) = opt_else_expr { + self.consume_expr(else_expr)?; + } + } + + Expr::Let { pat, expr: init } => { + self.walk_local(init, pat, None, |this| { + this.borrow_expr(init, BorrowKind::Immutable) + })?; + } + + Expr::Match { expr: discr, ref arms } => { + let discr_place = self.cat_expr(discr)?; + self.fake_read_scrutinee(discr_place.clone(), true); + self.walk_expr(discr)?; + + for arm in arms { + self.walk_arm(discr_place.clone(), arm)?; + } + } + + Expr::Array(Array::ElementList { elements: ref exprs }) => { + self.consume_exprs(exprs)?; + } + + Expr::Ref { expr: base, mutability: m, .. } => { + // &base + // make sure that the thing we are pointing out stays valid + // for the lifetime `scope_r` of the resulting ptr: + let bk = BorrowKind::from_hir_mutbl(m); + self.borrow_expr(base, bk)?; + } + + Expr::InlineAsm(ref asm) => { + for (_, op) in &asm.operands { + match *op { + AsmOperand::In { expr, .. } => { + self.consume_expr(expr)?; + } + AsmOperand::Out { expr: Some(expr), .. } + | AsmOperand::InOut { expr, .. } => { + self.mutate_expr(expr)?; + } + AsmOperand::SplitInOut { in_expr, out_expr, .. } => { + self.consume_expr(in_expr)?; + if let Some(out_expr) = out_expr { + self.mutate_expr(out_expr)?; + } + } + AsmOperand::Out { expr: None, .. } + | AsmOperand::Const { .. } + | AsmOperand::Sym { .. } => {} + AsmOperand::Label(block) => { + self.walk_expr(block)?; + } + } + } + } + + Expr::Continue { .. } + | Expr::Literal(..) + | Expr::Const(..) + | Expr::OffsetOf(..) + | Expr::Missing + | Expr::Underscore => {} + + Expr::Loop { body: blk, .. } => { + self.walk_expr(blk)?; + } + + Expr::UnaryOp { expr: lhs, .. } => { + self.consume_expr(lhs)?; + } + + Expr::BinaryOp { + lhs, + rhs, + op: Some(BinaryOp::ArithOp(..) | BinaryOp::CmpOp(..) | BinaryOp::LogicOp(..)), + } => { + self.consume_expr(lhs)?; + self.consume_expr(rhs)?; + } + + Expr::Block { ref statements, tail, .. } + | Expr::Unsafe { ref statements, tail, .. } => { + for stmt in statements { + self.walk_stmt(stmt)?; + } + + if let Some(tail_expr) = tail { + self.consume_expr(tail_expr)?; + } + } + + Expr::Break { expr: opt_expr, .. } | Expr::Return { expr: opt_expr } => { + if let Some(expr) = opt_expr { + self.consume_expr(expr)?; + } + } + + Expr::Become { expr } | Expr::Await { expr } | Expr::Box { expr } => { + self.consume_expr(expr)?; + } + + Expr::Assignment { target, value } => { + self.walk_expr(value)?; + let expr_place = self.cat_expr(value)?; + let update_guard = + self.cx.resolver.update_to_inner_scope(self.cx.db, self.cx.owner, expr); + self.walk_pat(expr_place, target, false, PatWalkMode::Assignment)?; + self.cx.resolver.reset_to_guard(update_guard); + } + + Expr::Cast { expr: base, .. } => { + self.consume_expr(base)?; + } + + Expr::BinaryOp { lhs, rhs, op: None | Some(BinaryOp::Assignment { .. }) } => { + self.consume_expr(lhs)?; + self.consume_expr(rhs)?; + } + + Expr::Array(Array::Repeat { initializer: base, .. }) => { + self.consume_expr(base)?; + } + + Expr::Closure { .. } => { + self.walk_captures(expr); + } + + Expr::Yield { expr: value } | Expr::Yeet { expr: value } => { + if let Some(value) = value { + self.consume_expr(value)?; + } + } + + Expr::Range { lhs, rhs, .. } => { + if let Some(lhs) = lhs { + self.consume_expr(lhs)?; + } + if let Some(rhs) = rhs { + self.consume_expr(rhs)?; + } + } + } + Ok(()) + } + + fn walk_stmt(&mut self, stmt: &Statement) -> Result { + match *stmt { + Statement::Let { pat, initializer: Some(expr), else_branch: els, .. } => { + self.walk_local(expr, pat, els, |_| Ok(()))?; + } + + Statement::Let { .. } => {} + + Statement::Item(_) => { + // We don't visit nested items in this visitor, + // only the fn body we were given. + } + + Statement::Expr { expr, .. } => { + self.consume_expr(expr)?; + } + } + Ok(()) + } + + #[instrument(skip(self), level = "debug")] + fn fake_read_scrutinee(&mut self, discr_place: PlaceWithOrigin, refutable: bool) { + let closure_def_id = match discr_place.place.base { + PlaceBase::Upvar { closure, var_id: _ } => Some(closure), + _ => None, + }; + + let cause = if refutable { + FakeReadCause::ForMatchedPlace(closure_def_id) + } else { + FakeReadCause::ForLet(closure_def_id) + }; + + self.delegate.fake_read(discr_place, cause, self.cx); + } + + fn walk_local(&mut self, expr: ExprId, pat: PatId, els: Option, mut f: F) -> Result + where + F: FnMut(&mut Self) -> Result, + { + self.walk_expr(expr)?; + let expr_place = self.cat_expr(expr)?; + f(self)?; + self.fake_read_scrutinee(expr_place.clone(), els.is_some()); + self.walk_pat(expr_place, pat, false, PatWalkMode::Declaration)?; + if let Some(els) = els { + self.walk_expr(els)?; + } + Ok(()) + } + + fn walk_struct_expr(&mut self, fields: &[RecordLitField], spread: RecordSpread) -> Result { + // Consume the expressions supplying values for each field. + for field in fields { + self.consume_expr(field.expr)?; + } + + let RecordSpread::Expr(with_expr) = spread else { return Ok(()) }; + + let with_place = self.cat_expr(with_expr)?; + + // Select just those fields of the `with` + // expression that will actually be used + match self.cx.table.structurally_resolve_type(with_place.place.ty()).kind() { + TyKind::Adt(adt, args) if adt.is_struct() => { + let AdtId::StructId(adt) = adt.def_id().0 else { unreachable!() }; + let adt_fields = VariantId::from(adt).fields(self.cx.db).fields(); + let adt_field_types = self.cx.db.field_types(adt.into()); + // Consume those fields of the with expression that are needed. + for (f_index, with_field) in adt_fields.iter() { + let is_mentioned = fields.iter().any(|f| f.name == with_field.name); + if !is_mentioned { + let field_place = self.cat_projection( + with_expr.into(), + with_place.clone(), + adt_field_types[f_index].get().instantiate(self.cx.interner(), args), + ProjectionKind::Field { + field_idx: f_index.into_raw().into_u32(), + variant_idx: 0, + }, + ); + self.consume_or_copy(field_place); + } + } + } + _ => {} + } + + // walk the with expression so that complex expressions + // are properly handled. + self.walk_expr(with_expr)?; + + Ok(()) + } + + fn expr_adjustments(&self, expr: ExprId) -> SmallVec<[Adjustment; 5]> { + // Due to borrowck problems, we cannot borrow the adjustments, unfortunately. + self.cx.result.expr_adjustment(expr).unwrap_or_default().into() + } + + /// Invoke the appropriate delegate calls for anything that gets + /// consumed or borrowed as part of the automatic adjustment + /// process. + fn walk_adjustment(&mut self, expr: ExprId) -> Result { + let adjustments = self.expr_adjustments(expr); + let mut place_with_id = self.cat_expr_unadjusted(expr)?; + for adjustment in &adjustments { + debug!("walk_adjustment expr={:?} adj={:?}", expr, adjustment); + match adjustment.kind { + Adjust::NeverToAny | Adjust::Pointer(_) => { + // Creating a closure/fn-pointer or unsizing consumes + // the input and stores it into the resulting rvalue. + self.consume_or_copy(place_with_id.clone()); + } + + Adjust::Deref(None) => {} + + // Autoderefs for overloaded Deref calls in fact reference + // their receiver. That is, if we have `(*x)` where `x` + // is of type `Rc`, then this in fact is equivalent to + // `x.deref()`. Since `deref()` is declared with `&self`, + // this is an autoref of `x`. + Adjust::Deref(Some(ref deref)) => { + let bk = BorrowKind::from_mutbl(deref.0); + self.delegate.borrow(place_with_id.clone(), bk, self.cx); + } + + Adjust::Borrow(ref autoref) => { + self.walk_autoref(expr, place_with_id.clone(), autoref); + } + } + place_with_id = self.cat_expr_adjusted(expr, place_with_id, adjustment)?; + } + Ok(()) + } + + /// Walks the autoref `autoref` applied to the autoderef'd + /// `expr`. `base_place` is `expr` represented as a place, + /// after all relevant autoderefs have occurred. + fn walk_autoref(&mut self, expr: ExprId, base_place: PlaceWithOrigin, autoref: &AutoBorrow) { + debug!("walk_autoref(expr={:?} base_place={:?} autoref={:?})", expr, base_place, autoref); + + match *autoref { + AutoBorrow::Ref(m) => { + self.delegate.borrow(base_place, BorrowKind::from_mutbl(m.into()), self.cx); + } + + AutoBorrow::RawPtr(m) => { + debug!("walk_autoref: expr={:?} base_place={:?}", expr, base_place); + + self.delegate.borrow(base_place, BorrowKind::from_mutbl(m), self.cx); + } + } + } + + fn walk_arm(&mut self, discr_place: PlaceWithOrigin, arm: &MatchArm) -> Result { + self.walk_pat(discr_place, arm.pat, arm.guard.is_some(), PatWalkMode::Declaration)?; + + if let Some(e) = arm.guard { + self.consume_expr(e)?; + } + + self.consume_expr(arm.expr) + } + + /// The core driver for walking a pattern + /// + /// This should mirror how pattern-matching gets lowered to MIR, as + /// otherwise lowering will ICE when trying to resolve the upvars. + /// + /// However, it is okay to approximate it here by doing *more* accesses than + /// the actual MIR builder will, which is useful when some checks are too + /// cumbersome to perform here. For example, if after typeck it becomes + /// clear that only one variant of an enum is inhabited, and therefore a + /// read of the discriminant is not necessary, `walk_pat` will have + /// over-approximated the necessary upvar capture granularity. + /// + /// Do note that discrepancies like these do still create obscure corners + /// in the semantics of the language, and should be avoided if possible. + #[instrument(skip(self), level = "debug")] + fn walk_pat( + &mut self, + discr_place: PlaceWithOrigin, + pat: PatId, + has_guard: bool, + mode: PatWalkMode, + ) -> Result { + self.cat_pattern(discr_place.clone(), pat, &mut |this, place, pat| { + debug!("walk_pat: pat.kind={:?}", this.cx.store[pat]); + let read_discriminant = { + let place = place.clone(); + |this: &mut Self| { + this.delegate.borrow(place, BorrowKind::Immutable, this.cx); + } + }; + + match this.cx.store[pat] { + Pat::Bind { id, .. } => { + debug!("walk_pat: binding place={:?} pat={:?}", place, pat); + let bm = this.cx.result.binding_modes[pat]; + debug!("walk_pat: pat.hir_id={:?} bm={:?}", pat, bm); + + // pat_ty: the type of the binding being produced. + let pat_ty = this.node_ty(pat.into())?; + debug!("walk_pat: pat_ty={:?}", pat_ty); + + if let Ok(binding_place) = this.cat_local(pat.into(), pat_ty, id) { + this.delegate.bind(binding_place, this.cx); + } + + // Subtle: MIR desugaring introduces immutable borrows for each pattern + // binding when lowering pattern guards to ensure that the guard does not + // modify the scrutinee. + if has_guard { + read_discriminant(this); + } + + // It is also a borrow or copy/move of the value being matched. + // In a cases of pattern like `let pat = upvar`, don't use the span + // of the pattern, as this just looks confusing, instead use the span + // of the discriminant. + match this.cx.result.binding_mode(pat) { + Some(BindingMode::Ref(m)) => { + let bk = BorrowKind::from_mutbl(m); + this.delegate.borrow(place, bk, this.cx); + } + None | Some(BindingMode::Move) => { + debug!("walk_pat binding consuming pat"); + this.consume_or_copy(place); + } + } + } + Pat::Path(ref path) => { + // A `Path` pattern is just a name like `Foo`. This is either a + // named constant or else it refers to an ADT variant + + let is_assoc_const = this + .cx + .result + .assoc_resolutions_for_pat(pat) + .is_some_and(|it| matches!(it.0, CandidateId::ConstId(_))); + let resolution = this.cx.resolver.resolve_path_in_value_ns_fully( + this.cx.db, + path, + this.cx.store.pat_path_hygiene(pat), + ); + let is_normal_const = matches!(resolution, Some(ValueNs::ConstId(_))); + if mode == PatWalkMode::Assignment + && let Some(ValueNs::LocalBinding(local)) = resolution + { + let pat_ty = this.pat_ty(pat)?; + let place = this.cat_local(pat.into(), pat_ty, local)?; + this.delegate.mutate(place, this.cx); + } else if is_assoc_const || is_normal_const { + // Named constants have to be equated with the value + // being matched, so that's a read of the value being matched. + // + // FIXME: Does the MIR code skip this read when matching on a ZST? + // If so, we can also skip it here. + read_discriminant(this); + } else if this.is_multivariant_adt(place.place.ty()) { + // Otherwise, this is a struct/enum variant, and so it's + // only a read if we need to read the discriminant. + read_discriminant(this); + } + } + Pat::Lit(_) | Pat::ConstBlock(_) | Pat::Range { .. } => { + // When matching against a literal or range, we need to + // borrow the place to compare it against the pattern. + // + // Note that we do this read even if the range matches all + // possible values, such as 0..=u8::MAX. This is because + // we don't want to depend on consteval here. + // + // FIXME: What if the type being matched only has one + // possible value? + read_discriminant(this); + } + Pat::Record { .. } | Pat::TupleStruct { .. } => { + if this.is_multivariant_adt(place.place.ty()) { + read_discriminant(this); + } + } + Pat::Slice { prefix: ref lhs, slice: wild, suffix: ref rhs } => { + // We don't need to test the length if the pattern is `[..]` + if matches!((&**lhs, wild, &**rhs), (&[], Some(_), &[])) + // Arrays have a statically known size, so + // there is no need to read their length + || place.place.ty().strip_references().is_array() + { + // No read necessary + } else { + read_discriminant(this); + } + } + Pat::Expr(expr) if mode == PatWalkMode::Assignment => { + // Destructuring assignment. + this.mutate_expr(expr)?; + } + Pat::Or(_) + | Pat::Box { .. } + | Pat::Ref { .. } + | Pat::Tuple { .. } + | Pat::Wild + | Pat::Missing => { + // If the PatKind is Or, Box, Ref, Guard, or Tuple, the relevant accesses + // are made later as these patterns contains subpatterns. + // If the PatKind is Missing, Wild or Err, any relevant accesses are made when processing + // the other patterns that are part of the match + } + Pat::Expr(_) => {} + } + + Ok(()) + }) + } + + /// Handle the case where the current body contains a closure. + /// + /// When the current body being handled is a closure, then we must make sure that + /// - The parent closure only captures Places from the nested closure that are not local to it. + /// + /// In the following example the closures `c` only captures `p.x` even though `incr` + /// is a capture of the nested closure + /// + /// ``` + /// struct P { x: i32 } + /// let mut p = P { x: 4 }; + /// let c = || { + /// let incr = 10; + /// let nested = || p.x += incr; + /// }; + /// ``` + /// + /// - When reporting the Place back to the Delegate, ensure that the UpvarId uses the enclosing + /// closure as the DefId. + #[instrument(skip(self), level = "debug")] + fn walk_captures(&mut self, closure_expr: ExprId) { + fn upvar_is_local_variable(upvars: UpvarsRef<'_>, var_id: BindingId) -> bool { + upvars.contains(var_id) + } + + // If we have a nested closure, we want to include the fake reads present in the nested + // closure. + // `remove()` then re-insert and not `get()` due to borrowck errors. + if let Some(closure_data) = self.cx.result.closures_data.remove(&closure_expr) { + for (fake_read, cause, origins) in closure_data.fake_reads.iter() { + match fake_read.base { + PlaceBase::Upvar { var_id, closure: _ } => { + if upvar_is_local_variable(self.upvars, var_id) { + // The nested closure might be fake reading the current (enclosing) closure's local variables. + // The only places we want to fake read before creating the parent closure are the ones that + // are not local to it/ defined by it. + // + // ```rust,ignore(cannot-test-this-because-pseudo-code) + // let v1 = (0, 1); + // let c = || { // fake reads: v1 + // let v2 = (0, 1); + // let e = || { // fake reads: v1, v2 + // let (_, t1) = v1; + // let (_, t2) = v2; + // } + // } + // ``` + // This check is performed when visiting the body of the outermost closure (`c`) and ensures + // that we don't add a fake read of v2 in c. + continue; + } + } + _ => { + panic!( + "Do not know how to get ExprId out of Rvalue and StaticItem {:?}", + fake_read.base + ); + } + }; + self.delegate.fake_read( + PlaceWithOrigin { place: fake_read.clone(), origins: origins.clone() }, + *cause, + self.cx, + ); + } + + for (var_id, min_list) in closure_data.min_captures.iter() { + if !self.upvars.contains(*var_id) { + // The nested closure might be capturing the current (enclosing) closure's local variables. + // We check if the root variable is ever mentioned within the enclosing closure, if not + // then for the current body (if it's a closure) these aren't captures, we will ignore them. + continue; + } + for captured_place in min_list { + let place = &captured_place.place; + let capture_info = &captured_place.info; + + // Mark the place to be captured by the enclosing closure + let place_base = + PlaceBase::Upvar { var_id: *var_id, closure: self.closure_expr }; + let place_with_id = PlaceWithOrigin::new( + capture_info.sources.clone(), + place.base_ty.as_ref(), + place_base, + place.projections.clone(), + ); + + match capture_info.capture_kind { + UpvarCapture::ByValue => { + self.consume_or_copy(place_with_id); + } + UpvarCapture::ByUse => { + self.consume_clone_or_copy(place_with_id); + } + UpvarCapture::ByRef(upvar_borrow) => { + self.delegate.borrow(place_with_id, upvar_borrow, self.cx); + } + } + } + } + + self.cx.result.closures_data.insert(closure_expr, closure_data); + } + } + + fn error_reported_in_ty(&self, ty: Ty<'db>) -> Result { + if ty.is_ty_error() { Err(ErrorGuaranteed) } else { Ok(()) } + } +} + +/// The job of the methods whose name starts with `cat_` is to analyze +/// expressions and construct the corresponding [`Place`]s. The `cat` +/// stands for "categorize", this is a leftover from long ago when +/// places were called "categorizations". +/// +/// Note that a [`Place`] differs somewhat from the expression itself. For +/// example, auto-derefs are explicit. Also, an index `a[b]` is decomposed into +/// two operations: a dereference to reach the array data and then an index to +/// jump forward to the relevant item. +impl<'db, D: Delegate<'db>> ExprUseVisitor<'_, '_, 'db, D> { + fn expect_and_resolve_type(&mut self, ty: Option>) -> Result> { + match ty { + Some(ty) => { + let ty = self.cx.infcx().resolve_vars_if_possible(ty); + self.error_reported_in_ty(ty)?; + Ok(ty) + } + None => Err(ErrorGuaranteed), + } + } + + fn node_ty(&mut self, id: ExprOrPatId) -> Result> { + self.expect_and_resolve_type(self.cx.result.type_of_expr_or_pat(id)) + } + + fn expr_ty(&mut self, expr: ExprId) -> Result> { + self.node_ty(expr.into()) + } + + fn pat_ty(&mut self, pat: PatId) -> Result> { + self.node_ty(pat.into()) + } + + fn expr_ty_adjusted(&mut self, expr: ExprId) -> Result> { + self.expect_and_resolve_type(self.cx.result.type_of_expr_with_adjust(expr)) + } + + /// Returns the type of value that this pattern matches against. + /// Some non-obvious cases: + /// + /// - a `ref x` binding matches against a value of type `T` and gives + /// `x` the type `&T`; we return `T`. + /// - a pattern with implicit derefs (thanks to default binding + /// modes #42640) may look like `Some(x)` but in fact have + /// implicit deref patterns attached (e.g., it is really + /// `&Some(x)`). In that case, we return the "outermost" type + /// (e.g., `&Option`). + fn pat_ty_adjusted(&mut self, pat: PatId) -> Result> { + // Check for implicit `&` types wrapping the pattern; note + // that these are never attached to binding patterns, so + // actually this is somewhat "disjoint" from the code below + // that aims to account for `ref x`. + if let Some(vec) = self.cx.result.pat_adjustments.get(&pat) + && let Some(first_adjust) = vec.first() + { + debug!("pat_ty(pat={:?}) found adjustment `{:?}`", pat, first_adjust); + return Ok(first_adjust.as_ref()); + } + self.pat_ty_unadjusted(pat) + } + + /// Like [`Self::pat_ty_adjusted`], but ignores implicit `&` patterns. + fn pat_ty_unadjusted(&mut self, pat: PatId) -> Result> { + Ok(self.cx.result.pat_ty(pat)) + } + + fn cat_expr(&mut self, expr: ExprId) -> Result { + self.cat_expr_(expr, &self.expr_adjustments(expr)) + } + + /// This recursion helper avoids going through *too many* + /// adjustments, since *only* non-overloaded deref recurses. + fn cat_expr_(&mut self, expr: ExprId, adjustments: &[Adjustment]) -> Result { + match adjustments.split_last() { + None => self.cat_expr_unadjusted(expr), + Some((adjustment, previous)) => { + self.cat_expr_adjusted_with(expr, |this| this.cat_expr_(expr, previous), adjustment) + } + } + } + + fn cat_expr_adjusted( + &mut self, + expr: ExprId, + previous: PlaceWithOrigin, + adjustment: &Adjustment, + ) -> Result { + self.cat_expr_adjusted_with(expr, |_this| Ok(previous), adjustment) + } + + fn cat_expr_adjusted_with( + &mut self, + expr: ExprId, + previous: F, + adjustment: &Adjustment, + ) -> Result + where + F: FnOnce(&mut Self) -> Result, + { + let target = self.cx.infcx().resolve_vars_if_possible(adjustment.target.as_ref()); + match adjustment.kind { + Adjust::Deref(overloaded) => { + // Equivalent to *expr or something similar. + let base = if let Some(deref) = overloaded { + let ref_ty = Ty::new_ref( + self.cx.interner(), + self.cx.types.regions.erased, + target, + deref.0, + ); + self.cat_rvalue(expr.into(), ref_ty) + } else { + previous(self)? + }; + self.cat_deref(expr.into(), base) + } + + Adjust::NeverToAny | Adjust::Pointer(_) | Adjust::Borrow(_) => { + // Result is an rvalue. + Ok(self.cat_rvalue(expr.into(), target)) + } + } + } + + fn cat_expr_unadjusted(&mut self, expr: ExprId) -> Result { + let expr_ty = self.expr_ty(expr)?; + match self.cx.store[expr] { + Expr::UnaryOp { expr: e_base, op: UnaryOp::Deref } => { + if self.cx.result.method_resolutions.contains_key(&expr) { + self.cat_overloaded_place(expr, e_base) + } else { + let base = self.cat_expr(e_base)?; + self.cat_deref(expr.into(), base) + } + } + + Expr::Field { expr: base, .. } => { + let base = self.cat_expr(base)?; + debug!(?base); + + let field_idx = self + .cx + .result + .field_resolutions + .get(&expr) + .map(|field| match *field { + Either::Left(field) => field.local_id.into_raw().into_u32(), + Either::Right(tuple_field) => tuple_field.index, + }) + .ok_or(ErrorGuaranteed)?; + + Ok(self.cat_projection( + expr.into(), + base, + expr_ty, + ProjectionKind::Field { field_idx, variant_idx: 0 }, + )) + } + + Expr::Index { base, index: _ } => { + // rustc checks if this is an overloaded index, but the check is buggy and treats any indexing + // as overloaded, see https://rust-lang.zulipchat.com/#narrow/channel/144729-t-types/topic/.E2.9C.94.20Is.20builtin.20indexing.20any.20special.20in.20typeck.3F/near/565881390. + // So that's what we do here. + self.cat_overloaded_place(expr, base) + } + + Expr::Path(ref path) => { + let resolver_guard = + self.cx.resolver.update_to_inner_scope(self.cx.db, self.cx.owner, expr); + let resolution = self.cx.resolver.resolve_path_in_value_ns_fully( + self.cx.db, + path, + self.cx.store.expr_path_hygiene(expr), + ); + self.cx.resolver.reset_to_guard(resolver_guard); + match (resolution, self.cx.result.assoc_resolutions_for_expr(expr)) { + (_, Some((CandidateId::FunctionId(_) | CandidateId::ConstId(_), _))) + | ( + Some( + ValueNs::ConstId(_) + | ValueNs::GenericParam(_) + | ValueNs::FunctionId(_) + | ValueNs::ImplSelf(_) + | ValueNs::EnumVariantId(_) + | ValueNs::StructId(_), + ), + None, + ) => Ok(self.cat_rvalue(expr.into(), expr_ty)), + (Some(ValueNs::StaticId(_)), None) => Ok(PlaceWithOrigin::new_no_projections( + expr, + expr_ty, + PlaceBase::StaticItem, + )), + (Some(ValueNs::LocalBinding(var_id)), None) => { + self.cat_local(expr.into(), expr_ty, var_id) + } + (None, None) => Err(ErrorGuaranteed), + } + } + + _ => Ok(self.cat_rvalue(expr.into(), expr_ty)), + } + } + + fn cat_local( + &mut self, + id: ExprOrPatId, + expr_ty: Ty<'db>, + var_id: BindingId, + ) -> Result { + if self.upvars.contains(var_id) { + self.cat_upvar(id, var_id) + } else { + Ok(PlaceWithOrigin::new_no_projections(id, expr_ty, PlaceBase::Local(var_id))) + } + } + + /// Categorize an upvar. + /// + /// Note: the actual upvar access contains invisible derefs of closure + /// environment and upvar reference as appropriate. Only regionck cares + /// about these dereferences, so we let it compute them as needed. + fn cat_upvar(&mut self, hir_id: ExprOrPatId, var_id: BindingId) -> Result { + let var_ty = self.expect_and_resolve_type( + self.cx.result.type_of_binding.get(var_id).map(|it| it.as_ref()), + )?; + + Ok(PlaceWithOrigin::new_no_projections( + hir_id, + var_ty, + PlaceBase::Upvar { closure: self.closure_expr, var_id }, + )) + } + + fn cat_rvalue(&self, hir_id: ExprOrPatId, expr_ty: Ty<'db>) -> PlaceWithOrigin { + PlaceWithOrigin::new_no_projections(hir_id, expr_ty, PlaceBase::Rvalue) + } + + fn cat_projection( + &self, + node: ExprOrPatId, + mut base_place: PlaceWithOrigin, + ty: Ty<'db>, + kind: ProjectionKind, + ) -> PlaceWithOrigin { + base_place.push_projection(Projection { kind, ty: ty.store() }, node); + base_place + } + + fn cat_overloaded_place(&mut self, expr: ExprId, base: ExprId) -> Result { + // Reconstruct the output assuming it's a reference with the + // same region and mutability as the receiver. This holds for + // `Deref(Mut)::Deref(_mut)` and `Index(Mut)::index(_mut)`. + let place_ty = self.expr_ty(expr)?; + let base_ty = self.expr_ty_adjusted(base)?; + + let TyKind::Ref(region, _, mutbl) = self.cx.table.structurally_resolve_type(base_ty).kind() + else { + return Err(ErrorGuaranteed); + }; + let ref_ty = Ty::new_ref(self.cx.interner(), region, place_ty, mutbl); + + let base = self.cat_rvalue(expr.into(), ref_ty); + self.cat_deref(expr.into(), base) + } + + fn cat_deref( + &mut self, + node: ExprOrPatId, + mut base_place: PlaceWithOrigin, + ) -> Result { + let base_curr_ty = base_place.place.ty(); + let Some(deref_ty) = + self.cx.table.structurally_resolve_type(base_curr_ty).builtin_deref(true) + else { + debug!("explicit deref of non-derefable type: {:?}", base_curr_ty); + return Err(ErrorGuaranteed); + }; + base_place.push_projection( + Projection { kind: ProjectionKind::Deref, ty: deref_ty.store() }, + node, + ); + Ok(base_place) + } + + /// Returns the variant index for an ADT used within a Struct or TupleStruct pattern + /// Here `pat_hir_id` is the ExprId of the pattern itself. + fn variant_index_for_adt(&self, pat_id: PatId) -> Result<(u32, VariantId)> { + let variant = self.cx.result.variant_resolution_for_pat(pat_id).ok_or(ErrorGuaranteed)?; + let variant_idx = match variant { + VariantId::EnumVariantId(variant) => variant.loc(self.cx.db).index, + VariantId::StructId(_) | VariantId::UnionId(_) => 0, + }; + Ok((variant_idx, variant)) + } + + /// Returns the total number of fields in a tuple used within a Tuple pattern. + /// Here `pat_hir_id` is the ExprId of the pattern itself. + fn total_fields_in_tuple(&mut self, pat_id: PatId) -> usize { + let ty = self.cx.result.pat_ty(pat_id); + match self.cx.table.structurally_resolve_type(ty).kind() { + TyKind::Tuple(args) => args.len(), + _ => panic!("tuple pattern not applied to a tuple"), + } + } + + /// Here, `place` is the `PlaceWithId` being matched and pat is the pattern it + /// is being matched against. + /// + /// In general, the way that this works is that we walk down the pattern, + /// constructing a `PlaceWithId` that represents the path that will be taken + /// to reach the value being matched. + fn cat_pattern( + &mut self, + mut place_with_id: PlaceWithOrigin, + pat: PatId, + op: &mut F, + ) -> Result + where + F: FnMut(&mut Self, PlaceWithOrigin, PatId) -> Result, + { + // If (pattern) adjustments are active for this pattern, adjust the `PlaceWithId` correspondingly. + // `PlaceWithId`s are constructed differently from patterns. For example, in + // + // ``` + // match foo { + // &&Some(x, ) => { ... }, + // _ => { ... }, + // } + // ``` + // + // the pattern `&&Some(x,)` is represented as `Ref { Ref { TupleStruct }}`. To build the + // corresponding `PlaceWithId` we start with the `PlaceWithId` for `foo`, and then, by traversing the + // pattern, try to answer the question: given the address of `foo`, how is `x` reached? + // + // `&&Some(x,)` `place_foo` + // `&Some(x,)` `deref { place_foo}` + // `Some(x,)` `deref { deref { place_foo }}` + // `(x,)` `field0 { deref { deref { place_foo }}}` <- resulting place + // + // The above example has no adjustments. If the code were instead the (after adjustments, + // equivalent) version + // + // ``` + // match foo { + // Some(x, ) => { ... }, + // _ => { ... }, + // } + // ``` + // + // Then we see that to get the same result, we must start with + // `deref { deref { place_foo }}` instead of `place_foo` since the pattern is now `Some(x,)` + // and not `&&Some(x,)`, even though its assigned type is that of `&&Some(x,)`. + let adjustments_len = self.cx.result.pat_adjustment(pat).map_or(0, |it| it.len()); + for _ in 0..adjustments_len { + debug!("applying adjustment to place_with_id={:?}", place_with_id); + // FIXME: We need to adjust this once we implement deref patterns (or pin ergonomics, for that matter). + place_with_id = self.cat_deref(pat.into(), place_with_id)?; + } + let place_with_id = place_with_id; // lose mutability + debug!("applied adjustment derefs to get place_with_id={:?}", place_with_id); + + // Invoke the callback, but only now, after the `place_with_id` has adjusted. + // + // To see that this makes sense, consider `match &Some(3) { Some(x) => { ... }}`. In that + // case, the initial `place_with_id` will be that for `&Some(3)` and the pattern is `Some(x)`. We + // don't want to call `op` with these incompatible values. As written, what happens instead + // is that `op` is called with the adjusted place (that for `*&Some(3)`) and the pattern + // `Some(x)` (which matches). Recursing once more, `*&Some(3)` and the pattern `Some(x)` + // result in the place `Downcast(*&Some(3)).0` associated to `x` and invoke `op` with + // that (where the `ref` on `x` is implied). + op(self, place_with_id.clone(), pat)?; + + match self.cx.store[pat] { + Pat::Tuple { args: ref subpats, ellipsis: dots_pos } => { + // (p1, ..., pN) + let total_fields = self.total_fields_in_tuple(pat); + + for (i, &subpat) in subpats.iter().enumerate_and_adjust(total_fields, dots_pos) { + let subpat_ty = self.pat_ty_adjusted(subpat)?; + let projection_kind = + ProjectionKind::Field { field_idx: i as u32, variant_idx: 0 }; + let sub_place = self.cat_projection( + pat.into(), + place_with_id.clone(), + subpat_ty, + projection_kind, + ); + self.cat_pattern(sub_place, subpat, op)?; + } + } + + Pat::TupleStruct { args: ref subpats, ellipsis: dots_pos, .. } => { + // S(p1, ..., pN) + let (variant_index, variant) = self.variant_index_for_adt(pat)?; + let total_fields = variant.fields(self.cx.db).len(); + + for (i, &subpat) in subpats.iter().enumerate_and_adjust(total_fields, dots_pos) { + let subpat_ty = self.pat_ty_adjusted(subpat)?; + let projection_kind = + ProjectionKind::Field { variant_idx: variant_index, field_idx: i as u32 }; + let sub_place = self.cat_projection( + pat.into(), + place_with_id.clone(), + subpat_ty, + projection_kind, + ); + self.cat_pattern(sub_place, subpat, op)?; + } + } + + Pat::Record { args: ref field_pats, .. } => { + // S { f1: p1, ..., fN: pN } + + let (variant_index, variant) = self.variant_index_for_adt(pat)?; + let fields = variant.fields(self.cx.db); + + for fp in field_pats { + let field_ty = self.pat_ty_adjusted(fp.pat)?; + let field_index = fields.field(&fp.name).ok_or(ErrorGuaranteed)?; + + let field_place = self.cat_projection( + pat.into(), + place_with_id.clone(), + field_ty, + ProjectionKind::Field { + variant_idx: variant_index, + field_idx: field_index.into_raw().into_u32(), + }, + ); + self.cat_pattern(field_place, fp.pat, op)?; + } + } + + Pat::Or(ref pats) => { + for &pat in pats { + self.cat_pattern(place_with_id.clone(), pat, op)?; + } + } + + Pat::Bind { subpat: Some(subpat), .. } => { + self.cat_pattern(place_with_id, subpat, op)?; + } + + Pat::Box { inner: subpat } | Pat::Ref { pat: subpat, .. } => { + // box p1, &p1, &mut p1. we can ignore the mutability of + // PatKind::Ref since that information is already contained + // in the type. + let subplace = self.cat_deref(pat.into(), place_with_id)?; + self.cat_pattern(subplace, subpat, op)?; + } + + Pat::Slice { prefix: ref before, slice, suffix: ref after } => { + let Some(element_ty) = self + .cx + .table + .structurally_resolve_type(place_with_id.place.ty()) + .builtin_index() + else { + debug!("explicit index of non-indexable type {:?}", place_with_id); + panic!("explicit index of non-indexable type"); + }; + let elt_place = self.cat_projection( + pat.into(), + place_with_id.clone(), + element_ty, + ProjectionKind::Index, + ); + for &before_pat in before { + self.cat_pattern(elt_place.clone(), before_pat, op)?; + } + if let Some(slice_pat) = slice { + let slice_pat_ty = self.pat_ty_adjusted(slice_pat)?; + let slice_place = self.cat_projection( + pat.into(), + place_with_id, + slice_pat_ty, + ProjectionKind::Subslice, + ); + self.cat_pattern(slice_place, slice_pat, op)?; + } + for &after_pat in after { + self.cat_pattern(elt_place.clone(), after_pat, op)?; + } + } + + Pat::Bind { subpat: None, .. } + | Pat::Expr(..) + | Pat::Path(_) + | Pat::Lit(..) + | Pat::ConstBlock(..) + | Pat::Range { .. } + | Pat::Missing + | Pat::Wild => { + // always ok + } + } + + Ok(()) + } + + /// Checks whether a type has multiple variants, and therefore, whether a + /// read of the discriminant might be necessary. Note that the actual MIR + /// builder code does a more specific check, filtering out variants that + /// happen to be uninhabited. + /// + /// Here, it is not practical to perform such a check, because inhabitedness + /// queries require typeck results, and typeck requires closure capture analysis. + /// + /// Moreover, the language is moving towards uninhabited variants still semantically + /// causing a discriminant read, so we *shouldn't* perform any such check. + /// + /// FIXME(never_patterns): update this comment once the aforementioned MIR builder + /// code is changed to be insensitive to inhhabitedness. + #[instrument(skip(self), level = "debug")] + fn is_multivariant_adt(&mut self, ty: Ty<'db>) -> bool { + if let TyKind::Adt(def, _) = self.cx.table.structurally_resolve_type(ty).kind() { + // Note that if a non-exhaustive SingleVariant is defined in another crate, we need + // to assume that more cases will be added to the variant in the future. This mean + // that we should handle non-exhaustive SingleVariant the same way we would handle + // a MultiVariant. + match def.def_id().0 { + AdtId::StructId(_) | AdtId::UnionId(_) => false, + AdtId::EnumId(did) => { + let has_foreign_non_exhaustive = || { + AttrFlags::query(self.cx.db, did.into()).contains(AttrFlags::NON_EXHAUSTIVE) + && did.krate(self.cx.db) != self.cx.krate() + }; + did.enum_variants(self.cx.db).variants.len() > 1 || has_foreign_non_exhaustive() + } + } + } else { + false + } + } +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/coerce.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/coerce.rs index 47a70492487e1..732a583047494 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/coerce.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/coerce.rs @@ -35,22 +35,23 @@ //! // and are then unable to coerce `&7i32` to `&mut i32`. //! ``` +use std::ops::ControlFlow; + use hir_def::{ - CallableDefId, + CallableDefId, TraitId, attrs::AttrFlags, hir::{ExprId, ExprOrPatId}, signatures::FunctionSignature, }; use rustc_ast_ir::Mutability; use rustc_type_ir::{ - BoundVar, DebruijnIndex, TyVid, TypeAndMut, TypeFoldable, TypeFolder, TypeSuperFoldable, - TypeVisitableExt, + BoundVar, DebruijnIndex, InferTy, TyVid, TypeAndMut, TypeFoldable, TypeFolder, + TypeSuperFoldable, TypeVisitableExt, error::TypeError, - inherent::{ - Const as _, GenericArg as _, GenericArgs as _, IntoKind, Safety as _, SliceLike, Ty as _, - }, + inherent::{Const as _, GenericArg as _, GenericArgs as _, IntoKind, Safety as _, Ty as _}, + solve::{Certainty, NoSolution}, }; -use smallvec::{SmallVec, smallvec}; +use smallvec::SmallVec; use tracing::{debug, instrument}; use crate::{ @@ -62,16 +63,16 @@ use crate::{ }, next_solver::{ Binder, BoundConst, BoundRegion, BoundRegionKind, BoundTy, BoundTyKind, CallableIdWrapper, - Canonical, ClauseKind, CoercePredicate, Const, ConstKind, DbInterner, ErrorGuaranteed, - GenericArgs, ParamEnv, PolyFnSig, PredicateKind, Region, RegionKind, TraitRef, Ty, TyKind, + Canonical, CoercePredicate, Const, ConstKind, DbInterner, ErrorGuaranteed, GenericArgs, + Goal, ParamEnv, PolyFnSig, PredicateKind, Region, RegionKind, TraitRef, Ty, TyKind, TypingMode, abi::Safety, infer::{ DbInternerInferExt, InferCtxt, InferOk, InferResult, relate::RelateResult, - select::{ImplSource, SelectionError}, - traits::{Obligation, ObligationCause, PredicateObligation, PredicateObligations}, + traits::{Obligation, ObligationCause, PredicateObligations}, }, + inspect::{InspectGoal, ProofTreeVisitor}, obligation_ctxt::ObligationCtxt, }, upvars::upvars_mentioned, @@ -85,9 +86,7 @@ trait CoerceDelegate<'db> { fn set_diverging(&mut self, diverging_ty: Ty<'db>); - fn set_tainted_by_errors(&mut self); - - fn type_var_is_sized(&mut self, var: TyVid) -> bool; + fn type_var_is_sized(&self, var: TyVid) -> bool; } struct Coerce { @@ -129,11 +128,6 @@ impl<'db, D> Coerce where D: CoerceDelegate<'db>, { - #[inline] - fn set_tainted_by_errors(&mut self) { - self.delegate.set_tainted_by_errors(); - } - #[inline] fn infcx(&self) -> &InferCtxt<'db> { self.delegate.infcx() @@ -680,125 +674,30 @@ where // Create an obligation for `Source: CoerceUnsized`. let cause = self.cause.clone(); - - // Use a FIFO queue for this custom fulfillment procedure. - // - // A Vec (or SmallVec) is not a natural choice for a queue. However, - // this code path is hot, and this queue usually has a max length of 1 - // and almost never more than 3. By using a SmallVec we avoid an - // allocation, at the (very small) cost of (occasionally) having to - // shift subsequent elements down when removing the front element. - let mut queue: SmallVec<[PredicateObligation<'db>; 4]> = smallvec![Obligation::new( + let pred = TraitRef::new( self.interner(), - cause, - self.param_env(), - TraitRef::new( - self.interner(), - coerce_unsized_did.into(), - [coerce_source, coerce_target] + coerce_unsized_did.into(), + [coerce_source, coerce_target], + ); + let obligation = Obligation::new(self.interner(), cause, self.param_env(), pred); + + coercion.obligations.push(obligation); + + if self + .delegate + .infcx() + .visit_proof_tree( + Goal::new(self.infcx().interner, self.param_env(), pred), + &mut CoerceVisitor { + delegate: &self.delegate, + errored: false, + unsize_did, + coerce_unsized_did, + }, ) - )]; - // Keep resolving `CoerceUnsized` and `Unsize` predicates to avoid - // emitting a coercion in cases like `Foo<$1>` -> `Foo<$2>`, where - // inference might unify those two inner type variables later. - let traits = [coerce_unsized_did, unsize_did]; - while !queue.is_empty() { - let obligation = queue.remove(0); - let trait_pred = match obligation.predicate.kind().no_bound_vars() { - Some(PredicateKind::Clause(ClauseKind::Trait(trait_pred))) - if traits.contains(&trait_pred.def_id().0) => - { - self.infcx().resolve_vars_if_possible(trait_pred) - } - // Eagerly process alias-relate obligations in new trait solver, - // since these can be emitted in the process of solving trait goals, - // but we need to constrain vars before processing goals mentioning - // them. - Some(PredicateKind::AliasRelate(..)) => { - let mut ocx = ObligationCtxt::new(self.infcx()); - ocx.register_obligation(obligation); - if !ocx.try_evaluate_obligations().is_empty() { - return Err(TypeError::Mismatch); - } - coercion.obligations.extend(ocx.into_pending_obligations()); - continue; - } - _ => { - coercion.obligations.push(obligation); - continue; - } - }; - debug!("coerce_unsized resolve step: {:?}", trait_pred); - match self.infcx().select(&obligation.with(self.interner(), trait_pred)) { - // Uncertain or unimplemented. - Ok(None) => { - if trait_pred.def_id().0 == unsize_did { - let self_ty = trait_pred.self_ty(); - let unsize_ty = trait_pred.trait_ref.args[1].expect_ty(); - debug!("coerce_unsized: ambiguous unsize case for {:?}", trait_pred); - match (self_ty.kind(), unsize_ty.kind()) { - (TyKind::Infer(rustc_type_ir::TyVar(v)), TyKind::Dynamic(..)) - if self.delegate.type_var_is_sized(v) => - { - debug!("coerce_unsized: have sized infer {:?}", v); - coercion.obligations.push(obligation); - // `$0: Unsize` where we know that `$0: Sized`, try going - // for unsizing. - } - _ => { - // Some other case for `$0: Unsize`. Note that we - // hit this case even if `Something` is a sized type, so just - // don't do the coercion. - debug!("coerce_unsized: ambiguous unsize"); - return Err(TypeError::Mismatch); - } - } - } else { - debug!("coerce_unsized: early return - ambiguous"); - if !coerce_source.references_non_lt_error() - && !coerce_target.references_non_lt_error() - { - // rustc always early-returns here, even when the types contains errors. However not bailing - // improves error recovery, and while we don't implement generic consts properly, it also helps - // correct code. - return Err(TypeError::Mismatch); - } - } - } - Err(SelectionError::Unimplemented) => { - debug!("coerce_unsized: early return - can't prove obligation"); - return Err(TypeError::Mismatch); - } - - Err(SelectionError::TraitDynIncompatible(_)) => { - // Dyn compatibility errors in coercion will *always* be due to the - // fact that the RHS of the coercion is a non-dyn compatible `dyn Trait` - // written in source somewhere (otherwise we will never have lowered - // the dyn trait from HIR to middle). - // - // There's no reason to emit yet another dyn compatibility error, - // especially since the span will differ slightly and thus not be - // deduplicated at all! - self.set_tainted_by_errors(); - } - Err(_err) => { - // FIXME: Report an error: - // let guar = self.err_ctxt().report_selection_error( - // obligation.clone(), - // &obligation, - // &err, - // ); - self.set_tainted_by_errors(); - // Treat this like an obligation and follow through - // with the unsizing - the lack of a coercion should - // be silent, as it causes a type mismatch later. - } - - Ok(Some(ImplSource::UserDefined(impl_source))) => { - queue.extend(impl_source.nested); - } - Ok(Some(impl_source)) => queue.extend(impl_source.nested_obligations()), - } + .is_break() + { + return Err(TypeError::Mismatch); } Ok(coercion) @@ -983,12 +882,7 @@ impl<'db> CoerceDelegate<'db> for InferenceCoercionDelegate<'_, '_, 'db> { } #[inline] - fn set_tainted_by_errors(&mut self) { - self.0.set_tainted_by_errors(); - } - - #[inline] - fn type_var_is_sized(&mut self, var: TyVid) -> bool { + fn type_var_is_sized(&self, var: TyVid) -> bool { self.0.table.type_var_is_sized(var) } } @@ -1560,8 +1454,7 @@ impl<'db> CoerceDelegate<'db> for HirCoercionDelegate<'_, 'db> { (self.target_features, TargetFeatureIsSafeInTarget::No) } fn set_diverging(&mut self, _diverging_ty: Ty<'db>) {} - fn set_tainted_by_errors(&mut self) {} - fn type_var_is_sized(&mut self, _var: TyVid) -> bool { + fn type_var_is_sized(&self, _var: TyVid) -> bool { false } } @@ -1663,7 +1556,7 @@ fn coerce<'db>( Const::new_bound( self.interner, self.debruijn, - BoundConst { var: BoundVar::from_usize(i) }, + BoundConst::new(BoundVar::from_usize(i)), ) }, ) @@ -1718,9 +1611,84 @@ fn coerce<'db>( fn is_capturing_closure(db: &dyn HirDatabase, closure: InternedClosureId) -> bool { let InternedClosure(owner, expr) = closure.loc(db); - let Some(body_owner) = owner.as_def_with_body() else { - return false; - }; - upvars_mentioned(db, body_owner) + upvars_mentioned(db, owner) .is_some_and(|upvars| upvars.get(&expr).is_some_and(|upvars| !upvars.is_empty())) } + +/// Recursively visit goals to decide whether an unsizing is possible. +/// `Break`s when it isn't, and an error should be raised. +/// `Continue`s when an unsizing ok based on an implementation of the `Unsize` trait / lang item. +struct CoerceVisitor<'a, D> { + delegate: &'a D, + /// Whether the coercion is impossible. If so we sometimes still try to + /// coerce in these cases to emit better errors. This changes the behavior + /// when hitting the recursion limit. + errored: bool, + unsize_did: TraitId, + coerce_unsized_did: TraitId, +} + +impl<'a, 'db, D: CoerceDelegate<'db>> ProofTreeVisitor<'db> for CoerceVisitor<'a, D> { + type Result = ControlFlow<()>; + + fn visit_goal(&mut self, goal: &InspectGoal<'_, 'db>) -> Self::Result { + let Some(pred) = goal.goal().predicate.as_trait_clause() else { + return ControlFlow::Continue(()); + }; + + // Make sure this predicate is referring to either an `Unsize` or `CoerceUnsized` trait, + // Otherwise there's nothing to do. + let def_id = pred.def_id().0; + if def_id != self.unsize_did && def_id != self.coerce_unsized_did { + return ControlFlow::Continue(()); + } + + match goal.result() { + // If we prove the `Unsize` or `CoerceUnsized` goal, continue recursing. + Ok(Certainty::Yes) => ControlFlow::Continue(()), + Err(NoSolution) => { + self.errored = true; + // Even if we find no solution, continue recursing if we find a single candidate + // for which we're shallowly certain it holds to get the right error source. + if let [only_candidate] = &goal.candidates()[..] + && only_candidate.shallow_certainty() == Certainty::Yes + { + only_candidate.visit_nested_no_probe(self) + } else { + ControlFlow::Break(()) + } + } + Ok(Certainty::Maybe { .. }) => { + // FIXME: structurally normalize? + if def_id == self.unsize_did + && let TyKind::Dynamic(..) = pred.skip_binder().trait_ref.args.type_at(1).kind() + && let TyKind::Infer(InferTy::TyVar(vid)) = pred.self_ty().skip_binder().kind() + && self.delegate.type_var_is_sized(vid) + { + // We get here when trying to unsize a type variable to a `dyn Trait`, + // knowing that that variable is sized. Unsizing definitely has to happen in that case. + // If the variable weren't sized, we may not need an unsizing coercion. + // In general, we don't want to add coercions too eagerly since it makes error messages much worse. + ControlFlow::Continue(()) + } else if let Some(cand) = goal.unique_applicable_candidate() + && cand.shallow_certainty() == Certainty::Yes + { + cand.visit_nested_no_probe(self) + } else { + ControlFlow::Break(()) + } + } + } + } + + fn on_recursion_limit(&mut self) -> Self::Result { + if self.errored { + // This prevents accidentally committing unfulfilled unsized coercions while trying to + // find the error source for diagnostics. + // See https://github.com/rust-lang/trait-system-refactor-initiative/issues/266. + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs index ee34a30ebaaf0..d80ea71674775 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/expr.rs @@ -11,23 +11,20 @@ use hir_def::{ InlineAsmKind, LabelId, Literal, Pat, PatId, RecordSpread, Statement, UnaryOp, }, resolver::ValueNs, - signatures::{FunctionSignature, VariantFields}, + signatures::VariantFields, }; use hir_def::{FunctionId, hir::ClosureKind}; use hir_expand::name::Name; use rustc_ast_ir::Mutability; use rustc_type_ir::{ - CoroutineArgs, CoroutineArgsParts, InferTy, Interner, + InferTy, Interner, inherent::{AdtDef, GenericArgs as _, IntoKind, Ty as _}, }; use syntax::ast::RangeOp; use tracing::debug; use crate::{ - Adjust, Adjustment, CallableDefId, DeclContext, DeclOrigin, Rawness, - autoderef::InferenceContextAutoderef, - consteval, - db::InternedCoroutine, + Adjust, Adjustment, CallableDefId, DeclContext, DeclOrigin, Rawness, consteval, generics::generics, infer::{ AllowTwoPhase, BreakableKind, coerce::CoerceMany, find_continuable, @@ -36,7 +33,7 @@ use crate::{ lower::{GenericPredicates, lower_mutability}, method_resolution::{self, CandidateId, MethodCallee, MethodError}, next_solver::{ - ErrorGuaranteed, FnSig, GenericArg, GenericArgs, TraitRef, Ty, TyKind, TypeError, + ClauseKind, FnSig, GenericArg, GenericArgs, TraitRef, Ty, TyKind, TypeError, infer::{ BoundRegionConversionTime, InferOk, traits::{Obligation, ObligationCause}, @@ -44,7 +41,6 @@ use crate::{ obligation_ctxt::ObligationCtxt, util::clauses_as_obligations, }, - traits::FnTrait, }; use super::{ @@ -244,7 +240,6 @@ impl<'db> InferenceContext<'_, 'db> { | Expr::Assignment { .. } | Expr::Yield { .. } | Expr::Cast { .. } - | Expr::Async { .. } | Expr::Unsafe { .. } | Expr::Await { .. } | Expr::Ref { .. } @@ -390,9 +385,6 @@ impl<'db> InferenceContext<'_, 'db> { }) .1 } - Expr::Async { id: _, statements, tail } => { - self.infer_async_block(tgt_expr, statements, tail) - } &Expr::Loop { body, label } => { // FIXME: should be: // let ty = expected.coercion_target_type(&mut self.table); @@ -1184,134 +1176,6 @@ impl<'db> InferenceContext<'_, 'db> { } oprnd_t } - - fn infer_async_block( - &mut self, - tgt_expr: ExprId, - statements: &[Statement], - tail: &Option, - ) -> Ty<'db> { - let ret_ty = self.table.next_ty_var(); - let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); - let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty); - let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(ret_ty)); - - // FIXME: We should handle async blocks like we handle closures - let expected = &Expectation::has_type(ret_ty); - let (_, inner_ty) = self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { - let ty = this.infer_block(tgt_expr, statements, *tail, None, expected); - if let Some(target) = expected.only_has_type(&mut this.table) { - match this.coerce(tgt_expr.into(), ty, target, AllowTwoPhase::No, ExprIsRead::Yes) { - Ok(res) => res, - Err(_) => { - this.result.type_mismatches.get_or_insert_default().insert( - tgt_expr.into(), - TypeMismatch { expected: target.store(), actual: ty.store() }, - ); - target - } - } - } else { - ty - } - }); - - self.diverges = prev_diverges; - self.return_ty = prev_ret_ty; - self.return_coercion = prev_ret_coercion; - - self.lower_async_block_type_impl_trait(inner_ty, tgt_expr) - } - - pub(crate) fn lower_async_block_type_impl_trait( - &mut self, - inner_ty: Ty<'db>, - tgt_expr: ExprId, - ) -> Ty<'db> { - let coroutine_id = InternedCoroutine(self.owner, tgt_expr); - let coroutine_id = self.db.intern_coroutine(coroutine_id).into(); - let parent_args = GenericArgs::identity_for_item(self.interner(), self.generic_def.into()); - Ty::new_coroutine( - self.interner(), - coroutine_id, - CoroutineArgs::new( - self.interner(), - CoroutineArgsParts { - parent_args: parent_args.as_slice(), - kind_ty: self.types.types.unit, - // rustc uses a special lang item type for the resume ty. I don't believe this can cause us problems. - resume_ty: self.types.types.unit, - yield_ty: self.types.types.unit, - return_ty: inner_ty, - // FIXME: Infer upvars. - tupled_upvars_ty: self.types.types.unit, - }, - ) - .args, - ) - } - - pub(crate) fn write_fn_trait_method_resolution( - &mut self, - fn_x: FnTrait, - derefed_callee: Ty<'db>, - adjustments: &mut Vec, - callee_ty: Ty<'db>, - params: &[Ty<'db>], - tgt_expr: ExprId, - ) { - match fn_x { - FnTrait::FnOnce | FnTrait::AsyncFnOnce => (), - FnTrait::FnMut | FnTrait::AsyncFnMut => { - if let TyKind::Ref(lt, inner, Mutability::Mut) = derefed_callee.kind() { - if adjustments - .last() - .map(|it| matches!(it.kind, Adjust::Borrow(_))) - .unwrap_or(true) - { - // prefer reborrow to move - adjustments - .push(Adjustment { kind: Adjust::Deref(None), target: inner.store() }); - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Mut, - inner, - lt, - )) - } - } else { - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Mut, - derefed_callee, - self.table.next_region_var(), - )); - } - } - FnTrait::Fn | FnTrait::AsyncFn => { - if !matches!(derefed_callee.kind(), TyKind::Ref(_, _, Mutability::Not)) { - adjustments.push(Adjustment::borrow( - self.interner(), - Mutability::Not, - derefed_callee, - self.table.next_region_var(), - )); - } - } - } - let Some(trait_) = fn_x.get_id(self.lang_items) else { - return; - }; - let trait_data = trait_.trait_items(self.db); - if let Some(func) = trait_data.method_by_name(&fn_x.method_name()) { - let subst = GenericArgs::new_from_slice(&[ - callee_ty.into(), - Ty::new_tup(self.interner(), params).into(), - ]); - self.write_method_resolution(tgt_expr, func, subst); - } - } - fn infer_expr_array(&mut self, array: &Array, expected: &Expectation<'db>) -> Ty<'db> { let elem_ty = match expected .to_option(&mut self.table) @@ -1728,76 +1592,6 @@ impl<'db> InferenceContext<'_, 'db> { MethodCallee { def_id, args, sig } } - fn infer_call( - &mut self, - tgt_expr: ExprId, - callee: ExprId, - args: &[ExprId], - expected: &Expectation<'db>, - ) -> Ty<'db> { - let callee_ty = self.infer_expr(callee, &Expectation::none(), ExprIsRead::Yes); - let callee_ty = self.table.try_structurally_resolve_type(callee_ty); - let interner = self.interner(); - let mut derefs = InferenceContextAutoderef::new_from_inference_context(self, callee_ty); - let (res, derefed_callee) = loop { - let Some((callee_deref_ty, _)) = derefs.next() else { - break (None, callee_ty); - }; - if let Some(res) = derefs.ctx().table.callable_sig(callee_deref_ty, args.len()) { - break (Some(res), callee_deref_ty); - } - }; - // if the function is unresolved, we use is_varargs=true to - // suppress the arg count diagnostic here - let is_varargs = derefed_callee.callable_sig(interner).is_some_and(|sig| sig.c_variadic()) - || res.is_none(); - let (param_tys, ret_ty) = match res { - Some((func, params, ret_ty)) => { - let infer_ok = derefs.adjust_steps_as_infer_ok(); - let mut adjustments = self.table.register_infer_ok(infer_ok); - if let Some(fn_x) = func { - self.write_fn_trait_method_resolution( - fn_x, - derefed_callee, - &mut adjustments, - callee_ty, - ¶ms, - tgt_expr, - ); - } - if let TyKind::Closure(c, _) = self.table.resolve_completely(callee_ty).kind() { - self.add_current_closure_dependency(c.into()); - self.deferred_closures.entry(c.into()).or_default().push(( - derefed_callee, - callee_ty, - params.clone(), - tgt_expr, - )); - } - self.write_expr_adj(callee, adjustments.into_boxed_slice()); - (params, ret_ty) - } - None => { - self.push_diagnostic(InferenceDiagnostic::ExpectedFunction { - call_expr: tgt_expr, - found: callee_ty.store(), - }); - (Vec::new(), Ty::new_error(interner, ErrorGuaranteed)) - } - }; - let indices_to_skip = self.check_legacy_const_generics(derefed_callee, args); - self.check_call( - tgt_expr, - args, - callee_ty, - ¶m_tys, - ret_ty, - &indices_to_skip, - is_varargs, - expected, - ) - } - fn check_call( &mut self, tgt_expr: ExprId, @@ -1819,6 +1613,7 @@ impl<'db> InferenceContext<'_, 'db> { args, indices_to_skip, is_varargs, + TupleArgumentsFlag::DontTupleArguments, ); ret_ty } @@ -1949,13 +1744,22 @@ impl<'db> InferenceContext<'_, 'db> { }; let ret_ty = sig.output(); - self.check_call_arguments(tgt_expr, param_tys, ret_ty, expected, args, &[], sig.c_variadic); + self.check_call_arguments( + tgt_expr, + param_tys, + ret_ty, + expected, + args, + &[], + sig.c_variadic, + TupleArgumentsFlag::DontTupleArguments, + ); ret_ty } /// Generic function that factors out common logic from function calls, /// method calls and overloaded operators. - pub(in super::super) fn check_call_arguments( + pub(super) fn check_call_arguments( &mut self, call_expr: ExprId, // Types (as defined in the *signature* of the target function) @@ -1968,7 +1772,18 @@ impl<'db> InferenceContext<'_, 'db> { skip_indices: &[u32], // Whether the function is variadic, for example when imported from C c_variadic: bool, + // Whether the arguments have been bundled in a tuple (ex: closures) + tuple_arguments: TupleArgumentsFlag, ) { + let formal_input_tys: Vec<_> = formal_input_tys + .iter() + .map(|&ty| { + let generalized_ty = self.table.next_ty_var(); + let _ = self.demand_eqtype(call_expr.into(), ty, generalized_ty); + generalized_ty + }) + .collect(); + // First, let's unify the formal method signature with the expectation eagerly. // We use this to guide coercion inference; it's output is "fudged" which means // any remaining type variables are assigned to new, unrelated variables. This @@ -1988,29 +1803,68 @@ impl<'db> InferenceContext<'_, 'db> { // No argument expectations are produced if unification fails. let origin = ObligationCause::new(); ocx.sup(&origin, self.table.param_env, expected_output, formal_output)?; + + for &ty in &formal_input_tys { + ocx.register_obligation(Obligation::new( + self.interner(), + ObligationCause::new(), + self.table.param_env, + ClauseKind::WellFormed(ty.into()), + )); + } + if !ocx.try_evaluate_obligations().is_empty() { return Err(TypeError::Mismatch); } // Record all the argument types, with the args // produced from the above subtyping unification. - Ok(Some( - formal_input_tys - .iter() - .map(|&ty| self.table.infer_ctxt.resolve_vars_if_possible(ty)) - .collect(), - )) + Ok(Some(formal_input_tys.clone())) }) .ok() }) .unwrap_or_default(); + // If the arguments should be wrapped in a tuple (ex: closures), unwrap them here + let (formal_input_tys, expected_input_tys) = + if tuple_arguments == TupleArgumentsFlag::TupleArguments { + let tuple_type = self.table.structurally_resolve_type(formal_input_tys[0]); + match tuple_type.kind() { + // We expected a tuple and got a tuple + TyKind::Tuple(arg_types) => { + // Argument length differs + if arg_types.len() != provided_args.len() { + // FIXME: Emit an error. + } + let expected_input_tys = match expected_input_tys { + Some(expected_input_tys) => match expected_input_tys.first() { + Some(ty) => match ty.kind() { + TyKind::Tuple(tys) => Some(tys.iter().collect()), + _ => None, + }, + None => None, + }, + None => None, + }; + (arg_types.iter().collect(), expected_input_tys) + } + _ => { + // Otherwise, there's a mismatch, so clear out what we're expecting, and set + // our input types to err_args so we don't blow up the error messages + // FIXME: Emit an error. + (vec![self.types.types.error; provided_args.len()], None) + } + } + } else { + (formal_input_tys.to_vec(), expected_input_tys) + }; + // If there are no external expectations at the call site, just use the types from the function defn - let expected_input_tys = if let Some(expected_input_tys) = &expected_input_tys { + let expected_input_tys = if let Some(expected_input_tys) = expected_input_tys { assert_eq!(expected_input_tys.len(), formal_input_tys.len()); expected_input_tys } else { - formal_input_tys + formal_input_tys.clone() }; let minimum_input_count = expected_input_tys.len(); @@ -2183,51 +2037,6 @@ impl<'db> InferenceContext<'_, 'db> { } } - /// Returns the argument indices to skip. - fn check_legacy_const_generics(&mut self, callee: Ty<'db>, args: &[ExprId]) -> Box<[u32]> { - let (func, _subst) = match callee.kind() { - TyKind::FnDef(callable, subst) => { - let func = match callable.0 { - CallableDefId::FunctionId(f) => f, - _ => return Default::default(), - }; - (func, subst) - } - _ => return Default::default(), - }; - - let data = FunctionSignature::of(self.db, func); - let Some(legacy_const_generics_indices) = data.legacy_const_generics_indices(self.db, func) - else { - return Default::default(); - }; - let mut legacy_const_generics_indices = Box::<[u32]>::from(legacy_const_generics_indices); - - // only use legacy const generics if the param count matches with them - if data.params.len() + legacy_const_generics_indices.len() != args.len() { - if args.len() <= data.params.len() { - return Default::default(); - } else { - // there are more parameters than there should be without legacy - // const params; use them - legacy_const_generics_indices.sort_unstable(); - return legacy_const_generics_indices; - } - } - - // check legacy const parameters - for arg_idx in legacy_const_generics_indices.iter().copied() { - if arg_idx >= args.len() as u32 { - continue; - } - let expected = Expectation::none(); // FIXME use actual const ty, when that is lowered correctly - self.infer_expr(args[arg_idx as usize], &expected, ExprIsRead::Yes); - // FIXME: evaluate and unify with the const - } - legacy_const_generics_indices.sort_unstable(); - legacy_const_generics_indices - } - pub(super) fn with_breakable_ctx( &mut self, kind: BreakableKind, @@ -2243,3 +2052,28 @@ impl<'db> InferenceContext<'_, 'db> { (if ctx.may_break { ctx.coerce.map(|ctx| ctx.complete(self)) } else { None }, res) } } + +/// Controls whether the arguments are tupled. This is used for the call +/// operator. +/// +/// Tupling means that all call-side arguments are packed into a tuple and +/// passed as a single parameter. For example, if tupling is enabled, this +/// function: +/// ``` +/// fn f(x: (isize, isize)) {} +/// ``` +/// Can be called as: +/// ```ignore UNSOLVED (can this be done in user code?) +/// # fn f(x: (isize, isize)) {} +/// f(1, 2); +/// ``` +/// Instead of: +/// ``` +/// # fn f(x: (isize, isize)) {} +/// f((1, 2)); +/// ``` +#[derive(Copy, Clone, Eq, PartialEq)] +pub(super) enum TupleArgumentsFlag { + DontTupleArguments, + TupleArguments, +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/mutability.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/mutability.rs index bfe43fc92827d..b2369f6a87e83 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/mutability.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/mutability.rs @@ -38,7 +38,7 @@ impl<'db> InferenceContext<'_, 'db> { ) { self.table.register_predicates(infer_ok.obligations); } - *d = OverloadedDeref(Some(mutability)); + *d = OverloadedDeref(mutability); } } Adjust::Borrow(b) => match b { @@ -86,7 +86,6 @@ impl<'db> InferenceContext<'_, 'db> { } Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)), Expr::Block { id: _, statements, tail, label: _ } - | Expr::Async { id: _, statements, tail } | Expr::Unsafe { id: _, statements, tail } => { for st in statements.iter() { match st { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/infer/unify.rs b/src/tools/rust-analyzer/crates/hir-ty/src/infer/unify.rs index d093412b42107..b0f916b8c0763 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/infer/unify.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/infer/unify.rs @@ -3,9 +3,7 @@ use std::fmt; use base_db::Crate; -use hir_def::{AdtId, ExpressionStoreOwnerId, GenericParamId}; -use hir_expand::name::Name; -use intern::sym; +use hir_def::{AdtId, ExpressionStoreOwnerId, GenericParamId, TraitId}; use rustc_hash::FxHashSet; use rustc_type_ir::{ TyVid, TypeFoldable, TypeVisitableExt, UpcastFrom, @@ -17,9 +15,9 @@ use smallvec::SmallVec; use crate::{ db::HirDatabase, next_solver::{ - AliasTy, Canonical, ClauseKind, Const, DbInterner, ErrorGuaranteed, GenericArg, - GenericArgs, Goal, ParamEnv, Predicate, PredicateKind, Region, SolverDefId, Term, TraitRef, - Ty, TyKind, TypingMode, + Canonical, ClauseKind, Const, DbInterner, ErrorGuaranteed, GenericArg, GenericArgs, Goal, + ParamEnv, Predicate, PredicateKind, Region, SolverDefId, Term, TraitRef, Ty, TyKind, + TypingMode, fulfill::{FulfillmentCtxt, NextSolverError}, infer::{ DbInternerInferExt, InferCtxt, InferOk, InferResult, @@ -31,7 +29,7 @@ use crate::{ obligation_ctxt::ObligationCtxt, }, traits::{ - FnTrait, NextTraitSolveResult, ParamEnvAndCrate, next_trait_solve_canonical_in_ctxt, + NextTraitSolveResult, ParamEnvAndCrate, next_trait_solve_canonical_in_ctxt, next_trait_solve_in_ctxt, }, }; @@ -174,6 +172,10 @@ impl<'db> InferenceTable<'db> { self.infer_ctxt.type_is_copy_modulo_regions(self.param_env, ty) } + pub(crate) fn type_is_use_cloned_modulo_regions(&self, ty: Ty<'db>) -> bool { + self.infer_ctxt.type_is_use_cloned_modulo_regions(self.param_env, ty) + } + pub(crate) fn type_var_is_sized(&self, self_ty: TyVid) -> bool { let Some(sized_did) = self.interner().lang_items().Sized else { return true; @@ -360,9 +362,6 @@ impl<'db> InferenceTable<'db> { /// in this case. pub(crate) fn try_structurally_resolve_type(&mut self, ty: Ty<'db>) -> Ty<'db> { if let TyKind::Alias(..) = ty.kind() { - // We need to use a separate variable here as otherwise the temporary for - // `self.fulfillment_cx.borrow_mut()` is alive in the `Err` branch, resulting - // in a reentrant borrow, causing an ICE. let result = self .infer_ctxt .at(&ObligationCause::misc(), self.param_env) @@ -445,6 +444,18 @@ impl<'db> InferenceTable<'db> { } } + pub(crate) fn register_bound(&mut self, ty: Ty<'db>, def_id: TraitId, cause: ObligationCause) { + if !ty.references_non_lt_error() { + let trait_ref = TraitRef::new(self.interner(), def_id.into(), [ty]); + self.register_predicate(Obligation::new( + self.interner(), + cause, + self.param_env, + trait_ref, + )); + } + } + pub(crate) fn register_infer_ok(&mut self, infer_ok: InferOk<'db, T>) -> T { let InferOk { value, obligations } = infer_ok; self.register_predicates(obligations); @@ -489,78 +500,6 @@ impl<'db> InferenceTable<'db> { } } - pub(crate) fn callable_sig( - &mut self, - ty: Ty<'db>, - num_args: usize, - ) -> Option<(Option, Vec>, Ty<'db>)> { - match ty.callable_sig(self.interner()) { - Some(sig) => { - let sig = sig.skip_binder(); - Some((None, sig.inputs_and_output.inputs().to_vec(), sig.output())) - } - None => { - let (f, args_ty, return_ty) = self.callable_sig_from_fn_trait(ty, num_args)?; - Some((Some(f), args_ty, return_ty)) - } - } - } - - fn callable_sig_from_fn_trait( - &mut self, - ty: Ty<'db>, - num_args: usize, - ) -> Option<(FnTrait, Vec>, Ty<'db>)> { - let lang_items = self.interner().lang_items(); - for (fn_trait_name, output_assoc_name, subtraits) in [ - (FnTrait::FnOnce, sym::Output, &[FnTrait::Fn, FnTrait::FnMut][..]), - (FnTrait::AsyncFnMut, sym::CallRefFuture, &[FnTrait::AsyncFn]), - (FnTrait::AsyncFnOnce, sym::CallOnceFuture, &[]), - ] { - let fn_trait = fn_trait_name.get_id(lang_items)?; - let trait_data = fn_trait.trait_items(self.db); - let output_assoc_type = - trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?; - - let mut arg_tys = Vec::with_capacity(num_args); - let arg_ty = Ty::new_tup_from_iter( - self.interner(), - std::iter::repeat_with(|| { - let ty = self.next_ty_var(); - arg_tys.push(ty); - ty - }) - .take(num_args), - ); - let args = GenericArgs::new_from_slice(&[ty.into(), arg_ty.into()]); - let trait_ref = TraitRef::new_from_args(self.interner(), fn_trait.into(), args); - - let proj_args = self.infer_ctxt.fill_rest_fresh_args(output_assoc_type.into(), args); - let projection = Ty::new_alias( - self.interner(), - rustc_type_ir::AliasTyKind::Projection, - AliasTy::new_from_args(self.interner(), output_assoc_type.into(), proj_args), - ); - - let pred = Predicate::upcast_from(trait_ref, self.interner()); - if !self.try_obligation(pred).no_solution() { - self.register_obligation(pred); - let return_ty = self.normalize_alias_ty(projection); - for &fn_x in subtraits { - let fn_x_trait = fn_x.get_id(lang_items)?; - let trait_ref = - TraitRef::new_from_args(self.interner(), fn_x_trait.into(), args); - let pred = Predicate::upcast_from(trait_ref, self.interner()); - if !self.try_obligation(pred).no_solution() { - return Some((fn_x, arg_tys, return_ty)); - } - } - return Some((fn_trait_name, arg_tys, return_ty)); - } - } - None - } - pub(super) fn insert_type_vars(&mut self, ty: T) -> T where T: TypeFoldable>, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs b/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs index 54332122d0e40..798c62c192405 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/layout.rs @@ -21,7 +21,7 @@ use rustc_type_ir::{ use triomphe::Arc; use crate::{ - InferenceResult, ParamEnvAndCrate, + ParamEnvAndCrate, consteval::try_const_usize, db::HirDatabase, next_solver::{ @@ -331,25 +331,18 @@ pub fn layout_of_ty_query( ptr.valid_range_mut().start = 1; Layout::scalar(dl, ptr) } - TyKind::Closure(id, args) => { - let def = db.lookup_intern_closure(id.0); - let infer = InferenceResult::of(db, def.0); - let (captures, _) = infer.closure_info(id.0); - let fields = captures - .iter() - .map(|it| { - let ty = it.ty.get().instantiate(interner, args.as_closure().parent_args()); - db.layout_of_ty(ty.store(), trait_env.clone()) - }) - .collect::, _>>()?; - let fields = fields.iter().map(|it| &**it).collect::>(); - let fields = fields.iter().collect::>(); - cx.calc.univariant(&fields, &ReprOptions::default(), StructKind::AlwaysSized)? + TyKind::Closure(_, args) => { + return db.layout_of_ty(args.as_closure().tupled_upvars_ty().store(), trait_env); + } + TyKind::Coroutine(_, args) => { + return db.layout_of_ty(args.as_coroutine().tupled_upvars_ty().store(), trait_env); + } + TyKind::CoroutineClosure(_, args) => { + return db + .layout_of_ty(args.as_coroutine_closure().tupled_upvars_ty().store(), trait_env); } - TyKind::Coroutine(_, _) - | TyKind::CoroutineWitness(_, _) - | TyKind::CoroutineClosure(_, _) => { + TyKind::CoroutineWitness(_, _) => { return Err(LayoutError::NotImplemented); } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/layout/target.rs b/src/tools/rust-analyzer/crates/hir-ty/src/layout/target.rs index b0986c423b116..1752b56b0f6d9 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/layout/target.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/layout/target.rs @@ -2,7 +2,7 @@ use base_db::{Crate, target::TargetLoadError}; use hir_def::layout::TargetDataLayout; -use rustc_abi::{AddressSpace, AlignFromBytesError, TargetDataLayoutErrors}; +use rustc_abi::{AddressSpace, AlignFromBytesError, TargetDataLayoutError}; use triomphe::Arc; use crate::db::HirDatabase; @@ -16,30 +16,29 @@ pub fn target_data_layout_query( Ok(it) => Ok(Arc::new(it)), Err(e) => { Err(match e { - TargetDataLayoutErrors::InvalidAddressSpace { addr_space, cause, err } => { + TargetDataLayoutError::InvalidAddressSpace { addr_space, cause, err } => { format!( r#"invalid address space `{addr_space}` for `{cause}` in "data-layout": {err}"# ) } - TargetDataLayoutErrors::InvalidBits { kind, bit, cause, err } => format!(r#"invalid {kind} `{bit}` for `{cause}` in "data-layout": {err}"#), - TargetDataLayoutErrors::MissingAlignment { cause } => format!(r#"missing alignment for `{cause}` in "data-layout""#), - TargetDataLayoutErrors::InvalidAlignment { cause, err } => format!( - r#"invalid alignment for `{cause}` in "data-layout": `{align}` is {err_kind}"#, - align = err.align(), - err_kind = match err { - AlignFromBytesError::NotPowerOfTwo(_) => "not a power of two", - AlignFromBytesError::TooLarge(_) => "too large", - } - ), - TargetDataLayoutErrors::InconsistentTargetArchitecture { dl, target } => { + TargetDataLayoutError::InvalidBits { kind, bit, cause, err } => format!(r#"invalid {kind} `{bit}` for `{cause}` in "data-layout": {err}"#), + TargetDataLayoutError::MissingAlignment { cause } => format!(r#"missing alignment for `{cause}` in "data-layout""#), + TargetDataLayoutError::InvalidAlignment { cause, err } => { + let (align, err_kind) = match err { + AlignFromBytesError::NotPowerOfTwo(align) => (align, "not a power of two"), + AlignFromBytesError::TooLarge(align) => (align, "too large"), + }; + format!(r#"invalid alignment for `{cause}` in "data-layout": `{align}` is {err_kind}"#) + }, + TargetDataLayoutError::InconsistentTargetArchitecture { dl, target } => { format!(r#"inconsistent target specification: "data-layout" claims architecture is {dl}-endian, while "target-endian" is `{target}`"#) } - TargetDataLayoutErrors::InconsistentTargetPointerWidth { + TargetDataLayoutError::InconsistentTargetPointerWidth { pointer_size, target, } => format!(r#"inconsistent target specification: "data-layout" claims pointers are {pointer_size}-bit, while "target-pointer-width" is `{target}`"#), - TargetDataLayoutErrors::InvalidBitsSize { err } => err, - TargetDataLayoutErrors::UnknownPointerSpecification { err } => format!(r#"use of unknown pointer specifier in "data-layout": {err}"#), + TargetDataLayoutError::InvalidBitsSize { err } => err, + TargetDataLayoutError::UnknownPointerSpecification { err } => format!(r#"use of unknown pointer specifier in "data-layout": {err}"#), }.into()) } }, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/layout/tests/closure.rs b/src/tools/rust-analyzer/crates/hir-ty/src/layout/tests/closure.rs index 9e761aa98ff80..d214b708655a8 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/layout/tests/closure.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/layout/tests/closure.rs @@ -125,6 +125,8 @@ fn capture_specific_fields2() { #[test] fn capture_specific_fields() { size_and_align_expr! { + minicore: fn; + stmts: [] struct X(i64, i32, (u8, i128)); let y: X = X(2, 5, (7, 3)); move |x: i64| { @@ -132,6 +134,8 @@ fn capture_specific_fields() { } } size_and_align_expr! { + minicore: fn; + stmts: [] struct X(i64, i32, (u8, i128)); let y: X = X(2, 5, (7, 3)); move |x: i64| { @@ -140,7 +144,7 @@ fn capture_specific_fields() { } } size_and_align_expr! { - minicore: copy; + minicore: fn, copy; stmts: [ struct X(i64, i32, (u8, i128)); let y: X = X(2, 5, (7, 3)); @@ -151,6 +155,8 @@ fn capture_specific_fields() { } } size_and_align_expr! { + minicore: fn; + stmts: [] struct X(i64, i32, (u8, i128)); let y: X = X(2, 5, (7, 3)); move |x: i64| { @@ -159,6 +165,8 @@ fn capture_specific_fields() { } } size_and_align_expr! { + minicore: fn; + stmts: [] struct X(i64, i32, (u8, i128)); let y = &&X(2, 5, (7, 3)); move |x: i64| { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs index e6b8329ca861a..d004b5e3ef1d6 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lib.rs @@ -62,7 +62,7 @@ use std::{hash::Hash, ops::ControlFlow}; use hir_def::{ CallableDefId, ExpressionStoreOwnerId, GenericDefId, TypeAliasId, TypeOrConstParamId, - TypeParamId, hir::generics::GenericParams, resolver::TypeNs, type_ref::Rawness, + TypeParamId, resolver::TypeNs, type_ref::Rawness, }; use hir_expand::name::Name; use indexmap::{IndexMap, map::Entry}; @@ -84,7 +84,7 @@ use crate::{ lower::SupertraitsInfo, next_solver::{ AliasTy, Binder, BoundConst, BoundRegion, BoundRegionKind, BoundTy, BoundTyKind, Canonical, - CanonicalVarKind, CanonicalVars, ClauseKind, Const, ConstKind, DbInterner, FnSig, + CanonicalVarKind, CanonicalVarKinds, ClauseKind, Const, ConstKind, DbInterner, FnSig, GenericArgs, PolyFnSig, Predicate, Region, RegionKind, TraitRef, Ty, TyKind, Tys, abi, }, }; @@ -92,10 +92,8 @@ use crate::{ pub use autoderef::autoderef; pub use infer::{ Adjust, Adjustment, AutoBorrow, BindingMode, InferenceDiagnostic, InferenceResult, - InferenceTyDiagnosticSource, OverloadedDeref, PointerCast, - cast::CastError, - closure::analysis::{CaptureKind, CapturedItem}, - could_coerce, could_unify, could_unify_deeply, infer_query_with_inspect, + InferenceTyDiagnosticSource, OverloadedDeref, PointerCast, cast::CastError, could_coerce, + could_unify, could_unify_deeply, infer_query_with_inspect, }; pub use lower::{ GenericPredicates, ImplTraits, LifetimeElisionKind, TyDefId, TyLoweringContext, ValueTyDefId, @@ -109,6 +107,16 @@ pub use utils::{ is_fn_unsafe_to_call, target_feature_is_safe_in_target, }; +pub mod closure_analysis { + pub use crate::infer::{ + CaptureInfo, CaptureSourceStack, CapturedPlace, ClosureData, UpvarCapture, + closure::analysis::{ + BorrowKind, + expr_use_visitor::{FakeReadCause, Place, PlaceBase, Projection, ProjectionKind}, + }, + }; +} + /// A constant can have reference to other things. Memory map job is holding /// the necessary bits of memory of the const eval session to keep the constant /// meaningful. @@ -197,7 +205,7 @@ pub fn param_idx(db: &dyn HirDatabase, id: TypeOrConstParamId) -> Option generics::generics(db, id.parent).type_or_const_param_idx(id) } -#[derive(Debug, Copy, Clone, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum FnAbi { Aapcs, AapcsUnwind, @@ -239,21 +247,6 @@ pub enum FnAbi { Unknown, } -impl PartialEq for FnAbi { - fn eq(&self, _other: &Self) -> bool { - // FIXME: Proper equality breaks `coercion::two_closures_lub` test - true - } -} - -impl Hash for FnAbi { - fn hash(&self, state: &mut H) { - // Required because of the FIXME above and due to us implementing `Eq`, without this - // we would break the `Hash` + `Eq` contract - core::mem::discriminant(&Self::Unknown).hash(state); - } -} - impl FnAbi { #[rustfmt::skip] pub fn from_symbol(s: &Symbol) -> FnAbi { @@ -435,7 +428,7 @@ where ConstKind::Error(_) => { let var = rustc_type_ir::BoundVar::from_usize(self.vars.len()); self.vars.push(CanonicalVarKind::Const(rustc_type_ir::UniverseIndex::ZERO)); - Ok(Const::new_bound(self.interner, self.binder, BoundConst { var })) + Ok(Const::new_bound(self.interner, self.binder, BoundConst::new(var))) } ConstKind::Infer(_) => error(), ConstKind::Bound(BoundVarIndexKind::Bound(index), _) if index > self.binder => { @@ -479,7 +472,7 @@ where Canonical { value, max_universe: rustc_type_ir::UniverseIndex::ZERO, - variables: CanonicalVars::new_from_slice(&error_replacer.vars), + var_kinds: CanonicalVarKinds::new_from_slice(&error_replacer.vars), } } @@ -495,10 +488,7 @@ pub fn associated_type_shorthand_candidates( TypeNs::GenericParam(param) => (def, param), TypeNs::SelfType(impl_) => { let impl_trait = db.impl_trait(impl_)?.skip_binder().def_id.0; - let param = TypeParamId::from_unchecked(TypeOrConstParamId { - parent: impl_trait.into(), - local_id: GenericParams::SELF_PARAM_ID_IN_SELF, - }); + let param = TypeParamId::trait_self(impl_trait); (impl_trait.into(), param) } _ => return None, @@ -554,8 +544,11 @@ pub fn callable_sig_from_fn_trait<'db>( let trait_ref = TraitRef::new_from_args(table.interner(), fn_once_trait.into(), args); let projection = Ty::new_alias( table.interner(), - rustc_type_ir::AliasTyKind::Projection, - AliasTy::new_from_args(table.interner(), output_assoc_type.into(), args), + AliasTy::new_from_args( + table.interner(), + rustc_type_ir::Projection { def_id: output_assoc_type.into() }, + args, + ), ); let pred = Predicate::upcast_from(trait_ref, table.interner()); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lower.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lower.rs index 71a7db6559a86..335aff2c1df16 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lower.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lower.rs @@ -537,8 +537,11 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { let args = GenericArgs::identity_for_item(self.interner, opaque_ty_id); Ty::new_alias( self.interner, - AliasTyKind::Opaque, - AliasTy::new_from_args(self.interner, opaque_ty_id, args), + AliasTy::new_from_args( + self.interner, + AliasTyKind::Opaque { def_id: opaque_ty_id }, + args, + ), ) } ImplTraitLoweringMode::Disallowed => { @@ -1039,8 +1042,7 @@ impl<'db, 'a> TyLoweringContext<'db, 'a> { let args = GenericArgs::identity_for_item(interner, def_id); let self_ty = Ty::new_alias( self.interner, - rustc_type_ir::AliasTyKind::Opaque, - AliasTy::new_from_args(interner, def_id, args), + AliasTy::new_from_args(interner, rustc_type_ir::Opaque { def_id }, args), ); let (predicates, assoc_ty_bounds_start) = self.with_shifted_in(DebruijnIndex::from_u32(1), |ctx| { @@ -1869,10 +1871,14 @@ fn resolve_type_param_assoc_type_shorthand( .skip_binder(); let args = EarlyBinder::bind(args).instantiate(interner, bounded_trait_ref.args); let current_result = StoredEarlyBinder::bind((assoc_type, args.store())); - if let Some(this_trait_resolution) = this_trait_resolution { - return AssocTypeShorthandResolution::Ambiguous { - sub_trait_resolution: Some(this_trait_resolution), - }; + if let Some(this_trait_resolution) = &this_trait_resolution { + if *this_trait_resolution == current_result { + continue; + } else { + return AssocTypeShorthandResolution::Ambiguous { + sub_trait_resolution: Some(this_trait_resolution.clone()), + }; + } } else if let Some(prev_resolution) = &supertraits_resolution { if let AssocTypeShorthandResolution::Ambiguous { sub_trait_resolution: Some(prev_resolution), diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/lower/path.rs b/src/tools/rust-analyzer/crates/hir-ty/src/lower/path.rs index 889f0792d347b..4f707321782a2 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/lower/path.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/lower/path.rs @@ -214,10 +214,9 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { ); Ty::new_alias( self.ctx.interner, - AliasTyKind::Projection, AliasTy::new_from_args( self.ctx.interner, - associated_ty.into(), + AliasTyKind::Projection { def_id: associated_ty.into() }, args, ), ) @@ -949,10 +948,9 @@ impl<'a, 'b, 'db> PathLoweringContext<'a, 'b, 'db> { bound, Ty::new_alias( self.ctx.interner, - AliasTyKind::Projection, AliasTy::new_from_args( self.ctx.interner, - associated_ty.into(), + AliasTyKind::Projection { def_id: associated_ty.into() }, args, ), ), diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs b/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs index b18e48c1fed36..68c4833d81b01 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/method_resolution.rs @@ -219,7 +219,6 @@ impl<'a, 'db> InferenceContext<'a, 'db> { /// between multiple candidates. We otherwise treat them as ordinary inference /// variable to avoid rejecting otherwise correct code. #[derive(Debug)] -#[expect(dead_code)] pub(super) enum TreatNotYetDefinedOpaques { AsInfer, AsRigid, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir.rs index a8865cd54e6a5..a8e06f3a2b586 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir.rs @@ -6,8 +6,7 @@ use base_db::Crate; use either::Either; use hir_def::{ DefWithBodyId, FieldId, StaticId, TupleFieldId, UnionId, VariantId, - expr_store::ExpressionStore, - hir::{BindingAnnotation, BindingId, Expr, ExprId, Ordering, PatId}, + hir::{BindingId, Expr, ExprId, Ordering, PatId}, }; use la_arena::{Arena, ArenaMap, Idx, RawIdx}; use rustc_ast_ir::Mutability; @@ -23,8 +22,8 @@ use crate::{ display::{DisplayTarget, HirDisplay}, infer::PointerCast, next_solver::{ - Const, DbInterner, ErrorGuaranteed, GenericArgs, ParamEnv, StoredConst, StoredGenericArgs, - StoredTy, Ty, TyKind, + Allocation, AllocationData, DbInterner, ErrorGuaranteed, GenericArgs, ParamEnv, + StoredAllocation, StoredConst, StoredGenericArgs, StoredTy, Ty, TyKind, infer::{InferCtxt, traits::ObligationCause}, obligation_ctxt::ObligationCtxt, }, @@ -107,7 +106,13 @@ pub enum OperandKind { /// [UCG#188]: https://github.com/rust-lang/unsafe-code-guidelines/issues/188 Move(Place), /// Constants are already semantically values, and remain unchanged. - Constant { konst: StoredConst, ty: StoredTy }, + Constant { + konst: StoredConst, + ty: StoredTy, + }, + Allocation { + allocation: StoredAllocation, + }, /// NON STANDARD: This kind of operand returns an immutable reference to that static memory. Rustc /// handles it with the `Constant` variant somehow. Static(StaticId), @@ -115,11 +120,10 @@ pub enum OperandKind { impl<'db> Operand { fn from_concrete_const(data: Box<[u8]>, memory_map: MemoryMap<'db>, ty: Ty<'db>) -> Self { - let interner = DbInterner::conjure(); Operand { - kind: OperandKind::Constant { - konst: Const::new_valtree(interner, ty, data, memory_map).store(), - ty: ty.store(), + kind: OperandKind::Allocation { + allocation: Allocation::new(AllocationData { ty, memory: data, memory_map }) + .store(), }, span: None, } @@ -163,7 +167,6 @@ impl ProjectionElem { infcx: &InferCtxt<'db>, env: ParamEnv<'db>, mut base: Ty<'db>, - closure_field: impl FnOnce(InternedClosureId, GenericArgs<'db>, usize) -> Ty<'db>, krate: Crate, ) -> Ty<'db> { let interner = infcx.interner; @@ -218,7 +221,7 @@ impl ProjectionElem { } }, ProjectionElem::ClosureField(f) => match base.kind() { - TyKind::Closure(id, subst) => closure_field(id.0, subst, *f), + TyKind::Closure(_, args) => args.as_closure().tupled_upvars_ty().tuple_fields()[*f], _ => { never!("Only closure has closure field"); Ty::new_error(interner, ErrorGuaranteed) @@ -706,19 +709,31 @@ pub enum MutBorrowKind { } impl BorrowKind { - fn from_hir(m: hir_def::type_ref::Mutability) -> Self { + fn from_hir_mutability(m: hir_def::type_ref::Mutability) -> Self { match m { hir_def::type_ref::Mutability::Shared => BorrowKind::Shared, hir_def::type_ref::Mutability::Mut => BorrowKind::Mut { kind: MutBorrowKind::Default }, } } - fn from_rustc(m: rustc_ast_ir::Mutability) -> Self { + fn from_rustc_mutability(m: rustc_ast_ir::Mutability) -> Self { match m { rustc_ast_ir::Mutability::Not => BorrowKind::Shared, rustc_ast_ir::Mutability::Mut => BorrowKind::Mut { kind: MutBorrowKind::Default }, } } + + fn from_hir(bk: crate::infer::closure::analysis::BorrowKind) -> Self { + match bk { + crate::closure_analysis::BorrowKind::Immutable => Self::Shared, + crate::closure_analysis::BorrowKind::UniqueImmutable => { + Self::Mut { kind: MutBorrowKind::ClosureCapture } + } + crate::closure_analysis::BorrowKind::Mutable => { + Self::Mut { kind: MutBorrowKind::Default } + } + } + } } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -1074,6 +1089,7 @@ pub struct MirBody { pub start_block: BasicBlockId, pub owner: DefWithBodyId, pub binding_locals: ArenaMap, + pub upvar_locals: FxHashMap>, pub param_locals: Vec, /// This field stores the closures directly owned by this body. It is used /// in traversing every mir body. @@ -1095,7 +1111,9 @@ impl MirBody { OperandKind::Copy(p) | OperandKind::Move(p) => { f(p, store); } - OperandKind::Constant { .. } | OperandKind::Static(_) => (), + OperandKind::Constant { .. } + | OperandKind::Static(_) + | OperandKind::Allocation { .. } => (), } } for (_, block) in self.basic_blocks.iter_mut() { @@ -1183,6 +1201,7 @@ impl MirBody { start_block: _, owner: _, binding_locals, + upvar_locals, param_locals, closures, projection_store, @@ -1191,6 +1210,7 @@ impl MirBody { basic_blocks.shrink_to_fit(); locals.shrink_to_fit(); binding_locals.shrink_to_fit(); + upvar_locals.shrink_to_fit(); param_locals.shrink_to_fit(); closures.shrink_to_fit(); for (_, b) in basic_blocks.iter_mut() { @@ -1208,20 +1228,6 @@ pub enum MirSpan { SelfParam, Unknown, } - -impl MirSpan { - pub fn is_ref_span(&self, store: &ExpressionStore) -> bool { - match *self { - MirSpan::ExprId(expr) => matches!(store[expr], Expr::Ref { .. }), - // FIXME: Figure out if this is correct wrt. match ergonomics. - MirSpan::BindingId(binding) => { - matches!(store[binding].mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) - } - MirSpan::PatId(_) | MirSpan::SelfParam | MirSpan::Unknown => false, - } - } -} - impl_from!(ExprId, PatId for MirSpan); impl From<&ExprId> for MirSpan { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/borrowck.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/borrowck.rs index 3ff2db15aaf56..17715d3fcd23b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/borrowck.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/borrowck.rs @@ -8,17 +8,16 @@ use std::iter; use hir_def::{DefWithBodyId, ExpressionStoreOwnerId, HasModule}; use la_arena::ArenaMap; use rustc_hash::FxHashMap; -use rustc_type_ir::inherent::GenericArgs as _; use stdx::never; use triomphe::Arc; use crate::{ - InferenceResult, - db::{HirDatabase, InternedClosure, InternedClosureId}, + closure_analysis::ProjectionKind as HirProjectionKind, + db::{HirDatabase, InternedClosureId}, display::DisplayTarget, mir::OperandKind, next_solver::{ - DbInterner, GenericArgs, ParamEnv, StoredTy, Ty, TypingMode, + DbInterner, ParamEnv, StoredTy, Ty, TypingMode, infer::{DbInternerInferExt, InferCtxt}, }, }; @@ -68,25 +67,49 @@ pub struct BorrowckResult { fn all_mir_bodies( db: &dyn HirDatabase, def: DefWithBodyId, - mut cb: impl FnMut(Arc), -) -> Result<(), MirLowerError> { + mut cb: impl FnMut(Arc) -> BorrowckResult, + mut merge_from_closures: impl FnMut(&mut BorrowckResult, &BorrowckResult), +) -> Result, MirLowerError> { fn for_closure( db: &dyn HirDatabase, c: InternedClosureId, - cb: &mut impl FnMut(Arc), + results: &mut Vec, + cb: &mut impl FnMut(Arc) -> BorrowckResult, + merge_from_closures: &mut impl FnMut(&mut BorrowckResult, &BorrowckResult), ) -> Result<(), MirLowerError> { match db.mir_body_for_closure(c) { Ok(body) => { - cb(body.clone()); - body.closures.iter().try_for_each(|&it| for_closure(db, it, cb)) + let parent_index = results.len(); + results.push(cb(body.clone())); + body.closures + .iter() + .try_for_each(|&it| for_closure(db, it, results, cb, merge_from_closures))?; + merge(results, merge_from_closures, parent_index); + Ok(()) } Err(e) => Err(e), } } + + fn merge( + results: &mut [BorrowckResult], + merge: &mut impl FnMut(&mut BorrowckResult, &BorrowckResult), + parent_index: usize, + ) { + let (parent_and_before, children) = results.split_at_mut(parent_index + 1); + let parent = &mut parent_and_before[parent_and_before.len() - 1]; + children.iter().for_each(|child| merge(parent, child)); + } + + let mut results = Vec::new(); match db.mir_body(def) { Ok(body) => { - cb(body.clone()); - body.closures.iter().try_for_each(|&it| for_closure(db, it, &mut cb)) + results.push(cb(body.clone())); + body.closures.iter().try_for_each(|&it| { + for_closure(db, it, &mut results, &mut cb, &mut merge_from_closures) + })?; + merge(&mut results, &mut merge_from_closures, 0); + Ok(results.into()) } Err(e) => Err(e), } @@ -100,34 +123,50 @@ pub fn borrowck_query( let module = def.module(db); let interner = DbInterner::new_with(db, module.krate(db)); let env = db.trait_environment(ExpressionStoreOwnerId::from(def)); - let mut res = vec![]; // This calculates opaques defining scope which is a bit costly therefore is put outside `all_mir_bodies()`. let typing_mode = TypingMode::borrowck(interner, def.into()); - all_mir_bodies(db, def, |body| { - // FIXME(next-solver): Opaques. - let infcx = interner.infer_ctxt().build(typing_mode); - res.push(BorrowckResult { - mutability_of_locals: mutability_of_locals(&infcx, env, &body), - moved_out_of_ref: moved_out_of_ref(&infcx, env, &body), - partially_moved: partially_moved(&infcx, env, &body), - borrow_regions: borrow_regions(db, &body), - mir_body: body, - }); - })?; - Ok(res.into()) -} - -fn make_fetch_closure_field<'db>( - db: &'db dyn HirDatabase, -) -> impl FnOnce(InternedClosureId, GenericArgs<'db>, usize) -> Ty<'db> + use<'db> { - |c: InternedClosureId, subst: GenericArgs<'db>, f: usize| { - let InternedClosure(owner, _) = db.lookup_intern_closure(c); - let interner = DbInterner::new_no_crate(db); - let infer = InferenceResult::of(db, owner); - let (captures, _) = infer.closure_info(c); - let parent_subst = subst.as_closure().parent_args(); - captures.get(f).expect("broken closure field").ty.get().instantiate(interner, parent_subst) - } + let res = all_mir_bodies( + db, + def, + |body| { + // FIXME(next-solver): Opaques. + let infcx = interner.infer_ctxt().build(typing_mode); + BorrowckResult { + mutability_of_locals: mutability_of_locals(&infcx, env, &body), + moved_out_of_ref: moved_out_of_ref(&infcx, env, &body), + partially_moved: partially_moved(&infcx, env, &body), + borrow_regions: borrow_regions(db, &body), + mir_body: body, + } + }, + |parent, child| { + for (upvar, child_locals) in &child.mir_body.upvar_locals { + let Some(&parent_local) = parent.mir_body.binding_locals.get(*upvar) else { + continue; + }; + for (child_local, capture_place) in child_locals { + if !capture_place + .projections + .iter() + .any(|proj| matches!(proj.kind, HirProjectionKind::Deref)) + { + let parent_mol = &mut parent.mutability_of_locals[parent_local]; + match (&*parent_mol, &child.mutability_of_locals[*child_local]) { + (MutabilityReason::Mut { .. }, _) => {} + (_, MutabilityReason::Mut { .. }) => { + // FIXME: Fix the child spans. + *parent_mol = MutabilityReason::Mut { spans: Vec::new() } + } + (MutabilityReason::Not, _) => {} + (_, MutabilityReason::Not) => *parent_mol = MutabilityReason::Not, + (MutabilityReason::Unused, MutabilityReason::Unused) => {} + } + } + } + } + }, + )?; + Ok(res) } fn moved_out_of_ref<'db>( @@ -145,13 +184,7 @@ fn moved_out_of_ref<'db>( if *proj == ProjectionElem::Deref && ty.as_reference().is_some() { is_dereference_of_ref = true; } - ty = proj.projected_ty( - infcx, - env, - ty, - make_fetch_closure_field(db), - body.owner.module(db).krate(db), - ); + ty = proj.projected_ty(infcx, env, ty, body.owner.module(db).krate(db)); } if is_dereference_of_ref && !infcx.type_is_copy_modulo_regions(env, ty) @@ -160,7 +193,7 @@ fn moved_out_of_ref<'db>( result.push(MovedOutOfRef { span: op.span.unwrap_or(span), ty: ty.store() }); } } - OperandKind::Constant { .. } | OperandKind::Static(_) => (), + OperandKind::Constant { .. } | OperandKind::Static(_) | OperandKind::Allocation { .. } => {} }; for (_, block) in body.basic_blocks.iter() { db.unwind_if_revision_cancelled(); @@ -242,19 +275,13 @@ fn partially_moved<'db>( OperandKind::Copy(p) | OperandKind::Move(p) => { let mut ty: Ty<'db> = body.locals[p.local].ty.as_ref(); for proj in p.projection.lookup(&body.projection_store) { - ty = proj.projected_ty( - infcx, - env, - ty, - make_fetch_closure_field(db), - body.owner.module(db).krate(db), - ); + ty = proj.projected_ty(infcx, env, ty, body.owner.module(db).krate(db)); } if !infcx.type_is_copy_modulo_regions(env, ty) && !ty.references_non_lt_error() { result.push(PartiallyMoved { span, ty: ty.store(), local: p.local }); } } - OperandKind::Constant { .. } | OperandKind::Static(_) => (), + OperandKind::Constant { .. } | OperandKind::Static(_) | OperandKind::Allocation { .. } => {} }; for (_, block) in body.basic_blocks.iter() { db.unwind_if_revision_cancelled(); @@ -397,13 +424,7 @@ fn place_case<'db>( } ProjectionElem::OpaqueCast(_) => (), } - ty = proj.projected_ty( - infcx, - env, - ty, - make_fetch_closure_field(db), - body.owner.module(db).krate(db), - ); + ty = proj.projected_ty(infcx, env, ty, body.owner.module(db).krate(db)); } if is_part_of { ProjectionCase::DirectPart } else { ProjectionCase::Direct } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval.rs index 505db1776f280..80e429c4c8232 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval.rs @@ -21,7 +21,7 @@ use hir_expand::{InFile, mod_path::path, name::Name}; use intern::sym; use la_arena::ArenaMap; use macros::GenericTypeVisitable; -use rustc_abi::TargetDataLayout; +use rustc_abi::{Size, TargetDataLayout}; use rustc_apfloat::{ Float, ieee::{Half as f16, Quad as f128}, @@ -40,14 +40,14 @@ use triomphe::Arc; use crate::{ CallableDefId, ComplexMemoryMap, InferenceResult, MemoryMap, ParamEnvAndCrate, consteval::{self, ConstEvalError, try_const_usize}, - db::{HirDatabase, InternedClosure, InternedClosureId}, + db::{HirDatabase, InternedClosureId}, display::{ClosureStyle, DisplayTarget, HirDisplay}, infer::PointerCast, layout::{Layout, LayoutError, RustcEnumVariantIdx}, method_resolution::{is_dyn_method, lookup_impl_const}, next_solver::{ - Const, ConstBytes, ConstKind, DbInterner, ErrorGuaranteed, GenericArgs, Region, - StoredConst, StoredTy, Ty, TyKind, TypingMode, UnevaluatedConst, ValueConst, + AliasTy, Allocation, AllocationData, Const, ConstKind, DbInterner, ErrorGuaranteed, + GenericArgs, Region, StoredTy, Ty, TyKind, TypingMode, UnevaluatedConst, ValTree, infer::{DbInternerInferExt, InferCtxt, traits::ObligationCause}, obligation_ctxt::ObligationCtxt, }, @@ -359,7 +359,7 @@ pub enum MirEvalError { MirLowerErrorForClosure(InternedClosureId, MirLowerError), TypeIsUnsized(StoredTy, &'static str), NotSupported(String), - InvalidConst(StoredConst), + InvalidConst, InFunction( Box, Vec<(Either, MirSpan, DefWithBodyId)>, @@ -484,7 +484,7 @@ impl MirEvalError { | MirEvalError::MirLowerErrorForClosure(_, _) | MirEvalError::TypeIsUnsized(_, _) | MirEvalError::NotSupported(_) - | MirEvalError::InvalidConst(_) + | MirEvalError::InvalidConst | MirEvalError::ExecutionLimitExceeded | MirEvalError::StackOverflow | MirEvalError::CoerceUnsizedError(_) @@ -537,7 +537,7 @@ impl std::fmt::Debug for MirEvalError { Self::InternalError(arg0) => f.debug_tuple("InternalError").field(arg0).finish(), Self::InvalidVTableId(arg0) => f.debug_tuple("InvalidVTableId").field(arg0).finish(), Self::NotSupported(arg0) => f.debug_tuple("NotSupported").field(arg0).finish(), - Self::InvalidConst(arg0) => f.debug_tuple("InvalidConst").field(&arg0).finish(), + Self::InvalidConst => f.write_str("InvalidConst"), Self::InFunction(e, stack) => { f.debug_struct("WithStack").field("error", e).field("stack", &stack).finish() } @@ -606,10 +606,10 @@ pub fn interpret_mir<'db>( // (and probably should) do better here, for example by excluding bindings outside of the target expression. assert_placeholder_ty_is_unused: bool, trait_env: Option>, -) -> Result<'db, (Result<'db, Const<'db>>, MirOutput)> { +) -> Result<'db, (Result<'db, Allocation<'db>>, MirOutput)> { let ty = body.locals[return_slot()].ty.as_ref(); let mut evaluator = Evaluator::new(db, body.owner, assert_placeholder_ty_is_unused, trait_env)?; - let it: Result<'db, Const<'db>> = (|| { + let it: Result<'db, Allocation<'db>> = (|| { if evaluator.ptr_size() != size_of::() { not_supported!("targets with different pointer size from host"); } @@ -620,7 +620,7 @@ pub fn interpret_mir<'db>( ty, &Locals { ptr: ArenaMap::new(), body, drop_flags: DropFlags::default() }, )?; - let bytes = bytes.into(); + let bytes = Box::from(bytes); let memory_map = if memory_map.memory.is_empty() && evaluator.vtable_map.is_empty() { MemoryMap::Empty } else { @@ -628,7 +628,7 @@ pub fn interpret_mir<'db>( memory_map.vtable.shrink_to_fit(); MemoryMap::Complex(Box::new(memory_map)) }; - Ok(Const::new_valtree(evaluator.interner(), ty, bytes, memory_map)) + Ok(Allocation::new(AllocationData { ty, memory: bytes, memory_map })) })(); Ok((it, MirOutput { stdout: evaluator.stdout, stderr: evaluator.stderr })) } @@ -731,24 +731,7 @@ impl<'db> Evaluator<'db> { return *r; } let (ty, proj) = pair; - let r = proj.projected_ty( - &self.infcx, - self.param_env.param_env, - ty, - |c, subst, f| { - let InternedClosure(owner, _) = self.db.lookup_intern_closure(c); - let infer = InferenceResult::of(self.db, owner); - let (captures, _) = infer.closure_info(c); - let parent_subst = subst.as_closure().parent_args(); - captures - .get(f) - .expect("broken closure field") - .ty - .get() - .instantiate(self.interner(), parent_subst) - }, - self.crate_id, - ); + let r = proj.projected_ty(&self.infcx, self.param_env.param_env, ty, self.crate_id); self.projected_ty_cache.borrow_mut().insert((ty, proj), r); r } @@ -898,6 +881,7 @@ impl<'db> Evaluator<'db> { Ok(match &o.kind { OperandKind::Copy(p) | OperandKind::Move(p) => self.place_ty(p, locals)?, OperandKind::Constant { konst: _, ty } => ty.as_ref(), + OperandKind::Allocation { allocation } => allocation.as_ref().ty, &OperandKind::Static(s) => { let ty = InferenceResult::of(self.db, DefWithBodyId::from(s)) .expr_ty(Body::of(self.db, s.into()).root_expr()); @@ -1927,19 +1911,152 @@ impl<'db> Evaluator<'db> { OperandKind::Constant { konst, .. } => { self.allocate_const_in_heap(locals, konst.as_ref())? } + OperandKind::Allocation { allocation } => { + self.allocate_allocation_in_heap(locals, allocation.as_ref())? + } }) } + fn allocate_valtree_in_heap( + &mut self, + ty: Ty<'db>, + valtree: ValTree<'db>, + ) -> Result<'db, Interval> { + match ty.kind() { + TyKind::Bool => { + let value = valtree.inner().to_leaf().try_to_bool().unwrap(); + let addr = self.heap_allocate(1, 1)?; + self.write_memory(addr, &[u8::from(value)])?; + Ok(Interval::new(addr, 1)) + } + TyKind::Char => { + let value = valtree.inner().to_leaf().to_u32(); + let addr = self.heap_allocate(4, 4)?; + self.write_memory(addr, &value.to_le_bytes())?; + Ok(Interval::new(addr, 4)) + } + TyKind::Int(int_ty) => { + let size = int_ty.bit_width().unwrap_or(self.ptr_size() as u64); + let value = valtree.inner().to_leaf().to_int(Size::from_bytes(size)); + let addr = self.heap_allocate(size as usize, size as usize)?; + self.write_memory(addr, &value.to_le_bytes()[..size as usize])?; + Ok(Interval::new(addr, size as usize)) + } + TyKind::Uint(uint_ty) => { + let size = uint_ty.bit_width().unwrap_or(self.ptr_size() as u64); + let value = valtree.inner().to_leaf().to_uint(Size::from_bytes(size)); + let addr = self.heap_allocate(size as usize, size as usize)?; + self.write_memory(addr, &value.to_le_bytes()[..size as usize])?; + Ok(Interval::new(addr, size as usize)) + } + TyKind::Float(float_ty) => { + let size = float_ty.bit_width(); + let value = valtree.inner().to_leaf().to_uint(Size::from_bytes(size)); + let addr = self.heap_allocate(size as usize, size as usize)?; + self.write_memory(addr, &value.to_le_bytes()[..size as usize])?; + Ok(Interval::new(addr, size as usize)) + } + TyKind::RawPtr(..) => { + let size = self.ptr_size(); + let value = valtree.inner().to_leaf().to_uint(Size::from_bytes(size)); + let addr = self.heap_allocate(size, size)?; + self.write_memory(addr, &value.to_le_bytes()[..size])?; + Ok(Interval::new(addr, size)) + } + TyKind::Ref(_, inner_ty, _) => match inner_ty.kind() { + TyKind::Str => { + let bytes = valtree + .inner() + .to_branch() + .iter() + .map(|konst| match konst.kind() { + ConstKind::Value(value) => Ok(value.value.inner().to_leaf().to_u8()), + _ => not_supported!("unsupported const"), + }) + .collect::>>()?; + let bytes_addr = self.heap_allocate(bytes.len(), 1)?; + self.write_memory(bytes_addr, &bytes)?; + let ref_addr = self.heap_allocate(self.ptr_size() * 2, self.ptr_size())?; + self.write_memory(ref_addr, &bytes_addr.to_bytes())?; + let mut len = [0; 16]; + len[..size_of::()].copy_from_slice(&bytes.len().to_le_bytes()); + self.write_memory(ref_addr.offset(self.ptr_size()), &len[..self.ptr_size()])?; + Ok(Interval::new(ref_addr, self.ptr_size() * 2)) + } + TyKind::Slice(inner_ty) => { + let item_layout = self.layout(inner_ty)?; + let items = valtree + .inner() + .to_branch() + .iter() + .map(|konst| match konst.kind() { + ConstKind::Value(value) => { + self.allocate_valtree_in_heap(value.ty, value.value) + } + _ => not_supported!("unsupported const"), + }) + .collect::>>()?; + let items_addr = self.heap_allocate( + items.len() * (item_layout.size.bits() as usize), + item_layout.align.bits_usize(), + )?; + for (i, item) in items.iter().enumerate() { + self.copy_from_interval( + items_addr.offset(i * (item_layout.size.bits() as usize)), + *item, + )?; + } + let ref_addr = self.heap_allocate(self.ptr_size() * 2, self.ptr_size())?; + self.write_memory(ref_addr, &items_addr.to_bytes())?; + let mut len = [0; 16]; + len[..size_of::()].copy_from_slice(&items.len().to_le_bytes()); + self.write_memory(ref_addr.offset(self.ptr_size()), &len[..self.ptr_size()])?; + Ok(Interval::new(ref_addr, self.ptr_size() * 2)) + } + TyKind::Dynamic(..) => not_supported!("`dyn Trait` consts not supported yet"), + _ => { + let inner_addr = self.allocate_valtree_in_heap(inner_ty, valtree)?; + let ref_addr = self.heap_allocate(self.ptr_size(), self.ptr_size())?; + self.write_memory(ref_addr, &inner_addr.addr.to_bytes())?; + Ok(Interval::new(ref_addr, self.ptr_size())) + } + }, + TyKind::Adt(_, _) | TyKind::Array(_, _) | TyKind::Tuple(_) => { + not_supported!( + "ADTs, arrays and tuples are unsupported in consts currently (requires `adt_const_params`)" + ) + } + TyKind::Pat(_, _) + | TyKind::Slice(_) + | TyKind::FnDef(_, _) + | TyKind::Foreign(_) + | TyKind::Dynamic(_, _) + | TyKind::UnsafeBinder(..) + | TyKind::FnPtr(..) + | TyKind::Closure(_, _) + | TyKind::CoroutineClosure(_, _) + | TyKind::Coroutine(_, _) + | TyKind::CoroutineWitness(_, _) + | TyKind::Never + | TyKind::Alias(..) + | TyKind::Param(_) + | TyKind::Bound(..) + | TyKind::Placeholder(_) + | TyKind::Infer(_) + | TyKind::Str + | TyKind::Error(_) => not_supported!("unsupported const"), + } + } + #[allow(clippy::double_parens)] fn allocate_const_in_heap( &mut self, locals: &Locals, konst: Const<'db>, ) -> Result<'db, Interval> { - let result_owner; - let value = match konst.kind() { - ConstKind::Value(value) => value, - ConstKind::Unevaluated(UnevaluatedConst { def: const_id, args: subst }) => 'b: { + match konst.kind() { + ConstKind::Value(value) => self.allocate_valtree_in_heap(value.ty, value.value), + ConstKind::Unevaluated(UnevaluatedConst { def: const_id, args: subst }) => { let mut id = const_id.0; let mut subst = subst; if let hir_def::GeneralConstId::ConstId(c) = id { @@ -1947,7 +2064,7 @@ impl<'db> Evaluator<'db> { id = hir_def::GeneralConstId::ConstId(c); subst = s; } - result_owner = match id { + let allocation = match id { GeneralConstId::ConstId(const_id) => { self.db.const_eval(const_id, subst, Some(self.param_env)).map_err(|e| { let name = id.name(self.db); @@ -1964,21 +2081,24 @@ impl<'db> Evaluator<'db> { not_supported!("anonymous const evaluation") } }; - if let ConstKind::Value(value) = result_owner.kind() { - break 'b value; - } - not_supported!("unevaluatable constant"); + self.allocate_allocation_in_heap(locals, allocation) } _ => not_supported!("evaluating unknown const"), - }; - let ValueConst { ty, value } = value; - let ConstBytes { memory: v, memory_map } = value.inner(); + } + } + + fn allocate_allocation_in_heap( + &mut self, + locals: &Locals, + allocation: Allocation<'db>, + ) -> Result<'db, Interval> { + let AllocationData { ty, memory: ref v, ref memory_map } = *allocation; let patch_map = memory_map.transform_addresses(|b, align| { let addr = self.heap_allocate(b.len(), align)?; self.write_memory(addr, b)?; Ok(addr.to_usize()) })?; - let (size, align) = self.size_align_of(ty, locals)?.unwrap_or((v.len(), 1)); + let (size, align) = self.size_align_of(allocation.ty, locals)?.unwrap_or((v.len(), 1)); let v: Cow<'_, [u8]> = if size != v.len() { // Handle self enum if size == 16 && v.len() < 16 { @@ -1986,7 +2106,7 @@ impl<'db> Evaluator<'db> { } else if size < 16 && v.len() == 16 { Cow::Borrowed(&v[0..size]) } else { - return Err(MirEvalError::InvalidConst(konst.store())); + return Err(MirEvalError::InvalidConst); } } else { Cow::Borrowed(v) @@ -2211,10 +2331,10 @@ impl<'db> Evaluator<'db> { match size { Some((size, _)) => { let addr_usize = from_bytes!(usize, bytes); - mm.insert( - addr_usize, - this.read_memory(Address::from_usize(addr_usize), size)?.into(), - ) + let bytes = + this.read_memory(Address::from_usize(addr_usize), size)?.to_vec(); + mm.insert(addr_usize, bytes.clone().into()); + rec(this, &bytes, t, locals, mm, stack_depth_limit - 1)?; } None => { let mut check_inner = None; @@ -2340,7 +2460,7 @@ impl<'db> Evaluator<'db> { } AdtId::UnionId(_) => (), }, - TyKind::Alias(AliasTyKind::Projection, _) => { + TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { .. }, .. }) => { let mut ocx = ObligationCtxt::new(&this.infcx); let ty = ocx .structurally_normalize_ty( @@ -2379,9 +2499,20 @@ impl<'db> Evaluator<'db> { match size { Some(_) => { let current = from_bytes!(usize, self.read_memory(addr, my_size)?); - if let Some(it) = patch_map.get(¤t) { - self.write_memory(addr, &it.to_le_bytes())?; - } + let patched = match patch_map.get(¤t) { + Some(it) => { + self.write_memory(addr, &it.to_le_bytes())?; + *it + } + None => current, + }; + self.patch_addresses( + patch_map, + ty_of_bytes, + Address::from_usize(patched), + t, + locals, + )?; } None => { let current = from_bytes!(usize, self.read_memory(addr, my_size / 2)?); @@ -2472,7 +2603,7 @@ impl<'db> Evaluator<'db> { | TyKind::Error(_) | TyKind::Placeholder(_) | TyKind::Dynamic(_, _) - | TyKind::Alias(_, _) + | TyKind::Alias(..) | TyKind::Bound(_, _) | TyKind::Infer(_) | TyKind::Pat(_, _) @@ -2829,10 +2960,10 @@ impl<'db> Evaluator<'db> { }; let static_data = StaticSignature::of(self.db, st); let result = if !static_data.flags.contains(StaticFlags::EXTERN) { - let konst = self.db.const_eval_static(st).map_err(|e| { + let allocation = self.db.const_eval_static(st).map_err(|e| { MirEvalError::ConstEvalError(static_data.name.as_str().to_owned(), Box::new(e)) })?; - self.allocate_const_in_heap(locals, konst)? + self.allocate_allocation_in_heap(locals, allocation)? } else { let ty = InferenceResult::of(self.db, DefWithBodyId::from(st)) .expr_ty(Body::of(self.db, st.into()).root_expr()); @@ -2992,7 +3123,7 @@ impl<'db> Evaluator<'db> { pub fn render_const_using_debug_impl<'db>( db: &'db dyn HirDatabase, owner: DefWithBodyId, - c: Const<'db>, + c: Allocation<'db>, ty: Ty<'db>, ) -> Result<'db, String> { let mut evaluator = Evaluator::new(db, owner, false, None)?; @@ -3003,7 +3134,7 @@ pub fn render_const_using_debug_impl<'db>( .map_err(|_| MirEvalError::NotSupported("unreachable".to_owned()))?, drop_flags: DropFlags::default(), }; - let data = evaluator.allocate_const_in_heap(locals, c)?; + let data = evaluator.allocate_allocation_in_heap(locals, c)?; let resolver = owner.resolver(db); let Some(TypeNs::TraitId(debug_trait)) = resolver.resolve_path_in_type_ns_fully( db, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval/shim.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval/shim.rs index 2aed76ec90913..9586d38abc517 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval/shim.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/eval/shim.rs @@ -6,18 +6,16 @@ use std::cmp::{self, Ordering}; use hir_def::{attrs::AttrFlags, signatures::FunctionSignature}; use hir_expand::name::Name; use intern::sym; -use rustc_type_ir::inherent::{AdtDef, IntoKind, SliceLike, Ty as _}; +use rustc_type_ir::inherent::{AdtDef, GenericArgs as _, IntoKind, SliceLike, Ty as _}; use stdx::never; use crate::{ - InferenceResult, display::DisplayTarget, drop::{DropGlue, has_drop_glue}, mir::eval::{ - Address, AdtId, Arc, Evaluator, FunctionId, GenericArgs, HasModule, HirDisplay, - InternedClosure, Interval, IntervalAndTy, IntervalOrOwned, ItemContainerId, Layout, Locals, - Lookup, MirEvalError, MirSpan, Mutability, Result, Ty, TyKind, from_bytes, not_supported, - pad16, + Address, AdtId, Arc, Evaluator, FunctionId, GenericArgs, HasModule, HirDisplay, Interval, + IntervalAndTy, IntervalOrOwned, ItemContainerId, Layout, Locals, Lookup, MirEvalError, + MirSpan, Mutability, Result, Ty, TyKind, from_bytes, not_supported, pad16, }, next_solver::Region, }; @@ -147,19 +145,14 @@ impl<'db> Evaluator<'db> { return destination .write_from_interval(self, Interval { addr, size: destination.size }); } - TyKind::Closure(id, subst) => { - let [arg] = args else { - not_supported!("wrong arg count for clone"); - }; - let addr = Address::from_bytes(arg.get(self)?)?; - let InternedClosure(owner, _) = self.db.lookup_intern_closure(id.0); - let infer = InferenceResult::of(self.db, owner); - let (captures, _) = infer.closure_info(id.0); - let layout = self.layout(self_ty)?; - let db = self.db; - let ty_iter = captures.iter().map(|c| c.ty(db, subst)); - self.exec_clone_for_fields(ty_iter, layout, addr, def, locals, destination, span)?; - } + TyKind::Closure(_, closure_args) => self.exec_clone( + def, + args, + closure_args.as_closure().tupled_upvars_ty(), + locals, + destination, + span, + )?, TyKind::Tuple(subst) => { let [arg] = args else { not_supported!("wrong arg count for clone"); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower.rs index 44785d948a49a..0f0ed729c930a 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower.rs @@ -8,8 +8,9 @@ use hir_def::{ HasModule, ItemContainerId, LocalFieldId, Lookup, TraitId, TupleId, expr_store::{Body, ExpressionStore, HygieneId, path::Path}, hir::{ - ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, - Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, generics::GenericParams, + ArithOp, Array, BinaryOp, BindingAnnotation, BindingId, ClosureKind, ExprId, ExprOrPatId, + LabelId, Literal, MatchArm, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, + generics::GenericParams, }, item_tree::FieldsShape, lang_item::LangItems, @@ -17,10 +18,11 @@ use hir_def::{ signatures::{ConstSignature, EnumSignature, FunctionSignature, StaticSignature}, }; use hir_expand::name::Name; +use itertools::{EitherOrBoth, Itertools}; use la_arena::ArenaMap; use rustc_apfloat::Float; use rustc_hash::FxHashMap; -use rustc_type_ir::inherent::{Const as _, GenericArgs as _, IntoKind, Ty as _}; +use rustc_type_ir::inherent::{AdtDef, Const as _, GenericArgs as _, IntoKind, Ty as _}; use span::{Edition, FileId}; use syntax::TextRange; use triomphe::Arc; @@ -32,8 +34,11 @@ use crate::{ display::{DisplayTarget, HirDisplay, hir_display_with_store}, generics::generics, infer::{ - CaptureKind, CapturedItem, TypeMismatch, cast::CastTy, - closure::analysis::HirPlaceProjection, + CaptureSourceStack, CapturedPlace, TypeMismatch, UpvarCapture, + cast::CastTy, + closure::analysis::expr_use_visitor::{ + Place as HirPlace, PlaceBase as HirPlaceBase, ProjectionKind as HirProjectionKind, + }, }, inhabitedness::is_ty_uninhabited_from, layout::LayoutError, @@ -51,7 +56,6 @@ use crate::{ abi::Safety, infer::{DbInternerInferExt, InferCtxt}, }, - traits::FnTrait, }; use super::OperandKind; @@ -303,6 +307,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { locals, start_block, binding_locals, + upvar_locals: FxHashMap::default(), param_locals: vec![], owner, closures: vec![], @@ -439,7 +444,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { else { return Ok(None); }; - let bk = BorrowKind::from_rustc(m); + let bk = BorrowKind::from_rustc_mutability(m); self.push_assignment(current, place, Rvalue::Ref(bk, p), expr_id.into()); Ok(Some(current)) } @@ -956,7 +961,6 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } Expr::Await { .. } => not_supported!("await"), Expr::Yeet { .. } => not_supported!("yeet"), - Expr::Async { .. } => not_supported!("async block"), &Expr::Const(_) => { // let subst = self.placeholder_subst(); // self.lower_const( @@ -996,7 +1000,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { let Some((p, current)) = self.lower_expr_as_place(current, *expr, true)? else { return Ok(None); }; - let bk = BorrowKind::from_hir(*mutability); + let bk = BorrowKind::from_hir_mutability(*mutability); self.push_assignment(current, place, Rvalue::Ref(bk, p), expr_id.into()); Ok(Some(current)) } @@ -1245,55 +1249,64 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { ); Ok(Some(current)) } - Expr::Closure { .. } => { + Expr::Closure { closure_kind: ClosureKind::Closure, .. } => { let ty = self.expr_ty_without_adjust(expr_id); let TyKind::Closure(id, _) = ty.kind() else { not_supported!("closure with non closure type"); }; self.result.closures.push(id.0); - let (captures, _) = self.infer.closure_info(id.0); - let mut operands = vec![]; - for capture in captures.iter() { - let p = Place { - local: self.binding_local(capture.place.local)?, - projection: self.result.projection_store.intern( - capture - .place - .projections - .clone() - .into_iter() - .map(|it| match it { - HirPlaceProjection::Deref => ProjectionElem::Deref, - HirPlaceProjection::Field(field_id) => { - ProjectionElem::Field(Either::Left(field_id)) - } - HirPlaceProjection::TupleField(idx) => { - ProjectionElem::Field(Either::Right(TupleFieldId { - tuple: TupleId(!0), // Dummy as it's unused - index: idx, - })) - } - }) - .collect(), - ), + let closure_data = &self.infer.closures_data[&id.0.loc(self.db).1]; + + let span = |sources: &[CaptureSourceStack]| match sources + .first() + .map(|it| it.final_source()) + { + Some(ExprOrPatId::ExprId(it)) => it.into(), + Some(ExprOrPatId::PatId(it)) => it.into(), + None => MirSpan::Unknown, + }; + let convert_place = |this: &mut Self, place: &HirPlace| { + let (HirPlaceBase::Local(local) | HirPlaceBase::Upvar { var_id: local, .. }) = + place.base + else { + not_supported!("non-local capture"); }; - match &capture.kind { - CaptureKind::ByRef(bk) => { - let tmp_ty = capture.ty.get().instantiate_identity(); + Ok(Place { + local: this.binding_local(local)?, + projection: this + .result + .projection_store + .intern(convert_closure_capture_projections(self.db, place).collect()), + }) + }; + + for (place, _, sources) in &closure_data.fake_reads { + let p = convert_place(self, place)?; + self.push_fake_read(current, p, span(sources)); + } + + let captures = closure_data.min_captures.values().flatten(); + let mut operands = vec![]; + for capture in captures { + let p = convert_place(self, &capture.place)?; + match capture.info.capture_kind { + UpvarCapture::ByRef(bk) => { + let tmp_ty = capture.captured_ty(self.db); // FIXME: Handle more than one span. - let capture_spans = capture.spans(); - let tmp: Place = self.temp(tmp_ty, current, capture_spans[0])?.into(); + let capture_span = span(&capture.info.sources); + let tmp: Place = self.temp(tmp_ty, current, capture_span)?.into(); self.push_assignment( current, tmp, - Rvalue::Ref(*bk, p), - capture_spans[0], + Rvalue::Ref(BorrowKind::from_hir(bk), p), + capture_span, ); operands.push(Operand { kind: OperandKind::Move(tmp), span: None }); } - CaptureKind::ByValue => { + UpvarCapture::ByValue => { operands.push(Operand { kind: OperandKind::Move(p), span: None }) } + UpvarCapture::ByUse => not_supported!("capture by use"), } } self.push_assignment( @@ -1304,6 +1317,7 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { ); Ok(Some(current)) } + Expr::Closure { closure_kind, .. } => not_supported!("{closure_kind:?} closure"), Expr::Tuple { exprs } => { let Some(values) = exprs .iter() @@ -1492,8 +1506,8 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { hir_def::hir::Literal::Uint(it, _) => Box::from(&it.to_le_bytes()[0..size()?]), hir_def::hir::Literal::Float(f, _) => match size()? { 16 => Box::new(f.to_f128().to_bits().to_le_bytes()), - 8 => Box::new(f.to_f64().to_le_bytes()), - 4 => Box::new(f.to_f32().to_le_bytes()), + 8 => Box::new(f.to_f64().to_bits().to_le_bytes()), + 4 => Box::new(f.to_f32().to_bits().to_le_bytes()), 2 => Box::new(u16::try_from(f.to_f16().to_bits()).unwrap().to_le_bytes()), _ => { return Err(MirLowerError::TypeError( @@ -1527,31 +1541,14 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { subst: GenericArgs<'db>, const_id: GeneralConstId, ) -> Result<'db, Operand> { - let konst = if !subst.is_empty() { - // We can't evaluate constant with substitution now, as generics are not monomorphized in lowering. - Const::new_unevaluated( - self.interner(), - UnevaluatedConst { def: const_id.into(), args: subst }, - ) - } else { - match const_id { - id @ GeneralConstId::ConstId(const_id) => { - self.db.const_eval(const_id, subst, None).map_err(|e| { - let name = id.name(self.db); - MirLowerError::ConstEvalError(name.into(), Box::new(e)) - })? - } - GeneralConstId::StaticId(static_id) => { - self.db.const_eval_static(static_id).map_err(|e| { - let name = const_id.name(self.db); - MirLowerError::ConstEvalError(name.into(), Box::new(e)) - })? - } - GeneralConstId::AnonConstId(_) => { - return Err(MirLowerError::IncompleteExpr); - } - } - }; + if matches!(const_id, GeneralConstId::AnonConstId(_)) { + // FIXME: + not_supported!("anon consts are not supported yet in const eval"); + } + let konst = Const::new_unevaluated( + self.interner(), + UnevaluatedConst { def: const_id.into(), args: subst }, + ); let ty = self .db .value_ty(match const_id { @@ -2084,6 +2081,44 @@ impl<'a, 'db> MirLowerCtx<'a, 'db> { } } +fn convert_closure_capture_projections( + db: &dyn HirDatabase, + place: &HirPlace, +) -> impl Iterator { + place.projections.iter().enumerate().map(|(i, proj)| match proj.kind { + HirProjectionKind::Deref => ProjectionElem::Deref, + HirProjectionKind::Field { field_idx, variant_idx } => { + let ty = place.ty_before_projection(i); + match ty.kind() { + TyKind::Tuple(_) => { + ProjectionElem::Field(Either::Right(TupleFieldId { + tuple: TupleId(!0), // Dummy as it's unused + index: field_idx, + })) + } + TyKind::Adt(adt_def, _) => { + let local_field_id = LocalFieldId::from_raw(RawIdx::from_u32(field_idx)); + let field = match adt_def.def_id().0 { + AdtId::StructId(id) => { + FieldId { parent: id.into(), local_id: local_field_id } + } + AdtId::UnionId(id) => { + FieldId { parent: id.into(), local_id: local_field_id } + } + AdtId::EnumId(id) => { + let variant = id.enum_variants(db).variants[variant_idx as usize].0; + FieldId { parent: variant.into(), local_id: local_field_id } + } + }; + ProjectionElem::Field(Either::Left(field)) + } + _ => panic!("unexpected type"), + } + } + _ => panic!("unexpected projection"), + }) +} + fn cast_kind<'db>( db: &'db dyn HirDatabase, source_ty: Ty<'db>, @@ -2110,7 +2145,7 @@ pub fn mir_body_for_closure_query<'db>( db: &'db dyn HirDatabase, closure: InternedClosureId, ) -> Result<'db, Arc> { - let InternedClosure(owner, expr) = db.lookup_intern_closure(closure); + let InternedClosure(owner, expr) = closure.loc(db); let body_owner = owner.as_def_with_body().expect("MIR lowering should only happen for body-owned closures"); let body = Body::of(db, body_owner); @@ -2121,20 +2156,22 @@ pub fn mir_body_for_closure_query<'db>( let crate::next_solver::TyKind::Closure(_, substs) = infer.expr_ty(expr).kind() else { implementation_error!("closure expression is not closure"); }; - let (captures, kind) = infer.closure_info(closure); + let kind = substs.as_closure().kind(); + let captures = infer.closures_data[&expr].min_captures.values().flatten(); let mut ctx = MirLowerCtx::new(db, body_owner, &body.store, infer); + // 0 is return local ctx.result.locals.alloc(Local { ty: infer.expr_ty(*root).store() }); let closure_local = ctx.result.locals.alloc(Local { ty: match kind { - FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer.expr_ty(expr), - FnTrait::FnMut | FnTrait::AsyncFnMut => Ty::new_ref( + rustc_type_ir::ClosureKind::FnOnce => infer.expr_ty(expr), + rustc_type_ir::ClosureKind::FnMut => Ty::new_ref( ctx.interner(), Region::error(ctx.interner()), infer.expr_ty(expr), Mutability::Mut, ), - FnTrait::Fn | FnTrait::AsyncFn => Ty::new_ref( + rustc_type_ir::ClosureKind::Fn => Ty::new_ref( ctx.interner(), Region::error(ctx.interner()), infer.expr_ty(expr), @@ -2144,6 +2181,7 @@ pub fn mir_body_for_closure_query<'db>( .store(), }); ctx.result.param_locals.push(closure_local); + let sig = ctx.interner().signature_unclosure(substs.as_closure().sig(), Safety::Safe); let resolver_guard = ctx.resolver.update_to_inner_scope(db, body_owner, expr); let current = ctx.lower_params_and_bindings( @@ -2151,60 +2189,101 @@ pub fn mir_body_for_closure_query<'db>( None, |_| true, )?; + + // Push local for every upvar in the closure. rustc doesn't do that, but we have to so we have locals + // to associate with upvars for borrowck. + let is_by_ref_closure = match kind { + rustc_type_ir::ClosureKind::Fn | rustc_type_ir::ClosureKind::FnMut => true, + rustc_type_ir::ClosureKind::FnOnce => false, + }; + let mut upvar_map: FxHashMap> = FxHashMap::default(); + for (capture_idx, capture) in captures.enumerate() { + let capture_local = ctx.result.locals.alloc(Local { ty: capture.captured_ty(db).store() }); + ctx.push_storage_live_for_local(capture_local, current, MirSpan::Unknown)?; + let mut projections = Vec::with_capacity(usize::from(is_by_ref_closure) + 1); + if is_by_ref_closure { + projections.push(ProjectionElem::Deref); + } + projections.push(ProjectionElem::ClosureField(capture_idx)); + let capture_param_place = Place { + local: closure_local, + projection: ctx.result.projection_store.intern(projections.into_boxed_slice()), + }; + let capture_local_place = Place { + local: capture_local, + projection: ctx.result.projection_store.intern(Box::new([])), + }; + let capture_local_rvalue = + Rvalue::Use(Operand { kind: OperandKind::Move(capture_param_place), span: None }); + ctx.push_assignment(current, capture_local_place, capture_local_rvalue, MirSpan::Unknown); + + let local = capture.captured_local(); + let local = ctx.binding_local(local)?; + upvar_map.entry(local).or_default().push((capture, capture_local)); + + ctx.result + .upvar_locals + .entry(capture.captured_local()) + .or_default() + .push((capture_local, capture.place.clone())); + } + ctx.resolver.reset_to_guard(resolver_guard); if let Some(current) = ctx.lower_expr_to_place(*root, return_slot().into(), current)? { let current = ctx.pop_drop_scope_assert_finished(current, root.into())?; ctx.set_terminator(current, TerminatorKind::Return, (*root).into()); } - let mut upvar_map: FxHashMap> = FxHashMap::default(); - for (i, capture) in captures.iter().enumerate() { - let local = ctx.binding_local(capture.place.local)?; - upvar_map.entry(local).or_default().push((capture, i)); - } + let mut err = None; - let closure_local = ctx.result.locals.iter().nth(1).unwrap().0; - let closure_projection = match kind { - FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![], - FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => { - vec![ProjectionElem::Deref] - } - }; - ctx.result.walk_places(|p, store| { - if let Some(it) = upvar_map.get(&p.local) { - let r = it.iter().find(|it| { - if p.projection.lookup(store).len() < it.0.place.projections.len() { - return false; - } - for (it, y) in p.projection.lookup(store).iter().zip(it.0.place.projections.iter()) - { - match (it, y) { - (ProjectionElem::Deref, HirPlaceProjection::Deref) => (), - (ProjectionElem::Field(Either::Left(it)), HirPlaceProjection::Field(y)) - if it == y => {} - ( - ProjectionElem::Field(Either::Right(it)), - HirPlaceProjection::TupleField(y), - ) if it.index == *y => (), - _ => return false, + ctx.result.walk_places(|mir_place, store| { + let mir_projections = mir_place.projection.lookup(store); + if let Some(hir_places) = upvar_map.get(&mir_place.local) { + let projections = hir_places.iter().find_map(|hir_place| { + let iter = mir_projections + .iter() + .cloned() + .zip_longest(convert_closure_capture_projections(db, &hir_place.0.place)) + .enumerate(); + + for (idx, item) in iter { + match item { + EitherOrBoth::Both(mir, hir) => { + if mir != hir { + // Not this place. + return None; + } + } + EitherOrBoth::Right(_) => { + // FIXME: This can happen in fake reads. I believe this is a bug. So we change the fake read's meaning. + // never!( + // "mir upvar place shorter than hir upvar place; this should not happen, \ + // capture analysis should have picked the shorter place" + // ); + // return None; + return Some((mir_projections.len(), hir_place)); + } + // This place, but truncated. + EitherOrBoth::Left(_) => return Some((idx, hir_place)), } } - true + // Exactly this place. + Some((hir_place.0.place.projections.len(), hir_place)) }); - match r { - Some(it) => { - p.local = closure_local; - let mut next_projs = closure_projection.clone(); - next_projs.push(PlaceElem::ClosureField(it.1)); - let prev_projs = p.projection; - if it.0.kind != CaptureKind::ByValue { - next_projs.push(ProjectionElem::Deref); - } - next_projs.extend( - prev_projs.lookup(store).iter().skip(it.0.place.projections.len()).cloned(), + match projections { + Some((skip_projections_up_to, (hir_place, upvar_local))) => { + mir_place.local = *upvar_local; + let mut result_projections = Vec::with_capacity( + usize::from(hir_place.is_by_ref()) + + (mir_projections.len() - skip_projections_up_to), ); - p.projection = store.intern(next_projs.into()); + if hir_place.is_by_ref() { + result_projections.push(ProjectionElem::Deref); + } + result_projections + .extend(mir_projections[skip_projections_up_to..].iter().cloned()); + mir_place.projection = store.intern(result_projections.into()); } - None => err = Some(*p), + None => err = Some(*mir_place), } } }); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower/as_place.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower/as_place.rs index 17dc95fb248a3..fb4a9add818f3 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower/as_place.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/lower/as_place.rs @@ -10,12 +10,6 @@ use crate::{ next_solver::Region, }; -macro_rules! not_supported { - ($it: expr) => { - return Err(MirLowerError::NotSupported(format!($it))) - }; -} - impl<'db> MirLowerCtx<'_, 'db> { fn lower_expr_to_some_place_without_adjust( &mut self, @@ -98,11 +92,8 @@ impl<'db> MirLowerCtx<'_, 'db> { last.target.as_ref(), expr_id.into(), match od.0 { - Some(Mutability::Mut) => true, - Some(Mutability::Not) => false, - None => { - not_supported!("implicit overloaded deref with unknown mutability") - } + Mutability::Mut => true, + Mutability::Not => false, }, ) } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/monomorphization.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/monomorphization.rs index 5752a3d7fae4b..41044f00c2e96 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/monomorphization.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/monomorphization.rs @@ -16,7 +16,10 @@ use triomphe::Arc; use crate::{ ParamEnvAndCrate, - next_solver::{Const, ConstKind, Region, RegionKind, StoredConst, StoredGenericArgs, StoredTy}, + next_solver::{ + Allocation, AllocationData, Const, ConstKind, Region, RegionKind, StoredConst, + StoredGenericArgs, StoredTy, + }, traits::StoredParamEnvAndCrate, }; use crate::{ @@ -138,6 +141,18 @@ impl<'db> Filler<'db> { self.fill_const(konst)?; self.fill_ty(ty)?; } + OperandKind::Allocation { allocation } => { + let alloc = allocation.as_ref(); + let mut ty = alloc.ty.store(); + self.fill_ty(&mut ty)?; + *allocation = Allocation::new(AllocationData { + ty: ty.as_ref(), + memory: alloc.memory.clone(), + // FIXME: Do we need to fill the memory map too? + memory_map: alloc.memory_map.clone(), + }) + .store(); + } OperandKind::Copy(_) | OperandKind::Move(_) | OperandKind::Static(_) => (), } Ok(()) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/mir/pretty.rs b/src/tools/rust-analyzer/crates/hir-ty/src/mir/pretty.rs index 4b654a0fbe085..de5ee223a1487 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/mir/pretty.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/mir/pretty.rs @@ -387,6 +387,9 @@ impl<'a, 'db> MirPrettyCtx<'a, 'db> { w!(self, "Const({})", self.hir_display(&konst.as_ref())) } OperandKind::Static(s) => w!(self, "Static({:?})", s), + OperandKind::Allocation { allocation } => { + w!(self, "Allocation({})", self.hir_display(&allocation.as_ref())) + } } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver.rs index 605e31404c575..161a3142df556 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver.rs @@ -5,6 +5,7 @@ // incorrect lifetime here. pub mod abi; +mod allocation; mod binder; mod consts; mod def_id; @@ -29,6 +30,7 @@ pub mod util; use std::{mem::ManuallyDrop, sync::OnceLock}; +pub use allocation::*; pub use binder::*; pub use consts::*; pub use def_id::*; @@ -89,6 +91,7 @@ pub struct DefaultTypes<'db> { pub struct DefaultConsts<'db> { pub error: Const<'db>, + pub u8_values: [Const<'db>; 256], } pub struct DefaultRegions<'db> { @@ -101,7 +104,7 @@ pub struct DefaultEmpty<'db> { pub tys: Tys<'db>, pub generic_args: GenericArgs<'db>, pub bound_var_kinds: BoundVarKinds<'db>, - pub canonical_vars: CanonicalVars<'db>, + pub canonical_vars: CanonicalVarKinds<'db>, pub variances: VariancesOf<'db>, pub pat_list: PatList<'db>, pub predefined_opaques: PredefinedOpaques<'db>, @@ -109,6 +112,7 @@ pub struct DefaultEmpty<'db> { pub bound_existential_predicates: BoundExistentialPredicates<'db>, pub clauses: Clauses<'db>, pub region_assumptions: RegionAssumptions<'db>, + pub consts: Consts<'db>, } pub struct DefaultAny<'db> { @@ -167,7 +171,7 @@ pub fn default_types<'a, 'db>(db: &'db dyn HirDatabase) -> &'a DefaultAny<'db> { ty.as_ref() }; let create_canonical_vars = |slice| { - let ty = CanonicalVars::new_from_slice(slice); + let ty = CanonicalVarKinds::new_from_slice(slice); // We need to increase the refcount (forever), so that the types won't be freed. let ty = ManuallyDrop::new(ty.store()); ty.as_ref() @@ -220,15 +224,22 @@ pub fn default_types<'a, 'db>(db: &'db dyn HirDatabase) -> &'a DefaultAny<'db> { let ty = ManuallyDrop::new(ty.store()); ty.as_ref() }; + let create_consts = |slice| { + let ty = Consts::new_from_slice(slice); + // We need to increase the refcount (forever), so that the types won't be freed. + let ty = ManuallyDrop::new(ty.store()); + ty.as_ref() + }; let str = create_ty(TyKind::Str); let statik = create_region(RegionKind::ReStatic); let empty_tys = create_tys(&[]); let unit = create_ty(TyKind::Tuple(empty_tys)); + let u8 = create_ty(TyKind::Uint(rustc_ast_ir::UintTy::U8)); DefaultAny { types: DefaultTypes { usize: create_ty(TyKind::Uint(rustc_ast_ir::UintTy::Usize)), - u8: create_ty(TyKind::Uint(rustc_ast_ir::UintTy::U8)), + u8, u16: create_ty(TyKind::Uint(rustc_ast_ir::UintTy::U16)), u32: create_ty(TyKind::Uint(rustc_ast_ir::UintTy::U32)), u64: create_ty(TyKind::Uint(rustc_ast_ir::UintTy::U64)), @@ -252,7 +263,15 @@ pub fn default_types<'a, 'db>(db: &'db dyn HirDatabase) -> &'a DefaultAny<'db> { static_str_ref: create_ty(TyKind::Ref(statik, str, rustc_ast_ir::Mutability::Not)), mut_unit_ptr: create_ty(TyKind::RawPtr(unit, rustc_ast_ir::Mutability::Mut)), }, - consts: DefaultConsts { error: create_const(ConstKind::Error(ErrorGuaranteed)) }, + consts: DefaultConsts { + error: create_const(ConstKind::Error(ErrorGuaranteed)), + u8_values: std::array::from_fn(|u8_value| { + create_const(ConstKind::Value(ValueConst { + ty: u8, + value: ValTree::new(ValTreeKind::Leaf(ScalarInt::from(u8_value as u8))), + })) + }), + }, regions: DefaultRegions { error: create_region(RegionKind::ReError(ErrorGuaranteed)), statik, @@ -270,11 +289,12 @@ pub fn default_types<'a, 'db>(db: &'db dyn HirDatabase) -> &'a DefaultAny<'db> { bound_existential_predicates: create_bound_existential_predicates(&[]), clauses: create_clauses(&[]), region_assumptions: create_region_assumptions(&[]), + consts: create_consts(&[]), }, one_invariant: create_variances_of(&[rustc_type_ir::Variance::Invariant]), one_covariant: create_variances_of(&[rustc_type_ir::Variance::Covariant]), coroutine_captures_by_ref_bound_var_kinds: create_bound_var_kinds(&[ - BoundVarKind::Region(BoundRegionKind::ClosureEnv), + BoundVariableKind::Region(BoundRegionKind::ClosureEnv), ]), } }) diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/allocation.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/allocation.rs new file mode 100644 index 0000000000000..d299c89c12eae --- /dev/null +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/allocation.rs @@ -0,0 +1,73 @@ +use std::{fmt, hash::Hash}; + +use intern::{Interned, InternedRef, impl_internable}; +use macros::GenericTypeVisitable; +use rustc_type_ir::GenericTypeVisitable; + +use crate::{ + MemoryMap, + next_solver::{Ty, impl_stored_interned}, +}; + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct Allocation<'db> { + interned: InternedRef<'db, AllocationInterned>, +} + +impl<'db> Allocation<'db> { + pub fn new(data: AllocationData<'db>) -> Self { + let data = + unsafe { std::mem::transmute::, AllocationData<'static>>(data) }; + Self { interned: Interned::new_gc(AllocationInterned(data)) } + } +} + +impl<'db> std::ops::Deref for Allocation<'db> { + type Target = AllocationData<'db>; + + #[inline] + fn deref(&self) -> &Self::Target { + let inner = &self.interned.0; + unsafe { std::mem::transmute::<&AllocationData<'static>, &AllocationData<'db>>(inner) } + } +} + +impl<'db, V: super::WorldExposer> GenericTypeVisitable for Allocation<'db> { + fn generic_visit_with(&self, visitor: &mut V) { + if visitor.on_interned(self.interned).is_continue() { + (**self).generic_visit_with(visitor); + } + } +} + +impl fmt::Debug for Allocation<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let AllocationData { ty, memory, memory_map } = &**self; + f.debug_struct("Allocation") + .field("ty", ty) + .field("memory", memory) + .field("memory_map", memory_map) + .finish() + } +} + +#[derive(PartialEq, Eq, Hash, GenericTypeVisitable)] +pub(super) struct AllocationInterned(AllocationData<'static>); + +#[derive(Debug, PartialEq, Eq, GenericTypeVisitable)] +pub struct AllocationData<'db> { + pub ty: Ty<'db>, + pub memory: Box<[u8]>, + pub memory_map: MemoryMap<'db>, +} + +impl<'db> Hash for AllocationData<'db> { + fn hash(&self, state: &mut H) { + let Self { ty, memory, memory_map: _ } = self; + ty.hash(state); + memory.hash(state); + } +} + +impl_internable!(gc; AllocationInterned); +impl_stored_interned!(AllocationInterned, Allocation, StoredAllocation); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts.rs index 9643f1ba4c3a3..fa90e3d8a004b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts.rs @@ -1,5 +1,7 @@ //! Things related to consts in the next-trait-solver. +mod valtree; + use std::hash::Hash; use hir_def::ConstParamId; @@ -9,17 +11,19 @@ use rustc_ast_ir::visit::VisitorResult; use rustc_type_ir::{ BoundVar, BoundVarIndexKind, ConstVid, DebruijnIndex, FlagComputation, Flags, GenericTypeVisitable, InferConst, TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, - TypeVisitable, TypeVisitableExt, WithCachedTypeInfo, - inherent::{IntoKind, ParamEnv as _, PlaceholderLike, SliceLike}, - relate::Relate, + TypeVisitable, WithCachedTypeInfo, inherent::IntoKind, relate::Relate, }; use crate::{ - MemoryMap, - next_solver::{ClauseKind, ParamEnv, impl_stored_interned}, + ParamEnvAndCrate, + next_solver::{ + AllocationData, impl_foldable_for_interned_slice, impl_stored_interned, interned_slice, + }, }; -use super::{BoundVarKind, DbInterner, ErrorGuaranteed, GenericArgs, Placeholder, Ty}; +use super::{DbInterner, ErrorGuaranteed, GenericArgs, Ty}; + +pub use self::valtree::*; pub type ConstKind<'db> = rustc_type_ir::ConstKind>; pub type UnevaluatedConst<'db> = rustc_type_ir::UnevaluatedConst>; @@ -71,26 +75,33 @@ impl<'db> Const<'db> { Const::new(interner, ConstKind::Param(param)) } - pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderConst) -> Self { + pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderConst<'db>) -> Self { Const::new(interner, ConstKind::Placeholder(placeholder)) } - pub fn new_bound(interner: DbInterner<'db>, index: DebruijnIndex, bound: BoundConst) -> Self { + pub fn new_bound( + interner: DbInterner<'db>, + index: DebruijnIndex, + bound: BoundConst<'db>, + ) -> Self { Const::new(interner, ConstKind::Bound(BoundVarIndexKind::Bound(index), bound)) } - pub fn new_valtree( + pub fn new_valtree(interner: DbInterner<'db>, ty: Ty<'db>, kind: ValTreeKind<'db>) -> Self { + Const::new(interner, ConstKind::Value(ValueConst { ty, value: ValTree::new(kind) })) + } + + pub fn new_from_allocation( interner: DbInterner<'db>, - ty: Ty<'db>, - memory: Box<[u8]>, - memory_map: MemoryMap<'db>, + allocation: &AllocationData<'db>, + param_env: ParamEnvAndCrate<'db>, ) -> Self { - Const::new( + allocation_to_const( interner, - ConstKind::Value(ValueConst { - ty, - value: Valtree::new(ConstBytes { memory, memory_map }), - }), + allocation.ty, + &allocation.memory, + &allocation.memory_map, + param_env, ) } @@ -120,7 +131,7 @@ impl<'db> std::fmt::Debug for Const<'db> { } } -pub type PlaceholderConst = Placeholder; +pub type PlaceholderConst<'db> = rustc_type_ir::PlaceholderConst>; #[derive(Copy, Clone, Hash, Eq, PartialEq)] pub struct ParamConst { @@ -135,126 +146,6 @@ impl std::fmt::Debug for ParamConst { } } -impl ParamConst { - pub fn find_const_ty_from_env<'db>(self, env: ParamEnv<'db>) -> Ty<'db> { - let mut candidates = env.caller_bounds().iter().filter_map(|clause| { - // `ConstArgHasType` are never desugared to be higher ranked. - match clause.kind().skip_binder() { - ClauseKind::ConstArgHasType(param_ct, ty) => { - assert!(!(param_ct, ty).has_escaping_bound_vars()); - - match param_ct.kind() { - ConstKind::Param(param_ct) if param_ct.index == self.index => Some(ty), - _ => None, - } - } - _ => None, - } - }); - - // N.B. it may be tempting to fix ICEs by making this function return - // `Option>` instead of `Ty<'db>`; however, this is generally - // considered to be a bandaid solution, since it hides more important - // underlying issues with how we construct generics and predicates of - // items. It's advised to fix the underlying issue rather than trying - // to modify this function. - let ty = candidates.next().unwrap_or_else(|| { - panic!("cannot find `{self:?}` in param-env: {env:#?}"); - }); - assert!( - candidates.next().is_none(), - "did not expect duplicate `ConstParamHasTy` for `{self:?}` in param-env: {env:#?}" - ); - ty - } -} - -/// A type-level constant value. -/// -/// Represents a typed, fully evaluated constant. -#[derive( - Debug, Copy, Clone, Eq, PartialEq, Hash, TypeFoldable, TypeVisitable, GenericTypeVisitable, -)] -pub struct ValueConst<'db> { - pub ty: Ty<'db>, - // FIXME: Should we ignore this for TypeVisitable, TypeFoldable? - #[type_visitable(ignore)] - #[type_foldable(identity)] - pub value: Valtree<'db>, -} - -impl<'db> ValueConst<'db> { - pub fn new(ty: Ty<'db>, bytes: ConstBytes<'db>) -> Self { - let value = Valtree::new(bytes); - ValueConst { ty, value } - } -} - -impl<'db> rustc_type_ir::inherent::ValueConst> for ValueConst<'db> { - fn ty(self) -> Ty<'db> { - self.ty - } - - fn valtree(self) -> Valtree<'db> { - self.value - } -} - -#[derive(Debug, Clone, PartialEq, Eq, GenericTypeVisitable)] -pub struct ConstBytes<'db> { - pub memory: Box<[u8]>, - pub memory_map: MemoryMap<'db>, -} - -impl Hash for ConstBytes<'_> { - fn hash(&self, state: &mut H) { - self.memory.hash(state) - } -} - -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub struct Valtree<'db> { - interned: InternedRef<'db, ValtreeInterned>, -} - -impl<'db, V: super::WorldExposer> GenericTypeVisitable for Valtree<'db> { - fn generic_visit_with(&self, visitor: &mut V) { - if visitor.on_interned(self.interned).is_continue() { - self.inner().generic_visit_with(visitor); - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, GenericTypeVisitable)] -pub(super) struct ValtreeInterned(ConstBytes<'static>); - -impl_internable!(gc; ValtreeInterned); - -const _: () = { - const fn is_copy() {} - is_copy::>(); -}; - -impl<'db> Valtree<'db> { - #[inline] - pub fn new(bytes: ConstBytes<'db>) -> Self { - let bytes = unsafe { std::mem::transmute::, ConstBytes<'static>>(bytes) }; - Self { interned: Interned::new_gc(ValtreeInterned(bytes)) } - } - - #[inline] - pub fn inner(&self) -> &ConstBytes<'db> { - let inner = &self.interned.0; - unsafe { std::mem::transmute::<&ConstBytes<'static>, &ConstBytes<'db>>(inner) } - } -} - -impl std::fmt::Debug for Valtree<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.interned.fmt(f) - } -} - #[derive( Copy, Clone, Debug, Hash, PartialEq, Eq, TypeVisitable, TypeFoldable, GenericTypeVisitable, )] @@ -388,25 +279,22 @@ impl<'db> rustc_type_ir::inherent::Const> for Const<'db> { Const::new(interner, ConstKind::Infer(InferConst::Var(var))) } - fn new_bound(interner: DbInterner<'db>, debruijn: DebruijnIndex, var: BoundConst) -> Self { + fn new_bound(interner: DbInterner<'db>, debruijn: DebruijnIndex, var: BoundConst<'db>) -> Self { Const::new(interner, ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), var)) } fn new_anon_bound(interner: DbInterner<'db>, debruijn: DebruijnIndex, var: BoundVar) -> Self { Const::new( interner, - ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), BoundConst { var }), + ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), BoundConst::new(var)), ) } fn new_canonical_bound(interner: DbInterner<'db>, var: BoundVar) -> Self { - Const::new(interner, ConstKind::Bound(BoundVarIndexKind::Canonical, BoundConst { var })) + Const::new(interner, ConstKind::Bound(BoundVarIndexKind::Canonical, BoundConst::new(var))) } - fn new_placeholder( - interner: DbInterner<'db>, - param: as rustc_type_ir::Interner>::PlaceholderConst, - ) -> Self { + fn new_placeholder(interner: DbInterner<'db>, param: PlaceholderConst<'db>) -> Self { Const::new(interner, ConstKind::Placeholder(param)) } @@ -426,43 +314,7 @@ impl<'db> rustc_type_ir::inherent::Const> for Const<'db> { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -pub struct BoundConst { - pub var: BoundVar, -} - -impl<'db> rustc_type_ir::inherent::BoundVarLike> for BoundConst { - fn var(self) -> BoundVar { - self.var - } - - fn assert_eq(self, var: BoundVarKind) { - var.expect_const() - } -} - -impl<'db> PlaceholderLike> for PlaceholderConst { - type Bound = BoundConst; - - fn universe(self) -> rustc_type_ir::UniverseIndex { - self.universe - } - - fn var(self) -> rustc_type_ir::BoundVar { - self.bound.var - } - - fn with_updated_universe(self, ui: rustc_type_ir::UniverseIndex) -> Self { - Placeholder { universe: ui, bound: self.bound } - } - - fn new(ui: rustc_type_ir::UniverseIndex, var: BoundConst) -> Self { - Placeholder { universe: ui, bound: var } - } - fn new_anon(ui: rustc_type_ir::UniverseIndex, var: rustc_type_ir::BoundVar) -> Self { - Placeholder { universe: ui, bound: BoundConst { var } } - } -} +pub type BoundConst<'db> = rustc_type_ir::BoundConst>; impl<'db> Relate> for ExprConst { fn relate>>( @@ -483,3 +335,6 @@ impl<'db> rustc_type_ir::inherent::ExprConst> for ExprConst { GenericArgs::default() } } + +interned_slice!(ConstsStorage, Consts, StoredConsts, consts, Const<'db>, Const<'static>); +impl_foldable_for_interned_slice!(Consts); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts/valtree.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts/valtree.rs new file mode 100644 index 0000000000000..b856ee5a85a76 --- /dev/null +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/consts/valtree.rs @@ -0,0 +1,712 @@ +use std::{fmt, hash::Hash, num::NonZero}; + +use intern::{Interned, InternedRef, impl_internable}; +use macros::{GenericTypeVisitable, TypeFoldable, TypeVisitable}; +use rustc_abi::{Size, TargetDataLayout}; +use rustc_type_ir::{GenericTypeVisitable, TypeFoldable, TypeVisitable, inherent::IntoKind}; +use stdx::never; + +use crate::{ + MemoryMap, ParamEnvAndCrate, consteval, + mir::pad16, + next_solver::{Const, Consts, TyKind, WorldExposer}, +}; + +use super::{DbInterner, Ty}; + +pub type ValTreeKind<'db> = rustc_type_ir::ValTreeKind>; + +/// A type-level constant value. +/// +/// Represents a typed, fully evaluated constant. +#[derive( + Debug, Copy, Clone, Eq, PartialEq, Hash, TypeFoldable, TypeVisitable, GenericTypeVisitable, +)] +pub struct ValueConst<'db> { + pub ty: Ty<'db>, + pub value: ValTree<'db>, +} + +impl<'db> ValueConst<'db> { + pub fn new(ty: Ty<'db>, kind: ValTreeKind<'db>) -> Self { + let value = ValTree::new(kind); + ValueConst { ty, value } + } +} + +pub(super) fn allocation_to_const<'db>( + interner: DbInterner<'db>, + ty: Ty<'db>, + memory: &[u8], + memory_map: &MemoryMap<'db>, + param_env: ParamEnvAndCrate<'db>, +) -> Const<'db> { + let Ok(data_layout) = interner.db.target_data_layout(param_env.krate) else { + return Const::error(interner); + }; + let valtree = match ty.kind() { + TyKind::Bool => ValTreeKind::Leaf(ScalarInt::from(memory[0] != 0)), + TyKind::Char => { + let it = u128::from_le_bytes(pad16(memory, false)) as u32; + let Ok(c) = char::try_from(it) else { + return Const::error(interner); + }; + ValTreeKind::Leaf(ScalarInt::from(c)) + } + TyKind::Int(int) => { + let it = i128::from_le_bytes(pad16(memory, true)); + let size = int.bit_width().map(Size::from_bits).unwrap_or(data_layout.pointer_size()); + let scalar = ScalarInt::try_from_int(it, size).unwrap(); + ValTreeKind::Leaf(scalar) + } + TyKind::Uint(uint) => { + let it = u128::from_le_bytes(pad16(memory, false)); + let size = uint.bit_width().map(Size::from_bits).unwrap_or(data_layout.pointer_size()); + let scalar = ScalarInt::try_from_uint(it, size).unwrap(); + ValTreeKind::Leaf(scalar) + } + TyKind::Float(float) => { + let scalar = match float { + rustc_ast_ir::FloatTy::F16 => { + ScalarInt::from(u16::from_le_bytes(memory.try_into().unwrap())) + } + rustc_ast_ir::FloatTy::F32 => { + ScalarInt::from(u32::from_le_bytes(memory.try_into().unwrap())) + } + rustc_ast_ir::FloatTy::F64 => { + ScalarInt::from(u64::from_le_bytes(memory.try_into().unwrap())) + } + rustc_ast_ir::FloatTy::F128 => { + ScalarInt::from(u128::from_le_bytes(memory.try_into().unwrap())) + } + }; + ValTreeKind::Leaf(scalar) + } + TyKind::Ref(_, t, _) => match t.kind() { + TyKind::Str => { + let addr = usize::from_le_bytes(memory[0..memory.len() / 2].try_into().unwrap()); + let size = usize::from_le_bytes(memory[memory.len() / 2..].try_into().unwrap()); + let Some(bytes) = memory_map.get(addr, size) else { + return Const::error(interner); + }; + let u8_values = &interner.default_types().consts.u8_values; + ValTreeKind::Branch(Consts::new_from_iter( + interner, + bytes.iter().map(|&byte| u8_values[usize::from(byte)]), + )) + } + TyKind::Slice(ty) => { + let addr = usize::from_le_bytes(memory[0..memory.len() / 2].try_into().unwrap()); + let count = usize::from_le_bytes(memory[memory.len() / 2..].try_into().unwrap()); + let Ok(layout) = interner.db.layout_of_ty(ty.store(), param_env.store()) else { + return Const::error(interner); + }; + let size_one = layout.size.bytes_usize(); + let Some(bytes) = memory_map.get(addr, size_one * count) else { + return Const::error(interner); + }; + let expected_len = count * size_one; + if bytes.len() < expected_len { + never!( + "Memory map size is too small. Expected {expected_len}, got {}", + bytes.len(), + ); + return Const::error(interner); + } + let items = (0..count).map(|i| { + let offset = size_one * i; + let bytes = &bytes[offset..offset + size_one]; + allocation_to_const(interner, ty, bytes, memory_map, param_env) + }); + ValTreeKind::Branch(Consts::new_from_iter(interner, items)) + } + TyKind::Dynamic(_, _) => { + let addr = usize::from_le_bytes(memory[0..memory.len() / 2].try_into().unwrap()); + let ty_id = usize::from_le_bytes(memory[memory.len() / 2..].try_into().unwrap()); + let Ok(t) = memory_map.vtable_ty(ty_id) else { + return Const::error(interner); + }; + let Ok(layout) = interner.db.layout_of_ty(t.store(), param_env.store()) else { + return Const::error(interner); + }; + let size = layout.size.bytes_usize(); + let Some(bytes) = memory_map.get(addr, size) else { + return Const::error(interner); + }; + return allocation_to_const(interner, t, bytes, memory_map, param_env); + } + TyKind::Adt(..) if memory.len() == 2 * size_of::() => { + // FIXME: Unsized ADT. + return Const::error(interner); + } + _ => { + let addr = usize::from_le_bytes(match memory.try_into() { + Ok(b) => b, + Err(_) => { + never!( + "tried rendering ty {:?} in const ref with incorrect byte count {}", + t, + memory.len() + ); + return Const::error(interner); + } + }); + let Ok(layout) = interner.db.layout_of_ty(t.store(), param_env.store()) else { + return Const::error(interner); + }; + let size = layout.size.bytes_usize(); + let Some(bytes) = memory_map.get(addr, size) else { + return Const::error(interner); + }; + return allocation_to_const(interner, t, bytes, memory_map, param_env); + } + }, + TyKind::Tuple(tys) => { + let Ok(layout) = interner.db.layout_of_ty(ty.store(), param_env.store()) else { + return Const::error(interner); + }; + let items = tys.iter().enumerate().map(|(id, ty)| { + let offset = layout.fields.offset(id).bytes_usize(); + let Ok(layout) = interner.db.layout_of_ty(ty.store(), param_env.store()) else { + return Const::error(interner); + }; + let size = layout.size.bytes_usize(); + allocation_to_const( + interner, + ty, + &memory[offset..offset + size], + memory_map, + param_env, + ) + }); + ValTreeKind::Branch(Consts::new_from_iter(interner, items)) + } + TyKind::Adt(..) => { + // FIXME: This requires `adt_const_params`. + return Const::error(interner); + } + TyKind::FnDef(..) => { + // FIXME: Fn items. + return Const::error(interner); + } + TyKind::FnPtr(_, _) | TyKind::RawPtr(_, _) => { + let it = u128::from_le_bytes(pad16(memory, false)); + // FIXME: Unsized pointers. + let scalar = ScalarInt::try_from_uint(it, data_layout.pointer_size()).unwrap(); + ValTreeKind::Leaf(scalar) + } + TyKind::Array(ty, len) => { + let Some(len) = consteval::try_const_usize(interner.db, len) else { + return Const::error(interner); + }; + let Ok(layout) = interner.db.layout_of_ty(ty.store(), param_env.store()) else { + return Const::error(interner); + }; + let size_one = layout.size.bytes_usize(); + let items = (0..len as usize).map(|i| { + let offset = size_one * i; + allocation_to_const( + interner, + ty, + &memory[offset..offset + size_one], + memory_map, + param_env, + ) + }); + ValTreeKind::Branch(Consts::new_from_iter(interner, items)) + } + TyKind::Never => return Const::error(interner), + // FIXME: + TyKind::Closure(_, _) + | TyKind::Coroutine(_, _) + | TyKind::CoroutineWitness(_, _) + | TyKind::CoroutineClosure(_, _) + | TyKind::UnsafeBinder(_) => return Const::error(interner), + // The below arms are unreachable, since const eval will bail out before here. + TyKind::Foreign(_) => return Const::error(interner), + TyKind::Pat(_, _) => return Const::error(interner), + TyKind::Error(..) + | TyKind::Placeholder(_) + | TyKind::Alias(..) + | TyKind::Param(_) + | TyKind::Bound(_, _) + | TyKind::Infer(_) => return Const::error(interner), + // The below arms are unreachable, since we handled them in ref case. + TyKind::Slice(_) | TyKind::Str | TyKind::Dynamic(_, _) => { + return Const::error(interner); + } + }; + Const::new_valtree(interner, ty, valtree) +} + +impl<'db> rustc_type_ir::inherent::ValueConst> for ValueConst<'db> { + fn ty(self) -> Ty<'db> { + self.ty + } + + fn valtree(self) -> ValTree<'db> { + self.value + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct ValTree<'db> { + interned: InternedRef<'db, ValTreeInterned>, +} + +impl<'db, V: WorldExposer> GenericTypeVisitable for ValTree<'db> { + fn generic_visit_with(&self, visitor: &mut V) { + if visitor.on_interned(self.interned).is_continue() { + self.inner().generic_visit_with(visitor); + } + } +} + +impl<'db> TypeVisitable> for ValTree<'db> { + fn visit_with>>( + &self, + visitor: &mut V, + ) -> V::Result { + self.inner().visit_with(visitor) + } +} + +impl<'db> TypeFoldable> for ValTree<'db> { + fn try_fold_with>>( + self, + folder: &mut F, + ) -> Result { + self.inner().try_fold_with(folder).map(ValTree::new) + } + + fn fold_with>>(self, folder: &mut F) -> Self { + ValTree::new(self.inner().fold_with(folder)) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, GenericTypeVisitable)] +pub(in super::super) struct ValTreeInterned(ValTreeKind<'static>); + +impl_internable!(gc; ValTreeInterned); + +const _: () = { + const fn is_copy() {} + is_copy::>(); +}; + +impl<'db> IntoKind for ValTree<'db> { + type Kind = ValTreeKind<'db>; + + fn kind(self) -> Self::Kind { + *self.inner() + } +} + +impl<'db> ValTree<'db> { + #[inline] + pub fn new(kind: ValTreeKind<'db>) -> Self { + let kind = unsafe { std::mem::transmute::, ValTreeKind<'static>>(kind) }; + Self { interned: Interned::new_gc(ValTreeInterned(kind)) } + } + + #[inline] + pub fn inner(&self) -> &ValTreeKind<'db> { + let inner = &self.interned.0; + unsafe { std::mem::transmute::<&ValTreeKind<'static>, &ValTreeKind<'db>>(inner) } + } +} + +impl std::fmt::Debug for ValTree<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.interned.fmt(f) + } +} + +/// The raw bytes of a simple value. +/// +/// This is a packed struct in order to allow this type to be optimally embedded in enums +/// (like Scalar). +#[derive(Clone, Copy, Eq, PartialEq, Hash)] +#[repr(Rust, packed)] +pub struct ScalarInt { + /// The first `size` bytes of `data` are the value. + /// Do not try to read less or more bytes than that. The remaining bytes must be 0. + data: u128, + size: NonZero, +} + +impl ScalarInt { + pub const TRUE: ScalarInt = ScalarInt { data: 1_u128, size: NonZero::new(1).unwrap() }; + pub const FALSE: ScalarInt = ScalarInt { data: 0_u128, size: NonZero::new(1).unwrap() }; + + fn raw(data: u128, size: Size) -> Self { + Self { data, size: NonZero::new(size.bytes() as u8).unwrap() } + } + + #[inline] + pub fn size(self) -> Size { + Size::from_bytes(self.size.get()) + } + + /// Make sure the `data` fits in `size`. + /// This is guaranteed by all constructors here, but having had this check saved us from + /// bugs many times in the past, so keeping it around is definitely worth it. + #[inline(always)] + fn check_data(self) { + // Using a block `{self.data}` here to force a copy instead of using `self.data` + // directly, because `debug_assert_eq` takes references to its arguments and formatting + // arguments and would thus borrow `self.data`. Since `Self` + // is a packed struct, that would create a possibly unaligned reference, which + // is UB. + debug_assert_eq!( + self.size().truncate(self.data), + { self.data }, + "Scalar value {:#x} exceeds size of {} bytes", + { self.data }, + self.size + ); + } + + #[inline] + pub fn null(size: Size) -> Self { + Self::raw(0, size) + } + + #[inline] + pub fn is_null(self) -> bool { + self.data == 0 + } + + #[inline] + pub fn try_from_uint(i: impl Into, size: Size) -> Option { + let (r, overflow) = Self::truncate_from_uint(i, size); + if overflow { None } else { Some(r) } + } + + /// Returns the truncated result, and whether truncation changed the value. + #[inline] + pub fn truncate_from_uint(i: impl Into, size: Size) -> (Self, bool) { + let data = i.into(); + let r = Self::raw(size.truncate(data), size); + (r, r.data != data) + } + + #[inline] + pub fn try_from_int(i: impl Into, size: Size) -> Option { + let (r, overflow) = Self::truncate_from_int(i, size); + if overflow { None } else { Some(r) } + } + + /// Returns the truncated result, and whether truncation changed the value. + #[inline] + pub fn truncate_from_int(i: impl Into, size: Size) -> (Self, bool) { + let data = i.into(); + // `into` performed sign extension, we have to truncate + let r = Self::raw(size.truncate(data as u128), size); + (r, size.sign_extend(r.data) != data) + } + + #[inline] + pub fn try_from_target_usize( + i: impl Into, + data_layout: &TargetDataLayout, + ) -> Option { + Self::try_from_uint(i, data_layout.pointer_size()) + } + + /// Try to convert this ScalarInt to the raw underlying bits. + /// Fails if the size is wrong. Generally a wrong size should lead to a panic, + /// but Miri sometimes wants to be resilient to size mismatches, + /// so the interpreter will generally use this `try` method. + #[inline] + pub fn try_to_bits(self, target_size: Size) -> Result { + assert_ne!(target_size.bytes(), 0, "you should never look at the bits of a ZST"); + if target_size.bytes() == u64::from(self.size.get()) { + self.check_data(); + Ok(self.data) + } else { + Err(self.size()) + } + } + + #[inline] + pub fn to_bits(self, target_size: Size) -> u128 { + self.try_to_bits(target_size).unwrap_or_else(|size| { + panic!("expected int of size {}, but got size {}", target_size.bytes(), size.bytes()) + }) + } + + /// Extracts the bits from the scalar without checking the size. + #[inline] + pub fn to_bits_unchecked(self) -> u128 { + self.check_data(); + self.data + } + + /// Converts the `ScalarInt` to an unsigned integer of the given size. + /// Panics if the size of the `ScalarInt` is not equal to `size`. + #[inline] + pub fn to_uint(self, size: Size) -> u128 { + self.to_bits(size) + } + + #[inline] + pub fn to_uint_unchecked(self) -> u128 { + self.data + } + + /// Converts the `ScalarInt` to `u8`. + /// Panics if the `size` of the `ScalarInt`in not equal to 1 byte. + #[inline] + pub fn to_u8(self) -> u8 { + self.to_uint(Size::from_bits(8)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to `u16`. + /// Panics if the size of the `ScalarInt` in not equal to 2 bytes. + #[inline] + pub fn to_u16(self) -> u16 { + self.to_uint(Size::from_bits(16)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to `u32`. + /// Panics if the `size` of the `ScalarInt` in not equal to 4 bytes. + #[inline] + pub fn to_u32(self) -> u32 { + self.to_uint(Size::from_bits(32)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to `u64`. + /// Panics if the `size` of the `ScalarInt` in not equal to 8 bytes. + #[inline] + pub fn to_u64(self) -> u64 { + self.to_uint(Size::from_bits(64)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to `u128`. + /// Panics if the `size` of the `ScalarInt` in not equal to 16 bytes. + #[inline] + pub fn to_u128(self) -> u128 { + self.to_uint(Size::from_bits(128)) + } + + #[inline] + pub fn to_target_usize(&self, data_layout: &TargetDataLayout) -> u64 { + self.to_uint(data_layout.pointer_size()).try_into().unwrap() + } + + /// Converts the `ScalarInt` to `bool`. + /// Panics if the `size` of the `ScalarInt` is not equal to 1 byte. + /// Errors if it is not a valid `bool`. + #[inline] + pub fn try_to_bool(self) -> Result { + match self.to_u8() { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(()), + } + } + + /// Converts the `ScalarInt` to a signed integer of the given size. + /// Panics if the size of the `ScalarInt` is not equal to `size`. + #[inline] + pub fn to_int(self, size: Size) -> i128 { + let b = self.to_bits(size); + size.sign_extend(b) + } + + #[inline] + pub fn to_int_unchecked(self) -> i128 { + self.size().sign_extend(self.data) + } + + /// Converts the `ScalarInt` to i8. + /// Panics if the size of the `ScalarInt` is not equal to 1 byte. + pub fn to_i8(self) -> i8 { + self.to_int(Size::from_bits(8)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to i16. + /// Panics if the size of the `ScalarInt` is not equal to 2 bytes. + pub fn to_i16(self) -> i16 { + self.to_int(Size::from_bits(16)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to i32. + /// Panics if the size of the `ScalarInt` is not equal to 4 bytes. + pub fn to_i32(self) -> i32 { + self.to_int(Size::from_bits(32)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to i64. + /// Panics if the size of the `ScalarInt` is not equal to 8 bytes. + pub fn to_i64(self) -> i64 { + self.to_int(Size::from_bits(64)).try_into().unwrap() + } + + /// Converts the `ScalarInt` to i128. + /// Panics if the size of the `ScalarInt` is not equal to 16 bytes. + pub fn to_i128(self) -> i128 { + self.to_int(Size::from_bits(128)) + } + + #[inline] + pub fn to_target_isize(&self, data_layout: &TargetDataLayout) -> i64 { + self.to_int(data_layout.pointer_size()).try_into().unwrap() + } +} + +macro_rules! from_x_for_scalar_int { + ($($ty:ty),*) => { + $( + impl From<$ty> for ScalarInt { + #[inline] + fn from(u: $ty) -> Self { + Self { + data: u128::from(u), + size: NonZero::new(size_of::<$ty>() as u8).unwrap(), + } + } + } + )* + } +} + +macro_rules! from_scalar_int_for_x { + ($($ty:ty),*) => { + $( + impl From for $ty { + #[inline] + fn from(int: ScalarInt) -> Self { + // The `unwrap` cannot fail because to_uint (if it succeeds) + // is guaranteed to return a value that fits into the size. + int.to_uint(Size::from_bytes(size_of::<$ty>())) + .try_into().unwrap() + } + } + )* + } +} + +from_x_for_scalar_int!(u8, u16, u32, u64, u128, bool); +from_scalar_int_for_x!(u8, u16, u32, u64, u128); + +impl TryFrom for bool { + type Error = (); + #[inline] + fn try_from(int: ScalarInt) -> Result { + int.try_to_bool() + } +} + +impl From for ScalarInt { + #[inline] + fn from(c: char) -> Self { + (c as u32).into() + } +} + +macro_rules! from_x_for_scalar_int_signed { + ($($ty:ty),*) => { + $( + impl From<$ty> for ScalarInt { + #[inline] + fn from(u: $ty) -> Self { + Self { + data: u128::from(u.cast_unsigned()), // go via the unsigned type of the same size + size: NonZero::new(size_of::<$ty>() as u8).unwrap(), + } + } + } + )* + } +} + +macro_rules! from_scalar_int_for_x_signed { + ($($ty:ty),*) => { + $( + impl From for $ty { + #[inline] + fn from(int: ScalarInt) -> Self { + // The `unwrap` cannot fail because to_int (if it succeeds) + // is guaranteed to return a value that fits into the size. + int.to_int(Size::from_bytes(size_of::<$ty>())) + .try_into().unwrap() + } + } + )* + } +} + +from_x_for_scalar_int_signed!(i8, i16, i32, i64, i128); +from_scalar_int_for_x_signed!(i8, i16, i32, i64, i128); + +impl From for ScalarInt { + #[inline] + fn from(c: std::cmp::Ordering) -> Self { + // Here we rely on `cmp::Ordering` having the same values in host and target! + ScalarInt::from(c as i8) + } +} + +/// Error returned when a conversion from ScalarInt to char fails. +#[derive(Debug)] +pub struct CharTryFromScalarInt; + +impl TryFrom for char { + type Error = CharTryFromScalarInt; + + #[inline] + fn try_from(int: ScalarInt) -> Result { + match char::from_u32(int.to_u32()) { + Some(c) => Ok(c), + None => Err(CharTryFromScalarInt), + } + } +} + +impl fmt::Debug for ScalarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Dispatch to LowerHex below. + write!(f, "0x{self:x}") + } +} + +impl fmt::LowerHex for ScalarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.check_data(); + if f.alternate() { + // Like regular ints, alternate flag adds leading `0x`. + write!(f, "0x")?; + } + // Format as hex number wide enough to fit any value of the given `size`. + // So data=20, size=1 will be "0x14", but with size=4 it'll be "0x00000014". + // Using a block `{self.data}` here to force a copy instead of using `self.data` + // directly, because `write!` takes references to its formatting arguments and + // would thus borrow `self.data`. Since `Self` + // is a packed struct, that would create a possibly unaligned reference, which + // is UB. + write!(f, "{:01$x}", { self.data }, self.size.get() as usize * 2) + } +} + +impl fmt::UpperHex for ScalarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.check_data(); + // Format as hex number wide enough to fit any value of the given `size`. + // So data=20, size=1 will be "0x14", but with size=4 it'll be "0x00000014". + // Using a block `{self.data}` here to force a copy instead of using `self.data` + // directly, because `write!` takes references to its formatting arguments and + // would thus borrow `self.data`. Since `Self` + // is a packed struct, that would create a possibly unaligned reference, which + // is UB. + write!(f, "{:01$X}", { self.data }, self.size.get() as usize * 2) + } +} + +impl fmt::Display for ScalarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.check_data(); + write!(f, "{}", { self.data }) + } +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/def_id.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/def_id.rs index 00161d6d08250..542eca3ded243 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/def_id.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/def_id.rs @@ -12,7 +12,9 @@ use hir_def::{ use rustc_type_ir::inherent; use stdx::impl_from; -use crate::db::{InternedClosureId, InternedCoroutineId, InternedOpaqueTyId}; +use crate::db::{ + InternedClosureId, InternedCoroutineClosureId, InternedCoroutineId, InternedOpaqueTyId, +}; use super::DbInterner; @@ -35,6 +37,7 @@ pub enum SolverDefId { TypeAliasId(TypeAliasId), InternedClosureId(InternedClosureId), InternedCoroutineId(InternedCoroutineId), + InternedCoroutineClosureId(InternedCoroutineClosureId), InternedOpaqueTyId(InternedOpaqueTyId), EnumVariantId(EnumVariantId), Ctor(Ctor), @@ -80,6 +83,9 @@ impl std::fmt::Debug for SolverDefId { SolverDefId::InternedCoroutineId(id) => { f.debug_tuple("InternedCoroutineId").field(&id).finish() } + SolverDefId::InternedCoroutineClosureId(id) => { + f.debug_tuple("InternedCoroutineClosureId").field(&id).finish() + } SolverDefId::InternedOpaqueTyId(id) => { f.debug_tuple("InternedOpaqueTyId").field(&id).finish() } @@ -123,6 +129,7 @@ impl_from!( TypeAliasId, InternedClosureId, InternedCoroutineId, + InternedCoroutineClosureId, InternedOpaqueTyId, EnumVariantId, Ctor @@ -206,6 +213,7 @@ impl TryFrom for AttrDefId { SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::AnonConstId(_) => Err(()), } @@ -229,6 +237,7 @@ impl TryFrom for DefWithBodyId { | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::Ctor(Ctor::Struct(_)) | SolverDefId::AnonConstId(_) | SolverDefId::AdtId(_) => return Err(()), @@ -251,6 +260,7 @@ impl TryFrom for GenericDefId { SolverDefId::TypeAliasId(type_alias_id) => GenericDefId::TypeAliasId(type_alias_id), SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) | SolverDefId::BuiltinDeriveImplId(_) @@ -348,6 +358,7 @@ declare_id_wrapper!(TraitIdWrapper, TraitId); declare_id_wrapper!(TypeAliasIdWrapper, TypeAliasId); declare_id_wrapper!(ClosureIdWrapper, InternedClosureId); declare_id_wrapper!(CoroutineIdWrapper, InternedCoroutineId); +declare_id_wrapper!(CoroutineClosureIdWrapper, InternedCoroutineClosureId); declare_id_wrapper!(AdtIdWrapper, AdtId); #[derive(Clone, Copy, PartialEq, Eq, Hash)] diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fold.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fold.rs index 7836419e8b751..af823aa005d08 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fold.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fold.rs @@ -17,28 +17,28 @@ use super::{ /// gets mapped to the same result. `BoundVarReplacer` caches by using /// a `DelayedMap` which does not cache the first few types it encounters. pub trait BoundVarReplacerDelegate<'db> { - fn replace_region(&mut self, br: BoundRegion) -> Region<'db>; - fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>; - fn replace_const(&mut self, bv: BoundConst) -> Const<'db>; + fn replace_region(&mut self, br: BoundRegion<'db>) -> Region<'db>; + fn replace_ty(&mut self, bt: BoundTy<'db>) -> Ty<'db>; + fn replace_const(&mut self, bv: BoundConst<'db>) -> Const<'db>; } /// A simple delegate taking 3 mutable functions. The used functions must /// always return the same result for each bound variable, no matter how /// frequently they are called. pub struct FnMutDelegate<'db, 'a> { - pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a), - pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a), - pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a), + pub regions: &'a mut (dyn FnMut(BoundRegion<'db>) -> Region<'db> + 'a), + pub types: &'a mut (dyn FnMut(BoundTy<'db>) -> Ty<'db> + 'a), + pub consts: &'a mut (dyn FnMut(BoundConst<'db>) -> Const<'db> + 'a), } impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> { - fn replace_region(&mut self, br: BoundRegion) -> Region<'db> { + fn replace_region(&mut self, br: BoundRegion<'db>) -> Region<'db> { (self.regions)(br) } - fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> { + fn replace_ty(&mut self, bt: BoundTy<'db>) -> Ty<'db> { (self.types)(bt) } - fn replace_const(&mut self, bv: BoundConst) -> Const<'db> { + fn replace_const(&mut self, bv: BoundConst<'db>) -> Const<'db> { (self.consts)(bv) } } @@ -177,13 +177,13 @@ impl<'db> DbInterner<'db> { self, value: Binder<'db, T>, mut fld_r: F, - ) -> (T, FxIndexMap>) + ) -> (T, FxIndexMap, Region<'db>>) where - F: FnMut(BoundRegion) -> Region<'db>, + F: FnMut(BoundRegion<'db>) -> Region<'db>, T: TypeFoldable>, { let mut region_map = FxIndexMap::default(); - let real_fld_r = |br: BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br)); + let real_fld_r = |br: BoundRegion<'db>| *region_map.entry(br).or_insert_with(|| fld_r(br)); let value = self.instantiate_bound_regions_uncached(value, real_fld_r); (value, region_map) } @@ -194,7 +194,7 @@ impl<'db> DbInterner<'db> { mut replace_regions: F, ) -> T where - F: FnMut(BoundRegion) -> Region<'db>, + F: FnMut(BoundRegion<'db>) -> Region<'db>, T: TypeFoldable>, { let value = value.skip_binder(); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill.rs index a8bff44a02583..6739795a0060d 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill.rs @@ -1,7 +1,5 @@ //! Fulfill loop for next-solver. -mod errors; - use std::ops::ControlFlow; use rustc_hash::FxHashSet; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill/errors.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill/errors.rs deleted file mode 100644 index 0e8218b33aaa2..0000000000000 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/fulfill/errors.rs +++ /dev/null @@ -1,1173 +0,0 @@ -//! Trait solving error diagnosis and reporting. -//! -//! This code isn't used by rust-analyzer (it should, but then it'll probably be better to re-port it from rustc). -//! It's only there because without it, debugging trait solver errors is a nightmare. - -use std::{fmt::Debug, ops::ControlFlow}; - -use rustc_next_trait_solver::solve::{GoalEvaluation, SolverDelegateEvalExt}; -use rustc_type_ir::{ - AliasRelationDirection, AliasTermKind, HostEffectPredicate, Interner, PredicatePolarity, - error::ExpectedFound, - inherent::{IntoKind, Span as _}, - lang_items::SolverTraitLangItem, - solve::{Certainty, GoalSource, MaybeCause, NoSolution}, -}; -use tracing::{instrument, trace}; - -use crate::next_solver::{ - AliasTerm, Binder, ClauseKind, Const, ConstKind, DbInterner, PolyTraitPredicate, PredicateKind, - SolverContext, Span, Term, TraitPredicate, Ty, TyKind, TypeError, - fulfill::NextSolverError, - infer::{ - InferCtxt, - select::SelectionError, - traits::{Obligation, ObligationCause, PredicateObligation, PredicateObligations}, - }, - inspect::{self, ProofTreeVisitor}, - normalize::deeply_normalize_for_diagnostics, -}; - -#[derive(Debug)] -pub struct FulfillmentError<'db> { - pub obligation: PredicateObligation<'db>, - pub code: FulfillmentErrorCode<'db>, - /// Diagnostics only: the 'root' obligation which resulted in - /// the failure to process `obligation`. This is the obligation - /// that was initially passed to `register_predicate_obligation` - pub root_obligation: PredicateObligation<'db>, -} - -impl<'db> FulfillmentError<'db> { - pub fn new( - obligation: PredicateObligation<'db>, - code: FulfillmentErrorCode<'db>, - root_obligation: PredicateObligation<'db>, - ) -> FulfillmentError<'db> { - FulfillmentError { obligation, code, root_obligation } - } - - pub fn is_true_error(&self) -> bool { - match self.code { - FulfillmentErrorCode::Select(_) - | FulfillmentErrorCode::Project(_) - | FulfillmentErrorCode::Subtype(_, _) - | FulfillmentErrorCode::ConstEquate(_, _) => true, - FulfillmentErrorCode::Cycle(_) | FulfillmentErrorCode::Ambiguity { overflow: _ } => { - false - } - } - } -} - -#[derive(Debug, Clone)] -pub enum FulfillmentErrorCode<'db> { - /// Inherently impossible to fulfill; this trait is implemented if and only - /// if it is already implemented. - Cycle(PredicateObligations<'db>), - Select(SelectionError<'db>), - Project(MismatchedProjectionTypes<'db>), - Subtype(ExpectedFound>, TypeError<'db>), // always comes from a SubtypePredicate - ConstEquate(ExpectedFound>, TypeError<'db>), - Ambiguity { - /// Overflow is only `Some(suggest_recursion_limit)` when using the next generation - /// trait solver `-Znext-solver`. With the old solver overflow is eagerly handled by - /// emitting a fatal error instead. - overflow: Option, - }, -} - -#[derive(Debug, Clone)] -pub struct MismatchedProjectionTypes<'db> { - pub err: TypeError<'db>, -} - -pub(super) fn fulfillment_error_for_no_solution<'db>( - infcx: &InferCtxt<'db>, - root_obligation: PredicateObligation<'db>, -) -> FulfillmentError<'db> { - let obligation = find_best_leaf_obligation(infcx, &root_obligation, false); - - let code = match obligation.predicate.kind().skip_binder() { - PredicateKind::Clause(ClauseKind::Projection(_)) => { - FulfillmentErrorCode::Project( - // FIXME: This could be a `Sorts` if the term is a type - MismatchedProjectionTypes { err: TypeError::Mismatch }, - ) - } - PredicateKind::Clause(ClauseKind::ConstArgHasType(ct, expected_ty)) => { - let ct_ty = match ct.kind() { - ConstKind::Unevaluated(uv) => { - infcx.interner.type_of(uv.def.into()).instantiate(infcx.interner, uv.args) - } - ConstKind::Param(param_ct) => param_ct.find_const_ty_from_env(obligation.param_env), - ConstKind::Value(cv) => cv.ty, - kind => panic!( - "ConstArgHasWrongType failed but we don't know how to compute type for {kind:?}" - ), - }; - FulfillmentErrorCode::Select(SelectionError::ConstArgHasWrongType { - ct, - ct_ty, - expected_ty, - }) - } - PredicateKind::NormalizesTo(..) => { - FulfillmentErrorCode::Project(MismatchedProjectionTypes { err: TypeError::Mismatch }) - } - PredicateKind::AliasRelate(_, _, _) => { - FulfillmentErrorCode::Project(MismatchedProjectionTypes { err: TypeError::Mismatch }) - } - PredicateKind::Subtype(pred) => { - let (a, b) = infcx.enter_forall_and_leak_universe( - obligation.predicate.kind().rebind((pred.a, pred.b)), - ); - let expected_found = ExpectedFound::new(a, b); - FulfillmentErrorCode::Subtype(expected_found, TypeError::Sorts(expected_found)) - } - PredicateKind::Coerce(pred) => { - let (a, b) = infcx.enter_forall_and_leak_universe( - obligation.predicate.kind().rebind((pred.a, pred.b)), - ); - let expected_found = ExpectedFound::new(b, a); - FulfillmentErrorCode::Subtype(expected_found, TypeError::Sorts(expected_found)) - } - PredicateKind::Clause(_) | PredicateKind::DynCompatible(_) | PredicateKind::Ambiguous => { - FulfillmentErrorCode::Select(SelectionError::Unimplemented) - } - PredicateKind::ConstEquate(..) => { - panic!("unexpected goal: {obligation:?}") - } - }; - - FulfillmentError { obligation, code, root_obligation } -} - -pub(super) fn fulfillment_error_for_stalled<'db>( - infcx: &InferCtxt<'db>, - root_obligation: PredicateObligation<'db>, -) -> FulfillmentError<'db> { - let (code, refine_obligation) = infcx.probe(|_| { - match <&SolverContext<'db>>::from(infcx).evaluate_root_goal( - root_obligation.as_goal(), - Span::dummy(), - None, - ) { - Ok(GoalEvaluation { - certainty: Certainty::Maybe { cause: MaybeCause::Ambiguity, .. }, - .. - }) => (FulfillmentErrorCode::Ambiguity { overflow: None }, true), - Ok(GoalEvaluation { - certainty: - Certainty::Maybe { - cause: - MaybeCause::Overflow { suggest_increasing_limit, keep_constraints: _ }, - .. - }, - .. - }) => ( - FulfillmentErrorCode::Ambiguity { overflow: Some(suggest_increasing_limit) }, - // Don't look into overflows because we treat overflows weirdly anyways. - // We discard the inference constraints from overflowing goals, so - // recomputing the goal again during `find_best_leaf_obligation` may apply - // inference guidance that makes other goals go from ambig -> pass, for example. - // - // FIXME: We should probably just look into overflows here. - false, - ), - Ok(GoalEvaluation { certainty: Certainty::Yes, .. }) => { - panic!( - "did not expect successful goal when collecting ambiguity errors for `{:?}`", - infcx.resolve_vars_if_possible(root_obligation.predicate), - ) - } - Err(_) => { - panic!( - "did not expect selection error when collecting ambiguity errors for `{:?}`", - infcx.resolve_vars_if_possible(root_obligation.predicate), - ) - } - } - }); - - FulfillmentError { - obligation: if refine_obligation { - find_best_leaf_obligation(infcx, &root_obligation, true) - } else { - root_obligation.clone() - }, - code, - root_obligation, - } -} - -pub(super) fn fulfillment_error_for_overflow<'db>( - infcx: &InferCtxt<'db>, - root_obligation: PredicateObligation<'db>, -) -> FulfillmentError<'db> { - FulfillmentError { - obligation: find_best_leaf_obligation(infcx, &root_obligation, true), - code: FulfillmentErrorCode::Ambiguity { overflow: Some(true) }, - root_obligation, - } -} - -#[instrument(level = "debug", skip(infcx), ret)] -fn find_best_leaf_obligation<'db>( - infcx: &InferCtxt<'db>, - obligation: &PredicateObligation<'db>, - consider_ambiguities: bool, -) -> PredicateObligation<'db> { - let obligation = infcx.resolve_vars_if_possible(obligation.clone()); - // FIXME: we use a probe here as the `BestObligation` visitor does not - // check whether it uses candidates which get shadowed by where-bounds. - // - // We should probably fix the visitor to not do so instead, as this also - // means the leaf obligation may be incorrect. - let obligation = infcx - .fudge_inference_if_ok(|| { - infcx - .visit_proof_tree( - obligation.as_goal(), - &mut BestObligation { obligation: obligation.clone(), consider_ambiguities }, - ) - .break_value() - .ok_or(()) - }) - .unwrap_or(obligation); - deeply_normalize_for_diagnostics(infcx, obligation.param_env, obligation) -} - -struct BestObligation<'db> { - obligation: PredicateObligation<'db>, - consider_ambiguities: bool, -} - -impl<'db> BestObligation<'db> { - fn with_derived_obligation( - &mut self, - derived_obligation: PredicateObligation<'db>, - and_then: impl FnOnce(&mut Self) -> >::Result, - ) -> >::Result { - let old_obligation = std::mem::replace(&mut self.obligation, derived_obligation); - let res = and_then(self); - self.obligation = old_obligation; - res - } - - /// Filter out the candidates that aren't interesting to visit for the - /// purposes of reporting errors. For ambiguities, we only consider - /// candidates that may hold. For errors, we only consider candidates that - /// *don't* hold and which have impl-where clauses that also don't hold. - fn non_trivial_candidates<'a>( - &self, - goal: &'a inspect::InspectGoal<'a, 'db>, - ) -> Vec> { - let mut candidates = goal.candidates(); - match self.consider_ambiguities { - true => { - // If we have an ambiguous obligation, we must consider *all* candidates - // that hold, or else we may guide inference causing other goals to go - // from ambig -> pass/fail. - candidates.retain(|candidate| candidate.result().is_ok()); - } - false => { - // We always handle rigid alias candidates separately as we may not add them for - // aliases whose trait bound doesn't hold. - candidates.retain(|c| !matches!(c.kind(), inspect::ProbeKind::RigidAlias { .. })); - // If we have >1 candidate, one may still be due to "boring" reasons, like - // an alias-relate that failed to hold when deeply evaluated. We really - // don't care about reasons like this. - if candidates.len() > 1 { - candidates.retain(|candidate| { - goal.infcx().probe(|_| { - candidate.instantiate_nested_goals().iter().any(|nested_goal| { - matches!( - nested_goal.source(), - GoalSource::ImplWhereBound - | GoalSource::AliasBoundConstCondition - | GoalSource::AliasWellFormed - ) && nested_goal.result().is_err() - }) - }) - }); - } - } - } - - candidates - } - - /// HACK: We walk the nested obligations for a well-formed arg manually, - /// since there's nontrivial logic in `wf.rs` to set up an obligation cause. - /// Ideally we'd be able to track this better. - fn visit_well_formed_goal( - &mut self, - candidate: &inspect::InspectCandidate<'_, 'db>, - term: Term<'db>, - ) -> ControlFlow> { - let infcx = candidate.goal().infcx(); - let param_env = candidate.goal().goal().param_env; - - for obligation in wf::unnormalized_obligations(infcx, param_env, term).into_iter().flatten() - { - let nested_goal = candidate - .instantiate_proof_tree_for_nested_goal(GoalSource::Misc, obligation.as_goal()); - // Skip nested goals that aren't the *reason* for our goal's failure. - match (self.consider_ambiguities, nested_goal.result()) { - (true, Ok(Certainty::Maybe { cause: MaybeCause::Ambiguity, .. })) - | (false, Err(_)) => {} - _ => continue, - } - - self.with_derived_obligation(obligation, |this| nested_goal.visit_with(this))?; - } - - ControlFlow::Break(self.obligation.clone()) - } - - /// If a normalization of an associated item or a trait goal fails without trying any - /// candidates it's likely that normalizing its self type failed. We manually detect - /// such cases here. - fn detect_error_in_self_ty_normalization( - &mut self, - goal: &inspect::InspectGoal<'_, 'db>, - self_ty: Ty<'db>, - ) -> ControlFlow> { - assert!(!self.consider_ambiguities); - let interner = goal.infcx().interner; - if let TyKind::Alias(..) = self_ty.kind() { - let infer_term = goal.infcx().next_ty_var(); - let pred = PredicateKind::AliasRelate( - self_ty.into(), - infer_term.into(), - AliasRelationDirection::Equate, - ); - let obligation = Obligation::new( - interner, - self.obligation.cause.clone(), - goal.goal().param_env, - pred, - ); - self.with_derived_obligation(obligation, |this| { - goal.infcx().visit_proof_tree_at_depth( - goal.goal().with(interner, pred), - goal.depth() + 1, - this, - ) - }) - } else { - ControlFlow::Continue(()) - } - } - - /// When a higher-ranked projection goal fails, check that the corresponding - /// higher-ranked trait goal holds or not. This is because the process of - /// instantiating and then re-canonicalizing the binder of the projection goal - /// forces us to be unable to see that the leak check failed in the nested - /// `NormalizesTo` goal, so we don't fall back to the rigid projection check - /// that should catch when a projection goal fails due to an unsatisfied trait - /// goal. - fn detect_trait_error_in_higher_ranked_projection( - &mut self, - goal: &inspect::InspectGoal<'_, 'db>, - ) -> ControlFlow> { - let interner = goal.infcx().interner; - if let Some(projection_clause) = goal.goal().predicate.as_projection_clause() - && !projection_clause.bound_vars().is_empty() - { - let pred = projection_clause.map_bound(|proj| proj.projection_term.trait_ref(interner)); - let obligation = Obligation::new( - interner, - self.obligation.cause.clone(), - goal.goal().param_env, - deeply_normalize_for_diagnostics(goal.infcx(), goal.goal().param_env, pred), - ); - self.with_derived_obligation(obligation, |this| { - goal.infcx().visit_proof_tree_at_depth( - goal.goal().with(interner, pred), - goal.depth() + 1, - this, - ) - }) - } else { - ControlFlow::Continue(()) - } - } - - /// It is likely that `NormalizesTo` failed without any applicable candidates - /// because the alias is not well-formed. - /// - /// As we only enter `RigidAlias` candidates if the trait bound of the associated type - /// holds, we discard these candidates in `non_trivial_candidates` and always manually - /// check this here. - fn detect_non_well_formed_assoc_item( - &mut self, - goal: &inspect::InspectGoal<'_, 'db>, - alias: AliasTerm<'db>, - ) -> ControlFlow> { - let interner = goal.infcx().interner; - let obligation = Obligation::new( - interner, - self.obligation.cause.clone(), - goal.goal().param_env, - alias.trait_ref(interner), - ); - self.with_derived_obligation(obligation, |this| { - goal.infcx().visit_proof_tree_at_depth( - goal.goal().with(interner, alias.trait_ref(interner)), - goal.depth() + 1, - this, - ) - }) - } - - /// If we have no candidates, then it's likely that there is a - /// non-well-formed alias in the goal. - fn detect_error_from_empty_candidates( - &mut self, - goal: &inspect::InspectGoal<'_, 'db>, - ) -> ControlFlow> { - let interner = goal.infcx().interner; - let pred_kind = goal.goal().predicate.kind(); - - match pred_kind.no_bound_vars() { - Some(PredicateKind::Clause(ClauseKind::Trait(pred))) => { - self.detect_error_in_self_ty_normalization(goal, pred.self_ty())?; - } - Some(PredicateKind::NormalizesTo(pred)) => { - if let AliasTermKind::ProjectionTy | AliasTermKind::ProjectionConst = - pred.alias.kind(interner) - { - self.detect_error_in_self_ty_normalization(goal, pred.alias.self_ty())?; - self.detect_non_well_formed_assoc_item(goal, pred.alias)?; - } - } - Some(_) | None => {} - } - - ControlFlow::Break(self.obligation.clone()) - } -} - -impl<'db> ProofTreeVisitor<'db> for BestObligation<'db> { - type Result = ControlFlow>; - - #[instrument(level = "trace", skip(self, goal), fields(goal = ?goal.goal()))] - fn visit_goal(&mut self, goal: &inspect::InspectGoal<'_, 'db>) -> Self::Result { - let interner = goal.infcx().interner; - // Skip goals that aren't the *reason* for our goal's failure. - match (self.consider_ambiguities, goal.result()) { - (true, Ok(Certainty::Maybe { cause: MaybeCause::Ambiguity, .. })) | (false, Err(_)) => { - } - _ => return ControlFlow::Continue(()), - } - - let pred = goal.goal().predicate; - - let candidates = self.non_trivial_candidates(goal); - let candidate = match candidates.as_slice() { - [candidate] => candidate, - [] => return self.detect_error_from_empty_candidates(goal), - _ => return ControlFlow::Break(self.obligation.clone()), - }; - - // // Don't walk into impls that have `do_not_recommend`. - // if let inspect::ProbeKind::TraitCandidate { - // source: CandidateSource::Impl(impl_def_id), - // result: _, - // } = candidate.kind() - // && interner.do_not_recommend_impl(impl_def_id) - // { - // trace!("#[do_not_recommend] -> exit"); - // return ControlFlow::Break(self.obligation.clone()); - // } - - // FIXME: Also, what about considering >1 layer up the stack? May be necessary - // for normalizes-to. - let child_mode = match pred.kind().skip_binder() { - PredicateKind::Clause(ClauseKind::Trait(trait_pred)) => { - ChildMode::Trait(pred.kind().rebind(trait_pred)) - } - PredicateKind::Clause(ClauseKind::HostEffect(host_pred)) => { - ChildMode::Host(pred.kind().rebind(host_pred)) - } - PredicateKind::NormalizesTo(normalizes_to) - if matches!( - normalizes_to.alias.kind(interner), - AliasTermKind::ProjectionTy | AliasTermKind::ProjectionConst - ) => - { - ChildMode::Trait(pred.kind().rebind(TraitPredicate { - trait_ref: normalizes_to.alias.trait_ref(interner), - polarity: PredicatePolarity::Positive, - })) - } - PredicateKind::Clause(ClauseKind::WellFormed(term)) => { - return self.visit_well_formed_goal(candidate, term); - } - _ => ChildMode::PassThrough, - }; - - let nested_goals = candidate.instantiate_nested_goals(); - - // If the candidate requires some `T: FnPtr` bound which does not hold should not be treated as - // an actual candidate, instead we should treat them as if the impl was never considered to - // have potentially applied. As if `impl Trait for for<..> fn(..A) -> R` was written - // instead of `impl Trait for T`. - // - // We do this as a separate loop so that we do not choose to tell the user about some nested - // goal before we encounter a `T: FnPtr` nested goal. - for nested_goal in &nested_goals { - if let Some(poly_trait_pred) = nested_goal.goal().predicate.as_trait_clause() - && interner - .is_trait_lang_item(poly_trait_pred.def_id(), SolverTraitLangItem::FnPtrTrait) - && let Err(NoSolution) = nested_goal.result() - { - return ControlFlow::Break(self.obligation.clone()); - } - } - - for nested_goal in nested_goals { - trace!(nested_goal = ?(nested_goal.goal(), nested_goal.source(), nested_goal.result())); - - let nested_pred = nested_goal.goal().predicate; - - let make_obligation = || Obligation { - cause: ObligationCause::dummy(), - param_env: nested_goal.goal().param_env, - predicate: nested_pred, - recursion_depth: self.obligation.recursion_depth + 1, - }; - - let obligation = match (child_mode, nested_goal.source()) { - ( - ChildMode::Trait(_) | ChildMode::Host(_), - GoalSource::Misc | GoalSource::TypeRelating | GoalSource::NormalizeGoal(_), - ) => { - continue; - } - (ChildMode::Trait(_parent_trait_pred), GoalSource::ImplWhereBound) => { - make_obligation() - } - ( - ChildMode::Host(_parent_host_pred), - GoalSource::ImplWhereBound | GoalSource::AliasBoundConstCondition, - ) => make_obligation(), - (ChildMode::PassThrough, _) - | (_, GoalSource::AliasWellFormed | GoalSource::AliasBoundConstCondition) => { - make_obligation() - } - }; - - self.with_derived_obligation(obligation, |this| nested_goal.visit_with(this))?; - } - - // alias-relate may fail because the lhs or rhs can't be normalized, - // and therefore is treated as rigid. - if let Some(PredicateKind::AliasRelate(lhs, rhs, _)) = pred.kind().no_bound_vars() { - goal.infcx().visit_proof_tree_at_depth( - goal.goal().with(interner, ClauseKind::WellFormed(lhs)), - goal.depth() + 1, - self, - )?; - goal.infcx().visit_proof_tree_at_depth( - goal.goal().with(interner, ClauseKind::WellFormed(rhs)), - goal.depth() + 1, - self, - )?; - } - - self.detect_trait_error_in_higher_ranked_projection(goal)?; - - ControlFlow::Break(self.obligation.clone()) - } -} - -#[derive(Debug, Copy, Clone)] -enum ChildMode<'db> { - // Try to derive an `ObligationCause::{ImplDerived,BuiltinDerived}`, - // and skip all `GoalSource::Misc`, which represent useless obligations - // such as alias-eq which may not hold. - Trait(PolyTraitPredicate<'db>), - // Try to derive an `ObligationCause::{ImplDerived,BuiltinDerived}`, - // and skip all `GoalSource::Misc`, which represent useless obligations - // such as alias-eq which may not hold. - Host(Binder<'db, HostEffectPredicate>>), - // Skip trying to derive an `ObligationCause` from this obligation, and - // report *all* sub-obligations as if they came directly from the parent - // obligation. - PassThrough, -} - -impl<'db> NextSolverError<'db> { - pub fn to_debuggable_error(&self, infcx: &InferCtxt<'db>) -> FulfillmentError<'db> { - match self { - NextSolverError::TrueError(obligation) => { - fulfillment_error_for_no_solution(infcx, obligation.clone()) - } - NextSolverError::Ambiguity(obligation) => { - fulfillment_error_for_stalled(infcx, obligation.clone()) - } - NextSolverError::Overflow(obligation) => { - fulfillment_error_for_overflow(infcx, obligation.clone()) - } - } - } -} - -mod wf { - use hir_def::signatures::ImplSignature; - use hir_def::{GeneralConstId, ItemContainerId}; - use rustc_type_ir::inherent::{ - AdtDef, BoundExistentialPredicates, GenericArgs as _, IntoKind, SliceLike, Term as _, - Ty as _, - }; - use rustc_type_ir::lang_items::SolverTraitLangItem; - use rustc_type_ir::{ - Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, - }; - use tracing::{debug, instrument}; - - use crate::next_solver::infer::InferCtxt; - use crate::next_solver::infer::traits::{Obligation, ObligationCause, PredicateObligations}; - use crate::next_solver::{ - Binder, ClauseKind, Const, ConstKind, Ctor, DbInterner, ExistentialPredicate, GenericArgs, - ParamEnv, Predicate, PredicateKind, Region, SolverDefId, Term, TraitRef, Ty, TyKind, - }; - - /// Compute the predicates that are required for a type to be well-formed. - /// - /// This is only intended to be used in the new solver, since it does not - /// take into account recursion depth or proper error-reporting spans. - pub(crate) fn unnormalized_obligations<'db>( - infcx: &InferCtxt<'db>, - param_env: ParamEnv<'db>, - term: Term<'db>, - ) -> Option> { - debug_assert_eq!(term, infcx.resolve_vars_if_possible(term)); - - // However, if `arg` IS an unresolved inference variable, returns `None`, - // because we are not able to make any progress at all. This is to prevent - // cycles where we say "?0 is WF if ?0 is WF". - if term.is_infer() { - return None; - } - - let mut wf = - WfPredicates { infcx, param_env, out: PredicateObligations::new(), recursion_depth: 0 }; - wf.add_wf_preds_for_term(term); - Some(wf.out) - } - - struct WfPredicates<'a, 'db> { - infcx: &'a InferCtxt<'db>, - param_env: ParamEnv<'db>, - out: PredicateObligations<'db>, - recursion_depth: usize, - } - - impl<'a, 'db> WfPredicates<'a, 'db> { - fn interner(&self) -> DbInterner<'db> { - self.infcx.interner - } - - fn require_sized(&mut self, subty: Ty<'db>) { - if !subty.has_escaping_bound_vars() { - let cause = ObligationCause::new(); - let trait_ref = TraitRef::new( - self.interner(), - self.interner().require_trait_lang_item(SolverTraitLangItem::Sized), - [subty], - ); - self.out.push(Obligation::with_depth( - self.interner(), - cause, - self.recursion_depth, - self.param_env, - Binder::dummy(trait_ref), - )); - } - } - - /// Pushes all the predicates needed to validate that `term` is WF into `out`. - #[instrument(level = "debug", skip(self))] - fn add_wf_preds_for_term(&mut self, term: Term<'db>) { - term.visit_with(self); - debug!(?self.out); - } - - #[instrument(level = "debug", skip(self))] - fn nominal_obligations( - &mut self, - def_id: SolverDefId, - args: GenericArgs<'db>, - ) -> PredicateObligations<'db> { - // PERF: `Sized`'s predicates include `MetaSized`, but both are compiler implemented marker - // traits, so `MetaSized` will always be WF if `Sized` is WF and vice-versa. Determining - // the nominal obligations of `Sized` would in-effect just elaborate `MetaSized` and make - // the compiler do a bunch of work needlessly. - if let SolverDefId::TraitId(def_id) = def_id - && self.interner().is_trait_lang_item(def_id.into(), SolverTraitLangItem::Sized) - { - return Default::default(); - } - - self.interner() - .predicates_of(def_id) - .iter_instantiated(self.interner(), args) - .map(|pred| { - let cause = ObligationCause::new(); - Obligation::with_depth( - self.interner(), - cause, - self.recursion_depth, - self.param_env, - pred, - ) - }) - .filter(|pred| !pred.has_escaping_bound_vars()) - .collect() - } - - fn add_wf_preds_for_dyn_ty( - &mut self, - _ty: Ty<'db>, - data: &[Binder<'db, ExistentialPredicate<'db>>], - region: Region<'db>, - ) { - // Imagine a type like this: - // - // trait Foo { } - // trait Bar<'c> : 'c { } - // - // &'b (Foo+'c+Bar<'d>) - // ^ - // - // In this case, the following relationships must hold: - // - // 'b <= 'c - // 'd <= 'c - // - // The first conditions is due to the normal region pointer - // rules, which say that a reference cannot outlive its - // referent. - // - // The final condition may be a bit surprising. In particular, - // you may expect that it would have been `'c <= 'd`, since - // usually lifetimes of outer things are conservative - // approximations for inner things. However, it works somewhat - // differently with trait objects: here the idea is that if the - // user specifies a region bound (`'c`, in this case) it is the - // "master bound" that *implies* that bounds from other traits are - // all met. (Remember that *all bounds* in a type like - // `Foo+Bar+Zed` must be met, not just one, hence if we write - // `Foo<'x>+Bar<'y>`, we know that the type outlives *both* 'x and - // 'y.) - // - // Note: in fact we only permit builtin traits, not `Bar<'d>`, I - // am looking forward to the future here. - if !data.has_escaping_bound_vars() && !region.has_escaping_bound_vars() { - let implicit_bounds = object_region_bounds(self.interner(), data); - - let explicit_bound = region; - - self.out.reserve(implicit_bounds.len()); - for implicit_bound in implicit_bounds { - let cause = ObligationCause::new(); - let outlives = Binder::dummy(rustc_type_ir::OutlivesPredicate( - explicit_bound, - implicit_bound, - )); - self.out.push(Obligation::with_depth( - self.interner(), - cause, - self.recursion_depth, - self.param_env, - outlives, - )); - } - - // We don't add any wf predicates corresponding to the trait ref's generic arguments - // which allows code like this to compile: - // ```rust - // trait Trait {} - // fn foo(_: &dyn Trait<[u32]>) {} - // ``` - } - } - } - - impl<'a, 'db> TypeVisitor> for WfPredicates<'a, 'db> { - type Result = (); - - fn visit_ty(&mut self, t: Ty<'db>) -> Self::Result { - debug!("wf bounds for t={:?} t.kind={:#?}", t, t.kind()); - - let tcx = self.interner(); - - match t.kind() { - TyKind::Bool - | TyKind::Char - | TyKind::Int(..) - | TyKind::Uint(..) - | TyKind::Float(..) - | TyKind::Error(_) - | TyKind::Str - | TyKind::CoroutineWitness(..) - | TyKind::Never - | TyKind::Param(_) - | TyKind::Bound(..) - | TyKind::Placeholder(..) - | TyKind::Foreign(..) => { - // WfScalar, WfParameter, etc - } - - // Can only infer to `TyKind::Int(_) | TyKind::Uint(_)`. - TyKind::Infer(rustc_type_ir::IntVar(_)) => {} - - // Can only infer to `TyKind::Float(_)`. - TyKind::Infer(rustc_type_ir::FloatVar(_)) => {} - - TyKind::Slice(subty) => { - self.require_sized(subty); - } - - TyKind::Array(subty, len) => { - self.require_sized(subty); - // Note that the len being WF is implicitly checked while visiting. - // Here we just check that it's of type usize. - let cause = ObligationCause::new(); - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - Binder::dummy(PredicateKind::Clause(ClauseKind::ConstArgHasType( - len, - Ty::new_unit(self.interner()), - ))), - )); - } - - TyKind::Pat(base_ty, _pat) => { - self.require_sized(base_ty); - } - - TyKind::Tuple(tys) => { - if let Some((_last, rest)) = tys.split_last() { - for &elem in rest { - self.require_sized(elem); - } - } - } - - TyKind::RawPtr(_, _) => { - // Simple cases that are WF if their type args are WF. - } - - TyKind::Alias( - rustc_type_ir::Projection | rustc_type_ir::Opaque | rustc_type_ir::Free, - data, - ) => { - let obligations = self.nominal_obligations(data.def_id, data.args); - self.out.extend(obligations); - } - TyKind::Alias(rustc_type_ir::Inherent, _data) => { - return; - } - - TyKind::Adt(def, args) => { - // WfNominalType - let obligations = self.nominal_obligations(def.def_id().0.into(), args); - self.out.extend(obligations); - } - - TyKind::FnDef(did, args) => { - // HACK: Check the return type of function definitions for - // well-formedness to mostly fix #84533. This is still not - // perfect and there may be ways to abuse the fact that we - // ignore requirements with escaping bound vars. That's a - // more general issue however. - let fn_sig = tcx.fn_sig(did).instantiate(tcx, args); - fn_sig.output().skip_binder().visit_with(self); - - let did = match did.0 { - hir_def::CallableDefId::FunctionId(id) => id.into(), - hir_def::CallableDefId::StructId(id) => SolverDefId::Ctor(Ctor::Struct(id)), - hir_def::CallableDefId::EnumVariantId(id) => { - SolverDefId::Ctor(Ctor::Enum(id)) - } - }; - let obligations = self.nominal_obligations(did, args); - self.out.extend(obligations); - } - - TyKind::Ref(r, rty, _) => { - // WfReference - if !r.has_escaping_bound_vars() && !rty.has_escaping_bound_vars() { - let cause = ObligationCause::new(); - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - Binder::dummy(PredicateKind::Clause(ClauseKind::TypeOutlives( - rustc_type_ir::OutlivesPredicate(rty, r), - ))), - )); - } - } - - TyKind::Coroutine(did, args, ..) => { - // Walk ALL the types in the coroutine: this will - // include the upvar types as well as the yield - // type. Note that this is mildly distinct from - // the closure case, where we have to be careful - // about the signature of the closure. We don't - // have the problem of implied bounds here since - // coroutines don't take arguments. - let obligations = self.nominal_obligations(did.0.into(), args); - self.out.extend(obligations); - } - - TyKind::Closure(did, args) => { - // Note that we cannot skip the generic types - // types. Normally, within the fn - // body where they are created, the generics will - // always be WF, and outside of that fn body we - // are not directly inspecting closure types - // anyway, except via auto trait matching (which - // only inspects the upvar types). - // But when a closure is part of a type-alias-impl-trait - // then the function that created the defining site may - // have had more bounds available than the type alias - // specifies. This may cause us to have a closure in the - // hidden type that is not actually well formed and - // can cause compiler crashes when the user abuses unsafe - // code to procure such a closure. - // See tests/ui/type-alias-impl-trait/wf_check_closures.rs - let obligations = self.nominal_obligations(did.0.into(), args); - self.out.extend(obligations); - // Only check the upvar types for WF, not the rest - // of the types within. This is needed because we - // capture the signature and it may not be WF - // without the implied bounds. Consider a closure - // like `|x: &'a T|` -- it may be that `T: 'a` is - // not known to hold in the creator's context (and - // indeed the closure may not be invoked by its - // creator, but rather turned to someone who *can* - // verify that). - // - // The special treatment of closures here really - // ought not to be necessary either; the problem - // is related to #25860 -- there is no way for us - // to express a fn type complete with the implied - // bounds that it is assuming. I think in reality - // the WF rules around fn are a bit messed up, and - // that is the rot problem: `fn(&'a T)` should - // probably always be WF, because it should be - // shorthand for something like `where(T: 'a) { - // fn(&'a T) }`, as discussed in #25860. - let upvars = args.as_closure().tupled_upvars_ty(); - return upvars.visit_with(self); - } - - TyKind::CoroutineClosure(did, args) => { - // See the above comments. The same apply to coroutine-closures. - let obligations = self.nominal_obligations(did.0.into(), args); - self.out.extend(obligations); - let upvars = args.as_coroutine_closure().tupled_upvars_ty(); - return upvars.visit_with(self); - } - - TyKind::FnPtr(..) => { - // Let the visitor iterate into the argument/return - // types appearing in the fn signature. - } - TyKind::UnsafeBinder(_ty) => {} - - TyKind::Dynamic(data, r) => { - // WfObject - // - // Here, we defer WF checking due to higher-ranked - // regions. This is perhaps not ideal. - self.add_wf_preds_for_dyn_ty(t, data.as_slice(), r); - - // FIXME(#27579) RFC also considers adding trait - // obligations that don't refer to Self and - // checking those - if let Some(principal) = data.principal_def_id() { - self.out.push(Obligation::with_depth( - tcx, - ObligationCause::new(), - self.recursion_depth, - self.param_env, - Binder::dummy(PredicateKind::DynCompatible(principal)), - )); - } - } - - // Inference variables are the complicated case, since we don't - // know what type they are. We do two things: - // - // 1. Check if they have been resolved, and if so proceed with - // THAT type. - // 2. If not, we've at least simplified things (e.g., we went - // from `Vec?0>: WF` to `?0: WF`), so we can - // register a pending obligation and keep - // moving. (Goal is that an "inductive hypothesis" - // is satisfied to ensure termination.) - // See also the comment on `fn obligations`, describing cycle - // prevention, which happens before this can be reached. - TyKind::Infer(_) => { - let cause = ObligationCause::new(); - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - Binder::dummy(PredicateKind::Clause(ClauseKind::WellFormed(t.into()))), - )); - } - } - - t.super_visit_with(self) - } - - fn visit_const(&mut self, c: Const<'db>) -> Self::Result { - let tcx = self.interner(); - - match c.kind() { - ConstKind::Unevaluated(uv) => { - if !c.has_escaping_bound_vars() { - let predicate = - Binder::dummy(PredicateKind::Clause(ClauseKind::ConstEvaluatable(c))); - let cause = ObligationCause::new(); - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - predicate, - )); - - if let GeneralConstId::ConstId(uv_def) = uv.def.0 - && let ItemContainerId::ImplId(impl_) = - uv_def.loc(self.interner().db).container - && ImplSignature::of(self.interner().db, impl_).target_trait.is_none() - { - return; // Subtree is handled by above function - } else { - let obligations = self.nominal_obligations(uv.def.into(), uv.args); - self.out.extend(obligations); - } - } - } - ConstKind::Infer(_) => { - let cause = ObligationCause::new(); - - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - Binder::dummy(PredicateKind::Clause(ClauseKind::WellFormed(c.into()))), - )); - } - ConstKind::Expr(_) => { - // FIXME(generic_const_exprs): this doesn't verify that given `Expr(N + 1)` the - // trait bound `typeof(N): Add` holds. This is currently unnecessary - // as `ConstKind::Expr` is only produced via normalization of `ConstKind::Unevaluated` - // which means that the `DefId` would have been typeck'd elsewhere. However in - // the future we may allow directly lowering to `ConstKind::Expr` in which case - // we would not be proving bounds we should. - - let predicate = - Binder::dummy(PredicateKind::Clause(ClauseKind::ConstEvaluatable(c))); - let cause = ObligationCause::new(); - self.out.push(Obligation::with_depth( - tcx, - cause, - self.recursion_depth, - self.param_env, - predicate, - )); - } - - ConstKind::Error(_) - | ConstKind::Param(_) - | ConstKind::Bound(..) - | ConstKind::Placeholder(..) => { - // These variants are trivially WF, so nothing to do here. - } - ConstKind::Value(..) => { - // FIXME: Enforce that values are structurally-matchable. - } - } - - c.super_visit_with(self) - } - - fn visit_predicate(&mut self, _p: Predicate<'db>) -> Self::Result { - panic!("predicate should not be checked for well-formedness"); - } - } - - /// Given an object type like `SomeTrait + Send`, computes the lifetime - /// bounds that must hold on the elided self type. These are derived - /// from the declarations of `SomeTrait`, `Send`, and friends -- if - /// they declare `trait SomeTrait : 'static`, for example, then - /// `'static` would appear in the list. - /// - /// N.B., in some cases, particularly around higher-ranked bounds, - /// this function returns a kind of conservative approximation. - /// That is, all regions returned by this function are definitely - /// required, but there may be other region bounds that are not - /// returned, as well as requirements like `for<'a> T: 'a`. - /// - /// Requires that trait definitions have been processed so that we can - /// elaborate predicates and walk supertraits. - pub(crate) fn object_region_bounds<'db>( - interner: DbInterner<'db>, - existential_predicates: &[Binder<'db, ExistentialPredicate<'db>>], - ) -> Vec> { - let erased_self_ty = Ty::new_unit(interner); - - let predicates = existential_predicates - .iter() - .map(|predicate| predicate.with_self_ty(interner, erased_self_ty)); - - rustc_type_ir::elaborate::elaborate(interner, predicates) - .filter_map(|pred| { - debug!(?pred); - match pred.kind().skip_binder() { - ClauseKind::TypeOutlives(rustc_type_ir::OutlivesPredicate(ref t, ref r)) => { - // Search for a bound of the form `erased_self_ty - // : 'a`, but be wary of something like `for<'a> - // erased_self_ty : 'a` (we interpret a - // higher-ranked bound like that as 'static, - // though at present the code in `fulfill.rs` - // considers such bounds to be unsatisfiable, so - // it's kind of a moot point since you could never - // construct such an object, but this seems - // correct even if that code changes). - if t == &erased_self_ty && !r.has_escaping_bound_vars() { - Some(*r) - } else { - None - } - } - ClauseKind::Trait(_) - | ClauseKind::HostEffect(..) - | ClauseKind::RegionOutlives(_) - | ClauseKind::Projection(_) - | ClauseKind::ConstArgHasType(_, _) - | ClauseKind::WellFormed(_) - | ClauseKind::UnstableFeature(_) - | ClauseKind::ConstEvaluatable(_) => None, - } - }) - .collect() - } -} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/canonicalizer.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/canonicalizer.rs index ccd93590107fc..33e4c175d0635 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/canonicalizer.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/canonicalizer.rs @@ -16,11 +16,11 @@ use rustc_type_ir::{ use smallvec::SmallVec; use tracing::debug; -use crate::next_solver::infer::InferCtxt; use crate::next_solver::{ - Binder, Canonical, CanonicalVarKind, CanonicalVars, Const, ConstKind, DbInterner, GenericArg, - Placeholder, Region, Ty, TyKind, + Binder, Canonical, CanonicalVarKind, CanonicalVarKinds, Const, ConstKind, DbInterner, + GenericArg, PlaceholderConst, PlaceholderRegion, Region, Ty, TyKind, }; +use crate::next_solver::{PlaceholderType, infer::InferCtxt}; /// When we canonicalize a value to form a query, we wind up replacing /// various parts of it with canonical variables. This struct stores @@ -498,7 +498,7 @@ impl<'cx, 'db> Canonicalizer<'cx, 'db> { { let base = Canonical { max_universe: UniverseIndex::ROOT, - variables: CanonicalVars::empty(tcx), + var_kinds: CanonicalVarKinds::empty(tcx), value: (), }; Canonicalizer::canonicalize_with_base( @@ -539,7 +539,7 @@ impl<'cx, 'db> Canonicalizer<'cx, 'db> { tcx, canonicalize_mode: canonicalize_region_mode, needs_canonical_flags, - variables: SmallVec::from_slice(base.variables.as_slice()), + variables: SmallVec::from_slice(base.var_kinds.as_slice()), query_state, indices: FxHashMap::default(), sub_root_lookup_table: Default::default(), @@ -562,7 +562,7 @@ impl<'cx, 'db> Canonicalizer<'cx, 'db> { debug_assert!(!out_value.has_infer() && !out_value.has_placeholders()); let canonical_variables = - CanonicalVars::new_from_slice(&canonicalizer.universe_canonicalized_variables()); + CanonicalVarKinds::new_from_slice(&canonicalizer.universe_canonicalized_variables()); let max_universe = canonical_variables .iter() @@ -570,7 +570,7 @@ impl<'cx, 'db> Canonicalizer<'cx, 'db> { .max() .unwrap_or(UniverseIndex::ROOT); - Canonical { max_universe, variables: canonical_variables, value: (base.value, out_value) } + Canonical { max_universe, var_kinds: canonical_variables, value: (base.value, out_value) } } /// Creates a canonical variable replacing `kind` from the input, @@ -670,22 +670,22 @@ impl<'cx, 'db> Canonicalizer<'cx, 'db> { CanonicalVarKind::Region(u) => CanonicalVarKind::Region(reverse_universe_map[&u]), CanonicalVarKind::Const(u) => CanonicalVarKind::Const(reverse_universe_map[&u]), CanonicalVarKind::PlaceholderTy(placeholder) => { - CanonicalVarKind::PlaceholderTy(Placeholder { - universe: reverse_universe_map[&placeholder.universe], - ..placeholder - }) + CanonicalVarKind::PlaceholderTy(PlaceholderType::new( + reverse_universe_map[&placeholder.universe], + placeholder.bound, + )) } CanonicalVarKind::PlaceholderRegion(placeholder) => { - CanonicalVarKind::PlaceholderRegion(Placeholder { - universe: reverse_universe_map[&placeholder.universe], - ..placeholder - }) + CanonicalVarKind::PlaceholderRegion(PlaceholderRegion::new( + reverse_universe_map[&placeholder.universe], + placeholder.bound, + )) } CanonicalVarKind::PlaceholderConst(placeholder) => { - CanonicalVarKind::PlaceholderConst(Placeholder { - universe: reverse_universe_map[&placeholder.universe], - ..placeholder - }) + CanonicalVarKind::PlaceholderConst(PlaceholderConst::new( + reverse_universe_map[&placeholder.universe], + placeholder.bound, + )) } }) .collect() diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/instantiate.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/instantiate.rs index 61d1e97746224..1738552a8e015 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/instantiate.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/instantiate.rs @@ -69,7 +69,7 @@ impl<'db, V> CanonicalExt<'db, V> for Canonical<'db, V> { where T: TypeFoldable>, { - assert_eq!(self.variables.len(), var_values.len()); + assert_eq!(self.var_kinds.len(), var_values.len()); let value = projection_fn(&self.value); instantiate_value(tcx, var_values, value) } @@ -90,15 +90,15 @@ where value } else { let delegate = FnMutDelegate { - regions: &mut |br: BoundRegion| match var_values[br.var].kind() { + regions: &mut |br: BoundRegion<'db>| match var_values[br.var].kind() { GenericArgKind::Lifetime(l) => l, r => panic!("{br:?} is a region but value is {r:?}"), }, - types: &mut |bound_ty: BoundTy| match var_values[bound_ty.var].kind() { + types: &mut |bound_ty: BoundTy<'db>| match var_values[bound_ty.var].kind() { GenericArgKind::Type(ty) => ty, r => panic!("{bound_ty:?} is a type but value is {r:?}"), }, - consts: &mut |bound_ct: BoundConst| match var_values[bound_ct.var].kind() { + consts: &mut |bound_ct: BoundConst<'db>| match var_values[bound_ct.var].kind() { GenericArgKind::Const(ct) => ct, c => panic!("{bound_ct:?} is a const but value is {c:?}"), }, @@ -356,7 +356,7 @@ impl<'db> InferCtxt<'db> { // result, then we can type the corresponding value from the // input. See the example above. let mut opt_values: IndexVec>> = - IndexVec::from_elem_n(None, query_response.variables.len()); + IndexVec::from_elem_n(None, query_response.var_kinds.len()); for (original_value, result_value) in iter::zip(&original_values.var_values, result_values) { @@ -368,7 +368,7 @@ impl<'db> InferCtxt<'db> { // more involved. They are also a lot rarer than region variables. if let TyKind::Bound(index_kind, b) = result_value.kind() && !matches!( - query_response.variables.as_slice()[b.var.as_usize()], + query_response.var_kinds.as_slice()[b.var.as_usize()], CanonicalVarKind::Ty { .. } ) { @@ -398,7 +398,7 @@ impl<'db> InferCtxt<'db> { // given variable in the loop above, use that. Otherwise, use // a fresh inference variable. let interner = self.interner; - let variables = query_response.variables; + let variables = query_response.var_kinds; let var_values = CanonicalVarValues::instantiate(interner, variables, |var_values, kind| { if kind.universe() != UniverseIndex::ROOT { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/mod.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/mod.rs index a0420a5a00b9d..1fefc0f265c55 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/mod.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/canonical/mod.rs @@ -23,7 +23,7 @@ use crate::next_solver::{ ArgOutlivesPredicate, Canonical, CanonicalVarValues, Const, DbInterner, GenericArg, - OpaqueTypeKey, PlaceholderConst, PlaceholderRegion, PlaceholderTy, Region, Ty, TyKind, + OpaqueTypeKey, PlaceholderConst, PlaceholderRegion, PlaceholderType, Region, Ty, TyKind, infer::InferCtxt, }; use instantiate::CanonicalExt; @@ -70,7 +70,7 @@ impl<'db> InferCtxt<'db> { let var_values = CanonicalVarValues::instantiate( self.interner, - canonical.variables, + canonical.var_kinds, |var_values, info| self.instantiate_canonical_var(info, var_values, |ui| universes[ui]), ); let result = canonical.instantiate(self.interner, &var_values); @@ -110,9 +110,9 @@ impl<'db> InferCtxt<'db> { CanonicalVarKind::Float => self.next_float_var().into(), - CanonicalVarKind::PlaceholderTy(PlaceholderTy { universe, bound }) => { + CanonicalVarKind::PlaceholderTy(PlaceholderType { universe, bound, .. }) => { let universe_mapped = universe_map(universe); - let placeholder_mapped = PlaceholderTy { universe: universe_mapped, bound }; + let placeholder_mapped = PlaceholderType::new(universe_mapped, bound); Ty::new_placeholder(self.interner, placeholder_mapped).into() } @@ -120,18 +120,16 @@ impl<'db> InferCtxt<'db> { self.next_region_var_in_universe(universe_map(ui)).into() } - CanonicalVarKind::PlaceholderRegion(PlaceholderRegion { universe, bound }) => { + CanonicalVarKind::PlaceholderRegion(PlaceholderRegion { universe, bound, .. }) => { let universe_mapped = universe_map(universe); - let placeholder_mapped: crate::next_solver::Placeholder< - crate::next_solver::BoundRegion, - > = PlaceholderRegion { universe: universe_mapped, bound }; + let placeholder_mapped = PlaceholderRegion::new(universe_mapped, bound); Region::new_placeholder(self.interner, placeholder_mapped).into() } CanonicalVarKind::Const(ui) => self.next_const_var_in_universe(universe_map(ui)).into(), - CanonicalVarKind::PlaceholderConst(PlaceholderConst { universe, bound }) => { + CanonicalVarKind::PlaceholderConst(PlaceholderConst { universe, bound, .. }) => { let universe_mapped = universe_map(universe); - let placeholder_mapped = PlaceholderConst { universe: universe_mapped, bound }; + let placeholder_mapped = PlaceholderConst::new(universe_mapped, bound); Const::new_placeholder(self.interner, placeholder_mapped).into() } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/mod.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/mod.rs index 21baacb116938..a6352c7899fff 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/mod.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/mod.rs @@ -30,7 +30,7 @@ use unify_key::{ConstVariableOrigin, ConstVariableValue, ConstVidKey}; pub use crate::next_solver::infer::traits::ObligationInspector; use crate::next_solver::{ - ArgOutlivesPredicate, BoundConst, BoundRegion, BoundTy, BoundVarKind, Goal, Predicate, + ArgOutlivesPredicate, BoundConst, BoundRegion, BoundTy, BoundVariableKind, Goal, Predicate, SolverContext, fold::BoundVarReplacerDelegate, infer::{at::ToTrace, select::EvaluationResult, traits::PredicateObligation}, @@ -53,7 +53,7 @@ mod outlives; pub mod region_constraints; pub mod relate; pub mod resolve; -pub(crate) mod select; +pub mod select; pub(crate) mod snapshot; pub(crate) mod traits; mod type_variable; @@ -366,12 +366,16 @@ impl<'db> InferCtxtBuilder<'db> { where T: TypeFoldable>, { - let infcx = self.build(input.typing_mode); + let infcx = self.build(input.typing_mode.0); let (value, args) = infcx.instantiate_canonical(&input.canonical); (infcx, value, args) } pub fn build(&mut self, typing_mode: TypingMode<'db>) -> InferCtxt<'db> { + // We do not allow creating an InferCtxt for an Interner without a crate, because this means + // an interner without a crate cannot access the cache, therefore constructing it doesn't need + // to reinit the cache, and we construct a lot of no-crate interners. + self.interner.expect_crate(); let InferCtxtBuilder { interner } = *self; InferCtxt { interner, @@ -557,6 +561,16 @@ impl<'db> InferCtxt<'db> { traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, copy_def_id) } + pub fn type_is_use_cloned_modulo_regions(&self, param_env: ParamEnv<'db>, ty: Ty<'db>) -> bool { + let ty = self.resolve_vars_if_possible(ty); + + let Some(use_cloned_def_id) = self.interner.lang_items().UseCloned else { + return false; + }; + + traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, use_cloned_def_id) + } + pub fn unresolved_variables(&self) -> Vec> { let mut inner = self.inner.borrow_mut(); let mut vars: Vec> = inner @@ -1098,9 +1112,9 @@ impl<'db> InferCtxt<'db> { for bound_var_kind in bound_vars { let arg: GenericArg<'db> = match bound_var_kind { - BoundVarKind::Ty(_) => self.next_ty_var().into(), - BoundVarKind::Region(_) => self.next_region_var().into(), - BoundVarKind::Const => self.next_const_var().into(), + BoundVariableKind::Ty(_) => self.next_ty_var().into(), + BoundVariableKind::Region(_) => self.next_region_var().into(), + BoundVariableKind::Const => self.next_const_var().into(), }; args.push(arg); } @@ -1110,13 +1124,13 @@ impl<'db> InferCtxt<'db> { } impl<'db> BoundVarReplacerDelegate<'db> for ToFreshVars<'db> { - fn replace_region(&mut self, br: BoundRegion) -> Region<'db> { + fn replace_region(&mut self, br: BoundRegion<'db>) -> Region<'db> { self.args[br.var.index()].expect_region() } - fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> { + fn replace_ty(&mut self, bt: BoundTy<'db>) -> Ty<'db> { self.args[bt.var.index()].expect_ty() } - fn replace_const(&mut self, bv: BoundConst) -> Const<'db> { + fn replace_const(&mut self, bv: BoundConst<'db>) -> Const<'db> { self.args[bv.var.index()].expect_const() } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/region_constraints/mod.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/region_constraints/mod.rs index ae5930d55c72d..7bb39519f50ac 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/region_constraints/mod.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/region_constraints/mod.rs @@ -17,7 +17,7 @@ use super::MemberConstraint; use super::unify_key::RegionVidKey; use crate::next_solver::infer::snapshot::undo_log::{InferCtxtUndoLogs, Snapshot}; use crate::next_solver::infer::unify_key::RegionVariableValue; -use crate::next_solver::{AliasTy, Binder, DbInterner, ParamTy, PlaceholderTy, Region, Ty}; +use crate::next_solver::{AliasTy, Binder, DbInterner, ParamTy, PlaceholderType, Region, Ty}; #[derive(Debug, Clone, Default)] pub struct RegionConstraintStorage<'db> { @@ -122,7 +122,7 @@ pub struct Verify<'db> { #[derive(Clone, PartialEq, Eq, Hash)] pub enum GenericKind<'db> { Param(ParamTy), - Placeholder(PlaceholderTy), + Placeholder(PlaceholderType<'db>), Alias(AliasTy<'db>), } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/generalize.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/generalize.rs index 0f7ae99fa41d0..d621dd4906e81 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/generalize.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/generalize.rs @@ -531,7 +531,7 @@ impl<'db> TypeRelation> for Generalizer<'_, 'db> { } } - TyKind::Alias(_, data) => match self.structurally_relate_aliases { + TyKind::Alias(data) => match self.structurally_relate_aliases { StructurallyRelateAliases::No => self.generalize_alias_ty(data), StructurallyRelateAliases::Yes => relate::structurally_relate_tys(self, t, t), }, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/higher_ranked.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/higher_ranked.rs index c523751e03e32..cfa864406c1e5 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/higher_ranked.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/higher_ranked.rs @@ -8,7 +8,7 @@ use crate::next_solver::fold::FnMutDelegate; use crate::next_solver::infer::InferCtxt; use crate::next_solver::{ Binder, BoundConst, BoundRegion, BoundTy, Const, DbInterner, PlaceholderConst, - PlaceholderRegion, PlaceholderTy, Region, Ty, + PlaceholderRegion, PlaceholderType, Region, Ty, }; impl<'db> InferCtxt<'db> { @@ -35,23 +35,14 @@ impl<'db> InferCtxt<'db> { let next_universe = self.create_next_universe(); let delegate = FnMutDelegate { - regions: &mut |br: BoundRegion| { - Region::new_placeholder( - self.interner, - PlaceholderRegion { universe: next_universe, bound: br }, - ) + regions: &mut |br: BoundRegion<'db>| { + Region::new_placeholder(self.interner, PlaceholderRegion::new(next_universe, br)) }, - types: &mut |bound_ty: BoundTy| { - Ty::new_placeholder( - self.interner, - PlaceholderTy { universe: next_universe, bound: bound_ty }, - ) + types: &mut |bound_ty: BoundTy<'db>| { + Ty::new_placeholder(self.interner, PlaceholderType::new(next_universe, bound_ty)) }, - consts: &mut |bound: BoundConst| { - Const::new_placeholder( - self.interner, - PlaceholderConst { universe: next_universe, bound }, - ) + consts: &mut |bound: BoundConst<'db>| { + Const::new_placeholder(self.interner, PlaceholderConst::new(next_universe, bound)) }, }; diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/lattice.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/lattice.rs index 1abe6a93f4dd5..3522827a9e959 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/lattice.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/infer/relate/lattice.rs @@ -165,8 +165,8 @@ impl<'db> TypeRelation> for LatticeOp<'_, 'db> { } ( - TyKind::Alias(rustc_type_ir::Opaque, AliasTy { def_id: a_def_id, .. }), - TyKind::Alias(rustc_type_ir::Opaque, AliasTy { def_id: b_def_id, .. }), + TyKind::Alias(AliasTy { kind: rustc_type_ir::Opaque { def_id: a_def_id }, .. }), + TyKind::Alias(AliasTy { kind: rustc_type_ir::Opaque { def_id: b_def_id }, .. }), ) if a_def_id == b_def_id => super_combine_tys(infcx, self, a, b), _ => super_combine_tys(infcx, self, a, b), diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/inspect.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/inspect.rs index 5286977549597..63a225b98f9fd 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/inspect.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/inspect.rs @@ -1,5 +1,3 @@ -pub(crate) use rustc_next_trait_solver::solve::inspect::*; - use rustc_ast_ir::try_visit; use rustc_next_trait_solver::{ canonical::instantiate_canonical_state, @@ -329,10 +327,6 @@ impl<'a, 'db> InspectGoal<'a, 'db> { self.result } - pub(crate) fn source(&self) -> GoalSource { - self.source - } - pub(crate) fn depth(&self) -> usize { self.depth } @@ -464,9 +458,10 @@ impl<'a, 'db> InspectGoal<'a, 'db> { pub(crate) fn visit_with>(&self, visitor: &mut V) -> V::Result { if self.depth < visitor.config().max_depth { try_visit!(visitor.visit_goal(self)); + V::Result::output() + } else { + visitor.on_recursion_limit() } - - V::Result::output() } } @@ -479,6 +474,10 @@ pub(crate) trait ProofTreeVisitor<'db> { } fn visit_goal(&mut self, goal: &InspectGoal<'_, 'db>) -> Self::Result; + + fn on_recursion_limit(&mut self) -> Self::Result { + Self::Result::output() + } } impl<'db> InferCtxt<'db> { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs index 5d7ad84e1fe2c..4095dbe47d852 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/interner.rs @@ -25,27 +25,28 @@ use rustc_abi::{ReprFlags, ReprOptions}; use rustc_hash::FxHashSet; use rustc_index::bit_set::DenseBitSet; use rustc_type_ir::{ - AliasTermKind, AliasTyKind, BoundVar, CoroutineWitnessTypes, DebruijnIndex, EarlyBinder, - FlagComputation, Flags, GenericArgKind, GenericTypeVisitable, ImplPolarity, InferTy, Interner, - TraitRef, TypeFlags, TypeVisitableExt, UniverseIndex, Upcast, Variance, + AliasTermKind, AliasTy, AliasTyKind, BoundVar, CoroutineWitnessTypes, DebruijnIndex, + EarlyBinder, FlagComputation, Flags, GenericArgKind, GenericTypeVisitable, ImplPolarity, + InferTy, Interner, TraitRef, TypeFlags, TypeVisitableExt, Upcast, Variance, elaborate::elaborate, error::TypeError, fast_reject, inherent::{self, Const as _, GenericsOf, IntoKind, SliceLike as _, Span as _, Ty as _}, lang_items::{SolverAdtLangItem, SolverLangItem, SolverTraitLangItem}, - solve::SizedTraitKind, + solve::{AdtDestructorKind, SizedTraitKind}, }; use crate::{ FnAbi, - db::{HirDatabase, InternedCoroutine, InternedCoroutineId}, + db::{HirDatabase, InternedClosure, InternedCoroutineId}, lower::GenericPredicates, method_resolution::TraitImpls, next_solver::{ AdtIdWrapper, AnyImplId, BoundConst, CallableIdWrapper, CanonicalVarKind, ClosureIdWrapper, - CoroutineIdWrapper, Ctor, FnSig, FxIndexMap, GeneralConstIdWrapper, OpaqueTypeKey, - RegionAssumptions, SimplifiedType, SolverContext, SolverDefIds, TraitIdWrapper, - TypeAliasIdWrapper, UnevaluatedConst, + Consts, CoroutineClosureIdWrapper, CoroutineIdWrapper, Ctor, FnSig, FxIndexMap, + GeneralConstIdWrapper, LateParamRegion, OpaqueTypeKey, RegionAssumptions, ScalarInt, + SimplifiedType, SolverContext, SolverDefIds, TraitIdWrapper, TypeAliasIdWrapper, + UnevaluatedConst, util::{explicit_item_bounds, explicit_item_self_bounds}, }, }; @@ -53,14 +54,11 @@ use crate::{ use super::{ Binder, BoundExistentialPredicates, BoundTy, BoundTyKind, Clause, ClauseKind, Clauses, Const, ErrorGuaranteed, ExprConst, ExternalConstraints, GenericArg, GenericArgs, ParamConst, ParamEnv, - ParamTy, PlaceholderConst, PlaceholderTy, PredefinedOpaques, Predicate, SolverDefId, Term, Ty, - TyKind, Tys, Valtree, ValueConst, + ParamTy, PredefinedOpaques, Predicate, SolverDefId, Term, Ty, TyKind, Tys, ValTree, ValueConst, abi::Safety, fold::{BoundVarReplacer, BoundVarReplacerDelegate, FnMutDelegate}, generics::{Generics, generics}, - region::{ - BoundRegion, BoundRegionKind, EarlyParamRegion, LateParamRegion, PlaceholderRegion, Region, - }, + region::{BoundRegion, BoundRegionKind, EarlyParamRegion, Region}, util::sizedness_constraint_for_ty, }; @@ -329,6 +327,7 @@ unsafe impl Sync for DbInterner<'_> {} impl<'db> DbInterner<'db> { // FIXME(next-solver): remove this method pub fn conjure() -> DbInterner<'db> { + // Here we can not reinit the cache since we do that when we attach the db. crate::with_attached_db(|db| DbInterner { db: unsafe { std::mem::transmute::<&dyn HirDatabase, &'db dyn HirDatabase>(db) }, krate: None, @@ -341,10 +340,13 @@ impl<'db> DbInterner<'db> { /// /// Elaboration is a special kind: it needs lang items (for `Sized`), therefore it needs `new_with()`. pub fn new_no_crate(db: &'db dyn HirDatabase) -> Self { + // We do not reinit the cache here, since anything accessing the cache needs an InferCtxt, + // and we panic when trying to construct an InferCtxt for an Interner without a crate. DbInterner { db, krate: None, lang_items: None } } pub fn new_with(db: &'db dyn HirDatabase, krate: Crate) -> DbInterner<'db> { + tls_cache::reinit_cache(db); DbInterner { db, krate: Some(krate), @@ -373,6 +375,11 @@ impl<'db> DbInterner<'db> { pub fn default_types<'a>(&self) -> &'a crate::next_solver::DefaultAny<'db> { crate::next_solver::default_types(self.db) } + + #[inline] + pub(crate) fn expect_crate(&self) -> Crate { + self.krate.expect("should have a crate") + } } // This is intentionally left as `()` @@ -390,43 +397,15 @@ interned_slice!( BoundVarKinds, StoredBoundVarKinds, bound_var_kinds, - BoundVarKind, - BoundVarKind, + BoundVariableKind<'db>, + BoundVariableKind<'static>, ); -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -pub enum BoundVarKind { - Ty(BoundTyKind), - Region(BoundRegionKind), - Const, -} - -impl BoundVarKind { - pub fn expect_region(self) -> BoundRegionKind { - match self { - BoundVarKind::Region(lt) => lt, - _ => panic!("expected a region, but found another kind"), - } - } - - pub fn expect_ty(self) -> BoundTyKind { - match self { - BoundVarKind::Ty(ty) => ty, - _ => panic!("expected a type, but found another kind"), - } - } - - pub fn expect_const(self) { - match self { - BoundVarKind::Const => (), - _ => panic!("expected a const, but found another kind"), - } - } -} +pub type BoundVariableKind<'db> = rustc_type_ir::BoundVariableKind>; interned_slice!( CanonicalVarsStorage, - CanonicalVars, + CanonicalVarKinds, StoredCanonicalVars, canonical_vars, CanonicalVarKind<'db>, @@ -438,22 +417,6 @@ pub struct DepNodeIndex; #[derive(Debug)] pub struct Tracked(T); -#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Placeholder { - pub universe: UniverseIndex, - pub bound: T, -} - -impl std::fmt::Debug for Placeholder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { - if self.universe == UniverseIndex::ROOT { - write!(f, "!{:?}", self.bound) - } else { - write!(f, "!{}_{:?}", self.universe.index(), self.bound) - } - } -} - #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct AllocId; @@ -762,17 +725,26 @@ impl<'db> inherent::AdtDef> for AdtDef { .transpose() } - fn destructor( - self, - _interner: DbInterner<'db>, - ) -> Option { - // FIXME(next-solver) - None + fn destructor(self, interner: DbInterner<'db>) -> Option { + crate::drop::destructor(interner.db, self.def_id().0).map(|_| AdtDestructorKind::NotConst) } fn is_manually_drop(self) -> bool { self.inner().flags.is_manually_drop } + + fn is_packed(self) -> bool { + self.repr().packed() + } + + fn field_representing_type_info( + self, + _interner: DbInterner<'db>, + _args: GenericArgs<'db>, + ) -> Option>> { + // FIXME + None + } } impl fmt::Debug for AdtDef { @@ -806,11 +778,16 @@ impl<'db> inherent::Features> for Features { false } - fn associated_const_equality(self) -> bool { + fn feature_bound_holds_in_crate(self, _symbol: Symbol) -> bool { false } +} - fn feature_bound_holds_in_crate(self, _symbol: ()) -> bool { +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, GenericTypeVisitable)] +pub struct Symbol; + +impl<'db> inherent::Symbol> for Symbol { + fn is_kw_underscore_lifetime(self) -> bool { false } } @@ -1022,7 +999,7 @@ impl<'db> Interner for DbInterner<'db> { type ForeignId = TypeAliasIdWrapper; type FunctionId = CallableIdWrapper; type ClosureId = ClosureIdWrapper; - type CoroutineClosureId = CoroutineIdWrapper; + type CoroutineClosureId = CoroutineClosureIdWrapper; type CoroutineId = CoroutineIdWrapper; type AdtId = AdtIdWrapper; type ImplId = AnyImplId; @@ -1036,7 +1013,6 @@ impl<'db> Interner for DbInterner<'db> { type Term = Term<'db>; type BoundVarKinds = BoundVarKinds<'db>; - type BoundVarKind = BoundVarKind; type PredefinedOpaques = PredefinedOpaques<'db>; @@ -1047,13 +1023,13 @@ impl<'db> Interner for DbInterner<'db> { PredefinedOpaques::new_from_slice(data) } - type CanonicalVarKinds = CanonicalVars<'db>; + type CanonicalVarKinds = CanonicalVarKinds<'db>; fn mk_canonical_var_kinds( self, kinds: &[rustc_type_ir::CanonicalVarKind], ) -> Self::CanonicalVarKinds { - CanonicalVars::new_from_slice(kinds) + CanonicalVarKinds::new_from_slice(kinds) } type ExternalConstraints = ExternalConstraints<'db>; @@ -1073,9 +1049,7 @@ impl<'db> Interner for DbInterner<'db> { type Tys = Tys<'db>; type FnInputTys = &'db [Ty<'db>]; type ParamTy = ParamTy; - type BoundTy = BoundTy; - type PlaceholderTy = PlaceholderTy; - type Symbol = (); + type Symbol = Symbol; type ErrorGuaranteed = ErrorGuaranteed; type BoundExistentialPredicates = BoundExistentialPredicates<'db>; @@ -1086,18 +1060,16 @@ impl<'db> Interner for DbInterner<'db> { type Abi = FnAbi; type Const = Const<'db>; - type PlaceholderConst = PlaceholderConst; type ParamConst = ParamConst; - type BoundConst = BoundConst; type ValueConst = ValueConst<'db>; - type ValTree = Valtree<'db>; + type ValTree = ValTree<'db>; + type Consts = Consts<'db>; + type ScalarInt = ScalarInt; type ExprConst = ExprConst; type Region = Region<'db>; type EarlyParamRegion = EarlyParamRegion; - type LateParamRegion = LateParamRegion; - type BoundRegion = BoundRegion; - type PlaceholderRegion = PlaceholderRegion; + type LateParamRegion = LateParamRegion<'db>; type RegionAssumptions = RegionAssumptions<'db>; @@ -1148,7 +1120,8 @@ impl<'db> Interner for DbInterner<'db> { self, f: impl FnOnce(&mut rustc_type_ir::search_graph::GlobalCache) -> R, ) -> R { - tls_cache::with_cache(self.db, f) + // We make sure to reinit the cache when constructing the Interner. + tls_cache::borrow_assume_valid(self.db, f) } fn canonical_param_env_cache_get_or_insert( @@ -1198,6 +1171,7 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::BuiltinDeriveImplId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::AnonConstId(_) => { return VariancesOf::empty(self); } @@ -1230,22 +1204,6 @@ impl<'db> Interner for DbInterner<'db> { AdtDef::new(def_id.0, self) } - fn alias_ty_kind(self, alias: rustc_type_ir::AliasTy) -> AliasTyKind { - match alias.def_id { - SolverDefId::InternedOpaqueTyId(_) => AliasTyKind::Opaque, - SolverDefId::TypeAliasId(type_alias) => match type_alias.loc(self.db).container { - ItemContainerId::ImplId(impl_) - if ImplSignature::of(self.db, impl_).target_trait.is_none() => - { - AliasTyKind::Inherent - } - ItemContainerId::TraitId(_) | ItemContainerId::ImplId(_) => AliasTyKind::Projection, - _ => AliasTyKind::Free, - }, - _ => unimplemented!("Unexpected alias: {:?}", alias.def_id), - } - } - fn alias_term_kind( self, alias: rustc_type_ir::AliasTerm, @@ -1315,10 +1273,13 @@ impl<'db> Interner for DbInterner<'db> { SolverDefId::TypeAliasId(it) => it.lookup(self.db()).container, SolverDefId::ConstId(it) => it.lookup(self.db()).container, SolverDefId::InternedClosureId(it) => { - return self.db().lookup_intern_closure(it).0.generic_def(self.db()).into(); + return it.loc(self.db).0.generic_def(self.db()).into(); } SolverDefId::InternedCoroutineId(it) => { - return self.db().lookup_intern_coroutine(it).0.generic_def(self.db()).into(); + return it.loc(self.db).0.generic_def(self.db()).into(); + } + SolverDefId::InternedCoroutineClosureId(it) => { + return it.loc(self.db).0.generic_def(self.db()).into(); } SolverDefId::StaticId(_) | SolverDefId::AdtId(_) @@ -1342,7 +1303,7 @@ impl<'db> Interner for DbInterner<'db> { 50 } - fn features(self) -> Self::Features { + fn features(self) -> Features { Features } @@ -1356,7 +1317,7 @@ impl<'db> Interner for DbInterner<'db> { fn coroutine_movability(self, def_id: Self::CoroutineId) -> rustc_ast_ir::Movability { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); let expr = &store[expr_id]; match *expr { @@ -1365,16 +1326,17 @@ impl<'db> Interner for DbInterner<'db> { hir_def::hir::Movability::Static => rustc_ast_ir::Movability::Static, hir_def::hir::Movability::Movable => rustc_ast_ir::Movability::Movable, }, - hir_def::hir::ClosureKind::Async => rustc_ast_ir::Movability::Static, + hir_def::hir::ClosureKind::AsyncBlock { .. } => rustc_ast_ir::Movability::Static, _ => panic!("unexpected expression for a coroutine: {expr:?}"), }, - hir_def::hir::Expr::Async { .. } => rustc_ast_ir::Movability::Static, _ => panic!("unexpected expression for a coroutine: {expr:?}"), } } fn coroutine_for_closure(self, def_id: Self::CoroutineClosureId) -> Self::CoroutineId { - def_id + let InternedClosure(owner, coroutine_closure_expr) = def_id.0.loc(self.db); + let coroutine_expr = ExpressionStore::coroutine_for_closure(coroutine_closure_expr); + InternedCoroutineId::new(self.db, InternedClosure(owner, coroutine_expr)).into() } fn generics_require_sized_self(self, def_id: Self::DefId) -> bool { @@ -1471,7 +1433,9 @@ impl<'db> Interner for DbInterner<'db> { fn is_ty_assoc_of_self(ty: Ty<'_>) -> bool { // FIXME: Is this correct wrt. combined kind of assoc type bounds, i.e. `where Self::Assoc: Trait` // wrt. `Assoc2`, which we should exclude? - if let TyKind::Alias(AliasTyKind::Projection, alias) = ty.kind() { + if let TyKind::Alias(alias @ AliasTy { kind: AliasTyKind::Projection { .. }, .. }) = + ty.kind() + { is_ty_assoc_of_self(alias.self_ty()) } else { is_ty_self(ty) @@ -1524,7 +1488,7 @@ impl<'db> Interner for DbInterner<'db> { fn require_lang_item(self, lang_item: SolverLangItem) -> Self::DefId { let lang_items = self.lang_items(); let lang_item = match lang_item { - SolverLangItem::AsyncFnKindUpvars => unimplemented!(), + SolverLangItem::AsyncFnKindUpvars => lang_items.AsyncFnKindUpvars, SolverLangItem::AsyncFnOnceOutput => lang_items.AsyncFnOnceOutput, SolverLangItem::CallOnceFuture => lang_items.CallOnceFuture, SolverLangItem::CallRefFuture => lang_items.CallRefFuture, @@ -1535,6 +1499,8 @@ impl<'db> Interner for DbInterner<'db> { SolverLangItem::DynMetadata => { return lang_items.DynMetadata.expect("Lang item required but not found.").into(); } + SolverLangItem::FieldBase => lang_items.FieldBase, + SolverLangItem::FieldType => lang_items.FieldType, }; lang_item.expect("Lang item required but not found.").into() } @@ -1543,13 +1509,13 @@ impl<'db> Interner for DbInterner<'db> { let lang_items = self.lang_items(); let lang_item = match lang_item { SolverTraitLangItem::AsyncFn => lang_items.AsyncFn, - SolverTraitLangItem::AsyncFnKindHelper => unimplemented!(), + SolverTraitLangItem::AsyncFnKindHelper => lang_items.AsyncFnKindHelper, SolverTraitLangItem::AsyncFnMut => lang_items.AsyncFnMut, SolverTraitLangItem::AsyncFnOnce => lang_items.AsyncFnOnce, SolverTraitLangItem::AsyncFnOnceOutput => unimplemented!( "This is incorrectly marked as `SolverTraitLangItem`, and is not used by the solver." ), - SolverTraitLangItem::AsyncIterator => unimplemented!(), + SolverTraitLangItem::AsyncIterator => lang_items.AsyncIterator, SolverTraitLangItem::Clone => lang_items.Clone, SolverTraitLangItem::Copy => lang_items.Copy, SolverTraitLangItem::Coroutine => lang_items.Coroutine, @@ -1560,7 +1526,7 @@ impl<'db> Interner for DbInterner<'db> { SolverTraitLangItem::FnMut => lang_items.FnMut, SolverTraitLangItem::FnOnce => lang_items.FnOnce, SolverTraitLangItem::FnPtrTrait => lang_items.FnPtrTrait, - SolverTraitLangItem::FusedIterator => unimplemented!(), + SolverTraitLangItem::FusedIterator => lang_items.FusedIterator, SolverTraitLangItem::Future => lang_items.Future, SolverTraitLangItem::Iterator => lang_items.Iterator, SolverTraitLangItem::PointeeTrait => lang_items.PointeeTrait, @@ -1571,10 +1537,9 @@ impl<'db> Interner for DbInterner<'db> { SolverTraitLangItem::Tuple => lang_items.Tuple, SolverTraitLangItem::Unpin => lang_items.Unpin, SolverTraitLangItem::Unsize => lang_items.Unsize, - SolverTraitLangItem::BikeshedGuaranteedNoDrop => { - unimplemented!() - } + SolverTraitLangItem::BikeshedGuaranteedNoDrop => lang_items.BikeshedGuaranteedNoDrop, SolverTraitLangItem::TrivialClone => lang_items.TrivialClone, + SolverTraitLangItem::Field => lang_items.Field, }; lang_item.expect("Lang item required but not found.").into() } @@ -1602,6 +1567,7 @@ impl<'db> Interner for DbInterner<'db> { AsyncIterator, BikeshedGuaranteedNoDrop, FusedIterator, + Field, AsyncFnOnceOutput, // This is incorrectly marked as `SolverTraitLangItem`, and is not used by the solver. } @@ -1647,6 +1613,8 @@ impl<'db> Interner for DbInterner<'db> { ignore = { AsyncFnKindUpvars, DynMetadata, + FieldBase, + FieldType, } Metadata, @@ -1671,6 +1639,8 @@ impl<'db> Interner for DbInterner<'db> { CallRefFuture, CallOnceFuture, AsyncFnOnceOutput, + FieldBase, + FieldType, } DynMetadata, @@ -1689,6 +1659,7 @@ impl<'db> Interner for DbInterner<'db> { AsyncIterator, BikeshedGuaranteedNoDrop, FusedIterator, + Field, AsyncFnOnceOutput, // This is incorrectly marked as `SolverTraitLangItem`, and is not used by the solver. } @@ -1763,6 +1734,7 @@ impl<'db> Interner for DbInterner<'db> { | SolverDefId::StaticId(_) | SolverDefId::InternedClosureId(_) | SolverDefId::InternedCoroutineId(_) + | SolverDefId::InternedCoroutineClosureId(_) | SolverDefId::InternedOpaqueTyId(_) | SolverDefId::EnumVariantId(_) | SolverDefId::AnonConstId(_) @@ -1865,7 +1837,7 @@ impl<'db> Interner for DbInterner<'db> { // // Impls which apply to an alias after normalization are handled by // `assemble_candidates_after_normalizing_self_ty`. - TyKind::Alias(_, _) | TyKind::Placeholder(..) | TyKind::Error(_) => (), + TyKind::Alias(..) | TyKind::Placeholder(..) | TyKind::Error(_) => (), // FIXME: These should ideally not exist as a self type. It would be nice for // the builtin auto trait impls of coroutines to instead directly recurse @@ -1958,12 +1930,6 @@ impl<'db> Interner for DbInterner<'db> { trait_data.flags.contains(TraitFlags::FUNDAMENTAL) } - fn trait_may_be_implemented_via_object(self, _trait_def_id: Self::TraitId) -> bool { - // FIXME(next-solver): should check the `TraitFlags` for - // the `#[rustc_do_not_implement_via_object]` flag - true - } - fn is_impl_trait_in_trait(self, _def_id: Self::DefId) -> bool { // FIXME(next-solver) false @@ -1976,7 +1942,7 @@ impl<'db> Interner for DbInterner<'db> { fn is_general_coroutine(self, def_id: Self::CoroutineId) -> bool { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); matches!( store[expr_id], @@ -1990,12 +1956,14 @@ impl<'db> Interner for DbInterner<'db> { fn coroutine_is_async(self, def_id: Self::CoroutineId) -> bool { // FIXME: Make this a query? I don't believe this can be accessed from bodies other than // the current infer query, except with revealed opaques - is it rare enough to not matter? - let InternedCoroutine(owner, expr_id) = def_id.0.loc(self.db); + let InternedClosure(owner, expr_id) = def_id.0.loc(self.db); let store = ExpressionStore::of(self.db, owner); matches!( store[expr_id], - hir_def::hir::Expr::Closure { closure_kind: hir_def::hir::ClosureKind::Async, .. } - | hir_def::hir::Expr::Async { .. } + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. }, + .. + } ) } @@ -2060,32 +2028,33 @@ impl<'db> Interner for DbInterner<'db> { ) -> rustc_type_ir::Binder { struct Anonymize<'a, 'db> { interner: DbInterner<'db>, - map: &'a mut FxIndexMap, + map: &'a mut FxIndexMap>, } impl<'db> BoundVarReplacerDelegate<'db> for Anonymize<'_, 'db> { - fn replace_region(&mut self, br: BoundRegion) -> Region<'db> { + fn replace_region(&mut self, br: BoundRegion<'db>) -> Region<'db> { let entry = self.map.entry(br.var); let index = entry.index(); let var = BoundVar::from_usize(index); - let kind = (*entry.or_insert_with(|| BoundVarKind::Region(BoundRegionKind::Anon))) - .expect_region(); + let kind = (*entry + .or_insert_with(|| BoundVariableKind::Region(BoundRegionKind::Anon))) + .expect_region(); let br = BoundRegion { var, kind }; Region::new_bound(self.interner, DebruijnIndex::ZERO, br) } - fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> { + fn replace_ty(&mut self, bt: BoundTy<'db>) -> Ty<'db> { let entry = self.map.entry(bt.var); let index = entry.index(); let var = BoundVar::from_usize(index); - let kind = - (*entry.or_insert_with(|| BoundVarKind::Ty(BoundTyKind::Anon))).expect_ty(); + let kind = (*entry.or_insert_with(|| BoundVariableKind::Ty(BoundTyKind::Anon))) + .expect_ty(); Ty::new_bound(self.interner, DebruijnIndex::ZERO, BoundTy { var, kind }) } - fn replace_const(&mut self, bv: BoundConst) -> Const<'db> { + fn replace_const(&mut self, bv: BoundConst<'db>) -> Const<'db> { let entry = self.map.entry(bv.var); let index = entry.index(); let var = BoundVar::from_usize(index); - let () = (*entry.or_insert_with(|| BoundVarKind::Const)).expect_const(); - Const::new_bound(self.interner, DebruijnIndex::ZERO, BoundConst { var }) + let () = (*entry.or_insert_with(|| BoundVariableKind::Const)).expect_const(); + Const::new_bound(self.interner, DebruijnIndex::ZERO, BoundConst::new(var)) } } @@ -2118,16 +2087,15 @@ impl<'db> Interner for DbInterner<'db> { body.exprs().for_each(|(expr_id, expr)| { if matches!( expr, - hir_def::hir::Expr::Async { .. } - | hir_def::hir::Expr::Closure { - closure_kind: hir_def::hir::ClosureKind::Async - | hir_def::hir::ClosureKind::Coroutine(_), - .. - } + hir_def::hir::Expr::Closure { + closure_kind: hir_def::hir::ClosureKind::AsyncBlock { .. } + | hir_def::hir::ClosureKind::Coroutine(_), + .. + } ) { let coroutine = InternedCoroutineId::new( self.db, - InternedCoroutine(ExpressionStoreOwnerId::Body(def_id), expr_id), + InternedClosure(ExpressionStoreOwnerId::Body(def_id), expr_id), ); result.push(coroutine.into()); } @@ -2266,6 +2234,34 @@ impl<'db> Interner for DbInterner<'db> { UnevaluatedConst { def: GeneralConstIdWrapper(id), args: GenericArgs::empty(self) }, )) } + + fn anon_const_kind(self, _def_id: Self::DefId) -> rustc_type_ir::AnonConstKind { + // FIXME + rustc_type_ir::AnonConstKind::GCE + } + + fn alias_ty_kind_from_def_id(self, def_id: Self::DefId) -> AliasTyKind> { + match def_id { + SolverDefId::TypeAliasId(type_alias) => match type_alias.loc(self.db).container { + ItemContainerId::ExternBlockId(_) | ItemContainerId::ModuleId(_) => { + AliasTyKind::Free { def_id } + } + ItemContainerId::ImplId(_) => AliasTyKind::Inherent { def_id }, + ItemContainerId::TraitId(_) => AliasTyKind::Projection { def_id }, + }, + SolverDefId::InternedOpaqueTyId(_) => AliasTyKind::Opaque { def_id }, + _ => unreachable!(), + } + } + + fn closure_is_const(self, _def_id: Self::ClosureId) -> bool { + // FIXME + false + } + + fn item_name(self, _item_index: Self::DefId) -> Self::Symbol { + Symbol + } } fn is_ty_self(ty: Ty<'_>) -> bool { @@ -2295,14 +2291,14 @@ impl<'db> DbInterner<'db> { self.replace_escaping_bound_vars_uncached( value, FnMutDelegate { - regions: &mut |r: BoundRegion| { + regions: &mut |r: BoundRegion<'db>| { Region::new_bound( self, DebruijnIndex::ZERO, BoundRegion { var: shift_bv(r.var), kind: r.kind }, ) }, - types: &mut |t: BoundTy| { + types: &mut |t: BoundTy<'db>| { Ty::new_bound( self, DebruijnIndex::ZERO, @@ -2310,7 +2306,7 @@ impl<'db> DbInterner<'db> { ) }, consts: &mut |c| { - Const::new_bound(self, DebruijnIndex::ZERO, BoundConst { var: shift_bv(c.var) }) + Const::new_bound(self, DebruijnIndex::ZERO, BoundConst::new(shift_bv(c.var))) }, }, ) @@ -2430,6 +2426,7 @@ TrivialTypeTraversalImpls! { CallableIdWrapper, ClosureIdWrapper, CoroutineIdWrapper, + CoroutineClosureIdWrapper, AdtIdWrapper, AnyImplId, GeneralConstIdWrapper, @@ -2438,17 +2435,9 @@ TrivialTypeTraversalImpls! { Span, ParamConst, ParamTy, - BoundRegion, - Placeholder, - Placeholder, - Placeholder, - Placeholder, - BoundVarKind, EarlyParamRegion, - LateParamRegion, AdtDef, - BoundTy, - BoundConst, + ScalarInt, } mod tls_db { @@ -2501,6 +2490,7 @@ mod tls_db { } let _guard = DbGuard::new(self, db); + super::tls_cache::reinit_cache(db); op() } @@ -2523,10 +2513,14 @@ mod tls_db { #[inline] fn drop(&mut self) { self.state.database.set(self.prev); + if let Some(prev) = self.prev { + super::tls_cache::reinit_cache(unsafe { prev.as_ref() }); + } } } let _guard = DbGuard::new(self, db); + super::tls_cache::reinit_cache(db); op() } @@ -2581,22 +2575,38 @@ mod tls_cache { static GLOBAL_CACHE: RefCell> = const { RefCell::new(None) }; } - pub(super) fn with_cache<'db, T>( - db: &'db dyn HirDatabase, - f: impl FnOnce(&mut GlobalCache>) -> T, - ) -> T { + pub(super) fn reinit_cache(db: &dyn HirDatabase) { GLOBAL_CACHE.with_borrow_mut(|handle| { let (db_nonce, revision) = db.nonce_and_revision(); - let handle = match handle { + match handle { Some(handle) => { if handle.revision != revision || db_nonce != handle.db_nonce { *handle = Cache { cache: GlobalCache::default(), revision, db_nonce }; } - handle } - None => handle.insert(Cache { cache: GlobalCache::default(), revision, db_nonce }), + None => *handle = Some(Cache { cache: GlobalCache::default(), revision, db_nonce }), + } + }) + } + + pub(super) fn borrow_assume_valid<'db, T>( + db: &'db dyn HirDatabase, + f: impl FnOnce(&mut GlobalCache>) -> T, + ) -> T { + if cfg!(debug_assertions) { + let get_state = || { + GLOBAL_CACHE.with_borrow(|handle| { + handle.as_ref().map(|handle| (handle.db_nonce, handle.revision)) + }) }; + let old_state = get_state(); + reinit_cache(db); + let new_state = get_state(); + assert_eq!(old_state, new_state, "you assumed the cache is valid!"); + } + GLOBAL_CACHE.with_borrow_mut(|handle| { + let handle = handle.as_mut().expect("you assumed the cache is valid!"); // SAFETY: No idea f(unsafe { std::mem::transmute::< @@ -2641,13 +2651,15 @@ pub unsafe fn collect_ty_garbage() { let mut gc = intern::GarbageCollector::default(); gc.add_storage::(); - gc.add_storage::(); + gc.add_storage::(); + gc.add_storage::(); gc.add_storage::(); gc.add_storage::(); gc.add_storage::(); gc.add_storage::(); gc.add_storage::(); + gc.add_slice_storage::(); gc.add_slice_storage::(); gc.add_slice_storage::(); gc.add_slice_storage::(); @@ -2682,7 +2694,8 @@ macro_rules! impl_gc_visit { impl_gc_visit!( super::consts::ConstInterned, - super::consts::ValtreeInterned, + super::consts::ValTreeInterned, + super::allocation::AllocationInterned, PatternInterned, super::opaques::ExternalConstraintsInterned, super::predicate::PredicateInterned, @@ -2721,4 +2734,5 @@ impl_gc_visit_slice!( super::predicate::BoundExistentialPredicatesStorage, super::region::RegionAssumptionsStorage, super::ty::TysStorage, + super::consts::ConstsStorage, ); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ir_print.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ir_print.rs index e0732b3473748..5dd372a367563 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ir_print.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ir_print.rs @@ -12,7 +12,7 @@ impl<'db> IrPrint> for DbInterner<'db> { } fn print_debug(t: &ty::AliasTy, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - crate::with_attached_db(|db| match t.def_id { + crate::with_attached_db(|db| match t.kind.def_id() { SolverDefId::TypeAliasId(id) => fmt.write_str(&format!( "AliasTy({:?}[{:?}])", TypeAliasSignature::of(db, id).name.as_str(), diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/normalize.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/normalize.rs index bd678b3e78ff3..5d8f3fe5194aa 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/normalize.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/normalize.rs @@ -229,6 +229,7 @@ impl<'db> FallibleTypeFolder> for NormalizationFolder<'_, 'db> { } // Deeply normalize a value and return it +#[expect(dead_code, reason = "rustc has this")] pub(crate) fn deeply_normalize_for_diagnostics<'db, T: TypeFoldable>>( infcx: &InferCtxt<'db>, param_env: ParamEnv<'db>, diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/region.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/region.rs index dc2441f76e3ae..3f0aebac2dea7 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/region.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/region.rs @@ -1,12 +1,12 @@ //! Things related to regions. use hir_def::LifetimeParamId; -use intern::{Interned, InternedRef, Symbol, impl_internable}; +use intern::{Interned, InternedRef, impl_internable}; use macros::GenericTypeVisitable; use rustc_type_ir::{ - BoundVar, BoundVarIndexKind, DebruijnIndex, Flags, GenericTypeVisitable, INNERMOST, RegionVid, - TypeFlags, TypeFoldable, TypeVisitable, - inherent::{IntoKind, PlaceholderLike, SliceLike}, + BoundVarIndexKind, DebruijnIndex, Flags, GenericTypeVisitable, INNERMOST, RegionVid, TypeFlags, + TypeFoldable, TypeVisitable, + inherent::{IntoKind, SliceLike}, relate::Relate, }; @@ -15,10 +15,7 @@ use crate::next_solver::{ interned_slice, }; -use super::{ - SolverDefId, - interner::{BoundVarKind, DbInterner, Placeholder}, -}; +use super::{SolverDefId, interner::DbInterner}; pub type RegionKind<'db> = rustc_type_ir::RegionKind>; @@ -57,7 +54,7 @@ impl<'db> Region<'db> { Region::new(interner, RegionKind::ReEarlyParam(early_bound_region)) } - pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderRegion) -> Self { + pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderRegion<'db>) -> Self { Region::new(interner, RegionKind::RePlaceholder(placeholder)) } @@ -72,7 +69,7 @@ impl<'db> Region<'db> { pub fn new_bound( interner: DbInterner<'db>, index: DebruijnIndex, - bound: BoundRegion, + bound: BoundRegion<'db>, ) -> Region<'db> { Region::new(interner, RegionKind::ReBound(BoundVarIndexKind::Bound(index), bound)) } @@ -147,7 +144,7 @@ impl<'db> Region<'db> { } } -pub type PlaceholderRegion = Placeholder; +pub type PlaceholderRegion<'db> = rustc_type_ir::PlaceholderRegion>; #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct EarlyParamRegion { @@ -156,7 +153,7 @@ pub struct EarlyParamRegion { pub index: u32, } -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, GenericTypeVisitable)] /// The parameter representation of late-bound function parameters, "some region /// at least as big as the scope `fr.scope`". /// @@ -165,50 +162,19 @@ pub struct EarlyParamRegion { /// between others we use the `DefId` of the parameter. For this reason the `bound_region` field /// should basically always be `BoundRegionKind::Named` as otherwise there is no way of telling /// different parameters apart. -pub struct LateParamRegion { +pub struct LateParamRegion<'db> { pub scope: SolverDefId, - pub bound_region: BoundRegionKind, + pub bound_region: BoundRegionKind<'db>, } -impl std::fmt::Debug for LateParamRegion { +impl std::fmt::Debug for LateParamRegion<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "ReLateParam({:?}, {:?})", self.scope, self.bound_region) } } -#[derive(Copy, Clone, PartialEq, Eq, Hash)] -pub enum BoundRegionKind { - /// An anonymous region parameter for a given fn (&T) - Anon, - - /// Named region parameters for functions (a in &'a T) - /// - /// The `DefId` is needed to distinguish free regions in - /// the event of shadowing. - Named(SolverDefId), - - /// Anonymous region for the implicit env pointer parameter - /// to a closure - ClosureEnv, -} - -impl std::fmt::Debug for BoundRegionKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match *self { - BoundRegionKind::Anon => write!(f, "BrAnon"), - BoundRegionKind::Named(did) => { - write!(f, "BrNamed({did:?})") - } - BoundRegionKind::ClosureEnv => write!(f, "BrEnv"), - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq, Hash)] -pub struct BoundRegion { - pub var: BoundVar, - pub kind: BoundRegionKind, -} +pub type BoundRegion<'db> = rustc_type_ir::BoundRegion>; +pub type BoundRegionKind<'db> = rustc_type_ir::BoundRegionKind>; impl rustc_type_ir::inherent::ParamLike for EarlyParamRegion { fn index(self) -> u32 { @@ -223,45 +189,6 @@ impl std::fmt::Debug for EarlyParamRegion { } } -impl<'db> rustc_type_ir::inherent::BoundVarLike> for BoundRegion { - fn var(self) -> BoundVar { - self.var - } - - fn assert_eq(self, var: BoundVarKind) { - assert_eq!(self.kind, var.expect_region()) - } -} - -impl core::fmt::Debug for BoundRegion { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.kind { - BoundRegionKind::Anon => write!(f, "{:?}", self.var), - BoundRegionKind::ClosureEnv => write!(f, "{:?}.Env", self.var), - BoundRegionKind::Named(def) => { - write!(f, "{:?}.Named({:?})", self.var, def) - } - } - } -} - -impl BoundRegionKind { - pub fn is_named(&self) -> bool { - matches!(self, BoundRegionKind::Named(_)) - } - - pub fn get_name(&self) -> Option { - None - } - - pub fn get_id(&self) -> Option { - match self { - BoundRegionKind::Named(id) => Some(*id), - _ => None, - } - } -} - impl std::fmt::Debug for Region<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.kind().fmt(f) @@ -323,15 +250,15 @@ impl<'db> Flags for Region<'db> { impl<'db> rustc_type_ir::inherent::Region> for Region<'db> { fn new_bound( interner: DbInterner<'db>, - debruijn: rustc_type_ir::DebruijnIndex, - var: BoundRegion, + debruijn: DebruijnIndex, + var: BoundRegion<'db>, ) -> Self { Region::new(interner, RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn), var)) } fn new_anon_bound( interner: DbInterner<'db>, - debruijn: rustc_type_ir::DebruijnIndex, + debruijn: DebruijnIndex, var: rustc_type_ir::BoundVar, ) -> Self { Region::new( @@ -357,38 +284,11 @@ impl<'db> rustc_type_ir::inherent::Region> for Region<'db> { interner.default_types().regions.statik } - fn new_placeholder( - interner: DbInterner<'db>, - var: as rustc_type_ir::Interner>::PlaceholderRegion, - ) -> Self { + fn new_placeholder(interner: DbInterner<'db>, var: PlaceholderRegion<'db>) -> Self { Region::new(interner, RegionKind::RePlaceholder(var)) } } -impl<'db> PlaceholderLike> for PlaceholderRegion { - type Bound = BoundRegion; - - fn universe(self) -> rustc_type_ir::UniverseIndex { - self.universe - } - - fn var(self) -> rustc_type_ir::BoundVar { - self.bound.var - } - - fn with_updated_universe(self, ui: rustc_type_ir::UniverseIndex) -> Self { - Placeholder { universe: ui, bound: self.bound } - } - - fn new(ui: rustc_type_ir::UniverseIndex, bound: Self::Bound) -> Self { - Placeholder { universe: ui, bound } - } - - fn new_anon(ui: rustc_type_ir::UniverseIndex, var: rustc_type_ir::BoundVar) -> Self { - Placeholder { universe: ui, bound: BoundRegion { var, kind: BoundRegionKind::Anon } } - } -} - impl<'db, V: super::WorldExposer> GenericTypeVisitable for Region<'db> { fn generic_visit_with(&self, visitor: &mut V) { if visitor.on_interned(self.interned).is_continue() { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/solver.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/solver.rs index 848bb110af2d0..d45ac6c959695 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/solver.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/solver.rs @@ -14,10 +14,13 @@ use rustc_type_ir::{ }; use tracing::debug; -use crate::next_solver::{ - AliasTy, AnyImplId, CanonicalVarKind, Clause, ClauseKind, CoercePredicate, GenericArgs, - ParamEnv, Predicate, PredicateKind, SubtypePredicate, Ty, TyKind, fold::fold_tys, - util::sizedness_fast_path, +use crate::{ + ParamEnvAndCrate, + next_solver::{ + AliasTy, AnyImplId, CanonicalVarKind, Clause, ClauseKind, CoercePredicate, GenericArgs, + ParamEnv, Predicate, PredicateKind, SubtypePredicate, Ty, TyKind, UnevaluatedConst, + fold::fold_tys, util::sizedness_fast_path, + }, }; use super::{ @@ -155,10 +158,11 @@ impl<'db> SolverDelegate for SolverContext<'db> { fold_tys(interner, clause, |ty| match ty.kind() { // Replace all other mentions of the same opaque type with the hidden type, // as the bounds must hold on the hidden type after all. - TyKind::Alias( - AliasTyKind::Opaque, - AliasTy { def_id: def_id2, args: args2, .. }, - ) if def_id == def_id2 && args == args2 => hidden_ty, + TyKind::Alias(AliasTy { + kind: AliasTyKind::Opaque { def_id: def_id2 }, + args: args2, + .. + }) if def_id == def_id2 && args == args2 => hidden_ty, _ => ty, }) }; @@ -247,25 +251,26 @@ impl<'db> SolverDelegate for SolverContext<'db> { fn evaluate_const( &self, - _param_env: ParamEnv<'db>, - uv: rustc_type_ir::UnevaluatedConst, - ) -> Option<::Const> { - match uv.def.0 { + param_env: ParamEnv<'db>, + uv: UnevaluatedConst<'db>, + ) -> Option> { + let ec = match uv.def.0 { GeneralConstId::ConstId(c) => { let subst = uv.args; - let ec = self.cx().db.const_eval(c, subst, None).ok()?; - Some(ec) - } - GeneralConstId::StaticId(c) => { - let ec = self.cx().db.const_eval_static(c).ok()?; - Some(ec) + self.cx().db.const_eval(c, subst, None).ok()? } + GeneralConstId::StaticId(c) => self.cx().db.const_eval_static(c).ok()?, // TODO: Wire up const_eval_anon query in Phase 5. // For now, return an error const so normalization resolves the // unevaluated const to Error (matching the old behavior where // complex expressions produced ConstKind::Error directly). - GeneralConstId::AnonConstId(_) => Some(Const::error(self.cx())), - } + GeneralConstId::AnonConstId(_) => return Some(Const::error(self.cx())), + }; + Some(Const::new_from_allocation( + self.interner, + &ec, + ParamEnvAndCrate { param_env, krate: self.interner.expect_crate() }, + )) } fn compute_goal_fast_path( diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs index 8e892b65ea383..39abdaf079b63 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/ty.rs @@ -17,8 +17,8 @@ use rustc_type_ir::{ IntVid, Interner, TyVid, TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor, UintTy, Upcast, WithCachedTypeInfo, inherent::{ - AdtDef as _, BoundExistentialPredicates, BoundVarLike, Const as _, GenericArgs as _, - IntoKind, ParamLike, PlaceholderLike, Safety as _, SliceLike, Ty as _, + AdtDef as _, BoundExistentialPredicates, Const as _, GenericArgs as _, IntoKind, ParamLike, + Safety as _, SliceLike, Ty as _, }, relate::Relate, solve::SizedTraitKind, @@ -27,11 +27,12 @@ use rustc_type_ir::{ use crate::{ FnAbi, - db::{HirDatabase, InternedCoroutine}, + db::{HirDatabase, InternedClosure}, lower::GenericPredicates, next_solver::{ AdtDef, AliasTy, Binder, CallableIdWrapper, Clause, ClauseKind, ClosureIdWrapper, Const, - CoroutineIdWrapper, FnSig, GenericArgKind, PolyFnSig, Region, TraitRef, TypeAliasIdWrapper, + CoroutineClosureIdWrapper, CoroutineIdWrapper, FnSig, GenericArgKind, PolyFnSig, Region, + TraitRef, TypeAliasIdWrapper, abi::Safety, impl_foldable_for_interned_slice, impl_stored_interned, interned_slice, util::{CoroutineArgsExt, IntegerTypeExt}, @@ -39,7 +40,7 @@ use crate::{ }; use super::{ - BoundVarKind, DbInterner, GenericArgs, Placeholder, SolverDefId, + DbInterner, GenericArgs, SolverDefId, util::{FloatExt, IntegerExt}, }; @@ -96,7 +97,7 @@ impl<'db> Ty<'db> { Ty::new(interner, TyKind::Param(ParamTy { id, index })) } - pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderTy) -> Self { + pub fn new_placeholder(interner: DbInterner<'db>, placeholder: PlaceholderType<'db>) -> Self { Ty::new(interner, TyKind::Placeholder(placeholder)) } @@ -176,7 +177,10 @@ impl<'db> Ty<'db> { def_id: SolverDefId, args: GenericArgs<'db>, ) -> Self { - Ty::new_alias(interner, AliasTyKind::Opaque, AliasTy::new_from_args(interner, def_id, args)) + Ty::new_alias( + interner, + AliasTy::new_from_args(interner, AliasTyKind::Opaque { def_id }, args), + ) } /// Returns the `Size` for primitive types (bool, uint, int, char, float). @@ -388,6 +392,11 @@ impl<'db> Ty<'db> { matches!(self.kind(), TyKind::Char) } + #[inline] + pub fn is_coroutine_closure(self) -> bool { + matches!(self.kind(), TyKind::CoroutineClosure(..)) + } + /// A scalar type is one that denotes an atomic datum, with no sub-components. /// (A RawPtr is scalar because it represents a non-managed pointer, so its /// contents are abstract to rustc.) @@ -437,6 +446,11 @@ impl<'db> Ty<'db> { matches!(self.kind(), TyKind::RawPtr(..)) } + #[inline] + pub fn is_ref(self) -> bool { + matches!(self.kind(), TyKind::Ref(..)) + } + #[inline] pub fn is_array(self) -> bool { matches!(self.kind(), TyKind::Array(..)) @@ -503,6 +517,14 @@ impl<'db> Ty<'db> { } } + /// Returns the type of `ty[i]`. + pub fn builtin_index(self) -> Option> { + match self.kind() { + TyKind::Array(ty, _) | TyKind::Slice(ty) => Some(ty), + _ => None, + } + } + /// Whether the type contains some non-lifetime, aka. type or const, error type. pub fn references_non_lt_error(self) -> bool { references_non_lt_error(&self) @@ -527,7 +549,7 @@ impl<'db> Ty<'db> { let unit_ty = Ty::new_unit(interner); let return_ty = Ty::new_coroutine( interner, - coroutine_id, + interner.coroutine_for_closure(coroutine_id), CoroutineArgs::new( interner, CoroutineArgsParts { @@ -680,12 +702,11 @@ impl<'db> Ty<'db> { let interner = DbInterner::new_no_crate(db); match self.kind() { - TyKind::Alias(AliasTyKind::Opaque, opaque_ty) => Some( - opaque_ty - .def_id + TyKind::Alias(AliasTy { kind: AliasTyKind::Opaque { def_id }, args, .. }) => Some( + def_id .expect_opaque_ty() .predicates(db) - .iter_instantiated_copied(interner, opaque_ty.args.as_slice()) + .iter_instantiated_copied(interner, args.as_slice()) .collect(), ), TyKind::Param(param) => { @@ -713,7 +734,7 @@ impl<'db> Ty<'db> { } } TyKind::Coroutine(coroutine_id, _args) => { - let InternedCoroutine(owner, _) = coroutine_id.0.loc(db); + let InternedClosure(owner, _) = coroutine_id.0.loc(db); let krate = owner.krate(db); if let Some(future_trait) = hir_def::lang_item::lang_items(db, krate).Future { // This is only used by type walking. @@ -742,9 +763,7 @@ impl<'db> Ty<'db> { true } (TyKind::FnDef(def_id, ..), TyKind::FnDef(def_id2, ..)) => def_id == def_id2, - (TyKind::Alias(_, alias, ..), TyKind::Alias(_, alias2)) => { - alias.def_id == alias2.def_id - } + (TyKind::Alias(alias), TyKind::Alias(alias2)) => alias.kind == alias2.kind, (TyKind::Foreign(ty_id, ..), TyKind::Foreign(ty_id2, ..)) => ty_id == ty_id2, (TyKind::Closure(id1, _), TyKind::Closure(id2, _)) => id1 == id2, (TyKind::Ref(.., mutability), TyKind::Ref(.., mutability2)) @@ -857,7 +876,7 @@ impl<'db> TypeSuperVisitable> for Ty<'db> { TyKind::CoroutineWitness(_did, ref args) => args.visit_with(visitor), TyKind::Closure(_did, ref args) => args.visit_with(visitor), TyKind::CoroutineClosure(_did, ref args) => args.visit_with(visitor), - TyKind::Alias(_, ref data) => data.visit_with(visitor), + TyKind::Alias(ref data) => data.visit_with(visitor), TyKind::Pat(ty, pat) => { try_visit!(ty.visit_with(visitor)); @@ -924,7 +943,7 @@ impl<'db> TypeSuperFoldable> for Ty<'db> { TyKind::CoroutineClosure(did, args) => { TyKind::CoroutineClosure(did, args.try_fold_with(folder)?) } - TyKind::Alias(kind, data) => TyKind::Alias(kind, data.try_fold_with(folder)?), + TyKind::Alias(data) => TyKind::Alias(data.try_fold_with(folder)?), TyKind::Pat(ty, pat) => { TyKind::Pat(ty.try_fold_with(folder)?, pat.try_fold_with(folder)?) } @@ -973,7 +992,7 @@ impl<'db> TypeSuperFoldable> for Ty<'db> { TyKind::CoroutineClosure(did, args) => { TyKind::CoroutineClosure(did, args.fold_with(folder)) } - TyKind::Alias(kind, data) => TyKind::Alias(kind, data.fold_with(folder)), + TyKind::Alias(data) => TyKind::Alias(data.fold_with(folder)), TyKind::Pat(ty, pat) => TyKind::Pat(ty.fold_with(folder), pat.fold_with(folder)), TyKind::Bool @@ -1044,11 +1063,11 @@ impl<'db> rustc_type_ir::inherent::Ty> for Ty<'db> { Ty::new(interner, TyKind::Param(param)) } - fn new_placeholder(interner: DbInterner<'db>, param: PlaceholderTy) -> Self { + fn new_placeholder(interner: DbInterner<'db>, param: PlaceholderType<'db>) -> Self { Ty::new(interner, TyKind::Placeholder(param)) } - fn new_bound(interner: DbInterner<'db>, debruijn: DebruijnIndex, var: BoundTy) -> Self { + fn new_bound(interner: DbInterner<'db>, debruijn: DebruijnIndex, var: BoundTy<'db>) -> Self { Ty::new(interner, TyKind::Bound(BoundVarIndexKind::Bound(debruijn), var)) } @@ -1069,8 +1088,8 @@ impl<'db> rustc_type_ir::inherent::Ty> for Ty<'db> { ) } - fn new_alias(interner: DbInterner<'db>, kind: AliasTyKind, alias_ty: AliasTy<'db>) -> Self { - Ty::new(interner, TyKind::Alias(kind, alias_ty)) + fn new_alias(interner: DbInterner<'db>, alias_ty: AliasTy<'db>) -> Self { + Ty::new(interner, TyKind::Alias(alias_ty)) } fn new_error(interner: DbInterner<'db>, guar: ErrorGuaranteed) -> Self { @@ -1107,7 +1126,7 @@ impl<'db> rustc_type_ir::inherent::Ty> for Ty<'db> { fn new_coroutine_closure( interner: DbInterner<'db>, - def_id: CoroutineIdWrapper, + def_id: CoroutineClosureIdWrapper, args: as Interner>::GenericArgs, ) -> Self { Ty::new(interner, TyKind::CoroutineClosure(def_id, args)) @@ -1351,7 +1370,7 @@ impl<'db> rustc_type_ir::inherent::Tys> for Tys<'db> { } } -pub type PlaceholderTy = Placeholder; +pub type PlaceholderType<'db> = rustc_type_ir::PlaceholderType>; #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct ParamTy { @@ -1374,27 +1393,8 @@ impl std::fmt::Debug for ParamTy { } } -#[derive(Copy, Clone, PartialEq, Eq, Hash)] -pub struct BoundTy { - pub var: BoundVar, - // FIXME: This is for diagnostics in rustc, do we really need it? - pub kind: BoundTyKind, -} - -impl std::fmt::Debug for BoundTy { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.kind { - BoundTyKind::Anon => write!(f, "{:?}", self.var), - BoundTyKind::Param(def_id) => write!(f, "{def_id:?}"), - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] -pub enum BoundTyKind { - Anon, - Param(SolverDefId), -} +pub type BoundTy<'db> = rustc_type_ir::BoundTy>; +pub type BoundTyKind<'db> = rustc_type_ir::BoundTyKind>; #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct ErrorGuaranteed; @@ -1430,40 +1430,6 @@ impl ParamLike for ParamTy { } } -impl<'db> BoundVarLike> for BoundTy { - fn var(self) -> BoundVar { - self.var - } - - fn assert_eq(self, var: BoundVarKind) { - assert_eq!(self.kind, var.expect_ty()) - } -} - -impl<'db> PlaceholderLike> for PlaceholderTy { - type Bound = BoundTy; - - fn universe(self) -> rustc_type_ir::UniverseIndex { - self.universe - } - - fn var(self) -> BoundVar { - self.bound.var - } - - fn with_updated_universe(self, ui: rustc_type_ir::UniverseIndex) -> Self { - Placeholder { universe: ui, bound: self.bound } - } - - fn new(ui: rustc_type_ir::UniverseIndex, bound: BoundTy) -> Self { - Placeholder { universe: ui, bound } - } - - fn new_anon(ui: rustc_type_ir::UniverseIndex, var: rustc_type_ir::BoundVar) -> Self { - Placeholder { universe: ui, bound: BoundTy { var, kind: BoundTyKind::Anon } } - } -} - impl<'db> DbInterner<'db> { /// Given a closure signature, returns an equivalent fn signature. Detuples /// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/util.rs b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/util.rs index c175062bda37c..858233cb2c900 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/util.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/next_solver/util.rs @@ -15,7 +15,7 @@ use rustc_type_ir::{ use crate::{ next_solver::{ - BoundConst, FxIndexMap, ParamEnv, Placeholder, PlaceholderConst, PlaceholderRegion, + BoundConst, FxIndexMap, ParamEnv, PlaceholderConst, PlaceholderRegion, PlaceholderType, PolyTraitRef, infer::{ InferCtxt, @@ -446,9 +446,10 @@ pub fn apply_args_to_binder<'db, T: TypeFoldable>>( args: GenericArgs<'db>, interner: DbInterner<'db>, ) -> T { - let types = &mut |ty: BoundTy| args.as_slice()[ty.var.index()].expect_ty(); - let regions = &mut |region: BoundRegion| args.as_slice()[region.var.index()].expect_region(); - let consts = &mut |const_: BoundConst| args.as_slice()[const_.var.index()].expect_const(); + let types = &mut |ty: BoundTy<'db>| args.as_slice()[ty.var.index()].expect_ty(); + let regions = + &mut |region: BoundRegion<'db>| args.as_slice()[region.var.index()].expect_region(); + let consts = &mut |const_: BoundConst<'db>| args.as_slice()[const_.var.index()].expect_const(); let mut instantiate = BoundVarReplacer::new(interner, FnMutDelegate { types, regions, consts }); b.skip_binder().fold_with(&mut instantiate) } @@ -497,9 +498,9 @@ impl<'db> TypeVisitor> for ContainsTypeErrors { /// The inverse of [`BoundVarReplacer`]: replaces placeholders with the bound vars from which they came. pub struct PlaceholderReplacer<'a, 'db> { infcx: &'a InferCtxt<'db>, - mapped_regions: FxIndexMap, - mapped_types: FxIndexMap, BoundTy>, - mapped_consts: FxIndexMap, + mapped_regions: FxIndexMap, BoundRegion<'db>>, + mapped_types: FxIndexMap, BoundTy<'db>>, + mapped_consts: FxIndexMap, BoundConst<'db>>, universe_indices: &'a [Option], current_index: DebruijnIndex, } @@ -507,9 +508,9 @@ pub struct PlaceholderReplacer<'a, 'db> { impl<'a, 'db> PlaceholderReplacer<'a, 'db> { pub fn replace_placeholders>>( infcx: &'a InferCtxt<'db>, - mapped_regions: FxIndexMap, - mapped_types: FxIndexMap, BoundTy>, - mapped_consts: FxIndexMap, + mapped_regions: FxIndexMap, BoundRegion<'db>>, + mapped_types: FxIndexMap, BoundTy<'db>>, + mapped_consts: FxIndexMap, BoundConst<'db>>, universe_indices: &'a [Option], value: T, ) -> T { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/opaques.rs b/src/tools/rust-analyzer/crates/hir-ty/src/opaques.rs index ce93a334221c1..2e85beea9163d 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/opaques.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/opaques.rs @@ -55,10 +55,10 @@ pub(crate) fn opaque_types_defined_by( .for_each(extend_with_taits); }; let extend_with_atpit_from_container = |container| match container { - ItemContainerId::ImplId(impl_id) => { - if ImplSignature::of(db, impl_id).target_trait.is_some() { - extend_with_atpit_from_assoc_items(&impl_id.impl_items(db).items); - } + ItemContainerId::ImplId(impl_id) + if ImplSignature::of(db, impl_id).target_trait.is_some() => + { + extend_with_atpit_from_assoc_items(&impl_id.impl_items(db).items); } ItemContainerId::TraitId(trait_id) => { extend_with_atpit_from_assoc_items(&trait_id.trait_items(db).items); @@ -196,10 +196,10 @@ fn tait_defining_bodies( .collect() }; match loc.container { - ItemContainerId::ImplId(impl_id) => { - if ImplSignature::of(db, impl_id).target_trait.is_some() { - return from_assoc_items(&impl_id.impl_items(db).items); - } + ItemContainerId::ImplId(impl_id) + if ImplSignature::of(db, impl_id).target_trait.is_some() => + { + return from_assoc_items(&impl_id.impl_items(db).items); } ItemContainerId::TraitId(trait_id) => { return from_assoc_items(&trait_id.trait_items(db).items); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs index 9e687568216d9..5324d8c605495 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/closure_captures.rs @@ -1,24 +1,68 @@ use expect_test::{Expect, expect}; use hir_def::{ - DefWithBodyId, + AdtId, DefWithBodyId, LocalFieldId, VariantId, expr_store::{Body, ExpressionStore}, + hir::{BindingId, ExprOrPatId}, }; use hir_expand::{HirFileId, files::InFileWrapper}; use itertools::Itertools; -use span::TextRange; +use rustc_type_ir::inherent::{AdtDef as _, IntoKind}; +use span::{Edition, TextRange}; +use stdx::{format_to, never}; use syntax::{AstNode, AstPtr}; use test_fixture::WithFixture; use crate::{ InferenceResult, - db::HirDatabase, + closure_analysis::Place, display::{DisplayTarget, HirDisplay}, - mir::MirSpan, + next_solver::TyKind, test_db::TestDB, }; use super::{setup_tracing, visit_module}; +fn display_place(db: &TestDB, store: &ExpressionStore, place: &Place, local: BindingId) -> String { + let mut result = store[local].name.display(db, Edition::LATEST).to_string(); + let mut last_was_deref = false; + for (i, proj) in place.projections.iter().enumerate() { + match proj.kind { + hir_ty::closure_analysis::ProjectionKind::Deref => { + result.insert(0, '*'); + last_was_deref = true; + } + hir_ty::closure_analysis::ProjectionKind::Field { field_idx, variant_idx } => { + if last_was_deref { + result.insert(0, '('); + result.push(')'); + last_was_deref = false; + } + + let ty = place.ty_before_projection(i); + match ty.kind() { + TyKind::Tuple(_) => format_to!(result, ".{field_idx}"), + TyKind::Adt(adt_def, _) => { + let variant = match adt_def.def_id().0 { + AdtId::StructId(id) => VariantId::from(id), + AdtId::UnionId(id) => id.into(), + AdtId::EnumId(id) => { + // Can't really do that for an enum, unfortunately, so try to do something alike. + id.enum_variants(db).variants[variant_idx as usize].0.into() + } + }; + let field = &variant.fields(db).fields() + [LocalFieldId::from_raw(la_arena::RawIdx::from_u32(field_idx))]; + format_to!(result, ".{}", field.name.display(db, Edition::LATEST)); + } + _ => never!("mismatching projection type"), + } + } + _ => never!("unexpected projection kind"), + } + } + result +} + fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expect) { let _tracing = setup_tracing(); let (db, file_id) = TestDB::with_single_file(ra_fixture); @@ -40,84 +84,69 @@ fn check_closure_captures(#[rust_analyzer::rust_fixture] ra_fixture: &str, expec }; let infer = InferenceResult::of(&db, def); let db = &db; - captures_info.extend(infer.closure_info.iter().flat_map( - |(closure_id, (captures, _))| { - let closure = db.lookup_intern_closure(*closure_id); - let body_owner = closure.0; - let source_map = ExpressionStore::with_source_map(db, body_owner).1; - let closure_text_range = source_map - .expr_syntax(closure.1) - .expect("failed to map closure to SyntaxNode") - .value - .text_range(); - captures.iter().map(move |capture| { - fn text_range( - db: &TestDB, - syntax: InFileWrapper>, - ) -> TextRange { - let root = syntax.file_syntax(db); - syntax.value.to_node(&root).syntax().text_range() + captures_info.extend(infer.closures_data.iter().flat_map(|(closure, closure_data)| { + let (body, source_map) = Body::with_source_map(db, def); + let closure_text_range = source_map + .expr_syntax(*closure) + .expect("failed to map closure to SyntaxNode") + .value + .text_range(); + closure_data.min_captures.values().flatten().map(move |capture| { + fn text_range( + db: &TestDB, + syntax: InFileWrapper>, + ) -> TextRange { + let root = syntax.file_syntax(db); + syntax.value.to_node(&root).syntax().text_range() + } + + // FIXME: Deduplicate this with hir::Local::sources(). + let captured_local = capture.captured_local(); + let local_text_range = match body.self_param.zip(source_map.self_param_syntax()) + { + Some((param, source)) if param == captured_local => { + format!("{:?}", text_range(db, source)) } - - // FIXME: Deduplicate this with hir::Local::sources(). - let (body, source_map) = - Body::with_source_map(db, body_owner.as_def_with_body().unwrap()); - let local_text_range = - match body.self_param.zip(source_map.self_param_syntax()) { - Some((param, source)) if param == capture.local() => { - format!("{:?}", text_range(db, source)) - } - _ => source_map - .patterns_for_binding(capture.local()) - .iter() - .map(|&definition| { - text_range(db, source_map.pat_syntax(definition).unwrap()) - }) - .map(|it| format!("{it:?}")) - .join(", "), - }; - let place = capture.display_place(body_owner, db); - let capture_ty = capture - .ty - .get() - .skip_binder() - .display_test(db, DisplayTarget::from_crate(db, module.krate(db))) - .to_string(); - let spans = capture - .spans() + _ => source_map + .patterns_for_binding(captured_local) .iter() - .flat_map(|span| match *span { - MirSpan::ExprId(expr) => { - vec![text_range(db, source_map.expr_syntax(expr).unwrap())] - } - MirSpan::PatId(pat) => { - vec![text_range(db, source_map.pat_syntax(pat).unwrap())] - } - MirSpan::BindingId(binding) => source_map - .patterns_for_binding(binding) - .iter() - .map(|pat| text_range(db, source_map.pat_syntax(*pat).unwrap())) - .collect(), - MirSpan::SelfParam => { - vec![text_range(db, source_map.self_param_syntax().unwrap())] - } - MirSpan::Unknown => Vec::new(), + .map(|&definition| { + text_range(db, source_map.pat_syntax(definition).unwrap()) }) - .sorted_by_key(|it| it.start()) .map(|it| format!("{it:?}")) - .join(","); - - ( - closure_text_range, - local_text_range, - spans, - place, - capture_ty, - capture.kind(), - ) - }) - }, - )); + .join(", "), + }; + let place = display_place(db, body, &capture.place, captured_local); + let capture_ty = capture + .captured_ty(db) + .display_test(db, DisplayTarget::from_crate(db, module.krate(db))) + .to_string(); + let spans = capture + .info + .sources + .iter() + .flat_map(|span| match span.final_source() { + ExprOrPatId::ExprId(expr) => { + vec![text_range(db, source_map.expr_syntax(expr).unwrap())] + } + ExprOrPatId::PatId(pat) => { + vec![text_range(db, source_map.pat_syntax(pat).unwrap())] + } + }) + .sorted_by_key(|it| it.start()) + .map(|it| format!("{it:?}")) + .join(","); + + ( + closure_text_range, + local_text_range, + spans, + place, + capture_ty, + capture.info.capture_kind, + ) + }) + })); } captures_info.sort_unstable_by_key(|(closure_text_range, local_text_range, ..)| { (closure_text_range.start(), local_text_range.clone()) @@ -146,7 +175,7 @@ fn main() { let closure = || { let b = *a; }; } "#, - expect!["53..71;20..21;66..68 ByRef(Shared) *a &'? bool"], + expect!["53..71;20..21;66..68 ByRef(Immutable) *a &' bool"], ); } @@ -160,7 +189,7 @@ fn main() { let closure = || { let &mut ref b = a; }; } "#, - expect!["53..79;20..21;67..72 ByRef(Shared) *a &'? bool"], + expect!["53..79;20..21;62..72 ByRef(Immutable) *a &' bool"], ); check_closure_captures( r#" @@ -170,7 +199,7 @@ fn main() { let closure = || { let &mut ref mut b = a; }; } "#, - expect!["53..83;20..21;67..76 ByRef(Mut { kind: Default }) *a &'? mut bool"], + expect!["53..83;20..21;62..76 ByRef(Mutable) *a &' mut bool"], ); } @@ -184,7 +213,7 @@ fn main() { let closure = || { *a = false; }; } "#, - expect!["53..71;20..21;58..60 ByRef(Mut { kind: Default }) *a &'? mut bool"], + expect!["53..71;20..21;58..60 ByRef(Mutable) *a &' mut bool"], ); } @@ -198,7 +227,7 @@ fn main() { let closure = || { let ref mut b = *a; }; } "#, - expect!["53..79;20..21;62..71 ByRef(Mut { kind: Default }) *a &'? mut bool"], + expect!["53..79;20..21;74..76 ByRef(Mutable) *a &' mut bool"], ); } @@ -212,7 +241,7 @@ fn main() { let closure = || { let _ = *a else { return; }; }; } "#, - expect!["53..88;20..21;66..68 ByRef(Shared) *a &'? bool"], + expect![""], ); } @@ -244,8 +273,8 @@ fn main() { } "#, expect![[r#" - 71..89;36..41;84..86 ByRef(Shared) a &'? NonCopy - 109..131;36..41;122..128 ByRef(Mut { kind: Default }) a &'? mut NonCopy"#]], + 71..89;36..41;85..86 ByRef(Immutable) a &' NonCopy + 109..131;36..41;127..128 ByRef(Mutable) a &' mut NonCopy"#]], ); } @@ -260,7 +289,7 @@ fn main() { let closure = || { let b = a.a; }; } "#, - expect!["92..111;50..51;105..108 ByRef(Shared) a.a &'? i32"], + expect!["92..111;50..51;105..108 ByRef(Immutable) a.a &' i32"], ); } @@ -281,8 +310,8 @@ fn main() { } "#, expect![[r#" - 133..212;87..92;154..158 ByRef(Shared) a.a &'? i32 - 133..212;87..92;176..184 ByRef(Mut { kind: Default }) a.b &'? mut i32 + 133..212;87..92;155..158 ByRef(Immutable) a.a &' i32 + 133..212;87..92;181..184 ByRef(Mutable) a.b &' mut i32 133..212;87..92;202..205 ByValue a.c NonCopy"#]], ); } @@ -304,8 +333,8 @@ fn main() { } "#, expect![[r#" - 123..133;92..97;126..127 ByRef(Shared) a &'? Foo - 153..164;92..97;156..157 ByRef(Mut { kind: Default }) a &'? mut Foo"#]], + 123..133;92..97;126..127 ByRef(Immutable) a &' Foo + 153..164;92..97;156..157 ByRef(Mutable) a &' mut Foo"#]], ); } @@ -332,7 +361,7 @@ fn main() { } "#, expect![[r#" - 113..167;36..41;127..128,154..160 ByRef(Mut { kind: Default }) a &'? mut &'? mut bool + 113..167;36..41;127..128,159..160 ByRef(Mutable) a &' mut &'? mut bool 231..304;196..201;252..253,276..277,296..297 ByValue a NonCopy"#]], ); } @@ -371,8 +400,8 @@ fn main() { } "#, expect![[r#" - 125..163;36..41;134..135 ByRef(Shared) a &'? NonCopy - 183..225;36..41;192..193 ByRef(Mut { kind: Default }) a &'? mut NonCopy"#]], + 125..163;36..41;134..135 ByRef(Immutable) a &' NonCopy + 183..225;36..41;192..193 ByRef(Mutable) a &' mut NonCopy"#]], ); } @@ -386,7 +415,7 @@ fn main() { let mut closure = || { let (b | b) = a; }; } "#, - expect!["57..80;20..25;76..77,76..77 ByRef(Shared) a &'? bool"], + expect!["57..80;20..25;76..77 ByRef(Immutable) a &' bool"], ); } @@ -406,7 +435,7 @@ fn main() { } "#, expect![ - "57..149;20..25;78..80,98..100,118..124,134..135 ByRef(Mut { kind: Default }) a &'? mut bool" + "57..149;20..25;79..80,99..100,123..124,134..135 ByRef(Mutable) a &' mut bool" ], ); } @@ -421,7 +450,7 @@ fn main() { let mut closure = || { let b = *&mut a; }; } "#, - expect!["57..80;20..25;71..77 ByRef(Mut { kind: Default }) a &'? mut bool"], + expect!["57..80;20..25;76..77 ByRef(Mutable) a &' mut bool"], ); } @@ -440,10 +469,10 @@ fn main() { } "#, expect![[r#" - 54..72;20..25;67..69 ByRef(Shared) a &'? &'? bool - 92..114;20..25;105..111 ByRef(Mut { kind: Default }) a &'? mut &'? bool - 158..176;124..125;171..173 ByRef(Shared) a &'? &'? mut bool - 196..218;124..125;209..215 ByRef(Mut { kind: Default }) a &'? mut &'? mut bool"#]], + 54..72;20..25;68..69 ByRef(Immutable) a &' &'? bool + 92..114;20..25;110..111 ByRef(Mutable) a &' mut &'? bool + 158..176;124..125;172..173 ByRef(Immutable) a &' &'? mut bool + 196..218;124..125;214..215 ByRef(Mutable) a &' mut &'? mut bool"#]], ); } @@ -451,7 +480,7 @@ fn main() { fn multiple_capture_usages() { check_closure_captures( r#" -//- minicore:copy, fn +//- minicore: copy, fn struct A { a: i32, b: bool } fn main() { let mut a = A { a: 123, b: false }; @@ -462,7 +491,7 @@ fn main() { closure(); } "#, - expect!["99..165;49..54;120..121,133..134 ByRef(Mut { kind: Default }) a &'? mut A"], + expect!["99..165;49..54;120..121,133..134 ByRef(Mutable) a &' mut A"], ); } @@ -485,8 +514,8 @@ fn main() { } "#, expect![[r#" - 129..225;49..54;149..155 ByRef(Shared) s_ref &'? &'? mut S - 129..225;93..99;188..198 ByRef(Mut { kind: Default }) s_ref2 &'? mut &'? mut S"#]], + 129..225;49..54;158..163 ByRef(Immutable) s_ref &' &'? mut S + 129..225;93..99;201..207 ByRef(Mutable) s_ref2 &' mut &'? mut S"#]], ); } @@ -530,7 +559,7 @@ fn main() { }; } "#, - expect!["220..257;174..175;245..250 ByRef(Shared) c.b.x &'? i32"], + expect!["220..257;174..175;245..250 ByRef(Immutable) c.b.x &' i32"], ); } @@ -549,8 +578,8 @@ fn f() { } "#, expect![[r#" - 44..113;17..18;92..93 ByRef(Shared) a &'? i32 - 73..106;17..18;92..93 ByRef(Shared) a &'? i32"#]], + 44..113;17..18;92..93 ByRef(Immutable) a &' i32 + 73..106;17..18;92..93 ByRef(Immutable) a &' i32"#]], ); } @@ -568,6 +597,6 @@ fn f() { }; } "#, - expect!["77..110;46..47;96..97 ByRef(Shared) b &'? i32"], + expect!["77..110;46..47;96..97 ByRef(Immutable) b &' i32"], ); } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/coercion.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/coercion.rs index 438699b40983e..a80ce5002deab 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/coercion.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/coercion.rs @@ -309,7 +309,7 @@ fn takes_ref_str(x: &str) {} fn returns_string() -> String { loop {} } fn test() { takes_ref_str(&{ returns_string() }); - // ^^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(None), Deref(Some(OverloadedDeref(Some(Not)))), Borrow(Ref(Not)) + // ^^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(None), Deref(Some(OverloadedDeref(Not))), Borrow(Ref(Not)) } "#, ); @@ -598,6 +598,10 @@ fn test() { ); } +// FIXME: rustc emits the following error here: +// - error[E0277]: he size for values of type `impl Foo + ?Sized` cannot be known at compilation time +// ...but we don't emit any error here for now +#[ignore = "rustc emits E0277 here"] #[test] fn coerce_unsize_apit() { check( diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs index 7cda259664c10..960155a8e4f93 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/incremental.rs @@ -49,6 +49,7 @@ fn foo() -> i32 { "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "ExprScopes::body_expr_scopes_", + "body_upvars_mentioned", ] "#]], ); @@ -137,6 +138,7 @@ fn baz() -> i32 { "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "ExprScopes::body_expr_scopes_", + "body_upvars_mentioned", "InferenceResult::for_body_", "FunctionSignature::of_", "FunctionSignature::with_source_map_", @@ -147,6 +149,7 @@ fn baz() -> i32 { "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "ExprScopes::body_expr_scopes_", + "body_upvars_mentioned", "InferenceResult::for_body_", "FunctionSignature::of_", "FunctionSignature::with_source_map_", @@ -157,6 +160,7 @@ fn baz() -> i32 { "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", "ExprScopes::body_expr_scopes_", + "body_upvars_mentioned", ] "#]], ); @@ -205,6 +209,7 @@ fn baz() -> i32 { "Body::of_", "InferenceResult::for_body_", "ExprScopes::body_expr_scopes_", + "body_upvars_mentioned", "AttrFlags::query_", "FunctionSignature::with_source_map_", "FunctionSignature::of_", @@ -594,6 +599,7 @@ fn main() { "GenericPredicates::query_with_diagnostics_", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", + "body_upvars_mentioned", "InferenceResult::for_body_", "FunctionSignature::of_", "FunctionSignature::with_source_map_", @@ -616,6 +622,7 @@ fn main() { "impl_self_ty_with_diagnostics_query", "AttrFlags::query_", "GenericPredicates::query_with_diagnostics_", + "body_upvars_mentioned", ] "#]], ); @@ -686,6 +693,7 @@ fn main() { "GenericPredicates::query_with_diagnostics_", "GenericPredicates::query_with_diagnostics_", "ImplTraits::return_type_impl_traits_", + "body_upvars_mentioned", "InferenceResult::for_body_", "FunctionSignature::with_source_map_", "GenericPredicates::query_with_diagnostics_", @@ -703,6 +711,7 @@ fn main() { "impl_self_ty_with_diagnostics_query", "AttrFlags::query_", "GenericPredicates::query_with_diagnostics_", + "body_upvars_mentioned", ] "#]], ); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/method_resolution.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/method_resolution.rs index c8ed8aa2584c4..fc8c1f8164801 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/method_resolution.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/method_resolution.rs @@ -1367,7 +1367,7 @@ mod a { mod b { fn foo() { let x = super::a::Bar::new().0; - // ^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(Some(OverloadedDeref(Some(Not)))) + // ^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(Some(OverloadedDeref(Not))) // ^^^^^^^^^^^^^^^^^^^^^^ type: char } } @@ -2129,7 +2129,7 @@ impl Foo { use core::mem::ManuallyDrop; fn test() { ManuallyDrop::new(Foo).foo(); - //^^^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(Some(OverloadedDeref(Some(Not)))), Borrow(Ref(Not)) + //^^^^^^^^^^^^^^^^^^^^^^ adjustments: Deref(Some(OverloadedDeref(Not))), Borrow(Ref(Not)) } "#, ); diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/patterns.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/patterns.rs index 42dc074309357..d6bc03f57dee0 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/patterns.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/patterns.rs @@ -294,6 +294,7 @@ fn infer_pattern_match_ergonomics_ref() { fn ref_pat_with_inference_variable() { check_no_mismatches( r#" +//- minicore: fn enum E { A } fn test() { let f = |e| match e { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression.rs index d3dfc44c227f9..e30fa779dac25 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression.rs @@ -2363,6 +2363,7 @@ fn test() { } "#, expect![[r#" + 46..49 'Foo': Foo 93..97 'self': Foo 108..125 '{ ... }': usize 118..119 'N': usize @@ -2856,3 +2857,59 @@ fn foo(v: T::T) {} "#, ); } + +#[test] +fn regression_22007() { + check_types( + r#" +//- minicore: fn +trait Super { + type Assoc; + fn foo(self) -> Self::Assoc + where + Self: Sub, + { loop {} } +} +trait Sub: Super {} + +struct Struct; +impl Super for Struct { + type Assoc = u8; +} +impl Sub for Struct {} + +fn foo() { + Struct.foo(); + // ^^^^^^^^^^^^ u8 +} + "#, + ); +} + +#[test] +fn regression_21885() { + check_no_mismatches( + r#" +//- minicore: coerce_unsized, future, result +use core::future::Future; + +trait Foo { + type Assoc; + + fn foo() -> &dyn Future>; +} + +struct Bar; + +impl Foo for Bar { + type Assoc = NotFound; + + fn foo() -> &dyn Future> { + &async { + Err(()) + } + } +} +"#, + ); +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression/new_solver.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression/new_solver.rs index e6b3244cda248..565360dc25680 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression/new_solver.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/regression/new_solver.rs @@ -680,9 +680,9 @@ where expect![[r#" 43..47 'self': &'? Self 168..172 'self': &'? F - 205..227 '{ ... }': >::CallRefFuture<'> + 205..227 '{ ... }': >::CallRefFuture<'?> 215..219 'self': &'? F - 215..221 'self()': >::CallRefFuture<'> + 215..221 'self()': >::CallRefFuture<'?> "#]], ); } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs b/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs index 1d27d52a36604..278666ef35923 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/tests/traits.rs @@ -87,7 +87,7 @@ async fn test() { fn infer_async_closure() { check_types( r#" -//- minicore: future, option +//- minicore: future, option, async_fn async fn test() { let f = async move |x: i32| x + 42; f; @@ -3149,6 +3149,7 @@ impl core::iter::Iterator for core::ops::Range { fn infer_closure_arg() { check_infer( r#" +//- minicore: fn //- /lib.rs enum Option { diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/upvars.rs b/src/tools/rust-analyzer/crates/hir-ty/src/upvars.rs index 489895fe3cb7d..48f3c803d8322 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/upvars.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/upvars.rs @@ -1,8 +1,8 @@ //! A simple query to collect tall locals (upvars) a closure use. use hir_def::{ - DefWithBodyId, - expr_store::{Body, path::Path}, + DefWithBodyId, ExpressionStoreOwnerId, GenericDefId, VariantId, + expr_store::{ExpressionStore, path::Path}, hir::{BindingId, Expr, ExprId, ExprOrPatId, Pat}, resolver::{HasResolver, Resolver, ValueNs}, }; @@ -36,18 +36,89 @@ impl Upvars { pub fn is_empty(&self) -> bool { self.0.is_empty() } + + #[inline] + pub fn as_ref(&self) -> UpvarsRef<'_> { + UpvarsRef(&self.0) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +// Kept sorted. +pub struct UpvarsRef<'db>(&'db [BindingId]); + +impl UpvarsRef<'_> { + #[inline] + pub fn contains(self, local: BindingId) -> bool { + self.0.binary_search(&local).is_ok() + } + + #[inline] + pub fn iter(self) -> impl ExactSizeIterator { + self.0.iter().copied() + } + + #[inline] + pub fn is_empty(self) -> bool { + self.0.is_empty() + } + + #[inline] + pub const fn empty() -> Self { + UpvarsRef(&[]) + } } /// Returns a map from `Expr::Closure` to its upvars. -#[salsa::tracked(returns(as_deref))] pub fn upvars_mentioned( db: &dyn HirDatabase, - owner: DefWithBodyId, + owner: ExpressionStoreOwnerId, +) -> Option<&FxHashMap> { + return match owner { + ExpressionStoreOwnerId::Signature(owner) => signature_upvars_mentioned(db, owner), + ExpressionStoreOwnerId::Body(owner) => body_upvars_mentioned(db, owner), + ExpressionStoreOwnerId::VariantFields(owner) => variant_fields_upvars_mentioned(db, owner), + }; + + #[salsa::tracked(returns(as_deref))] + pub fn signature_upvars_mentioned( + db: &dyn HirDatabase, + owner: GenericDefId, + ) -> Option>> { + upvars_mentioned_impl(db, owner.into()) + } + + #[salsa::tracked(returns(as_deref))] + pub fn body_upvars_mentioned( + db: &dyn HirDatabase, + owner: DefWithBodyId, + ) -> Option>> { + upvars_mentioned_impl(db, owner.into()) + } + + #[salsa::tracked(returns(as_deref))] + pub fn variant_fields_upvars_mentioned( + db: &dyn HirDatabase, + owner: VariantId, + ) -> Option>> { + upvars_mentioned_impl(db, owner.into()) + } +} + +pub fn upvars_mentioned_impl( + db: &dyn HirDatabase, + owner: ExpressionStoreOwnerId, ) -> Option>> { - let body = Body::of(db, owner); + let store = ExpressionStore::of(db, owner); + if store.const_expr_origins().is_empty() { + // Save constructing a Resolver. + return None; + } let mut resolver = owner.resolver(db); let mut result = FxHashMap::default(); - handle_expr_outside_closure(db, &mut resolver, owner, body, body.root_expr(), &mut result); + for root_expr in store.expr_roots() { + handle_expr_outside_closure(db, &mut resolver, owner, store, root_expr, &mut result); + } return if result.is_empty() { None } else { @@ -58,8 +129,8 @@ pub fn upvars_mentioned( fn handle_expr_outside_closure<'db>( db: &'db dyn HirDatabase, resolver: &mut Resolver<'db>, - owner: DefWithBodyId, - body: &Body, + owner: ExpressionStoreOwnerId, + body: &ExpressionStore, expr: ExprId, closures_map: &mut FxHashMap, ) { @@ -89,8 +160,8 @@ pub fn upvars_mentioned( fn handle_expr_inside_closure<'db>( db: &'db dyn HirDatabase, resolver: &mut Resolver<'db>, - owner: DefWithBodyId, - body: &Body, + owner: ExpressionStoreOwnerId, + body: &ExpressionStore, current_closure: ExprId, expr: ExprId, upvars: &mut FxHashSet, @@ -170,8 +241,8 @@ pub fn upvars_mentioned( fn resolve_maybe_upvar<'db>( db: &'db dyn HirDatabase, resolver: &mut Resolver<'db>, - owner: DefWithBodyId, - body: &Body, + owner: ExpressionStoreOwnerId, + body: &ExpressionStore, current_closure: ExprId, expr: ExprId, id: ExprOrPatId, @@ -179,8 +250,9 @@ fn resolve_maybe_upvar<'db>( path: &Path, ) { if let Path::BarePath(mod_path) = path - && matches!(mod_path.kind, PathKind::Plain) - && mod_path.segments().len() == 1 + && matches!(mod_path.kind, PathKind::Plain | PathKind::SELF) + // `self` is length zero. + && mod_path.segments().len() <= 1 { // Could be a variable. let guard = resolver.update_to_inner_scope(db, owner, expr); @@ -198,7 +270,9 @@ fn resolve_maybe_upvar<'db>( #[cfg(test)] mod tests { use expect_test::{Expect, expect}; - use hir_def::{ModuleDefId, expr_store::Body, nameres::crate_def_map}; + use hir_def::{ + AssocItemId, DefWithBodyId, ModuleDefId, expr_store::Body, nameres::crate_def_map, + }; use itertools::Itertools; use span::Edition; use test_fixture::WithFixture; @@ -217,10 +291,18 @@ mod tests { ModuleDefId::FunctionId(func) => Some(func), _ => None, }) + .chain(def_map.modules().flat_map(|(_, module)| { + module.scope.impls().flat_map(|impl_| &*impl_.impl_items(&db).items).filter_map( + |&(_, item)| match item { + AssocItemId::FunctionId(it) => Some(it), + _ => None, + }, + ) + })) .exactly_one() .unwrap_or_else(|_| panic!("expected one function")); let (body, source_map) = Body::with_source_map(&db, func.into()); - let Some(upvars) = upvars_mentioned(&db, func.into()) else { + let Some(upvars) = upvars_mentioned(&db, DefWithBodyId::from(func).into()) else { expectation.assert_eq(""); return; }; @@ -316,4 +398,19 @@ fn foo() { 49..110: a, b"#]], ); } + + #[test] + fn self_upvar() { + check( + r#" +struct Foo(i32); +impl Foo { + fn foo(&self) { + || self.0; + } +} + "#, + expect!["56..65: self"], + ); + } } diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/utils.rs b/src/tools/rust-analyzer/crates/hir-ty/src/utils.rs index 509109543cd6f..ae9b2c4618960 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/utils.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/utils.rs @@ -1,6 +1,8 @@ //! Helper functions for working with def, which don't need to be a separate //! query, but can't be computed directly from `*Data` (ie, which need a `db`). +use std::iter::Enumerate; + use base_db::target::{self, TargetData}; use hir_def::{ EnumId, EnumVariantId, FunctionId, Lookup, TraitId, attrs::AttrFlags, lang_item::LangItems, @@ -163,3 +165,54 @@ pub(crate) fn detect_variant_from_bytes<'a>( }; Some((var_id, var_layout)) } + +pub(crate) struct EnumerateAndAdjust { + enumerate: Enumerate, + gap_pos: usize, + gap_len: usize, +} + +impl Iterator for EnumerateAndAdjust +where + I: Iterator, +{ + type Item = (usize, ::Item); + + fn next(&mut self) -> Option<(usize, ::Item)> { + self.enumerate + .next() + .map(|(i, elem)| (if i < self.gap_pos { i } else { i + self.gap_len }, elem)) + } + + fn size_hint(&self) -> (usize, Option) { + self.enumerate.size_hint() + } +} + +pub(crate) trait EnumerateAndAdjustIterator { + fn enumerate_and_adjust( + self, + expected_len: usize, + gap_pos: Option, + ) -> EnumerateAndAdjust + where + Self: Sized; +} + +impl EnumerateAndAdjustIterator for T { + fn enumerate_and_adjust( + self, + expected_len: usize, + gap_pos: Option, + ) -> EnumerateAndAdjust + where + Self: Sized, + { + let actual_len = self.len(); + EnumerateAndAdjust { + enumerate: self.enumerate(), + gap_pos: gap_pos.map(|it| it as usize).unwrap_or(expected_len), + gap_len: expected_len - actual_len, + } + } +} diff --git a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs index 1945b04bb3cc3..a88457e3c745b 100644 --- a/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs +++ b/src/tools/rust-analyzer/crates/hir-ty/src/variance.rs @@ -216,7 +216,7 @@ impl<'db> Context<'db> { TyKind::Adt(def, args) => { self.add_constraints_from_args(def.def_id().0.into(), args, variance); } - TyKind::Alias(_, alias) => { + TyKind::Alias(alias) => { // FIXME: Probably not correct wrt. opaques. self.add_constraints_from_invariant_args(alias.args); } diff --git a/src/tools/rust-analyzer/crates/hir/Cargo.toml b/src/tools/rust-analyzer/crates/hir/Cargo.toml index d20ee1546fa48..89021441892c6 100644 --- a/src/tools/rust-analyzer/crates/hir/Cargo.toml +++ b/src/tools/rust-analyzer/crates/hir/Cargo.toml @@ -21,6 +21,7 @@ serde_json.workspace = true smallvec.workspace = true tracing = { workspace = true, features = ["attributes"] } triomphe.workspace = true +la-arena.workspace = true ra-ap-rustc_type_ir.workspace = true diff --git a/src/tools/rust-analyzer/crates/hir/src/attrs.rs b/src/tools/rust-analyzer/crates/hir/src/attrs.rs index 27e7985146107..223103b6e5d26 100644 --- a/src/tools/rust-analyzer/crates/hir/src/attrs.rs +++ b/src/tools/rust-analyzer/crates/hir/src/attrs.rs @@ -38,7 +38,11 @@ pub enum AttrsOwner { Field(FieldId), LifetimeParam(LifetimeParamId), TypeOrConstParam(TypeOrConstParamId), - /// Things that do not have attributes. Used for builtin derives. + /// Things that do not have attributes. + /// + /// Used for: + /// - builtin derives + /// - builtin types (as those do not have attributes) Dummy, } @@ -85,6 +89,19 @@ impl AttrsWithOwner { self.attrs.contains(AttrFlags::IS_UNSTABLE) } + /// Currently, it could be that `is_unstable() == true` but `unstable_feature == None` + /// (due to unstable features not being retrieved for fields etc.). + #[inline] + pub fn unstable_feature(&self, db: &dyn HirDatabase) -> Option { + match self.owner { + AttrsOwner::AttrDef(owner) => self.attrs.unstable_feature(db, owner), + AttrsOwner::Field(_) + | AttrsOwner::LifetimeParam(_) + | AttrsOwner::TypeOrConstParam(_) + | AttrsOwner::Dummy => None, + } + } + #[inline] pub fn is_macro_export(&self) -> bool { self.attrs.contains(AttrFlags::IS_MACRO_EXPORT) diff --git a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs index 555270bad8309..6cfb79d5a1f4b 100644 --- a/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs +++ b/src/tools/rust-analyzer/crates/hir/src/diagnostics.rs @@ -491,35 +491,35 @@ impl<'db> AnyDiagnostic<'db> { let file = record.file_id; let root = record.file_syntax(db); match record.value.to_node(&root) { - Either::Left(ast::Expr::RecordExpr(record_expr)) => { - if record_expr.record_expr_field_list().is_some() { - let field_list_parent_path = - record_expr.path().map(|path| AstPtr::new(&path)); - return Some( - MissingFields { - file, - field_list_parent: AstPtr::new(&Either::Left(record_expr)), - field_list_parent_path, - missed_fields, - } - .into(), - ); - } + Either::Left(ast::Expr::RecordExpr(record_expr)) + if record_expr.record_expr_field_list().is_some() => + { + let field_list_parent_path = + record_expr.path().map(|path| AstPtr::new(&path)); + return Some( + MissingFields { + file, + field_list_parent: AstPtr::new(&Either::Left(record_expr)), + field_list_parent_path, + missed_fields, + } + .into(), + ); } - Either::Right(ast::Pat::RecordPat(record_pat)) => { - if record_pat.record_pat_field_list().is_some() { - let field_list_parent_path = - record_pat.path().map(|path| AstPtr::new(&path)); - return Some( - MissingFields { - file, - field_list_parent: AstPtr::new(&Either::Right(record_pat)), - field_list_parent_path, - missed_fields, - } - .into(), - ); - } + Either::Right(ast::Pat::RecordPat(record_pat)) + if record_pat.record_pat_field_list().is_some() => + { + let field_list_parent_path = + record_pat.path().map(|path| AstPtr::new(&path)); + return Some( + MissingFields { + file, + field_list_parent: AstPtr::new(&Either::Right(record_pat)), + field_list_parent_path, + missed_fields, + } + .into(), + ); } _ => {} } diff --git a/src/tools/rust-analyzer/crates/hir/src/has_source.rs b/src/tools/rust-analyzer/crates/hir/src/has_source.rs index f9badc0b79016..45c9811cc0158 100644 --- a/src/tools/rust-analyzer/crates/hir/src/has_source.rs +++ b/src/tools/rust-analyzer/crates/hir/src/has_source.rs @@ -293,7 +293,7 @@ impl HasSource for Param<'_> { .map(|value| InFile { file_id, value }) } Callee::Closure(closure, _) => { - let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure); + let InternedClosure(owner, expr_id) = closure.loc(db); let (_, source_map) = ExpressionStore::with_source_map(db, owner); let ast @ InFile { file_id, value } = source_map.expr_syntax(expr_id).ok()?; let root = db.parse_or_expand(file_id); diff --git a/src/tools/rust-analyzer/crates/hir/src/lib.rs b/src/tools/rust-analyzer/crates/hir/src/lib.rs index 2829902035985..d24e2c0cb5837 100644 --- a/src/tools/rust-analyzer/crates/hir/src/lib.rs +++ b/src/tools/rust-analyzer/crates/hir/src/lib.rs @@ -85,14 +85,14 @@ use hir_ty::{ GenericPredicates, InferenceResult, ParamEnvAndCrate, TyDefId, TyLoweringDiagnostic, ValueTyDefId, all_super_traits, autoderef, check_orphan_rules, consteval::try_const_usize, - db::{InternedClosureId, InternedCoroutineId}, + db::{InternedClosure, InternedClosureId, InternedCoroutineClosureId}, diagnostics::BodyValidationDiagnostic, direct_super_traits, known_const_to_ast, layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding}, method_resolution::{ self, InherentImpls, MethodResolutionContext, MethodResolutionUnstableFeatures, }, - mir::{MutBorrowKind, interpret_mir}, + mir::interpret_mir, next_solver::{ AliasTy, AnyImplId, ClauseKind, ConstKind, DbInterner, EarlyBinder, EarlyParamRegion, ErrorGuaranteed, GenericArg, GenericArgs, ParamConst, ParamEnv, PolyFnSig, Region, @@ -108,9 +108,8 @@ use rustc_type_ir::{ TypeVisitor, fast_reject, inherent::{AdtDef, GenericArgs as _, IntoKind, SliceLike, Term as _, Ty as _}, }; -use smallvec::SmallVec; use span::{AstIdNode, Edition, FileId}; -use stdx::{format_to, impl_from, never, variance::PhantomCovariantLifetime}; +use stdx::{format_to, impl_from, never}; use syntax::{ AstNode, AstPtr, SmolStr, SyntaxNode, SyntaxNodePtr, TextRange, ToSmolStr, ast::{self, HasName as _, HasVisibility as _}, @@ -342,6 +341,10 @@ impl Crate { }) .map(Crate::from) } + + pub fn is_unstable_feature_enabled(self, db: &dyn HirDatabase, feature: &Symbol) -> bool { + crate_def_map(db, self.id).is_unstable_feature_enabled(feature) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -552,6 +555,23 @@ impl HasCrate for ModuleDef { } } +impl HasAttrs for ModuleDef { + fn attr_id(self, db: &dyn HirDatabase) -> attrs::AttrsOwner { + match self { + ModuleDef::Module(it) => it.attr_id(db), + ModuleDef::Function(it) => it.attr_id(db), + ModuleDef::Adt(it) => it.attr_id(db), + ModuleDef::EnumVariant(it) => it.attr_id(db), + ModuleDef::Const(it) => it.attr_id(db), + ModuleDef::Static(it) => it.attr_id(db), + ModuleDef::Trait(it) => it.attr_id(db), + ModuleDef::TypeAlias(it) => it.attr_id(db), + ModuleDef::Macro(it) => it.attr_id(db), + ModuleDef::BuiltinType(_) => attrs::AttrsOwner::Dummy, + } + } +} + impl HasVisibility for ModuleDef { fn visibility(&self, db: &dyn HirDatabase) -> Visibility { match *self { @@ -2284,11 +2304,9 @@ impl DefWithBody { } } (mir::MutabilityReason::Not, true) => { - if !infer.mutated_bindings_in_closure.contains(&binding_id) { - let should_ignore = body[binding_id].name.as_str().starts_with('_'); - if !should_ignore { - acc.push(UnusedMut { local }.into()) - } + let should_ignore = body[binding_id].name.as_str().starts_with('_'); + if !should_ignore { + acc.push(UnusedMut { local }.into()) } } } @@ -2950,7 +2968,7 @@ impl<'db> Param<'db> { } } Callee::Closure(closure, _) => { - let c = db.lookup_intern_closure(closure); + let c = closure.loc(db); let body_owner = c.0; let store = ExpressionStore::of(db, c.0); @@ -3124,7 +3142,7 @@ impl Const { let interner = DbInterner::new_no_crate(db); let ty = db.value_ty(self.id.into()).unwrap().instantiate_identity(); db.const_eval(self.id, GenericArgs::empty(interner), None).map(|it| EvaluatedConst { - const_: it, + allocation: it, def: self.id.into(), ty, }) @@ -3139,22 +3157,19 @@ impl HasVisibility for Const { pub struct EvaluatedConst<'db> { def: DefWithBodyId, - const_: hir_ty::next_solver::Const<'db>, + allocation: hir_ty::next_solver::Allocation<'db>, ty: Ty<'db>, } impl<'db> EvaluatedConst<'db> { pub fn render(&self, db: &dyn HirDatabase, display_target: DisplayTarget) -> String { - format!("{}", self.const_.display(db, display_target)) + format!("{}", self.allocation.display(db, display_target)) } pub fn render_debug(&self, db: &'db dyn HirDatabase) -> Result { - let kind = self.const_.kind(); - if let ConstKind::Value(c) = kind - && let ty = c.ty.kind() - && let TyKind::Int(_) | TyKind::Uint(_) = ty - { - let b = &c.value.inner().memory; + let ty = self.allocation.ty.kind(); + if let TyKind::Int(_) | TyKind::Uint(_) = ty { + let b = &self.allocation.memory; let value = u128::from_le_bytes(mir::pad16(b, false)); let value_signed = i128::from_le_bytes(mir::pad16(b, matches!(ty, TyKind::Int(_)))); let mut result = @@ -3166,7 +3181,7 @@ impl<'db> EvaluatedConst<'db> { return Ok(result); } } - mir::render_const_using_debug_impl(db, self.def, self.const_, self.ty) + mir::render_const_using_debug_impl(db, self.def, self.allocation, self.ty) } } @@ -3207,7 +3222,7 @@ impl Static { pub fn eval(self, db: &dyn HirDatabase) -> Result, ConstEvalError> { let ty = db.value_ty(self.id.into()).unwrap().instantiate_identity(); db.const_eval_static(self.id).map(|it| EvaluatedConst { - const_: it, + allocation: it, def: self.id.into(), ty, }) @@ -3318,6 +3333,19 @@ impl Trait { pub fn complete(self, db: &dyn HirDatabase) -> Complete { Complete::extract(true, self.attrs(db).attrs) } + + // Feature: Prefer Underscore Import Attribute + // Crate authors can declare that their trait prefers to be imported `as _`. This can be used + // for example for extension traits. To do that, a trait has to include the attribute + // `#[rust_analyzer::prefer_underscore_import]` + // + // When a trait includes this attribute, flyimport will import it `as _`, and the quickfix + // to import it will prefer to import it `as _` (but allow to import it normally as well). + // + // Malformed attributes will be ignored without warnings. + pub fn prefer_underscore_import(self, db: &dyn HirDatabase) -> bool { + AttrFlags::query(db, self.id.into()).contains(AttrFlags::PREFER_UNDERSCORE_IMPORT) + } } impl HasVisibility for Trait { @@ -5092,7 +5120,7 @@ impl<'db> TraitRef<'db> { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] enum AnyClosureId { ClosureId(InternedClosureId), - CoroutineClosureId(InternedCoroutineId), + CoroutineClosureId(InternedCoroutineClosureId), } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -5127,59 +5155,33 @@ impl<'db> Closure<'db> { } pub fn captured_items(&self, db: &'db dyn HirDatabase) -> Vec> { - let AnyClosureId::ClosureId(id) = self.id else { - // FIXME: Infer coroutine closures' captures. - return Vec::new(); + let closure = match self.id { + AnyClosureId::ClosureId(it) => it.loc(db), + AnyClosureId::CoroutineClosureId(it) => it.loc(db), }; - let owner = db.lookup_intern_closure(id).0; + let InternedClosure(owner, closure) = closure; let infer = InferenceResult::of(db, owner); - let info = infer.closure_info(id); - info.0 - .iter() - .cloned() - .map(|capture| ClosureCapture { - owner, - closure: id, - capture, - _marker: PhantomCovariantLifetime::new(), - }) + let param_env = body_param_env_from_has_crate(db, owner); + infer.closures_data[&closure] + .min_captures + .values() + .flatten() + .map(|capture| ClosureCapture { owner, closure, capture, param_env }) .collect() } - pub fn capture_types(&self, db: &'db dyn HirDatabase) -> Vec> { - let AnyClosureId::ClosureId(id) = self.id else { - // FIXME: Infer coroutine closures' captures. - return Vec::new(); - }; - let owner = db.lookup_intern_closure(id).0; - let Some(body_owner) = owner.as_def_with_body() else { - return Vec::new(); - }; - let infer = InferenceResult::of(db, body_owner); - let (captures, _) = infer.closure_info(id); - let env = body_param_env_from_has_crate(db, body_owner); - captures.iter().map(|capture| Type { env, ty: capture.ty(db, self.subst) }).collect() - } - - pub fn fn_trait(&self, db: &dyn HirDatabase) -> FnTrait { + pub fn fn_trait(&self, _db: &dyn HirDatabase) -> FnTrait { match self.id { - AnyClosureId::ClosureId(id) => { - let owner = db.lookup_intern_closure(id).0; - let Some(body_owner) = owner.as_def_with_body() else { - return FnTrait::FnOnce; - }; - let infer = InferenceResult::of(db, body_owner); - let info = infer.closure_info(id); - info.1.into() - } - AnyClosureId::CoroutineClosureId(_id) => { - // FIXME: Infer kind for coroutine closures. - match self.subst.as_coroutine_closure().kind() { - rustc_type_ir::ClosureKind::Fn => FnTrait::AsyncFn, - rustc_type_ir::ClosureKind::FnMut => FnTrait::AsyncFnMut, - rustc_type_ir::ClosureKind::FnOnce => FnTrait::AsyncFnOnce, - } - } + AnyClosureId::ClosureId(_) => match self.subst.as_closure().kind() { + rustc_type_ir::ClosureKind::Fn => FnTrait::Fn, + rustc_type_ir::ClosureKind::FnMut => FnTrait::FnMut, + rustc_type_ir::ClosureKind::FnOnce => FnTrait::FnOnce, + }, + AnyClosureId::CoroutineClosureId(_) => match self.subst.as_coroutine_closure().kind() { + rustc_type_ir::ClosureKind::Fn => FnTrait::AsyncFn, + rustc_type_ir::ClosureKind::FnMut => FnTrait::AsyncFnMut, + rustc_type_ir::ClosureKind::FnOnce => FnTrait::AsyncFnOnce, + }, } } } @@ -5252,51 +5254,120 @@ impl FnTrait { #[derive(Clone, Debug, PartialEq, Eq)] pub struct ClosureCapture<'db> { owner: ExpressionStoreOwnerId, - closure: InternedClosureId, - capture: hir_ty::CapturedItem, - _marker: PhantomCovariantLifetime<'db>, + closure: ExprId, + capture: &'db hir_ty::closure_analysis::CapturedPlace, + param_env: ParamEnvAndCrate<'db>, } impl<'db> ClosureCapture<'db> { pub fn local(&self) -> Local { - Local { parent: self.owner, binding_id: self.capture.local() } + Local { parent: self.owner, binding_id: self.capture.captured_local() } } /// Returns whether this place has any field (aka. non-deref) projections. pub fn has_field_projections(&self) -> bool { - self.capture.has_field_projections() + self.capture + .place + .projections + .iter() + .any(|proj| matches!(proj.kind, hir_ty::closure_analysis::ProjectionKind::Field { .. })) } - pub fn usages(&self) -> CaptureUsages { - CaptureUsages { parent: self.owner, spans: self.capture.spans() } + pub fn usages(&self) -> CaptureUsages<'db> { + CaptureUsages { parent: self.owner, sources: &self.capture.info.sources } } pub fn kind(&self) -> CaptureKind { - match self.capture.kind() { - hir_ty::CaptureKind::ByRef( - hir_ty::mir::BorrowKind::Shallow | hir_ty::mir::BorrowKind::Shared, + match self.capture.info.capture_kind { + hir_ty::closure_analysis::UpvarCapture::ByValue => CaptureKind::Move, + hir_ty::closure_analysis::UpvarCapture::ByUse => CaptureKind::SharedRef, // Good enough? + hir_ty::closure_analysis::UpvarCapture::ByRef( + hir_ty::closure_analysis::BorrowKind::Immutable, ) => CaptureKind::SharedRef, - hir_ty::CaptureKind::ByRef(hir_ty::mir::BorrowKind::Mut { - kind: MutBorrowKind::ClosureCapture, - }) => CaptureKind::UniqueSharedRef, - hir_ty::CaptureKind::ByRef(hir_ty::mir::BorrowKind::Mut { - kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow, - }) => CaptureKind::MutableRef, - hir_ty::CaptureKind::ByValue => CaptureKind::Move, + hir_ty::closure_analysis::UpvarCapture::ByRef( + hir_ty::closure_analysis::BorrowKind::UniqueImmutable, + ) => CaptureKind::UniqueSharedRef, + hir_ty::closure_analysis::UpvarCapture::ByRef( + hir_ty::closure_analysis::BorrowKind::Mutable, + ) => CaptureKind::MutableRef, } } /// Converts the place to a name that can be inserted into source code. - pub fn place_to_name(&self, db: &dyn HirDatabase) -> String { - self.capture.place_to_name(self.owner, db) + pub fn place_to_name(&self, db: &dyn HirDatabase, edition: Edition) -> String { + let mut result = self.local().name(db).display(db, edition).to_string(); + for (i, proj) in self.capture.place.projections.iter().enumerate() { + match proj.kind { + hir_ty::closure_analysis::ProjectionKind::Deref => {} + hir_ty::closure_analysis::ProjectionKind::Field { field_idx, variant_idx } => { + let ty = self.capture.place.ty_before_projection(i); + match ty.kind() { + TyKind::Tuple(_) => format_to!(result, "_{field_idx}"), + TyKind::Adt(adt_def, _) => { + let variant = match adt_def.def_id().0 { + AdtId::StructId(id) => VariantId::from(id), + AdtId::UnionId(id) => id.into(), + AdtId::EnumId(id) => { + id.enum_variants(db).variants[variant_idx as usize].0.into() + } + }; + let field = &variant.fields(db).fields() + [LocalFieldId::from_raw(la_arena::RawIdx::from_u32(field_idx))]; + format_to!(result, "_{}", field.name.display(db, edition)); + } + _ => never!("mismatching projection type"), + } + } + _ => never!("unexpected projection kind"), + } + } + result } - pub fn display_place_source_code(&self, db: &dyn HirDatabase) -> String { - self.capture.display_place_source_code(self.owner, db) + pub fn display_place_source_code(&self, db: &dyn HirDatabase, edition: Edition) -> String { + let mut result = self.local().name(db).display(db, edition).to_string(); + // We only need the derefs that have no field access after them, autoderef will do the rest. + let mut last_derefs = 0; + for (i, proj) in self.capture.place.projections.iter().enumerate() { + match proj.kind { + hir_ty::closure_analysis::ProjectionKind::Deref => last_derefs += 1, + hir_ty::closure_analysis::ProjectionKind::Field { field_idx, variant_idx } => { + last_derefs = 0; + + let ty = self.capture.place.ty_before_projection(i); + match ty.kind() { + TyKind::Tuple(_) => format_to!(result, ".{field_idx}"), + TyKind::Adt(adt_def, _) => { + let variant = match adt_def.def_id().0 { + AdtId::StructId(id) => VariantId::from(id), + AdtId::UnionId(id) => id.into(), + AdtId::EnumId(id) => { + // Can't really do that for an enum, unfortunately, so try to do something alike. + id.enum_variants(db).variants[variant_idx as usize].0.into() + } + }; + let field = &variant.fields(db).fields() + [LocalFieldId::from_raw(la_arena::RawIdx::from_u32(field_idx))]; + format_to!(result, ".{}", field.name.display(db, edition)); + } + _ => never!("mismatching projection type"), + } + } + _ => never!("unexpected projection kind"), + } + } + result.insert_str(0, &"*".repeat(last_derefs)); + result } - pub fn display_place(&self, db: &dyn HirDatabase) -> String { - self.capture.display_place(self.owner, db) + pub fn ty(&self, _db: &'db dyn HirDatabase) -> Type<'db> { + Type { env: self.param_env, ty: self.capture.place.ty() } + } + + /// The type that is stored in the closure, which is different from [`Self::ty()`], representing + /// the place's type, when the capture is by ref. + pub fn captured_ty(&self, db: &'db dyn HirDatabase) -> Type<'db> { + Type { env: self.param_env, ty: self.capture.captured_ty(db) } } } @@ -5309,38 +5380,43 @@ pub enum CaptureKind { } #[derive(Debug, Clone)] -pub struct CaptureUsages { +pub struct CaptureUsages<'db> { parent: ExpressionStoreOwnerId, - spans: SmallVec<[mir::MirSpan; 3]>, -} + sources: &'db [hir_ty::closure_analysis::CaptureSourceStack], +} + +impl CaptureUsages<'_> { + fn is_ref(store: &ExpressionStore, id: ExprOrPatId) -> bool { + match id { + ExprOrPatId::ExprId(expr) => matches!(store[expr], Expr::Ref { .. }), + // FIXME: Figure out if this is correct wrt. match ergonomics. + ExprOrPatId::PatId(pat) => match store[pat] { + Pat::Bind { id: binding, .. } => matches!( + store[binding].mode, + BindingAnnotation::Ref | BindingAnnotation::RefMut + ), + _ => false, + }, + } + } -impl CaptureUsages { pub fn sources(&self, db: &dyn HirDatabase) -> Vec { - let (body, source_map) = ExpressionStore::with_source_map(db, self.parent); - let mut result = Vec::with_capacity(self.spans.len()); - for &span in self.spans.iter() { - let is_ref = span.is_ref_span(body); - match span { - mir::MirSpan::ExprId(expr) => { + let (store, source_map) = ExpressionStore::with_source_map(db, self.parent); + let mut result = Vec::with_capacity(self.sources.len()); + for source in self.sources { + let source = source.final_source(); + let is_ref = Self::is_ref(store, source); + match source { + ExprOrPatId::ExprId(expr) => { if let Ok(expr) = source_map.expr_syntax(expr) { result.push(CaptureUsageSource { is_ref, source: expr }) } } - mir::MirSpan::PatId(pat) => { + ExprOrPatId::PatId(pat) => { if let Ok(pat) = source_map.pat_syntax(pat) { result.push(CaptureUsageSource { is_ref, source: pat }); } } - mir::MirSpan::BindingId(binding) => result.extend( - source_map - .patterns_for_binding(binding) - .iter() - .filter_map(|&pat| source_map.pat_syntax(pat).ok()) - .map(|pat| CaptureUsageSource { is_ref, source: pat }), - ), - mir::MirSpan::SelfParam | mir::MirSpan::Unknown => { - unreachable!("invalid capture usage span") - } } } result @@ -5736,8 +5812,11 @@ impl<'db> Type<'db> { // FIXME: We don't handle GATs yet. let projection = Ty::new_alias( interner, - AliasTyKind::Projection, - AliasTy::new_from_args(interner, alias.id.into(), args), + AliasTy::new_from_args( + interner, + AliasTyKind::Projection { def_id: alias.id.into() }, + args, + ), ); let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis); @@ -5818,6 +5897,18 @@ impl<'db> Type<'db> { matches!(self.ty.kind(), TyKind::RawPtr(..)) } + pub fn is_mutable_raw_ptr(&self) -> bool { + // Used outside of rust-analyzer (e.g. by `ra_ap_hir` consumers). + matches!(self.ty.kind(), TyKind::RawPtr(.., hir_ty::next_solver::Mutability::Mut)) + } + + pub fn as_raw_ptr(&self) -> Option<(Type<'db>, Mutability)> { + // Used outside of rust-analyzer (e.g. by `ra_ap_hir` consumers). + let TyKind::RawPtr(ty, m) = self.ty.kind() else { return None }; + let m = Mutability::from_mutable(matches!(m, hir_ty::next_solver::Mutability::Mut)); + Some((self.derived(ty), m)) + } + pub fn remove_raw_ptr(&self) -> Option> { if let TyKind::RawPtr(ty, _) = self.ty.kind() { Some(self.derived(ty)) } else { None } } @@ -6341,8 +6432,12 @@ impl<'db> Type<'db> { } pub fn as_associated_type_parent_trait(&self, db: &'db dyn HirDatabase) -> Option { - let TyKind::Alias(AliasTyKind::Projection, alias) = self.ty.kind() else { return None }; - match alias.def_id.expect_type_alias().loc(db).container { + let TyKind::Alias(AliasTy { kind: AliasTyKind::Projection { def_id }, .. }) = + self.ty.kind() + else { + return None; + }; + match def_id.expect_type_alias().loc(db).container { ItemContainerId::TraitId(id) => Some(Trait { id }), _ => None, } @@ -6520,7 +6615,7 @@ pub struct Callable<'db> { enum Callee<'db> { Def(CallableDefId), Closure(InternedClosureId, GenericArgs<'db>), - CoroutineClosure(InternedCoroutineId, GenericArgs<'db>), + CoroutineClosure(InternedCoroutineClosureId, GenericArgs<'db>), FnPtr, FnImpl(traits::FnTrait), BuiltinDeriveImplMethod { method: BuiltinDeriveImplMethod, impl_: BuiltinDeriveImplId }, @@ -6658,8 +6753,8 @@ impl Layout { let offset = stride.bytes() * tail; self.0.size.bytes().checked_sub(offset)?.checked_sub(tail_field_size) }), - layout::FieldsShape::Arbitrary { ref offsets, ref memory_index } => { - let tail = memory_index.last_index()?; + layout::FieldsShape::Arbitrary { ref offsets, ref in_memory_order } => { + let tail = in_memory_order[in_memory_order.len().checked_sub(1)? as u32]; let tail_field_size = field_size(tail.0.into_raw().into_u32() as usize)?; let offset = offsets.get(tail)?.bytes(); self.0.size.bytes().checked_sub(offset)?.checked_sub(tail_field_size) @@ -6679,10 +6774,11 @@ impl Layout { let size = field_size(0)?; stride.bytes().checked_sub(size) } - layout::FieldsShape::Arbitrary { ref offsets, ref memory_index } => { - let mut reverse_index = vec![None; memory_index.len()]; - for (src, (mem, offset)) in memory_index.iter().zip(offsets.iter()).enumerate() { - reverse_index[*mem as usize] = Some((src, offset.bytes())); + layout::FieldsShape::Arbitrary { ref offsets, ref in_memory_order } => { + let mut reverse_index = vec![None; in_memory_order.len()]; + for (mem, src) in in_memory_order.iter().enumerate() { + reverse_index[mem] = + Some((src.0.into_raw().into_u32() as usize, offsets[*src].bytes())); } if reverse_index.iter().any(|it| it.is_none()) { stdx::never!(); diff --git a/src/tools/rust-analyzer/crates/hir/src/semantics.rs b/src/tools/rust-analyzer/crates/hir/src/semantics.rs index 9996162485910..b7cc780ae42f7 100644 --- a/src/tools/rust-analyzer/crates/hir/src/semantics.rs +++ b/src/tools/rust-analyzer/crates/hir/src/semantics.rs @@ -1693,9 +1693,7 @@ impl<'db> SemanticsImpl<'db> { hir_ty::Adjust::NeverToAny => Adjust::NeverToAny, hir_ty::Adjust::Deref(Some(hir_ty::OverloadedDeref(m))) => { // FIXME: Should we handle unknown mutability better? - Adjust::Deref(Some(OverloadedDeref( - m.map(mutability).unwrap_or(Mutability::Shared), - ))) + Adjust::Deref(Some(OverloadedDeref(mutability(m)))) } hir_ty::Adjust::Deref(None) => Adjust::Deref(None), hir_ty::Adjust::Borrow(hir_ty::AutoBorrow::RawPtr(m)) => { diff --git a/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs b/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs index 1a34fa913425e..6c43f80ce8789 100644 --- a/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs +++ b/src/tools/rust-analyzer/crates/hir/src/source_analyzer.rs @@ -37,7 +37,7 @@ use hir_ty::{ lang_items::lang_items_for_bin_op, method_resolution::{self, CandidateId}, next_solver::{ - DbInterner, ErrorGuaranteed, GenericArgs, ParamEnv, Ty, TyKind, TypingMode, + AliasTy, DbInterner, ErrorGuaranteed, GenericArgs, ParamEnv, Ty, TyKind, TypingMode, infer::DbInternerInferExt, }, traits::structurally_normalize_ty, @@ -1293,10 +1293,14 @@ impl<'db> SourceAnalyzer<'db> { PathResolution::Def(ModuleDef::Adt(adt_id.into())), ) } - TyKind::Alias(AliasTyKind::Projection, alias) => { - let assoc_id = alias.def_id.expect_type_alias(); + TyKind::Alias(AliasTy { + kind: AliasTyKind::Projection { def_id }, + args, + .. + }) => { + let assoc_id = def_id.expect_type_alias(); ( - GenericSubstitution::new(assoc_id.into(), alias.args, env), + GenericSubstitution::new(assoc_id.into(), args, env), PathResolution::Def(ModuleDef::TypeAlias(assoc_id.into())), ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_braces.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_braces.rs index da1322de4b641..c5ec88ffb88a3 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_braces.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_braces.rs @@ -1,7 +1,7 @@ use either::Either; use syntax::{ AstNode, T, - ast::{self, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, + ast::{self, edit::AstNodeEdit}, match_ast, }; @@ -56,15 +56,13 @@ pub(crate) fn add_braces(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( }, expr.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(expr.syntax()); + let editor = builder.make_editor(expr.syntax()); + let make = editor.make(); let new_expr = expr.reset_indent().indent(1.into()); let block_expr = make.block_expr(None, Some(new_expr)); editor.replace(expr.syntax(), block_expr.indent(expr.indent_level()).syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_dot_deref.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_dot_deref.rs index d27a6b4ce7709..1809b8f305b62 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_dot_deref.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_explicit_dot_deref.rs @@ -53,18 +53,18 @@ pub(crate) fn add_explicit_method_call_deref( "Insert explicit method call derefs", dot_token.text_range(), |builder| { - let mut edit = builder.make_editor(method_call_expr.syntax()); - let make = SyntaxFactory::without_mappings(); + let editor = builder.make_editor(method_call_expr.syntax()); + let make = editor.make(); let mut expr = receiver.clone(); for adjust_kind in adjustments { - expr = adjust_kind.wrap_expr(expr, &make); + expr = adjust_kind.wrap_expr(expr, make); } expr = make.expr_paren(expr).into(); - edit.replace(receiver.syntax(), expr.syntax()); + editor.replace(receiver.syntax(), expr.syntax()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs index 6a408e5254fd6..41e9b6cc84539 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_label_to_loop.rs @@ -3,7 +3,7 @@ use ide_db::{ }; use syntax::{ SyntaxToken, T, - ast::{self, AstNode, HasLoopBody, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, HasLoopBody}, syntax_editor::{Position, SyntaxEditor}, }; @@ -42,8 +42,8 @@ pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> O "Add Label", loop_expr.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(loop_expr.syntax()); + let editor = builder.make_editor(loop_expr.syntax()); + let make = editor.make(); let label = make.lifetime("'l"); let elements = vec![ @@ -65,11 +65,10 @@ pub(crate) fn add_label_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) -> O _ => return, }; if let Some(token) = token { - insert_label_after_token(&mut editor, &make, &token, ctx, builder); + insert_label_after_token(&editor, &token, ctx, builder); } }); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); builder.rename(); }, @@ -85,12 +84,12 @@ fn loop_token(loop_expr: &ast::AnyHasLoopBody) -> Option { } fn insert_label_after_token( - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, token: &SyntaxToken, ctx: &AssistContext<'_>, builder: &mut SourceChangeBuilder, ) { + let make = editor.make(); let label = make.lifetime("'l"); let elements = vec![make.whitespace(" ").into(), label.syntax().clone().into()]; editor.insert_all(Position::after(token), elements); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs index 44b367059eca3..d1f1f9f123387 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_impl_members.rs @@ -148,9 +148,10 @@ fn add_missing_impl_members_inner( let target = impl_def.syntax().text_range(); acc.add(AssistId::quick_fix(assist_id), label, target, |edit| { - let make = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(impl_def.syntax()); + let make = editor.make(); let new_item = add_trait_assoc_items_to_impl( - &make, + make, &ctx.sema, ctx.config, &missing_items, @@ -166,7 +167,7 @@ fn add_missing_impl_members_inner( let mut first_new_item = if let DefaultMethods::No = mode && let ast::AssocItem::Fn(func) = &first_new_item && let Some(body) = try_gen_trait_body( - &make, + make, ctx, func, trait_ref, @@ -175,7 +176,7 @@ fn add_missing_impl_members_inner( ) && let Some(func_body) = func.body() { - let (mut func_editor, _) = SyntaxEditor::new(first_new_item.syntax().clone()); + let (func_editor, _) = SyntaxEditor::new(first_new_item.syntax().clone()); func_editor.replace(func_body.syntax(), body.syntax()); ast::AssocItem::cast(func_editor.finish().new_root().clone()) } else { @@ -188,9 +189,8 @@ fn add_missing_impl_members_inner( .chain(other_items.iter().cloned()) .collect::>(); - let mut editor = edit.make_editor(impl_def.syntax()); if let Some(assoc_item_list) = impl_def.assoc_item_list() { - assoc_item_list.add_items(&mut editor, new_assoc_items); + assoc_item_list.add_items(&editor, new_assoc_items); } else { let assoc_item_list = make.assoc_item_list(new_assoc_items); editor.insert_all( @@ -218,7 +218,6 @@ fn add_missing_impl_members_inner( editor.add_annotation(first_new_item.syntax(), tabstop); }; }; - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs index b7510bb82676a..3c33ddec31624 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_missing_match_arms.rs @@ -8,7 +8,7 @@ use ide_db::{famous_defs::FamousDefs, helpers::mod_path_to_ast}; use itertools::Itertools; use syntax::ast::edit::IndentLevel; use syntax::ast::syntax_factory::SyntaxFactory; -use syntax::ast::{self, AstNode, MatchArmList, MatchExpr, Pat, make}; +use syntax::ast::{self, AstNode, MatchArmList, MatchExpr, Pat}; use syntax::syntax_editor::{Position, SyntaxEditor}; use syntax::{SyntaxKind, SyntaxNode, ToSmolStr}; @@ -74,7 +74,7 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) .filter(|pat| !matches!(pat, Pat::WildcardPat(_))) .collect(); - let make = SyntaxFactory::with_mappings(); + let make = SyntaxFactory::without_mappings(); let scope = ctx.sema.scope(expr.syntax())?; let module = scope.module(); @@ -271,12 +271,12 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) } }; - let mut editor = builder.make_editor(&old_place); + let editor = builder.make_editor(&old_place); let mut arms_edit = ArmsEdit { match_arm_list, place: old_place, last_arm: None }; - arms_edit.remove_wildcard_arms(ctx, &mut editor); - arms_edit.add_comma_after_last_arm(ctx, &make, &mut editor); - arms_edit.append_arms(&missing_arms, &make, &mut editor); + arms_edit.remove_wildcard_arms(ctx, &editor); + arms_edit.add_comma_after_last_arm(ctx, &make, &editor); + arms_edit.append_arms(&missing_arms, &make, &editor); if let Some(cap) = ctx.config.snippet_cap { if let Some(it) = missing_arms @@ -297,7 +297,6 @@ pub(crate) fn add_missing_match_arms(acc: &mut Assists, ctx: &AssistContext<'_>) } } - editor.add_mappings(make.take()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -358,7 +357,7 @@ struct ArmsEdit { } impl ArmsEdit { - fn remove_wildcard_arms(&mut self, ctx: &AssistContext<'_>, editor: &mut SyntaxEditor) { + fn remove_wildcard_arms(&mut self, ctx: &AssistContext<'_>, editor: &SyntaxEditor) { for arm in self.match_arm_list.arms() { if !matches!(arm.pat(), Some(Pat::WildcardPat(_))) { self.last_arm = Some(arm); @@ -387,7 +386,7 @@ impl ArmsEdit { } } - fn append_arms(&self, arms: &[ast::MatchArm], make: &SyntaxFactory, editor: &mut SyntaxEditor) { + fn append_arms(&self, arms: &[ast::MatchArm], make: &SyntaxFactory, editor: &SyntaxEditor) { let Some(mut before) = self.place.last_token() else { stdx::never!("match arm list not contain any token"); return; @@ -420,7 +419,7 @@ impl ArmsEdit { &self, ctx: &AssistContext<'_>, make: &SyntaxFactory, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, ) { if let Some(last_arm) = &self.last_arm && last_arm.comma_token().is_none() @@ -593,12 +592,12 @@ fn build_pat( ExtendedVariant::Variant { variant: var, use_self } => { let edition = module.krate(db).edition(db); let path = if use_self { - make::path_from_segments( + make.path_from_segments( [ - make::path_segment(make::name_ref_self_ty()), - make::path_segment(make::name_ref( - &var.name(db).display(db, edition).to_smolstr(), - )), + make.path_segment(make.name_ref_self_ty()), + make.path_segment( + make.name_ref(&var.name(db).display(db, edition).to_smolstr()), + ), ], false, ) @@ -612,7 +611,7 @@ fn build_pat( let pats = fields.into_iter().map(|f| { let name = name_generator.for_type(&f.ty(db).to_type(db), db, edition); match name { - Some(name) => make::ext::simple_ident_pat(make.name(&name)).into(), + Some(name) => make.ident_pat(false, false, make.name(&name)).into(), None => make.wildcard_pat().into(), } }); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs index c5e722d87e1ae..dcd2124f7bebc 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_turbo_fish.rs @@ -93,8 +93,8 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti "Add `: _` before assignment operator", ident.text_range(), |builder| { - let mut editor = builder.make_editor(let_stmt.syntax()); - let make = SyntaxFactory::without_mappings(); + let editor = builder.make_editor(let_stmt.syntax()); + let make = editor.make(); if let_stmt.semicolon_token().is_none() { editor.insert( @@ -141,14 +141,12 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti ident.text_range(), |builder| { builder.trigger_parameter_hints(); - - let make = SyntaxFactory::with_mappings(); - let mut editor = match &turbofish_target { + let editor = match &turbofish_target { Either::Left(it) => builder.make_editor(it.syntax()), Either::Right(it) => builder.make_editor(it.syntax()), }; - let fish_head = get_fish_head(&make, number_of_arguments); + let fish_head = get_fish_head(editor.make(), number_of_arguments); match turbofish_target { Either::Left(path_segment) => { @@ -180,8 +178,6 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti editor.add_annotation(arg.syntax(), builder.make_placeholder_snippet(cap)); } } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs index 2ea0d76b01617..b87a757047ac5 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -80,9 +80,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti _ => return None, }; - let make = SyntaxFactory::with_mappings(); - - let (mut editor, demorganed) = SyntaxEditor::with_ast_node(&bin_expr); + let (editor, demorganed) = SyntaxEditor::with_ast_node(&bin_expr); + let make = editor.make(); editor.replace(demorganed.op_token()?, make.token(inv_token)); let mut exprs = VecDeque::from([ @@ -98,7 +97,7 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti exprs.push_back((bin_expr.lhs()?, cbin_expr.lhs()?, prec)); exprs.push_back((bin_expr.rhs()?, cbin_expr.rhs()?, prec)); } else { - let mut inv = invert_boolean_expression(&make, expr); + let mut inv = invert_boolean_expression(make, expr); if precedence(&inv).needs_parentheses_in(prec) { inv = make.expr_paren(inv).into(); } @@ -108,7 +107,7 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti return None; } } else { - let mut inv = invert_boolean_expression(&make, demorganed.clone()); + let mut inv = invert_boolean_expression(make, demorganed.clone()); if precedence(&inv).needs_parentheses_in(prec) { inv = make.expr_paren(inv).into(); } @@ -116,7 +115,6 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti } } - editor.add_mappings(make.finish_with_mappings()); let edit = editor.finish(); let demorganed = ast::Expr::cast(edit.new_root().clone())?; @@ -126,7 +124,9 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti "Apply De Morgan's law", op_range, |builder| { - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(bin_expr.syntax()); + let make = editor.make(); + let (target_node, result_expr) = if let Some(neg_expr) = bin_expr .syntax() .parent() @@ -141,9 +141,9 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti bin_expr.syntax().parent().and_then(ast::ParenExpr::cast) { cov_mark::hit!(demorgan_double_parens); - (paren_expr.syntax().clone(), add_bang_paren(&make, demorganed)) + (paren_expr.syntax().clone(), add_bang_paren(make, demorganed)) } else { - (bin_expr.syntax().clone(), add_bang_paren(&make, demorganed)) + (bin_expr.syntax().clone(), add_bang_paren(make, demorganed)) }; let final_expr = if target_node @@ -156,9 +156,7 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti result_expr }; - let mut editor = builder.make_editor(&target_node); editor.replace(&target_node, final_expr.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -206,8 +204,8 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> label, op_range, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(method_call.syntax()); + let editor = builder.make_editor(method_call.syntax()); + let make = editor.make(); // replace the method name let new_name = match name.text().as_str() { "all" => make.name_ref("any"), @@ -217,7 +215,7 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> editor.replace(name.syntax(), new_name.syntax()); // negate all tail expressions in the closure body - let tail_cb = &mut |e: &_| tail_cb_impl(&mut editor, &make, e); + let tail_cb = &mut |e: &_| tail_cb_impl(&editor, e); walk_expr(&closure_body, &mut |expr| { if let ast::Expr::ReturnExpr(ret_expr) = expr && let Some(ret_expr_arg) = &ret_expr.expr() @@ -240,8 +238,6 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> } else { editor.insert(Position::before(method_call.syntax()), make.token(SyntaxKind::BANG)); } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -269,18 +265,18 @@ fn validate_method_call_expr( it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr)) } -fn tail_cb_impl(editor: &mut SyntaxEditor, make: &SyntaxFactory, e: &ast::Expr) { +fn tail_cb_impl(editor: &SyntaxEditor, e: &ast::Expr) { match e { ast::Expr::BreakExpr(break_expr) => { if let Some(break_expr_arg) = break_expr.expr() { - for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, make, e)) + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, e)) } } ast::Expr::ReturnExpr(_) => { // all return expressions have already been handled by the walk loop } e => { - let inverted_body = invert_boolean_expression(make, e.clone()); + let inverted_body = invert_boolean_expression(editor.make(), e.clone()); editor.replace(e.syntax(), inverted_body.syntax()); } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs index adeb191719fb7..f9d618790c6e8 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/auto_import.rs @@ -6,7 +6,7 @@ use ide_db::{ active_parameter::ActiveParameter, helpers::mod_path_to_ast, imports::{ - import_assets::{ImportAssets, ImportCandidate, LocatedImport}, + import_assets::{ImportAssets, ImportCandidate, LocatedImport, TraitImportCandidate}, insert_use::{ImportScope, insert_use, insert_use_as_alias}, }, }; @@ -123,44 +123,48 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option< let (assist_id, import_name) = (AssistId::quick_fix("auto_import"), import_path.display(ctx.db(), edition)); - acc.add_group( - &group_label, - assist_id, - format!("Import `{import_name}`"), - range, - |builder| { + let add_normal_import = |acc: &mut Assists, label| { + acc.add_group(&group_label, assist_id, label, range, |builder| { let scope = builder.make_import_scope_mut(scope.clone()); insert_use(&scope, mod_path_to_ast(&import_path, edition), &ctx.config.insert_use); - }, - ); - - match import_assets.import_candidate() { - ImportCandidate::TraitAssocItem(name) | ImportCandidate::TraitMethod(name) => { - let is_method = - matches!(import_assets.import_candidate(), ImportCandidate::TraitMethod(_)); - let type_ = if is_method { "method" } else { "item" }; - let group_label = GroupLabel(format!( - "Import a trait for {} {} by alias", - type_, - name.assoc_item_name.text() - )); - acc.add_group( - &group_label, - assist_id, - format!("Import `{import_name} as _`"), - range, - |builder| { - let scope = builder.make_import_scope_mut(scope.clone()); - insert_use_as_alias( - &scope, - mod_path_to_ast(&import_path, edition), - &ctx.config.insert_use, - edition, - ); - }, + }) + }; + let add_underscore_import = |acc: &mut Assists, name: &TraitImportCandidate<'_>, label| { + let is_method = + matches!(import_assets.import_candidate(), ImportCandidate::TraitMethod(_)); + let type_ = if is_method { "method" } else { "item" }; + let group_label = GroupLabel(format!( + "Import a trait for {} {} by alias", + type_, + name.assoc_item_name.text() + )); + acc.add_group(&group_label, assist_id, label, range, |builder| { + let scope = builder.make_import_scope_mut(scope.clone()); + insert_use_as_alias( + &scope, + mod_path_to_ast(&import_path, edition), + &ctx.config.insert_use, + edition, ); - } - _ => {} + }); + }; + + if let ImportCandidate::TraitAssocItem(name) | ImportCandidate::TraitMethod(name) = + import_assets.import_candidate() + { + if let hir::ItemInNs::Types(hir::ModuleDef::Trait(trait_to_import)) = + import.item_to_import + && trait_to_import.prefer_underscore_import(ctx.db()) + { + // Flip the order of the suggestions and show a preference for `as _` in the name. + add_underscore_import(acc, name, format!("Import `{import_name}`")); + add_normal_import(acc, format!("Import `{import_name}` without `as _`")); + } else { + add_normal_import(acc, format!("Import `{import_name}`")); + add_underscore_import(acc, name, format!("Import `{import_name} as _`")); + } + } else { + add_normal_import(acc, format!("Import `{import_name}`")); } } Some(()) @@ -1957,4 +1961,72 @@ fn main() { "#, ); } + + #[test] + fn prefer_underscore_import() { + check_assist_by_label( + auto_import, + r#" +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.b$0ar(); +} + "#, + r#" +use foo::Ext as _; + +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.bar(); +} + "#, + "Import `foo::Ext`", + ); + check_assist_by_label( + auto_import, + r#" +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.b$0ar(); +} + "#, + r#" +use foo::Ext; + +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.bar(); +} + "#, + "Import `foo::Ext` without `as _`", + ); + } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bind_unused_param.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bind_unused_param.rs index 0e85a77822bed..50e4a367e9a1b 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bind_unused_param.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/bind_unused_param.rs @@ -1,5 +1,5 @@ use crate::assist_context::{AssistContext, Assists}; -use ide_db::{LineIndexDatabase, assists::AssistId, defs::Definition}; +use ide_db::{assists::AssistId, defs::Definition, line_index}; use syntax::{ AstNode, ast::{self, HasName, edit::AstNodeEdit}, @@ -43,7 +43,7 @@ pub(crate) fn bind_unused_param(acc: &mut Assists, ctx: &AssistContext<'_>) -> O format!("Bind as `let _ = {name};`"), param.syntax().text_range(), |builder| { - let line_index = ctx.db().line_index(ctx.vfs_file_id()); + let line_index = line_index(ctx.db(), ctx.vfs_file_id()); let indent = func.indent_level(); let text_indent = indent + 1; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs index c36c79ee998b4..a2a71bcba6baa 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_bool_then.rs @@ -77,7 +77,7 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> "Convert `if` expression to `bool::then` call", target, |builder| { - let (mut editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body); + let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body); // Rewrite all `Some(e)` in tail position to `e` for_each_tail_expr(&closure_body, &mut |e| { let e = match e { @@ -95,13 +95,13 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> let edit = editor.finish(); let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap(); - let mut editor = builder.make_editor(expr.syntax()); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(expr.syntax()); + let make = editor.make(); let closure_body = match closure_body { ast::Expr::BlockExpr(block) => unwrap_trivial_block(block), e => e, }; - let cond = if invert_cond { invert_boolean_expression(&make, cond) } else { cond }; + let cond = if invert_cond { invert_boolean_expression(make, cond) } else { cond }; let parenthesize = matches!( cond, @@ -128,8 +128,6 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into())); let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list); editor.replace(expr.syntax(), mcall.syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -187,7 +185,7 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> e => mapless_make.block_expr(None, Some(e)), }; - let (mut editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body); + let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body); // Wrap all tails in `Some(...)` let none_path = mapless_make.expr_path(mapless_make.ident_path("None")); let some_path = mapless_make.expr_path(mapless_make.ident_path("Some")); @@ -210,8 +208,8 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> let edit = editor.finish(); let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap(); - let mut editor = builder.make_editor(mcall.syntax()); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(mcall.syntax()); + let make = editor.make(); let cond = match &receiver { ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver), @@ -225,8 +223,6 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> ) .indent(mcall.indent_level()); editor.replace(mcall.syntax().clone(), if_expr.syntax().clone()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs index 9f9ced98d73b2..acade433978ce 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_closure_to_fn.rs @@ -24,7 +24,7 @@ use crate::assist_context::{AssistContext, Assists}; // This converts a closure to a freestanding function, changing all captures to parameters. // // ``` -// # //- minicore: copy +// # //- minicore: copy, fn // # struct String; // # impl String { // # fn new() -> Self {} @@ -90,6 +90,7 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) } }) .collect::>>()?; + let capture_params_start = params.len(); let mut body = closure.body()?.clone_for_update(); let mut is_gen = false; @@ -152,7 +153,8 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) .map(|(_, _, it)| it.clone()) .unwrap_or_else(|| make::name("fun_name")); let captures = closure_ty.captured_items(ctx.db()); - let capture_tys = closure_ty.capture_types(ctx.db()); + let capture_tys = + captures.iter().map(|capture| capture.captured_ty(ctx.db())).collect::>(); let mut captures_as_args = Vec::with_capacity(captures.len()); @@ -163,22 +165,28 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) for (capture, capture_ty) in std::iter::zip(&captures, &capture_tys) { // FIXME: Allow configuring the replacement of `self`. - let capture_name = - if capture.local().is_self(ctx.db()) && !capture.has_field_projections() { - make::name("this") - } else { - make::name(&capture.place_to_name(ctx.db())) - }; + let is_self = capture.local().is_self(ctx.db()) && !capture.has_field_projections(); + let capture_name = if is_self { + make::name("this") + } else { + make::name(&capture.place_to_name(ctx.db(), ctx.edition())) + }; closure_mentioned_generic_params.extend(capture_ty.generic_params(ctx.db())); let capture_ty = capture_ty .display_source_code(ctx.db(), module.into(), true) .unwrap_or_else(|_| "_".to_owned()); - params.push(make::param( + let param = make::param( ast::Pat::IdentPat(make::ident_pat(false, false, capture_name.clone_subtree())), make::ty(&capture_ty), - )); + ); + if is_self { + // Always put `this` first. + params.insert(capture_params_start, param); + } else { + params.push(param); + } for capture_usage in capture.usages().sources(ctx.db()) { if capture_usage.file_id() != ctx.file_id() { @@ -188,24 +196,32 @@ pub(crate) fn convert_closure_to_fn(acc: &mut Assists, ctx: &AssistContext<'_>) let capture_usage_source = capture_usage.source(); let capture_usage_source = capture_usage_source.to_node(&body_root); - let expr = match capture_usage_source { + let mut expr = match capture_usage_source { Either::Left(expr) => expr, Either::Right(pat) => { let Some(expr) = expr_of_pat(pat) else { continue }; expr } }; + if !capture_usage.is_ref() { + expr = peel_ref(expr); + } let replacement = wrap_capture_in_deref_if_needed( &expr, &capture_name, capture.kind(), - capture_usage.is_ref(), + matches!(expr, ast::Expr::RefExpr(_)) || capture_usage.is_ref(), ) .clone_for_update(); capture_usages_replacement_map.push((expr, replacement)); } - captures_as_args.push(capture_as_arg(ctx, capture)); + let capture_as_arg = capture_as_arg(ctx, capture); + if is_self { + captures_as_args.insert(0, capture_as_arg); + } else { + captures_as_args.push(capture_as_arg); + } } let (closure_type_params, closure_where_clause) = @@ -463,24 +479,29 @@ fn compute_closure_type_params( (Some(make::generic_param_list(include_params)), where_clause) } +fn peel_parens(mut expr: ast::Expr) -> ast::Expr { + loop { + if ast::ParenExpr::can_cast(expr.syntax().kind()) { + let Some(parent) = expr.syntax().parent().and_then(ast::Expr::cast) else { break }; + expr = parent; + } else { + break; + } + } + expr +} + +fn peel_ref(mut expr: ast::Expr) -> ast::Expr { + expr = peel_parens(expr); + expr.syntax().parent().and_then(ast::RefExpr::cast).map(Into::into).unwrap_or(expr) +} + fn wrap_capture_in_deref_if_needed( expr: &ast::Expr, capture_name: &ast::Name, capture_kind: CaptureKind, is_ref: bool, ) -> ast::Expr { - fn peel_parens(mut expr: ast::Expr) -> ast::Expr { - loop { - if ast::ParenExpr::can_cast(expr.syntax().kind()) { - let Some(parent) = expr.syntax().parent().and_then(ast::Expr::cast) else { break }; - expr = parent; - } else { - break; - } - } - expr - } - let capture_name = make::expr_path(make::path_from_text(&capture_name.text())); if capture_kind == CaptureKind::Move || is_ref { return capture_name; @@ -507,8 +528,11 @@ fn wrap_capture_in_deref_if_needed( } fn capture_as_arg(ctx: &AssistContext<'_>, capture: &ClosureCapture<'_>) -> ast::Expr { - let place = parse_expr_from_str(&capture.display_place_source_code(ctx.db()), ctx.edition()) - .expect("`display_place_source_code()` produced an invalid expr"); + let place = parse_expr_from_str( + &capture.display_place_source_code(ctx.db(), ctx.edition()), + ctx.edition(), + ) + .expect("`display_place_source_code()` produced an invalid expr"); let needs_mut = match capture.kind() { CaptureKind::SharedRef => false, CaptureKind::MutableRef | CaptureKind::UniqueSharedRef => true, @@ -688,7 +712,7 @@ mod tests { check_assist( convert_closure_to_fn, r#" -//- minicore:copy +//- minicore: copy, fn fn main() { let s = &mut true; let closure = |$0| { *s = false; }; @@ -710,7 +734,7 @@ fn main() { check_assist( convert_closure_to_fn, r#" -//- minicore:copy +//- minicore: copy, fn struct A { a: i32, b: bool } fn main() { let mut a = A { a: 123, b: false }; @@ -740,8 +764,8 @@ fn main() { check_assist( convert_closure_to_fn, r#" -//- minicore:copy -struct A { b: &'static B, c: i32 } +//- minicore: copy, fn +struct A { b: &'static mut B, c: i32 } struct B(bool, i32); struct C; impl C { @@ -756,7 +780,7 @@ impl C { } "#, r#" -struct A { b: &'static B, c: i32 } +struct A { b: &'static mut B, c: i32 } struct B(bool, i32); struct C; impl C { @@ -778,7 +802,7 @@ impl C { check_assist( convert_closure_to_fn, r#" -//- minicore:copy +//- minicore: copy, fn struct A { b: &'static B, c: i32 } struct B(bool, i32); impl A { @@ -795,10 +819,10 @@ struct A { b: &'static B, c: i32 } struct B(bool, i32); impl A { fn foo(&self) { - fn closure(self_b_1: &i32) { - let b = *self_b_1; + fn closure(self_b: &B) { + let b = self_b.1; } - closure(&self.b.1); + closure(self.b); } } "#, @@ -810,7 +834,7 @@ impl A { check_assist( convert_closure_to_fn, r#" -//- minicore: copy, future +//- minicore: copy, future, async_fn fn foo(&self) { let closure = async |$0| 1; closure(); @@ -832,7 +856,7 @@ fn foo(&self) { check_assist( convert_closure_to_fn, r#" -//- minicore: copy, future +//- minicore: copy, future, fn fn foo() { let closure = |$0| async { 1 }; closure(); @@ -878,7 +902,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let closure = |$0| {}; closure(); @@ -898,7 +922,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let a = 1; let closure = |$0| a; @@ -918,7 +942,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let closure = |$0| 'label: {}; closure(); @@ -936,7 +960,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let closure = |$0| { const { () } @@ -956,7 +980,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let closure = |$0| unsafe { }; closure(); @@ -974,7 +998,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { { let closure = |$0| match () { @@ -1049,7 +1073,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn struct A { b: B } struct B(bool, i32); fn foo() { @@ -1066,7 +1090,7 @@ struct B(bool, i32); fn foo() { let mut a = A { b: B(true, 0) }; fn closure(a_b_1: &mut i32) { - let A { b: B(_, ref mut c) } = a_b_1; + let A { b: B(_, ref mut c) } = *a_b_1; } closure(&mut a.b.1); } @@ -1079,7 +1103,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let (mut a, b) = (0.1, "abc"); let closure = |$0p1: i32, p2: &mut bool| { @@ -1107,7 +1131,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let (mut a, b) = (0.1, "abc"); let closure = |$0p1: i32, p2| { @@ -1145,7 +1169,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, fn fn foo() { let (mut a, b) = (0.1, "abc"); let closure = |$0p1: i32, p2| { @@ -1183,7 +1207,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore: copy +//- minicore: copy, from struct Foo(A, B); impl, const C: usize> Foo { fn foo(a: A, b: D) @@ -1244,7 +1268,7 @@ fn foo() { check_assist( convert_closure_to_fn, r#" -//- minicore:copy +//- minicore: copy, fn fn main() { let a = &mut true; let closure = |$0| { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs index 0d36a5ddb304c..f242fe831447a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_comment_block.rs @@ -381,6 +381,21 @@ fn main() { ); } + #[test] + fn empty_block_to_line() { + check_assist( + convert_comment_block, + r#" +/**/$0 +fn main() {} +"#, + r#" + +fn main() {} +"#, + ); + } + #[test] fn end_of_line_block_to_line() { check_assist_not_applicable( diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_for_to_while_let.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_for_to_while_let.rs index a5c29a45a51f2..9eb4c0584b362 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_for_to_while_let.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_for_to_while_let.rs @@ -2,7 +2,7 @@ use hir::{Name, sym}; use ide_db::{famous_defs::FamousDefs, syntax_helpers::suggest_name}; use syntax::{ AstNode, - ast::{self, HasAttrs, HasLoopBody, edit::IndentLevel, syntax_factory::SyntaxFactory}, + ast::{self, HasAttrs, HasLoopBody, edit::IndentLevel}, syntax_editor::Position, }; @@ -48,8 +48,8 @@ pub(crate) fn convert_for_loop_to_while_let( "Replace this for loop with `while let`", for_loop.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(for_loop.syntax()); + let editor = builder.make_editor(for_loop.syntax()); + let make = editor.make(); let (iterable, method) = if impls_core_iter(&ctx.sema, &iterable) { (iterable, None) @@ -85,12 +85,7 @@ pub(crate) fn convert_for_loop_to_while_let( editor.insert(Position::before(for_loop.syntax()), make.whitespace(" ")); editor.insert(Position::before(for_loop.syntax()), label); } - crate::utils::insert_attributes( - for_loop.syntax(), - &mut editor, - for_loop.attrs(), - &make, - ); + crate::utils::insert_attributes(for_loop.syntax(), &editor, for_loop.attrs()); editor.insert( Position::before(for_loop.syntax()), @@ -110,7 +105,6 @@ pub(crate) fn convert_for_loop_to_while_let( editor.replace(for_loop.syntax(), while_loop.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs index 66ccd2a4e4093..18f3ced414026 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_from_to_tryfrom.rs @@ -74,8 +74,8 @@ pub(crate) fn convert_from_to_tryfrom(acc: &mut Assists, ctx: &AssistContext<'_> "Convert From to TryFrom", impl_.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(impl_.syntax()); + let editor = builder.make_editor(impl_.syntax()); + let make = editor.make(); editor.replace(trait_ty.syntax(), make.ty(&format!("TryFrom<{from_type}>")).syntax()); editor.replace( @@ -83,11 +83,11 @@ pub(crate) fn convert_from_to_tryfrom(acc: &mut Assists, ctx: &AssistContext<'_> make.ty("Result").syntax(), ); editor.replace(from_fn_name.syntax(), make.name("try_from").syntax()); - editor.replace(tail_expr.syntax(), wrap_ok(&make, tail_expr.clone()).syntax()); + editor.replace(tail_expr.syntax(), wrap_ok(make, tail_expr.clone()).syntax()); for r in return_exprs { let t = r.expr().unwrap_or_else(|| make.expr_unit()); - editor.replace(t.syntax(), wrap_ok(&make, t.clone()).syntax()); + editor.replace(t.syntax(), wrap_ok(make, t.clone()).syntax()); } let error_type_alias = @@ -111,7 +111,6 @@ pub(crate) fn convert_from_to_tryfrom(acc: &mut Assists, ctx: &AssistContext<'_> make.whitespace("\n").syntax_element(), ], ); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs index 63b1a0193bd6d..cc5cc490f1bc1 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_iter_for_each_to_for.rs @@ -1,12 +1,11 @@ use hir::{Name, sym}; use ide_db::famous_defs::FamousDefs; -use stdx::format_to; use syntax::{ AstNode, - ast::{self, HasArgList, HasLoopBody, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, + ast::{self, HasArgList, HasLoopBody, edit::AstNodeEdit}, }; -use crate::{AssistContext, AssistId, Assists}; +use crate::{AssistContext, AssistId, Assists, utils::wrap_paren}; // Assist: convert_iter_for_each_to_for // @@ -57,7 +56,9 @@ pub(crate) fn convert_iter_for_each_to_for( "Replace this `Iterator::for_each` with a for loop", range, |builder| { - let make = SyntaxFactory::with_mappings(); + let target_node = stmt.as_ref().map_or(method.syntax(), AstNode::syntax); + let editor = builder.make_editor(target_node); + let make = editor.make(); let indent = stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level); @@ -68,9 +69,6 @@ pub(crate) fn convert_iter_for_each_to_for( .indent(indent); let expr_for_loop = make.expr_for_loop(param, receiver, block); - - let target_node = stmt.as_ref().map_or(method.syntax(), AstNode::syntax); - let mut editor = builder.make_editor(target_node); editor.replace(target_node, expr_for_loop.syntax()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, @@ -116,32 +114,40 @@ pub(crate) fn convert_for_loop_with_for_each( "Replace this for loop with `Iterator::for_each`", for_loop.syntax().text_range(), |builder| { - let mut buf = String::new(); + let editor = builder.make_editor(for_loop.syntax()); + let make = editor.make(); + + let mut receiver = iterable.clone(); - if let Some((expr_behind_ref, method, krate)) = + let iter_method = if let Some((expr_behind_ref, method, krate)) = is_ref_and_impls_iter_method(&ctx.sema, &iterable) { + receiver = expr_behind_ref; // We have either "for x in &col" and col implements a method called iter // or "for x in &mut col" and col implements a method called iter_mut - format_to!( - buf, - "{expr_behind_ref}.{}()", - method.display(ctx.db(), krate.edition(ctx.db())) - ); - } else if let ast::Expr::RangeExpr(..) = iterable { - // range expressions need to be parenthesized for the syntax to be correct - format_to!(buf, "({iterable})"); - } else if impls_core_iter(&ctx.sema, &iterable) { - format_to!(buf, "{iterable}"); - } else if let ast::Expr::RefExpr(_) = iterable { - format_to!(buf, "({iterable}).into_iter()"); + method.display(ctx.db(), krate.edition(ctx.db())).to_string() } else { - format_to!(buf, "{iterable}.into_iter()"); + "into_iter".to_owned() + }; + + receiver = wrap_paren(receiver, make, ast::prec::ExprPrecedence::Postfix); + + if !impls_core_iter(&ctx.sema, &iterable) { + receiver = make + .expr_method_call(receiver, make.name_ref(&iter_method), make.arg_list([])) + .into(); } - format_to!(buf, ".for_each(|{pat}| {body});"); + let loop_arg = make.expr_closure([make.untyped_param(pat)], body.into()); + let for_each = make.expr_method_call( + receiver, + make.name_ref("for_each"), + make.arg_list([loop_arg.into()]), + ); + let for_each = make.expr_stmt(for_each.into()); - builder.replace(for_loop.syntax().text_range(), buf) + editor.replace(for_loop.syntax(), for_each.syntax()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs index 9a9808e270fa2..1ae12390eedbf 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_let_else_to_match.rs @@ -1,8 +1,8 @@ use syntax::T; use syntax::ast::RangeItem; use syntax::ast::edit::AstNodeEdit; -use syntax::ast::syntax_factory::SyntaxFactory; use syntax::ast::{self, AstNode, HasName, LetStmt, Pat}; +use syntax::syntax_editor::SyntaxEditor; use crate::{AssistContext, AssistId, Assists}; @@ -25,6 +25,7 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); // Should focus on the `else` token to trigger let let_stmt = ctx .find_token_syntax_at_offset(T![else]) @@ -45,10 +46,8 @@ pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<' return None; } let pat = let_stmt.pat()?; - - let make = SyntaxFactory::with_mappings(); let mut idents = Vec::default(); - let pat_without_mut = remove_mut_and_collect_idents(&make, &pat, &mut idents)?; + let pat_without_mut = remove_mut_and_collect_idents(&editor, &pat, &mut idents)?; let bindings = idents .into_iter() .filter_map(|ref pat| { @@ -70,8 +69,7 @@ pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<' }, let_stmt.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(let_stmt.syntax()); - + let make = editor.make(); let binding_paths = bindings .iter() .map(|(name, _)| make.expr_path(make.ident_path(&name.to_string()))) @@ -115,18 +113,17 @@ pub(crate) fn convert_let_else_to_match(acc: &mut Assists, ctx: &AssistContext<' ); editor.replace(let_stmt.syntax(), new_let_stmt.syntax()); } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } fn remove_mut_and_collect_idents( - make: &SyntaxFactory, + editor: &SyntaxEditor, pat: &ast::Pat, acc: &mut Vec, ) -> Option { + let make = editor.make(); Some(match pat { ast::Pat::IdentPat(p) => { acc.push(p.clone()); @@ -135,90 +132,92 @@ fn remove_mut_and_collect_idents( p.ref_token().is_some() && p.mut_token().is_some(), p.name()?, ); - if let Some(inner) = p.pat() { - non_mut_pat.set_pat(remove_mut_and_collect_idents(make, &inner, acc)); - } + let non_mut_pat = if let Some(inner) = p.pat() { + non_mut_pat.set_pat(remove_mut_and_collect_idents(editor, &inner, acc), editor) + } else { + non_mut_pat + }; non_mut_pat.into() } ast::Pat::BoxPat(p) => { - make.box_pat(remove_mut_and_collect_idents(make, &p.pat()?, acc)?).into() + let pat = remove_mut_and_collect_idents(editor, &p.pat()?, acc)?; + make.box_pat(pat).into() + } + ast::Pat::OrPat(p) => { + let pats = p + .pats() + .map(|pat| remove_mut_and_collect_idents(editor, &pat, acc)) + .collect::>>()?; + make.or_pat(pats, p.leading_pipe().is_some()).into() } - ast::Pat::OrPat(p) => make - .or_pat( - p.pats() - .map(|pat| remove_mut_and_collect_idents(make, &pat, acc)) - .collect::>>()?, - p.leading_pipe().is_some(), - ) - .into(), ast::Pat::ParenPat(p) => { - make.paren_pat(remove_mut_and_collect_idents(make, &p.pat()?, acc)?).into() + let pat = remove_mut_and_collect_idents(editor, &p.pat()?, acc)?; + make.paren_pat(pat).into() } - ast::Pat::RangePat(p) => make - .range_pat( - if let Some(start) = p.start() { - Some(remove_mut_and_collect_idents(make, &start, acc)?) - } else { - None - }, - if let Some(end) = p.end() { - Some(remove_mut_and_collect_idents(make, &end, acc)?) - } else { - None - }, - ) - .into(), - ast::Pat::RecordPat(p) => make - .record_pat_with_fields( + ast::Pat::RangePat(p) => { + let start = if let Some(start) = p.start() { + Some(remove_mut_and_collect_idents(editor, &start, acc)?) + } else { + None + }; + let end = if let Some(end) = p.end() { + Some(remove_mut_and_collect_idents(editor, &end, acc)?) + } else { + None + }; + make.range_pat(start, end).into() + } + ast::Pat::RecordPat(p) => { + let fields = p + .record_pat_field_list()? + .fields() + .map(|field| { + remove_mut_and_collect_idents(editor, &field.pat()?, acc).map(|pat| { + if let Some(name_ref) = field.name_ref() { + make.record_pat_field(name_ref, pat) + } else { + make.record_pat_field_shorthand(pat) + } + }) + }) + .collect::>>()?; + make.record_pat_with_fields( p.path()?, - make.record_pat_field_list( - p.record_pat_field_list()? - .fields() - .map(|field| { - remove_mut_and_collect_idents(make, &field.pat()?, acc).map(|pat| { - if let Some(name_ref) = field.name_ref() { - make.record_pat_field(name_ref, pat) - } else { - make.record_pat_field_shorthand(pat) - } - }) - }) - .collect::>>()?, - p.record_pat_field_list()?.rest_pat(), - ), + make.record_pat_field_list(fields, p.record_pat_field_list()?.rest_pat()), ) - .into(), + .into() + } ast::Pat::RefPat(p) => { let inner = p.pat()?; if let ast::Pat::IdentPat(ident) = inner { acc.push(ident); p.clone().into() } else { - make.ref_pat(remove_mut_and_collect_idents(make, &inner, acc)?).into() + let pat = remove_mut_and_collect_idents(editor, &inner, acc)?; + make.ref_pat(pat).into() } } - ast::Pat::SlicePat(p) => make - .slice_pat( - p.pats() - .map(|pat| remove_mut_and_collect_idents(make, &pat, acc)) - .collect::>>()?, - ) - .into(), - ast::Pat::TuplePat(p) => make - .tuple_pat( - p.fields() - .map(|field| remove_mut_and_collect_idents(make, &field, acc)) - .collect::>>()?, - ) - .into(), - ast::Pat::TupleStructPat(p) => make - .tuple_struct_pat( - p.path()?, - p.fields() - .map(|field| remove_mut_and_collect_idents(make, &field, acc)) - .collect::>>()?, - ) - .into(), + ast::Pat::SlicePat(p) => { + let pats = p + .pats() + .map(|pat| remove_mut_and_collect_idents(editor, &pat, acc)) + .collect::>>()?; + make.slice_pat(pats).into() + } + ast::Pat::TuplePat(p) => { + let pats = p + .fields() + .map(|field| remove_mut_and_collect_idents(editor, &field, acc)) + .collect::>>()?; + make.tuple_pat(pats).into() + } + ast::Pat::TupleStructPat(p) => { + let fields = p + .fields() + .map(|field| remove_mut_and_collect_idents(editor, &field, acc)) + .collect::>>()?; + make.tuple_struct_pat(p.path()?, fields).into() + } ast::Pat::RestPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::PathPat(_) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_match_to_let_else.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_match_to_let_else.rs index 4b132d68ee3a5..bc49acc1ef356 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_match_to_let_else.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_match_to_let_else.rs @@ -1,7 +1,7 @@ use ide_db::defs::{Definition, NameRefClass}; use syntax::{ AstNode, SyntaxNode, - ast::{self, HasName, Name, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, + ast::{self, HasName, Name, edit::AstNodeEdit}, syntax_editor::SyntaxEditor, }; @@ -121,8 +121,8 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti // Rename `extracted` with `binding` in `pat`. fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode { - let (mut editor, syntax) = SyntaxEditor::new(pat.syntax().clone()); - let make = SyntaxFactory::with_mappings(); + let (editor, syntax) = SyntaxEditor::new(pat.syntax().clone()); + let make = editor.make(); let extracted = extracted .iter() .map(|e| e.syntax().text_range() - pat.syntax().text_range().start()) @@ -145,7 +145,6 @@ fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> Syn editor.replace(extracted_syntax, binding.syntax()); } } - editor.add_mappings(make.finish_with_mappings()); let new_node = editor.finish().new_root().clone(); if let Some(pat) = ast::Pat::cast(new_node.clone()) { pat.dedent(1.into()).syntax().clone() diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs index 4ea56e3e613fb..5b691dba5ea76 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_named_struct_to_tuple_struct.rs @@ -3,9 +3,7 @@ use ide_db::{defs::Definition, search::FileReference}; use syntax::{ NodeOrToken, SyntaxKind, SyntaxNode, T, algo::next_non_trivia_token, - ast::{ - self, AstNode, HasAttrs, HasGenericParams, HasVisibility, syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, HasAttrs, HasGenericParams, HasVisibility}, match_ast, syntax_editor::{Element, Position, SyntaxEditor}, }; @@ -101,27 +99,26 @@ fn edit_struct_def( ) { // Note that we don't need to consider macro files in this function because this is // currently not triggered for struct definitions inside macro calls. + let editor = builder.make_editor(strukt.syntax()); + let make = editor.make(); + let tuple_fields = record_fields.fields().filter_map(|f| { - let (mut editor, field) = - SyntaxEditor::with_ast_node(&ast::make::tuple_field(f.visibility(), f.ty()?)); - editor.insert_all( + let (field_editor, field) = + SyntaxEditor::with_ast_node(&make.tuple_field(f.visibility(), f.ty()?)); + field_editor.insert_all( Position::first_child_of(field.syntax()), f.attrs().map(|attr| attr.syntax().clone().into()).collect(), ); - let field_syntax = editor.finish().new_root().clone(); - let field = ast::TupleField::cast(field_syntax)?; - Some(field) + let field_syntax = field_editor.finish().new_root().clone(); + ast::TupleField::cast(field_syntax) }); - let make = SyntaxFactory::without_mappings(); - let mut edit = builder.make_editor(strukt.syntax()); - let tuple_fields = make.tuple_field_list(tuple_fields); let mut elements = vec![tuple_fields.syntax().clone().into()]; if let Either::Left(strukt) = strukt { if let Some(w) = strukt.where_clause() { - edit.delete(w.syntax()); + editor.delete(w.syntax()); elements.extend([ make.whitespace("\n").into(), @@ -136,23 +133,23 @@ fn edit_struct_def( .and_then(|tok| tok.next_token()) .filter(|tok| tok.kind() == SyntaxKind::WHITESPACE) { - edit.delete(tok); + editor.delete(tok); } } else { elements.push(make.token(T![;]).into()); } } - edit.replace_with_many(record_fields.syntax(), elements); + editor.replace_with_many(record_fields.syntax(), elements); if let Some(tok) = record_fields .l_curly_token() .and_then(|tok| tok.prev_token()) .filter(|tok| tok.kind() == SyntaxKind::WHITESPACE) { - edit.delete(tok) + editor.delete(tok) } - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); } fn edit_struct_references( @@ -168,18 +165,18 @@ fn edit_struct_references( for (file_id, refs) in usages { let source = ctx.sema.parse(file_id); - let mut edit = builder.make_editor(source.syntax()); + let editor = builder.make_editor(source.syntax()); for r in refs { - process_struct_name_reference(ctx, r, &mut edit, &source); + process_struct_name_reference(ctx, r, &editor, &source); } - builder.add_file_edits(file_id.file_id(ctx.db()), edit); + builder.add_file_edits(file_id.file_id(ctx.db()), editor); } } fn process_struct_name_reference( ctx: &AssistContext<'_>, r: FileReference, - edit: &mut SyntaxEditor, + edit: &SyntaxEditor, source: &ast::SourceFile, ) -> Option<()> { // First check if it's the last semgnet of a path that directly belongs to a record @@ -232,7 +229,7 @@ fn process_struct_name_reference( fn record_to_tuple_struct_like( ctx: &AssistContext<'_>, source: &ast::SourceFile, - edit: &mut SyntaxEditor, + editor: &SyntaxEditor, field_list: T, fields: impl FnOnce(&T) -> I, ) -> Option<()> @@ -240,7 +237,7 @@ where T: AstNode, I: IntoIterator, { - let make = SyntaxFactory::without_mappings(); + let make = editor.make(); let orig = ctx.sema.original_range_opt(field_list.syntax())?; let list_range = cover_edit_range(source.syntax(), orig.range); @@ -254,13 +251,13 @@ where }; if l_curly.kind() == T!['{'] { - delete_whitespace(edit, l_curly.prev_token()); - delete_whitespace(edit, l_curly.next_token()); - edit.replace(l_curly, make.token(T!['('])); + delete_whitespace(editor, l_curly.prev_token()); + delete_whitespace(editor, l_curly.next_token()); + editor.replace(l_curly, make.token(T!['('])); } if r_curly.kind() == T!['}'] { - delete_whitespace(edit, r_curly.prev_token()); - edit.replace(r_curly, make.token(T![')'])); + delete_whitespace(editor, r_curly.prev_token()); + editor.replace(r_curly, make.token(T![')'])); } for name_ref in fields(&field_list) { @@ -270,14 +267,14 @@ where if let Some(colon) = next_non_trivia_token(name_range.end().clone()) && colon.kind() == T![:] { - edit.delete(&colon); - edit.delete_all(name_range); + editor.delete(&colon); + editor.delete_all(name_range); if let Some(next) = next_non_trivia_token(colon.clone()) && next.kind() != T!['}'] { // Avoid overlapping delete whitespace on `{ field: }` - delete_whitespace(edit, colon.next_token()); + delete_whitespace(editor, colon.next_token()); } } } @@ -289,7 +286,6 @@ fn edit_field_references( builder: &mut SourceChangeBuilder, fields: impl Iterator, ) { - let make = SyntaxFactory::without_mappings(); for (index, field) in fields.enumerate() { let field = match ctx.sema.to_def(&field) { Some(it) => it, @@ -299,13 +295,14 @@ fn edit_field_references( let usages = def.usages(&ctx.sema).all(); for (file_id, refs) in usages { let source = ctx.sema.parse(file_id); - let mut edit = builder.make_editor(source.syntax()); + let editor = builder.make_editor(source.syntax()); + let make = editor.make(); for r in refs { if let Some(name_ref) = r.name.as_name_ref() { // Only edit the field reference if it's part of a `.field` access if name_ref.syntax().parent().and_then(ast::FieldExpr::cast).is_some() { - edit.replace_all( + editor.replace_all( cover_edit_range(source.syntax(), r.range), vec![make.name_ref(&index.to_string()).syntax().clone().into()], ); @@ -313,12 +310,12 @@ fn edit_field_references( } } - builder.add_file_edits(file_id.file_id(ctx.db()), edit); + builder.add_file_edits(file_id.file_id(ctx.db()), editor); } } } -fn delete_whitespace(edit: &mut SyntaxEditor, whitespace: Option) { +fn delete_whitespace(edit: &SyntaxEditor, whitespace: Option) { let Some(whitespace) = whitespace else { return }; let NodeOrToken::Token(token) = whitespace.syntax_element() else { return }; @@ -328,7 +325,7 @@ fn delete_whitespace(edit: &mut SyntaxEditor, whitespace: Option) } fn remove_trailing_comma(w: ast::WhereClause) -> SyntaxNode { - let (mut editor, w) = SyntaxEditor::new(w.syntax().clone()); + let (editor, w) = SyntaxEditor::new(w.syntax().clone()); if let Some(last) = w.last_child_or_token() && last.kind() == T![,] { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs index 61393950767f8..c83f8b076551b 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_range_for_to_while.rs @@ -6,7 +6,7 @@ use syntax::{ T, algo::previous_non_trivia_token, ast::{ - self, HasArgList, HasLoopBody, HasName, RangeItem, edit::AstNodeEdit, make, + self, HasArgList, HasLoopBody, HasName, RangeItem, edit::AstNodeEdit, syntax_factory::SyntaxFactory, }, syntax_editor::{Element, Position, SyntaxEditor}, @@ -36,11 +36,13 @@ use crate::assist_context::{AssistContext, Assists}; // } // ``` pub(crate) fn convert_range_for_to_while(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); + let make = editor.make(); let for_kw = ctx.find_token_syntax_at_offset(T![for])?; let for_ = ast::ForExpr::cast(for_kw.parent()?)?; let ast::Pat::IdentPat(pat) = for_.pat()? else { return None }; let iterable = for_.iterable()?; - let (start, end, step, inclusive) = extract_range(&iterable)?; + let (start, end, step, inclusive) = extract_range(&iterable, make)?; let name = pat.name()?; let body = for_.loop_body()?.stmt_list()?; let label = for_.label(); @@ -55,13 +57,11 @@ pub(crate) fn convert_range_for_to_while(acc: &mut Assists, ctx: &AssistContext< description, for_.syntax().text_range(), |builder| { - let mut edit = builder.make_editor(for_.syntax()); - let make = SyntaxFactory::with_mappings(); - + let make = editor.make(); let indent = for_.indent_level(); let pat = make.ident_pat(pat.ref_token().is_some(), true, name.clone()); let let_stmt = make.let_stmt(pat.into(), None, Some(start)); - edit.insert_all( + editor.insert_all( Position::before(for_.syntax()), vec![ let_stmt.syntax().syntax_element(), @@ -86,39 +86,36 @@ pub(crate) fn convert_range_for_to_while(acc: &mut Assists, ctx: &AssistContext< elements.push(make.token(T![loop]).syntax_element()); } - edit.replace_all( + editor.replace_all( for_kw.syntax_element()..=iterable.syntax().syntax_element(), elements, ); let op = ast::BinaryOp::Assignment { op: Some(ast::ArithOp::Add) }; - process_loop_body( - body, - label, - &mut edit, - vec![ - make.whitespace(&format!("\n{}", indent + 1)).syntax_element(), - make.expr_bin(var_expr, op, step).syntax().syntax_element(), - make.token(T![;]).syntax_element(), - ], - ); - - edit.add_mappings(make.finish_with_mappings()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + let incrementer = vec![ + make.whitespace(&format!("\n{}", indent + 1)).syntax_element(), + make.expr_bin(var_expr, op, step).syntax().syntax_element(), + make.token(T![;]).syntax_element(), + ]; + process_loop_body(body, label, &editor, incrementer); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } -fn extract_range(iterable: &ast::Expr) -> Option<(ast::Expr, Option, ast::Expr, bool)> { +fn extract_range( + iterable: &ast::Expr, + make: &SyntaxFactory, +) -> Option<(ast::Expr, Option, ast::Expr, bool)> { Some(match iterable { - ast::Expr::ParenExpr(expr) => extract_range(&expr.expr()?)?, + ast::Expr::ParenExpr(expr) => extract_range(&expr.expr()?, make)?, ast::Expr::RangeExpr(range) => { let inclusive = range.op_kind()? == ast::RangeOp::Inclusive; - (range.start()?, range.end(), make::expr_literal("1").into(), inclusive) + (range.start()?, range.end(), make.expr_literal("1").into(), inclusive) } ast::Expr::MethodCallExpr(call) if call.name_ref()?.text() == "step_by" => { let [step] = Itertools::collect_array(call.arg_list()?.args())?; - let (start, end, _, inclusive) = extract_range(&call.receiver()?)?; + let (start, end, _, inclusive) = extract_range(&call.receiver()?, make)?; (start, end, step, inclusive) } _ => return None, @@ -128,9 +125,10 @@ fn extract_range(iterable: &ast::Expr) -> Option<(ast::Expr, Option, fn process_loop_body( body: ast::StmtList, label: Option, - edit: &mut SyntaxEditor, + editor: &SyntaxEditor, incrementer: Vec, ) -> Option<()> { + let make = editor.make(); let last = previous_non_trivia_token(body.r_curly_token()?)?.syntax_element(); let new_body = body.indent(1.into()); @@ -143,7 +141,7 @@ fn process_loop_body( ); if continues.is_empty() { - edit.insert_all(Position::after(last), incrementer); + editor.insert_all(Position::after(last), incrementer); return Some(()); } @@ -154,9 +152,9 @@ fn process_loop_body( let first = children.next()?; let block_content = first.clone()..=children.last().unwrap_or(first); - let continue_label = make::lifetime("'cont"); - let break_expr = make::expr_break(Some(continue_label.clone()), None); - let (mut new_edit, _) = SyntaxEditor::new(new_body.syntax().clone()); + let continue_label = make.lifetime("'cont"); + let break_expr = make.expr_break(Some(continue_label.clone()), None); + let (new_edit, _) = SyntaxEditor::new(new_body.syntax().clone()); for continue_expr in &continues { new_edit.replace(continue_expr.syntax(), break_expr.syntax()); } @@ -164,13 +162,13 @@ fn process_loop_body( let elements = itertools::chain( [ continue_label.syntax().syntax_element(), - make::token(T![:]).syntax_element(), - make::tokens::single_space().syntax_element(), + make.token(T![:]).syntax_element(), + make.whitespace(" ").syntax_element(), new_body.syntax_element(), ], incrementer, ); - edit.replace_all(block_content, elements.collect()); + editor.replace_all(block_content, elements.collect()); Some(()) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs index 004d09acac6e2..791a6a26af381 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs @@ -130,9 +130,10 @@ fn if_expr_to_guarded_return( "Convert to guarded return", target, |edit| { - let make = SyntaxFactory::without_mappings(); + let editor = edit.make_editor(if_expr.syntax()); + let make = editor.make(); let if_indent_level = IndentLevel::from_node(if_expr.syntax()); - let early_expression = else_block.make_early_block(&ctx.sema, &make); + let early_expression = else_block.make_early_block(&ctx.sema, make); let replacement = let_chains.into_iter().map(|expr| { if let ast::Expr::LetExpr(let_expr) = &expr && let (Some(pat), Some(expr)) = (let_expr.pat(), let_expr.expr()) @@ -145,8 +146,8 @@ fn if_expr_to_guarded_return( } else { // If. let new_expr = { - let then_branch = clean_stmt_block(&early_expression, &make); - let cond = invert_boolean_expression(&make, expr); + let then_branch = clean_stmt_block(&early_expression, make); + let cond = invert_boolean_expression(make, expr); make.expr_if(cond, then_branch, None).indent(if_indent_level) }; new_expr.syntax().clone() @@ -170,7 +171,6 @@ fn if_expr_to_guarded_return( .take_while(|i| *i != end_of_then), ) .collect(); - let mut editor = edit.make_editor(if_expr.syntax()); editor.replace_with_many(if_expr.syntax(), then_statements); edit.add_file_edits(ctx.vfs_file_id(), editor); }, @@ -209,22 +209,21 @@ fn let_stmt_to_guarded_return( "Convert to guarded return", target, |edit| { + let editor = edit.make_editor(let_stmt.syntax()); + let make = editor.make(); let let_indent_level = IndentLevel::from_node(let_stmt.syntax()); - let make = SyntaxFactory::without_mappings(); let replacement = { let let_else_stmt = make.let_else_stmt( happy_pattern, let_stmt.ty(), expr.reset_indent(), - else_block.make_early_block(&ctx.sema, &make), + else_block.make_early_block(&ctx.sema, make), ); let let_else_stmt = let_else_stmt.indent(let_indent_level); let_else_stmt.syntax().clone() }; - let mut editor = edit.make_editor(let_stmt.syntax()); editor.replace(let_stmt.syntax(), replacement); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -268,7 +267,8 @@ impl<'db> ElseBlock<'db> { return block_expr.reset_indent(); } - let (mut edit, block_expr) = SyntaxEditor::with_ast_node(&block_expr.reset_indent()); + let (editor, block_expr) = SyntaxEditor::with_ast_node(&block_expr.reset_indent()); + let make = editor.make(); let last_stmt = block_expr.statements().last().map(|it| it.syntax().clone()); let tail_expr = block_expr.tail_expr().map(|it| it.syntax().clone()); @@ -277,13 +277,11 @@ impl<'db> ElseBlock<'db> { }; let whitespace = last_element.prev_sibling_or_token().filter(|it| it.kind() == WHITESPACE); - let make = SyntaxFactory::without_mappings(); - if let Some(tail_expr) = block_expr.tail_expr() && !self.kind.is_unit() { - let early_expr = self.kind.make_early_expr(sema, &make, Some(tail_expr.clone())); - edit.replace(tail_expr.syntax(), early_expr.syntax()); + let early_expr = self.kind.make_early_expr(sema, make, Some(tail_expr.clone())); + editor.replace(tail_expr.syntax(), early_expr.syntax()); } else { let last_stmt = match block_expr.tail_expr() { Some(expr) => make.expr_stmt(expr).syntax().clone(), @@ -291,14 +289,14 @@ impl<'db> ElseBlock<'db> { }; let whitespace = make.whitespace(&whitespace.map_or(String::new(), |it| it.to_string())); - let early_expr = self.kind.make_early_expr(sema, &make, None).syntax().clone().into(); - edit.replace_with_many( + let early_expr = self.kind.make_early_expr(sema, make, None).syntax().clone().into(); + editor.replace_with_many( last_element, vec![last_stmt.into(), whitespace.into(), early_expr], ); } - ast::BlockExpr::cast(edit.finish().new_root().clone()).unwrap() + ast::BlockExpr::cast(editor.finish().new_root().clone()).unwrap() } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs index 1740cd024a89c..0af0cbc32a988 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_return_type_to_struct.rs @@ -72,15 +72,15 @@ pub(crate) fn convert_tuple_return_type_to_struct( "Convert tuple return type to tuple struct", target, move |edit| { - let mut syntax_editor = edit.make_editor(ret_type.syntax()); - let syntax_factory = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(ret_type.syntax()); + let make = editor.make(); let usages = Definition::Function(fn_def).usages(&ctx.sema).all(); let struct_name = format!("{}Result", stdx::to_camel_case(&fn_name.to_string())); let parent = fn_.syntax().ancestors().find_map(>::cast); add_tuple_struct_def( edit, - &syntax_factory, + make, ctx, &usages, parent.as_ref().map(|it| it.syntax()).unwrap_or(fn_.syntax()), @@ -89,22 +89,12 @@ pub(crate) fn convert_tuple_return_type_to_struct( &target_module, ); - syntax_editor.replace( - ret_type.syntax(), - syntax_factory.ret_type(syntax_factory.ty(&struct_name)).syntax(), - ); + editor.replace(ret_type.syntax(), make.ret_type(make.ty(&struct_name)).syntax()); if let Some(fn_body) = fn_.body() { - replace_body_return_values( - &mut syntax_editor, - &syntax_factory, - ast::Expr::BlockExpr(fn_body), - &struct_name, - ); + replace_body_return_values(&editor, ast::Expr::BlockExpr(fn_body), &struct_name); } - - syntax_editor.add_mappings(syntax_factory.finish_with_mappings()); - edit.add_file_edits(ctx.vfs_file_id(), syntax_editor); + edit.add_file_edits(ctx.vfs_file_id(), editor); replace_usages(edit, ctx, &usages, &struct_name, &target_module); }, @@ -122,35 +112,22 @@ fn replace_usages( for (file_id, references) in usages.iter() { let Some(first_ref) = references.first() else { continue }; - let mut editor = edit.make_editor(first_ref.name.syntax().as_node().unwrap()); - let syntax_factory = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(first_ref.name.syntax().as_node().unwrap()); + let make = editor.make(); - let refs_with_imports = augment_references_with_imports( - &syntax_factory, - ctx, - references, - struct_name, - target_module, - ); + let refs_with_imports = + augment_references_with_imports(make, ctx, references, struct_name, target_module); refs_with_imports.into_iter().rev().for_each(|(name, import_data)| { if let Some(fn_) = name.syntax().parent().and_then(ast::Fn::cast) { cov_mark::hit!(replace_trait_impl_fns); if let Some(ret_type) = fn_.ret_type() { - editor.replace( - ret_type.syntax(), - syntax_factory.ret_type(syntax_factory.ty(struct_name)).syntax(), - ); + editor.replace(ret_type.syntax(), make.ret_type(make.ty(struct_name)).syntax()); } if let Some(fn_body) = fn_.body() { - replace_body_return_values( - &mut editor, - &syntax_factory, - ast::Expr::BlockExpr(fn_body), - struct_name, - ); + replace_body_return_values(&editor, ast::Expr::BlockExpr(fn_body), struct_name); } } else { // replace tuple patterns @@ -172,27 +149,15 @@ fn replace_usages( for tuple_pat in tuple_pats { editor.replace( tuple_pat.syntax(), - syntax_factory - .tuple_struct_pat( - syntax_factory.path_from_text(struct_name), - tuple_pat.fields(), - ) + make.tuple_struct_pat(make.path_from_text(struct_name), tuple_pat.fields()) .syntax(), ); } } if let Some((import_scope, path)) = import_data { - insert_use_with_editor( - &import_scope, - path, - &ctx.config.insert_use, - &mut editor, - &syntax_factory, - ); + insert_use_with_editor(&import_scope, path, &ctx.config.insert_use, &editor); } }); - - editor.add_mappings(syntax_factory.finish_with_mappings()); edit.add_file_edits(file_id.file_id(ctx.db()), editor); } } @@ -296,12 +261,8 @@ fn add_tuple_struct_def( } /// Replaces each returned tuple in `body` with the constructor of the tuple struct named `struct_name`. -fn replace_body_return_values( - syntax_editor: &mut SyntaxEditor, - syntax_factory: &SyntaxFactory, - body: ast::Expr, - struct_name: &str, -) { +fn replace_body_return_values(editor: &SyntaxEditor, body: ast::Expr, struct_name: &str) { + let make = editor.make(); let mut exprs_to_wrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); @@ -316,11 +277,11 @@ fn replace_body_return_values( for ret_expr in exprs_to_wrap { if let ast::Expr::TupleExpr(tuple_expr) = &ret_expr { - let struct_constructor = syntax_factory.expr_call( - syntax_factory.expr_path(syntax_factory.ident_path(struct_name)), - syntax_factory.arg_list(tuple_expr.fields()), + let struct_constructor = make.expr_call( + make.expr_path(make.ident_path(struct_name)), + make.arg_list(tuple_expr.fields()), ); - syntax_editor.replace(ret_expr.syntax(), struct_constructor.syntax()); + editor.replace(ret_expr.syntax(), struct_constructor.syntax()); } } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs index 4ce7a9d866a9f..afbcf024b9fd4 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_tuple_struct_to_named_struct.rs @@ -86,32 +86,32 @@ pub(crate) fn convert_tuple_struct_to_named_struct( "Convert to named struct", target, |edit| { - let names = generate_names(tuple_fields.fields()); + let editor = edit.make_editor(syntax); + let names = generate_names(tuple_fields.fields(), editor.make()); edit_field_references(ctx, edit, tuple_fields.fields(), &names); - let mut editor = edit.make_editor(syntax); edit_struct_references(ctx, edit, strukt_def, &names); - edit_struct_def(&mut editor, &strukt_or_variant, tuple_fields, names); + edit_struct_def(&editor, &strukt_or_variant, tuple_fields, names); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) } fn edit_struct_def( - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, strukt: &Either, tuple_fields: ast::TupleFieldList, names: Vec, ) { + let make = editor.make(); let record_fields = tuple_fields.fields().zip(names).filter_map(|(f, name)| { - let (mut field_editor, field) = - SyntaxEditor::with_ast_node(&ast::make::record_field(f.visibility(), name, f.ty()?)); + let (field_editor, field) = + SyntaxEditor::with_ast_node(&make.record_field(f.visibility(), name, f.ty()?)); field_editor.insert_all( Position::first_child_of(field.syntax()), f.attrs().map(|attr| attr.syntax().clone().into()).collect(), ); ast::RecordField::cast(field_editor.finish().new_root().clone()) }); - let make = SyntaxFactory::without_mappings(); let record_fields = make.record_field_list(record_fields); let tuple_fields_before = Position::before(tuple_fields.syntax()); @@ -119,21 +119,21 @@ fn edit_struct_def( if let Some(w) = strukt.where_clause() { editor.delete(w.syntax()); let mut insert_element = Vec::new(); - insert_element.push(ast::make::tokens::single_newline().syntax_element()); + insert_element.push(make.whitespace("\n").syntax_element()); insert_element.push(w.syntax().syntax_element()); if w.syntax().last_token().is_none_or(|t| t.kind() != SyntaxKind::COMMA) { - insert_element.push(ast::make::token(T![,]).into()); + insert_element.push(make.token(T![,]).into()); } - insert_element.push(ast::make::tokens::single_newline().syntax_element()); + insert_element.push(make.whitespace("\n").syntax_element()); editor.insert_all(tuple_fields_before, insert_element); } else { - editor.insert(tuple_fields_before, ast::make::tokens::single_space()); + editor.insert(tuple_fields_before, make.whitespace(" ")); } if let Some(t) = strukt.semicolon_token() { editor.delete(t); } } else { - editor.insert(tuple_fields_before, ast::make::tokens::single_space()); + editor.insert(tuple_fields_before, make.whitespace(" ")); } editor.replace(tuple_fields.syntax(), record_fields.syntax()); @@ -153,10 +153,10 @@ fn edit_struct_references( for (file_id, refs) in usages { let source = ctx.sema.parse(file_id); - let mut editor = edit.make_editor(source.syntax()); + let editor = edit.make_editor(source.syntax()); for r in refs { - process_struct_name_reference(ctx, r, &mut editor, &source, &strukt_def, names); + process_struct_name_reference(ctx, r, &editor, &source, &strukt_def, names); } edit.add_file_edits(file_id.file_id(ctx.db()), editor); @@ -166,12 +166,12 @@ fn edit_struct_references( fn process_struct_name_reference( ctx: &AssistContext<'_>, r: FileReference, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, source: &ast::SourceFile, strukt_def: &Definition, names: &[ast::Name], ) -> Option<()> { - let make = SyntaxFactory::without_mappings(); + let make = editor.make(); let name_ref = r.name.as_name_ref()?; let path_segment = name_ref.syntax().parent().and_then(ast::PathSegment::cast)?; let full_path = path_segment.syntax().parent().and_then(ast::Path::cast)?.top_path(); @@ -189,7 +189,7 @@ fn process_struct_name_reference( let range = ctx.sema.original_range_opt(tuple_struct_pat.syntax())?.range; let new = make.record_pat_with_fields( full_path, - generate_record_pat_list(&tuple_struct_pat, names), + generate_record_pat_list(&tuple_struct_pat, names, make), ); editor.replace_all(cover_edit_range(source.syntax(), range), vec![new.syntax().clone().into()]); }, @@ -231,10 +231,11 @@ fn process_struct_name_reference( fn process_delimiter( ctx: &AssistContext<'_>, source: &ast::SourceFile, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, list: &impl AstNode, first_insert: Vec, ) { + let make = editor.make(); let Some(range) = ctx.sema.original_range_opt(list.syntax()) else { return }; let place = cover_edit_range(source.syntax(), range.range); @@ -247,7 +248,6 @@ fn process_delimiter( syntax::NodeOrToken::Token(t) => Some(t.clone()), }; - let make = SyntaxFactory::without_mappings(); if let Some(l_paren) = l_paren && l_paren.kind() == T!['('] { @@ -284,7 +284,7 @@ fn edit_field_references( let usages = def.usages(&ctx.sema).all(); for (file_id, refs) in usages { let source = ctx.sema.parse(file_id); - let mut editor = edit.make_editor(source.syntax()); + let editor = edit.make_editor(source.syntax()); for r in refs { if let Some(name_ref) = r.name.as_name_ref() && let Some(original) = ctx.sema.original_range_opt(name_ref.syntax()) @@ -300,8 +300,10 @@ fn edit_field_references( } } -fn generate_names(fields: impl Iterator) -> Vec { - let make = SyntaxFactory::without_mappings(); +fn generate_names( + fields: impl Iterator, + make: &SyntaxFactory, +) -> Vec { fields .enumerate() .map(|(i, _)| { @@ -314,6 +316,7 @@ fn generate_names(fields: impl Iterator) -> Vec ast::RecordPatFieldList { let pure_fields = pat.fields().filter(|p| !matches!(p, ast::Pat::RestPat(_))); let rest_len = names.len().saturating_sub(pure_fields.clone().count()); @@ -325,8 +328,8 @@ fn generate_record_pat_list( let fields = before_rest .chain(after_rest) - .map(|(pat, name)| ast::make::record_pat_field(ast::make::name_ref(&name.text()), pat)); - ast::make::record_pat_field_list(fields, rest_pat) + .map(|(pat, name)| make.record_pat_field(make.name_ref(&name.text()), pat)); + make.record_pat_field_list(fields, rest_pat) } #[cfg(test)] diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs index f8215d6723d3f..793e7465c11ae 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_while_to_loop.rs @@ -6,7 +6,6 @@ use syntax::{ ast::{ self, HasLoopBody, edit::{AstNodeEdit, IndentLevel}, - syntax_factory::SyntaxFactory, }, syntax_editor::{Element, Position}, }; @@ -52,8 +51,8 @@ pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) "Convert while to loop", target, |builder| { - let make = SyntaxFactory::without_mappings(); - let mut edit = builder.make_editor(while_expr.syntax()); + let editor = builder.make_editor(while_expr.syntax()); + let make = editor.make(); let while_indent_level = IndentLevel::from_node(while_expr.syntax()); let break_block = make @@ -63,7 +62,7 @@ pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) ) .indent(IndentLevel(1)); - edit.replace_all( + editor.replace_all( while_kw.syntax_element()..=while_cond.syntax().syntax_element(), vec![make.token(T![loop]).syntax_element()], ); @@ -73,17 +72,17 @@ pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) let if_expr = make.expr_if(while_cond, then_branch, Some(break_block.into())); let stmts = iter::once(make.expr_stmt(if_expr.into()).into()); let block_expr = make.block_expr(stmts, None); - edit.replace(while_body.syntax(), block_expr.indent(while_indent_level).syntax()); + editor.replace(while_body.syntax(), block_expr.indent(while_indent_level).syntax()); } else { - let if_cond = invert_boolean_expression(&make, while_cond); + let if_cond = invert_boolean_expression(make, while_cond); let if_expr = make.expr_if(if_cond, break_block, None).indent(while_indent_level); if !while_body.syntax().text().contains_char('\n') { - edit.insert( + editor.insert( Position::after(&l_curly), make.whitespace(&format!("\n{while_indent_level}")), ); } - edit.insert_all( + editor.insert_all( Position::after(&l_curly), vec![ make.whitespace(&format!("\n{}", while_indent_level + 1)).into(), @@ -91,9 +90,7 @@ pub(crate) fn convert_while_to_loop(acc: &mut Assists, ctx: &AssistContext<'_>) ], ); }; - - edit.add_mappings(make.finish_with_mappings()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_struct_binding.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_struct_binding.rs index ec4a83b642c01..9ffce445d1a1e 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_struct_binding.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_struct_binding.rs @@ -124,9 +124,9 @@ fn destructure_struct_binding_impl( data: &StructEditData, ) { let field_names = generate_field_names(ctx, data); - let mut editor = builder.make_editor(data.target.syntax()); - destructure_pat(ctx, &mut editor, data, &field_names); - update_usages(ctx, &mut editor, data, &field_names.into_iter().collect()); + let editor = builder.make_editor(data.target.syntax()); + destructure_pat(ctx, &editor, data, &field_names); + update_usages(ctx, &editor, data, &field_names.into_iter().collect()); builder.add_file_edits(ctx.vfs_file_id(), editor); } @@ -145,12 +145,8 @@ struct StructEditData { } impl StructEditData { - fn apply_to_destruct( - &self, - new_pat: ast::Pat, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, - ) { + fn apply_to_destruct(&self, new_pat: ast::Pat, editor: &SyntaxEditor) { + let make = editor.make(); match &self.target { Target::IdentPat(pat) => { // If the binding is nested inside a record, we need to wrap the new @@ -275,15 +271,15 @@ fn get_names_in_scope( fn destructure_pat( _ctx: &AssistContext<'_>, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, data: &StructEditData, field_names: &[(SmolStr, SmolStr)], ) { + let make = editor.make(); let struct_path = mod_path_to_ast(&data.struct_def_path, data.edition); let is_ref = data.target.is_ref(); let is_mut = data.target.is_mut(); - let make = SyntaxFactory::with_mappings(); let new_pat = match data.kind { hir::StructKind::Tuple => { let ident_pats = field_names.iter().map(|(_, new_name)| { @@ -314,8 +310,7 @@ fn destructure_pat( hir::StructKind::Unit => make.path_pat(struct_path), }; - data.apply_to_destruct(new_pat, editor, &make); - editor.add_mappings(make.finish_with_mappings()); + data.apply_to_destruct(new_pat, editor); } fn generate_field_names(ctx: &AssistContext<'_>, data: &StructEditData) -> Vec<(SmolStr, SmolStr)> { @@ -354,18 +349,16 @@ fn new_field_name(base_name: SmolStr, names_in_scope: &FxHashSet) -> Sm fn update_usages( ctx: &AssistContext<'_>, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, data: &StructEditData, field_names: &FxHashMap, ) { let source = ctx.source_file().syntax(); - let make = SyntaxFactory::with_mappings(); let edits = data .usages .iter() - .filter_map(|r| build_usage_edit(ctx, &make, data, r, field_names)) + .filter_map(|r| build_usage_edit(ctx, editor.make(), data, r, field_names)) .collect_vec(); - editor.add_mappings(make.finish_with_mappings()); for (old, new) in edits { if let Some(range) = ctx.sema.original_range_opt(&old) { editor.replace_all(cover_edit_range(source, range.range), vec![new.into()]); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs index 23c11b258c1a7..291605056b3c6 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/destructure_tuple_binding.rs @@ -89,22 +89,17 @@ fn destructure_tuple_edit_impl( data: &TupleData, in_sub_pattern: bool, ) { - let mut syntax_editor = edit.make_editor(data.ident_pat.syntax()); - let syntax_factory = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(data.ident_pat.syntax()); + let make = editor.make(); - let assignment_edit = - edit_tuple_assignment(ctx, edit, &mut syntax_editor, &syntax_factory, data, in_sub_pattern); - let current_file_usages_edit = edit_tuple_usages(data, ctx, &syntax_factory, in_sub_pattern); + let assignment_edit = edit_tuple_assignment(ctx, edit, &editor, data, in_sub_pattern); + let current_file_usages_edit = edit_tuple_usages(data, ctx, make, in_sub_pattern); - assignment_edit.apply(&mut syntax_editor, &syntax_factory); + assignment_edit.apply(&editor); if let Some(usages_edit) = current_file_usages_edit { - usages_edit - .into_iter() - .for_each(|usage_edit| usage_edit.apply(ctx, edit, &mut syntax_editor)) + usages_edit.into_iter().for_each(|usage_edit| usage_edit.apply(ctx, edit, &editor)) } - - syntax_editor.add_mappings(syntax_factory.finish_with_mappings()); - edit.add_file_edits(ctx.vfs_file_id(), syntax_editor); + edit.add_file_edits(ctx.vfs_file_id(), editor); } fn collect_data(ident_pat: IdentPat, ctx: &AssistContext<'_>) -> Option { @@ -175,11 +170,11 @@ struct TupleData { fn edit_tuple_assignment( ctx: &AssistContext<'_>, edit: &mut SourceChangeBuilder, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, data: &TupleData, in_sub_pattern: bool, ) -> AssignmentEdit { + let make = editor.make(); let tuple_pat = { let original = &data.ident_pat; let is_ref = original.ref_token().is_some(); @@ -223,22 +218,17 @@ struct AssignmentEdit { } impl AssignmentEdit { - fn apply(self, syntax_editor: &mut SyntaxEditor, syntax_mapping: &SyntaxFactory) { + fn apply(self, editor: &SyntaxEditor) { + let make = editor.make(); // with sub_pattern: keep original tuple and add subpattern: `tup @ (_0, _1)` if self.in_sub_pattern { - self.ident_pat.set_pat_with_editor( - Some(self.tuple_pat.into()), - syntax_editor, - syntax_mapping, - ) + self.ident_pat.set_pat(Some(self.tuple_pat.into()), editor); } else if self.is_shorthand_field { - syntax_editor.insert(Position::after(self.ident_pat.syntax()), self.tuple_pat.syntax()); - syntax_editor - .insert(Position::after(self.ident_pat.syntax()), syntax_mapping.whitespace(" ")); - syntax_editor - .insert(Position::after(self.ident_pat.syntax()), syntax_mapping.token(T![:])); + editor.insert(Position::after(self.ident_pat.syntax()), self.tuple_pat.syntax()); + editor.insert(Position::after(self.ident_pat.syntax()), make.whitespace(" ")); + editor.insert(Position::after(self.ident_pat.syntax()), make.token(T![:])); } else { - syntax_editor.replace(self.ident_pat.syntax(), self.tuple_pat.syntax()) + editor.replace(self.ident_pat.syntax(), self.tuple_pat.syntax()) } } } @@ -317,7 +307,7 @@ impl EditTupleUsage { self, ctx: &AssistContext<'_>, edit: &mut SourceChangeBuilder, - syntax_editor: &mut SyntaxEditor, + syntax_editor: &SyntaxEditor, ) { match self { EditTupleUsage::NoIndex(range) => { @@ -907,6 +897,7 @@ fn main() { check_assist( assist, r#" +//- minicore: fn fn main() { let f = |$0t| t.0 + t.1; let v = f((1,2)); @@ -1111,6 +1102,7 @@ fn main() { check_assist( assist, r#" +//- minicore: fn fn main() { let $0t = (1,2); let v = t.1; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/desugar_try_expr.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/desugar_try_expr.rs index 865dc862215f1..fc894f0fe9a0f 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/desugar_try_expr.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/desugar_try_expr.rs @@ -65,9 +65,8 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op "Replace try expression with match", target, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(try_expr.syntax()); - + let editor = builder.make_editor(try_expr.syntax()); + let make = editor.make(); let sad_pat = match try_enum { TryEnum::Option => make.path_pat(make.ident_path("None")), TryEnum::Result => make @@ -77,7 +76,7 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ) .into(), }; - let sad_expr = make.expr_return(Some(sad_expr(try_enum, &make, || { + let sad_expr = make.expr_return(Some(sad_expr(try_enum, make, || { make.expr_path(make.ident_path("err")) }))); @@ -95,7 +94,6 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op .indent(IndentLevel::from_node(try_expr.syntax())); editor.replace(try_expr.syntax(), expr_match.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -109,8 +107,8 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op "Replace try expression with let else", target, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(let_stmt.syntax()); + let editor = builder.make_editor(let_stmt.syntax()); + let make = editor.make(); let indent_level = IndentLevel::from_node(let_stmt.syntax()); let fill_expr = || crate::utils::expr_fill_default(ctx.config); @@ -124,7 +122,7 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op make.block_expr( iter::once( make.expr_stmt( - make.expr_return(Some(sad_expr(try_enum, &make, fill_expr))).into(), + make.expr_return(Some(sad_expr(try_enum, make, fill_expr))).into(), ) .into(), ), @@ -133,7 +131,6 @@ pub(crate) fn desugar_try_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op .indent(indent_level), ); editor.replace(let_stmt.syntax(), new_let_stmt.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs index 6c5c21bfc90f3..1c8cbf5af5941 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_glob_import.rs @@ -8,7 +8,7 @@ use ide_db::{ use stdx::never; use syntax::{ AstNode, Direction, SyntaxNode, SyntaxToken, T, - ast::{self, Use, UseTree, VisibilityKind, make}, + ast::{self, Use, UseTree, VisibilityKind}, }; use crate::{ @@ -148,6 +148,8 @@ fn build_expanded_import( current_module: Module, reexport_public_items: bool, ) { + let editor = builder.make_editor(use_tree.syntax()); + let make = editor.make(); let (must_be_pub, visible_from) = if !reexport_public_items { (false, current_module) } else { @@ -167,15 +169,13 @@ fn build_expanded_import( if reexport_public_items { refs_in_target } else { refs_in_target.used_refs(ctx) }; let names_to_import = find_names_to_import(filtered_defs, imported_defs); - let expanded = make::use_tree_list(names_to_import.iter().map(|n| { - let path = make::ext::ident_path( + let expanded = make.use_tree_list(names_to_import.iter().map(|n| { + let path = make.ident_path( &n.display(ctx.db(), current_module.krate(ctx.db()).edition(ctx.db())).to_string(), ); - make::use_tree(path, None, None, false) - })) - .clone_for_update(); + make.use_tree(path, None, None, false) + })); - let mut editor = builder.make_editor(use_tree.syntax()); match use_tree.star_token() { Some(star) => { let needs_braces = use_tree.path().is_some() && names_to_import.len() != 1; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_rest_pattern.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_rest_pattern.rs index a7e78dfc9c940..dc4976e8c29d7 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_rest_pattern.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/expand_rest_pattern.rs @@ -51,8 +51,8 @@ fn expand_record_rest_pattern( "Fill struct fields", rest_pat.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(rest_pat.syntax()); + let editor = builder.make_editor(rest_pat.syntax()); + let make = editor.make(); let new_fields = old_field_list.fields().chain(matched_fields.iter().map(|(f, _)| { make.record_pat_field_shorthand( make.ident_pat( @@ -66,8 +66,6 @@ fn expand_record_rest_pattern( let new_field_list = make.record_pat_field_list(new_fields, None); editor.replace(old_field_list.syntax(), new_field_list.syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -130,8 +128,8 @@ fn expand_tuple_struct_rest_pattern( "Fill tuple struct fields", rest_pat.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(rest_pat.syntax()); + let editor = builder.make_editor(rest_pat.syntax()); + let make = editor.make(); let mut name_gen = NameGenerator::new_from_scope_locals(ctx.sema.scope(pat.syntax())); let new_pat = make.tuple_struct_pat( @@ -141,7 +139,7 @@ fn expand_tuple_struct_rest_pattern( .chain(fields[prefix_count..fields.len() - suffix_count].iter().map(|f| { gen_unnamed_pat( ctx, - &make, + make, &mut name_gen, &f.ty(ctx.db()).to_type(ctx.sema.db), f.index(), @@ -151,8 +149,6 @@ fn expand_tuple_struct_rest_pattern( ); editor.replace(pat.syntax(), new_pat.syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -200,24 +196,21 @@ fn expand_tuple_rest_pattern( "Fill tuple fields", rest_pat.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(rest_pat.syntax()); - + let editor = builder.make_editor(rest_pat.syntax()); + let make = editor.make(); let mut name_gen = NameGenerator::new_from_scope_locals(ctx.sema.scope(pat.syntax())); let new_pat = make.tuple_pat( pat.fields() .take(prefix_count) .chain(fields[prefix_count..len - suffix_count].iter().enumerate().map( |(index, ty)| { - gen_unnamed_pat(ctx, &make, &mut name_gen, ty, prefix_count + index) + gen_unnamed_pat(ctx, make, &mut name_gen, ty, prefix_count + index) }, )) .chain(pat.fields().skip(prefix_count + 1)), ); editor.replace(pat.syntax(), new_pat.syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -264,8 +257,8 @@ fn expand_slice_rest_pattern( "Fill slice fields", rest_pat.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(rest_pat.syntax()); + let editor = builder.make_editor(rest_pat.syntax()); + let make = editor.make(); let mut name_gen = NameGenerator::new_from_scope_locals(ctx.sema.scope(pat.syntax())); let new_pat = make.slice_pat( @@ -273,14 +266,12 @@ fn expand_slice_rest_pattern( .take(prefix_count) .chain( (prefix_count..len - suffix_count) - .map(|index| gen_unnamed_pat(ctx, &make, &mut name_gen, &ty, index)), + .map(|index| gen_unnamed_pat(ctx, make, &mut name_gen, &ty, index)), ) .chain(pat.pats().skip(prefix_count + 1)), ); editor.replace(pat.syntax(), new_pat.syntax()); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs index 35e8baa18aca7..c87ded9dc47b7 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_expressions_from_format_string.rs @@ -8,7 +8,7 @@ use syntax::{ AstNode, AstToken, NodeOrToken, SyntaxKind::WHITESPACE, SyntaxToken, T, - ast::{self, TokenTree, syntax_factory::SyntaxFactory}, + ast::{self, TokenTree}, }; // Assist: extract_expressions_from_format_string @@ -57,7 +57,8 @@ pub(crate) fn extract_expressions_from_format_string( "Extract format expressions", tt.syntax().text_range(), |edit| { - let make = SyntaxFactory::without_mappings(); + let editor = edit.make_editor(tt.syntax()); + let make = editor.make(); // Extract existing arguments in macro let mut raw_tokens = tt.token_trees_and_tokens().skip(1).collect_vec(); let format_string_index = format_str_index(&raw_tokens, &fmt_string); @@ -110,7 +111,7 @@ pub(crate) fn extract_expressions_from_format_string( Arg::Expr(s) => { // insert arg let expr = ast::Expr::parse(&s, ctx.edition()).syntax_node(); - let mut expr_tt = utils::tt_from_syntax(expr, &make); + let mut expr_tt = utils::tt_from_syntax(expr, make); new_tt_bits.append(&mut expr_tt); } Arg::Placeholder => { @@ -131,7 +132,6 @@ pub(crate) fn extract_expressions_from_format_string( // Insert new args let new_tt = make.token_tree(tt_delimiter, new_tt_bits); - let mut editor = edit.make_editor(tt.syntax()); editor.replace(tt.syntax(), new_tt.syntax()); if let Some(cap) = ctx.config.snippet_cap { @@ -158,7 +158,6 @@ pub(crate) fn extract_expressions_from_format_string( editor.add_annotation(literal, annotation); } } - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs index fa5bb39c54ba8..4219e6845fad5 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_function.rs @@ -1362,26 +1362,26 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option break, - SyntaxKind::IMPL => { - if body.extracted_from_trait_impl() && matches!(anchor, Anchor::Method) { - let impl_node = find_non_trait_impl(&next_ancestor); - if let target_node @ Some(_) = impl_node.as_ref().and_then(last_impl_member) { - return target_node; - } + SyntaxKind::IMPL + if body.extracted_from_trait_impl() && matches!(anchor, Anchor::Method) => + { + let impl_node = find_non_trait_impl(&next_ancestor); + if let target_node @ Some(_) = impl_node.as_ref().and_then(last_impl_member) { + return target_node; } } SyntaxKind::ITEM_LIST if !matches!(anchor, Anchor::Freestanding) => continue, - SyntaxKind::ITEM_LIST => { - if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) { - break; - } + SyntaxKind::ITEM_LIST + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::MODULE) => + { + break; } SyntaxKind::ASSOC_ITEM_LIST if !matches!(anchor, Anchor::Method) => continue, SyntaxKind::ASSOC_ITEM_LIST if body.extracted_from_trait_impl() => continue, - SyntaxKind::ASSOC_ITEM_LIST => { - if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) { - break; - } + SyntaxKind::ASSOC_ITEM_LIST + if ancestors.peek().map(SyntaxNode::kind) == Some(SyntaxKind::IMPL) => + { + break; } _ => (), } @@ -2088,7 +2088,11 @@ fn fix_param_usages( for (param, usages) in usages_for_param { for usage in usages { match usage.syntax().ancestors().skip(1).find_map(ast::Expr::cast) { - Some(ast::Expr::MethodCallExpr(_) | ast::Expr::FieldExpr(_)) => { + Some( + ast::Expr::MethodCallExpr(_) + | ast::Expr::FieldExpr(_) + | ast::Expr::IndexExpr(_), + ) => { // do nothing } Some(ast::Expr::RefExpr(node)) @@ -2124,19 +2128,19 @@ fn update_external_control_flow(handler: &FlowHandler<'_>, syntax: &SyntaxNode) for event in syntax.preorder() { match event { WalkEvent::Enter(e) => match e.kind() { - SyntaxKind::LOOP_EXPR | SyntaxKind::WHILE_EXPR | SyntaxKind::FOR_EXPR => { - if nested_loop.is_none() { - nested_loop = Some(e.clone()); - } + SyntaxKind::LOOP_EXPR | SyntaxKind::WHILE_EXPR | SyntaxKind::FOR_EXPR + if nested_loop.is_none() => + { + nested_loop = Some(e.clone()); } SyntaxKind::FN | SyntaxKind::CONST | SyntaxKind::STATIC | SyntaxKind::IMPL - | SyntaxKind::MODULE => { - if nested_scope.is_none() { - nested_scope = Some(e.clone()); - } + | SyntaxKind::MODULE + if nested_scope.is_none() => + { + nested_scope = Some(e.clone()); } _ => {} }, @@ -3211,6 +3215,32 @@ fn $0fun_name(n: &mut i32) { ); } + #[test] + fn mut_index_from_outer_scope() { + check_assist( + extract_function, + r#" +//- minicore: index +fn foo() { + let mut arr = [1i32]; + $0arr[0] = 3;$0 + let _ = arr; +} +"#, + r#" +fn foo() { + let mut arr = [1i32]; + fun_name(&mut arr); + let _ = arr; +} + +fn $0fun_name(arr: &mut [i32; 1]) { + arr[0] = 3; +} +"#, + ); + } + #[test] fn mut_field_from_outer_scope() { check_assist( diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs index 3bbf9a0ad3a25..21013e2e614c7 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_struct_from_enum_variant.rs @@ -60,8 +60,8 @@ pub(crate) fn extract_struct_from_enum_variant( "Extract struct from enum variant", target, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(variant.syntax()); + let editor = builder.make_editor(variant.syntax()); + let make = editor.make(); let edition = enum_hir.krate(ctx.db()).edition(ctx.db()); let variant_hir_name = variant_hir.name(ctx.db()); let enum_module_def = ModuleDef::from(enum_hir); @@ -87,7 +87,7 @@ pub(crate) fn extract_struct_from_enum_variant( if processed.is_empty() { continue; } - let mut file_editor = builder.make_editor(processed[0].0.syntax()); + let file_editor = builder.make_editor(processed[0].0.syntax()); processed.into_iter().for_each(|(path, node, import)| { apply_references( ctx.config.insert_use, @@ -95,11 +95,9 @@ pub(crate) fn extract_struct_from_enum_variant( node, import, edition, - &mut file_editor, - &make, + &file_editor, ) }); - file_editor.add_mappings(make.take()); builder.add_file_edits(file_id.file_id(ctx.db()), file_editor); } @@ -112,20 +110,12 @@ pub(crate) fn extract_struct_from_enum_variant( references, ); processed.into_iter().for_each(|(path, node, import)| { - apply_references( - ctx.config.insert_use, - path, - node, - import, - edition, - &mut editor, - &make, - ) + apply_references(ctx.config.insert_use, path, node, import, edition, &editor) }); } let generic_params = enum_ast.generic_param_list().and_then(|known_generics| { - extract_generic_params(&make, &known_generics, &field_list) + extract_generic_params(make, &known_generics, &field_list) }); // resolve GenericArg in field_list to actual type @@ -148,13 +138,13 @@ pub(crate) fn extract_struct_from_enum_variant( }; let (comments_for_struct, comments_to_delete) = - collect_variant_comments(&make, variant.syntax()); + collect_variant_comments(make, variant.syntax()); for element in &comments_to_delete { editor.delete(element.clone()); } let def = create_struct_def( - &make, + make, variant_name.clone(), &field_list, generic_params.clone(), @@ -173,15 +163,10 @@ pub(crate) fn extract_struct_from_enum_variant( insert_items.extend(comments_for_struct); insert_items.push(def.syntax().clone().into()); insert_items.push(make.whitespace(&format!("\n\n{indent}")).into()); - editor.insert_all_with_whitespace( - Position::before(enum_ast.syntax()), - insert_items, - &make, - ); + editor.insert_all_with_whitespace(Position::before(enum_ast.syntax()), insert_items); - update_variant(&make, &mut editor, &variant, generic_params); + update_variant(&editor, &variant, generic_params); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -230,6 +215,7 @@ fn extract_generic_params( ) -> Option { let mut generics = known_generics.generic_params().map(|param| (param, false)).collect_vec(); + #[expect(clippy::unnecessary_fold, reason = "this function has side effects")] let tagged_one = match field_list { Either::Left(field_list) => field_list .fields() @@ -263,6 +249,10 @@ fn tag_generics_in_variant(ty: &ast::Type, generics: &mut [(ast::GenericParam, b } } param if matches!(token.kind(), T![ident]) => { + #[expect( + clippy::collapsible_match, + reason = "it won't compile since in the guard, `param` is immutable" + )] if match param { ast::GenericParam::ConstParam(konst) => konst .name() @@ -340,11 +330,11 @@ fn create_struct_def( } fn update_variant( - make: &SyntaxFactory, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, variant: &ast::Variant, generics: Option, ) -> Option<()> { + let make = editor.make(); let name = variant.name()?; let generic_args = generics .filter(|generics| generics.generic_params().count() > 0) @@ -407,17 +397,11 @@ fn apply_references( node: SyntaxNode, import: Option<(ImportScope, hir::ModPath)>, edition: Edition, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, ) { + let make = editor.make(); if let Some((scope, path)) = import { - insert_use_with_editor( - &scope, - mod_path_to_ast(&path, edition), - &insert_use_cfg, - editor, - make, - ); + insert_use_with_editor(&scope, mod_path_to_ast(&path, edition), &insert_use_cfg, editor); } // deep clone to prevent cycle let path = make.path_from_segments(iter::once(segment.clone()), false); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs index e4fdac27f47ff..eda35eba45c9a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_type_alias.rs @@ -1,11 +1,8 @@ use either::Either; use hir::HirDisplay; -use ide_db::syntax_helpers::node_ext::walk_ty; +use ide_db::syntax_helpers::{node_ext::walk_ty, suggest_name::NameGenerator}; use syntax::{ - ast::{ - self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, - syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel}, syntax_editor, }; @@ -43,9 +40,10 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> ); let target = ty.syntax().text_range(); + let scope = ctx.sema.scope(ty.syntax())?; let resolved_ty = ctx.sema.resolve_type(&ty)?; let resolved_ty = if !resolved_ty.contains_unknown() { - let module = ctx.sema.scope(ty.syntax())?.module(); + let module = scope.module(); resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()? } else { ty.to_string() @@ -56,10 +54,11 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> "Extract type as type alias", target, |builder| { - let mut edit = builder.make_editor(node); - let make = SyntaxFactory::without_mappings(); + let editor = builder.make_editor(node); + let make = editor.make(); let resolved_ty = make.ty(&resolved_ty); + let name = &NameGenerator::new_from_scope_non_locals(Some(scope)).suggest_name("Type"); let mut known_generics = match item.generic_param_list() { Some(it) => it.generic_params().collect(), @@ -78,24 +77,24 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> // Replace original type with the alias let ty_args = generic_params.as_ref().map(|it| it.to_generic_args().generic_args()); let new_ty = if let Some(ty_args) = ty_args { - make.generic_ty_path_segment(make.name_ref("Type"), ty_args) + make.generic_ty_path_segment(make.name_ref(name), ty_args) } else { - make.path_segment(make.name_ref("Type")) + make.path_segment(make.name_ref(name)) }; - edit.replace(ty.syntax(), new_ty.syntax()); + editor.replace(ty.syntax(), new_ty.syntax()); // Insert new alias let ty_alias = - make.ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None))); + make.ty_alias(None, name, generic_params, None, None, Some((resolved_ty, None))); if let Some(cap) = ctx.config.snippet_cap && let Some(name) = ty_alias.name() { - edit.add_annotation(name.syntax(), builder.make_tabstop_before(cap)); + editor.add_annotation(name.syntax(), builder.make_tabstop_before(cap)); } let indent = IndentLevel::from_node(node); - edit.insert_all( + editor.insert_all( syntax_editor::Position::before(node), vec![ ty_alias.syntax().clone().into(), @@ -103,7 +102,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> ], ); - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } @@ -450,4 +449,41 @@ fn main() { "#, ) } + + #[test] + fn duplicate_names() { + check_assist( + extract_type_alias, + r" +struct Type; +struct S { + field: $0u8$0, +} + ", + r#" +struct Type; +type $0Type1 = u8; + +struct S { + field: Type1, +} + "#, + ); + + check_assist( + extract_type_alias, + r" +struct S { + field: $0u8$0, +} + ", + r#" +type $0Type1 = u8; + +struct S { + field: Type1, +} + "#, + ); + } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs index 1556339d8df43..c5c57c76b47ca 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/extract_variable.rs @@ -4,14 +4,13 @@ use ide_db::{ syntax_helpers::{LexedStr, suggest_name}, }; use syntax::{ - NodeOrToken, SyntaxKind, SyntaxNode, T, - algo::ancestors_at_offset, + Direction, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken, T, TextRange, + algo::{ancestors_at_offset, skip_trivia_token}, ast::{ self, AstNode, edit::{AstNodeEdit, IndentLevel}, - syntax_factory::SyntaxFactory, }, - syntax_editor::Position, + syntax_editor::{Element, Position}, }; use crate::{AssistContext, AssistId, Assists, utils::is_body_const}; @@ -92,27 +91,54 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?; let range = node.text_range(); - let to_extract = node - .descendants() - .take_while(|it| range.contains_range(it.text_range())) - .find_map(valid_target_expr(ctx))?; + let (to_replace, analysis) = if node.kind() == SyntaxKind::TOKEN_TREE { + let (first, last) = extract_token_range_of(&node, ctx.selection_trimmed())?; - let ty = ctx.sema.type_of_expr(&to_extract).map(TypeInfo::adjusted); + let first_descend = ctx.sema.descend_into_macros_single_exact(first.clone()); + let last_descend = ctx.sema.descend_into_macros_single_exact(last.clone()); + let range = first_descend.text_range().cover(last_descend.text_range()); + + if first_descend.parent_ancestors().last() != last_descend.parent_ancestors().last() { + return None; + } + + let expr = first_descend + .parent_ancestors() + .skip_while(|it| !it.text_range().contains_range(range)) + .find_map(valid_target_expr(ctx))?; + let original_range = ctx.sema.original_range(expr.syntax()); + let (first, last) = extract_token_range_of(&node, original_range.range)?; + let to_extract = first.syntax_element()..=last.syntax_element(); + (to_extract, expr) + } else { + let expr = node + .descendants() + .take_while(|it| range.contains_range(it.text_range())) + .find_map(valid_target_expr(ctx))?; + let to_extract = expr.syntax().syntax_element(); + (to_extract.clone()..=to_extract, expr) + }; + let place = match to_replace.start() { + NodeOrToken::Node(node) => node.clone(), + NodeOrToken::Token(t) => t.parent()?, + }; + + let ty = ctx.sema.type_of_expr(&analysis).map(TypeInfo::adjusted); if matches!(&ty, Some(ty_info) if ty_info.is_unit()) { return None; } - let parent = to_extract.syntax().parent().and_then(ast::Expr::cast); + let parent = analysis.syntax().parent().and_then(ast::Expr::cast); // Any expression that autoderefs may need adjustment. let mut needs_adjust = parent.as_ref().is_some_and(|it| match it { ast::Expr::FieldExpr(_) | ast::Expr::MethodCallExpr(_) | ast::Expr::CallExpr(_) | ast::Expr::AwaitExpr(_) => true, - ast::Expr::IndexExpr(index) if index.base().as_ref() == Some(&to_extract) => true, + ast::Expr::IndexExpr(index) if index.base().as_ref() == Some(&analysis) => true, _ => false, }); - let mut to_extract_no_ref = peel_parens(to_extract.clone()); + let mut to_extract_no_ref = peel_parens(analysis.clone()); let needs_ref = needs_adjust && match &to_extract_no_ref { ast::Expr::FieldExpr(_) @@ -127,14 +153,14 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op } _ => false, }; - let module = ctx.sema.scope(to_extract.syntax())?.module(); - let target = to_extract.syntax().text_range(); + let module = ctx.sema.scope(analysis.syntax())?.module(); + let target = to_replace.start().text_range().cover(to_replace.end().text_range()); let needs_mut = match &parent { Some(ast::Expr::RefExpr(expr)) => expr.mut_token().is_some(), _ => needs_adjust && !needs_ref && ty.as_ref().is_some_and(|ty| ty.is_mutable_reference()), }; for kind in ExtractionKind::ALL { - let Some(anchor) = Anchor::from(&to_extract, kind) else { + let Some(anchor) = Anchor::from(&place, kind) else { continue; }; @@ -169,10 +195,18 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op kind.label(), target, |edit| { - let (var_name, expr_replace) = kind.get_name_and_expr(ctx, &to_extract); + let (var_name, expr_replace) = kind.get_name_and_expr(ctx, &analysis); + + let to_replace = + if expr_replace.ancestors().last() == to_replace.start().ancestors().last() { + let element = expr_replace.clone().syntax_element(); + element.clone()..=element + } else { + to_replace.clone() + }; - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(&expr_replace); + let editor = edit.make_editor(&place); + let make = editor.make(); let pat_name = make.name(&var_name); let name_expr = make.expr_path(make.ident_path(&var_name)); @@ -236,7 +270,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ], ); - editor.replace(expr_replace, name_expr.syntax()); + editor.replace_all(to_replace, vec![name_expr.syntax().syntax_element()]); } Anchor::Replace(stmt) => { cov_mark::hit!(test_extract_var_expr_stmt); @@ -252,17 +286,16 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op make.block_expr([new_stmt], Some(name_expr)) } else { // `expr_replace` is a descendant of `to_wrap`, so we just replace it with `name_expr`. - editor.replace(expr_replace, name_expr.syntax()); + editor + .replace_all(to_replace, vec![name_expr.syntax().syntax_element()]); make.block_expr([new_stmt], Some(to_wrap.clone())) } // fixup indentation of block - .indent_with_mapping(indent_to, &make); + .indent_with_mapping(indent_to, make); editor.replace(to_wrap.syntax(), block.syntax()); } } - - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); edit.rename(); }, @@ -272,6 +305,23 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op Some(()) } +fn extract_token_range_of( + node: &SyntaxNode, + range: TextRange, +) -> Option<(SyntaxToken, SyntaxToken)> { + let first = node.token_at_offset(range.start()).right_biased()?; + let last = node.token_at_offset(range.end()).left_biased()?; + + let first = skip_trivia_token(first, Direction::Next)?; + let last = skip_trivia_token(last, Direction::Next)?; + + if first.text_range().ordering(last.text_range()).is_gt() { + return None; + } + + Some((first, last)) +} + fn peel_parens(mut expr: ast::Expr) -> ast::Expr { while let ast::Expr::ParenExpr(parens) = &expr { let Some(expr_inside) = parens.expr() else { break }; @@ -401,9 +451,8 @@ enum Anchor { } impl Anchor { - fn from(to_extract: &ast::Expr, kind: &ExtractionKind) -> Option { - let result = to_extract - .syntax() + fn from(place: &SyntaxNode, kind: &ExtractionKind) -> Option { + let result = place .ancestors() .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind())) .find_map(|node| { @@ -435,7 +484,7 @@ impl Anchor { if let Some(stmt) = ast::Stmt::cast(node.clone()) { if let ast::Stmt::ExprStmt(stmt) = stmt - && stmt.expr().as_ref() == Some(to_extract) + && stmt.expr().is_some_and(|it| it.syntax() == place) { return Some(Anchor::Replace(stmt)); } @@ -446,7 +495,7 @@ impl Anchor { match kind { ExtractionKind::Constant | ExtractionKind::Static if result.is_none() => { - to_extract.syntax().ancestors().find_map(|node| { + place.ancestors().find_map(|node| { let item = ast::Item::cast(node.clone())?; let parent = item.syntax().parent()?; match parent.kind() { @@ -2771,6 +2820,186 @@ fn main() { let t2 = t; let x = s; } +"#, + "Extract into variable", + ); + } + + #[test] + fn extract_variable_in_token_tree() { + // FIXME: Keep the original trivia instead of extracting macro expanded? + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= $02 + 3$0 + 4); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0var_name = 2+3; + let x = foo!(= var_name + 4); +} +"#, + "Extract into variable", + ); + + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= $02 +$0 3 + 4); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0var_name = 2+3; + let x = foo!(= var_name + 4); +} +"#, + "Extract into variable", + ); + + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= $02 + 3 + 4$0); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0var_name = 2+3+4; + let x = foo!(= var_name); +} +"#, + "Extract into variable", + ); + + // FIXME: Extract to inside the macro instead of outside the macro + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= { + $02 + 3 + 4$0 + }); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0var_name = 2+3+4; + let x = foo!(= { + var_name + }); +} +"#, + "Extract into variable", + ); + } + + #[test] + fn extract_variable_in_token_tree_record_expr() { + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= Foo { x: $02 + 3$0 }); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0x = 2+3; + let x = foo!(= Foo { x: x }); +} +"#, + "Extract into variable", + ); + + check_assist_by_label( + extract_variable, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let x = foo!(= Foo { x: $02 + 3$0 + 4 }); +} +"#, + r#" +macro_rules! foo { + (= $($t:tt)*) => { + $($t)* + }; +} + +fn main() { + let $0var_name = 2+3; + let x = foo!(= Foo { x: var_name + 4 }); +} "#, "Extract into variable", ); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs index 440f2d5f17ca4..d8714dd49c2db 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/fix_visibility.rs @@ -78,7 +78,7 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext<'_>) }; acc.add(AssistId::quick_fix("fix_visibility"), assist_label, target, |builder| { - let mut editor = builder.make_editor(vis_owner.syntax()); + let editor = builder.make_editor(vis_owner.syntax()); if let Some(current_visibility) = vis_owner.visibility() { editor.replace(current_visibility.syntax(), missing_visibility.syntax()); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs index 922a61bf3a854..17911150f5e7b 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_binexpr.rs @@ -1,6 +1,6 @@ use syntax::{ SyntaxKind, T, - ast::{self, AstNode, BinExpr, RangeItem, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, BinExpr, RangeItem}, syntax_editor::Position, }; @@ -48,14 +48,13 @@ pub(crate) fn flip_binexpr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option "Flip binary expression", op_token.text_range(), |builder| { - let mut editor = builder.make_editor(&expr.syntax().parent().unwrap()); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(&expr.syntax().parent().unwrap()); + let make = editor.make(); if let FlipAction::FlipAndReplaceOp(binary_op) = action { editor.replace(op_token, make.token(binary_op)) }; editor.replace(lhs.syntax(), rhs.syntax()); editor.replace(rhs.syntax(), lhs.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -133,25 +132,25 @@ pub(crate) fn flip_range_expr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt "Flip range expression", op.text_range(), |builder| { - let mut edit = builder.make_editor(range_expr.syntax()); + let editor = builder.make_editor(range_expr.syntax()); match (start, end) { (Some(start), Some(end)) => { - edit.replace(start.syntax(), end.syntax()); - edit.replace(end.syntax(), start.syntax()); + editor.replace(start.syntax(), end.syntax()); + editor.replace(end.syntax(), start.syntax()); } (Some(start), None) => { - edit.delete(start.syntax()); - edit.insert(Position::after(&op), start.syntax()); + editor.delete(start.syntax()); + editor.insert(Position::after(&op), start.syntax()); } (None, Some(end)) => { - edit.delete(end.syntax()); - edit.insert(Position::before(&op), end.syntax()); + editor.delete(end.syntax()); + editor.insert(Position::before(&op), end.syntax()); } (None, None) => (), } - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs index 1e95d4772349e..65dc36cdca7ee 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_comma.rs @@ -2,7 +2,6 @@ use syntax::{ AstNode, Direction, NodeOrToken, SyntaxKind, SyntaxToken, T, algo::non_trivia_sibling, ast::{self, syntax_factory::SyntaxFactory}, - syntax_editor::SyntaxMapping, }; use crate::{AssistContext, AssistId, Assists}; @@ -42,14 +41,13 @@ pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( let target = comma.text_range(); acc.add(AssistId::refactor_rewrite("flip_comma"), "Flip comma", target, |builder| { let parent = comma.parent().unwrap(); - let mut editor = builder.make_editor(&parent); + let editor = builder.make_editor(&parent); if let Some(parent) = ast::TokenTree::cast(parent) { // An attribute. It often contains a path followed by a // token tree (e.g. `align(2)`), so we have to be smarter. - let (new_tree, mapping) = flip_tree(parent.clone(), comma); + let new_tree = flip_tree(parent.clone(), comma, editor.make()); editor.replace(parent.syntax(), new_tree.syntax()); - editor.add_mappings(mapping); } else { editor.replace(prev.clone(), next.clone()); editor.replace(next.clone(), prev.clone()); @@ -59,7 +57,7 @@ pub(crate) fn flip_comma(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( }) } -fn flip_tree(tree: ast::TokenTree, comma: SyntaxToken) -> (ast::TokenTree, SyntaxMapping) { +fn flip_tree(tree: ast::TokenTree, comma: SyntaxToken, make: &SyntaxFactory) -> ast::TokenTree { let mut tree_iter = tree.token_trees_and_tokens(); let before: Vec<_> = tree_iter.by_ref().take_while(|it| it.as_token() != Some(&comma)).collect(); @@ -100,10 +98,7 @@ fn flip_tree(tree: ast::TokenTree, comma: SyntaxToken) -> (ast::TokenTree, Synta &after[next_end..after.len() - 1], ] .concat(); - - let make = SyntaxFactory::with_mappings(); - let new_token_tree = make.token_tree(tree.left_delimiter_token().unwrap().kind(), result); - (new_token_tree, make.finish_with_mappings()) + make.token_tree(tree.left_delimiter_token().unwrap().kind(), result) } #[cfg(test)] diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_or_pattern.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_or_pattern.rs index 4829f5bec206b..bd56331f4128b 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_or_pattern.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_or_pattern.rs @@ -32,7 +32,7 @@ pub(crate) fn flip_or_pattern(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt let target = pipe.text_range(); acc.add(AssistId::refactor_rewrite("flip_or_pattern"), "Flip patterns", target, |builder| { - let mut editor = builder.make_editor(parent.syntax()); + let editor = builder.make_editor(parent.syntax()); editor.replace(before.clone(), after.clone()); editor.replace(after, before); builder.add_file_edits(ctx.vfs_file_id(), editor); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs index 9756268c7cc33..dfd280efa6303 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/flip_trait_bound.rs @@ -33,7 +33,7 @@ pub(crate) fn flip_trait_bound(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op "Flip trait bounds", target, |builder| { - let mut editor = builder.make_editor(parent.syntax()); + let editor = builder.make_editor(parent.syntax()); editor.replace(before.clone(), after.clone()); editor.replace(after, before); builder.add_file_edits(ctx.vfs_file_id(), editor); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs index fccc04770e897..0bb90f187c684 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_blanket_trait_impl.rs @@ -13,7 +13,7 @@ use syntax::{ AstNode, ast::{ self, AssocItem, GenericParam, HasAttrs, HasGenericParams, HasName, HasTypeBounds, - HasVisibility, edit::AstNodeEdit, make, + HasVisibility, edit::AstNodeEdit, syntax_factory::SyntaxFactory, }, syntax_editor::Position, }; @@ -73,24 +73,25 @@ pub(crate) fn generate_blanket_trait_impl( "Generate blanket trait implementation", name.syntax().text_range(), |builder| { - let mut edit = builder.make_editor(traitd.syntax()); - let namety = make::ty_path(make::path_from_text(&name.text())); + let editor = builder.make_editor(traitd.syntax()); + let make = editor.make(); + let namety = make.ty_path(make.path_from_text(&name.text())); let trait_where_clause = traitd.where_clause().map(|it| it.reset_indent()); - let bounds = traitd.type_bound_list().and_then(exlucde_sized); + let bounds = traitd.type_bound_list().and_then(|list| exclude_sized(make, list)); let is_unsafe = traitd.unsafe_token().is_some(); - let thisname = this_name(&traitd); - let thisty = make::ty_path(make::path_from_text(&thisname.text())); + let thisname = this_name(make, &traitd); + let thisty = make.ty_path(make.path_from_text(&thisname.text())); let indent = traitd.indent_level(); - let gendecl = make::generic_param_list([GenericParam::TypeParam(make::type_param( + let gendecl = make.generic_param_list([GenericParam::TypeParam(make.type_param( thisname.clone(), - apply_sized(has_sized(&traitd, &ctx.sema), bounds), + apply_sized(make, has_sized(&traitd, &ctx.sema), bounds), ))]); let trait_gen_args = traitd.generic_param_list().map(|param_list| param_list.to_generic_args()); - let impl_ = make::impl_trait( + let impl_ = make.impl_trait( cfg_attrs(&traitd), is_unsafe, traitd.generic_param_list(), @@ -98,20 +99,19 @@ pub(crate) fn generate_blanket_trait_impl( Some(gendecl), None, false, - namety, - thisty.clone(), + namety.into(), + thisty.into(), trait_where_clause, None, None, - ) - .clone_for_update(); + ); if let Some(trait_assoc_list) = traitd.assoc_item_list() { - let assoc_item_list = impl_.get_or_create_assoc_item_list(); + let assoc_item_list = impl_.get_or_create_assoc_item_list_with_editor(&editor); for item in trait_assoc_list.assoc_items() { let item = match item { ast::AssocItem::Fn(method) if method.body().is_none() => { - todo_fn(&method, ctx.config).into() + todo_fn(make, &method, ctx.config).into() } ast::AssocItem::Const(_) | ast::AssocItem::TypeAlias(_) => item, _ => continue, @@ -122,10 +122,10 @@ pub(crate) fn generate_blanket_trait_impl( let impl_ = impl_.indent(indent); - edit.insert_all( + editor.insert_all( Position::after(traitd.syntax()), vec![ - make::tokens::whitespace(&format!("\n\n{indent}")).into(), + make.whitespace(&format!("\n\n{indent}")).into(), impl_.syntax().clone().into(), ], ); @@ -135,8 +135,7 @@ pub(crate) fn generate_blanket_trait_impl( { builder.add_tabstop_before(cap, self_ty); } - - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -212,22 +211,26 @@ fn where_clause_sized(where_clause: Option) -> Option { }) } -fn apply_sized(has_sized: bool, bounds: Option) -> Option { +fn apply_sized( + make: &SyntaxFactory, + has_sized: bool, + bounds: Option, +) -> Option { if has_sized { return bounds; } let bounds = bounds .into_iter() .flat_map(|bounds| bounds.bounds()) - .chain([make::type_bound_text("?Sized")]); - make::type_bound_list(bounds) + .chain([make.type_bound_text("?Sized")]); + make.type_bound_list(bounds) } -fn exlucde_sized(bounds: ast::TypeBoundList) -> Option { - make::type_bound_list(bounds.bounds().filter(|bound| !ty_bound_is(bound, "Sized"))) +fn exclude_sized(make: &SyntaxFactory, bounds: ast::TypeBoundList) -> Option { + make.type_bound_list(bounds.bounds().filter(|bound| !ty_bound_is(bound, "Sized"))) } -fn this_name(traitd: &ast::Trait) -> ast::Name { +fn this_name(make: &SyntaxFactory, traitd: &ast::Trait) -> ast::Name { let has_iter = find_bound("Iterator", traitd.type_bound_list()).is_some(); let params = traitd @@ -245,7 +248,7 @@ fn this_name(traitd: &ast::Trait) -> ast::Name { let mut name_gen = suggest_name::NameGenerator::new_with_names(params.iter().map(String::as_str)); - make::name(&name_gen.suggest_name(if has_iter { "I" } else { "T" })) + make.name(&name_gen.suggest_name(if has_iter { "I" } else { "T" })) } fn find_bound(s: &str, bounds: Option) -> Option { @@ -260,16 +263,16 @@ fn ty_bound_is(bound: &ast::TypeBound, s: &str) -> bool { .is_some_and(|name| name.text() == s)) } -fn todo_fn(f: &ast::Fn, config: &AssistConfig) -> ast::Fn { - let params = f.param_list().unwrap_or_else(|| make::param_list(None, None)); - make::fn_( +fn todo_fn(make: &SyntaxFactory, f: &ast::Fn, config: &AssistConfig) -> ast::Fn { + let params = f.param_list().unwrap_or_else(|| make.param_list(None, None)); + make.fn_( cfg_attrs(f), f.visibility(), - f.name().unwrap_or_else(|| make::name("unnamed")), + f.name().unwrap_or_else(|| make.name("unnamed")), f.generic_param_list(), f.where_clause(), params, - make::block_expr(None, Some(crate::utils::expr_fill_default(config))), + make.block_expr(None, Some(crate::utils::expr_fill_default(config))), f.ret_type(), f.async_token().is_some(), f.const_token().is_some(), diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs index 2d92bf5146227..739b63173694a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_default_from_new.rs @@ -73,12 +73,12 @@ pub(crate) fn generate_default_from_new(acc: &mut Assists, ctx: &AssistContext<' "Generate a Default impl from a new fn", target, move |builder| { - let make = SyntaxFactory::without_mappings(); - let default_impl = generate_default_impl(&make, &impl_, self_ty); + let editor = builder.make_editor(impl_.syntax()); + let make = editor.make(); + let default_impl = generate_default_impl(make, &impl_, self_ty); let indent = IndentLevel::from_node(impl_.syntax()); let default_impl = default_impl.indent(indent); - let mut editor = builder.make_editor(impl_.syntax()); editor.insert_all( Position::after(impl_.syntax()), vec![ diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs index 63033c7d5e398..9486aa6f01953 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_methods.rs @@ -4,7 +4,6 @@ use syntax::{ ast::{ self, AstNode, HasGenericParams, HasName, HasVisibility as _, edit::{AstNodeEdit, IndentLevel}, - syntax_factory::SyntaxFactory, }, syntax_editor::Position, }; @@ -107,7 +106,8 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' format!("Generate delegate for `{field_name}.{name}()`",), target, |edit| { - let make = SyntaxFactory::without_mappings(); + let editor = edit.make_editor(strukt.syntax()); + let make = editor.make(); let field = make .field_from_idents(["self", &field_name]) .expect("always be a valid expression"); @@ -145,7 +145,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' // compute the `body` let arg_list = method_source .param_list() - .map(|v| convert_param_list_to_arg_list(v, &make)) + .map(|v| convert_param_list_to_arg_list(v, make)) .unwrap_or_else(|| make.arg_list([])); let tail_expr = make.expr_method_call(field, make.name_ref(&name), arg_list).into(); @@ -173,12 +173,11 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' .indent(IndentLevel(1)); let item = ast::AssocItem::Fn(f.clone()); - let mut editor = edit.make_editor(strukt.syntax()); let fn_: Option = match impl_def { Some(impl_def) => match impl_def.assoc_item_list() { Some(assoc_item_list) => { let item = item.indent(IndentLevel::from_node(impl_def.syntax())); - assoc_item_list.add_items(&mut editor, vec![item.clone()]); + assoc_item_list.add_items(&editor, vec![item.clone()]); Some(item) } None => { @@ -229,7 +228,6 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let tabstop = edit.make_tabstop_before(cap); editor.add_annotation(fn_.syntax(), tabstop); } - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, )?; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_trait.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_trait.rs index abe447d9d9b7d..6639f10c1f360 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_trait.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_delegate_trait.rs @@ -494,7 +494,7 @@ fn remove_instantiated_params( } } -fn remove_useless_where_clauses(editor: &mut SyntaxEditor, delegate: &ast::Impl) { +fn remove_useless_where_clauses(editor: &SyntaxEditor, delegate: &ast::Impl) { let Some(wc) = delegate.where_clause() else { return; }; @@ -563,7 +563,7 @@ fn finalize_delegate( return Some(delegate.clone()); } - let (mut editor, delegate) = SyntaxEditor::with_ast_node(delegate); + let (editor, delegate) = SyntaxEditor::with_ast_node(delegate); // 1. Replace assoc_item_list if we have new items if let Some(items) = assoc_items @@ -577,7 +577,7 @@ fn finalize_delegate( // 2. Remove useless where clauses if remove_where_clauses { - remove_useless_where_clauses(&mut editor, &delegate); + remove_useless_where_clauses(&editor, &delegate); } ast::Impl::cast(editor.finish().new_root().clone()) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs index 5534dc1cd304f..a5bdf80ac725e 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_deref.rs @@ -2,7 +2,7 @@ use hir::{ModPath, ModuleDef}; use ide_db::{FileId, RootDatabase, famous_defs::FamousDefs}; use syntax::{ Edition, - ast::{self, AstNode, HasName, edit::AstNodeEdit, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, HasName, edit::AstNodeEdit}, syntax_editor::Position, }; @@ -138,7 +138,8 @@ fn generate_edit( trait_path: ModPath, edition: Edition, ) { - let make = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(strukt.syntax()); + let make = editor.make(); let strukt_adt = ast::Adt::Struct(strukt.clone()); let trait_ty = make.ty(&trait_path.display(db, edition).to_string()); @@ -195,15 +196,12 @@ fn generate_edit( let body = make.assoc_item_list(assoc_items); let indent = strukt.indent_level(); - let impl_ = generate_trait_impl_intransitive_with_item(&make, &strukt_adt, trait_ty, body) + let impl_ = generate_trait_impl_intransitive_with_item(make, &strukt_adt, trait_ty, body) .indent(indent); - - let mut editor = edit.make_editor(strukt.syntax()); editor.insert_all( Position::after(strukt.syntax()), vec![make.whitespace(&format!("\n\n{indent}")).into(), impl_.syntax().clone().into()], ); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(file_id, editor); } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs index 7aeb5e3396969..0129b1db396b2 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_derive.rs @@ -1,7 +1,7 @@ use syntax::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, T, - ast::{self, AstNode, HasAttrs, edit::IndentLevel, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, HasAttrs, edit::IndentLevel}, syntax_editor::{Element, Position}, }; @@ -42,17 +42,15 @@ pub(crate) fn generate_derive(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt }; acc.add(AssistId::generate("generate_derive"), "Add `#[derive]`", target, |edit| { - let make = SyntaxFactory::without_mappings(); - match derive_attr { None => { + let editor = edit.make_editor(nominal.syntax()); + let make = editor.make(); let derive = make.attr_outer(make.meta_token_tree( make.ident_path("derive"), make.token_tree(T!['('], vec![]), )); - - let mut editor = edit.make_editor(nominal.syntax()); let indent = IndentLevel::from_node(nominal.syntax()); let after_attrs_and_comments = nominal .syntax() diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs index b866022a7dfd6..e2783811f743b 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_is_method.rs @@ -1,12 +1,14 @@ use ide_db::assists::GroupLabel; -use itertools::Itertools; use stdx::to_lower_snake_case; -use syntax::ast::HasVisibility; -use syntax::ast::{self, AstNode, HasName}; +use syntax::{ + AstNode, Edition, + ast::{self, HasName, HasVisibility, edit::AstNodeEdit}, + syntax_editor::Position, +}; use crate::{ AssistContext, AssistId, Assists, - utils::{add_method_to_adt, find_struct_impl, is_selected}, + utils::{find_struct_impl, generate_impl_with_item, is_selected}, }; // Assist: generate_enum_is_method @@ -64,27 +66,63 @@ pub(crate) fn generate_enum_is_method(acc: &mut Assists, ctx: &AssistContext<'_> target, |builder| { let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{v} ")); - let method = methods + + let fn_items: Vec = methods .iter() - .map(|Method { pattern_suffix, fn_name, variant_name }| { - format!( - " \ - /// Returns `true` if the {enum_lowercase_name} is [`{variant_name}`]. - /// - /// [`{variant_name}`]: {enum_name}::{variant_name} - #[must_use] - {vis}fn {fn_name}(&self) -> bool {{ - matches!(self, Self::{variant_name}{pattern_suffix}) - }}", - ) - }) - .join("\n\n"); - - add_method_to_adt(builder, &parent_enum, impl_def, &method); + .map(|method| build_fn_item(method, &enum_lowercase_name, &enum_name, &vis)) + .collect(); + + if let Some(impl_def) = &impl_def { + let editor = builder.make_editor(impl_def.syntax()); + impl_def.assoc_item_list().unwrap().add_items(&editor, fn_items); + builder.add_file_edits(ctx.vfs_file_id(), editor); + return; + } + + let editor = builder.make_editor(parent_enum.syntax()); + let make = editor.make(); + let indent = parent_enum.indent_level(); + let assoc_list = make.assoc_item_list(fn_items); + let new_impl = generate_impl_with_item(make, &parent_enum, Some(assoc_list)); + editor.insert_all( + Position::after(parent_enum.syntax()), + vec![ + make.whitespace(&format!("\n\n{indent}")).into(), + new_impl.syntax().clone().into(), + ], + ); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } +fn build_fn_item( + method: &Method, + enum_lowercase_name: &str, + enum_name: &ast::Name, + vis: &str, +) -> ast::AssocItem { + let Method { pattern_suffix, fn_name, variant_name } = method; + let fn_text = format!( + "/// Returns `true` if the {enum_lowercase_name} is [`{variant_name}`]. +/// +/// [`{variant_name}`]: {enum_name}::{variant_name} +#[must_use] +{vis}fn {fn_name}(&self) -> bool {{ + matches!(self, Self::{variant_name}{pattern_suffix}) +}}" + ); + let wrapped = format!("impl X {{ {fn_text} }}"); + let parse = syntax::SourceFile::parse(&wrapped, Edition::CURRENT); + let fn_ = parse + .tree() + .syntax() + .descendants() + .find_map(ast::Fn::cast) + .expect("fn text must produce a valid fn node"); + ast::AssocItem::Fn(fn_.indent(1.into())) +} + struct Method { pattern_suffix: &'static str, fn_name: String, diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs index 39a6382b7cc71..9a97ad1e8fe75 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_projection_method.rs @@ -1,12 +1,14 @@ use ide_db::assists::GroupLabel; -use itertools::Itertools; use stdx::to_lower_snake_case; -use syntax::ast::HasVisibility; -use syntax::ast::{self, AstNode, HasName}; +use syntax::{ + AstNode, Edition, + ast::{self, HasName, HasVisibility, edit::AstNodeEdit}, + syntax_editor::Position, +}; use crate::{ AssistContext, AssistId, Assists, - utils::{add_method_to_adt, find_struct_impl, is_selected}, + utils::{find_struct_impl, generate_impl_with_item, is_selected}, }; // Assist: generate_enum_try_into_method @@ -116,15 +118,6 @@ fn generate_enum_projection_method( assist_description: &str, props: ProjectionProps, ) -> Option<()> { - let ProjectionProps { - fn_name_prefix, - self_param, - return_prefix, - return_suffix, - happy_case, - sad_case, - } = props; - let variant = ctx.find_node_at_offset::()?; let parent_enum = ast::Adt::Enum(variant.parent_enum()); let variants = variant @@ -135,7 +128,7 @@ fn generate_enum_projection_method( .collect::>(); let methods = variants .iter() - .map(|variant| Method::new(variant, fn_name_prefix)) + .map(|variant| Method::new(variant, props.fn_name_prefix)) .collect::>>()?; let fn_names = methods.iter().map(|it| it.fn_name.clone()).collect::>(); stdx::never!(variants.is_empty()); @@ -151,30 +144,66 @@ fn generate_enum_projection_method( target, |builder| { let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{v} ")); + let must_use = if ctx.config.assist_emit_must_use { "#[must_use]\n" } else { "" }; - let must_use = if ctx.config.assist_emit_must_use { "#[must_use]\n " } else { "" }; - - let method = methods + let fn_items: Vec = methods .iter() - .map(|Method { pattern_suffix, field_type, bound_name, fn_name, variant_name }| { - format!( - " \ - {must_use}{vis}fn {fn_name}({self_param}) -> {return_prefix}{field_type}{return_suffix} {{ - if let Self::{variant_name}{pattern_suffix} = self {{ - {happy_case}({bound_name}) - }} else {{ - {sad_case} - }} - }}" - ) - }) - .join("\n\n"); + .map(|method| build_fn_item(method, &vis, must_use, &props)) + .collect(); + + if let Some(impl_def) = &impl_def { + let editor = builder.make_editor(impl_def.syntax()); + impl_def.assoc_item_list().unwrap().add_items(&editor, fn_items); + builder.add_file_edits(ctx.vfs_file_id(), editor); + return; + } - add_method_to_adt(builder, &parent_enum, impl_def, &method); + let editor = builder.make_editor(parent_enum.syntax()); + let make = editor.make(); + let indent = parent_enum.indent_level(); + let assoc_list = make.assoc_item_list(fn_items); + let new_impl = generate_impl_with_item(make, &parent_enum, Some(assoc_list)); + editor.insert_all( + Position::after(parent_enum.syntax()), + vec![ + make.whitespace(&format!("\n\n{indent}")).into(), + new_impl.syntax().clone().into(), + ], + ); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } +fn build_fn_item( + method: &Method, + vis: &str, + must_use: &str, + props: &ProjectionProps, +) -> ast::AssocItem { + let Method { pattern_suffix, field_type, bound_name, fn_name, variant_name } = method; + let ProjectionProps { self_param, return_prefix, return_suffix, happy_case, sad_case, .. } = + props; + let fn_text = format!( + "{must_use}{vis}fn {fn_name}({self_param}) -> {return_prefix}{field_type}{return_suffix} {{ + if let Self::{variant_name}{pattern_suffix} = self {{ + {happy_case}({bound_name}) + }} else {{ + {sad_case} + }} +}}" + ); + let wrapped = format!("impl X {{ {fn_text} }}"); + let parse = syntax::SourceFile::parse(&wrapped, Edition::CURRENT); + let fn_ = parse + .tree() + .syntax() + .descendants() + .find_map(ast::Fn::cast) + .expect("fn text must produce a valid fn node"); + ast::AssocItem::Fn(fn_.indent(1.into())) +} + struct Method { pattern_suffix: String, field_type: ast::Type, @@ -185,6 +214,7 @@ struct Method { impl Method { fn new(variant: &ast::Variant, fn_name_prefix: &str) -> Option { + use itertools::Itertools as _; let variant_name = variant.name()?; let fn_name = format!("{fn_name_prefix}_{}", &to_lower_snake_case(&variant_name.text())); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs index 3514ebb811ee2..9b4d44d8b5177 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_enum_variant.rs @@ -59,12 +59,12 @@ pub(crate) fn generate_enum_variant(acc: &mut Assists, ctx: &AssistContext<'_>) let InRealFile { file_id, value: enum_node } = e.source(db)?.original_ast_node_rooted(db)?; acc.add(AssistId::generate("generate_enum_variant"), "Generate variant", target, |builder| { - let mut editor = builder.make_editor(enum_node.syntax()); - let make = SyntaxFactory::with_mappings(); - let field_list = parent.make_field_list(ctx, &make); + let editor = builder.make_editor(enum_node.syntax()); + let make = editor.make(); + let field_list = parent.make_field_list(ctx, make); let variant = make.variant(None, make.name(&name_ref.text()), field_list, None); if let Some(it) = enum_node.variant_list() { - it.add_variant(&mut editor, &variant); + it.add_variant(&editor, &variant); } builder.add_file_edits(file_id.file_id(ctx.db()), editor); }) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs index 6bcbd9b0ccc2a..55e5083811347 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_fn_type_alias.rs @@ -2,7 +2,7 @@ use either::Either; use ide_db::assists::{AssistId, GroupLabel}; use syntax::{ AstNode, - ast::{self, HasGenericParams, HasName, edit::IndentLevel, syntax_factory::SyntaxFactory}, + ast::{self, HasGenericParams, HasName, edit::IndentLevel}, syntax_editor, }; @@ -55,9 +55,8 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) style.label(), func_node.syntax().text_range(), |builder| { - let mut edit = builder.make_editor(func); - let make = SyntaxFactory::without_mappings(); - + let editor = builder.make_editor(func); + let make = editor.make(); let alias_name = format!("{}Fn", stdx::to_camel_case(&name.to_string())); let mut fn_params_vec = Vec::new(); @@ -104,7 +103,7 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ); let indent = IndentLevel::from_node(insertion_node); - edit.insert_all( + editor.insert_all( syntax_editor::Position::before(insertion_node), vec![ ty_alias.syntax().clone().into(), @@ -115,10 +114,10 @@ pub(crate) fn generate_fn_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) if let Some(cap) = ctx.config.snippet_cap && let Some(name) = ty_alias.name() { - edit.add_annotation(name.syntax(), builder.make_placeholder_snippet(cap)); + editor.add_annotation(name.syntax(), builder.make_placeholder_snippet(cap)); } - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs index 1adb3f4fe49a1..76246c3e8efd0 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_from_impl_for_enum.rs @@ -40,19 +40,18 @@ pub(crate) fn generate_from_impl_for_enum( "Generate `From` impl for this enum variant(s)", target, |edit| { - let make = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(adt.syntax()); + let make = editor.make(); let indent = adt.indent_level(); let mut elements = Vec::new(); for variant_info in variants { - let impl_ = build_from_impl(&make, &adt, variant_info).indent(indent); + let impl_ = build_from_impl(make, &adt, variant_info).indent(indent); elements.push(make.whitespace(&format!("\n\n{indent}")).into()); elements.push(impl_.syntax().clone().into()); } - let mut editor = edit.make_editor(adt.syntax()); editor.insert_all(Position::after(adt.syntax()), elements); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(file_id, editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs index fbf6241e43a38..6ef492619b50c 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_function.rs @@ -1166,10 +1166,10 @@ fn next_space_for_fn_after_call_site(expr: ast::CallableExpr) -> Option { break; } - SyntaxKind::ITEM_LIST => { - if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) { - break; - } + SyntaxKind::ITEM_LIST + if ancestors.peek().map(|a| a.kind()) == Some(SyntaxKind::MODULE) => + { + break; } _ => {} } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter_or_setter.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter_or_setter.rs index 4cd018d02d029..b884581041f45 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter_or_setter.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_getter_or_setter.rs @@ -410,35 +410,37 @@ fn parse_record_field( Some(RecordFieldInfo { field_name, field_ty, fn_name, target }) } -fn build_source_change( - builder: &mut SourceChangeBuilder, +fn items( ctx: &AssistContext<'_>, info_of_record_fields: Vec, - assist_info: AssistInfo, -) { - let syntax_factory = SyntaxFactory::without_mappings(); - - let items: Vec = info_of_record_fields + assist_info: &AssistInfo, + make: &SyntaxFactory, +) -> Vec { + info_of_record_fields .iter() .map(|record_field_info| { let method = match assist_info.assist_type { - AssistType::Set => { - generate_setter_from_info(&assist_info, record_field_info, &syntax_factory) - } - _ => { - generate_getter_from_info(ctx, &assist_info, record_field_info, &syntax_factory) - } + AssistType::Set => generate_setter_from_info(assist_info, record_field_info, make), + _ => generate_getter_from_info(ctx, assist_info, record_field_info, make), }; let new_fn = method; let new_fn = new_fn.indent(1.into()); new_fn.into() }) - .collect(); + .collect() +} +fn build_source_change( + builder: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + info_of_record_fields: Vec, + assist_info: AssistInfo, +) { if let Some(impl_def) = &assist_info.impl_def { // We have an existing impl to add to - let mut editor = builder.make_editor(impl_def.syntax()); - impl_def.assoc_item_list().unwrap().add_items(&mut editor, items.clone()); + let editor = builder.make_editor(impl_def.syntax()); + let items = items(ctx, info_of_record_fields, &assist_info, editor.make()); + impl_def.assoc_item_list().unwrap().add_items(&editor, items.clone()); if let Some(cap) = ctx.config.snippet_cap && let Some(ast::AssocItem::Fn(fn_)) = items.last() @@ -451,22 +453,23 @@ fn build_source_change( builder.add_file_edits(ctx.vfs_file_id(), editor); return; } + + let editor = builder.make_editor(assist_info.strukt.syntax()); + let make = editor.make(); + let items = items(ctx, info_of_record_fields, &assist_info, make); let ty_params = assist_info.strukt.generic_param_list(); let ty_args = ty_params.as_ref().map(|it| it.to_generic_args()); - let impl_def = syntax_factory.impl_( + let impl_def = make.impl_( None, ty_params, ty_args, - syntax_factory - .ty_path(syntax_factory.ident_path(&assist_info.strukt.name().unwrap().to_string())) - .into(), + make.ty_path(make.ident_path(&assist_info.strukt.name().unwrap().to_string())).into(), None, - Some(syntax_factory.assoc_item_list(items)), + Some(make.assoc_item_list(items)), ); - let mut editor = builder.make_editor(assist_info.strukt.syntax()); editor.insert_all( Position::after(assist_info.strukt.syntax()), - vec![syntax_factory.whitespace("\n\n").into(), impl_def.syntax().clone().into()], + vec![make.whitespace("\n\n").into(), impl_def.syntax().clone().into()], ); if let Some(cap) = ctx.config.snippet_cap diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs index af123eeaa0ce8..c5a46f6981f59 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_impl.rs @@ -1,7 +1,5 @@ use syntax::{ - ast::{ - self, AstNode, HasGenericParams, HasName, edit::AstNodeEdit, syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, HasGenericParams, HasName, edit::AstNodeEdit}, syntax_editor::{Position, SyntaxEditor}, }; @@ -13,12 +11,8 @@ use crate::{ }, }; -fn insert_impl( - editor: &mut SyntaxEditor, - make: &SyntaxFactory, - impl_: &ast::Impl, - nominal: &impl AstNodeEdit, -) -> ast::Impl { +fn insert_impl(editor: &SyntaxEditor, impl_: &ast::Impl, nominal: &impl AstNodeEdit) -> ast::Impl { + let make = editor.make(); let indent = nominal.indent_level(); let impl_ = impl_.indent(indent); @@ -65,13 +59,11 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio format!("Generate impl for `{name}`"), target, |edit| { - let make = SyntaxFactory::with_mappings(); - // Generate the impl - let impl_ = generate_impl_with_factory(&make, &nominal); - - let mut editor = edit.make_editor(nominal.syntax()); + let editor = edit.make_editor(nominal.syntax()); + let make = editor.make(); + let impl_ = generate_impl_with_factory(make, &nominal); - let impl_ = insert_impl(&mut editor, &make, &impl_, &nominal); + let impl_ = insert_impl(&editor, &impl_, &nominal); // Add a tabstop after the left curly brace if let Some(cap) = ctx.config.snippet_cap && let Some(l_curly) = impl_.assoc_item_list().and_then(|it| it.l_curly_token()) @@ -79,8 +71,6 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio let tabstop = edit.make_tabstop_after(cap); editor.add_annotation(l_curly, tabstop); } - - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -117,13 +107,10 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Generate trait impl for `{name}`"), target, |edit| { - let make = SyntaxFactory::with_mappings(); - // Generate the impl - let impl_ = generate_trait_impl_intransitive(&make, &nominal, make.ty_placeholder()); - - let mut editor = edit.make_editor(nominal.syntax()); - - let impl_ = insert_impl(&mut editor, &make, &impl_, &nominal); + let editor = edit.make_editor(nominal.syntax()); + let make = editor.make(); + let impl_ = generate_trait_impl_intransitive(make, &nominal, make.ty_placeholder()); + let impl_ = insert_impl(&editor, &impl_, &nominal); // Make the trait type a placeholder snippet if let Some(cap) = ctx.config.snippet_cap { if let Some(trait_) = impl_.trait_() { @@ -136,8 +123,6 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> editor.add_annotation(l_curly, tabstop); } } - - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -176,8 +161,8 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Generate `{name}` impl for type"), target, |edit| { - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(trait_.syntax()); + let editor = edit.make_editor(trait_.syntax()); + let make = editor.make(); let holder_arg = ast::GenericArg::TypeArg(make.type_arg(make.ty_placeholder())); let missing_items = utils::filter_assoc_items( @@ -213,7 +198,7 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> } else { let impl_ = make_impl_(None); let assoc_items = add_trait_assoc_items_to_impl( - &make, + make, &ctx.sema, ctx.config, &missing_items, @@ -225,8 +210,7 @@ pub(crate) fn generate_impl_trait(acc: &mut Assists, ctx: &AssistContext<'_>) -> make_impl_(Some(assoc_item_list)) }; - let impl_ = insert_impl(&mut editor, &make, &impl_, &trait_); - editor.add_mappings(make.finish_with_mappings()); + let impl_ = insert_impl(&editor, &impl_, &trait_); if let Some(cap) = ctx.config.snippet_cap { if let Some(generics) = impl_.trait_().and_then(|it| it.generic_arg_list()) { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs index 31e49c8ce48e7..acf0819222d53 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_mut_trait_impl.rs @@ -67,10 +67,9 @@ pub(crate) fn generate_mut_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_> format!("Generate `{trait_new}` impl from this `{trait_name}` trait"), target, |edit| { - let (mut editor, impl_clone) = SyntaxEditor::with_ast_node(&impl_def.reset_indent()); - let factory = SyntaxFactory::without_mappings(); + let (editor, impl_clone) = SyntaxEditor::with_ast_node(&impl_def.reset_indent()); - apply_generate_mut_impl(&mut editor, &factory, &impl_clone, trait_new); + apply_generate_mut_impl(&editor, &impl_clone, trait_new); let new_root = editor.finish(); let new_root = new_root.new_root(); @@ -79,12 +78,13 @@ pub(crate) fn generate_mut_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_> let new_impl = new_impl.indent(indent); - let mut editor = edit.make_editor(impl_def.syntax()); + let editor = edit.make_editor(impl_def.syntax()); + let make = editor.make(); editor.insert_all( Position::before(impl_def.syntax()), vec![ new_impl.syntax().syntax_element(), - factory.whitespace(&format!("\n\n{indent}")).syntax_element(), + make.whitespace(&format!("\n\n{indent}")).syntax_element(), ], ); @@ -98,7 +98,7 @@ pub(crate) fn generate_mut_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_> ) } -fn delete_with_trivia(editor: &mut SyntaxEditor, node: &SyntaxNode) { +fn delete_with_trivia(editor: &SyntaxEditor, node: &SyntaxNode) { let mut end: SyntaxElement = node.clone().into(); if let Some(next) = node.next_sibling_or_token() @@ -112,23 +112,23 @@ fn delete_with_trivia(editor: &mut SyntaxEditor, node: &SyntaxNode) { } fn apply_generate_mut_impl( - editor: &mut SyntaxEditor, - factory: &SyntaxFactory, + editor: &SyntaxEditor, impl_def: &ast::Impl, trait_new: &str, ) -> Option<()> { + let make = editor.make(); let path = impl_def.trait_().and_then(|t| t.syntax().descendants().find_map(ast::Path::cast))?; let seg = path.segment()?; let name_ref = seg.name_ref()?; - let new_name_ref = factory.name_ref(trait_new); + let new_name_ref = make.name_ref(trait_new); editor.replace(name_ref.syntax(), new_name_ref.syntax()); if let Some((name, new_name)) = impl_def.syntax().descendants().filter_map(ast::Name::cast).find_map(process_method_name) { - let new_name_node = factory.name(new_name); + let new_name_node = make.name(new_name); editor.replace(name.syntax(), new_name_node.syntax()); } @@ -137,14 +137,14 @@ fn apply_generate_mut_impl( } if let Some(self_param) = impl_def.syntax().descendants().find_map(ast::SelfParam::cast) { - let mut_self = factory.mut_self_param(); + let mut_self = make.mut_self_param(); editor.replace(self_param.syntax(), mut_self.syntax()); } if let Some(ret_type) = impl_def.syntax().descendants().find_map(ast::RetType::cast) - && let Some(new_ty) = process_ret_type(factory, &ret_type) + && let Some(new_ty) = process_ret_type(make, &ret_type) { - let new_ret = factory.ret_type(new_ty); + let new_ret = make.ret_type(new_ty); editor.replace(ret_type.syntax(), new_ret.syntax()) } @@ -154,13 +154,14 @@ fn apply_generate_mut_impl( _ => None, }) }) { - process_ref_mut(editor, factory, &fn_); + process_ref_mut(editor, &fn_); } Some(()) } -fn process_ref_mut(editor: &mut SyntaxEditor, factory: &SyntaxFactory, fn_: &ast::Fn) { +fn process_ref_mut(editor: &SyntaxEditor, fn_: &ast::Fn) { + let make = editor.make(); let Some(expr) = fn_.body().and_then(|b| b.tail_expr()) else { return }; let ast::Expr::RefExpr(ref_expr) = expr else { return }; @@ -171,8 +172,8 @@ fn process_ref_mut(editor: &mut SyntaxEditor, factory: &SyntaxFactory, fn_: &ast let Some(amp) = ref_expr.amp_token() else { return }; - let mut_kw = factory.token(T![mut]); - let space = factory.whitespace(" "); + let mut_kw = make.token(T![mut]); + let space = make.whitespace(" "); editor.insert(Position::after(amp.clone()), space.syntax_element()); editor.insert(Position::after(amp), mut_kw.syntax_element()); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs index 301d13c095842..520709adc5e29 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_new.rs @@ -3,10 +3,7 @@ use ide_db::{ use_trivial_constructor::use_trivial_constructor, }; use syntax::{ - ast::{ - self, AstNode, HasName, HasVisibility, StructKind, edit::AstNodeEdit, - syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, HasName, HasVisibility, StructKind, edit::AstNodeEdit}, syntax_editor::Position, }; @@ -38,11 +35,9 @@ use crate::{ // ``` pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let strukt = ctx.find_node_at_offset::()?; - - let make = SyntaxFactory::without_mappings(); - let field_list = match strukt.kind() { + let field_list: Vec<(String, ast::Type)> = match strukt.kind() { StructKind::Record(named) => { - named.fields().filter_map(|f| Some((f.name()?, f.ty()?))).collect::>() + named.fields().filter_map(|f| Some((f.name()?.to_string(), f.ty()?))).collect() } StructKind::Tuple(tuple) => { let mut name_generator = NameGenerator::default(); @@ -56,12 +51,12 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option ctx.db(), ctx.edition(), ) { - Some(name) => name, - None => name_generator.suggest_name(&format!("_{i}")), + Some(name) => name.to_string(), + None => name_generator.suggest_name(&format!("_{i}")).to_string(), }; - Some((make.name(name.as_str()), f.ty()?)) + Some((name, ty)) }) - .collect::>() + .collect() } StructKind::Unit => return None, }; @@ -74,7 +69,9 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let target = strukt.syntax().text_range(); acc.add(AssistId::generate("generate_new"), "Generate `new`", target, |builder| { - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(strukt.syntax()); + let make = editor.make(); + let trivial_constructors = field_list .iter() .map(|(name, ty)| { @@ -100,13 +97,13 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option edition, )?; - Some((make.name_ref(&name.text()), Some(expr))) + Some((make.name_ref(name), Some(expr))) }) .collect::>(); let params = field_list.iter().enumerate().filter_map(|(i, (name, ty))| { if trivial_constructors[i].is_none() { - Some(make.param(make.ident_pat(false, false, name.clone()).into(), ty.clone())) + Some(make.param(make.ident_pat(false, false, make.name(name)).into(), ty.clone())) } else { None } @@ -117,7 +114,7 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option if let Some(constructor) = trivial_constructors[i].clone() { constructor } else { - (make.name_ref(&name.text()), None) + (make.name_ref(name), None) } }); @@ -158,8 +155,6 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option ) .indent(1.into()); - let mut editor = builder.make_editor(strukt.syntax()); - // Get the node for set annotation let contain_fn = if let Some(impl_def) = impl_def { let fn_ = fn_.indent(impl_def.indent_level()); @@ -185,7 +180,7 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let indent_level = strukt.indent_level(); let list = make.assoc_item_list([ast::AssocItem::Fn(fn_)]); let impl_def = - generate_impl_with_item(&make, &ast::Adt::Struct(strukt.clone()), Some(list)) + generate_impl_with_item(make, &ast::Adt::Struct(strukt.clone()), Some(list)) .indent(strukt.indent_level()); // Insert it after the adt @@ -235,8 +230,6 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option editor.add_annotation(name.syntax(), tabstop_before); } } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_single_field_struct_from.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_single_field_struct_from.rs index 7746cdc068a1f..10c009a2ea440 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_single_field_struct_from.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_single_field_struct_from.rs @@ -80,8 +80,8 @@ pub(crate) fn generate_single_field_struct_from( "Generate single field `From`", strukt.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(strukt.syntax()); + let editor = builder.make_editor(strukt.syntax()); + let make = editor.make(); let indent = strukt.indent_level(); let ty_where_clause = strukt.where_clause(); @@ -95,7 +95,7 @@ pub(crate) fn generate_single_field_struct_from( let ty = make.ty(&strukt_name.text()); let constructor = - make_adt_constructor(names.as_deref(), constructors, &main_field_name, &make); + make_adt_constructor(names.as_deref(), constructors, &main_field_name, make); let body = make.block_expr([], Some(constructor)); let fn_ = make @@ -119,7 +119,7 @@ pub(crate) fn generate_single_field_struct_from( false, false, ) - .indent_with_mapping(1.into(), &make); + .indent_with_mapping(1.into(), make); let cfg_attrs = strukt.attrs().filter(|attr| matches!(attr.meta(), Some(ast::Meta::CfgMeta(_)))); @@ -139,13 +139,12 @@ pub(crate) fn generate_single_field_struct_from( None, ); - let (mut impl_editor, impl_root) = SyntaxEditor::with_ast_node(&impl_); - let assoc_list = - impl_root.get_or_create_assoc_item_list_with_editor(&mut impl_editor, &make); - assoc_list.add_items(&mut impl_editor, vec![fn_.into()]); + let (impl_editor, impl_root) = SyntaxEditor::with_ast_node(&impl_); + let assoc_list = impl_root.get_or_create_assoc_item_list_with_editor(&impl_editor); + assoc_list.add_items(&impl_editor, vec![fn_.into()]); let impl_ = ast::Impl::cast(impl_editor.finish().new_root().clone()) .unwrap() - .indent_with_mapping(indent, &make); + .indent_with_mapping(indent, make); editor.insert_all( Position::after(strukt.syntax()), @@ -154,8 +153,6 @@ pub(crate) fn generate_single_field_struct_from( impl_.syntax().clone().into(), ], ); - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_trait_from_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_trait_from_impl.rs index 2d3d05849b0ba..049398de8c559 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_trait_from_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/generate_trait_from_impl.rs @@ -3,7 +3,7 @@ use ide_db::assists::AssistId; use syntax::{ AstNode, AstToken, SyntaxKind, T, ast::{ - self, HasDocComments, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, make, + self, HasDocComments, HasGenericParams, HasName, HasVisibility, edit::AstNodeEdit, syntax_factory::SyntaxFactory, }, syntax_editor::{Position, SyntaxEditor}, @@ -99,34 +99,35 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ impl_ast.syntax().text_range(), |builder| { let trait_items: ast::AssocItemList = { - let (mut trait_items_editor, trait_items) = + let (trait_items_editor, trait_items) = SyntaxEditor::with_ast_node(&impl_assoc_items); trait_items.assoc_items().for_each(|item| { - strip_body(&mut trait_items_editor, &item); - remove_items_visibility(&mut trait_items_editor, &item); + strip_body(&trait_items_editor, &item); + remove_items_visibility(&trait_items_editor, &item); }); ast::AssocItemList::cast(trait_items_editor.finish().new_root().clone()).unwrap() }; - let factory = SyntaxFactory::with_mappings(); - let trait_ast = factory.trait_( + let editor = builder.make_editor(impl_ast.syntax()); + let make = editor.make(); + let trait_ast = make.trait_( false, - &trait_name(&impl_assoc_items).text(), + &trait_name(&impl_assoc_items, make).text(), impl_ast.generic_param_list(), impl_ast.where_clause(), trait_items, ); let trait_name = trait_ast.name().expect("new trait should have a name"); - let trait_name_ref = factory.name_ref(&trait_name.to_string()); + let trait_name_ref = make.name_ref(&trait_name.to_string()); // Change `impl Foo` to `impl NewTrait for Foo` let mut elements = vec![ trait_name_ref.syntax().clone().into(), - make::tokens::single_space().into(), - make::token(T![for]).into(), - make::tokens::single_space().into(), + make.whitespace(" ").into(), + make.token(T![for]).into(), + make.whitespace(" ").into(), ]; if let Some(params) = impl_ast.generic_param_list() { @@ -134,10 +135,9 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ elements.insert(1, gen_args.syntax().clone().into()); } - let mut editor = builder.make_editor(impl_ast.syntax()); impl_assoc_items.assoc_items().for_each(|item| { - remove_items_visibility(&mut editor, &item); - remove_doc_comments(&mut editor, &item); + remove_items_visibility(&editor, &item); + remove_doc_comments(&editor, &item); }); editor.insert_all(Position::before(impl_name.syntax()), elements); @@ -147,7 +147,7 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ Position::before(impl_ast.syntax()), vec![ trait_ast.syntax().clone().into(), - make::tokens::whitespace(&format!("\n\n{}", impl_ast.indent_level())).into(), + make.whitespace(&format!("\n\n{}", impl_ast.indent_level())).into(), ], ); @@ -157,8 +157,6 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ editor.add_annotation(trait_name.syntax(), placeholder); editor.add_annotation(trait_name_ref.syntax(), placeholder); } - - editor.add_mappings(factory.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -166,20 +164,20 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ Some(()) } -fn trait_name(items: &ast::AssocItemList) -> ast::Name { +fn trait_name(items: &ast::AssocItemList, make: &SyntaxFactory) -> ast::Name { let mut fn_names = items .assoc_items() .filter_map(|x| if let ast::AssocItem::Fn(f) = x { f.name() } else { None }); fn_names .next() .and_then(|name| { - fn_names.next().is_none().then(|| make::name(&stdx::to_camel_case(&name.text()))) + fn_names.next().is_none().then(|| make.name(&stdx::to_camel_case(&name.text()))) }) - .unwrap_or_else(|| make::name("NewTrait")) + .unwrap_or_else(|| make.name("NewTrait")) } /// `E0449` Trait items always share the visibility of their trait -fn remove_items_visibility(editor: &mut SyntaxEditor, item: &ast::AssocItem) { +fn remove_items_visibility(editor: &SyntaxEditor, item: &ast::AssocItem) { if let Some(has_vis) = ast::AnyHasVisibility::cast(item.syntax().clone()) { if let Some(vis) = has_vis.visibility() && let Some(token) = vis.syntax().next_sibling_or_token() @@ -193,7 +191,7 @@ fn remove_items_visibility(editor: &mut SyntaxEditor, item: &ast::AssocItem) { } } -fn remove_doc_comments(editor: &mut SyntaxEditor, item: &ast::AssocItem) { +fn remove_doc_comments(editor: &SyntaxEditor, item: &ast::AssocItem) { for doc in item.doc_comments() { if let Some(next) = doc.syntax().next_token() && next.kind() == SyntaxKind::WHITESPACE @@ -204,7 +202,8 @@ fn remove_doc_comments(editor: &mut SyntaxEditor, item: &ast::AssocItem) { } } -fn strip_body(editor: &mut SyntaxEditor, item: &ast::AssocItem) { +fn strip_body(editor: &SyntaxEditor, item: &ast::AssocItem) { + let make = editor.make(); if let ast::AssocItem::Fn(f) = item && let Some(body) = f.body() { @@ -216,7 +215,7 @@ fn strip_body(editor: &mut SyntaxEditor, item: &ast::AssocItem) { editor.delete(prev); } - editor.replace(body.syntax(), make::tokens::semicolon()); + editor.replace(body.syntax(), make.token(T![;])); }; } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs index f55ef4229e587..2af074f1fcdfe 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs @@ -7,7 +7,7 @@ use ide_db::{ }; use syntax::{ Direction, TextRange, - ast::{self, AstNode, AstToken, HasName, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, AstToken, HasName}, syntax_editor::{Element, SyntaxEditor}, }; @@ -83,7 +83,8 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) "Inline variable", target.text_range(), move |builder| { - let mut editor = builder.make_editor(&target); + let editor = builder.make_editor(&target); + let make = editor.make(); if delete_let { editor.delete(let_stmt.syntax()); @@ -91,15 +92,13 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) && let Some(op_token) = bin_expr.op_token() { editor.delete(&op_token); - remove_whitespace(op_token, Direction::Prev, &mut editor); - remove_whitespace(let_stmt.syntax(), Direction::Prev, &mut editor); + remove_whitespace(op_token, Direction::Prev, &editor); + remove_whitespace(let_stmt.syntax(), Direction::Prev, &editor); } else { - remove_whitespace(let_stmt.syntax(), Direction::Next, &mut editor); + remove_whitespace(let_stmt.syntax(), Direction::Next, &editor); } } - let make = SyntaxFactory::with_mappings(); - for (name, should_wrap) in wrap_in_parens { let replacement = if should_wrap { make.expr_paren(initializer_expr.clone()).into() @@ -115,8 +114,6 @@ pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) editor.replace(name.syntax(), replacement.syntax()); } } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -204,7 +201,7 @@ fn inline_usage( Some(InlineData { let_stmt, delete_let, target: ast::NameOrNameRef::NameRef(name), references }) } -fn remove_whitespace(elem: impl Element, dir: Direction, editor: &mut SyntaxEditor) { +fn remove_whitespace(elem: impl Element, dir: Direction, editor: &SyntaxEditor) { let token = match elem.syntax_element() { syntax::NodeOrToken::Node(node) => match dir { Direction::Next => node.last_token(), @@ -1054,6 +1051,7 @@ fn f() { check_assist( inline_local_variable, r#" +//- minicore: fn fn main() { let $0f = || 2; let _ = f(); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs index 4b60f0ac1e3cf..6d8750afdcff1 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_type_alias.rs @@ -70,7 +70,7 @@ pub(crate) fn inline_type_alias_uses(acc: &mut Assists, ctx: &AssistContext<'_>) let mut inline_refs_for_file = |file_id, refs: Vec| { let source = ctx.sema.parse(file_id); - let mut editor = builder.make_editor(source.syntax()); + let editor = builder.make_editor(source.syntax()); let (path_types, path_type_uses) = split_refs_and_uses(builder, refs, |path_type| { @@ -101,7 +101,7 @@ pub(crate) fn inline_type_alias_uses(acc: &mut Assists, ctx: &AssistContext<'_>) inline_refs_for_file(file_id, refs); } if !definition_deleted { - let mut editor = builder.make_editor(ast_alias.syntax()); + let editor = builder.make_editor(ast_alias.syntax()); editor.delete(ast_alias.syntax()); builder.add_file_edits(ctx.vfs_file_id(), editor) } @@ -156,7 +156,7 @@ pub(crate) fn inline_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) -> O "Inline type alias", alias_instance.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(alias_instance.syntax()); + let editor = builder.make_editor(alias_instance.syntax()); let replace = replacement.replace_generic(&concrete_type); editor.replace(alias_instance.syntax(), replace); builder.add_file_edits(ctx.vfs_file_id(), editor); @@ -312,8 +312,8 @@ fn create_replacement( const_and_type_map: &ConstAndTypeMap, concrete_type: &ast::Type, ) -> SyntaxNode { - let (mut editor, updated_concrete_type) = SyntaxEditor::new(concrete_type.syntax().clone()); - + let (editor, updated_concrete_type) = SyntaxEditor::new(concrete_type.syntax().clone()); + let make = editor.make(); let mut replacements: Vec<(SyntaxNode, SyntaxNode)> = Vec::new(); let mut removals: Vec> = Vec::new(); @@ -368,7 +368,6 @@ fn create_replacement( }; let new_string = replacement_syntax.to_string(); let new = if new_string == "_" { - let make = SyntaxFactory::without_mappings(); make.wildcard_pat().syntax().clone() } else { replacement_syntax.clone() diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs index 5e8ea7daff90d..2cbeae1d19c25 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_lifetime.rs @@ -1,7 +1,7 @@ use ide_db::{FileId, FxHashSet}; use syntax::{ AstNode, SmolStr, T, TextRange, ToSmolStr, - ast::{self, HasGenericParams, HasName, syntax_factory::SyntaxFactory}, + ast::{self, HasGenericParams, HasName}, format_smolstr, syntax_editor::{Element, Position, SyntaxEditor}, }; @@ -97,23 +97,23 @@ fn generate_fn_def_assist( }; acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |edit| { - let mut editor = edit.make_editor(fn_def.syntax()); - let factory = SyntaxFactory::with_mappings(); + let editor = edit.make_editor(fn_def.syntax()); + let make = editor.make(); if let Some(generic_list) = fn_def.generic_param_list() { - insert_lifetime_param(&mut editor, &factory, &generic_list, &new_lifetime_name); + insert_lifetime_param(&editor, &generic_list, &new_lifetime_name); } else { - insert_new_generic_param_list_fn(&mut editor, &factory, &fn_def, &new_lifetime_name); + insert_new_generic_param_list_fn(&editor, &fn_def, &new_lifetime_name); } - editor.replace(lifetime.syntax(), factory.lifetime(&new_lifetime_name).syntax()); + editor.replace(lifetime.syntax(), make.lifetime(&new_lifetime_name).syntax()); if let Some(pos) = loc_needing_lifetime.and_then(|l| l.to_position()) { editor.insert_all( pos, vec![ - factory.lifetime(&new_lifetime_name).syntax().clone().into(), - factory.whitespace(" ").into(), + make.lifetime(&new_lifetime_name).syntax().clone().into(), + make.whitespace(" ").into(), ], ); } @@ -123,19 +123,19 @@ fn generate_fn_def_assist( } fn insert_new_generic_param_list_fn( - editor: &mut SyntaxEditor, - factory: &SyntaxFactory, + editor: &SyntaxEditor, fn_def: &ast::Fn, lifetime_name: &str, ) -> Option<()> { + let make = editor.make(); let name = fn_def.name()?; editor.insert_all( Position::after(name.syntax()), vec![ - factory.token(T![<]).syntax_element(), - factory.lifetime(lifetime_name).syntax().syntax_element(), - factory.token(T![>]).syntax_element(), + make.token(T![<]).syntax_element(), + make.lifetime(lifetime_name).syntax().syntax_element(), + make.token(T![>]).syntax_element(), ], ); @@ -166,35 +166,35 @@ fn generate_impl_def_assist( let new_lifetime_name = generate_unique_lifetime_param_name(impl_def.generic_param_list())?; acc.add(AssistId::refactor(ASSIST_NAME), ASSIST_LABEL, lifetime_loc, |edit| { - let mut editor = edit.make_editor(impl_def.syntax()); - let factory = SyntaxFactory::without_mappings(); + let editor = edit.make_editor(impl_def.syntax()); + let make = editor.make(); if let Some(generic_list) = impl_def.generic_param_list() { - insert_lifetime_param(&mut editor, &factory, &generic_list, &new_lifetime_name); + insert_lifetime_param(&editor, &generic_list, &new_lifetime_name); } else { - insert_new_generic_param_list_imp(&mut editor, &factory, &impl_def, &new_lifetime_name); + insert_new_generic_param_list_imp(&editor, &impl_def, &new_lifetime_name); } - editor.replace(lifetime.syntax(), factory.lifetime(&new_lifetime_name).syntax()); + editor.replace(lifetime.syntax(), make.lifetime(&new_lifetime_name).syntax()); edit.add_file_edits(file_id, editor); }) } fn insert_new_generic_param_list_imp( - editor: &mut SyntaxEditor, - factory: &SyntaxFactory, + editor: &SyntaxEditor, impl_: &ast::Impl, lifetime_name: &str, ) -> Option<()> { + let make = editor.make(); let impl_kw = impl_.impl_token()?; editor.insert_all( Position::after(impl_kw), vec![ - factory.token(T![<]).syntax_element(), - factory.lifetime(lifetime_name).syntax().syntax_element(), - factory.token(T![>]).syntax_element(), + make.token(T![<]).syntax_element(), + make.lifetime(lifetime_name).syntax().syntax_element(), + make.token(T![>]).syntax_element(), ], ); @@ -202,22 +202,22 @@ fn insert_new_generic_param_list_imp( } fn insert_lifetime_param( - editor: &mut SyntaxEditor, - factory: &SyntaxFactory, + editor: &SyntaxEditor, generic_list: &ast::GenericParamList, lifetime_name: &str, ) -> Option<()> { + let make = editor.make(); let r_angle = generic_list.r_angle_token()?; let needs_comma = generic_list.generic_params().next().is_some(); let mut elements = Vec::new(); if needs_comma { - elements.push(factory.token(T![,]).syntax_element()); - elements.push(factory.whitespace(" ").syntax_element()); + elements.push(make.token(T![,]).syntax_element()); + elements.push(make.whitespace(" ").syntax_element()); } - let lifetime = factory.lifetime(lifetime_name); + let lifetime = make.lifetime(lifetime_name); elements.push(lifetime.syntax().clone().into()); editor.insert_all(Position::before(r_angle), elements); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_type_parameter.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_type_parameter.rs index db51070a6430b..95f223420b0b8 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_type_parameter.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/introduce_named_type_parameter.rs @@ -1,6 +1,6 @@ use ide_db::syntax_helpers::suggest_name; use itertools::Itertools; -use syntax::ast::{self, AstNode, HasGenericParams, HasName, syntax_factory::SyntaxFactory}; +use syntax::ast::{self, AstNode, HasGenericParams, HasName}; use crate::{AssistContext, AssistId, Assists}; @@ -24,14 +24,14 @@ pub(crate) fn introduce_named_type_parameter( let fn_ = param.syntax().ancestors().nth(2).and_then(ast::Fn::cast)?; let type_bound_list = impl_trait_type.type_bound_list()?; - let make = SyntaxFactory::with_mappings(); let target = fn_.syntax().text_range(); acc.add( AssistId::refactor_rewrite("introduce_named_type_parameter"), "Replace impl trait with type parameter", target, |builder| { - let mut editor = builder.make_editor(fn_.syntax()); + let editor = builder.make_editor(fn_.syntax()); + let make = editor.make(); let existing_names = match fn_.generic_param_list() { Some(generic_param_list) => generic_param_list @@ -58,7 +58,6 @@ pub(crate) fn introduce_named_type_parameter( editor.add_annotation(type_param.syntax(), builder.make_tabstop_before(cap)); } - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs index 42bc05811fd14..1dd0833fad03d 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/merge_imports.rs @@ -4,11 +4,7 @@ use ide_db::imports::{ merge_imports::{MergeBehavior, try_merge_imports, try_merge_trees}, }; use syntax::{ - AstNode, SyntaxElement, SyntaxNode, - algo::neighbor, - ast::{self, syntax_factory::SyntaxFactory}, - match_ast, - syntax_editor::Removable, + AstNode, SyntaxElement, SyntaxNode, algo::neighbor, ast, match_ast, syntax_editor::Removable, }; use crate::{ @@ -76,17 +72,16 @@ pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio }; acc.add(AssistId::refactor_rewrite("merge_imports"), "Merge imports", target, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(&parent_node); + let editor = builder.make_editor(&parent_node); for edit in edits { match edit { Remove(it) => { let node = it.as_ref(); if let Some(left) = node.left() { - left.remove(&mut editor); + left.remove(&editor); } else if let Some(right) = node.right() { - right.remove(&mut editor); + right.remove(&editor); } } Replace(old, new) => { @@ -94,7 +89,6 @@ pub(crate) fn merge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio } } } - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs index 79b8bd5d3d489..e044068ff7578 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_bounds.rs @@ -45,22 +45,21 @@ pub(crate) fn move_bounds_to_where_clause( "Move to where clause", target, |builder| { - let mut edit = builder.make_editor(&parent); - let make = SyntaxFactory::without_mappings(); + let editor = builder.make_editor(&parent); let new_preds: Vec = type_param_list .generic_params() - .filter_map(|param| build_predicate(param, &make)) + .filter_map(|param| build_predicate(param, editor.make())) .collect(); match_ast! { match (&parent) { - ast::Fn(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), - ast::Trait(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), - ast::Impl(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), - ast::Enum(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), - ast::Struct(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), - ast::TypeAlias(it) => it.get_or_create_where_clause(&mut edit, &make, new_preds.into_iter()), + ast::Fn(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), + ast::Trait(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), + ast::Impl(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), + ast::Enum(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), + ast::Struct(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), + ast::TypeAlias(it) => it.get_or_create_where_clause(&editor, new_preds.into_iter()), _ => return, } }; @@ -72,11 +71,11 @@ pub(crate) fn move_bounds_to_where_clause( ast::GenericParam::ConstParam(_) => continue, }; if let Some(tbl) = param.type_bound_list() { - tbl.remove(&mut edit); + tbl.remove(&editor); } } - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_const_to_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_const_to_impl.rs index b3e79e4663e9d..86bdf3f8b4eb7 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_const_to_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_const_to_impl.rs @@ -138,7 +138,6 @@ pub(crate) fn move_const_to_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> let fixup = if last_const.is_none() { "\n" } else { "" }; let indent = IndentLevel::from_node(parent_fn.syntax()); - let const_ = const_.clone_for_update(); let const_ = const_.reset_indent(); let const_ = const_.indent(indent); builder.insert(insert_offset, format!("\n{indent}{const_}{fixup}")); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs index 80587372e5165..7309cc6d06a5c 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/move_guard.rs @@ -73,24 +73,24 @@ pub(crate) fn move_guard_to_arm_body(acc: &mut Assists, ctx: &AssistContext<'_>) "Move guard to arm body", target, |builder| { - let mut edit = builder.make_editor(match_arm.syntax()); + let editor = builder.make_editor(match_arm.syntax()); for element in space_before_delete { if element.kind() == WHITESPACE { - edit.delete(element); + editor.delete(element); } } for rest_arm in &rest_arms { - edit.delete(rest_arm.syntax()); + editor.delete(rest_arm.syntax()); } if let Some(element) = space_after_arrow && element.kind() == WHITESPACE { - edit.replace(element, make.whitespace(" ")); + editor.replace(element, make.whitespace(" ")); } - edit.delete(guard.syntax()); - edit.replace(arm_expr.syntax(), if_expr.syntax()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + editor.delete(guard.syntax()); + editor.replace(arm_expr.syntax(), if_expr.syntax()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } @@ -156,7 +156,8 @@ pub(crate) fn move_arm_cond_to_match_guard( "Move condition to match guard", replace_node.text_range(), |builder| { - let make = SyntaxFactory::without_mappings(); + let editor = builder.make_editor(match_arm.syntax()); + let make = editor.make(); let mut replace_arms = vec![]; // Dedent if if_expr is in a BlockExpr @@ -227,14 +228,12 @@ pub(crate) fn move_arm_cond_to_match_guard( } } - let mut edit = builder.make_editor(match_arm.syntax()); - let newline = make.whitespace(&format!("\n{indent_level}")); let replace_arms = replace_arms.iter().map(|it| it.syntax().syntax_element()); let replace_arms = Itertools::intersperse(replace_arms, newline.syntax_element()); - edit.replace_with_many(match_arm.syntax(), replace_arms.collect()); + editor.replace_with_many(match_arm.syntax(), replace_arms.collect()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs index 483c90d1032e5..ed61d32eb657f 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/promote_local_to_const.rs @@ -3,7 +3,7 @@ use ide_db::{assists::AssistId, defs::Definition}; use stdx::to_upper_snake_case; use syntax::{ AstNode, - ast::{self, HasName, syntax_factory::SyntaxFactory}, + ast::{self, HasName}, }; use crate::{ @@ -68,8 +68,8 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>) "Promote local to constant", let_stmt.syntax().text_range(), |edit| { - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(let_stmt.syntax()); + let editor = edit.make_editor(let_stmt.syntax()); + let make = editor.make(); let name = to_upper_snake_case(&name.to_string()); let usages = Definition::Local(local).usages(&ctx.sema).all(); if let Some(usages) = usages.references.get(&ctx.file_id()) { @@ -97,7 +97,6 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>) editor.replace(let_stmt.syntax(), item.syntax()); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs index 74ed2e14fa239..082052c9d42dc 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs @@ -1,10 +1,5 @@ use either::Either; -use syntax::{ - AstNode, - algo::find_node_at_range, - ast::{self, syntax_factory::SyntaxFactory}, - syntax_editor::SyntaxEditor, -}; +use syntax::{AstNode, algo::find_node_at_range, ast, syntax_editor::SyntaxEditor}; use crate::{ AssistId, @@ -75,8 +70,7 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> } let target = tgt.syntax().text_range(); - let (mut editor, edit_tgt) = SyntaxEditor::new(tgt.syntax().clone()); - + let (editor, edit_tgt) = SyntaxEditor::new(tgt.syntax().clone()); let assignments: Vec<_> = collector .assignments .into_iter() @@ -110,13 +104,12 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> "Pull assignment up", target, move |edit| { - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(tgt.syntax()); + let editor = edit.make_editor(tgt.syntax()); + let make = editor.make(); let assign_expr = make.expr_assignment(collector.common_lhs, new_tgt.clone()); let assign_stmt = make.expr_stmt(assign_expr.into()); editor.replace(tgt.syntax(), assign_stmt.syntax()); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs index 8b9e6570e917b..d7885d50651c0 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_method_call.rs @@ -1,6 +1,6 @@ use hir::{AsAssocItem, AssocItem, AssocItemContainer, ItemInNs, ModuleDef, db::HirDatabase}; use ide_db::assists::AssistId; -use syntax::{AstNode, ast, ast::syntax_factory::SyntaxFactory}; +use syntax::{AstNode, ast}; use crate::{ assist_context::{AssistContext, Assists}, @@ -59,17 +59,8 @@ pub(crate) fn qualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Qualify `{ident}` method call"), range, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(call.syntax()); - qualify_candidate.qualify( - |_| {}, - &mut editor, - &make, - &receiver_path, - item_in_ns, - current_edition, - ); - editor.add_mappings(make.finish_with_mappings()); + let editor = builder.make_editor(call.syntax()); + qualify_candidate.qualify(|_| {}, &editor, &receiver_path, item_in_ns, current_edition); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs index c059f758c47e2..e3dd77360cca3 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/qualify_path.rs @@ -9,11 +9,7 @@ use ide_db::{ }; use syntax::Edition; use syntax::ast::HasGenericArgs; -use syntax::{ - AstNode, ast, - ast::{HasArgList, syntax_factory::SyntaxFactory}, - syntax_editor::SyntaxEditor, -}; +use syntax::{AstNode, ast, ast::HasArgList, syntax_editor::SyntaxEditor}; use crate::{ AssistId, GroupLabel, @@ -102,17 +98,14 @@ pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option label(ctx.db(), candidate, &import, current_edition), range, |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(&syntax_under_caret); + let editor = builder.make_editor(&syntax_under_caret); qualify_candidate.qualify( |replace_with: String| builder.replace(range, replace_with), - &mut editor, - &make, + &editor, &import.import_path, import.item_to_import, current_edition, ); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -131,8 +124,7 @@ impl QualifyCandidate<'_> { pub(crate) fn qualify( &self, mut replacer: impl FnMut(String), - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, import: &hir::ModPath, item: hir::ItemInNs, edition: Edition, @@ -151,10 +143,10 @@ impl QualifyCandidate<'_> { replacer(format!("<{qualifier} as {import}>::{segment}")); } QualifyCandidate::TraitMethod(db, mcall_expr) => { - Self::qualify_trait_method(db, mcall_expr, editor, make, import, item); + Self::qualify_trait_method(db, mcall_expr, editor, import, item); } QualifyCandidate::ImplMethod(db, mcall_expr, hir_fn) => { - Self::qualify_fn_call(db, mcall_expr, editor, make, import, hir_fn); + Self::qualify_fn_call(db, mcall_expr, editor, import, hir_fn); } } } @@ -162,11 +154,11 @@ impl QualifyCandidate<'_> { fn qualify_fn_call( db: &RootDatabase, mcall_expr: &ast::MethodCallExpr, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, import: ast::Path, hir_fn: &hir::Function, ) -> Option<()> { + let make = editor.make(); let receiver = mcall_expr.receiver()?; let method_name = mcall_expr.name_ref()?; let generics = @@ -193,15 +185,14 @@ impl QualifyCandidate<'_> { fn qualify_trait_method( db: &RootDatabase, mcall_expr: &ast::MethodCallExpr, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, import: ast::Path, item: hir::ItemInNs, ) -> Option<()> { let trait_method_name = mcall_expr.name_ref()?; let trait_ = item_as_trait(db, item)?; let method = find_trait_method(db, trait_, &trait_method_name)?; - Self::qualify_fn_call(db, mcall_expr, editor, make, import, &method) + Self::qualify_fn_call(db, mcall_expr, editor, import, &method) } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs index d6d99b8b6d9dc..8234a0374e777 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/raw_string.rs @@ -1,7 +1,7 @@ use ide_db::source_change::SourceChangeBuilder; use syntax::{ - AstToken, - ast::{self, IsString, make::tokens::literal}, + AstNode, AstToken, + ast::{self, IsString}, }; use crate::{ @@ -162,23 +162,18 @@ fn replace_literal( builder: &mut SourceChangeBuilder, ctx: &AssistContext<'_>, ) { - let token = token.syntax(); - let node = token.parent().expect("no parent token"); - let mut edit = builder.make_editor(&node); - let new_literal = literal(new); - - edit.replace(token, mut_token(new_literal)); - - builder.add_file_edits(ctx.vfs_file_id(), edit); -} - -fn mut_token(token: syntax::SyntaxToken) -> syntax::SyntaxToken { - let node = token.parent().expect("no parent token"); - node.clone_for_update() - .children_with_tokens() - .filter_map(|it| it.into_token()) - .find(|it| it.text_range() == token.text_range() && it.text() == token.text()) - .unwrap() + let old_token = token.syntax(); + let parent = old_token.parent().expect("no parent token"); + let editor = builder.make_editor(&parent); + let make = editor.make(); + let new_literal = make.expr_literal(new); + let new_token = new_literal + .syntax() + .first_child_or_token() + .and_then(|it| it.into_token()) + .expect("literal has no token child"); + editor.replace(old_token, new_token); + builder.add_file_edits(ctx.vfs_file_id(), editor); } #[cfg(test)] diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs index f4c354b8a21d8..778533be5a005 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_dbg.rs @@ -1,7 +1,7 @@ use itertools::Itertools; use syntax::{ Edition, NodeOrToken, SyntaxNode, SyntaxToken, T, - ast::{self, AstNode, make}, + ast::{self, AstNode, syntax_factory::SyntaxFactory}, match_ast, syntax_editor::{Position, SyntaxEditor}, }; @@ -24,6 +24,8 @@ use crate::{AssistContext, AssistId, Assists}; // } // ``` pub(crate) fn remove_dbg(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); + let make = editor.make(); let macro_calls = if ctx.has_empty_selection() { vec![ctx.find_node_at_offset::()?] } else { @@ -39,15 +41,16 @@ pub(crate) fn remove_dbg(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( .collect() }; - let replacements = - macro_calls.into_iter().filter_map(compute_dbg_replacement).collect::>(); + let replacements = macro_calls + .into_iter() + .filter_map(|macro_expr| compute_dbg_replacement(macro_expr, make)) + .collect::>(); let target = replacements .iter() .flat_map(|(node_or_token, _)| node_or_token.iter()) .map(|t| t.text_range()) .reduce(|acc, range| acc.cover(range))?; acc.add(AssistId::quick_fix("remove_dbg"), "Remove dbg!()", target, |builder| { - let mut editor = builder.make_editor(ctx.source_file().syntax()); for (range, expr) in replacements { if let Some(expr) = expr { editor.insert(Position::before(range[0].clone()), expr.syntax()); @@ -68,6 +71,7 @@ pub(crate) fn remove_dbg(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( /// Returns `Some(_, None)` when the macro call should just be removed. fn compute_dbg_replacement( macro_expr: ast::MacroExpr, + make: &SyntaxFactory, ) -> Option<(Vec>, Option)> { let macro_call = macro_expr.macro_call()?; let tt = macro_call.token_tree()?; @@ -110,7 +114,7 @@ fn compute_dbg_replacement( } (replace, None) }, - _ => (vec![macro_call.syntax().clone().into()], Some(make::ext::expr_unit())), + _ => (vec![macro_call.syntax().clone().into()], Some(make.expr_unit())), } } } @@ -162,14 +166,14 @@ fn compute_dbg_replacement( }, None => false, }; - let expr = replace_nested_dbgs(expr.clone()); - let expr = if wrap { make::expr_paren(expr).into() } else { expr }; + let expr = replace_nested_dbgs(expr.clone(), make); + let expr = if wrap { make.expr_paren(expr).into() } else { expr }; (vec![macro_call.syntax().clone().into()], Some(expr)) } // dbg!(expr0, expr1, ...) exprs => { - let exprs = exprs.iter().cloned().map(replace_nested_dbgs); - let expr = make::expr_tuple(exprs); + let exprs = exprs.iter().cloned().map(|expr| replace_nested_dbgs(expr, make)); + let expr = make.expr_tuple(exprs); (vec![macro_call.syntax().clone().into()], Some(expr.into())) } }) @@ -189,12 +193,12 @@ fn pure_expr(expr: &ast::Expr) -> bool { } } -fn replace_nested_dbgs(expanded: ast::Expr) -> ast::Expr { +fn replace_nested_dbgs(expanded: ast::Expr, make: &SyntaxFactory) -> ast::Expr { if let ast::Expr::MacroExpr(mac) = &expanded { // Special-case when `expanded` itself is `dbg!()` since we cannot replace the whole tree // with `ted`. It should be fairly rare as it means the user wrote `dbg!(dbg!(..))` but you // never know how code ends up being! - let replaced = if let Some((_, expr_opt)) = compute_dbg_replacement(mac.clone()) { + let replaced = if let Some((_, expr_opt)) = compute_dbg_replacement(mac.clone(), make) { match expr_opt { Some(expr) => expr, None => { @@ -209,13 +213,13 @@ fn replace_nested_dbgs(expanded: ast::Expr) -> ast::Expr { return replaced; } - let (mut editor, expanded) = SyntaxEditor::with_ast_node(&expanded); + let (editor, expanded) = SyntaxEditor::with_ast_node(&expanded); // We need to collect to avoid mutation during traversal. let macro_exprs: Vec<_> = expanded.syntax().descendants().filter_map(ast::MacroExpr::cast).collect(); for mac in macro_exprs { - let expr_opt = match compute_dbg_replacement(mac.clone()) { + let expr_opt = match compute_dbg_replacement(mac.clone(), make) { Some((_, expr)) => expr, None => continue, }; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_else_branches.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_else_branches.rs index 6a02c37015d33..0c03856417a0a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_else_branches.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_else_branches.rs @@ -55,7 +55,7 @@ pub(crate) fn remove_else_branches(acc: &mut Assists, ctx: &AssistContext<'_>) - "Remove `else` branches", target, |builder| { - let mut editor = builder.make_editor(&else_token.parent().unwrap()); + let editor = builder.make_editor(&else_token.parent().unwrap()); match else_token.prev_token() { Some(it) if it.kind() == SyntaxKind::WHITESPACE => editor.delete(it), _ => (), diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs index b07a361adf48e..2a6024339f605 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_mut.rs @@ -22,7 +22,7 @@ pub(crate) fn remove_mut(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<( let target = mut_token.text_range(); acc.add(AssistId::refactor("remove_mut"), "Remove `mut` keyword", target, |builder| { - let mut editor = builder.make_editor(&mut_token.parent().unwrap()); + let editor = builder.make_editor(&mut_token.parent().unwrap()); match mut_token.next_token() { Some(it) if it.kind() == SyntaxKind::WHITESPACE => editor.delete(it), _ => (), diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_parentheses.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_parentheses.rs index f07da489e23ae..af249c97b9c42 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_parentheses.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_parentheses.rs @@ -1,8 +1,4 @@ -use syntax::{ - AstNode, SyntaxKind, T, - ast::{self, syntax_factory::SyntaxFactory}, - syntax_editor::Position, -}; +use syntax::{AstNode, SyntaxKind, T, ast, syntax_editor::Position}; use crate::{AssistContext, AssistId, Assists}; @@ -44,7 +40,8 @@ pub(crate) fn remove_parentheses(acc: &mut Assists, ctx: &AssistContext<'_>) -> "Remove redundant parentheses", target, |builder| { - let mut editor = builder.make_editor(parens.syntax()); + let editor = builder.make_editor(parens.syntax()); + let make = editor.make(); let prev_token = parens.syntax().first_token().and_then(|it| it.prev_token()); let need_to_add_ws = match prev_token { Some(it) => { @@ -54,9 +51,7 @@ pub(crate) fn remove_parentheses(acc: &mut Assists, ctx: &AssistContext<'_>) -> None => false, }; if need_to_add_ws { - let make = SyntaxFactory::with_mappings(); editor.insert(Position::before(parens.syntax()), make.whitespace(" ")); - editor.add_mappings(make.finish_with_mappings()); } editor.replace(parens.syntax(), expr.syntax()); builder.add_file_edits(ctx.vfs_file_id(), editor); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs index 8b824c7c7f497..b91d678c9371f 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/remove_unused_param.rs @@ -80,7 +80,7 @@ pub(crate) fn remove_unused_param(acc: &mut Assists, ctx: &AssistContext<'_>) -> "Remove unused parameter", param.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(&parent); + let editor = builder.make_editor(&parent); let elements = elements_to_remove(param.syntax()); for element in elements { editor.delete(element); @@ -116,7 +116,7 @@ fn process_usages( else { continue; }; - let mut editor = builder.make_editor(&parent); + let editor = builder.make_editor(&parent); for element in element_range { editor.delete(element); } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs index 990677d372139..facbab8019b2a 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs @@ -71,14 +71,12 @@ pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti "Reorder record fields", target, |builder| { - let mut editor = builder.make_editor(&parent_node); + let editor = builder.make_editor(&parent_node); match fields { - Either::Left((sorted, field_list)) => { - replace(&mut editor, field_list.fields(), sorted) - } + Either::Left((sorted, field_list)) => replace(&editor, field_list.fields(), sorted), Either::Right((sorted, field_list)) => { - replace(&mut editor, field_list.fields(), sorted) + replace(&editor, field_list.fields(), sorted) } } @@ -88,7 +86,7 @@ pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti } fn replace( - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, fields: impl Iterator, sorted_fields: impl IntoIterator, ) { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs index 0ad5ec9d44246..df5281895abdc 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_impl_items.rs @@ -99,7 +99,7 @@ pub(crate) fn reorder_impl_items(acc: &mut Assists, ctx: &AssistContext<'_>) -> "Sort items by trait definition", target, |builder| { - let mut editor = builder.make_editor(&parent_node); + let editor = builder.make_editor(&parent_node); assoc_items .into_iter() diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_arith_op.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_arith_op.rs index b686dc056667c..5ad5efac05c24 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_arith_op.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_arith_op.rs @@ -1,7 +1,7 @@ use ide_db::assists::{AssistId, GroupLabel}; use syntax::{ AstNode, - ast::{self, ArithOp, BinaryOp, syntax_factory::SyntaxFactory}, + ast::{self, ArithOp, BinaryOp}, }; use crate::assist_context::{AssistContext, Assists}; @@ -83,8 +83,8 @@ fn replace_arith(acc: &mut Assists, ctx: &AssistContext<'_>, kind: ArithKind) -> kind.label(), op_expr.text_range(), |builder| { - let mut edit = builder.make_editor(rhs.syntax()); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(rhs.syntax()); + let make = editor.make(); let method_name = kind.method_name(op); let needs_parentheses = @@ -92,10 +92,8 @@ fn replace_arith(acc: &mut Assists, ctx: &AssistContext<'_>, kind: ArithKind) -> let receiver = if needs_parentheses { make.expr_paren(lhs).into() } else { lhs }; let arith_expr = make.expr_method_call(receiver, make.name_ref(&method_name), make.arg_list([rhs])); - edit.replace(op_expr, arith_expr.syntax()); - - edit.add_mappings(make.finish_with_mappings()); - builder.add_file_edits(ctx.vfs_file_id(), edit); + editor.replace(op_expr, arith_expr.syntax()); + builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs index 5e595218f6b1a..751cd42f6e8df 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use syntax::{ SyntaxKind::WHITESPACE, T, - ast::{self, AstNode, HasName, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, HasName}, syntax_editor::{Position, SyntaxEditor}, }; @@ -128,10 +128,12 @@ fn add_assist( let label = format!("Convert to manual `impl {replace_trait_path} for {annotated_name}`"); acc.add(AssistId::refactor("replace_derive_with_manual_impl"), label, target, |builder| { - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(attr.syntax()); + let make = editor.make(); let insert_after = Position::after(adt.syntax()); let impl_is_unsafe = trait_.map(|s| s.is_unsafe(ctx.db())).unwrap_or(false); let impl_def = impl_def_from_trait( + &editor, &ctx.sema, ctx.config, adt, @@ -140,9 +142,7 @@ fn add_assist( replace_trait_path, impl_is_unsafe, ); - - let mut editor = builder.make_editor(attr.syntax()); - update_attribute(&make, &mut editor, old_derives, old_tree, old_trait_path, attr); + update_attribute(&editor, old_derives, old_tree, old_trait_path, attr); let trait_path = make.ty_path(replace_trait_path.clone()).into(); @@ -152,7 +152,7 @@ fn add_assist( impl_def.assoc_item_list().and_then(|list| list.assoc_items().next()), ) } else { - (generate_trait_impl(&make, impl_is_unsafe, adt, trait_path), None) + (generate_trait_impl(make, impl_is_unsafe, adt, trait_path), None) }; if let Some(cap) = ctx.config.snippet_cap { @@ -178,12 +178,12 @@ fn add_assist( insert_after, vec![make.whitespace("\n\n").into(), impl_def.syntax().clone().into()], ); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } fn impl_def_from_trait( + editor: &SyntaxEditor, sema: &hir::Semantics<'_, ide_db::RootDatabase>, config: &AssistConfig, adt: &ast::Adt, @@ -192,6 +192,7 @@ fn impl_def_from_trait( trait_path: &ast::Path, impl_is_unsafe: bool, ) -> Option { + let make = editor.make(); let trait_ = trait_?; let target_scope = sema.scope(annotated_name.syntax())?; @@ -208,12 +209,11 @@ fn impl_def_from_trait( if trait_items.is_empty() { return None; } - let make = SyntaxFactory::without_mappings(); let trait_ty: ast::Type = make.ty_path(trait_path.clone()).into(); - let impl_def = generate_trait_impl(&make, impl_is_unsafe, adt, trait_ty.clone()); + let impl_def = generate_trait_impl(make, impl_is_unsafe, adt, trait_ty.clone()); let assoc_items = add_trait_assoc_items_to_impl( - &make, + make, sema, config, &trait_items, @@ -223,10 +223,10 @@ fn impl_def_from_trait( ); let assoc_item_list = if let Some((first, other)) = assoc_items.split_first() { let first_item = if let ast::AssocItem::Fn(func) = first - && let Some(body) = gen_trait_fn_body(&make, func, trait_path, adt, None) + && let Some(body) = gen_trait_fn_body(make, func, trait_path, adt, None) && let Some(func_body) = func.body() { - let (mut editor, _) = SyntaxEditor::new(first.syntax().clone()); + let (editor, _) = SyntaxEditor::new(first.syntax().clone()); editor.replace(func_body.syntax(), body.syntax()); ast::AssocItem::cast(editor.finish().new_root().clone()) } else { @@ -239,17 +239,17 @@ fn impl_def_from_trait( make.assoc_item_list_empty() }; - Some(generate_trait_impl_with_item(&make, impl_is_unsafe, adt, trait_ty, assoc_item_list)) + Some(generate_trait_impl_with_item(make, impl_is_unsafe, adt, trait_ty, assoc_item_list)) } fn update_attribute( - make: &SyntaxFactory, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, old_derives: &[ast::Path], old_tree: &ast::TokenTree, old_trait_path: &ast::Path, attr: &ast::Attr, ) { + let make = editor.make(); let new_derives = old_derives .iter() .filter(|t| t.to_string() != old_trait_path.to_string()) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs index ada2fd9b217ad..0badad7d0cbe6 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_if_let_with_match.rs @@ -111,9 +111,10 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' format!("Replace if{let_} with match"), available_range, move |builder| { - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(if_expr.syntax()); + let make = editor.make(); let match_expr: ast::Expr = { - let else_arm = make_else_arm(ctx, &make, else_block, &cond_bodies); + let else_arm = make_else_arm(ctx, make, else_block, &cond_bodies); let make_match_arm = |(pat, guard, body): (_, Option, ast::BlockExpr)| { // Dedent from original position, then indent for match arm @@ -131,6 +132,11 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' }; let arms = cond_bodies.into_iter().map(make_match_arm).chain([else_arm]); let expr = scrutinee_to_be_expr.reset_indent(); + let expr = if match_scrutinee_needs_paren(&expr) { + make.expr_paren(expr).into() + } else { + expr + }; let match_expr = make.expr_match(expr, make.match_arm_list(arms)).indent(indent); match_expr.into() }; @@ -146,10 +152,7 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext<' } else { match_expr }; - - let mut editor = builder.make_editor(if_expr.syntax()); editor.replace(if_expr.syntax(), expr.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -267,7 +270,8 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' format!("Replace match with if{let_}"), match_expr.syntax().text_range(), move |builder| { - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(match_expr.syntax()); + let make = editor.make(); let make_block_expr = |expr: ast::Expr| { // Blocks with modifiers (unsafe, async, etc.) are parsed as BlockExpr, but are // formatted without enclosing braces. If we encounter such block exprs, @@ -292,7 +296,7 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' _ => make.expr_let(if_let_pat, scrutinee).into(), }; let condition = if let Some(guard) = guard { - let guard = wrap_paren(guard, &make, ast::prec::ExprPrecedence::LAnd); + let guard = wrap_paren(guard, make, ast::prec::ExprPrecedence::LAnd); make.expr_bin(condition, ast::BinaryOp::LogicOp(ast::LogicOp::And), guard).into() } else { condition @@ -309,9 +313,7 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext<' ) .indent(IndentLevel::from_node(match_expr.syntax())); - let mut editor = builder.make_editor(match_expr.syntax()); editor.replace(match_expr.syntax(), if_let_expr.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -402,27 +404,37 @@ fn let_and_guard(cond: &ast::Expr) -> (Option, Option) } else if let ast::Expr::BinExpr(bin_expr) = cond && let Some(ast::Expr::LetExpr(let_expr)) = and_bin_expr_left(bin_expr).lhs() { - let (mut edit, new_expr) = SyntaxEditor::with_ast_node(bin_expr); - + let (editor, new_expr) = SyntaxEditor::with_ast_node(bin_expr); let left_bin = and_bin_expr_left(&new_expr); if let Some(rhs) = left_bin.rhs() { - edit.replace(left_bin.syntax(), rhs.syntax()); + editor.replace(left_bin.syntax(), rhs.syntax()); } else { if let Some(next) = left_bin.syntax().next_sibling_or_token() && next.kind() == SyntaxKind::WHITESPACE { - edit.delete(next); + editor.delete(next); } - edit.delete(left_bin.syntax()); + editor.delete(left_bin.syntax()); } - let new_expr = edit.finish().new_root().clone(); + let new_expr = editor.finish().new_root().clone(); (Some(let_expr), ast::Expr::cast(new_expr)) } else { (None, Some(cond.clone())) } } +fn match_scrutinee_needs_paren(expr: &ast::Expr) -> bool { + let make = SyntaxFactory::without_mappings(); + let fake_scrutinee = make.expr_unit(); + let fake_match = make.expr_match(fake_scrutinee, make.match_arm_list(std::iter::empty())); + let Some(fake_expr) = fake_match.expr() else { + stdx::never!(); + return false; + }; + expr.needs_parens_in_place_of(fake_match.syntax(), fake_expr.syntax()) +} + fn and_bin_expr_left(expr: &ast::BinExpr) -> ast::BinExpr { if expr.op_kind() == Some(ast::BinaryOp::LogicOp(ast::LogicOp::And)) && let Some(ast::Expr::BinExpr(left)) = expr.lhs() @@ -451,6 +463,26 @@ fn main() { ) } + #[test] + fn test_if_with_match_paren_jump_scrutinee() { + check_assist( + replace_if_let_with_match, + r#" +fn f() { + if $0(return) {} +} +"#, + r#" +fn f() { + match (return) { + true => {} + false => (), + } +} +"#, + ) + } + #[test] fn test_if_with_match_no_else() { check_assist( diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs index 38d8c38ef23d6..802d5f72b9973 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_is_method_with_if_let_method.rs @@ -1,6 +1,9 @@ use either::Either; use ide_db::syntax_helpers::suggest_name; -use syntax::ast::{self, AstNode, HasArgList, prec::ExprPrecedence, syntax_factory::SyntaxFactory}; +use syntax::{ + ast::{self, AstNode, HasArgList, prec::ExprPrecedence, syntax_factory::SyntaxFactory}, + syntax_editor::SyntaxEditor, +}; use crate::{ AssistContext, AssistId, Assists, @@ -41,9 +44,9 @@ pub(crate) fn replace_is_method_with_if_let_method( let method_kind = token.text().strip_suffix("_and").unwrap_or(token.text()); match method_kind { "is_some" | "is_ok" => { + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); + let make = editor.make(); let receiver = call_expr.receiver()?; - let make = SyntaxFactory::with_mappings(); - let mut name_generator = suggest_name::NameGenerator::new_from_scope_locals( ctx.sema.scope(has_cond.syntax()), ); @@ -52,7 +55,7 @@ pub(crate) fn replace_is_method_with_if_let_method( } else { name_generator.for_variable(&receiver, &ctx.sema) }; - let (pat, predicate) = method_predicate(&call_expr, &var_name, &make); + let (pat, predicate) = method_predicate(&call_expr, &var_name, make); let (assist_id, message, text) = if method_kind == "is_some" { ("replace_is_some_with_if_let_some", "Replace `is_some` with `let Some`", "Some") @@ -65,8 +68,7 @@ pub(crate) fn replace_is_method_with_if_let_method( message, call_expr.syntax().text_range(), |edit| { - let mut editor = edit.make_editor(call_expr.syntax()); - + let make = editor.make(); let pat = make.tuple_struct_pat(make.ident_path(text), [pat]).into(); let let_expr = make.expr_let(pat, receiver); @@ -81,14 +83,12 @@ pub(crate) fn replace_is_method_with_if_let_method( let new_expr = if let Some(predicate) = predicate { let op = ast::BinaryOp::LogicOp(ast::LogicOp::And); - let predicate = wrap_paren(predicate, &make, ExprPrecedence::LAnd); + let predicate = wrap_paren(predicate, make, ExprPrecedence::LAnd); make.expr_bin(let_expr.into(), op, predicate).into() } else { ast::Expr::from(let_expr) }; editor.replace(call_expr.syntax(), new_expr.syntax()); - - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs index 6ff5f0bbd30cf..85e72130e0a6c 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_let_with_if_let.rs @@ -46,8 +46,8 @@ pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_> "Replace let with if let", target, |builder| { - let mut editor = builder.make_editor(let_stmt.syntax()); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(let_stmt.syntax()); + let make = editor.make(); let ty = ctx.sema.type_of_expr(&init); let pat = if let_stmt.let_else().is_some() { // Do not add the wrapper type that implements `Try`, @@ -79,7 +79,6 @@ pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_> let if_stmt = make.expr_stmt(if_expr.into()); editor.replace(let_stmt.syntax(), if_stmt.syntax()); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -88,6 +87,11 @@ pub(crate) fn replace_let_with_if_let(acc: &mut Assists, ctx: &AssistContext<'_> fn let_expr_needs_paren(expr: &ast::Expr) -> bool { let make = SyntaxFactory::without_mappings(); let fake_expr_let = make.expr_let(make.tuple_pat(None).into(), make.expr_unit()); + let fake_if = make.expr_if(fake_expr_let.into(), make.expr_empty_block(), None); + let Some(ast::Expr::LetExpr(fake_expr_let)) = fake_if.condition() else { + stdx::never!(); + return false; + }; let Some(fake_expr) = fake_expr_let.expr() else { stdx::never!(); return false; @@ -182,6 +186,24 @@ fn main() { ) } + #[test] + fn replace_let_record_expr() { + check_assist( + replace_let_with_if_let, + r" +fn main() { + $0let x = Foo { x }; +} + ", + r" +fn main() { + if let x = (Foo { x }) { + } +} + ", + ) + } + #[test] fn replace_let_else() { check_assist( diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs index 018642a047232..17ef7727eca39 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs @@ -5,7 +5,6 @@ use ide_db::{ defs::Definition, search::{SearchScope, UsageSearchResult}, }; -use syntax::ast::syntax_factory::SyntaxFactory; use syntax::{ AstNode, ast::{self, HasGenericParams, HasName, HasTypeBounds, Name, NameLike, PathType}, @@ -72,8 +71,8 @@ pub(crate) fn replace_named_generic_with_impl( "Replace named generic with impl trait", target, |edit| { - let mut editor = edit.make_editor(type_param.syntax()); - let make = SyntaxFactory::without_mappings(); + let editor = edit.make_editor(type_param.syntax()); + let make = editor.make(); // remove trait from generic param list if let Some(generic_params) = fn_.generic_param_list() { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs index fd090cc081fa5..eebe93f005f99 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/replace_qualified_name_with_use.rs @@ -73,8 +73,8 @@ pub(crate) fn replace_qualified_name_with_use( // Now that we've brought the name into scope, re-qualify all paths that could be // affected (that is, all paths inside the node we added the `use` to). let scope_node = scope.as_syntax_node(); - let mut editor = builder.make_editor(scope_node); - shorten_paths(&mut editor, scope_node, &original_path); + let editor = builder.make_editor(scope_node); + shorten_paths(&editor, scope_node, &original_path); builder.add_file_edits(ctx.vfs_file_id(), editor); let path = drop_generic_args(&original_path); let edition = ctx @@ -111,7 +111,7 @@ fn target_path(ctx: &AssistContext<'_>, mut original_path: ast::Path) -> Option< } fn drop_generic_args(path: &ast::Path) -> ast::Path { - let (mut editor, path) = SyntaxEditor::with_ast_node(path); + let (editor, path) = SyntaxEditor::with_ast_node(path); if let Some(segment) = path.segment() && let Some(generic_args) = segment.generic_arg_list() { @@ -122,7 +122,7 @@ fn drop_generic_args(path: &ast::Path) -> ast::Path { } /// Mutates `node` to shorten `path` in all descendants of `node`. -fn shorten_paths(editor: &mut SyntaxEditor, node: &SyntaxNode, path: &ast::Path) { +fn shorten_paths(editor: &SyntaxEditor, node: &SyntaxNode, path: &ast::Path) { for child in node.children() { match_ast! { match child { @@ -141,7 +141,7 @@ fn shorten_paths(editor: &mut SyntaxEditor, node: &SyntaxNode, path: &ast::Path) } } -fn maybe_replace_path(editor: &mut SyntaxEditor, path: ast::Path, target: ast::Path) -> Option<()> { +fn maybe_replace_path(editor: &SyntaxEditor, path: ast::Path, target: ast::Path) -> Option<()> { if !path_eq_no_generics(path.clone(), target) { return None; } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs index e973e70345dc2..911fa9d14b591 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/sort_items.rs @@ -127,7 +127,7 @@ impl AddRewrite for Assists { target: &SyntaxNode, ) -> Option<()> { self.add(AssistId::refactor_rewrite("sort_items"), label, target.text_range(), |builder| { - let mut editor = builder.make_editor(target); + let editor = builder.make_editor(target); old.into_iter() .zip(new) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs index 15143575e7d84..4d375080f50e0 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/toggle_macro_delimiter.rs @@ -2,7 +2,7 @@ use ide_db::assists::AssistId; use syntax::{ AstNode, SyntaxKind, SyntaxToken, T, algo::{previous_non_trivia_token, skip_trivia_token}, - ast::{self, syntax_factory::SyntaxFactory}, + ast, }; use crate::{AssistContext, Assists}; @@ -73,8 +73,8 @@ pub(crate) fn toggle_macro_delimiter(acc: &mut Assists, ctx: &AssistContext<'_>) }, token_tree.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(token_tree.syntax()); + let editor = builder.make_editor(token_tree.syntax()); + let make = editor.make(); match token { MacroDelims::LPar | MacroDelims::RPar => { @@ -100,7 +100,6 @@ pub(crate) fn toggle_macro_delimiter(acc: &mut Assists, ctx: &AssistContext<'_>) } } } - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_imports.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_imports.rs index accb5c28d6ed3..ab6317ad446d6 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_imports.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_imports.rs @@ -1,7 +1,7 @@ use syntax::{ AstNode, SyntaxKind, - ast::{self, HasAttrs, HasVisibility, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, - syntax_editor::{Element, Position, Removable}, + ast::{self, HasAttrs, HasVisibility, edit::IndentLevel, syntax_factory::SyntaxFactory}, + syntax_editor::{Element, Position, Removable, SyntaxEditor}, }; use crate::{ @@ -22,6 +22,8 @@ use crate::{ // use std::fmt::Display; // ``` pub(crate) fn unmerge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { + let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone()); + let make = editor.make(); let tree = ctx.find_node_at_offset::()?; let tree_list = tree.syntax().parent().and_then(ast::UseTreeList::cast)?; @@ -31,7 +33,7 @@ pub(crate) fn unmerge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt } let use_ = tree_list.syntax().ancestors().find_map(ast::Use::cast)?; - let path = resolve_full_path(&tree)?; + let path = resolve_full_path(&tree, make)?; // If possible, explain what is going to be done. let label = match tree.path().and_then(|path| path.first_segment()) { @@ -41,16 +43,15 @@ pub(crate) fn unmerge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt let target = tree.syntax().text_range(); acc.add(AssistId::refactor_rewrite("unmerge_imports"), label, target, |builder| { - let make = SyntaxFactory::with_mappings(); + let make = editor.make(); let new_use = make.use_( use_.attrs(), use_.visibility(), make.use_tree(path, tree.use_tree_list(), tree.rename(), tree.star_token().is_some()), ); - let mut editor = builder.make_editor(use_.syntax()); // Remove the use tree from the current use item - tree.remove(&mut editor); + tree.remove(&editor); // Insert a newline and indentation, followed by the new use item editor.insert_all( Position::after(use_.syntax()), @@ -60,12 +61,11 @@ pub(crate) fn unmerge_imports(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt new_use.syntax().syntax_element(), ], ); - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } -fn resolve_full_path(tree: &ast::UseTree) -> Option { +fn resolve_full_path(tree: &ast::UseTree, make: &SyntaxFactory) -> Option { let paths = tree .syntax() .ancestors() @@ -73,7 +73,7 @@ fn resolve_full_path(tree: &ast::UseTree) -> Option { .filter_map(ast::UseTree::cast) .filter_map(|t| t.path()); - let final_path = paths.reduce(|prev, next| make::path_concat(next, prev))?; + let final_path = paths.reduce(|prev, next| make.path_concat(next, prev))?; if final_path.segment().is_some_and(|it| it.self_token().is_some()) { final_path.qualifier() } else { diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_match_arm.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_match_arm.rs index c4c03d3e35f56..65300ccefdb91 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_match_arm.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unmerge_match_arm.rs @@ -1,6 +1,6 @@ use syntax::{ Direction, SyntaxKind, T, - ast::{self, AstNode, edit::IndentLevel, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, edit::IndentLevel}, syntax_editor::{Element, Position}, }; @@ -56,8 +56,8 @@ pub(crate) fn unmerge_match_arm(acc: &mut Assists, ctx: &AssistContext<'_>) -> O "Unmerge match arm", pipe_token.text_range(), |edit| { - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(&new_parent); + let editor = edit.make_editor(&new_parent); + let make = editor.make(); // It is guaranteed that `pats_after` has at least one element let new_pat = if pats_after.len() == 1 { pats_after[0].clone() @@ -101,7 +101,6 @@ pub(crate) fn unmerge_match_arm(acc: &mut Assists, ctx: &AssistContext<'_>) -> O insert_after_old_arm.push(new_match_arm.syntax().clone().into()); editor.insert_all(Position::after(match_arm.syntax()), insert_after_old_arm); - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }, ) diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unqualify_method_call.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unqualify_method_call.rs index ef395791e2518..045a27295297e 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unqualify_method_call.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unqualify_method_call.rs @@ -1,5 +1,5 @@ use hir::AsAssocItem; -use syntax::ast::{self, AstNode, HasArgList, prec::ExprPrecedence, syntax_factory::SyntaxFactory}; +use syntax::ast::{self, AstNode, HasArgList, prec::ExprPrecedence}; use crate::{AssistContext, AssistId, Assists}; @@ -50,8 +50,8 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) "Unqualify method call", call.syntax().text_range(), |builder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = builder.make_editor(call.syntax()); + let editor = builder.make_editor(call.syntax()); + let make = editor.make(); let new_arg_list = make.arg_list(args.args().skip(1)); let receiver = if first_arg.precedence().needs_parentheses_in(ExprPrecedence::Postfix) { @@ -67,10 +67,9 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) && let Some(trait_) = fun.container_or_implemented_trait(ctx.db()) && !scope.can_use_trait_methods(trait_) { - add_import(qualifier, ctx, &make, &mut editor); + add_import(qualifier, ctx, &editor); } - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ) @@ -79,8 +78,7 @@ pub(crate) fn unqualify_method_call(acc: &mut Assists, ctx: &AssistContext<'_>) fn add_import( qualifier: ast::Path, ctx: &AssistContext<'_>, - make: &SyntaxFactory, - editor: &mut syntax::syntax_editor::SyntaxEditor, + editor: &syntax::syntax_editor::SyntaxEditor, ) { if let Some(path_segment) = qualifier.segment() { // for `` @@ -112,7 +110,6 @@ fn add_import( import, &ctx.config.insert_use, editor, - make, ); } } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs index 5593ca3eb88fd..77941bcfb2bdb 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_block.rs @@ -3,7 +3,6 @@ use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - make, }, match_ast, syntax_editor::{Element, Position, SyntaxEditor}, @@ -72,18 +71,19 @@ pub(crate) fn unwrap_block(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let replacement = replacement.stmt_list()?; acc.add(AssistId::refactor_rewrite("unwrap_block"), "Unwrap block", target, |builder| { - let mut edit = builder.make_editor(block.syntax()); + let editor = builder.make_editor(block.syntax()); let replacement = replacement.dedent(from_indent).indent(into_indent); let container = prefer_container.unwrap_or(container); - edit.replace_with_many(&container, extract_statements(replacement)); - delete_else_before(container, &mut edit); + editor.replace_with_many(&container, extract_statements(replacement)); + delete_else_before(container, &editor); - builder.add_file_edits(ctx.vfs_file_id(), edit); + builder.add_file_edits(ctx.vfs_file_id(), editor); }) } -fn delete_else_before(container: SyntaxNode, edit: &mut SyntaxEditor) { +fn delete_else_before(container: SyntaxNode, editor: &SyntaxEditor) { + let make = editor.make(); let Some(else_token) = container .siblings_with_tokens(syntax::Direction::Prev) .skip(1) @@ -94,16 +94,16 @@ fn delete_else_before(container: SyntaxNode, edit: &mut SyntaxEditor) { }; itertools::chain(else_token.prev_token(), else_token.next_token()) .filter(|it| it.kind() == SyntaxKind::WHITESPACE) - .for_each(|it| edit.delete(it)); + .for_each(|it| editor.delete(it)); let indent = IndentLevel::from_node(&container); - let newline = make::tokens::whitespace(&format!("\n{indent}")); - edit.replace(else_token, newline); + let newline = make.whitespace(&format!("\n{indent}")); + editor.replace(else_token, newline); } fn wrap_let(assign: &ast::LetStmt, replacement: ast::BlockExpr) -> ast::BlockExpr { let try_wrap_assign = || { let initializer = assign.initializer()?.syntax().syntax_element(); - let (mut edit, replacement) = SyntaxEditor::with_ast_node(&replacement); + let (editor, replacement) = SyntaxEditor::with_ast_node(&replacement); let tail_expr = replacement.tail_expr()?; let before = assign.syntax().children_with_tokens().take_while(|it| *it != initializer).collect(); @@ -114,9 +114,9 @@ fn wrap_let(assign: &ast::LetStmt, replacement: ast::BlockExpr) -> ast::BlockExp .skip(1) .collect(); - edit.insert_all(Position::before(tail_expr.syntax()), before); - edit.insert_all(Position::after(tail_expr.syntax()), after); - ast::BlockExpr::cast(edit.finish().new_root().clone()) + editor.insert_all(Position::before(tail_expr.syntax()), before); + editor.insert_all(Position::after(tail_expr.syntax()), after); + ast::BlockExpr::cast(editor.finish().new_root().clone()) }; try_wrap_assign().unwrap_or(replacement) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_return_type.rs index eea6c85e8df0a..1fe9ea4eb8759 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_return_type.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_return_type.rs @@ -5,7 +5,7 @@ use ide_db::{ }; use syntax::{ AstNode, NodeOrToken, SyntaxKind, - ast::{self, HasArgList, HasGenericArgs, syntax_factory::SyntaxFactory}, + ast::{self, HasArgList, HasGenericArgs}, match_ast, }; @@ -66,8 +66,8 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> let happy_type = extract_wrapped_type(type_ref)?; acc.add(kind.assist_id(), kind.label(), type_ref.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(&parent); - let make = SyntaxFactory::with_mappings(); + let editor = builder.make_editor(&parent); + let make = editor.make(); let mut exprs_to_unwrap = Vec::new(); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e); @@ -168,7 +168,6 @@ pub(crate) fn unwrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> editor.add_annotation(final_placeholder.syntax(), builder.make_tabstop_after(cap)); } - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }) } diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_type_to_generic_arg.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_type_to_generic_arg.rs index 7b5adc1858b4f..935ae18905449 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_type_to_generic_arg.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_type_to_generic_arg.rs @@ -46,7 +46,7 @@ pub(crate) fn unwrap_type_to_generic_arg(acc: &mut Assists, ctx: &AssistContext< format!("Unwrap type to type argument {generic_arg}"), path_type.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(path_type.syntax()); + let editor = builder.make_editor(path_type.syntax()); editor.replace(path_type.syntax(), generic_arg.syntax()); builder.add_file_edits(ctx.vfs_file_id(), editor); diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs index 0f089c9b66eb0..ddc0af31c33f2 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type.rs @@ -77,9 +77,9 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op kind.label(), type_ref.syntax().text_range(), |builder| { - let mut editor = builder.make_editor(&parent); - let make = SyntaxFactory::with_mappings(); - let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol()); + let editor = builder.make_editor(&parent); + let make = editor.make(); + let alias = wrapper_alias(ctx, make, core_wrapper, type_ref, &ty, kind.symbol()); let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| { let (ast_ty, ty_constructor) = match kind { WrapperKind::Option => { @@ -156,8 +156,6 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op ); } } - - editor.add_mappings(make.finish_with_mappings()); builder.add_file_edits(ctx.vfs_file_id(), editor); }, ); @@ -964,7 +962,7 @@ fn foo(num: i32) -> Option { check_assist_by_label( wrap_return_type, r#" -//- minicore: option +//- minicore: option, fn fn foo(the_field: u32) ->$0 u32 { let true_closure = || { return true; }; if the_field < 5 { @@ -998,7 +996,7 @@ fn foo(the_field: u32) -> Option { check_assist_by_label( wrap_return_type, r#" -//- minicore: option +//- minicore: option, fn fn foo(the_field: u32) -> u32$0 { let true_closure = || { return true; @@ -2035,7 +2033,7 @@ fn foo(num: i32) -> Result { check_assist_by_label( wrap_return_type, r#" -//- minicore: result +//- minicore: result, fn fn foo(the_field: u32) ->$0 u32 { let true_closure = || { return true; }; if the_field < 5 { @@ -2069,7 +2067,7 @@ fn foo(the_field: u32) -> Result { check_assist_by_label( wrap_return_type, r#" -//- minicore: result +//- minicore: result, fn fn foo(the_field: u32) -> u32$0 { let true_closure = || { return true; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs index 3b8988db7aae7..635fab857d087 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs @@ -1,8 +1,7 @@ use ide_db::source_change::SourceChangeBuilder; -use itertools::Itertools; use syntax::{ NodeOrToken, SyntaxToken, T, TextRange, algo, - ast::{self, AstNode, edit::AstNodeEdit, make, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, edit::AstNodeEdit}, }; use crate::{AssistContext, AssistId, Assists}; @@ -151,7 +150,7 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - if let [attr] = &attrs[..] && let Some(ast::Meta::CfgAttrMeta(meta)) = attr.meta() { - unwrap_cfg_attr(acc, meta) + unwrap_cfg_attr(acc, ctx, meta) } else { wrap_cfg_attrs(acc, ctx, attrs) } @@ -192,8 +191,8 @@ fn wrap_derive( } } let handle_source_change = |edit: &mut SourceChangeBuilder| { - let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(attr.syntax()); + let editor = edit.make_editor(attr.syntax()); + let make = editor.make(); let new_derive = make.attr_outer( make.meta_token_tree(make.ident_path("derive"), make.token_tree(T!['('], new_derive)), ); @@ -221,8 +220,6 @@ fn wrap_derive( let tabstop = edit.make_placeholder_snippet(snippet_cap); editor.add_annotation(cfg_predicate.syntax(), tabstop); } - - editor.add_mappings(make.finish_with_mappings()); edit.add_file_edits(ctx.vfs_file_id(), editor); }; @@ -239,8 +236,8 @@ fn wrap_cfg_attrs(acc: &mut Assists, ctx: &AssistContext<'_>, attrs: Vec, attrs: Vec, attrs: Vec Option<()> { +fn unwrap_cfg_attr( + acc: &mut Assists, + ctx: &AssistContext<'_>, + meta: ast::CfgAttrMeta, +) -> Option<()> { let top_attr = ast::Meta::from(meta.clone()).parent_attr()?; let range = top_attr.syntax().text_range(); - let inner_attrs = meta - .metas() - .map(|meta| { - if top_attr.excl_token().is_some() { - make::attr_inner(meta) - } else { - make::attr_outer(meta) - } - }) - .collect::>(); - if inner_attrs.is_empty() { + let inner_metas: Vec = meta.metas().collect(); + if inner_metas.is_empty() { return None; } - let handle_source_change = |f: &mut SourceChangeBuilder| { - let inner_attrs = inner_attrs - .iter() - .map(|it| it.to_string()) - .join(&format!("\n{}", top_attr.indent_level())); - f.replace(range, inner_attrs); - }; + let is_inner = top_attr.excl_token().is_some(); + let indent = top_attr.indent_level(); acc.add( AssistId::refactor("wrap_unwrap_cfg_attr"), "Extract Inner Attributes from `cfg_attr`", range, - handle_source_change, + |builder: &mut SourceChangeBuilder| { + let editor = builder.make_editor(top_attr.syntax()); + let make = editor.make(); + let mut elements = vec![]; + for (i, meta) in inner_metas.into_iter().enumerate() { + if i > 0 { + elements.push(make.whitespace(&format!("\n{indent}")).into()); + } + let attr = if is_inner { make.attr_inner(meta) } else { make.attr_outer(meta) }; + elements.push(attr.syntax().clone().into()); + } + editor.replace_with_many(top_attr.syntax(), elements); + builder.add_file_edits(ctx.vfs_file_id(), editor); + }, ); Some(()) } + #[cfg(test)] mod tests { use crate::tests::check_assist; diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs index a499607c1f711..048f3d7ce8f3e 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/tests/generated.rs @@ -463,7 +463,7 @@ fn doctest_convert_closure_to_fn() { check_doc_test( "convert_closure_to_fn", r#####" -//- minicore: copy +//- minicore: copy, fn struct String; impl String { fn new() -> Self {} diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs index 896743342c1a0..bf1062d207bba 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils.rs @@ -15,7 +15,6 @@ use ide_db::{ syntax_helpers::{node_ext::preorder_expr, prettify_macro_expansion}, }; use itertools::Itertools; -use stdx::format_to; use syntax::{ AstNode, AstToken, Direction, NodeOrToken, SourceFile, SyntaxKind::*, @@ -248,20 +247,20 @@ pub fn add_trait_assoc_items_to_impl( }) .filter_map(|item| match item { ast::AssocItem::Fn(fn_) if fn_.body().is_none() => { - let (mut fn_editor, fn_) = SyntaxEditor::with_ast_node(&fn_); + let (fn_editor, fn_) = SyntaxEditor::with_ast_node(&fn_); let fill_expr: ast::Expr = match config.expr_fill_default { ExprFillDefaultMode::Todo | ExprFillDefaultMode::Default => make.expr_todo(), ExprFillDefaultMode::Underscore => make.expr_underscore().into(), }; let new_body = make.block_expr(None::, Some(fill_expr)); - fn_.replace_or_insert_body(&mut fn_editor, new_body); + fn_.replace_or_insert_body(&fn_editor, new_body); let new_fn_ = fn_editor.finish().new_root().clone(); ast::AssocItem::cast(new_fn_) } ast::AssocItem::TypeAlias(type_alias) => { - let (mut type_alias_editor, type_alias) = SyntaxEditor::with_ast_node(&type_alias); + let (type_alias_editor, type_alias) = SyntaxEditor::with_ast_node(&type_alias); if let Some(type_bound_list) = type_alias.type_bound_list() { - type_bound_list.remove(&mut type_alias_editor); + type_bound_list.remove(&type_alias_editor); }; let type_alias = type_alias_editor.finish().new_root().clone(); ast::AssocItem::cast(type_alias) @@ -346,10 +345,10 @@ fn invert_special_case(make: &SyntaxFactory, expr: &ast::Expr) -> Option, - make: &SyntaxFactory, ) { + let make = editor.make(); let mut attrs = attrs.into_iter().peekable(); if attrs.peek().is_none() { return; @@ -357,12 +356,10 @@ pub(crate) fn insert_attributes( let elem = before.syntax_element(); let indent = IndentLevel::from_element(&elem); let whitespace = format!("\n{indent}"); - edit.insert_all( - syntax::syntax_editor::Position::before(elem), - attrs - .flat_map(|attr| [attr.syntax().clone().into(), make.whitespace(&whitespace).into()]) - .collect(), - ); + let elements: Vec = attrs + .flat_map(|attr| [attr.syntax().clone().into(), make.whitespace(&whitespace).into()]) + .collect(); + editor.insert_all(syntax::syntax_editor::Position::before(elem), elements); } pub(crate) fn next_prev() -> impl Iterator { @@ -532,102 +529,6 @@ fn has_any_fn(imp: &ast::Impl, names: &[String]) -> bool { false } -/// Find the end of the `impl` block for the given `ast::Impl`. -// -// FIXME: this partially overlaps with `find_struct_impl` -pub(crate) fn find_impl_block_end(impl_def: ast::Impl, buf: &mut String) -> Option { - buf.push('\n'); - let end = impl_def - .assoc_item_list() - .and_then(|it| it.r_curly_token())? - .prev_sibling_or_token()? - .text_range() - .end(); - Some(end) -} - -/// Generates the surrounding `impl Type { }` including type and lifetime -/// parameters. -// FIXME: migrate remaining uses to `generate_impl` -pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String { - generate_impl_text_inner(adt, None, true, code) -} - -fn generate_impl_text_inner( - adt: &ast::Adt, - trait_text: Option<&str>, - trait_is_transitive: bool, - code: &str, -) -> String { - // Ensure lifetime params are before type & const params - let generic_params = adt.generic_param_list().map(|generic_params| { - let lifetime_params = - generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam); - let ty_or_const_params = generic_params.type_or_const_params().filter_map(|param| { - let param = match param { - ast::TypeOrConstParam::Type(param) => { - // remove defaults since they can't be specified in impls - let mut bounds = - param.type_bound_list().map_or_else(Vec::new, |it| it.bounds().collect()); - if let Some(trait_) = trait_text { - // Add the current trait to `bounds` if the trait is transitive, - // meaning `impl Trait for U` requires `T: Trait`. - if trait_is_transitive { - bounds.push(make::type_bound_text(trait_)); - } - }; - // `{ty_param}: {bounds}` - let param = make::type_param(param.name()?, make::type_bound_list(bounds)); - ast::GenericParam::TypeParam(param) - } - ast::TypeOrConstParam::Const(param) => { - // remove defaults since they can't be specified in impls - let param = make::const_param(param.name()?, param.ty()?); - ast::GenericParam::ConstParam(param) - } - }; - Some(param) - }); - - make::generic_param_list(itertools::chain(lifetime_params, ty_or_const_params)) - }); - - // FIXME: use syntax::make & mutable AST apis instead - // `trait_text` and `code` can't be opaque blobs of text - let mut buf = String::with_capacity(code.len()); - - // Copy any cfg attrs from the original adt - buf.push_str("\n\n"); - let cfg_attrs = adt.attrs().filter(|attr| matches!(attr.meta(), Some(ast::Meta::CfgMeta(_)))); - cfg_attrs.for_each(|attr| buf.push_str(&format!("{attr}\n"))); - - // `impl{generic_params} {trait_text} for {name}{generic_params.to_generic_args()}` - buf.push_str("impl"); - if let Some(generic_params) = &generic_params { - format_to!(buf, "{generic_params}"); - } - buf.push(' '); - if let Some(trait_text) = trait_text { - buf.push_str(trait_text); - buf.push_str(" for "); - } - buf.push_str(&adt.name().unwrap().text()); - if let Some(generic_params) = generic_params { - format_to!(buf, "{}", generic_params.to_generic_args()); - } - - match adt.where_clause() { - Some(where_clause) => { - format_to!(buf, "\n{where_clause}\n{{\n{code}\n}}"); - } - None => { - format_to!(buf, " {{\n{code}\n}}"); - } - } - - buf -} - /// Generates the corresponding `impl Type {}` including type and lifetime /// parameters. pub(crate) fn generate_impl_with_item( @@ -919,28 +820,6 @@ fn generic_param_associated_bounds_with_factory( trait_where_clause.peek().is_some().then(|| make.where_clause(trait_where_clause)) } -pub(crate) fn add_method_to_adt( - builder: &mut SourceChangeBuilder, - adt: &ast::Adt, - impl_def: Option, - method: &str, -) { - let mut buf = String::with_capacity(method.len() + 2); - if impl_def.is_some() { - buf.push('\n'); - } - buf.push_str(method); - - let start_offset = impl_def - .and_then(|impl_def| find_impl_block_end(impl_def, &mut buf)) - .unwrap_or_else(|| { - buf = generate_impl_text(adt, &buf); - adt.syntax().text_range().end() - }); - - builder.insert(start_offset, buf); -} - #[derive(Debug)] pub(crate) struct ReferenceConversion<'db> { conversion: ReferenceConversionType, diff --git a/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs b/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs index b0d88737fe0f7..c0ddcb950cbac 100644 --- a/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs +++ b/src/tools/rust-analyzer/crates/ide-assists/src/utils/gen_trait_fn_body.rs @@ -95,7 +95,7 @@ fn gen_clone_impl(make: &SyntaxFactory, adt: &ast::Adt) -> Option Option MyStruct (fields..) => f.debug_tuple("MyStruct")...finish(), - let pat = make.tuple_struct_pat(variant_name.clone(), pats.into_iter()); + let pat = make.tuple_struct_pat(variant_name.clone(), pats); arms.push(make.match_arm(pat.into(), None, expr)); } None => { diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/completions.rs b/src/tools/rust-analyzer/crates/ide-completion/src/completions.rs index 9a09e9bd4a20b..2ed582598b729 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/completions.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/completions.rs @@ -234,17 +234,13 @@ impl Completions { Visible::Editable => true, Visible::No => return, }; - self.add( - render_path_resolution( - RenderContext::new(ctx) - .private_editable(is_private_editable) - .doc_aliases(doc_aliases), - path_ctx, - local_name, - resolution, - ) - .build(ctx.db), - ); + render_path_resolution( + RenderContext::new(ctx).private_editable(is_private_editable).doc_aliases(doc_aliases), + path_ctx, + local_name, + resolution, + ) + .add_to(self, ctx.db); } pub(crate) fn add_pattern_resolution( @@ -259,15 +255,13 @@ impl Completions { Visible::Editable => true, Visible::No => return, }; - self.add( - render_pattern_resolution( - RenderContext::new(ctx).private_editable(is_private_editable), - pattern_ctx, - local_name, - resolution, - ) - .build(ctx.db), - ); + render_pattern_resolution( + RenderContext::new(ctx).private_editable(is_private_editable), + pattern_ctx, + local_name, + resolution, + ) + .add_to(self, ctx.db); } pub(crate) fn add_enum_variants( @@ -276,9 +270,6 @@ impl Completions { path_ctx: &PathCompletionCtx<'_>, e: hir::Enum, ) { - if !ctx.check_stability_and_hidden(e) { - return; - } e.variants(ctx.db) .into_iter() .for_each(|variant| self.add_enum_variant(ctx, path_ctx, variant, None)); @@ -313,15 +304,13 @@ impl Completions { Visible::Editable => true, Visible::No => return, }; - self.add( - render_macro( - RenderContext::new(ctx).private_editable(is_private_editable), - path_ctx, - local_name, - mac, - ) - .build(ctx.db), - ); + render_macro( + RenderContext::new(ctx).private_editable(is_private_editable), + path_ctx, + local_name, + mac, + ) + .add_to(self, ctx.db); } pub(crate) fn add_function( @@ -337,17 +326,13 @@ impl Completions { Visible::No => return, }; let doc_aliases = ctx.doc_aliases(&func); - self.add( - render_fn( - RenderContext::new(ctx) - .private_editable(is_private_editable) - .doc_aliases(doc_aliases), - path_ctx, - local_name, - func, - ) - .build(ctx.db), - ); + render_fn( + RenderContext::new(ctx).private_editable(is_private_editable).doc_aliases(doc_aliases), + path_ctx, + local_name, + func, + ) + .add_to(self, ctx.db); } pub(crate) fn add_method( @@ -364,18 +349,14 @@ impl Completions { Visible::No => return, }; let doc_aliases = ctx.doc_aliases(&func); - self.add( - render_method( - RenderContext::new(ctx) - .private_editable(is_private_editable) - .doc_aliases(doc_aliases), - dot_access, - receiver, - local_name, - func, - ) - .build(ctx.db), - ); + render_method( + RenderContext::new(ctx).private_editable(is_private_editable).doc_aliases(doc_aliases), + dot_access, + receiver, + local_name, + func, + ) + .add_to(self, ctx.db); } pub(crate) fn add_method_with_import( @@ -391,19 +372,17 @@ impl Completions { Visible::No => return, }; let doc_aliases = ctx.doc_aliases(&func); - self.add( - render_method( - RenderContext::new(ctx) - .private_editable(is_private_editable) - .doc_aliases(doc_aliases) - .import_to_add(Some(import)), - dot_access, - None, - None, - func, - ) - .build(ctx.db), - ); + render_method( + RenderContext::new(ctx) + .private_editable(is_private_editable) + .doc_aliases(doc_aliases) + .import_to_add(Some(import)), + dot_access, + None, + None, + func, + ) + .add_to(self, ctx.db); } pub(crate) fn add_const(&mut self, ctx: &CompletionContext<'_>, konst: hir::Const) { diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/completions/expr.rs b/src/tools/rust-analyzer/crates/ide-completion/src/completions/expr.rs index 99ca55bdaf74a..c15c67173ea3d 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/completions/expr.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/completions/expr.rs @@ -317,7 +317,12 @@ pub(crate) fn complete_expr_path( } // synthetic names currently leak out as we lack synthetic hygiene, so filter them // out here - ScopeDef::Local(_) => { + ScopeDef::Local(_) => + { + #[expect( + clippy::collapsible_match, + reason = "this changes meaning, causing the next arm to be selected" + )] if !name.as_str().starts_with('<') { acc.add_path_resolution(ctx, path_ctx, name, def, doc_aliases) } diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/completions/type.rs b/src/tools/rust-analyzer/crates/ide-completion/src/completions/type.rs index e2125a9678234..20bbf0dd8bacc 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/completions/type.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/completions/type.rs @@ -162,29 +162,27 @@ pub(crate) fn complete_type_path( } TypeLocation::GenericArg { args: Some(arg_list), of_trait: Some(trait_), .. - } => { - if arg_list.syntax().ancestors().find_map(ast::TypeBound::cast).is_some() { - let arg_idx = arg_list - .generic_args() - .filter(|arg| { - arg.syntax().text_range().end() - < ctx.original_token.text_range().start() - }) - .count(); - - let n_required_params = trait_.type_or_const_param_count(ctx.sema.db, true); - if arg_idx >= n_required_params { - trait_.items_with_supertraits(ctx.sema.db).into_iter().for_each(|it| { - if let hir::AssocItem::TypeAlias(alias) = it { - cov_mark::hit!(complete_assoc_type_in_generics_list); - acc.add_type_alias_with_eq(ctx, alias); - } - }); - - let n_params = trait_.type_or_const_param_count(ctx.sema.db, false); - if arg_idx >= n_params { - return; // only show assoc types + } if arg_list.syntax().ancestors().find_map(ast::TypeBound::cast).is_some() => { + let arg_idx = arg_list + .generic_args() + .filter(|arg| { + arg.syntax().text_range().end() + < ctx.original_token.text_range().start() + }) + .count(); + + let n_required_params = trait_.type_or_const_param_count(ctx.sema.db, true); + if arg_idx >= n_required_params { + trait_.items_with_supertraits(ctx.sema.db).into_iter().for_each(|it| { + if let hir::AssocItem::TypeAlias(alias) = it { + cov_mark::hit!(complete_assoc_type_in_generics_list); + acc.add_type_alias_with_eq(ctx, alias); } + }); + + let n_params = trait_.type_or_const_param_count(ctx.sema.db, false); + if arg_idx >= n_params { + return; // only show assoc types } } } diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/context.rs b/src/tools/rust-analyzer/crates/ide-completion/src/context.rs index 485e5f0cafd7f..b9520e9132143 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/context.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/context.rs @@ -4,12 +4,12 @@ mod analysis; #[cfg(test)] mod tests; -use std::iter; +use std::{iter, sync::LazyLock}; use base_db::toolchain_channel; use hir::{ DisplayTarget, HasAttrs, InFile, Local, ModuleDef, ModuleSource, Name, PathResolution, - ScopeDef, Semantics, SemanticsScope, Symbol, Type, TypeInfo, + ScopeDef, Semantics, SemanticsScope, Symbol, Type, TypeInfo, sym, }; use ide_db::{ FilePosition, FxHashMap, FxHashSet, RootDatabase, famous_defs::FamousDefs, @@ -411,7 +411,7 @@ pub(crate) enum CompletionAnalysis<'db> { fake_attribute_under_caret: Option, extern_crate: Option, }, - /// Set if we are inside the predicate of a #[cfg] or #[cfg_attr]. + /// Set if we are inside the predicate of a `#[cfg]` or `#[cfg_attr]`. CfgPredicate, MacroSegment, } @@ -601,7 +601,18 @@ impl CompletionContext<'_> { let Some(attrs) = attrs else { return true; }; - !attrs.is_unstable() || self.is_nightly + if !attrs.is_unstable() { + return true; + } + if !self.is_nightly { + return false; + } + // Unstable on nightly, but we still don't want to suggest internal features, unless the feature flag is enabled. + let Some(unstable_feature) = attrs.unstable_feature(self.db) else { + return true; + }; + !INTERNAL_FEATURES.contains(&unstable_feature) + || self.krate.is_unstable_feature_enabled(self.db, &unstable_feature) } pub(crate) fn check_stability_and_hidden(&self, item: I) -> bool @@ -924,3 +935,40 @@ const OP_TRAIT_LANG: &[hir::LangItem] = &[ hir::LangItem::Shr, hir::LangItem::Sub, ]; + +// FIXME: Find a way to keep this up to date somehow? +const INTERNAL_FEATURES_LIST: &[Symbol] = &[ + sym::abi_unadjusted, + sym::allocator_internals, + sym::allow_internal_unsafe, + sym::allow_internal_unstable, + sym::cfg_emscripten_wasm_eh, + sym::cfg_target_has_reliable_f16_f128, + sym::compiler_builtins, + sym::custom_mir, + sym::eii_internals, + sym::field_representing_type_raw, + sym::intrinsics, + sym::lang_items, + sym::link_cfg, + sym::more_maybe_bounds, + sym::negative_bounds, + sym::pattern_complexity_limit, + sym::prelude_import, + sym::profiler_runtime, + sym::rustc_attrs, + sym::staged_api, + sym::test_unstable_lint, + sym::builtin_syntax, + sym::link_llvm_intrinsics, + sym::needs_panic_runtime, + sym::panic_runtime, + sym::pattern_types, + sym::rustdoc_internals, + sym::contracts_internals, + sym::freeze_impls, + sym::unsized_fn_params, +]; + +static INTERNAL_FEATURES: LazyLock> = + LazyLock::new(|| INTERNAL_FEATURES_LIST.iter().cloned().collect()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/context/analysis.rs b/src/tools/rust-analyzer/crates/ide-completion/src/context/analysis.rs index 2a293313f2c92..58c0f683a344c 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/context/analysis.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/context/analysis.rs @@ -1183,18 +1183,16 @@ fn classify_name_ref<'db>( let arg_name = arg_name.text(); for item in trait_.items_with_supertraits(sema.db) { match item { - hir::AssocItem::TypeAlias(assoc_ty) => { - if assoc_ty.name(sema.db).as_str() == arg_name { + hir::AssocItem::TypeAlias(assoc_ty) + if assoc_ty.name(sema.db).as_str() == arg_name => { override_location = Some(TypeLocation::AssocTypeEq); return None; - } - }, - hir::AssocItem::Const(const_) => { - if const_.name(sema.db)?.as_str() == arg_name { + }, + hir::AssocItem::Const(const_) + if const_.name(sema.db)?.as_str() == arg_name => { override_location = Some(TypeLocation::AssocConstEq); return None; - } - }, + }, _ => (), } } @@ -1592,7 +1590,7 @@ fn classify_name_ref<'db>( kind_macro_call(it)? }, ast::Meta(meta) => make_path_kind_attr(meta)?, - ast::Visibility(it) => PathKind::Vis { has_in_token: it.in_token().is_some() }, + ast::VisibilityInner(it) => PathKind::Vis { has_in_token: it.in_token().is_some() }, ast::UseTree(_) => PathKind::Use, // completing inside a qualifier ast::Path(parent) => { @@ -1621,7 +1619,7 @@ fn classify_name_ref<'db>( kind_macro_call(it)? }, ast::Meta(meta) => make_path_kind_attr(meta)?, - ast::Visibility(it) => PathKind::Vis { has_in_token: it.in_token().is_some() }, + ast::VisibilityInner(it) => PathKind::Vis { has_in_token: it.in_token().is_some() }, ast::UseTree(_) => PathKind::Use, ast::RecordExpr(it) => make_path_kind_expr(it.into()), _ => return None, diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/item.rs b/src/tools/rust-analyzer/crates/ide-completion/src/item.rs index e6dd1d37d9335..6abf4f632aa8a 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/item.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/item.rs @@ -84,7 +84,15 @@ pub struct CompletionItem { pub ref_match: Option<(CompletionItemRefMode, TextSize)>, /// The import data to add to completion's edits. - pub import_to_add: SmallVec<[String; 1]>, + pub import_to_add: SmallVec<[CompletionItemImport; 1]>, +} + +#[derive(Clone, UpmapFromRaFixture)] +pub struct CompletionItemImport { + /// The path to import. + pub path: String, + /// Whether to import `as _`. + pub as_underscore: bool, } #[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] @@ -184,6 +192,8 @@ pub struct CompletionRelevance { pub function: Option, /// true when there is an `await.method()` or `iter().method()` completion. pub is_skipping_completion: bool, + /// if inherent impl already exists in current module, user may not want to implement it again. + pub has_local_inherent_impl: bool, } #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct CompletionRelevanceTraitInfo { @@ -275,6 +285,7 @@ impl CompletionRelevance { trait_, function, is_skipping_completion, + has_local_inherent_impl, } = self; // only applicable for completions within use items @@ -347,6 +358,10 @@ impl CompletionRelevance { score += fn_score; }; + if has_local_inherent_impl { + score -= 5; + } + score } @@ -578,7 +593,18 @@ impl Builder { let import_to_add = self .imports_to_add .into_iter() - .map(|import| import.import_path.display(db, self.edition).to_string()) + .map(|import| { + let path = import.import_path.display(db, self.edition).to_string(); + let as_underscore = + if let hir::ItemInNs::Types(hir::ModuleDef::Trait(trait_to_import)) = + import.item_to_import + { + trait_to_import.prefer_underscore_import(db) + } else { + false + }; + CompletionItemImport { path, as_underscore } + }) .collect(); CompletionItem { diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/lib.rs b/src/tools/rust-analyzer/crates/ide-completion/src/lib.rs index 3867e65ae57e4..3df511a5ad0f3 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/lib.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/lib.rs @@ -36,8 +36,8 @@ use crate::{ pub use crate::{ config::{AutoImportExclusionType, CallableSnippets, CompletionConfig}, item::{ - CompletionItem, CompletionItemKind, CompletionItemRefMode, CompletionRelevance, - CompletionRelevancePostfixMatch, CompletionRelevanceReturnType, + CompletionItem, CompletionItemImport, CompletionItemKind, CompletionItemRefMode, + CompletionRelevance, CompletionRelevancePostfixMatch, CompletionRelevanceReturnType, CompletionRelevanceTypeMatch, }, snippet::{Snippet, SnippetScope}, @@ -280,7 +280,7 @@ pub fn resolve_completion_edits( db: &RootDatabase, config: &CompletionConfig<'_>, FilePosition { file_id, offset }: FilePosition, - imports: impl IntoIterator, + imports: impl IntoIterator, ) -> Option> { let _p = tracing::info_span!("resolve_completion_edits").entered(); let sema = hir::Semantics::new(db); @@ -299,12 +299,18 @@ pub fn resolve_completion_edits( let new_ast = scope.clone_for_update(); let mut import_insert = TextEdit::builder(); - imports.into_iter().for_each(|full_import_path| { - insert_use::insert_use( - &new_ast, - make::path_from_text_with_edition(&full_import_path, current_edition), - &config.insert_use, - ); + imports.into_iter().for_each(|import| { + let full_path = make::path_from_text_with_edition(&import.path, current_edition); + if import.as_underscore { + insert_use::insert_use_as_alias( + &new_ast, + full_path, + &config.insert_use, + current_edition, + ); + } else { + insert_use::insert_use(&new_ast, full_path, &config.insert_use); + } }); diff(scope.as_syntax_node(), new_ast.as_syntax_node()).into_text_edit(&mut import_insert); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render.rs index b6da6fba638fc..a636c0603ba52 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render.rs @@ -10,7 +10,7 @@ pub(crate) mod type_alias; pub(crate) mod union_literal; pub(crate) mod variant; -use hir::{AsAssocItem, HasAttrs, HirDisplay, ModuleDef, ScopeDef, Type}; +use hir::{AsAssocItem, HasAttrs, HirDisplay, Impl, ModuleDef, ScopeDef, Type}; use ide_db::text_edit::TextEdit; use ide_db::{ RootDatabase, SnippetCap, SymbolKind, @@ -23,7 +23,9 @@ use syntax::{AstNode, SmolStr, SyntaxKind, TextRange, ToSmolStr, ast, format_smo use crate::{ CompletionContext, CompletionItem, CompletionItemKind, CompletionItemRefMode, CompletionRelevance, - context::{DotAccess, DotAccessKind, PathCompletionCtx, PathKind, PatternContext}, + context::{ + DotAccess, DotAccessKind, PathCompletionCtx, PathKind, PatternContext, TypeLocation, + }, item::{Builder, CompletionRelevanceTypeMatch}, render::{ function::render_fn, @@ -90,27 +92,32 @@ impl<'a> RenderContext<'a> { && self.completion.token.parent().is_some_and(|it| it.kind() == SyntaxKind::MACRO_CALL) } - fn is_deprecated(&self, def: impl HasAttrs) -> bool { - def.attrs(self.db()).is_deprecated() - } - - fn is_deprecated_assoc_item(&self, as_assoc_item: impl AsAssocItem) -> bool { + /// Whether `def` is deprecated. + /// + /// This can happen for two reasons: + /// - the def is marked with `#[deprecated]` + /// - the def is an assoc item whose trait is deprecated + /// + /// In order to be able to check for the latter, we'd ideally want to `try_as_dyn<_, dyn AsAssocItem>(def)` + /// (see [`try_as_dyn`][]), but that function is currently unstable. Therefore, we employ a hack instead: + /// if `def` can be an assoc item, it should be passed to this method as follows: + /// ```ignore + /// self.is_deprecated(def, Some(def)) + /// ``` + /// otherwise, it should be passed as: + /// ```ignore + /// self.is_deprecated(def, None) + /// ``` + /// + /// [`try_as_dyn`]: https://doc.rust-lang.org/std/any/fn.try_as_dyn.html + fn is_deprecated(&self, def: impl HasAttrs, def_as_assoc_item: Option) -> bool { let db = self.db(); - let assoc = match as_assoc_item.as_assoc_item(db) { - Some(assoc) => assoc, - None => return false, - }; - - let is_assoc_deprecated = match assoc { - hir::AssocItem::Function(it) => self.is_deprecated(it), - hir::AssocItem::Const(it) => self.is_deprecated(it), - hir::AssocItem::TypeAlias(it) => self.is_deprecated(it), - }; - is_assoc_deprecated - || assoc - .container_or_implemented_trait(db) - .map(|trait_| self.is_deprecated(trait_)) - .unwrap_or(false) + def.attrs(db).is_deprecated() + || def_as_assoc_item + .and_then(|assoc| assoc.container_or_implemented_trait(db)) + .is_some_and(|trait_| { + self.is_deprecated(trait_, None /* traits can't be assoc items */) + }) } // FIXME: remove this @@ -127,7 +134,7 @@ pub(crate) fn render_field( ty: &hir::Type<'_>, ) -> CompletionItem { let db = ctx.db(); - let is_deprecated = ctx.is_deprecated(field); + let is_deprecated = ctx.is_deprecated(field, None /* fields can't be assoc items */); let name = field.name(db); let (name, escaped_name) = (name.as_str().to_smolstr(), name.display_no_db(ctx.completion.edition).to_smolstr()); @@ -422,6 +429,7 @@ fn render_resolution_path( } let completion = ctx.completion; + let module = completion.module; let cap = ctx.snippet_cap(); let db = completion.db; let config = completion.config; @@ -466,6 +474,7 @@ fn render_resolution_path( exact_name_match: compute_exact_name_match(completion, &name), is_local: matches!(resolution, ScopeDef::Local(_)), requires_import, + has_local_inherent_impl: compute_has_local_inherent_impl(db, path_ctx, &ty, module), ..CompletionRelevance::default() }); @@ -572,10 +581,15 @@ fn scope_def_docs(db: &RootDatabase, resolution: ScopeDef) -> Option, resolution: ScopeDef) -> bool { + let db = ctx.db(); match resolution { - ScopeDef::ModuleDef(it) => ctx.is_deprecated_assoc_item(it), - ScopeDef::GenericParam(it) => ctx.is_deprecated(it), - ScopeDef::AdtSelfType(it) => ctx.is_deprecated(it), + ScopeDef::ModuleDef(it) => ctx.is_deprecated(it, it.as_assoc_item(db)), + ScopeDef::GenericParam(it) => { + ctx.is_deprecated(it, None /* generic params can't be assoc items */) + } + ScopeDef::AdtSelfType(it) => { + ctx.is_deprecated(it, None /* `Self` can't be an assoc item */) + } _ => false, } } @@ -660,6 +674,18 @@ fn compute_type_match( match_types(ctx, expected_type, completion_ty) } +fn compute_has_local_inherent_impl( + db: &RootDatabase, + path_ctx: &PathCompletionCtx<'_>, + completion_ty: &hir::Type<'_>, + curr_module: hir::Module, +) -> bool { + matches!(path_ctx.kind, PathKind::Type { location: TypeLocation::ImplTarget }) + && Impl::all_for_type(db, completion_ty.clone()) + .iter() + .any(|imp| imp.trait_(db).is_none() && imp.module(db) == curr_module) +} + fn compute_exact_name_match(ctx: &CompletionContext<'_>, completion_name: &str) -> bool { ctx.expected_name.as_ref().is_some_and(|name| name.text() == completion_name) } @@ -717,7 +743,7 @@ fn path_ref_match( // FIXME: This might create inconsistent completions where we show a ref match in macro inputs // as long as nothing was typed yet if let Some(ref_mode) = compute_ref_match(completion, ty) { - item.ref_match(ref_mode, completion.position.offset); + item.ref_match(ref_mode, completion.source_range().start()); } } } @@ -832,6 +858,7 @@ mod tests { ), (relevance.trait_.is_some_and(|it| it.is_op_method), "op_method"), (relevance.requires_import, "requires_import"), + (relevance.has_local_inherent_impl, "has_local_inherent_impl"), ] .into_iter() .filter_map(|(cond, desc)| if cond { Some(desc) } else { None }) @@ -1214,6 +1241,7 @@ fn main() { Foo::Fo$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1264,6 +1292,7 @@ fn main() { Foo::Fo$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1407,6 +1436,7 @@ fn main() { Foo::Fo$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1490,6 +1520,7 @@ fn main() { let _: m::Spam = S$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1526,6 +1557,7 @@ fn main() { let _: m::Spam = S$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1539,6 +1571,32 @@ fn main() { let _: m::Spam = S$0 } check( r#" #[deprecated] +mod something_deprecated {} + +fn main() { som$0 } +"#, + SymbolKind::Module, + expect![[r#" + [ + CompletionItem { + label: "something_deprecated", + detail_left: None, + detail_right: None, + source_range: 55..58, + delete: 55..58, + insert: "something_deprecated", + kind: SymbolKind( + Module, + ), + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] fn something_deprecated() {} fn main() { som$0 } @@ -1581,10 +1639,295 @@ fn main() { som$0 } "#]], ); + check( + r#" +#[deprecated] +struct A; + +fn main() { A$0 } +"#, + SymbolKind::Struct, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: Some( + "A", + ), + source_range: 37..38, + delete: 37..38, + insert: "A", + kind: SymbolKind( + Struct, + ), + detail: "A", + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +enum A {} + +fn main() { A$0 } +"#, + SymbolKind::Enum, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: Some( + "A", + ), + source_range: 37..38, + delete: 37..38, + insert: "A", + kind: SymbolKind( + Enum, + ), + detail: "A", + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +enum A { + Okay, + #[deprecated] + Old, +} + +fn main() { A::$0 } +"#, + SymbolKind::Variant, + expect![[r#" + [ + CompletionItem { + label: "Okay", + detail_left: None, + detail_right: Some( + "Okay", + ), + source_range: 64..64, + delete: 64..64, + insert: "Okay$0", + kind: SymbolKind( + Variant, + ), + detail: "Okay", + relevance: CompletionRelevance { + exact_name_match: false, + type_match: None, + is_local: false, + trait_: None, + is_name_already_imported: false, + requires_import: false, + is_private_editable: false, + postfix_match: None, + function: Some( + CompletionRelevanceFn { + has_params: false, + has_self_param: false, + return_type: DirectConstructor, + }, + ), + is_skipping_completion: false, + has_local_inherent_impl: false, + }, + trigger_call_info: true, + }, + CompletionItem { + label: "Old", + detail_left: None, + detail_right: Some( + "Old", + ), + source_range: 64..64, + delete: 64..64, + insert: "Old$0", + kind: SymbolKind( + Variant, + ), + detail: "Old", + deprecated: true, + relevance: CompletionRelevance { + exact_name_match: false, + type_match: None, + is_local: false, + trait_: None, + is_name_already_imported: false, + requires_import: false, + is_private_editable: false, + postfix_match: None, + function: Some( + CompletionRelevanceFn { + has_params: false, + has_self_param: false, + return_type: DirectConstructor, + }, + ), + is_skipping_completion: false, + has_local_inherent_impl: false, + }, + trigger_call_info: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +const A: i32 = 0; + +fn main() { A$0 } +"#, + SymbolKind::Const, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: Some( + "i32", + ), + source_range: 45..46, + delete: 45..46, + insert: "A", + kind: SymbolKind( + Const, + ), + detail: "i32", + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +static A: i32 = 0; + +fn main() { A$0 } +"#, + SymbolKind::Static, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: Some( + "i32", + ), + source_range: 46..47, + delete: 46..47, + insert: "A", + kind: SymbolKind( + Static, + ), + detail: "i32", + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +trait A {} + +impl A$0 +"#, + SymbolKind::Trait, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: None, + source_range: 31..32, + delete: 31..32, + insert: "A", + kind: SymbolKind( + Trait, + ), + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +type A = i32; + +fn main() { A$0 } +"#, + SymbolKind::TypeAlias, + expect![[r#" + [ + CompletionItem { + label: "A", + detail_left: None, + detail_right: None, + source_range: 41..42, + delete: 41..42, + insert: "A", + kind: SymbolKind( + TypeAlias, + ), + deprecated: true, + }, + ] + "#]], + ); + + check( + r#" +#[deprecated] +macro_rules! a { _ => {}} + +fn main() { a$0 } +"#, + SymbolKind::Macro, + expect![[r#" + [ + CompletionItem { + label: "a!(…)", + detail_left: None, + detail_right: Some( + "macro_rules! a", + ), + source_range: 53..54, + delete: 53..54, + insert: "a!($0)", + kind: SymbolKind( + Macro, + ), + lookup: "a!", + detail: "macro_rules! a", + deprecated: true, + }, + ] + "#]], + ); + check( r#" struct A { #[deprecated] the_field: u32 } -fn foo() { A { the$0 } } + +fn main() { A { the$0 } } "#, SymbolKind::Field, expect![[r#" @@ -1595,8 +1938,8 @@ fn foo() { A { the$0 } } detail_right: Some( "u32", ), - source_range: 57..60, - delete: 57..60, + source_range: 59..62, + delete: 59..62, insert: "the_field", kind: SymbolKind( Field, @@ -1616,6 +1959,7 @@ fn foo() { A { the$0 } } postfix_match: None, function: None, is_skipping_completion: false, + has_local_inherent_impl: false, }, }, ] @@ -1675,6 +2019,7 @@ impl S { }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, }, CompletionItem { @@ -1766,6 +2111,7 @@ use self::E::*; }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, trigger_call_info: true, }, @@ -1836,6 +2182,7 @@ fn foo(s: S) { s.$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, }, ] @@ -2048,6 +2395,7 @@ fn f() -> i32 { postfix_match: None, function: None, is_skipping_completion: false, + has_local_inherent_impl: false, }, }, ] @@ -2115,6 +2463,53 @@ fn go(world: &WorldSnapshot) { go(w$0) } ); } + #[test] + fn complete_ref_match_after_keyword_prefix() { + // About https://github.com/rust-lang/rust-analyzer/issues/15139 + check_kinds( + r#" +fn foo(data: &i32) {} +fn main() { + let indent = 2i32; + foo(in$0) +} +"#, + &[CompletionItemKind::SymbolKind(SymbolKind::Local)], + expect![[r#" + [ + CompletionItem { + label: "indent", + detail_left: None, + detail_right: Some( + "i32", + ), + source_range: 65..67, + delete: 65..67, + insert: "indent", + kind: SymbolKind( + Local, + ), + detail: "i32", + relevance: CompletionRelevance { + exact_name_match: false, + type_match: None, + is_local: true, + trait_: None, + is_name_already_imported: false, + requires_import: false, + is_private_editable: false, + postfix_match: None, + function: None, + is_skipping_completion: false, + has_local_inherent_impl: false, + }, + ref_match: "&@65", + }, + ] + "#]], + ); + } + #[test] fn too_many_arguments() { cov_mark::check!(too_many_arguments); @@ -2193,6 +2588,48 @@ fn f() { ); } + #[test] + fn score_has_local_inherent_impl() { + check_relevance( + r#" +trait Foob {} +struct Fooa {} +impl Fooa {} + +impl Foo$0 +"#, + expect![[r#" + tt Foob [] + st Fooa Fooa [has_local_inherent_impl] + "#]], + ); + + // inherent impl in different modules, not trigger `has_local_inherent_impl` + check_relevance( + r#" +trait Foob {} +struct Fooa {} + +mod a { + use super::*; + impl Fooa {} +} + +mod b { + use super::*; + impl Foo$0 +} + +"#, + expect![[r#" + st Fooa Fooa [] + tt Foob [] + md a [] + md b [] + "#]], + ); + } + #[test] fn test_avoid_redundant_suggestion() { check_relevance( @@ -2861,6 +3298,7 @@ fn foo(f: Foo) { let _: &u32 = f.b$0 } }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, ref_match: "&@107", }, @@ -2948,6 +3386,7 @@ fn foo() { postfix_match: None, function: None, is_skipping_completion: false, + has_local_inherent_impl: false, }, }, ] @@ -3006,6 +3445,7 @@ fn main() { }, ), is_skipping_completion: false, + has_local_inherent_impl: false, }, ref_match: "&@92", }, @@ -3476,6 +3916,7 @@ fn main() { postfix_match: None, function: None, is_skipping_completion: false, + has_local_inherent_impl: false, }, }, CompletionItem { @@ -3510,6 +3951,7 @@ fn main() { postfix_match: None, function: None, is_skipping_completion: false, + has_local_inherent_impl: false, }, }, ] diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/const_.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/const_.rs index 707a8aed4fb9e..134a77a8991e3 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/const_.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/const_.rs @@ -21,7 +21,7 @@ fn render(ctx: RenderContext<'_>, const_: hir::Const) -> Option let mut item = CompletionItem::new(SymbolKind::Const, ctx.source_range(), name, ctx.completion.edition); item.set_documentation(ctx.docs(const_)) - .set_deprecated(ctx.is_deprecated(const_) || ctx.is_deprecated_assoc_item(const_)) + .set_deprecated(ctx.is_deprecated(const_, const_.as_assoc_item(db))) .detail(detail) .set_relevance(ctx.completion_relevance()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/function.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/function.rs index dfa30841e7db1..18151cffcd391 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/function.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/function.rs @@ -147,7 +147,7 @@ fn render( detail(ctx.completion, func) }; item.set_documentation(ctx.docs(func)) - .set_deprecated(ctx.is_deprecated(func) || ctx.is_deprecated_assoc_item(func)) + .set_deprecated(ctx.is_deprecated(func, func.as_assoc_item(db))) .detail(detail) .lookup_by(name.as_str().to_smolstr()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/literal.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/literal.rs index 6e49af980aeaa..b7de3da468dde 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/literal.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/literal.rs @@ -189,8 +189,12 @@ impl Variant { fn is_deprecated(self, ctx: &RenderContext<'_>) -> bool { match self { - Variant::Struct(it) => ctx.is_deprecated(it), - Variant::EnumVariant(it) => ctx.is_deprecated(it), + Variant::Struct(it) => { + ctx.is_deprecated(it, None /* structs can't be assoc items */) + } + Variant::EnumVariant(it) => { + ctx.is_deprecated(it, None /* enum variants can't be assoc items */) + } } } diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/macro_.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/macro_.rs index 8cdeb8abbff77..ff4cf9a75b60e 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/macro_.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/macro_.rs @@ -64,7 +64,7 @@ fn render( label(&ctx, needs_bang, bra, ket, &name.to_smolstr()), completion.edition, ); - item.set_deprecated(ctx.is_deprecated(macro_)) + item.set_deprecated(ctx.is_deprecated(macro_, None /* macros can't be assoc items */)) .detail(macro_.display(completion.db, completion.display_target).to_string()) .set_documentation(docs) .set_relevance(ctx.completion_relevance()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/pattern.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/pattern.rs index fb35d7b9b6714..022e97e4f7600 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/pattern.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/pattern.rs @@ -126,7 +126,9 @@ fn build_completion( ctx.completion.edition, ); item.set_documentation(ctx.docs(def)) - .set_deprecated(ctx.is_deprecated(def)) + .set_deprecated( + ctx.is_deprecated(def, None /* the two current `def` arguments to this function, `Struct` and `EnumVariant`, both can't be assoc items */), + ) .detail(&pat) .lookup_by(lookup) .set_relevance(relevance); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/type_alias.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/type_alias.rs index 3fc0f369e5ada..2b79ca2deb693 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/type_alias.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/type_alias.rs @@ -47,7 +47,7 @@ fn render( ctx.completion.edition, ); item.set_documentation(ctx.docs(type_alias)) - .set_deprecated(ctx.is_deprecated(type_alias) || ctx.is_deprecated_assoc_item(type_alias)) + .set_deprecated(ctx.is_deprecated(type_alias, type_alias.as_assoc_item(db))) .detail(detail) .set_relevance(ctx.completion_relevance()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/render/union_literal.rs b/src/tools/rust-analyzer/crates/ide-completion/src/render/union_literal.rs index 23f0d4e06f2c8..7164c94fde946 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/render/union_literal.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/render/union_literal.rs @@ -95,7 +95,7 @@ pub(crate) fn render_union_literal( ); item.set_documentation(ctx.docs(un)) - .set_deprecated(ctx.is_deprecated(un)) + .set_deprecated(ctx.is_deprecated(un, None /* unions can't be assoc items */)) .detail(detail) .set_relevance(ctx.completion_relevance()); diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/tests/expression.rs b/src/tools/rust-analyzer/crates/ide-completion/src/tests/expression.rs index 4a5983097a12c..294434297eccb 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/tests/expression.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/tests/expression.rs @@ -2273,7 +2273,7 @@ fn main() { $0 } //- /std.rs crate:std -#[unstable] +#[unstable(feature = "some_non_internal_feature")] pub struct UnstableButWeAreOnNightlyAnyway; "#, expect![[r#" @@ -2317,6 +2317,112 @@ pub struct UnstableButWeAreOnNightlyAnyway; ); } +#[test] +fn expr_unstable_item_internal_feature() { + check( + r#" +//- toolchain:nightly +//- /main.rs crate:main deps:std +use std::*; +fn main() { + $0 +} +//- /std.rs crate:std +#[unstable(feature = "intrinsics")] +pub mod intrinsics {} + "#, + expect![[r#" + fn main() fn() + md std + bt u32 u32 + kw async + kw const + kw crate:: + kw enum + kw extern + kw false + kw fn + kw for + kw if + kw if let + kw impl + kw impl for + kw let + kw letm + kw loop + kw match + kw mod + kw return + kw self:: + kw static + kw struct + kw trait + kw true + kw type + kw union + kw unsafe + kw use + kw while + kw while let + sn macro_rules + sn pd + sn ppd + "#]], + ); + check( + r#" +//- toolchain:nightly +//- /main.rs crate:main deps:std +#![feature(intrinsics)] +use std::*; +fn main() { + $0 +} +//- /std.rs crate:std +#[unstable(feature = "intrinsics")] +pub mod intrinsics {} + "#, + expect![[r#" + fn main() fn() + md intrinsics + md std + bt u32 u32 + kw async + kw const + kw crate:: + kw enum + kw extern + kw false + kw fn + kw for + kw if + kw if let + kw impl + kw impl for + kw let + kw letm + kw loop + kw match + kw mod + kw return + kw self:: + kw static + kw struct + kw trait + kw true + kw type + kw union + kw unsafe + kw use + kw while + kw while let + sn macro_rules + sn pd + sn ppd + "#]], + ); +} + #[test] fn inside_format_args_completions_work() { check( diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/tests/flyimport.rs b/src/tools/rust-analyzer/crates/ide-completion/src/tests/flyimport.rs index 5391e6c9ce6e5..60ae077d01426 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/tests/flyimport.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/tests/flyimport.rs @@ -2057,3 +2057,38 @@ fn main() { "#, ); } + +#[test] +fn prefer_underscore_import() { + check_edit( + "bar", + r#" +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.bar$0 +} + "#, + r#" +use foo::Ext as _; + +mod foo { + #[rust_analyzer::prefer_underscore_import] + pub trait Ext { + fn bar(&self) {} + } + impl Ext for T {} +} + +fn baz() { + 1.bar();$0 +} + "#, + ); +} diff --git a/src/tools/rust-analyzer/crates/ide-completion/src/tests/item.rs b/src/tools/rust-analyzer/crates/ide-completion/src/tests/item.rs index 2f032c3f4ca56..45024ad21638c 100644 --- a/src/tools/rust-analyzer/crates/ide-completion/src/tests/item.rs +++ b/src/tools/rust-analyzer/crates/ide-completion/src/tests/item.rs @@ -380,3 +380,23 @@ foo!(f$0); "#]], ); } + +#[test] +fn completes_variant_through_hidden_enum_alias() { + check( + r#" +//- /lib.rs crate:dep +#[doc(hidden)] +pub enum Foo { Variant } +pub type Bar = Foo; + +//- /main.rs crate:main deps:dep +fn main() { + let x = dep::Bar::V$0; +} +"#, + expect![[r#" + ev Variant Variant + "#]], + ); +} diff --git a/src/tools/rust-analyzer/crates/ide-db/Cargo.toml b/src/tools/rust-analyzer/crates/ide-db/Cargo.toml index fca06b69d1bb8..2c0919a183709 100644 --- a/src/tools/rust-analyzer/crates/ide-db/Cargo.toml +++ b/src/tools/rust-analyzer/crates/ide-db/Cargo.toml @@ -25,7 +25,6 @@ arrayvec.workspace = true memchr = "2.7.5" salsa.workspace = true salsa-macros.workspace = true -query-group.workspace = true triomphe.workspace = true nohash-hasher.workspace = true bitflags.workspace = true diff --git a/src/tools/rust-analyzer/crates/ide-db/src/generated/lints.rs b/src/tools/rust-analyzer/crates/ide-db/src/generated/lints.rs index c25feceb4157d..52a5a95450974 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/generated/lints.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/generated/lints.rs @@ -285,6 +285,13 @@ pub const DEFAULT_LINTS: &[Lint] = &[ warn_since: None, deny_since: None, }, + Lint { + label: "deprecated_llvm_intrinsic", + description: r##"detects uses of deprecated LLVM intrinsics"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, Lint { label: "deprecated_safe_2024", description: r##"detects unsafe functions being used as safe functions"##, @@ -425,6 +432,13 @@ pub const DEFAULT_LINTS: &[Lint] = &[ warn_since: None, deny_since: None, }, + Lint { + label: "float_literal_f32_fallback", + description: r##"detects unsuffixed floating point literals whose type fallback to `f32`"##, + default_severity: Severity::Warning, + warn_since: None, + deny_since: None, + }, Lint { label: "for_loops_over_fallibles", description: r##"for-looping over an `Option` or a `Result`, which is more clearly expressed as an `if let`"##, @@ -1345,7 +1359,7 @@ pub const DEFAULT_LINTS: &[Lint] = &[ Lint { label: "uninhabited_static", description: r##"uninhabited static"##, - default_severity: Severity::Warning, + default_severity: Severity::Error, warn_since: None, deny_since: None, }, @@ -1681,7 +1695,7 @@ pub const DEFAULT_LINTS: &[Lint] = &[ Lint { label: "varargs_without_pattern", description: r##"detects usage of `...` arguments without a pattern in non-foreign items"##, - default_severity: Severity::Warning, + default_severity: Severity::Error, warn_since: None, deny_since: None, }, @@ -1715,7 +1729,7 @@ pub const DEFAULT_LINTS: &[Lint] = &[ }, Lint { label: "future_incompatible", - description: r##"lint group for: internal-eq-trait-method-impls, aarch64-softfloat-neon, ambiguous-associated-items, ambiguous-derive-helpers, ambiguous-glob-imported-traits, ambiguous-glob-imports, ambiguous-import-visibilities, ambiguous-panic-imports, coherence-leak-check, conflicting-repr-hints, const-evaluatable-unchecked, elided-lifetimes-in-associated-constant, forbidden-lint-groups, ill-formed-attribute-input, invalid-macro-export-arguments, invalid-type-param-default, late-bound-lifetime-arguments, legacy-derive-helpers, macro-expanded-macro-exports-accessed-by-absolute-paths, out-of-scope-macro-calls, patterns-in-fns-without-body, proc-macro-derive-resolution-fallback, pub-use-of-private-extern-crate, repr-c-enums-larger-than-int, repr-transparent-non-zst-fields, self-constructor-from-outer-item, semicolon-in-expressions-from-macros, uncovered-param-in-projection, uninhabited-static, unstable-name-collisions, unstable-syntax-pre-expansion, unsupported-calling-conventions, varargs-without-pattern"##, + description: r##"lint group for: internal-eq-trait-method-impls, aarch64-softfloat-neon, ambiguous-associated-items, ambiguous-derive-helpers, ambiguous-glob-imported-traits, ambiguous-glob-imports, ambiguous-import-visibilities, ambiguous-panic-imports, coherence-leak-check, conflicting-repr-hints, const-evaluatable-unchecked, elided-lifetimes-in-associated-constant, float-literal-f32-fallback, forbidden-lint-groups, ill-formed-attribute-input, invalid-macro-export-arguments, invalid-type-param-default, late-bound-lifetime-arguments, legacy-derive-helpers, macro-expanded-macro-exports-accessed-by-absolute-paths, out-of-scope-macro-calls, patterns-in-fns-without-body, proc-macro-derive-resolution-fallback, pub-use-of-private-extern-crate, repr-c-enums-larger-than-int, repr-transparent-non-zst-fields, self-constructor-from-outer-item, semicolon-in-expressions-from-macros, uncovered-param-in-projection, uninhabited-static, unstable-name-collisions, unstable-syntax-pre-expansion, unsupported-calling-conventions, varargs-without-pattern"##, default_severity: Severity::Allow, warn_since: None, deny_since: None, @@ -1790,13 +1804,6 @@ pub const DEFAULT_LINTS: &[Lint] = &[ warn_since: None, deny_since: None, }, - Lint { - label: "warnings", - description: r##"lint group for: all lints that are set to issue warnings"##, - default_severity: Severity::Allow, - warn_since: None, - deny_since: None, - }, ]; pub const DEFAULT_LINT_GROUPS: &[LintGroup] = &[ @@ -1813,7 +1820,7 @@ pub const DEFAULT_LINT_GROUPS: &[LintGroup] = &[ LintGroup { lint: Lint { label: "future_incompatible", - description: r##"lint group for: internal-eq-trait-method-impls, aarch64-softfloat-neon, ambiguous-associated-items, ambiguous-derive-helpers, ambiguous-glob-imported-traits, ambiguous-glob-imports, ambiguous-import-visibilities, ambiguous-panic-imports, coherence-leak-check, conflicting-repr-hints, const-evaluatable-unchecked, elided-lifetimes-in-associated-constant, forbidden-lint-groups, ill-formed-attribute-input, invalid-macro-export-arguments, invalid-type-param-default, late-bound-lifetime-arguments, legacy-derive-helpers, macro-expanded-macro-exports-accessed-by-absolute-paths, out-of-scope-macro-calls, patterns-in-fns-without-body, proc-macro-derive-resolution-fallback, pub-use-of-private-extern-crate, repr-c-enums-larger-than-int, repr-transparent-non-zst-fields, self-constructor-from-outer-item, semicolon-in-expressions-from-macros, uncovered-param-in-projection, uninhabited-static, unstable-name-collisions, unstable-syntax-pre-expansion, unsupported-calling-conventions, varargs-without-pattern"##, + description: r##"lint group for: internal-eq-trait-method-impls, aarch64-softfloat-neon, ambiguous-associated-items, ambiguous-derive-helpers, ambiguous-glob-imported-traits, ambiguous-glob-imports, ambiguous-import-visibilities, ambiguous-panic-imports, coherence-leak-check, conflicting-repr-hints, const-evaluatable-unchecked, elided-lifetimes-in-associated-constant, float-literal-f32-fallback, forbidden-lint-groups, ill-formed-attribute-input, invalid-macro-export-arguments, invalid-type-param-default, late-bound-lifetime-arguments, legacy-derive-helpers, macro-expanded-macro-exports-accessed-by-absolute-paths, out-of-scope-macro-calls, patterns-in-fns-without-body, proc-macro-derive-resolution-fallback, pub-use-of-private-extern-crate, repr-c-enums-larger-than-int, repr-transparent-non-zst-fields, self-constructor-from-outer-item, semicolon-in-expressions-from-macros, uncovered-param-in-projection, uninhabited-static, unstable-name-collisions, unstable-syntax-pre-expansion, unsupported-calling-conventions, varargs-without-pattern"##, default_severity: Severity::Allow, warn_since: None, deny_since: None, @@ -1831,6 +1838,7 @@ pub const DEFAULT_LINT_GROUPS: &[LintGroup] = &[ "conflicting_repr_hints", "const_evaluatable_unchecked", "elided_lifetimes_in_associated_constant", + "float_literal_f32_fallback", "forbidden_lint_groups", "ill_formed_attribute_input", "invalid_macro_export_arguments", @@ -2485,6 +2493,22 @@ The tracking issue for this feature is: [#40180] [#40180]: https://github.com/rust-lang/rust/issues/40180 +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "abort_immediate", + description: r##"# `abort_immediate` + + + +The tracking issue for this feature is: [#154601] + +[#154601]: https://github.com/rust-lang/rust/issues/154601 + ------------------------ "##, default_severity: Severity::Allow, @@ -4615,6 +4639,47 @@ Allows checking whether or not the backend correctly supports unstable float typ This feature has no tracking issue, and is therefore likely internal to the compiler, not being intended for general use. ------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "cfg_target_object_format", + description: r##"# `cfg_target_object_format` + +The tracking issue for this feature is: [#152586] + +[#152586]: https://github.com/rust-lang/rust/issues/152586 + +------------------------ + +The `cfg_target_object_format` feature makes it possible to execute different code +depending on the current target's object file format. + +## Examples + +```rust +#![feature(cfg_target_object_format)] + +#[cfg(target_object_format = "elf")] +fn a() { + // ... +} + +#[cfg(target_object_format = "mach-o")] +fn a() { + // ... +} + +fn b() { + if cfg!(target_object_format = "wasm") { + // ... + } else { + // ... + } +} +``` "##, default_severity: Severity::Allow, warn_since: None, @@ -6927,6 +6992,22 @@ The tracking issue for this feature is: [#154181] [#154181]: https://github.com/rust-lang/rust/issues/154181 +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "diagnostic_on_unknown", + description: r##"# `diagnostic_on_unknown` + +Allows giving unresolved imports a custom diagnostic message + +The tracking issue for this feature is: [#152900] + +[#152900]: https://github.com/rust-lang/rust/issues/152900 + ------------------------ "##, default_severity: Severity::Allow, @@ -7698,6 +7779,22 @@ The tracking issue for this feature is: [#116909] --- Enable the `f16` type for IEEE 16-bit floating numbers (half precision). +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "f32_from_f16", + description: r##"# `f32_from_f16` + + + +The tracking issue for this feature is: [#154005] + +[#154005]: https://github.com/rust-lang/rust/issues/154005 + +------------------------ "##, default_severity: Severity::Allow, warn_since: None, @@ -8001,6 +8098,22 @@ The tracking issue for this feature is: [#91079] This feature is internal to the Rust compiler and is not intended for general use. +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "fma4_target_feature", + description: r##"# `fma4_target_feature` + +fma4 target feature on x86. + +The tracking issue for this feature is: [#155233] + +[#155233]: https://github.com/rust-lang/rust/issues/155233 + ------------------------ "##, default_severity: Severity::Allow, @@ -8882,22 +8995,6 @@ The tracking issue for this feature is: [#134821] [#134821]: https://github.com/rust-lang/rust/issues/134821 ------------------------- -"##, - default_severity: Severity::Allow, - warn_since: None, - deny_since: None, - }, - Lint { - label: "int_lowest_highest_one", - description: r##"# `int_lowest_highest_one` - - - -The tracking issue for this feature is: [#145203] - -[#145203]: https://github.com/rust-lang/rust/issues/145203 - ------------------------ "##, default_severity: Severity::Allow, @@ -9245,22 +9342,6 @@ The tracking issue for this feature is: [#111192] [#111192]: https://github.com/rust-lang/rust/issues/111192 ------------------------- -"##, - default_severity: Severity::Allow, - warn_since: None, - deny_since: None, - }, - Lint { - label: "isolate_most_least_significant_one", - description: r##"# `isolate_most_least_significant_one` - - - -The tracking issue for this feature is: [#136909] - -[#136909]: https://github.com/rust-lang/rust/issues/136909 - ------------------------ "##, default_severity: Severity::Allow, @@ -12989,19 +13070,20 @@ only discuss a few of them. ------------------------ The `rustc_attrs` feature allows debugging rustc type layouts by using -`#[rustc_layout(...)]` to debug layout at compile time (it even works +`#[rustc_dump_layout(...)]` to debug layout at compile time (it even works with `cargo check`) as an alternative to `rustc -Z print-type-sizes` that is way more verbose. -Options provided by `#[rustc_layout(...)]` are `debug`, `size`, `align`, -`abi`. Note that it only works on sized types without generics. +Options provided by `#[rustc_dump_layout(...)]` are `backend_repr`, `align`, +`debug`, `homogeneous_aggregate` and `size`. +Note that it only works on sized types without generics. ## Examples ```rust,compile_fail #![feature(rustc_attrs)] -#[rustc_layout(abi, size)] +#[rustc_dump_layout(backend_repr, size)] pub enum X { Y(u8, u8, u8), Z(isize), @@ -13011,7 +13093,7 @@ pub enum X { When that is compiled, the compiler will error with something like ```text -error: abi: Aggregate { sized: true } +error: backend_repr: Aggregate { sized: true } --> src/lib.rs:4:1 | 4 | / pub enum T { @@ -14867,6 +14949,38 @@ The tracking issue for this feature is: [#109929] [#109929]: https://github.com/rust-lang/rust/issues/109929 +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "transmute_neo", + description: r##"# `transmute_neo` + + + +The tracking issue for this feature is: [#155079] + +[#155079]: https://github.com/rust-lang/rust/issues/155079 + +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "transmute_prefix", + description: r##"# `transmute_prefix` + + + +The tracking issue for this feature is: [#155079] + +[#155079]: https://github.com/rust-lang/rust/issues/155079 + ------------------------ "##, default_severity: Severity::Allow, @@ -15164,6 +15278,22 @@ The tracking issue for this feature is: [#63178] [#63178]: https://github.com/rust-lang/rust/issues/63178 +------------------------ +"##, + default_severity: Severity::Allow, + warn_since: None, + deny_since: None, + }, + Lint { + label: "try_from_int_error_kind", + description: r##"# `try_from_int_error_kind` + + + +The tracking issue for this feature is: [#153978] + +[#153978]: https://github.com/rust-lang/rust/issues/153978 + ------------------------ "##, default_severity: Severity::Allow, @@ -15526,22 +15656,6 @@ The tracking issue for this feature is: [#100499] [#100499]: https://github.com/rust-lang/rust/issues/100499 ------------------------- -"##, - default_severity: Severity::Allow, - warn_since: None, - deny_since: None, - }, - Lint { - label: "uint_bit_width", - description: r##"# `uint_bit_width` - - - -The tracking issue for this feature is: [#142326] - -[#142326]: https://github.com/rust-lang/rust/issues/142326 - ------------------------ "##, default_severity: Severity::Allow, diff --git a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs index 9318c3e132725..fe30a4dc5cc92 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/imports/insert_use.rs @@ -9,7 +9,7 @@ use syntax::{ Direction, NodeOrToken, SyntaxKind, SyntaxNode, algo, ast::{ self, AstNode, HasAttrs, HasModuleItem, HasVisibility, PathSegmentKind, - edit_in_place::Removable, make, syntax_factory::SyntaxFactory, + edit_in_place::Removable, make, }, syntax_editor::{Position, SyntaxEditor}, ted, @@ -175,10 +175,9 @@ pub fn insert_use_with_editor( scope: &ImportScope, path: ast::Path, cfg: &InsertUseConfig, - syntax_editor: &mut SyntaxEditor, - syntax_factory: &SyntaxFactory, + syntax_editor: &SyntaxEditor, ) { - insert_use_with_alias_option_with_editor(scope, path, cfg, None, syntax_editor, syntax_factory); + insert_use_with_alias_option_with_editor(scope, path, cfg, None, syntax_editor); } pub fn insert_use_as_alias( @@ -269,9 +268,9 @@ fn insert_use_with_alias_option_with_editor( path: ast::Path, cfg: &InsertUseConfig, alias: Option, - syntax_editor: &mut SyntaxEditor, - syntax_factory: &SyntaxFactory, + syntax_editor: &SyntaxEditor, ) { + let make = syntax_editor.make(); let _p = tracing::info_span!("insert_use_with_alias_option").entered(); let mut mb = match cfg.granularity { ImportGranularity::Crate => Some(MergeBehavior::Crate), @@ -301,7 +300,7 @@ fn insert_use_with_alias_option_with_editor( }; } - let use_tree = syntax_factory.use_tree(path, None, alias, false); + let use_tree = make.use_tree(path, None, alias, false); if mb == Some(MergeBehavior::One) && use_tree.path().is_some() { use_tree.wrap_in_tree_list(); } @@ -324,7 +323,7 @@ fn insert_use_with_alias_option_with_editor( } // either we weren't allowed to merge or there is no import that fits the merge conditions // so look for the place we have to insert to - insert_use_with_editor_(scope, use_item, cfg.group, syntax_editor, syntax_factory); + insert_use_with_editor_(scope, use_item, cfg.group, syntax_editor); } pub fn ast_to_remove_for_path_in_use_stmt(path: &ast::Path) -> Option> { @@ -604,9 +603,9 @@ fn insert_use_with_editor_( scope: &ImportScope, use_item: ast::Use, group_imports: bool, - syntax_editor: &mut SyntaxEditor, - syntax_factory: &SyntaxFactory, + syntax_editor: &SyntaxEditor, ) { + let make = syntax_editor.make(); let scope_syntax = scope.as_syntax_node(); let insert_use_tree = use_item.use_tree().expect("`use_item` should have a use tree for `insert_path`"); @@ -656,7 +655,7 @@ fn insert_use_with_editor_( cov_mark::hit!(insert_group_new_group); syntax_editor.insert(Position::before(&node), use_item.syntax()); if let Some(node) = algo::non_trivia_sibling(node.into(), Direction::Prev) { - syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + syntax_editor.insert(Position::after(node), make.whitespace("\n")); } return; } @@ -664,7 +663,7 @@ fn insert_use_with_editor_( if let Some(node) = last { cov_mark::hit!(insert_group_no_group); syntax_editor.insert(Position::after(&node), use_item.syntax()); - syntax_editor.insert(Position::after(node), syntax_factory.whitespace("\n")); + syntax_editor.insert(Position::after(node), make.whitespace("\n")); return; } } else { @@ -703,24 +702,18 @@ fn insert_use_with_editor_( { cov_mark::hit!(insert_empty_inner_attr); syntax_editor.insert(Position::after(&last_inner_element), use_item.syntax()); - syntax_editor.insert(Position::after(last_inner_element), syntax_factory.whitespace("\n")); + syntax_editor.insert(Position::after(last_inner_element), make.whitespace("\n")); } else { match l_curly { Some(b) => { cov_mark::hit!(insert_empty_module); - syntax_editor.insert(Position::after(&b), syntax_factory.whitespace("\n")); - syntax_editor.insert_with_whitespace( - Position::after(&b), - use_item.syntax(), - syntax_factory, - ); + syntax_editor.insert(Position::after(&b), make.whitespace("\n")); + syntax_editor.insert_with_whitespace(Position::after(&b), use_item.syntax()); } None => { cov_mark::hit!(insert_empty_file); - syntax_editor.insert( - Position::first_child_of(scope_syntax), - syntax_factory.whitespace("\n\n"), - ); + syntax_editor + .insert(Position::first_child_of(scope_syntax), make.whitespace("\n\n")); syntax_editor.insert(Position::first_child_of(scope_syntax), use_item.syntax()); } } diff --git a/src/tools/rust-analyzer/crates/ide-db/src/lib.rs b/src/tools/rust-analyzer/crates/ide-db/src/lib.rs index 8d16826e191da..6b72a3033990b 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/lib.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/lib.rs @@ -61,7 +61,7 @@ use std::{fmt, mem::ManuallyDrop}; use base_db::{ CrateGraphBuilder, CratesMap, FileSourceRootInput, FileText, Files, Nonce, SourceDatabase, - SourceRoot, SourceRootId, SourceRootInput, query_group, set_all_crates_with_durability, + SourceRoot, SourceRootId, SourceRootInput, set_all_crates_with_durability, }; use hir::{ FilePositionWrapper, FileRangeWrapper, @@ -252,15 +252,20 @@ impl RootDatabase { } } -#[query_group::query_group] -pub trait LineIndexDatabase: base_db::SourceDatabase { - #[salsa::invoke_interned(line_index)] - fn line_index(&self, file_id: FileId) -> Arc; -} - -fn line_index(db: &dyn LineIndexDatabase, file_id: FileId) -> Arc { - let text = db.file_text(file_id).text(db); - Arc::new(LineIndex::new(text)) +pub fn line_index(db: &dyn SourceDatabase, file_id: FileId) -> &Arc { + #[salsa::interned] + pub struct InternedFileId { + id: FileId, + } + #[salsa::tracked(returns(ref))] + fn line_index<'db>( + db: &'db dyn SourceDatabase, + file_id: InternedFileId<'db>, + ) -> Arc { + let text = db.file_text(file_id.id(db)).text(db); + Arc::new(LineIndex::new(text)) + } + line_index(db, InternedFileId::new(db, file_id)) } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -380,7 +385,7 @@ pub enum Severity { Allow, } -#[derive(Debug, Clone, Copy)] +#[derive(Clone, Copy)] pub struct MiniCore<'a>(&'a str); impl<'a> MiniCore<'a> { @@ -395,6 +400,15 @@ impl<'a> MiniCore<'a> { } } +impl std::fmt::Debug for MiniCore<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("MiniCore") + // don't print the whole contents if they correspond to the default + .field(if self.0 == test_utils::MiniCore::RAW_SOURCE { &"" } else { &self.0 }) + .finish() + } +} + impl<'a> Default for MiniCore<'a> { #[inline] fn default() -> Self { diff --git a/src/tools/rust-analyzer/crates/ide-db/src/path_transform.rs b/src/tools/rust-analyzer/crates/ide-db/src/path_transform.rs index 407276a2defc9..2d4a6b8b5b1ba 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/path_transform.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/path_transform.rs @@ -278,7 +278,7 @@ impl Ctx<'_> { // `transform_path` may update a node's parent and that would break the // tree traversal. Thus all paths in the tree are collected into a vec // so that such operation is safe. - let (mut editor, item) = SyntaxEditor::new(self.transform_path(item)); + let (editor, item) = SyntaxEditor::new(self.transform_path(item)); preorder_rev(&item).filter_map(ast::Lifetime::cast).for_each(|lifetime| { if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) { editor.replace(lifetime.syntax(), subst.clone().syntax()); @@ -329,22 +329,22 @@ impl Ctx<'_> { result } - let (mut editor, root_path) = SyntaxEditor::new(path.clone()); + let (editor, root_path) = SyntaxEditor::new(path.clone()); let result = find_child_paths_and_ident_pats(&root_path); for sub_path in result { let new = self.transform_path(sub_path.syntax()); editor.replace(sub_path.syntax(), new); } - let (mut editor, update_sub_item) = SyntaxEditor::new(editor.finish().new_root().clone()); + let (editor, update_sub_item) = SyntaxEditor::new(editor.finish().new_root().clone()); let item = find_child_paths_and_ident_pats(&update_sub_item); for sub_path in item { - self.transform_path_or_ident_pat(&mut editor, &sub_path); + self.transform_path_or_ident_pat(&editor, &sub_path); } editor.finish().new_root().clone() } fn transform_path_or_ident_pat( &self, - editor: &mut SyntaxEditor, + editor: &SyntaxEditor, item: &Either, ) -> Option<()> { match item { @@ -353,7 +353,7 @@ impl Ctx<'_> { } } - fn transform_path_(&self, editor: &mut SyntaxEditor, path: &ast::Path) -> Option<()> { + fn transform_path_(&self, editor: &SyntaxEditor, path: &ast::Path) -> Option<()> { if path.qualifier().is_some() { return None; } @@ -448,7 +448,7 @@ impl Ctx<'_> { }; let found_path = self.target_module.find_path(self.source_scope.db, def, cfg)?; let res = mod_path_to_ast(&found_path, self.target_edition); - let (mut res_editor, res) = SyntaxEditor::with_ast_node(&res); + let (res_editor, res) = SyntaxEditor::with_ast_node(&res); if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) && let Some(segment) = res.segment() { @@ -522,11 +522,7 @@ impl Ctx<'_> { Some(()) } - fn transform_ident_pat( - &self, - editor: &mut SyntaxEditor, - ident_pat: &ast::IdentPat, - ) -> Option<()> { + fn transform_ident_pat(&self, editor: &SyntaxEditor, ident_pat: &ast::IdentPat) -> Option<()> { let name = ident_pat.name()?; let temp_path = make::path_from_text(&name.text()); diff --git a/src/tools/rust-analyzer/crates/ide-db/src/search.rs b/src/tools/rust-analyzer/crates/ide-db/src/search.rs index 69459a4b72dac..f41e29307007d 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/search.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/search.rs @@ -449,6 +449,8 @@ impl Definition { scope: None, include_self_kw_refs: None, search_self_mod: false, + included_categories: ReferenceCategory::all(), + exclude_library_files: false, } } } @@ -465,6 +467,10 @@ pub struct FindUsages<'a> { include_self_kw_refs: Option>, /// whether to search for the `self` module search_self_mod: bool, + /// categories to include while collecting usages + included_categories: ReferenceCategory, + /// whether to skip files from library source roots + exclude_library_files: bool, } impl<'a> FindUsages<'a> { @@ -495,6 +501,16 @@ impl<'a> FindUsages<'a> { self } + pub fn set_included_categories(mut self, categories: ReferenceCategory) -> Self { + self.included_categories = categories; + self + } + + pub fn set_exclude_library_files(mut self, exclude_library_files: bool) -> Self { + self.exclude_library_files = exclude_library_files; + self + } + pub fn at_least_one(&self) -> bool { let mut found = false; self.search(&mut |_, _| { @@ -516,14 +532,21 @@ impl<'a> FindUsages<'a> { fn scope_files<'b>( db: &'b RootDatabase, scope: &'b SearchScope, + exclude_library_files: bool, ) -> impl Iterator, EditionedFileId, TextRange)> + 'b { - scope.entries.iter().map(|(&file_id, &search_range)| { - let text = db.file_text(file_id.file_id(db)).text(db); - let search_range = - search_range.unwrap_or_else(|| TextRange::up_to(TextSize::of(&**text))); + scope + .entries + .iter() + .filter(move |(file_id, _)| { + !exclude_library_files || !is_library_file(db, file_id.file_id(db)) + }) + .map(|(&file_id, &search_range)| { + let text = db.file_text(file_id.file_id(db)).text(db); + let search_range = + search_range.unwrap_or_else(|| TextRange::up_to(TextSize::of(&**text))); - (text.clone(), file_id, search_range) - }) + (text.clone(), file_id, search_range) + }) } fn match_indices<'b>( @@ -649,6 +672,7 @@ impl<'a> FindUsages<'a> { fn collect_possible_aliases( sema: &Semantics<'_, RootDatabase>, container: Adt, + exclude_library_files: bool, ) -> Option<(FxHashSet, Vec>)> { fn insert_type_alias( db: &RootDatabase, @@ -682,9 +706,11 @@ impl<'a> FindUsages<'a> { }; let finder = Finder::new(current_to_process.as_bytes()); - for (file_text, file_id, search_range) in - FindUsages::scope_files(db, ¤t_to_process_search_scope) - { + for (file_text, file_id, search_range) in FindUsages::scope_files( + db, + ¤t_to_process_search_scope, + exclude_library_files, + ) { let tree = LazyCell::new(move || sema.parse(file_id).syntax().clone()); for offset in FindUsages::match_indices(&file_text, &finder, search_range) { @@ -869,7 +895,7 @@ impl<'a> FindUsages<'a> { } let Some((container_possible_aliases, is_possibly_self)) = - collect_possible_aliases(self.sema, container) + collect_possible_aliases(self.sema, container, self.exclude_library_files) else { return false; }; @@ -906,7 +932,7 @@ impl<'a> FindUsages<'a> { self, &finder, name, - FindUsages::scope_files(self.sema.db, search_scope), + FindUsages::scope_files(self.sema.db, search_scope, self.exclude_library_files), |path, name_position| { has_any_name(path, |name| container_possible_aliases.contains(name)) && !self_positions.contains(&name_position) @@ -931,6 +957,9 @@ impl<'a> FindUsages<'a> { Some(scope) => base.intersection(scope), } }; + if search_scope.entries.is_empty() { + return; + } let name = match (self.rename, self.def) { (Some(rename), _) => { @@ -982,7 +1011,9 @@ impl<'a> FindUsages<'a> { let finder = &Finder::new(name); let include_self_kw_refs = self.include_self_kw_refs.as_ref().map(|ty| (ty, Finder::new("Self"))); - for (text, file_id, search_range) in Self::scope_files(sema.db, &search_scope) { + for (text, file_id, search_range) in + Self::scope_files(sema.db, &search_scope, self.exclude_library_files) + { let tree = LazyCell::new(move || sema.parse(file_id).syntax().clone()); // Search for occurrences of the items name @@ -1039,7 +1070,9 @@ impl<'a> FindUsages<'a> { let is_crate_root = module.is_crate_root(self.sema.db).then(|| Finder::new("crate")); let finder = &Finder::new("super"); - for (text, file_id, search_range) in Self::scope_files(sema.db, &scope) { + for (text, file_id, search_range) in + Self::scope_files(sema.db, &scope, self.exclude_library_files) + { self.sema.db.unwind_if_revision_cancelled(); let tree = LazyCell::new(move || sema.parse(file_id).syntax().clone()); @@ -1118,6 +1151,10 @@ impl<'a> FindUsages<'a> { name_ref: &ast::NameRef, sink: &mut dyn FnMut(EditionedFileId, FileReference) -> bool, ) -> bool { + if self.is_excluded_name_ref(name_ref) { + return false; + } + // See https://github.com/rust-lang/rust-analyzer/pull/15864/files/e0276dc5ddc38c65240edb408522bb869f15afb4#r1389848845 let ty_eq = |ty: hir::Type<'_>| match (ty.as_adt(), self_ty.as_adt()) { (Some(ty), Some(self_ty)) => ty == self_ty, @@ -1146,6 +1183,10 @@ impl<'a> FindUsages<'a> { name_ref: &ast::NameRef, sink: &mut dyn FnMut(EditionedFileId, FileReference) -> bool, ) -> bool { + if self.is_excluded_name_ref(name_ref) { + return false; + } + match NameRefClass::classify(self.sema, name_ref) { Some(NameRefClass::Definition(def @ Definition::Module(_), _)) if def == self.def => { let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); @@ -1210,6 +1251,10 @@ impl<'a> FindUsages<'a> { name_ref: &ast::NameRef, sink: &mut dyn FnMut(EditionedFileId, FileReference) -> bool, ) -> bool { + if self.is_excluded_name_ref(name_ref) { + return false; + } + match NameRefClass::classify(self.sema, name_ref) { Some(NameRefClass::Definition(def, _)) if self.def == def @@ -1241,18 +1286,17 @@ impl<'a> FindUsages<'a> { }; sink(file_id, reference) } - Some(NameRefClass::Definition(def, _)) if self.include_self_kw_refs.is_some() => { - if self.include_self_kw_refs == def_to_ty(self.sema, &def) { - let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); - let reference = FileReference { - range, - name: FileReferenceNode::NameRef(name_ref.clone()), - category: ReferenceCategory::new(self.sema, &def, name_ref), - }; - sink(file_id, reference) - } else { - false - } + Some(NameRefClass::Definition(def, _)) + if self.include_self_kw_refs.is_some() + && self.include_self_kw_refs == def_to_ty(self.sema, &def) => + { + let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax()); + let reference = FileReference { + range, + name: FileReferenceNode::NameRef(name_ref.clone()), + category: ReferenceCategory::new(self.sema, &def, name_ref), + }; + sink(file_id, reference) } Some(NameRefClass::FieldShorthand { local_ref: local, @@ -1283,6 +1327,13 @@ impl<'a> FindUsages<'a> { } } + fn is_excluded_name_ref(&self, name_ref: &ast::NameRef) -> bool { + (!self.included_categories.contains(ReferenceCategory::TEST) + && is_name_ref_in_test(self.sema, name_ref)) + || (!self.included_categories.contains(ReferenceCategory::IMPORT) + && is_name_ref_in_import(name_ref)) + } + fn found_name( &self, name: &ast::Name, @@ -1409,3 +1460,8 @@ fn is_name_ref_in_test(sema: &Semantics<'_, RootDatabase>, name_ref: &ast::NameR None => false, }) } + +fn is_library_file(db: &RootDatabase, file_id: span::FileId) -> bool { + let source_root = db.file_source_root(file_id).source_root_id(db); + db.source_root(source_root).source_root(db).is_library +} diff --git a/src/tools/rust-analyzer/crates/ide-db/src/source_change.rs b/src/tools/rust-analyzer/crates/ide-db/src/source_change.rs index 4a83f707fcacd..81b679ead233c 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/source_change.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/source_change.rs @@ -285,11 +285,11 @@ impl SourceChangeBuilder { SyntaxEditor::new(node.ancestors().last().unwrap_or_else(|| node.clone())).0 } - pub fn add_file_edits(&mut self, file_id: impl Into, edit: SyntaxEditor) { + pub fn add_file_edits(&mut self, file_id: impl Into, editor: SyntaxEditor) { match self.file_editors.entry(file_id.into()) { - Entry::Occupied(mut entry) => entry.get_mut().merge(edit), + Entry::Occupied(mut entry) => entry.get_mut().merge(editor), Entry::Vacant(entry) => { - entry.insert(edit); + entry.insert(editor); } } } diff --git a/src/tools/rust-analyzer/crates/ide-db/src/symbol_index.rs b/src/tools/rust-analyzer/crates/ide-db/src/symbol_index.rs index 2ad3a51c3d9a3..55acf5abf868d 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/symbol_index.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/symbol_index.rs @@ -400,22 +400,16 @@ impl<'db> SymbolIndex<'db> { /// The symbol index for a given module. These modules should only be in source roots that /// are inside local_roots. pub fn module_symbols(db: &dyn HirDatabase, module: Module) -> &SymbolIndex<'_> { - // FIXME: - #[salsa::interned] - struct InternedModuleId { - id: hir::ModuleId, - } - #[salsa::tracked(returns(ref))] fn module_symbols<'db>( db: &'db dyn HirDatabase, - module: InternedModuleId<'db>, + module: hir::ModuleId, ) -> SymbolIndex<'db> { let _p = tracing::info_span!("module_symbols").entered(); // We call this without attaching because this runs in parallel, so we need to attach here. hir::attach_db(db, || { - let module: Module = module.id(db).into(); + let module: Module = module.into(); SymbolIndex::new(SymbolCollector::new_module( db, module, @@ -424,7 +418,7 @@ impl<'db> SymbolIndex<'db> { }) } - module_symbols(db, InternedModuleId::new(db, hir::ModuleId::from(module))) + module_symbols(db, hir::ModuleId::from(module)) } /// The symbol index for all extern prelude crates. diff --git a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string_exprs.rs b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string_exprs.rs index 8f25833fffb8d..6cc3334196d22 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string_exprs.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/format_string_exprs.rs @@ -125,7 +125,7 @@ pub fn parse_format_exprs(input: &str) -> Result<(String, Vec), ()> { // if the expression consists of a single number, like "0" or "12", it can refer to // format args in the order they are specified. // see: https://doc.rust-lang.org/std/fmt/#positional-parameters - if trimmed.chars().fold(true, |only_num, c| c.is_ascii_digit() && only_num) { + if trimmed.chars().all(|c| c.is_ascii_digit()) { output.push_str(trimmed); } else if matches!(state, State::Expr) { extracted_expressions.push(Arg::Expr(trimmed.into())); diff --git a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs index e30b21c139fad..11ba815dab2ec 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/node_ext.rs @@ -216,7 +216,12 @@ pub fn walk_ty(ty: &ast::Type, cb: &mut dyn FnMut(ast::Type) -> bool) { preorder.skip_subtree(); cb(ty); } - Some(ty) => { + Some(ty) => + { + #[expect( + clippy::collapsible_match, + reason = "it won't compile due to exhaustiveness" + )] if cb(ty) { preorder.skip_subtree(); } diff --git a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/suggest_name.rs b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/suggest_name.rs index 3a785fbe80a0f..09e6115320664 100644 --- a/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/suggest_name.rs +++ b/src/tools/rust-analyzer/crates/ide-db/src/syntax_helpers/suggest_name.rs @@ -123,6 +123,20 @@ impl NameGenerator { generator } + pub fn new_from_scope_non_locals(scope: Option>) -> Self { + let mut generator = Self::default(); + if let Some(scope) = scope { + scope.process_all_names(&mut |name, scope| { + if let hir::ScopeDef::Local(_) = scope { + return; + } + generator.insert(name.as_str()); + }); + } + + generator + } + /// Suggest a name without conflicts. If the name conflicts with existing names, /// it will try to resolve the conflict by adding a numeric suffix. pub fn suggest_name(&mut self, name: &str) -> SmolStr { diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs index 4ed71f0d3fb82..4c0985c7ae965 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs @@ -357,6 +357,7 @@ fn f() { fn arg_count_lambda() { check_diagnostics( r#" +//- minicore: fn fn main() { let f = |()| (); f(); diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs index 1abb50144d34f..6a37702fc50e2 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/missing_unsafe.rs @@ -845,6 +845,7 @@ fn foo(v: &Union) { fn union_destructuring() { check_diagnostics( r#" +//- minicore: fn union Union { field: u8 } fn foo(v @ Union { field: _field }: &Union) { // ^^^^^^ error: access to union field is unsafe and requires an unsafe function or block diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs index dec7be8b74275..2ec41d0528496 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs @@ -95,6 +95,7 @@ fn foo() -> u8 { fn remove_trailing_return_closure() { check_diagnostics( r#" +//- minicore: fn fn foo() -> u8 { let bar = || return 2; bar() //^^^^^^^^ 💡 weak: replace return ; with @@ -103,6 +104,7 @@ fn foo() -> u8 { ); check_diagnostics( r#" +//- minicore: fn fn foo() -> u8 { let bar = || { return 2; @@ -276,6 +278,7 @@ fn foo() -> u8 { fn replace_in_closure() { check_fix( r#" +//- minicore: fn fn foo() -> u8 { let bar = || return$0 2; bar() @@ -290,6 +293,7 @@ fn foo() -> u8 { ); check_fix( r#" +//- minicore: fn fn foo() -> u8 { let bar = || { return$0 2; diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs index ff0e6a254b6ae..98a4474ef1e94 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -10,7 +10,6 @@ use syntax::{ ast::{ self, BlockExpr, Expr, ExprStmt, HasArgList, edit::{AstNodeEdit, IndentLevel}, - syntax_factory::SyntaxFactory, }, }; @@ -235,7 +234,7 @@ fn remove_unnecessary_wrapper( let file_id = expr_ptr.file_id.original_file(db); let mut builder = SourceChangeBuilder::new(file_id.file_id(ctx.sema.db)); - let mut editor; + let editor; match inner_arg { // We're returning `()` Expr::TupleExpr(tup) if tup.fields().next().is_none() => { @@ -245,7 +244,7 @@ fn remove_unnecessary_wrapper( .and_then(Either::::cast)?; editor = builder.make_editor(parent.syntax()); - let make = SyntaxFactory::with_mappings(); + let make = editor.make(); match parent { Either::Left(ret_expr) => { @@ -261,8 +260,6 @@ fn remove_unnecessary_wrapper( editor.replace(stmt_list.syntax().parent()?, new_block.syntax()); } } - - editor.add_mappings(make.finish_with_mappings()); } _ => { editor = builder.make_editor(call_expr.syntax()); @@ -1248,7 +1245,7 @@ trait B {} fn test(a: &dyn A) -> &dyn B { a - //^ error: expected &(dyn B + 'static), found &(dyn A + 'static) + //^💡 error: expected &(dyn B + 'static), found &(dyn A + 'static) } "#, ); diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/unlinked_file.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/unlinked_file.rs index d7a0a3b0f59d4..570319c347d49 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/unlinked_file.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/handlers/unlinked_file.rs @@ -4,11 +4,9 @@ use std::iter; use hir::crate_def_map; use hir::{InFile, ModuleSource}; -use ide_db::base_db; use ide_db::text_edit::TextEdit; -use ide_db::{ - FileId, FileRange, LineIndexDatabase, base_db::SourceDatabase, source_change::SourceChange, -}; +use ide_db::{FileId, FileRange, base_db::SourceDatabase, source_change::SourceChange}; +use ide_db::{base_db, line_index}; use paths::Utf8Component; use syntax::{ AstNode, TextRange, @@ -26,7 +24,7 @@ pub(crate) fn unlinked_file( acc: &mut Vec, file_id: FileId, ) { - let mut range = TextRange::up_to(ctx.sema.db.line_index(file_id).len()); + let mut range = TextRange::up_to(line_index(ctx.sema.db, file_id).len()); let fixes = fixes(ctx, file_id, range); // FIXME: This is a hack for the vscode extension to notice whether there is an autofix or not before having to resolve diagnostics. // This is to prevent project linking popups from appearing when there is an autofix. https://github.com/rust-lang/rust-analyzer/issues/14523 diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests.rs index 3dc155efe96b9..fc49542e3ccdd 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests.rs @@ -4,9 +4,10 @@ mod overly_long_real_world_cases; use hir::setup_tracing; use ide_db::{ - LineIndexDatabase, RootDatabase, + RootDatabase, assists::{AssistResolveStrategy, ExprFillDefaultMode}, base_db::SourceDatabase, + line_index, }; use itertools::Itertools; use stdx::trim_indent; @@ -242,7 +243,7 @@ pub(crate) fn check_diagnostics_with_config( .into_group_map(); for file_id in files { let file_id = file_id.file_id(&db); - let line_index = db.line_index(file_id); + let line_index = line_index(&db, file_id); let mut actual = annotations.remove(&file_id).unwrap_or_default(); let mut expected = extract_annotations(db.file_text(file_id).text(&db)); diff --git a/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests/overly_long_real_world_cases.rs b/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests/overly_long_real_world_cases.rs index c6831d818aac1..9883bcc84ff8c 100644 --- a/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests/overly_long_real_world_cases.rs +++ b/src/tools/rust-analyzer/crates/ide-diagnostics/src/tests/overly_long_real_world_cases.rs @@ -6,6 +6,7 @@ use crate::tests::check_diagnostics_with_disabled; fn tracing_infinite_repeat() { check_diagnostics_with_disabled( r#" +//- minicore: fn //- /core.rs crate:core #[rustc_builtin_macro] #[macro_export] diff --git a/src/tools/rust-analyzer/crates/ide/src/annotations.rs b/src/tools/rust-analyzer/crates/ide/src/annotations.rs index 21b2339c722c7..f716f94d7141b 100644 --- a/src/tools/rust-analyzer/crates/ide/src/annotations.rs +++ b/src/tools/rust-analyzer/crates/ide/src/annotations.rs @@ -216,7 +216,12 @@ pub(crate) fn resolve_annotation( *data = find_all_refs( &Semantics::new(db), pos, - &FindAllRefsConfig { search_scope: None, ra_fixture: config.ra_fixture }, + &FindAllRefsConfig { + search_scope: None, + ra_fixture: config.ra_fixture, + exclude_imports: false, + exclude_tests: false, + }, ) .map(|result| { result diff --git a/src/tools/rust-analyzer/crates/ide/src/folding_ranges.rs b/src/tools/rust-analyzer/crates/ide/src/folding_ranges.rs index 375e42cc833ee..965190b27df42 100644 --- a/src/tools/rust-analyzer/crates/ide/src/folding_ranges.rs +++ b/src/tools/rust-analyzer/crates/ide/src/folding_ranges.rs @@ -487,7 +487,7 @@ mod tests { "The amount of folds is different than the expected amount" ); - for (fold, (range, attr, collapsed_text)) in folds.iter().zip(ranges.into_iter()) { + for (fold, (range, attr, collapsed_text)) in folds.iter().zip(ranges) { assert_eq!(fold.range.start(), range.start(), "mismatched start of folding ranges"); assert_eq!(fold.range.end(), range.end(), "mismatched end of folding ranges"); diff --git a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs index c8e01e21ec9ce..e6ef7b894913e 100644 --- a/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs +++ b/src/tools/rust-analyzer/crates/ide/src/highlight_related.rs @@ -2083,6 +2083,7 @@ fn test() { fn return_in_macros() { check( r#" +//- minicore: fn macro_rules! N { ($i:ident, $x:expr, $blk:expr) => { for $i in 0..$x { diff --git a/src/tools/rust-analyzer/crates/ide/src/hover/render.rs b/src/tools/rust-analyzer/crates/ide/src/hover/render.rs index af78e9a40c9f9..4d712bf0f0e0c 100644 --- a/src/tools/rust-analyzer/crates/ide/src/hover/render.rs +++ b/src/tools/rust-analyzer/crates/ide/src/hover/render.rs @@ -1009,8 +1009,9 @@ fn closure_ty( display_target: DisplayTarget, ) -> Option { let c = original.as_closure()?; - let mut captures_rendered = c.captured_items(sema.db) - .into_iter() + let captures = c.captured_items(sema.db); + let mut captures_rendered = captures + .iter() .map(|it| { let borrow_kind = match it.kind() { CaptureKind::SharedRef => "immutable borrow", @@ -1018,7 +1019,7 @@ fn closure_ty( CaptureKind::MutableRef => "mutable borrow", CaptureKind::Move => "move", }; - format!("* `{}` by {}", it.display_place(sema.db), borrow_kind) + format!("* `{}` by {}", it.display_place_source_code(sema.db, display_target.edition), borrow_kind) }) .join("\n"); if captures_rendered.trim().is_empty() { @@ -1031,8 +1032,8 @@ fn closure_ty( } }; walk_and_push_ty(sema.db, original, &mut push_new_def); - c.capture_types(sema.db).into_iter().for_each(|ty| { - walk_and_push_ty(sema.db, &ty, &mut push_new_def); + captures.iter().for_each(|capture| { + walk_and_push_ty(sema.db, &capture.ty(sema.db), &mut push_new_def); }); let adjusted = if let Some(adjusted_ty) = adjusted { diff --git a/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs b/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs index 9c53b05539e27..491471428fc61 100644 --- a/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs +++ b/src/tools/rust-analyzer/crates/ide/src/hover/tests.rs @@ -423,7 +423,7 @@ fn main() { ## Captures * `x.f1` by move - * `(*x.f2.0.0).f` by mutable borrow + * `x.f2.0.0.f` by mutable borrow "#]], ); check( @@ -11078,6 +11078,58 @@ impl PublicFlags for NoteDialects { ); } +#[test] +fn hover_recursive_const_fn() { + check( + r#" +//- minicore: option +enum Child { + Static { child: &'static MyEnum }, +} + +enum MyEnum { + Unit, + Array(Child), +} + +impl MyEnum { + pub const fn static_array(child: &'static MyEnum) -> Self { + MyEnum::Array(Child::Static { child }) + } +} + +pub trait MyTrait { + const MY_CONST: &'static MyEnum; +} + +impl MyTrait for Option where T: MyTrait { + const MY_CONST: &'static MyEnum = &MyEnum::static_array(T::MY_CONST); +} + +impl MyTrait for () { + const MY_CONST: &'static MyEnum = &MyEnum::Unit; +} + +pub struct Address; + +impl MyTrait for Address { + const MY_CONST$0: &'static MyEnum = ( as MyTrait>::MY_CONST); +} + "#, + expect![[r#" + *MY_CONST* + + ```rust + ra_test_fixture::Address + ``` + + ```rust + const MY_CONST: &'static MyEnum = &Array(Static { child: &Unit }) + ``` + "#]], + ); +} + #[test] fn bounds_from_container_do_not_panic() { check( diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints.rs index f51d7f5207863..0d2239c71fe9e 100644 --- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints.rs +++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints.rs @@ -235,7 +235,7 @@ fn hints( param_name::hints(hints, famous_defs, config, file_id, ast::Expr::from(it)) } ast::Expr::ClosureExpr(it) => { - closure_captures::hints(hints, famous_defs, config, it.clone()); + closure_captures::hints(hints, famous_defs, config, it.clone(), file_id.edition(sema.db)); closure_ret::hints(hints, famous_defs, config, display_target, it) }, ast::Expr::RangeExpr(it) => range_exclusive::hints(hints, famous_defs, config, it), @@ -1085,9 +1085,10 @@ fn foo() { fn closure_dependency_cycle_no_panic() { check( r#" +//- minicore: fn fn foo() { let closure; - // ^^^^^^^ impl Fn() + // ^^^^^^^ impl FnOnce() closure = || { closure(); }; @@ -1095,9 +1096,9 @@ fn foo() { fn bar() { let closure1; - // ^^^^^^^^ impl Fn() + // ^^^^^^^^ impl FnOnce() let closure2; - // ^^^^^^^^ impl Fn() + // ^^^^^^^^ impl FnOnce() closure1 = || { closure2(); }; diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/bind_pat.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/bind_pat.rs index b901c6b67d3e4..f194bb183e18d 100644 --- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/bind_pat.rs +++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/bind_pat.rs @@ -906,7 +906,7 @@ fn fallible() -> ControlFlow<()> { check_with_config( InlayHintsConfig { type_hints: true, ..DISABLED_CONFIG }, r#" -//- minicore: fn +//- minicore: fn, add, builtin_impls fn main() { let x = || 2; //^ impl Fn() -> i32 @@ -928,7 +928,7 @@ fn main() { ..DISABLED_CONFIG }, r#" -//- minicore: fn +//- minicore: fn, add, builtin_impls fn main() { let x = || 2; //^ || -> i32 @@ -950,7 +950,7 @@ fn main() { ..DISABLED_CONFIG }, r#" -//- minicore: fn +//- minicore: fn, add, builtin_impls fn main() { let x = || 2; //^ … @@ -1094,6 +1094,7 @@ fn test(v: S<(S, S<()>)>, f: F) { check_edit( TEST_CONFIG, r#" +//- minicore: fn fn test(t: T) { let f = |a, b, c| {}; let result = f(42, "", t); diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/closure_captures.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/closure_captures.rs index f8d4ddc6eb57a..f4ac9c42f459c 100644 --- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/closure_captures.rs +++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/closure_captures.rs @@ -3,6 +3,7 @@ //! Tests live in [`bind_pat`][super::bind_pat] module. use ide_db::famous_defs::FamousDefs; use ide_db::text_edit::{TextRange, TextSize}; +use span::Edition; use stdx::{TupleExt, never}; use syntax::ast::{self, AstNode}; @@ -15,6 +16,7 @@ pub(super) fn hints( FamousDefs(sema, _): &FamousDefs<'_, '_>, config: &InlayHintsConfig<'_>, closure: ast::ClosureExpr, + edition: Edition, ) -> Option<()> { if !config.closure_capture_hints { return None; @@ -60,7 +62,7 @@ pub(super) fn hints( hir::CaptureKind::MutableRef => "&mut ", hir::CaptureKind::Move => "", }, - capture.display_place(sema.db) + capture.display_place_source_code(sema.db, edition) ); if never!(label.is_empty()) { continue; diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/lifetime.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/lifetime.rs index 4982b60f1dc8e..7a8a6eb84a5fa 100644 --- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/lifetime.rs +++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/lifetime.rs @@ -31,7 +31,6 @@ pub(super) fn fn_hints( let param_list = func.param_list()?; let generic_param_list = func.generic_param_list(); let ret_type = func.ret_type(); - let self_param = param_list.self_param().filter(|it| it.amp_token().is_some()); let gpl_append_range = func.name()?.syntax().text_range(); hints_( acc, @@ -49,7 +48,7 @@ pub(super) fn fn_hints( }), generic_param_list, ret_type, - self_param, + param_list.self_param(), |acc, allocated_lifetimes| { acc.push(InlayHint { range: gpl_append_range, @@ -208,6 +207,20 @@ fn hints_( Some(lt) => matches!(lt.text().as_str(), "'_"), None => true, }; + let self_param = self_param.and_then(|it| { + if it.colon_token().is_none() { + return Some((it.amp_token(), it.lifetime())); + } + it.ty().map(|ty| { + let ref_type = ty.syntax().descendants().find_map(ast::RefType::cast); + let lifetime = ref_type + .as_ref() + .and_then(|it| it.lifetime()) + .or_else(|| ty.syntax().descendants().find_map(ast::Lifetime::cast)); + (ref_type.and_then(|it| it.amp_token()), lifetime) + }) + }); + let self_param = self_param.filter(|(amp, lt)| amp.is_some() || lt.is_some()); let mk_lt_hint = |t: SyntaxToken, label: String| InlayHint { range: t.text_range(), @@ -222,10 +235,9 @@ fn hints_( let potential_lt_refs = { let mut acc: Vec<_> = vec![]; - if let Some(self_param) = &self_param { - let lifetime = self_param.lifetime(); + if let Some((amp_token, lifetime)) = self_param.clone() { let is_elided = is_elided(&lifetime); - acc.push((None, self_param.amp_token(), lifetime, is_elided)); + acc.push((None, amp_token, lifetime, is_elided)); } params.for_each(|(name, ty)| { // FIXME: check path types @@ -240,17 +252,14 @@ fn hints_( is_trivial = false; true } - ast::Type::PathType(t) => { + ast::Type::PathType(t) if t.path() .and_then(|it| it.segment()) .and_then(|it| it.parenthesized_arg_list()) - .is_some() - { - is_trivial = false; - true - } else { - false - } + .is_some() => + { + is_trivial = false; + true } _ => false, }) @@ -339,17 +348,14 @@ fn hints_( is_trivial = false; true } - ast::Type::PathType(t) => { + ast::Type::PathType(t) if t.path() .and_then(|it| it.segment()) .and_then(|it| it.parenthesized_arg_list()) - .is_some() - { - is_trivial = false; - true - } else { - false - } + .is_some() => + { + is_trivial = false; + true } _ => false, }) @@ -439,6 +445,9 @@ fn nested_out(a: &()) -> & &X< &()>{} //^'0 ^'0 ^'0 ^'0 impl () { + fn foo(self, x: &()) -> &() {} + // ^^^<'0> + // ^'0 ^'0 fn foo(&self) {} // ^^^<'0> // ^'0 @@ -448,6 +457,10 @@ impl () { fn foo(&self, a: &()) -> &() {} // ^^^<'0, '1> // ^'0 ^'1 ^'0 + fn foo(self: &Self, a: &()) -> &() {} + // ^^^<'0, '1> + // ^'0 ^'1 ^'0 + } "#, ); diff --git a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs index 8dddf9d37e4fb..c780ce5864963 100644 --- a/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs +++ b/src/tools/rust-analyzer/crates/ide/src/inlay_hints/param_name.rs @@ -427,6 +427,7 @@ fn main() { fn param_hints_on_closure() { check_params( r#" +//- minicore: fn fn main() { let clo = |a: u8, b: u8| a + b; clo( diff --git a/src/tools/rust-analyzer/crates/ide/src/interpret.rs b/src/tools/rust-analyzer/crates/ide/src/interpret.rs index 3741822547e45..994d325cccde6 100644 --- a/src/tools/rust-analyzer/crates/ide/src/interpret.rs +++ b/src/tools/rust-analyzer/crates/ide/src/interpret.rs @@ -1,5 +1,5 @@ use hir::{ConstEvalError, DefWithBody, DisplayTarget, Semantics}; -use ide_db::{FilePosition, LineIndexDatabase, RootDatabase, base_db::SourceDatabase}; +use ide_db::{FilePosition, RootDatabase, base_db::SourceDatabase, line_index}; use std::time::{Duration, Instant}; use stdx::format_to; use syntax::{AstNode, TextRange, algo::ancestors_at_offset, ast}; @@ -40,7 +40,7 @@ fn find_and_interpret(db: &RootDatabase, position: FilePosition) -> Option<(Dura let path = source_root.path_for_file(&file_id).map(|x| x.to_string()); let path = path.as_deref().unwrap_or(""); - match db.line_index(file_id).try_line_col(text_range.start()) { + match line_index(db, file_id).try_line_col(text_range.start()) { Some(line_col) => format!("file://{path}:{}:{}", line_col.line + 1, line_col.col), None => format!("file://{path} range {text_range:?}"), } @@ -68,7 +68,7 @@ pub(crate) fn render_const_eval_error( let source_root = db.source_root(source_root).source_root(db); let path = source_root.path_for_file(&file_id).map(|x| x.to_string()); let path = path.as_deref().unwrap_or(""); - match db.line_index(file_id).try_line_col(text_range.start()) { + match line_index(db, file_id).try_line_col(text_range.start()) { Some(line_col) => format!("file://{path}:{}:{}", line_col.line + 1, line_col.col), None => format!("file://{path} range {text_range:?}"), } diff --git a/src/tools/rust-analyzer/crates/ide/src/lib.rs b/src/tools/rust-analyzer/crates/ide/src/lib.rs index 270998cdf751c..0af2a1f82039e 100644 --- a/src/tools/rust-analyzer/crates/ide/src/lib.rs +++ b/src/tools/rust-analyzer/crates/ide/src/lib.rs @@ -64,9 +64,11 @@ use cfg::CfgOptions; use fetch_crates::CrateInfo; use hir::{ChangeWithProcMacros, EditionedFileId, crate_def_map, sym}; use ide_db::base_db::relevant_crates; +use ide_db::base_db::salsa::Durability; +use ide_db::line_index; use ide_db::ra_fixture::RaFixtureAnalysis; use ide_db::{ - FxHashMap, FxIndexSet, LineIndexDatabase, + FxHashMap, FxIndexSet, base_db::{ CrateOrigin, CrateWorkspaceData, Env, FileSet, SourceDatabase, VfsPath, salsa::{Cancelled, Database}, @@ -126,7 +128,8 @@ pub use ide_assists::{ }; pub use ide_completion::{ CallableSnippets, CompletionConfig, CompletionFieldsToResolve, CompletionItem, - CompletionItemKind, CompletionItemRefMode, CompletionRelevance, Snippet, SnippetScope, + CompletionItemImport, CompletionItemKind, CompletionItemRefMode, CompletionRelevance, Snippet, + SnippetScope, }; pub use ide_db::{ FileId, FilePosition, FileRange, RootDatabase, Severity, SymbolKind, @@ -202,10 +205,18 @@ impl AnalysisHost { self.db.per_query_memory_usage() } pub fn trigger_cancellation(&mut self) { - self.db.trigger_cancellation(); + // We need to do a synthetic write right now due to how fixpoint cycles handle cancellation + // the revision bump there is a reset marker for clearing fixpoint poisoning. + // That is `trigger_cancellation` is currently bugged wrt to cancellation. + // self.db.trigger_cancellation(); + self.db.synthetic_write(Durability::LOW); } pub fn trigger_garbage_collection(&mut self) { - self.db.trigger_lru_eviction(); + // We need to do a synthetic write right now due to how fixpoint cycles handle cancellation + // the revision bump there is a reset marker for clearing fixpoint poisoning. + // That is `trigger_lru_eviction` is currently bugged wrt to cancellation. + // self.db.trigger_lru_eviction(); + self.db.synthetic_write(Durability::LOW); // SAFETY: `trigger_lru_eviction` triggers cancellation, so all running queries were canceled. unsafe { hir::collect_ty_garbage() }; } @@ -358,7 +369,7 @@ impl Analysis { /// Gets the file's `LineIndex`: data structure to convert between absolute /// offsets and line/column representation. pub fn file_line_index(&self, file_id: FileId) -> Cancellable> { - self.with_db(|db| db.line_index(file_id)) + self.with_db(|db| line_index(db, file_id).clone()) } /// Selects the next syntactic nodes encompassing the range. @@ -768,7 +779,7 @@ impl Analysis { &self, config: &CompletionConfig<'_>, position: FilePosition, - imports: impl IntoIterator + std::panic::UnwindSafe, + imports: impl IntoIterator + std::panic::UnwindSafe, ) -> Cancellable> { Ok(self .with_db(|db| ide_completion::resolve_completion_edits(db, config, position, imports))? diff --git a/src/tools/rust-analyzer/crates/ide/src/references.rs b/src/tools/rust-analyzer/crates/ide/src/references.rs index 0288099bbcc2b..4ed3d1c7d7e4a 100644 --- a/src/tools/rust-analyzer/crates/ide/src/references.rs +++ b/src/tools/rust-analyzer/crates/ide/src/references.rs @@ -20,6 +20,7 @@ use hir::{PathResolution, Semantics}; use ide_db::{ FileId, RootDatabase, + base_db::SourceDatabase, defs::{Definition, NameClass, NameRefClass}, helpers::pick_best_token, ra_fixture::{RaFixtureConfig, UpmapFromRaFixture}, @@ -91,6 +92,8 @@ pub struct Declaration { pub struct FindAllRefsConfig<'a> { pub search_scope: Option, pub ra_fixture: RaFixtureConfig<'a>, + pub exclude_imports: bool, + pub exclude_tests: bool, } /// Find all references to the item at the given position. @@ -125,10 +128,23 @@ pub(crate) fn find_all_refs( ) -> Option> { let _p = tracing::info_span!("find_all_refs").entered(); let syntax = sema.parse_guess_edition(position.file_id).syntax().clone(); + let exclude_library_refs = !is_library_file(sema.db, position.file_id); let make_searcher = |literal_search: bool| { move |def: Definition| { - let mut usages = - def.usages(sema).set_scope(config.search_scope.as_ref()).include_self_refs().all(); + let mut included_categories = ReferenceCategory::all(); + if config.exclude_imports { + included_categories.remove(ReferenceCategory::IMPORT); + } + if config.exclude_tests { + included_categories.remove(ReferenceCategory::TEST); + } + let mut usages = def + .usages(sema) + .set_scope(config.search_scope.as_ref()) + .set_included_categories(included_categories) + .set_exclude_library_files(exclude_library_refs) + .include_self_refs() + .all(); if literal_search { retain_adt_literal_usages(&mut usages, def, sema); } @@ -207,6 +223,11 @@ pub(crate) fn find_all_refs( } } +fn is_library_file(db: &RootDatabase, file_id: FileId) -> bool { + let source_root = db.file_source_root(file_id).source_root_id(db); + db.source_root(source_root).source_root(db).is_library +} + pub(crate) fn find_defs( sema: &Semantics<'_, RootDatabase>, syntax: &SyntaxNode, @@ -469,7 +490,7 @@ mod tests { #[test] fn exclude_tests() { - check( + check_with_filters( r#" fn test_func() {} @@ -482,6 +503,8 @@ fn test() { test_func(); } "#, + false, + false, expect![[r#" test_func Function FileId(0) 0..17 3..12 @@ -490,7 +513,7 @@ fn test() { "#]], ); - check( + check_with_filters( r#" fn test_func() {} @@ -503,6 +526,8 @@ fn test() { test_func(); } "#, + false, + false, expect![[r#" test_func Function FileId(0) 0..17 3..12 @@ -510,6 +535,133 @@ fn test() { FileId(0) 96..105 test "#]], ); + + check_with_filters( + r#" +fn test_func() {} + +fn func() { + test_func$0(); +} + +#[test] +fn test() { + test_func(); +} +"#, + false, + true, + expect![[r#" + test_func Function FileId(0) 0..17 3..12 + + FileId(0) 35..44 + "#]], + ); + } + + #[test] + fn exclude_library_refs_filtering() { + // exclude refs in 3rd party lib + check_with_filters( + r#" +//- /main.rs crate:main deps:dep +use dep::foo; + +fn main() { + foo$0(); +} + +//- /dep/lib.rs crate:dep new_source_root:library +pub fn foo() {} + +pub fn also_calls_foo() { + foo(); +} +"#, + false, + false, + expect![[r#" + foo Function FileId(1) 0..15 7..10 + + FileId(0) 9..12 import + FileId(0) 31..34 + "#]], + ); + + // exclude refs in stdlib + check_with_filters( + r#" +//- minicore: option +fn main() { + let _ = core::option::Option::Some$0(0); +} +"#, + false, + false, + expect![[r#" + Some Variant FileId(1) 5999..6031 6024..6028 + + FileId(0) 46..50 + "#]], + ); + + // keep refs in local lib + check_with_filters( + r#" +//- /main.rs crate:main deps:dep +use dep::foo; + +fn main() { + foo$0(); +} + +//- /dep/lib.rs crate:dep +pub fn foo() {} + +pub fn also_calls_foo() { + foo(); +} +"#, + false, + false, + expect![[r#" + foo Function FileId(1) 0..15 7..10 + + FileId(0) 9..12 import + FileId(0) 31..34 + FileId(1) 47..50 + "#]], + ); + } + + #[test] + fn find_refs_from_library_source_keeps_library_refs() { + check_with_filters( + r#" +//- /main.rs crate:main deps:dep +use dep::foo; + +fn main() { + foo(); +} + +//- /dep/lib.rs crate:dep new_source_root:library +pub fn foo$0() {} + +pub fn also_calls_foo() { + foo(); +} +"#, + false, + false, + expect![[r#" + foo Function FileId(1) 0..15 7..10 + + FileId(0) 9..12 import + FileId(0) 31..34 + FileId(1) 47..50 + "#]], + ); } #[test] @@ -1556,18 +1708,39 @@ fn main() { } fn check(#[rust_analyzer::rust_fixture] ra_fixture: &str, expect: Expect) { - check_with_scope(ra_fixture, None, expect) + check_with_filters(ra_fixture, false, false, expect) + } + + fn check_with_filters( + #[rust_analyzer::rust_fixture] ra_fixture: &str, + exclude_imports: bool, + exclude_tests: bool, + expect: Expect, + ) { + check_with_scope_and_filters(ra_fixture, None, exclude_imports, exclude_tests, expect) } fn check_with_scope( #[rust_analyzer::rust_fixture] ra_fixture: &str, search_scope: Option<&mut dyn FnMut(&RootDatabase) -> SearchScope>, expect: Expect, + ) { + check_with_scope_and_filters(ra_fixture, search_scope, false, false, expect) + } + + fn check_with_scope_and_filters( + #[rust_analyzer::rust_fixture] ra_fixture: &str, + search_scope: Option<&mut dyn FnMut(&RootDatabase) -> SearchScope>, + exclude_imports: bool, + exclude_tests: bool, + expect: Expect, ) { let (analysis, pos) = fixture::position(ra_fixture); let config = FindAllRefsConfig { search_scope: search_scope.map(|it| it(&analysis.db)), ra_fixture: RaFixtureConfig::default(), + exclude_imports, + exclude_tests, }; let refs = analysis.find_all_refs(pos, &config).unwrap().unwrap(); @@ -2567,6 +2740,7 @@ fn test() { fn goto_ref_fn_kw() { check( r#" +//- minicore: fn macro_rules! N { ($i:ident, $x:expr, $blk:expr) => { for $i in 0..$x { diff --git a/src/tools/rust-analyzer/crates/ide/src/runnables.rs b/src/tools/rust-analyzer/crates/ide/src/runnables.rs index 3b472390d2f43..60750608a5b49 100644 --- a/src/tools/rust-analyzer/crates/ide/src/runnables.rs +++ b/src/tools/rust-analyzer/crates/ide/src/runnables.rs @@ -606,14 +606,14 @@ fn has_test_function_or_multiple_test_submodules( return true; } } - hir::ModuleDef::Module(submodule) => { + hir::ModuleDef::Module(submodule) if has_test_function_or_multiple_test_submodules( sema, &submodule, consider_exported_main, - ) { - number_of_test_submodules += 1; - } + ) => + { + number_of_test_submodules += 1; } _ => (), } diff --git a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/highlight.rs b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/highlight.rs index 0e101ab235f54..6823736d12730 100644 --- a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/highlight.rs +++ b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/highlight.rs @@ -307,12 +307,12 @@ fn highlight_name_ref( h |= HlMod::Consuming; } // highlight unsafe traits as unsafe only in their implementations - Definition::Trait(trait_) if trait_.is_unsafe(db) => { - if ast::Impl::for_trait_name_ref(&name_ref) - .is_some_and(|impl_| impl_.unsafe_token().is_some()) - { - h |= HlMod::Unsafe; - } + Definition::Trait(trait_) + if trait_.is_unsafe(db) + && ast::Impl::for_trait_name_ref(&name_ref) + .is_some_and(|impl_| impl_.unsafe_token().is_some()) => + { + h |= HlMod::Unsafe; } Definition::Function(_) => { let is_unsafe = name_ref diff --git a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/tests.rs b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/tests.rs index e8d185b7b6369..d687cb40a9697 100644 --- a/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/tests.rs +++ b/src/tools/rust-analyzer/crates/ide/src/syntax_highlighting/tests.rs @@ -1348,7 +1348,7 @@ fn benchmark_syntax_highlighting_parser() { }) .count() }; - assert_eq!(hash, 1606); + assert_eq!(hash, 1631); } #[test] diff --git a/src/tools/rust-analyzer/crates/ide/src/view_syntax_tree.rs b/src/tools/rust-analyzer/crates/ide/src/view_syntax_tree.rs index ecd93e8b28190..7732b180829ad 100644 --- a/src/tools/rust-analyzer/crates/ide/src/view_syntax_tree.rs +++ b/src/tools/rust-analyzer/crates/ide/src/view_syntax_tree.rs @@ -1,6 +1,6 @@ use hir::Semantics; use ide_db::{ - FileId, LineIndexDatabase, RootDatabase, + FileId, RootDatabase, line_index, line_index::{LineCol, LineIndex}, }; use span::{TextRange, TextSize}; @@ -20,7 +20,7 @@ use triomphe::Arc; // | VS Code | **Rust Syntax Tree** | pub(crate) fn view_syntax_tree(db: &RootDatabase, file_id: FileId) -> String { let sema = Semantics::new(db); - let line_index = db.line_index(file_id); + let line_index = line_index(db, file_id).clone(); let parse = sema.parse_guess_edition(file_id); let ctx = SyntaxTreeCtx { line_index, in_string: None }; diff --git a/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs b/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs index cc09a1aae7a6d..614411598b2fd 100644 --- a/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs +++ b/src/tools/rust-analyzer/crates/intern/src/symbol/symbols.rs @@ -123,7 +123,6 @@ define_symbols! { all, alloc_layout, alloc, - allow_internal_unsafe, allow, any, as_str, @@ -178,6 +177,7 @@ define_symbols! { Continue, convert, copy, + use_cloned, Copy, core_panic, core, @@ -200,7 +200,9 @@ define_symbols! { derive, discriminant_kind, discriminant_type, - dispatch_from_dyn,destruct, + dispatch_from_dyn, + destruct, + bikeshed_guaranteed_no_drop, div_assign, div, doc, @@ -233,6 +235,8 @@ define_symbols! { async_fn_once_output, async_fn_mut, async_fn, + async_fn_kind_helper, + async_fn_kind_upvars, call_ref_future, call_once_future, fn_ptr_addr, @@ -298,6 +302,8 @@ define_symbols! { iter, Iterator, iterator, + fused_iterator, + async_iterator, keyword, lang, lang_items, @@ -541,4 +547,35 @@ define_symbols! { DispatchFromDyn, define_opaque, marker, + abi_unadjusted, + allocator_internals, + allow_internal_unsafe, + allow_internal_unstable, + cfg_emscripten_wasm_eh, + cfg_target_has_reliable_f16_f128, + compiler_builtins, + custom_mir, + eii_internals, + field_representing_type_raw, + intrinsics, + link_cfg, + more_maybe_bounds, + negative_bounds, + pattern_complexity_limit, + profiler_runtime, + rustc_attrs, + staged_api, + test_unstable_lint, + builtin_syntax, + link_llvm_intrinsics, + needs_panic_runtime, + panic_runtime, + pattern_types, + rustdoc_internals, + contracts_internals, + freeze_impls, + unsized_fn_params, + field, + field_base, + field_type, } diff --git a/src/tools/rust-analyzer/crates/load-cargo/src/lib.rs b/src/tools/rust-analyzer/crates/load-cargo/src/lib.rs index 68bf78e037c0c..839df181597ba 100644 --- a/src/tools/rust-analyzer/crates/load-cargo/src/lib.rs +++ b/src/tools/rust-analyzer/crates/load-cargo/src/lib.rs @@ -884,7 +884,7 @@ mod tests { let fsc = builder.build(); let src = SourceRootConfig { fsc, local_filesets: vec![0, 1, 2, 3] }; let mut vc = src.source_root_parent_map().into_iter().collect::>(); - vc.sort_by(|x, y| x.0.0.cmp(&y.0.0)); + vc.sort_by_key(|x| x.0.0); assert_eq!(vc, vec![(SourceRootId(2), SourceRootId(1)), (SourceRootId(3), SourceRootId(1))]) } @@ -899,7 +899,7 @@ mod tests { let fsc = builder.build(); let src = SourceRootConfig { fsc, local_filesets: vec![0, 1, 3] }; let mut vc = src.source_root_parent_map().into_iter().collect::>(); - vc.sort_by(|x, y| x.0.0.cmp(&y.0.0)); + vc.sort_by_key(|x| x.0.0); assert_eq!(vc, vec![(SourceRootId(3), SourceRootId(1)),]) } @@ -914,7 +914,7 @@ mod tests { let fsc = builder.build(); let src = SourceRootConfig { fsc, local_filesets: vec![0, 1, 3] }; let mut vc = src.source_root_parent_map().into_iter().collect::>(); - vc.sort_by(|x, y| x.0.0.cmp(&y.0.0)); + vc.sort_by_key(|x| x.0.0); assert_eq!(vc, vec![(SourceRootId(3), SourceRootId(1)),]) } @@ -930,7 +930,7 @@ mod tests { let fsc = builder.build(); let src = SourceRootConfig { fsc, local_filesets: vec![0, 1] }; let mut vc = src.source_root_parent_map().into_iter().collect::>(); - vc.sort_by(|x, y| x.0.0.cmp(&y.0.0)); + vc.sort_by_key(|x| x.0.0); assert_eq!(vc, vec![(SourceRootId(1), SourceRootId(0)),]) } @@ -946,7 +946,7 @@ mod tests { let fsc = builder.build(); let src = SourceRootConfig { fsc, local_filesets: vec![0, 1] }; let mut vc = src.source_root_parent_map().into_iter().collect::>(); - vc.sort_by(|x, y| x.0.0.cmp(&y.0.0)); + vc.sort_by_key(|x| x.0.0); assert_eq!(vc, vec![(SourceRootId(1), SourceRootId(0)),]) } diff --git a/src/tools/rust-analyzer/crates/parser/src/grammar.rs b/src/tools/rust-analyzer/crates/parser/src/grammar.rs index 1ff8a56b580f0..0623e7ea19ab0 100644 --- a/src/tools/rust-analyzer/crates/parser/src/grammar.rs +++ b/src/tools/rust-analyzer/crates/parser/src/grammar.rs @@ -228,50 +228,64 @@ fn opt_visibility(p: &mut Parser<'_>, in_tuple_field: bool) -> bool { let m = p.start(); p.bump(T![pub]); - if p.at(T!['(']) { - match p.nth(1) { - // test crate_visibility - // pub(crate) struct S; - // pub(self) struct S; - // pub(super) struct S; - - // test_err crate_visibility_empty_recover - // pub() struct S; - - // test pub_parens_typepath - // struct B(pub (super::A)); - // struct B(pub (crate::A,)); - T![crate] | T![self] | T![super] | T![ident] | T![')'] if p.nth(2) != T![:] => { - // If we are in a tuple struct, then the parens following `pub` - // might be an tuple field, not part of the visibility. So in that - // case we don't want to consume an identifier. - - // test pub_tuple_field - // struct MyStruct(pub (u32, u32)); - // struct MyStruct(pub (u32)); - // struct MyStruct(pub ()); - if !(in_tuple_field && matches!(p.nth(1), T![ident] | T![')'])) { - p.bump(T!['(']); - paths::vis_path(p); - p.expect(T![')']); - } - } - // test crate_visibility_in - // pub(in super::A) struct S; - // pub(in crate) struct S; - T![in] => { - p.bump(T!['(']); - p.bump(T![in]); - paths::vis_path(p); - p.expect(T![')']); - } - _ => {} - } - } + opt_visibility_inner(p, in_tuple_field); m.complete(p, VISIBILITY); true } +fn opt_visibility_inner(p: &mut Parser<'_>, in_tuple_field: bool) -> bool { + if !p.at(T!['(']) { + return false; + } + + match p.nth(1) { + // test crate_visibility + // pub(crate) struct S; + // pub(self) struct S; + // pub(super) struct S; + + // test_err crate_visibility_empty_recover + // pub() struct S; + + // test pub_parens_typepath + // struct B(pub (super::A)); + // struct B(pub (crate::A,)); + T![crate] | T![self] | T![super] | T![ident] | T![')'] + if p.nth(2) != T![:] + // If we are in a tuple struct, then the parens following `pub` + // might be an tuple field, not part of the visibility. So in that + // case we don't want to consume an identifier. + + // test pub_tuple_field + // struct MyStruct(pub (u32, u32)); + // struct MyStruct(pub (u32)); + // struct MyStruct(pub ()); + && !(in_tuple_field && matches!(p.nth(1), T![ident] | T![')'])) => + { + let m = p.start(); + p.bump(T!['(']); + paths::vis_path(p); + p.expect(T![')']); + m.complete(p, VISIBILITY_INNER); + return true; + } + // test crate_visibility_in + // pub(in super::A) struct S; + // pub(in crate) struct S; + T![in] => { + let m = p.start(); + p.bump(T!['(']); + p.bump(T![in]); + paths::vis_path(p); + p.expect(T![')']); + m.complete(p, VISIBILITY_INNER); + return true; + } + _ => {} + } + false +} + fn opt_rename(p: &mut Parser<'_>) { if p.at(T![as]) { let m = p.start(); diff --git a/src/tools/rust-analyzer/crates/parser/src/grammar/items.rs b/src/tools/rust-analyzer/crates/parser/src/grammar/items.rs index c0acdde2a7240..c5c6e04dd49aa 100644 --- a/src/tools/rust-analyzer/crates/parser/src/grammar/items.rs +++ b/src/tools/rust-analyzer/crates/parser/src/grammar/items.rs @@ -167,6 +167,25 @@ pub(super) fn opt_item(p: &mut Parser<'_>, m: Marker, is_in_extern: bool) -> Res has_mods = true; } + if p.at(T![impl]) + && p.nth(1) == T!['('] + && ((matches!(p.nth(2), T![crate] | T![super] | T![self]) && p.nth(3) == T![')']) + || p.nth(2) == T![in]) + { + // test impl_restrictions + // pub unsafe impl(crate) trait Foo {} + // impl(in super::bar) trait Bar {} + // impl () {} + // impl (i32) {} + let m = p.start(); + p.bump(T![impl]); + if !opt_visibility_inner(p, false) { + p.error("expected an impl restriction"); + } + m.complete(p, IMPL_RESTRICTION); + has_mods = true; + } + // test default_item // default impl T for Foo {} if p.at_contextual_kw(T![default]) { @@ -216,6 +235,7 @@ pub(super) fn opt_item(p: &mut Parser<'_>, m: Marker, is_in_extern: bool) -> Res T![trait] => traits::trait_(p, m), T![impl] => traits::impl_(p, m), + T![type] if p.nth(1) == T![const] => consts::konst(p, m), T![type] => type_alias(p, m), // test extern_block @@ -247,6 +267,9 @@ fn opt_item_without_modifiers(p: &mut Parser<'_>, m: Marker) -> Result<(), Marke T![use] => use_item::use_(p, m), T![mod] => mod_item(p, m), + // test type_const + // type const FOO: i32 = 2; + T![type] if la == T![const] => consts::konst(p, m), T![type] => type_alias(p, m), T![struct] => adt::strukt(p, m), T![enum] => adt::enum_(p, m), diff --git a/src/tools/rust-analyzer/crates/parser/src/grammar/items/adt.rs b/src/tools/rust-analyzer/crates/parser/src/grammar/items/adt.rs index cfba4c3a77b2a..a030190ad34ba 100644 --- a/src/tools/rust-analyzer/crates/parser/src/grammar/items/adt.rs +++ b/src/tools/rust-analyzer/crates/parser/src/grammar/items/adt.rs @@ -133,7 +133,30 @@ pub(crate) fn record_field_list(p: &mut Parser<'_>) { // struct S { #[attr] f: f32 } attributes::outer_attrs(p); opt_visibility(p, false); - p.eat(T![unsafe]); + + if p.at(T![mut]) && p.nth(1) == T!['('] { + // test record_mut_restrictions_before + // struct Foo { mut(super) unsafe i: i32 } + let m = p.start(); + p.bump(T![mut]); + if !opt_visibility_inner(p, false) { + p.error("expected a mut restriction"); + } + m.complete(p, MUT_RESTRICTION); + } + + // We accept mut restriction both after and before `unsafe`, as the order is undecided yet. + if p.eat(T![unsafe]) && p.at(T![mut]) && p.nth(1) == T!['('] { + // test record_mut_restrictions_after + // struct Foo { unsafe mut(super) i: i32 } + let m = p.start(); + p.bump(T![mut]); + if !opt_visibility_inner(p, false) { + p.error("expected a mut restriction"); + } + m.complete(p, MUT_RESTRICTION); + } + if p.at(IDENT) { name(p); p.expect(T![:]); @@ -175,6 +198,18 @@ fn tuple_field_list(p: &mut Parser<'_>) { // struct S (#[attr] f32); attributes::outer_attrs(p); let has_vis = opt_visibility(p, true); + + if p.at(T![mut]) && p.nth(1) == T!['('] { + // test tuple_mut_restrictions + // struct Foo(pub(crate) mut(super) i32); + let m = p.start(); + p.bump(T![mut]); + if !opt_visibility_inner(p, false) { + p.error("expected a mut restriction"); + } + m.complete(p, MUT_RESTRICTION); + } + if !p.at_ts(types::TYPE_FIRST) { p.error("expected a type"); if has_vis { diff --git a/src/tools/rust-analyzer/crates/parser/src/grammar/items/consts.rs b/src/tools/rust-analyzer/crates/parser/src/grammar/items/consts.rs index e6a8aca5861a6..cc5bb73bdcabc 100644 --- a/src/tools/rust-analyzer/crates/parser/src/grammar/items/consts.rs +++ b/src/tools/rust-analyzer/crates/parser/src/grammar/items/consts.rs @@ -3,6 +3,7 @@ use super::*; // test const_item // const C: u32 = 92; pub(super) fn konst(p: &mut Parser<'_>, m: Marker) { + p.eat(T![type]); p.bump(T![const]); const_or_static(p, m, true); } diff --git a/src/tools/rust-analyzer/crates/parser/src/syntax_kind/generated.rs b/src/tools/rust-analyzer/crates/parser/src/syntax_kind/generated.rs index 9cd48f2aa4f3e..59fa3ee773882 100644 --- a/src/tools/rust-analyzer/crates/parser/src/syntax_kind/generated.rs +++ b/src/tools/rust-analyzer/crates/parser/src/syntax_kind/generated.rs @@ -218,6 +218,7 @@ pub enum SyntaxKind { IDENT_PAT, IF_EXPR, IMPL, + IMPL_RESTRICTION, IMPL_TRAIT_TYPE, INDEX_EXPR, INFER_TYPE, @@ -247,6 +248,7 @@ pub enum SyntaxKind { MATCH_GUARD, METHOD_CALL_EXPR, MODULE, + MUT_RESTRICTION, NAME, NAME_REF, NEVER_TYPE, @@ -318,6 +320,7 @@ pub enum SyntaxKind { VARIANT, VARIANT_LIST, VISIBILITY, + VISIBILITY_INNER, WHERE_CLAUSE, WHERE_PRED, WHILE_EXPR, @@ -399,6 +402,7 @@ impl SyntaxKind { | IDENT_PAT | IF_EXPR | IMPL + | IMPL_RESTRICTION | IMPL_TRAIT_TYPE | INDEX_EXPR | INFER_TYPE @@ -428,6 +432,7 @@ impl SyntaxKind { | MATCH_GUARD | METHOD_CALL_EXPR | MODULE + | MUT_RESTRICTION | NAME | NAME_REF | NEVER_TYPE @@ -499,6 +504,7 @@ impl SyntaxKind { | VARIANT | VARIANT_LIST | VISIBILITY + | VISIBILITY_INNER | WHERE_CLAUSE | WHERE_PRED | WHILE_EXPR diff --git a/src/tools/rust-analyzer/crates/parser/test_data/generated/runner.rs b/src/tools/rust-analyzer/crates/parser/test_data/generated/runner.rs index 71978390df6a8..6dfb78b12878b 100644 --- a/src/tools/rust-analyzer/crates/parser/test_data/generated/runner.rs +++ b/src/tools/rust-analyzer/crates/parser/test_data/generated/runner.rs @@ -348,6 +348,10 @@ mod ok { run_and_expect_no_errors("test_data/parser/inline/ok/impl_item_never_type.rs"); } #[test] + fn impl_restrictions() { + run_and_expect_no_errors("test_data/parser/inline/ok/impl_restrictions.rs"); + } + #[test] fn impl_trait_type() { run_and_expect_no_errors("test_data/parser/inline/ok/impl_trait_type.rs"); } @@ -556,6 +560,14 @@ mod ok { run_and_expect_no_errors("test_data/parser/inline/ok/record_literal_field_with_attr.rs"); } #[test] + fn record_mut_restrictions_after() { + run_and_expect_no_errors("test_data/parser/inline/ok/record_mut_restrictions_after.rs"); + } + #[test] + fn record_mut_restrictions_before() { + run_and_expect_no_errors("test_data/parser/inline/ok/record_mut_restrictions_before.rs"); + } + #[test] fn record_pat_field() { run_and_expect_no_errors("test_data/parser/inline/ok/record_pat_field.rs"); } @@ -658,6 +670,10 @@ mod ok { run_and_expect_no_errors("test_data/parser/inline/ok/tuple_field_attrs.rs"); } #[test] + fn tuple_mut_restrictions() { + run_and_expect_no_errors("test_data/parser/inline/ok/tuple_mut_restrictions.rs"); + } + #[test] fn tuple_pat() { run_and_expect_no_errors("test_data/parser/inline/ok/tuple_pat.rs"); } #[test] fn tuple_pat_fields() { @@ -672,6 +688,8 @@ mod ok { #[test] fn type_alias() { run_and_expect_no_errors("test_data/parser/inline/ok/type_alias.rs"); } #[test] + fn type_const() { run_and_expect_no_errors("test_data/parser/inline/ok/type_const.rs"); } + #[test] fn type_item_type_params() { run_and_expect_no_errors("test_data/parser/inline/ok/type_item_type_params.rs"); } diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/err/crate_visibility_empty_recover.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/err/crate_visibility_empty_recover.rast index 172bc099b58d0..37116ca895be6 100644 --- a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/err/crate_visibility_empty_recover.rast +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/err/crate_visibility_empty_recover.rast @@ -2,8 +2,9 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility.rast index 8738292a9f7fe..c946f19321e58 100644 --- a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility.rast +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility.rast @@ -2,12 +2,13 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - PATH - PATH_SEGMENT - NAME_REF - CRATE_KW "crate" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + CRATE_KW "crate" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " @@ -18,12 +19,13 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - PATH - PATH_SEGMENT - NAME_REF - SELF_KW "self" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SELF_KW "self" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " @@ -34,12 +36,13 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - PATH - PATH_SEGMENT - NAME_REF - SUPER_KW "super" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility_in.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility_in.rast index ac45c56956790..1a551ea2212fe 100644 --- a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility_in.rast +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/crate_visibility_in.rast @@ -2,19 +2,20 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - IN_KW "in" - WHITESPACE " " - PATH + VISIBILITY_INNER + L_PAREN "(" + IN_KW "in" + WHITESPACE " " PATH + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + COLON2 "::" PATH_SEGMENT NAME_REF - SUPER_KW "super" - COLON2 "::" - PATH_SEGMENT - NAME_REF - IDENT "A" - R_PAREN ")" + IDENT "A" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " @@ -25,14 +26,15 @@ SOURCE_FILE STRUCT VISIBILITY PUB_KW "pub" - L_PAREN "(" - IN_KW "in" - WHITESPACE " " - PATH - PATH_SEGMENT - NAME_REF - CRATE_KW "crate" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + IN_KW "in" + WHITESPACE " " + PATH + PATH_SEGMENT + NAME_REF + CRATE_KW "crate" + R_PAREN ")" WHITESPACE " " STRUCT_KW "struct" WHITESPACE " " diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rast new file mode 100644 index 0000000000000..5f2680cbaa921 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rast @@ -0,0 +1,80 @@ +SOURCE_FILE + TRAIT + VISIBILITY + PUB_KW "pub" + WHITESPACE " " + UNSAFE_KW "unsafe" + WHITESPACE " " + IMPL_RESTRICTION + IMPL_KW "impl" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + CRATE_KW "crate" + R_PAREN ")" + WHITESPACE " " + TRAIT_KW "trait" + WHITESPACE " " + NAME + IDENT "Foo" + WHITESPACE " " + ASSOC_ITEM_LIST + L_CURLY "{" + R_CURLY "}" + WHITESPACE "\n" + TRAIT + IMPL_RESTRICTION + IMPL_KW "impl" + VISIBILITY_INNER + L_PAREN "(" + IN_KW "in" + WHITESPACE " " + PATH + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + COLON2 "::" + PATH_SEGMENT + NAME_REF + IDENT "bar" + R_PAREN ")" + WHITESPACE " " + TRAIT_KW "trait" + WHITESPACE " " + NAME + IDENT "Bar" + WHITESPACE " " + ASSOC_ITEM_LIST + L_CURLY "{" + R_CURLY "}" + WHITESPACE "\n" + IMPL + IMPL_KW "impl" + WHITESPACE " " + TUPLE_TYPE + L_PAREN "(" + R_PAREN ")" + WHITESPACE " " + ASSOC_ITEM_LIST + L_CURLY "{" + R_CURLY "}" + WHITESPACE "\n" + IMPL + IMPL_KW "impl" + WHITESPACE " " + PAREN_TYPE + L_PAREN "(" + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + R_PAREN ")" + WHITESPACE " " + ASSOC_ITEM_LIST + L_CURLY "{" + R_CURLY "}" + WHITESPACE "\n" diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rs b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rs new file mode 100644 index 0000000000000..0a46b158affc9 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/impl_restrictions.rs @@ -0,0 +1,4 @@ +pub unsafe impl(crate) trait Foo {} +impl(in super::bar) trait Bar {} +impl () {} +impl (i32) {} diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rast new file mode 100644 index 0000000000000..ebe3c81468322 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rast @@ -0,0 +1,35 @@ +SOURCE_FILE + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "Foo" + WHITESPACE " " + RECORD_FIELD_LIST + L_CURLY "{" + WHITESPACE " " + RECORD_FIELD + UNSAFE_KW "unsafe" + WHITESPACE " " + MUT_RESTRICTION + MUT_KW "mut" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + R_PAREN ")" + WHITESPACE " " + NAME + IDENT "i" + COLON ":" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + WHITESPACE " " + R_CURLY "}" + WHITESPACE "\n" diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rs b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rs new file mode 100644 index 0000000000000..bf37f60f85442 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_after.rs @@ -0,0 +1 @@ +struct Foo { unsafe mut(super) i: i32 } diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rast new file mode 100644 index 0000000000000..7f76e737f7eef --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rast @@ -0,0 +1,35 @@ +SOURCE_FILE + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "Foo" + WHITESPACE " " + RECORD_FIELD_LIST + L_CURLY "{" + WHITESPACE " " + RECORD_FIELD + MUT_RESTRICTION + MUT_KW "mut" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + R_PAREN ")" + WHITESPACE " " + UNSAFE_KW "unsafe" + WHITESPACE " " + NAME + IDENT "i" + COLON ":" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + WHITESPACE " " + R_CURLY "}" + WHITESPACE "\n" diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rs b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rs new file mode 100644 index 0000000000000..9bbb80f205a06 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/record_mut_restrictions_before.rs @@ -0,0 +1 @@ +struct Foo { mut(super) unsafe i: i32 } diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rast new file mode 100644 index 0000000000000..944133ff8551c --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rast @@ -0,0 +1,37 @@ +SOURCE_FILE + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "Foo" + TUPLE_FIELD_LIST + L_PAREN "(" + TUPLE_FIELD + VISIBILITY + PUB_KW "pub" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + CRATE_KW "crate" + R_PAREN ")" + WHITESPACE " " + MUT_RESTRICTION + MUT_KW "mut" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + R_PAREN ")" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + R_PAREN ")" + SEMICOLON ";" + WHITESPACE "\n" diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rs b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rs new file mode 100644 index 0000000000000..42653b0043b65 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/tuple_mut_restrictions.rs @@ -0,0 +1 @@ +struct Foo(pub(crate) mut(super) i32); diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rast new file mode 100644 index 0000000000000..9ceae9e44b3a8 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rast @@ -0,0 +1,22 @@ +SOURCE_FILE + CONST + TYPE_KW "type" + WHITESPACE " " + CONST_KW "const" + WHITESPACE " " + NAME + IDENT "FOO" + COLON ":" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + WHITESPACE " " + EQ "=" + WHITESPACE " " + LITERAL + INT_NUMBER "2" + SEMICOLON ";" + WHITESPACE "\n" diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rs b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rs new file mode 100644 index 0000000000000..8e2c4259227a0 --- /dev/null +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/inline/ok/type_const.rs @@ -0,0 +1 @@ +type const FOO: i32 = 2; diff --git a/src/tools/rust-analyzer/crates/parser/test_data/parser/ok/0012_visibility.rast b/src/tools/rust-analyzer/crates/parser/test_data/parser/ok/0012_visibility.rast index 3d9322947b35a..348498daa990b 100644 --- a/src/tools/rust-analyzer/crates/parser/test_data/parser/ok/0012_visibility.rast +++ b/src/tools/rust-analyzer/crates/parser/test_data/parser/ok/0012_visibility.rast @@ -52,12 +52,13 @@ SOURCE_FILE FN VISIBILITY PUB_KW "pub" - L_PAREN "(" - PATH - PATH_SEGMENT - NAME_REF - CRATE_KW "crate" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + CRATE_KW "crate" + R_PAREN ")" WHITESPACE " " FN_KW "fn" WHITESPACE " " @@ -75,12 +76,13 @@ SOURCE_FILE FN VISIBILITY PUB_KW "pub" - L_PAREN "(" - PATH - PATH_SEGMENT - NAME_REF - SUPER_KW "super" - R_PAREN ")" + VISIBILITY_INNER + L_PAREN "(" + PATH + PATH_SEGMENT + NAME_REF + SUPER_KW "super" + R_PAREN ")" WHITESPACE " " FN_KW "fn" WHITESPACE " " @@ -98,24 +100,25 @@ SOURCE_FILE FN VISIBILITY PUB_KW "pub" - L_PAREN "(" - IN_KW "in" - WHITESPACE " " - PATH + VISIBILITY_INNER + L_PAREN "(" + IN_KW "in" + WHITESPACE " " PATH PATH + PATH + PATH_SEGMENT + NAME_REF + IDENT "foo" + COLON2 "::" PATH_SEGMENT NAME_REF - IDENT "foo" + IDENT "bar" COLON2 "::" PATH_SEGMENT NAME_REF - IDENT "bar" - COLON2 "::" - PATH_SEGMENT - NAME_REF - IDENT "baz" - R_PAREN ")" + IDENT "baz" + R_PAREN ")" WHITESPACE " " FN_KW "fn" WHITESPACE " " diff --git a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs index cd8944aa6170a..248de70f0e71d 100644 --- a/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs +++ b/src/tools/rust-analyzer/crates/proc-macro-api/src/legacy_protocol/msg/flat.rs @@ -495,6 +495,10 @@ impl<'a, T: SpanTransformer> Writer<'a, '_, T, tt::iter::TtIt } } + #[expect( + clippy::explicit_counter_loop, + reason = "it looks better the current way since we use `first_tt` before the loop" + )] fn subtree(&mut self, idx: usize, n_tt: usize, subtree: tt::iter::TtIter<'a>) { let mut first_tt = self.token_tree.len(); self.token_tree.resize(first_tt + n_tt, !0); diff --git a/src/tools/rust-analyzer/crates/project-model/src/env.rs b/src/tools/rust-analyzer/crates/project-model/src/env.rs index ab45917a5663b..37cfcd554524a 100644 --- a/src/tools/rust-analyzer/crates/project-model/src/env.rs +++ b/src/tools/rust-analyzer/crates/project-model/src/env.rs @@ -160,7 +160,7 @@ env.RA_TEST_NOT_AN_OBJECT = "value" ("RA_TEST_UNSET", None), ] .iter() - .map(|(k, v)| (k.to_string(), v.map(ToString::to_string))) + .map(|(k, v)| (k.to_string(), v.map(str::to_owned))) .collect(); let env = cargo_config_env(&Some(config), &extra_env); assert_eq!(env.get("RA_TEST_WORKSPACE_DIR").as_deref(), Some(cwd.join("").as_str())); diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/Cargo.toml b/src/tools/rust-analyzer/crates/rust-analyzer/Cargo.toml index beb83a8173a09..da2ec740197ea 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/Cargo.toml +++ b/src/tools/rust-analyzer/crates/rust-analyzer/Cargo.toml @@ -28,7 +28,7 @@ dissimilar.workspace = true ide-completion.workspace = true indexmap.workspace = true itertools.workspace = true -scip = "0.5.2" +scip = "0.7.1" lsp-types = { version = "=0.95.0", features = ["proposed"] } parking_lot = "0.12.4" xflags = "0.3.2" diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/analysis_stats.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/analysis_stats.rs index e56727d39d671..bf9a66bf3fc29 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/analysis_stats.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/analysis_stats.rs @@ -26,8 +26,9 @@ use ide::{ InlayHintsConfig, LineCol, RaFixtureConfig, RootDatabase, }; use ide_db::{ - EditionedFileId, LineIndexDatabase, SnippetCap, + EditionedFileId, SnippetCap, base_db::{SourceDatabase, salsa::Database}, + line_index, }; use itertools::Itertools; use load_cargo::{LoadCargoConfig, ProcMacroServerChoice, load_workspace}; @@ -1487,7 +1488,7 @@ fn location_csv_expr(db: &RootDatabase, vfs: &Vfs, sm: &BodySourceMap, expr_id: let node = src.map(|e| e.to_node(&root).syntax().clone()); let original_range = node.as_ref().original_file_range_rooted(db); let path = vfs.file_path(original_range.file_id.file_id(db)); - let line_index = db.line_index(original_range.file_id.file_id(db)); + let line_index = line_index(db, original_range.file_id.file_id(db)); let text_range = original_range.range; let (start, end) = (line_index.line_col(text_range.start()), line_index.line_col(text_range.end())); @@ -1503,7 +1504,7 @@ fn location_csv_pat(db: &RootDatabase, vfs: &Vfs, sm: &BodySourceMap, pat_id: Pa let node = src.map(|e| e.to_node(&root).syntax().clone()); let original_range = node.as_ref().original_file_range_rooted(db); let path = vfs.file_path(original_range.file_id.file_id(db)); - let line_index = db.line_index(original_range.file_id.file_id(db)); + let line_index = line_index(db, original_range.file_id.file_id(db)); let text_range = original_range.range; let (start, end) = (line_index.line_col(text_range.start()), line_index.line_col(text_range.end())); @@ -1522,7 +1523,7 @@ fn expr_syntax_range<'a>( let node = src.map(|e| e.to_node(&root).syntax().clone()); let original_range = node.as_ref().original_file_range_rooted(db); let path = vfs.file_path(original_range.file_id.file_id(db)); - let line_index = db.line_index(original_range.file_id.file_id(db)); + let line_index = line_index(db, original_range.file_id.file_id(db)); let text_range = original_range.range; let (start, end) = (line_index.line_col(text_range.start()), line_index.line_col(text_range.end())); @@ -1543,7 +1544,7 @@ fn pat_syntax_range<'a>( let node = src.map(|e| e.to_node(&root).syntax().clone()); let original_range = node.as_ref().original_file_range_rooted(db); let path = vfs.file_path(original_range.file_id.file_id(db)); - let line_index = db.line_index(original_range.file_id.file_id(db)); + let line_index = line_index(db, original_range.file_id.file_id(db)); let text_range = original_range.range; let (start, end) = (line_index.line_col(text_range.start()), line_index.line_col(text_range.end())); diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/diagnostics.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/diagnostics.rs index efbaad3c4936e..e50e1c26bb971 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/diagnostics.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/diagnostics.rs @@ -6,7 +6,7 @@ use rustc_hash::FxHashSet; use hir::{Crate, Module, db::HirDatabase, sym}; use ide::{AnalysisHost, AssistResolveStrategy, Diagnostic, DiagnosticsConfig, Severity}; -use ide_db::{LineIndexDatabase, base_db::SourceDatabase}; +use ide_db::{base_db::SourceDatabase, line_index}; use load_cargo::{LoadCargoConfig, ProcMacroServerChoice, load_workspace_at}; use crate::cli::{flags, progress_report::ProgressReport}; @@ -99,7 +99,7 @@ impl flags::Diagnostics { } let Diagnostic { code, message, range, severity, .. } = diagnostic; - let line_index = db.line_index(range.file_id); + let line_index = line_index(db, range.file_id); let start = line_index.line_col(range.range.start()); let end = line_index.line_col(range.range.end()); bar.println(format!( diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/lsif.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/lsif.rs index 3950a581fd776..4f6de6850abbd 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/lsif.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/lsif.rs @@ -8,7 +8,7 @@ use ide::{ RootDatabase, StaticIndex, StaticIndexedFile, TokenId, TokenStaticData, VendoredLibrariesConfig, }; -use ide_db::{LineIndexDatabase, line_index::WideEncoding}; +use ide_db::{line_index, line_index::WideEncoding}; use load_cargo::{LoadCargoConfig, ProcMacroServerChoice, load_workspace}; use lsp_types::lsif; use project_model::{CargoConfig, ProjectManifest, ProjectWorkspace, RustLibSource}; @@ -120,9 +120,9 @@ impl LsifManager<'_, '_> { } let file_id = id.file_id; let doc_id = self.get_file_id(file_id); - let line_index = self.db.line_index(file_id); + let line_index = line_index(self.db, file_id); let line_index = LineIndex { - index: line_index, + index: line_index.clone(), encoding: PositionEncoding::Wide(WideEncoding::Utf16), endings: LineEndings::Unix, }; @@ -241,9 +241,9 @@ impl LsifManager<'_, '_> { let StaticIndexedFile { file_id, tokens, folds, .. } = file; let doc_id = self.get_file_id(file_id); let text = self.analysis.file_text(file_id).unwrap(); - let line_index = self.db.line_index(file_id); + let line_index = line_index(self.db, file_id); let line_index = LineIndex { - index: line_index, + index: line_index.clone(), encoding: PositionEncoding::Wide(WideEncoding::Utf16), endings: LineEndings::Unix, }; diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/run_tests.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/run_tests.rs index e8c88cadf6f0a..0f7ef84a0eb81 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/run_tests.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/run_tests.rs @@ -2,7 +2,7 @@ use hir::{Crate, Module}; use hir_ty::db::HirDatabase; -use ide_db::{LineIndexDatabase, base_db::SourceDatabase}; +use ide_db::{base_db::SourceDatabase, line_index}; use profile::StopWatch; use project_model::{CargoConfig, RustLibSource}; use syntax::TextRange; @@ -38,7 +38,7 @@ impl flags::RunTests { }) .filter(|x| x.is_test(db)); let span_formatter = |file_id, text_range: TextRange| { - let line_col = match db.line_index(file_id).try_line_col(text_range.start()) { + let line_col = match line_index(db, file_id).try_line_col(text_range.start()) { None => " (unknown line col)".to_owned(), Some(x) => format!("#{}:{}", x.line + 1, x.col), }; diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/scip.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/scip.rs index ef6d4399e663c..bca38ed82fdd5 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/scip.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/scip.rs @@ -7,7 +7,7 @@ use ide::{ RootDatabase, StaticIndex, StaticIndexedFile, SymbolInformationKind, TextRange, TokenId, TokenStaticData, VendoredLibrariesConfig, }; -use ide_db::LineIndexDatabase; +use ide_db::line_index; use load_cargo::{LoadCargoConfig, ProcMacroServerChoice, load_workspace_at}; use rustc_hash::{FxHashMap, FxHashSet}; use scip::types::{self as scip_types, SymbolInformation}; @@ -348,7 +348,7 @@ fn get_relative_filepath( fn get_line_index(db: &RootDatabase, file_id: FileId) -> LineIndex { LineIndex { - index: db.line_index(file_id), + index: line_index(db, file_id).clone(), encoding: PositionEncoding::Utf8, endings: LineEndings::Unix, } diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/unresolved_references.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/unresolved_references.rs index 2d9b870f4de82..f8eacbb670587 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/unresolved_references.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/cli/unresolved_references.rs @@ -1,7 +1,7 @@ //! Reports references in code that the IDE layer cannot resolve. use hir::{AnyDiagnostic, Crate, Module, Semantics, db::HirDatabase, sym}; use ide::{AnalysisHost, RootDatabase, TextRange}; -use ide_db::{FxHashSet, LineIndexDatabase as _, base_db::SourceDatabase, defs::NameRefClass}; +use ide_db::{FxHashSet, base_db::SourceDatabase, defs::NameRefClass, line_index}; use load_cargo::{LoadCargoConfig, ProcMacroServerChoice, load_workspace_at}; use parser::SyntaxKind; use syntax::{AstNode, WalkEvent, ast}; @@ -75,7 +75,7 @@ impl flags::UnresolvedReferences { let file_path = vfs.file_path(file_id); eprintln!("processing crate: {crate_name}, module: {file_path}",); - let line_index = db.line_index(file_id); + let line_index = line_index(db, file_id); let file_text = db.file_text(file_id); for range in find_unresolved_references(db, &sema, file_id, &module) { diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/handlers/request.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/handlers/request.rs index c1806c82c724c..0c1c067ffa919 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/handlers/request.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/handlers/request.rs @@ -7,9 +7,9 @@ use anyhow::Context; use base64::{Engine, prelude::BASE64_STANDARD}; use ide::{ - AssistKind, AssistResolveStrategy, Cancellable, CompletionFieldsToResolve, FilePosition, - FileRange, FileStructureConfig, FindAllRefsConfig, HoverAction, HoverGotoTypeData, - InlayFieldsToResolve, Query, RangeInfo, ReferenceCategory, Runnable, RunnableKind, + AssistKind, AssistResolveStrategy, Cancellable, CompletionFieldsToResolve, + CompletionItemImport, FilePosition, FileRange, FileStructureConfig, FindAllRefsConfig, + HoverAction, HoverGotoTypeData, InlayFieldsToResolve, Query, RangeInfo, Runnable, RunnableKind, SingleResolve, SourceChange, TextEdit, }; use ide_db::{FxHashMap, SymbolKind}; @@ -1233,7 +1233,10 @@ pub(crate) fn handle_completion_resolve( .resolve_completion_edits( &forced_resolve_completions_config, position, - resolve_data.imports.into_iter().map(|import| import.full_import_path), + resolve_data.imports.into_iter().map(|import| CompletionItemImport { + path: import.full_import_path, + as_underscore: import.as_underscore, + }), )? .into_iter() .flat_map(|edit| edit.into_iter().map(|indel| to_proto::text_edit(&line_index, indel))) @@ -1396,12 +1399,13 @@ pub(crate) fn handle_references( let exclude_imports = snap.config.find_all_refs_exclude_imports(); let exclude_tests = snap.config.find_all_refs_exclude_tests(); - let Some(refs) = snap.analysis.find_all_refs( position, &FindAllRefsConfig { search_scope: None, ra_fixture: snap.config.ra_fixture(snap.minicore()), + exclude_imports, + exclude_tests, }, )? else { @@ -1423,12 +1427,7 @@ pub(crate) fn handle_references( refs.references .into_iter() .flat_map(|(file_id, refs)| { - refs.into_iter() - .filter(|&(_, category)| { - (!exclude_imports || !category.contains(ReferenceCategory::IMPORT)) - && (!exclude_tests || !category.contains(ReferenceCategory::TEST)) - }) - .map(move |(range, _)| FileRange { file_id, range }) + refs.into_iter().map(move |(range, _)| FileRange { file_id, range }) }) .chain(decl) }) @@ -2211,7 +2210,10 @@ fn show_ref_command_link( *position, &FindAllRefsConfig { search_scope: None, + ra_fixture: snap.config.ra_fixture(snap.minicore()), + exclude_imports: snap.config.find_all_refs_exclude_imports(), + exclude_tests: snap.config.find_all_refs_exclude_tests(), }, ) .unwrap_or(None) diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp.rs index c7a5a95e66bbf..a6a35dadd9919 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp.rs @@ -3,7 +3,7 @@ use core::fmt; use hir::Mutability; -use ide::{CompletionItem, CompletionItemRefMode, CompletionRelevance}; +use ide::{CompletionItem, CompletionItemImport, CompletionItemRefMode, CompletionRelevance}; use tenthash::TentHash; pub mod ext; @@ -136,8 +136,10 @@ pub(crate) fn completion_item_hash(item: &CompletionItem, is_ref_completion: boo hasher.update(item.import_to_add.len().to_ne_bytes()); for import_path in &item.import_to_add { - hasher.update(import_path.len().to_ne_bytes()); - hasher.update(import_path); + let CompletionItemImport { path, as_underscore } = import_path; + hasher.update(path.len().to_ne_bytes()); + hasher.update(path); + hasher.update([u8::from(*as_underscore)]); } hasher.finalize() diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/ext.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/ext.rs index e6493eefef17a..5d0d9209de2f8 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/ext.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/ext.rs @@ -858,6 +858,7 @@ pub struct InlayHintResolveData { #[derive(Debug, Serialize, Deserialize)] pub struct CompletionImport { pub full_import_path: String, + pub as_underscore: bool, } #[derive(Debug, Deserialize, Default)] diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/to_proto.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/to_proto.rs index 5fa95252e7cbf..d857f23b6703c 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/to_proto.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/lsp/to_proto.rs @@ -413,7 +413,10 @@ fn completion_item( item.import_to_add .clone() .into_iter() - .map(|import_path| lsp_ext::CompletionImport { full_import_path: import_path }) + .map(|import| lsp_ext::CompletionImport { + full_import_path: import.path, + as_underscore: import.as_underscore, + }) .collect() } else { Vec::new() diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/src/main_loop.rs b/src/tools/rust-analyzer/crates/rust-analyzer/src/main_loop.rs index a8c3d062d041c..f5b3658ea90ca 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/src/main_loop.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/src/main_loop.rs @@ -144,12 +144,11 @@ impl fmt::Debug for Event { }; match self { - Event::Lsp(lsp_server::Message::Notification(not)) => { - if notification_is::(not) - || notification_is::(not) - { - return debug_non_verbose(not, f); - } + Event::Lsp(lsp_server::Message::Notification(not)) + if (notification_is::(not) + || notification_is::(not)) => + { + return debug_non_verbose(not, f); } Event::Task(Task::Response(resp)) => { return f diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/flycheck.rs b/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/flycheck.rs index c6f1f81139d28..e5d4d7c88e803 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/flycheck.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/flycheck.rs @@ -112,6 +112,7 @@ fn main() {} } #[test] +#[ignore = "this test tends to stuck, FIXME: investigate that"] fn test_flycheck_diagnostics_with_override_command_cleared_after_fix() { if skip_slow_tests() { return; diff --git a/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/main.rs b/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/main.rs index 3c57e36b4fe9f..a8632630784be 100644 --- a/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/main.rs +++ b/src/tools/rust-analyzer/crates/rust-analyzer/tests/slow-tests/main.rs @@ -1577,6 +1577,6 @@ fn test() { let res: serde_json::Value = serde_json::from_str(res.as_str().unwrap()).unwrap(); let arr = res.as_array().unwrap(); - assert_eq!(arr.len(), 2); + assert_eq!(arr.len(), 1); expect![[r#"{"goal":"Goal { param_env: ParamEnv { clauses: [] }, predicate: Binder { value: TraitPredicate(usize: Trait, polarity:Positive), bound_vars: [] } }","result":"Err(NoSolution)","depth":0,"candidates":[]}"#]].assert_eq(&arr[0].to_string()); } diff --git a/src/tools/rust-analyzer/crates/span/src/hygiene.rs b/src/tools/rust-analyzer/crates/span/src/hygiene.rs index fe05ef9465181..0a81cef52ec5a 100644 --- a/src/tools/rust-analyzer/crates/span/src/hygiene.rs +++ b/src/tools/rust-analyzer/crates/span/src/hygiene.rs @@ -81,25 +81,24 @@ const _: () = { #[derive(Hash)] struct StructKey<'db, T0, T1, T2, T3>(T0, T1, T2, T3, std::marker::PhantomData<&'db ()>); - impl<'db, T0, T1, T2, T3> zalsa_::interned::HashEqLike> - for SyntaxContextData + impl<'db, T0, T1, T2, T3> zalsa_::HashEqLike> for SyntaxContextData where - Option: zalsa_::interned::HashEqLike, - Transparency: zalsa_::interned::HashEqLike, - Edition: zalsa_::interned::HashEqLike, - SyntaxContext: zalsa_::interned::HashEqLike, + Option: zalsa_::HashEqLike, + Transparency: zalsa_::HashEqLike, + Edition: zalsa_::HashEqLike, + SyntaxContext: zalsa_::HashEqLike, { fn hash(&self, h: &mut H) { - zalsa_::interned::HashEqLike::::hash(&self.outer_expn, &mut *h); - zalsa_::interned::HashEqLike::::hash(&self.outer_transparency, &mut *h); - zalsa_::interned::HashEqLike::::hash(&self.edition, &mut *h); - zalsa_::interned::HashEqLike::::hash(&self.parent, &mut *h); + zalsa_::HashEqLike::::hash(&self.outer_expn, &mut *h); + zalsa_::HashEqLike::::hash(&self.outer_transparency, &mut *h); + zalsa_::HashEqLike::::hash(&self.edition, &mut *h); + zalsa_::HashEqLike::::hash(&self.parent, &mut *h); } fn eq(&self, data: &StructKey<'db, T0, T1, T2, T3>) -> bool { - zalsa_::interned::HashEqLike::::eq(&self.outer_expn, &data.0) - && zalsa_::interned::HashEqLike::::eq(&self.outer_transparency, &data.1) - && zalsa_::interned::HashEqLike::::eq(&self.edition, &data.2) - && zalsa_::interned::HashEqLike::::eq(&self.parent, &data.3) + zalsa_::HashEqLike::::eq(&self.outer_expn, &data.0) + && zalsa_::HashEqLike::::eq(&self.outer_transparency, &data.1) + && zalsa_::HashEqLike::::eq(&self.edition, &data.2) + && zalsa_::HashEqLike::::eq(&self.parent, &data.3) } } impl zalsa_struct_::Configuration for SyntaxContext { @@ -203,10 +202,10 @@ const _: () = { impl<'db> SyntaxContext { pub fn new< Db, - T0: zalsa_::interned::Lookup> + std::hash::Hash, - T1: zalsa_::interned::Lookup + std::hash::Hash, - T2: zalsa_::interned::Lookup + std::hash::Hash, - T3: zalsa_::interned::Lookup + std::hash::Hash, + T0: zalsa_::Lookup> + std::hash::Hash, + T1: zalsa_::Lookup + std::hash::Hash, + T2: zalsa_::Lookup + std::hash::Hash, + T3: zalsa_::Lookup + std::hash::Hash, >( db: &'db Db, outer_expn: T0, @@ -218,10 +217,10 @@ const _: () = { ) -> Self where Db: ?Sized + salsa::Database, - Option: zalsa_::interned::HashEqLike, - Transparency: zalsa_::interned::HashEqLike, - Edition: zalsa_::interned::HashEqLike, - SyntaxContext: zalsa_::interned::HashEqLike, + Option: zalsa_::HashEqLike, + Transparency: zalsa_::HashEqLike, + Edition: zalsa_::HashEqLike, + SyntaxContext: zalsa_::HashEqLike, { let (zalsa, zalsa_local) = db.zalsas(); @@ -236,10 +235,10 @@ const _: () = { std::marker::PhantomData, ), |id, data| SyntaxContextData { - outer_expn: zalsa_::interned::Lookup::into_owned(data.0), - outer_transparency: zalsa_::interned::Lookup::into_owned(data.1), - edition: zalsa_::interned::Lookup::into_owned(data.2), - parent: zalsa_::interned::Lookup::into_owned(data.3), + outer_expn: zalsa_::Lookup::into_owned(data.0), + outer_transparency: zalsa_::Lookup::into_owned(data.1), + edition: zalsa_::Lookup::into_owned(data.2), + parent: zalsa_::Lookup::into_owned(data.3), opaque: opaque(zalsa_::FromId::from_id(id)), opaque_and_semiopaque: opaque_and_semiopaque(zalsa_::FromId::from_id(id)), }, diff --git a/src/tools/rust-analyzer/crates/syntax/rust.ungram b/src/tools/rust-analyzer/crates/syntax/rust.ungram index 324b2bbd58e1c..768cf2013d650 100644 --- a/src/tools/rust-analyzer/crates/syntax/rust.ungram +++ b/src/tools/rust-analyzer/crates/syntax/rust.ungram @@ -277,6 +277,7 @@ RecordFieldList = RecordField = Attr* Visibility? 'unsafe'? + MutRestriction? Name ':' Type ('=' default_val:ConstArg)? TupleFieldList = @@ -284,12 +285,16 @@ TupleFieldList = TupleField = Attr* Visibility? + MutRestriction? Type FieldList = RecordFieldList | TupleFieldList +MutRestriction = + 'mut' VisibilityInner + Enum = Attr* Visibility? 'enum' Name GenericParamList? WhereClause? @@ -323,6 +328,7 @@ VariantDef = Const = Attr* Visibility? 'default'? + 'type'? 'const' (Name | '_') GenericParamList? ':' Type ('=' body:Expr)? WhereClause? ';' @@ -336,10 +342,14 @@ Static = Trait = Attr* Visibility? 'unsafe'? 'auto'? + ImplRestriction? 'trait' Name GenericParamList? (((':' TypeBoundList?)? WhereClause? AssocItemList) | ('=' TypeBoundList? WhereClause? ';')) +ImplRestriction = + 'impl' VisibilityInner + AssocItemList = '{' Attr* AssocItem* '}' @@ -368,8 +378,10 @@ ExternItem = | TypeAlias Visibility = - 'pub' ('(' 'in'? Path ')')? + 'pub' VisibilityInner? +VisibilityInner = + '(' 'in'? Path ')' //****************************// // Statements and Expressions // diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/edit.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/edit.rs index 23a0411eadbd6..b20aa90d06f23 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/edit.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/edit.rs @@ -1,12 +1,15 @@ //! This module contains functions for editing syntax trees. As the trees are //! immutable, all function here return a fresh copy of the tree, instead of //! doing an in-place modification. +use parser::T; use std::{fmt, iter, ops}; use crate::{ - AstToken, NodeOrToken, SyntaxElement, SyntaxNode, SyntaxToken, - ast::{self, AstNode, make}, - syntax_editor::{SyntaxEditor, SyntaxMappingBuilder}, + AstToken, NodeOrToken, SyntaxElement, + SyntaxKind::WHITESPACE, + SyntaxNode, SyntaxToken, + ast::{self, AstNode, HasName, make}, + syntax_editor::{Position, SyntaxEditor, SyntaxMappingBuilder}, ted, }; @@ -105,7 +108,7 @@ impl IndentLevel { } pub(super) fn clone_increase_indent(self, node: &SyntaxNode) -> SyntaxNode { - let (mut editor, node) = SyntaxEditor::new(node.clone()); + let (editor, node) = SyntaxEditor::new(node.clone()); let tokens = node .preorder_with_tokens() .filter_map(|event| match event { @@ -139,7 +142,7 @@ impl IndentLevel { } pub(super) fn clone_decrease_indent(self, node: &SyntaxNode) -> SyntaxNode { - let (mut editor, node) = SyntaxEditor::new(node.clone()); + let (editor, node) = SyntaxEditor::new(node.clone()); let tokens = node .preorder_with_tokens() .filter_map(|event| match event { @@ -194,6 +197,61 @@ pub trait AstNodeEdit: AstNode + Clone + Sized { impl AstNodeEdit for N {} +impl ast::IdentPat { + pub fn set_pat(&self, pat: Option, editor: &SyntaxEditor) -> ast::IdentPat { + let make = editor.make(); + match pat { + None => { + if let Some(at_token) = self.at_token() { + // Remove `@ Pat` + let start = at_token.clone().into(); + let end = self + .pat() + .map(|it| it.syntax().clone().into()) + .unwrap_or_else(|| at_token.into()); + editor.delete_all(start..=end); + + // Remove any trailing ws + if let Some(last) = + self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) + { + last.detach(); + } + } + } + Some(pat) => { + if let Some(old_pat) = self.pat() { + // Replace existing pattern + editor.replace(old_pat.syntax(), pat.syntax()) + } else if let Some(at_token) = self.at_token() { + // Have an `@` token but not a pattern yet + editor.insert(Position::after(at_token), pat.syntax()); + } else { + // Don't have an `@`, should have a name + let name = self.name().unwrap(); + let elements = vec![ + make.whitespace(" ").into(), + make.token(T![@]).into(), + make.whitespace(" ").into(), + pat.syntax().clone().into(), + ]; + + if self.syntax().parent().is_none() { + let (local, local_self) = SyntaxEditor::with_ast_node(self); + let local_name = local_self.name().unwrap(); + local.insert_all(Position::after(local_name.syntax()), elements); + let edit = local.finish(); + return ast::IdentPat::cast(edit.new_root().clone()).unwrap(); + } else { + editor.insert_all(Position::after(name.syntax()), elements); + } + } + } + } + self.clone() + } +} + #[test] fn test_increase_indent() { let arm_list = { diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs index 7f59ae4213829..46ea4daba82ea 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/edit_in_place.rs @@ -9,221 +9,12 @@ use crate::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, SyntaxNode, SyntaxToken, algo::{self, neighbor}, - ast::{self, HasGenericParams, edit::IndentLevel, make, syntax_factory::SyntaxFactory}, - syntax_editor::{Position, SyntaxEditor}, + ast::{self, edit::IndentLevel, make}, ted, }; use super::{GenericParam, HasName}; -pub trait GenericParamsOwnerEdit: ast::HasGenericParams { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList; - fn get_or_create_where_clause(&self) -> ast::WhereClause; -} - -impl GenericParamsOwnerEdit for ast::Fn { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(fn_token) = self.fn_token() { - ted::Position::after(fn_token) - } else if let Some(param_list) = self.param_list() { - ted::Position::before(param_list.syntax) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = if let Some(ty) = self.ret_type() { - ted::Position::after(ty.syntax()) - } else if let Some(param_list) = self.param_list() { - ted::Position::after(param_list.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Impl { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = match self.impl_token() { - Some(imp_token) => ted::Position::after(imp_token), - None => ted::Position::last_child_of(self.syntax()), - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match self.assoc_item_list() { - Some(items) => ted::Position::before(items.syntax()), - None => ted::Position::last_child_of(self.syntax()), - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Trait { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(trait_token) = self.trait_token() { - ted::Position::after(trait_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match (self.assoc_item_list(), self.semicolon_token()) { - (Some(items), _) => ted::Position::before(items.syntax()), - (_, Some(tok)) => ted::Position::before(tok), - (None, None) => ted::Position::last_child_of(self.syntax()), - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::TypeAlias { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(trait_token) = self.type_token() { - ted::Position::after(trait_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = match self.eq_token() { - Some(tok) => ted::Position::before(tok), - None => match self.semicolon_token() { - Some(tok) => ted::Position::before(tok), - None => ted::Position::last_child_of(self.syntax()), - }, - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Struct { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(struct_token) = self.struct_token() { - ted::Position::after(struct_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let tfl = self.field_list().and_then(|fl| match fl { - ast::FieldList::RecordFieldList(_) => None, - ast::FieldList::TupleFieldList(it) => Some(it), - }); - let position = if let Some(tfl) = tfl { - ted::Position::after(tfl.syntax()) - } else if let Some(gpl) = self.generic_param_list() { - ted::Position::after(gpl.syntax()) - } else if let Some(name) = self.name() { - ted::Position::after(name.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -impl GenericParamsOwnerEdit for ast::Enum { - fn get_or_create_generic_param_list(&self) -> ast::GenericParamList { - match self.generic_param_list() { - Some(it) => it, - None => { - let position = if let Some(name) = self.name() { - ted::Position::after(name.syntax) - } else if let Some(enum_token) = self.enum_token() { - ted::Position::after(enum_token) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_generic_param_list(position) - } - } - } - - fn get_or_create_where_clause(&self) -> ast::WhereClause { - if self.where_clause().is_none() { - let position = if let Some(gpl) = self.generic_param_list() { - ted::Position::after(gpl.syntax()) - } else if let Some(name) = self.name() { - ted::Position::after(name.syntax()) - } else { - ted::Position::last_child_of(self.syntax()) - }; - create_where_clause(position); - } - self.where_clause().unwrap() - } -} - -fn create_where_clause(position: ted::Position) { - let where_clause = make::where_clause(empty()).clone_for_update(); - ted::insert(position, where_clause.syntax()); -} - -fn create_generic_param_list(position: ted::Position) -> ast::GenericParamList { - let gpl = make::generic_param_list(empty()).clone_for_update(); - ted::insert_raw(position, gpl.syntax()); - gpl -} - pub trait AttrsOwnerEdit: ast::HasAttrs { fn remove_attrs_and_docs(&self) { remove_attrs_and_docs(self.syntax()); @@ -723,10 +514,11 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> { let indent = IndentLevel::from_node(node); match l.next_sibling_or_token() { - Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => { - if ws.next_sibling_or_token()?.into_token()? == r { - ted::replace(ws, make::tokens::whitespace(&format!("\n{indent}"))); - } + Some(ws) + if ws.kind() == SyntaxKind::WHITESPACE + && ws.next_sibling_or_token()?.into_token()? == r => + { + ted::replace(ws, make::tokens::whitespace(&format!("\n{indent}"))); } Some(ws) if ws.kind() == T!['}'] => { ted::insert(ted::Position::after(l), make::tokens::whitespace(&format!("\n{indent}"))); @@ -736,128 +528,6 @@ fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> { Some(()) } -impl ast::IdentPat { - pub fn set_pat(&self, pat: Option) { - match pat { - None => { - if let Some(at_token) = self.at_token() { - // Remove `@ Pat` - let start = at_token.clone().into(); - let end = self - .pat() - .map(|it| it.syntax().clone().into()) - .unwrap_or_else(|| at_token.into()); - - ted::remove_all(start..=end); - - // Remove any trailing ws - if let Some(last) = - self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) - { - last.detach(); - } - } - } - Some(pat) => { - if let Some(old_pat) = self.pat() { - // Replace existing pattern - ted::replace(old_pat.syntax(), pat.syntax()) - } else if let Some(at_token) = self.at_token() { - // Have an `@` token but not a pattern yet - ted::insert(ted::Position::after(at_token), pat.syntax()); - } else { - // Don't have an `@`, should have a name - let name = self.name().unwrap(); - - ted::insert_all( - ted::Position::after(name.syntax()), - vec![ - make::token(T![@]).into(), - make::tokens::single_space().into(), - pat.syntax().clone().into(), - ], - ) - } - } - } - } - - pub fn set_pat_with_editor( - &self, - pat: Option, - syntax_editor: &mut SyntaxEditor, - syntax_factory: &SyntaxFactory, - ) { - match pat { - None => { - if let Some(at_token) = self.at_token() { - // Remove `@ Pat` - let start = at_token.clone().into(); - let end = self - .pat() - .map(|it| it.syntax().clone().into()) - .unwrap_or_else(|| at_token.into()); - syntax_editor.delete_all(start..=end); - - // Remove any trailing ws - if let Some(last) = - self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) - { - last.detach(); - } - } - } - Some(pat) => { - if let Some(old_pat) = self.pat() { - // Replace existing pattern - syntax_editor.replace(old_pat.syntax(), pat.syntax()) - } else if let Some(at_token) = self.at_token() { - // Have an `@` token but not a pattern yet - syntax_editor.insert(Position::after(at_token), pat.syntax()); - } else { - // Don't have an `@`, should have a name - let name = self.name().unwrap(); - - syntax_editor.insert_all( - Position::after(name.syntax()), - vec![ - syntax_factory.whitespace(" ").into(), - syntax_factory.token(T![@]).into(), - syntax_factory.whitespace(" ").into(), - pat.syntax().clone().into(), - ], - ) - } - } - } - } -} - -pub trait HasVisibilityEdit: ast::HasVisibility { - fn set_visibility(&self, visibility: Option) { - if let Some(visibility) = visibility { - match self.visibility() { - Some(current_visibility) => { - ted::replace(current_visibility.syntax(), visibility.syntax()) - } - None => { - let vis_before = self - .syntax() - .children_with_tokens() - .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR)) - .unwrap_or_else(|| self.syntax().first_child_or_token().unwrap()); - - ted::insert(ted::Position::before(vis_before), visibility.syntax()); - } - } - } else if let Some(visibility) = self.visibility() { - ted::remove(visibility.syntax()); - } - } -} - -impl HasVisibilityEdit for T {} - pub trait Indent: AstNode + Clone + Sized { fn indent_level(&self) -> IndentLevel { IndentLevel::from_node(self.syntax()) @@ -879,8 +549,6 @@ impl Indent for N {} #[cfg(test)] mod tests { - use std::fmt; - use parser::Edition; use crate::SourceFile; @@ -892,33 +560,6 @@ mod tests { parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update() } - #[test] - fn test_create_generic_param_list() { - fn check_create_gpl(before: &str, after: &str) { - let gpl_owner = ast_mut_from_text::(before); - gpl_owner.get_or_create_generic_param_list(); - assert_eq!(gpl_owner.to_string(), after); - } - - check_create_gpl::("fn foo", "fn foo<>"); - check_create_gpl::("fn foo() {}", "fn foo<>() {}"); - - check_create_gpl::("impl", "impl<>"); - check_create_gpl::("impl Struct {}", "impl<> Struct {}"); - check_create_gpl::("impl Trait for Struct {}", "impl<> Trait for Struct {}"); - - check_create_gpl::("trait Trait<>", "trait Trait<>"); - check_create_gpl::("trait Trait<> {}", "trait Trait<> {}"); - - check_create_gpl::("struct A", "struct A<>"); - check_create_gpl::("struct A;", "struct A<>;"); - check_create_gpl::("struct A();", "struct A<>();"); - check_create_gpl::("struct A {}", "struct A<> {}"); - - check_create_gpl::("enum E", "enum E<>"); - check_create_gpl::("enum E {", "enum E<> {"); - } - #[test] fn test_increase_indent() { let arm_list = ast_mut_from_text::( @@ -936,32 +577,4 @@ mod tests { }", ); } - - #[test] - fn test_ident_pat_set_pat() { - #[track_caller] - fn check(before: &str, expected: &str, pat: Option) { - let pat = pat.map(|it| it.clone_for_update()); - - let ident_pat = ast_mut_from_text::(&format!("fn f() {{ {before} }}")); - ident_pat.set_pat(pat); - - let after = ast_mut_from_text::(&format!("fn f() {{ {expected} }}")); - assert_eq!(ident_pat.to_string(), after.to_string()); - } - - // replacing - check("let a @ _;", "let a @ ();", Some(make::tuple_pat([]).into())); - - // note: no trailing semicolon is added for the below tests since it - // seems to be picked up by the ident pat during error recovery? - - // adding - check("let a ", "let a @ ()", Some(make::tuple_pat([]).into())); - check("let a @ ", "let a @ ()", Some(make::tuple_pat([]).into())); - - // removing - check("let a @ ()", "let a", None); - check("let a @ ", "let a", None); - } } diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/generated/nodes.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/generated/nodes.rs index cd7f6a018ab2b..9a2bba9ebf0dd 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/generated/nodes.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/generated/nodes.rs @@ -484,6 +484,8 @@ impl Const { pub fn const_token(&self) -> Option { support::token(&self.syntax, T![const]) } #[inline] pub fn default_token(&self) -> Option { support::token(&self.syntax, T![default]) } + #[inline] + pub fn type_token(&self) -> Option { support::token(&self.syntax, T![type]) } } pub struct ConstArg { pub(crate) syntax: SyntaxNode, @@ -808,6 +810,15 @@ impl Impl { #[inline] pub fn unsafe_token(&self) -> Option { support::token(&self.syntax, T![unsafe]) } } +pub struct ImplRestriction { + pub(crate) syntax: SyntaxNode, +} +impl ImplRestriction { + #[inline] + pub fn visibility_inner(&self) -> Option { support::child(&self.syntax) } + #[inline] + pub fn impl_token(&self) -> Option { support::token(&self.syntax, T![impl]) } +} pub struct ImplTraitType { pub(crate) syntax: SyntaxNode, } @@ -1114,6 +1125,15 @@ impl Module { #[inline] pub fn mod_token(&self) -> Option { support::token(&self.syntax, T![mod]) } } +pub struct MutRestriction { + pub(crate) syntax: SyntaxNode, +} +impl MutRestriction { + #[inline] + pub fn visibility_inner(&self) -> Option { support::child(&self.syntax) } + #[inline] + pub fn mut_token(&self) -> Option { support::token(&self.syntax, T![mut]) } +} pub struct Name { pub(crate) syntax: SyntaxNode, } @@ -1400,6 +1420,8 @@ impl RecordField { #[inline] pub fn default_val(&self) -> Option { support::child(&self.syntax) } #[inline] + pub fn mut_restriction(&self) -> Option { support::child(&self.syntax) } + #[inline] pub fn ty(&self) -> Option { support::child(&self.syntax) } #[inline] pub fn colon_token(&self) -> Option { support::token(&self.syntax, T![:]) } @@ -1690,6 +1712,8 @@ impl Trait { #[inline] pub fn assoc_item_list(&self) -> Option { support::child(&self.syntax) } #[inline] + pub fn impl_restriction(&self) -> Option { support::child(&self.syntax) } + #[inline] pub fn semicolon_token(&self) -> Option { support::token(&self.syntax, T![;]) } #[inline] pub fn eq_token(&self) -> Option { support::token(&self.syntax, T![=]) } @@ -1742,6 +1766,8 @@ impl ast::HasAttrs for TupleField {} impl ast::HasDocComments for TupleField {} impl ast::HasVisibility for TupleField {} impl TupleField { + #[inline] + pub fn mut_restriction(&self) -> Option { support::child(&self.syntax) } #[inline] pub fn ty(&self) -> Option { support::child(&self.syntax) } } @@ -2000,6 +2026,15 @@ pub struct Visibility { pub(crate) syntax: SyntaxNode, } impl Visibility { + #[inline] + pub fn visibility_inner(&self) -> Option { support::child(&self.syntax) } + #[inline] + pub fn pub_token(&self) -> Option { support::token(&self.syntax, T![pub]) } +} +pub struct VisibilityInner { + pub(crate) syntax: SyntaxNode, +} +impl VisibilityInner { #[inline] pub fn path(&self) -> Option { support::child(&self.syntax) } #[inline] @@ -2008,8 +2043,6 @@ impl Visibility { pub fn r_paren_token(&self) -> Option { support::token(&self.syntax, T![')']) } #[inline] pub fn in_token(&self) -> Option { support::token(&self.syntax, T![in]) } - #[inline] - pub fn pub_token(&self) -> Option { support::token(&self.syntax, T![pub]) } } pub struct WhereClause { pub(crate) syntax: SyntaxNode, @@ -4192,6 +4225,38 @@ impl fmt::Debug for Impl { f.debug_struct("Impl").field("syntax", &self.syntax).finish() } } +impl AstNode for ImplRestriction { + #[inline] + fn kind() -> SyntaxKind + where + Self: Sized, + { + IMPL_RESTRICTION + } + #[inline] + fn can_cast(kind: SyntaxKind) -> bool { kind == IMPL_RESTRICTION } + #[inline] + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { Some(Self { syntax }) } else { None } + } + #[inline] + fn syntax(&self) -> &SyntaxNode { &self.syntax } +} +impl hash::Hash for ImplRestriction { + fn hash(&self, state: &mut H) { self.syntax.hash(state); } +} +impl Eq for ImplRestriction {} +impl PartialEq for ImplRestriction { + fn eq(&self, other: &Self) -> bool { self.syntax == other.syntax } +} +impl Clone for ImplRestriction { + fn clone(&self) -> Self { Self { syntax: self.syntax.clone() } } +} +impl fmt::Debug for ImplRestriction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ImplRestriction").field("syntax", &self.syntax).finish() + } +} impl AstNode for ImplTraitType { #[inline] fn kind() -> SyntaxKind @@ -5120,6 +5185,38 @@ impl fmt::Debug for Module { f.debug_struct("Module").field("syntax", &self.syntax).finish() } } +impl AstNode for MutRestriction { + #[inline] + fn kind() -> SyntaxKind + where + Self: Sized, + { + MUT_RESTRICTION + } + #[inline] + fn can_cast(kind: SyntaxKind) -> bool { kind == MUT_RESTRICTION } + #[inline] + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { Some(Self { syntax }) } else { None } + } + #[inline] + fn syntax(&self) -> &SyntaxNode { &self.syntax } +} +impl hash::Hash for MutRestriction { + fn hash(&self, state: &mut H) { self.syntax.hash(state); } +} +impl Eq for MutRestriction {} +impl PartialEq for MutRestriction { + fn eq(&self, other: &Self) -> bool { self.syntax == other.syntax } +} +impl Clone for MutRestriction { + fn clone(&self) -> Self { Self { syntax: self.syntax.clone() } } +} +impl fmt::Debug for MutRestriction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MutRestriction").field("syntax", &self.syntax).finish() + } +} impl AstNode for Name { #[inline] fn kind() -> SyntaxKind @@ -7392,6 +7489,38 @@ impl fmt::Debug for Visibility { f.debug_struct("Visibility").field("syntax", &self.syntax).finish() } } +impl AstNode for VisibilityInner { + #[inline] + fn kind() -> SyntaxKind + where + Self: Sized, + { + VISIBILITY_INNER + } + #[inline] + fn can_cast(kind: SyntaxKind) -> bool { kind == VISIBILITY_INNER } + #[inline] + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { Some(Self { syntax }) } else { None } + } + #[inline] + fn syntax(&self) -> &SyntaxNode { &self.syntax } +} +impl hash::Hash for VisibilityInner { + fn hash(&self, state: &mut H) { self.syntax.hash(state); } +} +impl Eq for VisibilityInner {} +impl PartialEq for VisibilityInner { + fn eq(&self, other: &Self) -> bool { self.syntax == other.syntax } +} +impl Clone for VisibilityInner { + fn clone(&self) -> Self { Self { syntax: self.syntax.clone() } } +} +impl fmt::Debug for VisibilityInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("VisibilityInner").field("syntax", &self.syntax).finish() + } +} impl AstNode for WhereClause { #[inline] fn kind() -> SyntaxKind @@ -10092,6 +10221,11 @@ impl std::fmt::Display for Impl { std::fmt::Display::fmt(self.syntax(), f) } } +impl std::fmt::Display for ImplRestriction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self.syntax(), f) + } +} impl std::fmt::Display for ImplTraitType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) @@ -10237,6 +10371,11 @@ impl std::fmt::Display for Module { std::fmt::Display::fmt(self.syntax(), f) } } +impl std::fmt::Display for MutRestriction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self.syntax(), f) + } +} impl std::fmt::Display for Name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) @@ -10592,6 +10731,11 @@ impl std::fmt::Display for Visibility { std::fmt::Display::fmt(self.syntax(), f) } } +impl std::fmt::Display for VisibilityInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self.syntax(), f) + } +} impl std::fmt::Display for WhereClause { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/node_ext.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/node_ext.rs index 03118d01dc90a..751f8d7e1cbe1 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/node_ext.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/node_ext.rs @@ -947,6 +947,15 @@ pub enum VisibilityKind { } impl ast::Visibility { + pub fn kind(&self) -> VisibilityKind { + match self.visibility_inner() { + Some(inner) => inner.kind(), + None => VisibilityKind::Pub, + } + } +} + +impl ast::VisibilityInner { pub fn kind(&self) -> VisibilityKind { match self.path() { Some(path) => { diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory.rs index f3ae7544cc37f..9369a4e700cbd 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory.rs @@ -12,6 +12,7 @@ use std::cell::{RefCell, RefMut}; use crate::syntax_editor::SyntaxMapping; +#[derive(Debug)] pub struct SyntaxFactory { // Stored in a refcell so that the factory methods can be &self mappings: Option>, @@ -19,7 +20,7 @@ pub struct SyntaxFactory { impl SyntaxFactory { /// Creates a new [`SyntaxFactory`], generating mappings between input nodes and generated nodes. - pub fn with_mappings() -> Self { + pub(crate) fn with_mappings() -> Self { Self { mappings: Some(RefCell::new(SyntaxMapping::default())) } } @@ -28,13 +29,8 @@ impl SyntaxFactory { Self { mappings: None } } - /// Gets all of the tracked syntax mappings, if any. - pub fn finish_with_mappings(self) -> SyntaxMapping { - self.mappings.unwrap_or_default().into_inner() - } - /// Take all of the tracked syntax mappings, leaving `SyntaxMapping::default()` in its place, if any. - pub fn take(&self) -> SyntaxMapping { + pub(crate) fn take(&self) -> SyntaxMapping { self.mappings.as_ref().map(|mappings| mappings.take()).unwrap_or_default() } diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs index c66f096e8342c..0f3b3d301c544 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -67,6 +67,26 @@ impl SyntaxFactory { make::type_bound(bound).clone_for_update() } + pub fn type_bound_text(&self, bound: &str) -> ast::TypeBound { + make::type_bound_text(bound).clone_for_update() + } + + pub fn use_tree_list( + &self, + use_trees: impl IntoIterator, + ) -> ast::UseTreeList { + let (use_trees, input) = iterator_input(use_trees); + let ast = make::use_tree_list(use_trees).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_children(input, ast.use_trees().map(|b| b.syntax().clone())); + builder.finish(&mut mapping); + } + + ast + } + pub fn type_bound_list( &self, bounds: impl IntoIterator, @@ -76,7 +96,9 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input, ast.bounds().map(|b| b.syntax().clone())); + for (input_node, output_bound) in input.into_iter().zip(ast.bounds()) { + builder.map_node(input_node, output_bound.syntax().clone()); + } builder.finish(&mut mapping); } @@ -171,6 +193,18 @@ impl SyntaxFactory { ast } + pub fn untyped_param(&self, pat: ast::Pat) -> ast::Param { + let ast = make::untyped_param(pat.clone()).clone_for_update(); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn ty_fn_ptr>( &self, is_unsafe: bool, @@ -189,7 +223,7 @@ impl SyntaxFactory { } builder.map_children( params_input, - ast.param_list().unwrap().params().map(|p| p.syntax().clone()), + ast.syntax().children().filter(|c| ast::Param::can_cast(c.kind())), ); if let Some(ret_type) = ret_type { builder @@ -222,13 +256,17 @@ impl SyntaxFactory { builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone()); } } + builder.finish(&mut mapping); + if let Some(type_bound_list) = ast.type_bound_list() { - builder.map_children( + let mut bounds_builder = + SyntaxMappingBuilder::new(type_bound_list.syntax().clone()); + bounds_builder.map_children( bounds_input, type_bound_list.bounds().map(|b| b.syntax().clone()), ); + bounds_builder.finish(&mut mapping); } - builder.finish(&mut mapping); } ast @@ -464,11 +502,13 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); builder.map_node(name_ref.syntax().clone(), ast.name_ref().unwrap().syntax().clone()); - builder.map_children( - input, - ast.generic_arg_list().unwrap().generic_args().map(|a| a.syntax().clone()), - ); builder.finish(&mut mapping); + + let generic_arg_list = ast.generic_arg_list().unwrap(); + let mut arg_builder = SyntaxMappingBuilder::new(generic_arg_list.syntax().clone()); + arg_builder + .map_children(input, generic_arg_list.generic_args().map(|a| a.syntax().clone())); + arg_builder.finish(&mut mapping); } ast @@ -609,9 +649,16 @@ impl SyntaxFactory { let ast = make::path_from_segments(segments, is_abs).clone_for_update(); if let Some(mut mapping) = self.mappings() { - let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input, ast.segments().map(|it| it.syntax().clone())); - builder.finish(&mut mapping); + let mut current_path = Some(ast.clone()); + for input_segment in input.iter().rev() { + let Some(path) = current_path else { break }; + if let Some(segment) = path.segment() { + let mut builder = SyntaxMappingBuilder::new(path.syntax().clone()); + builder.map_node(input_segment.clone(), segment.syntax().clone()); + builder.finish(&mut mapping); + } + current_path = path.qualifier(); + } } ast @@ -1053,13 +1100,15 @@ impl SyntaxFactory { let ast = make::expr_closure(args, expr.clone()).clone_for_update(); if let Some(mut mapping) = self.mappings() { - let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone()); - builder.map_children( - input, - ast.param_list().unwrap().params().map(|param| param.syntax().clone()), - ); + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); builder.map_node(expr.syntax().clone(), ast.body().unwrap().syntax().clone()); builder.finish(&mut mapping); + + let param_list = ast.param_list().unwrap(); + let mut params_builder = SyntaxMappingBuilder::new(param_list.syntax().clone()); + params_builder + .map_children(input, param_list.params().map(|param| param.syntax().clone())); + params_builder.finish(&mut mapping); } ast diff --git a/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs b/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs index 83ab87c1c687e..29b3b0930adab 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/ast/token_ext.rs @@ -32,11 +32,7 @@ impl ast::Comment { } pub fn prefix(&self) -> &'static str { - let &(prefix, _kind) = CommentKind::BY_PREFIX - .iter() - .find(|&(prefix, kind)| self.kind() == *kind && self.text().starts_with(prefix)) - .unwrap(); - prefix + self.kind().prefix() } /// Returns the textual content of a doc comment node as a single string with prefix and suffix diff --git a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs index 8e4dc75d22194..edd063ffd4617 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor.rs @@ -5,6 +5,7 @@ //! [`SyntaxEditor`]: https://github.com/dotnet/roslyn/blob/43b0b05cc4f492fd5de00f6f6717409091df8daa/src/Workspaces/Core/Portable/Editing/SyntaxEditor.cs use std::{ + cell::RefCell, fmt, iter, num::NonZeroU32, ops::RangeInclusive, @@ -29,9 +30,9 @@ pub use mapping::{SyntaxMapping, SyntaxMappingBuilder}; #[derive(Debug)] pub struct SyntaxEditor { root: SyntaxNode, - changes: Vec, - mappings: SyntaxMapping, - annotations: Vec<(SyntaxElement, SyntaxAnnotation)>, + changes: RefCell>, + annotations: RefCell>, + make: SyntaxFactory, } impl SyntaxEditor { @@ -50,9 +51,9 @@ impl SyntaxEditor { let editor = Self { root: root.clone(), - changes: Vec::new(), - mappings: SyntaxMapping::default(), - annotations: Vec::new(), + changes: RefCell::new(Vec::new()), + annotations: RefCell::new(Vec::new()), + make: SyntaxFactory::with_mappings(), }; (editor, root) @@ -68,20 +69,21 @@ impl SyntaxEditor { (editor, T::cast(root).unwrap()) } - pub fn add_annotation(&mut self, element: impl Element, annotation: SyntaxAnnotation) { - self.annotations.push((element.syntax_element(), annotation)) + pub fn make(&self) -> &SyntaxFactory { + &self.make } - pub fn add_annotation_all( - &mut self, - elements: Vec, - annotation: SyntaxAnnotation, - ) { + pub fn add_annotation(&self, element: impl Element, annotation: SyntaxAnnotation) { + self.annotations.borrow_mut().push((element.syntax_element(), annotation)) + } + + pub fn add_annotation_all(&self, elements: Vec, annotation: SyntaxAnnotation) { self.annotations + .borrow_mut() .extend(elements.into_iter().map(|e| e.syntax_element()).zip(iter::repeat(annotation))); } - pub fn merge(&mut self, mut other: SyntaxEditor) { + pub fn merge(&self, other: SyntaxEditor) { debug_assert!( self.root == other.root || other.root.ancestors().any(|node| node == self.root), "{:?} is not in the same tree as {:?}", @@ -89,102 +91,92 @@ impl SyntaxEditor { self.root ); - self.changes.append(&mut other.changes); - self.mappings.merge(other.mappings); - self.annotations.append(&mut other.annotations); + self.changes.borrow_mut().append(&mut other.changes.into_inner()); + if let Some(mut m) = self.make.mappings() { + m.merge(other.make.take()); + } + self.annotations.borrow_mut().append(&mut other.annotations.into_inner()); } - pub fn insert(&mut self, position: Position, element: impl Element) { + pub fn insert(&self, position: Position, element: impl Element) { debug_assert!(is_ancestor_or_self(&position.parent(), &self.root)); - self.changes.push(Change::Insert(position, element.syntax_element())) + self.changes.borrow_mut().push(Change::Insert(position, element.syntax_element())) } - pub fn insert_all(&mut self, position: Position, elements: Vec) { + pub fn insert_all(&self, position: Position, elements: Vec) { debug_assert!(is_ancestor_or_self(&position.parent(), &self.root)); - self.changes.push(Change::InsertAll(position, elements)) + self.changes.borrow_mut().push(Change::InsertAll(position, elements)) } - pub fn insert_with_whitespace( - &mut self, - position: Position, - element: impl Element, - factory: &SyntaxFactory, - ) { - self.insert_all_with_whitespace(position, vec![element.syntax_element()], factory) + pub fn insert_with_whitespace(&self, position: Position, element: impl Element) { + self.insert_all_with_whitespace(position, vec![element.syntax_element()]) } - pub fn insert_all_with_whitespace( - &mut self, - position: Position, - mut elements: Vec, - factory: &SyntaxFactory, - ) { + pub fn insert_all_with_whitespace(&self, position: Position, mut elements: Vec) { if let Some(first) = elements.first() - && let Some(ws) = ws_before(&position, first, factory) + && let Some(ws) = ws_before(&position, first, &self.make) { elements.insert(0, ws.into()); } if let Some(last) = elements.last() - && let Some(ws) = ws_after(&position, last, factory) + && let Some(ws) = ws_after(&position, last, &self.make) { elements.push(ws.into()); } self.insert_all(position, elements) } - pub fn delete(&mut self, element: impl Element) { + pub fn delete(&self, element: impl Element) { let element = element.syntax_element(); debug_assert!(is_ancestor_or_self_of_element(&element, &self.root)); debug_assert!( !matches!(&element, SyntaxElement::Node(node) if node == &self.root), "should not delete root node" ); - self.changes.push(Change::Replace(element.syntax_element(), None)); + self.changes.borrow_mut().push(Change::Replace(element.syntax_element(), None)); } - pub fn delete_all(&mut self, range: RangeInclusive) { + pub fn delete_all(&self, range: RangeInclusive) { if range.start() == range.end() { self.delete(range.start()); return; } debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root)); - self.changes.push(Change::ReplaceAll(range, Vec::new())) + self.changes.borrow_mut().push(Change::ReplaceAll(range, Vec::new())) } - pub fn replace(&mut self, old: impl Element, new: impl Element) { + pub fn replace(&self, old: impl Element, new: impl Element) { let old = old.syntax_element(); debug_assert!(is_ancestor_or_self_of_element(&old, &self.root)); - self.changes.push(Change::Replace(old.syntax_element(), Some(new.syntax_element()))); + self.changes + .borrow_mut() + .push(Change::Replace(old.syntax_element(), Some(new.syntax_element()))); } - pub fn replace_with_many(&mut self, old: impl Element, new: Vec) { + pub fn replace_with_many(&self, old: impl Element, new: Vec) { let old = old.syntax_element(); debug_assert!(is_ancestor_or_self_of_element(&old, &self.root)); debug_assert!( !(matches!(&old, SyntaxElement::Node(node) if node == &self.root) && new.len() > 1), "cannot replace root node with many elements" ); - self.changes.push(Change::ReplaceWithMany(old.syntax_element(), new)); + self.changes.borrow_mut().push(Change::ReplaceWithMany(old.syntax_element(), new)); } - pub fn replace_all(&mut self, range: RangeInclusive, new: Vec) { + pub fn replace_all(&self, range: RangeInclusive, new: Vec) { if range.start() == range.end() { self.replace_with_many(range.start(), new); return; } debug_assert!(is_ancestor_or_self_of_element(range.start(), &self.root)); - self.changes.push(Change::ReplaceAll(range, new)) + self.changes.borrow_mut().push(Change::ReplaceAll(range, new)) } pub fn finish(self) -> SyntaxEdit { edit_algo::apply_edits(self) } - - pub fn add_mappings(&mut self, other: SyntaxMapping) { - self.mappings.merge(other); - } } /// Represents a completed [`SyntaxEditor`] operation. @@ -538,7 +530,7 @@ mod tests { use crate::{ AstNode, - ast::{self, make, syntax_factory::SyntaxFactory}, + ast::{self, make}, }; use super::*; @@ -559,13 +551,12 @@ mod tests { .into(), ); - let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let (editor, root) = SyntaxEditor::with_ast_node(&root); + let make = editor.make(); let to_wrap = root.syntax().descendants().find_map(ast::TupleExpr::cast).unwrap(); let to_replace = root.syntax().descendants().find_map(ast::BinExpr::cast).unwrap(); - let make = SyntaxFactory::with_mappings(); - let name = make::name("var_name"); let name_ref = make::name_ref("var_name").clone_for_update(); @@ -574,7 +565,8 @@ mod tests { editor.add_annotation(name_ref.syntax(), placeholder_snippet); let new_block = make.block_expr( - [make + [editor + .make() .let_stmt( make.ident_pat(false, false, name.clone()).into(), None, @@ -586,7 +578,6 @@ mod tests { editor.replace(to_replace.syntax(), name_ref.syntax()); editor.replace(to_wrap.syntax(), new_block.syntax()); - editor.add_mappings(make.finish_with_mappings()); let edit = editor.finish(); @@ -600,8 +591,8 @@ mod tests { assert_eq!(edit.find_annotation(placeholder_snippet).len(), 2); assert!( edit.annotations - .iter() - .flat_map(|(_, elements)| elements) + .values() + .flatten() .all(|element| element.ancestors().any(|it| &it == edit.new_root())) ) } @@ -618,9 +609,9 @@ mod tests { None, ); - let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let (editor, root) = SyntaxEditor::with_ast_node(&root); + let make = editor.make(); let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap(); - let make = SyntaxFactory::without_mappings(); editor.insert( Position::first_child_of(root.stmt_list().unwrap().syntax()), @@ -669,14 +660,13 @@ mod tests { ), ); - let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let (editor, root) = SyntaxEditor::with_ast_node(&root); + let make = editor.make(); let inner_block = root.syntax().descendants().flat_map(ast::BlockExpr::cast).nth(1).unwrap(); let second_let = root.syntax().descendants().find_map(ast::LetStmt::cast).unwrap(); - let make = SyntaxFactory::with_mappings(); - let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone()))); let first_let = make.let_stmt( @@ -697,7 +687,6 @@ mod tests { ); editor.insert(Position::after(second_let.syntax()), third_let.syntax()); editor.replace(inner_block.syntax(), new_block_expr.syntax()); - editor.add_mappings(make.finish_with_mappings()); let edit = editor.finish(); @@ -724,10 +713,10 @@ mod tests { None, ); - let (mut editor, root) = SyntaxEditor::with_ast_node(&root); + let (editor, root) = SyntaxEditor::with_ast_node(&root); + let make = editor.make(); let inner_block = root; - let make = SyntaxFactory::with_mappings(); let new_block_expr = make.block_expr([], Some(ast::Expr::BlockExpr(inner_block.clone()))); @@ -742,7 +731,6 @@ mod tests { first_let.syntax(), ); editor.replace(inner_block.syntax(), new_block_expr.syntax()); - editor.add_mappings(make.finish_with_mappings()); let edit = editor.finish(); @@ -772,7 +760,7 @@ mod tests { false, ); - let (mut editor, parent_fn) = SyntaxEditor::with_ast_node(&parent_fn); + let (editor, parent_fn) = SyntaxEditor::with_ast_node(&parent_fn); if let Some(ret_ty) = parent_fn.ret_type() { editor.delete(ret_ty.syntax().clone()); @@ -799,7 +787,7 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let (editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); let target_expr = make::token(parser::SyntaxKind::UNDERSCORE); @@ -818,7 +806,7 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let (editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); let target_expr = make::expr_literal("3").clone_for_update(); @@ -837,7 +825,7 @@ mod tests { let arg_list = make::arg_list([make::expr_literal("1").into(), make::expr_literal("2").into()]); - let (mut editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); + let (editor, arg_list) = SyntaxEditor::with_ast_node(&arg_list); let target_expr = make::ext::expr_unit().clone_for_update(); diff --git a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edit_algo.rs b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edit_algo.rs index 78e7083f97e4c..27ea03ec09e7d 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edit_algo.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edit_algo.rs @@ -35,7 +35,10 @@ pub(super) fn apply_edits(editor: SyntaxEditor) -> SyntaxEdit { // - changed nodes become part of the changed node set (useful for the formatter to only change those parts) // - Propagate annotations - let SyntaxEditor { root, mut changes, mappings, annotations } = editor; + let SyntaxEditor { root, changes, annotations, make } = editor; + let mut changes = changes.into_inner(); + let annotations = annotations.into_inner(); + let mappings = make.take(); let mut node_depths = FxHashMap::::default(); let mut get_node_depth = |node: SyntaxNode| { diff --git a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edits.rs b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edits.rs index d741adb6e3449..28e8ceed708f3 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edits.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/syntax_editor/edits.rs @@ -3,10 +3,7 @@ use crate::{ AstToken, Direction, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, T, algo::neighbor, - ast::{ - self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make, - syntax_factory::SyntaxFactory, - }, + ast::{self, AstNode, Fn, GenericParam, HasGenericParams, HasName, edit::IndentLevel, make}, syntax_editor::{Position, SyntaxEditor}, }; @@ -15,10 +12,10 @@ pub trait GetOrCreateWhereClause: ast::HasGenericParams { fn get_or_create_where_clause( &self, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, new_preds: impl Iterator, ) { + let make = editor.make(); let existing = self.where_clause(); let all_preds: Vec<_> = existing.iter().flat_map(|wc| wc.predicates()).chain(new_preds).collect(); @@ -113,7 +110,7 @@ impl GetOrCreateWhereClause for ast::Enum { impl SyntaxEditor { /// Adds a new generic param to the function using `SyntaxEditor` - pub fn add_generic_param(&mut self, function: &Fn, new_param: GenericParam) { + pub fn add_generic_param(&self, function: &Fn, new_param: GenericParam) { match function.generic_param_list() { Some(generic_param_list) => match generic_param_list.generic_params().last() { Some(last_param) => { @@ -177,8 +174,8 @@ impl SyntaxEditor { } } -fn get_or_insert_comma_after(editor: &mut SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken { - let make = SyntaxFactory::without_mappings(); +fn get_or_insert_comma_after(editor: &SyntaxEditor, syntax: &SyntaxNode) -> SyntaxToken { + let make = editor.make(); match syntax .siblings_with_tokens(Direction::Next) .filter_map(|it| it.into_token()) @@ -198,7 +195,7 @@ impl ast::AssocItemList { /// /// Attention! This function does align the first line of `item` with respect to `self`, /// but it does _not_ change indentation of other lines (if any). - pub fn add_items(&self, editor: &mut SyntaxEditor, items: Vec) { + pub fn add_items(&self, editor: &SyntaxEditor, items: Vec) { let (indent, position, whitespace) = match self.assoc_items().last() { Some(last_item) => ( IndentLevel::from_node(last_item.syntax()), @@ -232,9 +229,9 @@ impl ast::AssocItemList { impl ast::Impl { pub fn get_or_create_assoc_item_list_with_editor( &self, - editor: &mut SyntaxEditor, - make: &SyntaxFactory, + editor: &SyntaxEditor, ) -> ast::AssocItemList { + let make = editor.make(); if let Some(list) = self.assoc_item_list() { list } else { @@ -249,8 +246,8 @@ impl ast::Impl { } impl ast::VariantList { - pub fn add_variant(&self, editor: &mut SyntaxEditor, variant: &ast::Variant) { - let make = SyntaxFactory::without_mappings(); + pub fn add_variant(&self, editor: &SyntaxEditor, variant: &ast::Variant) { + let make = editor.make(); let (indent, position) = match self.variants().last() { Some(last_item) => ( IndentLevel::from_node(last_item.syntax()), @@ -274,7 +271,7 @@ impl ast::VariantList { } impl ast::Fn { - pub fn replace_or_insert_body(&self, editor: &mut SyntaxEditor, body: ast::BlockExpr) { + pub fn replace_or_insert_body(&self, editor: &SyntaxEditor, body: ast::BlockExpr) { if let Some(old_body) = self.body() { editor.replace(old_body.syntax(), body.syntax()); } else { @@ -290,8 +287,8 @@ impl ast::Fn { } } -fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> Option<()> { - let make = SyntaxFactory::without_mappings(); +fn normalize_ws_between_braces(editor: &SyntaxEditor, node: &SyntaxNode) -> Option<()> { + let make = editor.make(); let l = node .children_with_tokens() .filter_map(|it| it.into_token()) @@ -304,10 +301,11 @@ fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> let indent = IndentLevel::from_node(node); match l.next_sibling_or_token() { - Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => { - if ws.next_sibling_or_token()?.into_token()? == r { - editor.replace(ws, make.whitespace(&format!("\n{indent}"))); - } + Some(ws) + if ws.kind() == SyntaxKind::WHITESPACE + && ws.next_sibling_or_token()?.into_token()? == r => + { + editor.replace(ws, make.whitespace(&format!("\n{indent}"))); } Some(ws) if ws.kind() == T!['}'] => { editor.insert(Position::after(l), make.whitespace(&format!("\n{indent}"))); @@ -318,11 +316,11 @@ fn normalize_ws_between_braces(editor: &mut SyntaxEditor, node: &SyntaxNode) -> } pub trait Removable: AstNode { - fn remove(&self, editor: &mut SyntaxEditor); + fn remove(&self, editor: &SyntaxEditor); } impl Removable for ast::TypeBoundList { - fn remove(&self, editor: &mut SyntaxEditor) { + fn remove(&self, editor: &SyntaxEditor) { match self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:]) { Some(colon) => editor.delete_all(colon..=self.syntax().clone().into()), None => editor.delete(self.syntax()), @@ -331,9 +329,8 @@ impl Removable for ast::TypeBoundList { } impl Removable for ast::Use { - fn remove(&self, editor: &mut SyntaxEditor) { - let make = SyntaxFactory::without_mappings(); - + fn remove(&self, editor: &SyntaxEditor) { + let make = editor.make(); let next_ws = self .syntax() .next_sibling_or_token() @@ -355,7 +352,7 @@ impl Removable for ast::Use { } impl Removable for ast::UseTree { - fn remove(&self, editor: &mut SyntaxEditor) { + fn remove(&self, editor: &SyntaxEditor) { for dir in [Direction::Next, Direction::Prev] { if let Some(next_use_tree) = neighbor(self, dir) { let separators = self @@ -379,7 +376,7 @@ mod tests { use stdx::trim_indent; use test_utils::assert_eq_text; - use crate::SourceFile; + use crate::{SourceFile, ast::syntax_factory::SyntaxFactory}; use super::*; @@ -492,9 +489,9 @@ enum Foo { } fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) { - let (mut editor, enum_) = SyntaxEditor::with_ast_node(&ast_from_text::(before)); + let (editor, enum_) = SyntaxEditor::with_ast_node(&ast_from_text::(before)); if let Some(it) = enum_.variant_list() { - it.add_variant(&mut editor, &variant) + it.add_variant(&editor, &variant) } let edit = editor.finish(); let after = edit.new_root.to_string(); diff --git a/src/tools/rust-analyzer/crates/syntax/src/validation.rs b/src/tools/rust-analyzer/crates/syntax/src/validation.rs index 485140be8f69c..4622590656e4c 100644 --- a/src/tools/rust-analyzer/crates/syntax/src/validation.rs +++ b/src/tools/rust-analyzer/crates/syntax/src/validation.rs @@ -240,8 +240,16 @@ fn validate_numeric_name(name_ref: Option, errors: &mut Vec) { - let path_without_in_token = vis.in_token().is_none() - && vis.path().and_then(|p| p.as_single_name_ref()).and_then(|n| n.ident_token()).is_some(); + let path_without_in_token = if let Some(inner) = vis.visibility_inner() { + inner.in_token().is_none() + && inner + .path() + .and_then(|p| p.as_single_name_ref()) + .and_then(|n| n.ident_token()) + .is_some() + } else { + false + }; if path_without_in_token { errors.push(SyntaxError::new("incorrect visibility restriction", vis.syntax.text_range())); } diff --git a/src/tools/rust-analyzer/crates/syntax/test_data/parser/validation/0037_visibility_in_traits.rast b/src/tools/rust-analyzer/crates/syntax/test_data/parser/validation/0037_visibility_in_traits.rast index 90c258cd1a6c4..2d6d4b2681354 100644 --- a/src/tools/rust-analyzer/crates/syntax/test_data/parser/validation/0037_visibility_in_traits.rast +++ b/src/tools/rust-analyzer/crates/syntax/test_data/parser/validation/0037_visibility_in_traits.rast @@ -51,12 +51,13 @@ SOURCE_FILE@0..118 TYPE_ALIAS@56..81 VISIBILITY@56..66 PUB_KW@56..59 "pub" - L_PAREN@59..60 "(" - PATH@60..65 - PATH_SEGMENT@60..65 - NAME_REF@60..65 - CRATE_KW@60..65 "crate" - R_PAREN@65..66 ")" + VISIBILITY_INNER@59..66 + L_PAREN@59..60 "(" + PATH@60..65 + PATH_SEGMENT@60..65 + NAME_REF@60..65 + CRATE_KW@60..65 "crate" + R_PAREN@65..66 ")" WHITESPACE@66..67 " " TYPE_KW@67..71 "type" WHITESPACE@71..72 " " @@ -73,12 +74,13 @@ SOURCE_FILE@0..118 CONST@86..115 VISIBILITY@86..96 PUB_KW@86..89 "pub" - L_PAREN@89..90 "(" - PATH@90..95 - PATH_SEGMENT@90..95 - NAME_REF@90..95 - CRATE_KW@90..95 "crate" - R_PAREN@95..96 ")" + VISIBILITY_INNER@89..96 + L_PAREN@89..90 "(" + PATH@90..95 + PATH_SEGMENT@90..95 + NAME_REF@90..95 + CRATE_KW@90..95 "crate" + R_PAREN@95..96 ")" WHITESPACE@96..97 " " CONST_KW@97..102 "const" WHITESPACE@102..103 " " diff --git a/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs b/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs index 86fb08073253c..a51698aca8074 100644 --- a/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs +++ b/src/tools/rust-analyzer/crates/test-utils/src/minicore.rs @@ -936,6 +936,14 @@ pub mod ops { } } } + + mod internal_implementation_detail { + #[lang = "async_fn_kind_helper"] + trait AsyncFnKindHelper { + #[lang = "async_fn_kind_upvars"] + type Upvars<'closure_env, Inputs, Upvars, BorrowedUpvarsAsFnPtr>; + } + } } pub use self::async_function::{AsyncFn, AsyncFnMut, AsyncFnOnce}; // endregion:async_fn diff --git a/src/tools/rust-analyzer/docs/book/src/contributing/lsp-extensions.md b/src/tools/rust-analyzer/docs/book/src/contributing/lsp-extensions.md index 22c1784ac293c..8ba6f6ab531e6 100644 --- a/src/tools/rust-analyzer/docs/book/src/contributing/lsp-extensions.md +++ b/src/tools/rust-analyzer/docs/book/src/contributing/lsp-extensions.md @@ -1,5 +1,5 @@ $DIR/unsized-return-suggest-ref-issue-152064.rs:7:16 + | +LL | for s in o.map(|s| s[3..8]) {} + | ^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `str` +note: required by an implicit `Sized` bound in `Option::::map` + --> $SRC_DIR/core/src/option.rs:LL:COL + +error[E0277]: the size for values of type `str` cannot be known at compilation time + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:7:24 + | +LL | for s in o.map(|s| s[3..8]) {} + | ^^^^^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `str` + = note: the return type of a function must have a statically known size + +error[E0277]: the size for values of type `str` cannot be known at compilation time + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:7:14 + | +LL | for s in o.map(|s| s[3..8]) {} + | ^^^^^^^^^^^^^^^^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `str` +note: required by an implicit `Sized` bound in `Option` + --> $SRC_DIR/core/src/option.rs:LL:COL +help: consider borrowing the value + | +LL | for s in o.map(|s| &s[3..8]) {} + | + + +error[E0277]: `Option` is not an iterator + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:7:14 + | +LL | for s in o.map(|s| s[3..8]) {} + | ^^^^^^^^^^^^^^^^^^ `Option` is not an iterator + | + = help: the trait `IntoIterator` is not implemented for `Option` +help: the following other types implement trait `IntoIterator` + --> $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `Option` + ::: $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `&Option` + ::: $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `&mut Option` + +error[E0277]: the size for values of type `[u8]` cannot be known at compilation time + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:15:18 + | +LL | for s in arr.map(|s| s[3..8]) {} + | ^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `[u8]` +note: required by an implicit `Sized` bound in `Option::::map` + --> $SRC_DIR/core/src/option.rs:LL:COL + +error[E0277]: the size for values of type `[u8]` cannot be known at compilation time + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:15:26 + | +LL | for s in arr.map(|s| s[3..8]) {} + | ^^^^^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `[u8]` + = note: the return type of a function must have a statically known size + +error[E0277]: the size for values of type `[u8]` cannot be known at compilation time + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:15:14 + | +LL | for s in arr.map(|s| s[3..8]) {} + | ^^^^^^^^^^^^^^^^^^^^ doesn't have a size known at compile-time + | + = help: the trait `Sized` is not implemented for `[u8]` +note: required by an implicit `Sized` bound in `Option` + --> $SRC_DIR/core/src/option.rs:LL:COL +help: consider borrowing the value + | +LL | for s in arr.map(|s| &s[3..8]) {} + | + + +error[E0277]: `Option<[u8]>` is not an iterator + --> $DIR/unsized-return-suggest-ref-issue-152064.rs:15:14 + | +LL | for s in arr.map(|s| s[3..8]) {} + | ^^^^^^^^^^^^^^^^^^^^ `Option<[u8]>` is not an iterator + | + = help: the trait `IntoIterator` is not implemented for `Option<[u8]>` +help: the following other types implement trait `IntoIterator` + --> $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `Option` + ::: $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `&Option` + ::: $SRC_DIR/core/src/option.rs:LL:COL + | + = note: `&mut Option` + +error: aborting due to 8 previous errors + +For more information about this error, try `rustc --explain E0277`. diff --git a/tests/ui/extern/extern-crate-rename.rs b/tests/ui/extern/extern-crate-rename.rs index 9eeea6dc57115..f7642b3b8dcc3 100644 --- a/tests/ui/extern/extern-crate-rename.rs +++ b/tests/ui/extern/extern-crate-rename.rs @@ -1,5 +1,6 @@ //@ aux-build:m1.rs //@ aux-build:m2.rs +//@ reference: items.extern-crate.as extern crate m1; diff --git a/tests/ui/extern/extern-crate-rename.stderr b/tests/ui/extern/extern-crate-rename.stderr index 88b78a07485d4..d8236ad7f7030 100644 --- a/tests/ui/extern/extern-crate-rename.stderr +++ b/tests/ui/extern/extern-crate-rename.stderr @@ -1,5 +1,5 @@ error[E0259]: the name `m1` is defined multiple times - --> $DIR/extern-crate-rename.rs:6:1 + --> $DIR/extern-crate-rename.rs:7:1 | LL | extern crate m1; | ---------------- previous import of the extern crate `m1` here diff --git a/tests/ui/extern/extern-ffi-fn-with-body.rs b/tests/ui/extern/extern-ffi-fn-with-body.rs index ef234e8afd8ca..da822f4abb0a4 100644 --- a/tests/ui/extern/extern-ffi-fn-with-body.rs +++ b/tests/ui/extern/extern-ffi-fn-with-body.rs @@ -1,3 +1,6 @@ +//@ reference: items.extern.fn.body +//@ reference: items.fn.extern.intro + extern "C" { fn foo() -> i32 { //~ ERROR incorrect function inside `extern` block return 0; diff --git a/tests/ui/extern/extern-ffi-fn-with-body.stderr b/tests/ui/extern/extern-ffi-fn-with-body.stderr index dc34490b39a00..38d381a93835f 100644 --- a/tests/ui/extern/extern-ffi-fn-with-body.stderr +++ b/tests/ui/extern/extern-ffi-fn-with-body.stderr @@ -1,5 +1,5 @@ error: incorrect function inside `extern` block - --> $DIR/extern-ffi-fn-with-body.rs:2:8 + --> $DIR/extern-ffi-fn-with-body.rs:5:8 | LL | extern "C" { | ---------- `extern` blocks define existing foreign functions and functions inside of them cannot have a body diff --git a/tests/ui/extern/extern-thiscall.rs b/tests/ui/extern/extern-thiscall.rs index 3fa796bdbe858..e8d982eb978bd 100644 --- a/tests/ui/extern/extern-thiscall.rs +++ b/tests/ui/extern/extern-thiscall.rs @@ -1,5 +1,6 @@ //@ run-pass //@ only-x86 +//@ reference: items.extern.abi.thiscall trait A { extern "thiscall" fn test1(i: i32); diff --git a/tests/ui/extern/function-definition-in-extern-block-75283.rs b/tests/ui/extern/function-definition-in-extern-block-75283.rs index e2b7358743ba1..64c40d2ff7b5c 100644 --- a/tests/ui/extern/function-definition-in-extern-block-75283.rs +++ b/tests/ui/extern/function-definition-in-extern-block-75283.rs @@ -1,3 +1,4 @@ +//@ reference: items.extern.fn.body // https://github.com/rust-lang/rust/issues/75283 extern "C" { fn lol() { //~ ERROR incorrect function inside `extern` block diff --git a/tests/ui/extern/function-definition-in-extern-block-75283.stderr b/tests/ui/extern/function-definition-in-extern-block-75283.stderr index 67be1c2959922..19b6a8548c02a 100644 --- a/tests/ui/extern/function-definition-in-extern-block-75283.stderr +++ b/tests/ui/extern/function-definition-in-extern-block-75283.stderr @@ -1,5 +1,5 @@ error: incorrect function inside `extern` block - --> $DIR/function-definition-in-extern-block-75283.rs:3:8 + --> $DIR/function-definition-in-extern-block-75283.rs:4:8 | LL | extern "C" { | ---------- `extern` blocks define existing foreign functions and functions inside of them cannot have a body diff --git a/tests/ui/extern/issue-28324.rs b/tests/ui/extern/issue-28324.rs index 4af400d823bb1..0dec04cb5fa69 100644 --- a/tests/ui/extern/issue-28324.rs +++ b/tests/ui/extern/issue-28324.rs @@ -1,3 +1,5 @@ +//@ reference: const-eval.const-expr.path-static +//@ reference: items.extern.static.safety extern "C" { static error_message_count: u32; } diff --git a/tests/ui/extern/issue-28324.stderr b/tests/ui/extern/issue-28324.stderr index 4637163bc5c34..689eda0f338c9 100644 --- a/tests/ui/extern/issue-28324.stderr +++ b/tests/ui/extern/issue-28324.stderr @@ -1,11 +1,11 @@ error[E0080]: cannot access extern static `error_message_count` - --> $DIR/issue-28324.rs:5:23 + --> $DIR/issue-28324.rs:7:23 | LL | pub static BAZ: u32 = *&error_message_count; | ^^^^^^^^^^^^^^^^^^^^^ evaluation of `BAZ` failed here error[E0133]: use of extern static is unsafe and requires unsafe function or block - --> $DIR/issue-28324.rs:5:25 + --> $DIR/issue-28324.rs:7:25 | LL | pub static BAZ: u32 = *&error_message_count; | ^^^^^^^^^^^^^^^^^^^ use of extern static diff --git a/tests/ui/extern/issue-47725.rs b/tests/ui/extern/issue-47725.rs index 6b4d0dd30e024..6941508e6e5e4 100644 --- a/tests/ui/extern/issue-47725.rs +++ b/tests/ui/extern/issue-47725.rs @@ -1,3 +1,4 @@ +//@ reference: items.extern.attributes.link_name.allowed-positions #![warn(unused_attributes)] //~ NOTE lint level is defined here #[link_name = "foo"] diff --git a/tests/ui/extern/issue-47725.stderr b/tests/ui/extern/issue-47725.stderr index 023f4265c80fc..b018554d12d21 100644 --- a/tests/ui/extern/issue-47725.stderr +++ b/tests/ui/extern/issue-47725.stderr @@ -1,5 +1,5 @@ error[E0539]: malformed `link_name` attribute input - --> $DIR/issue-47725.rs:19:1 + --> $DIR/issue-47725.rs:20:1 | LL | #[link_name] | ^^^^^^^^^^^^ @@ -11,7 +11,7 @@ LL | #[link_name = "name"] | ++++++++ warning: `#[link_name]` attribute cannot be used on structs - --> $DIR/issue-47725.rs:3:1 + --> $DIR/issue-47725.rs:4:1 | LL | #[link_name = "foo"] | ^^^^^^^^^^^^^^^^^^^^ @@ -19,13 +19,13 @@ LL | #[link_name = "foo"] = warning: this was previously accepted by the compiler but is being phased out; it will become a hard error in a future release! = help: `#[link_name]` can be applied to foreign functions and foreign statics note: the lint level is defined here - --> $DIR/issue-47725.rs:1:9 + --> $DIR/issue-47725.rs:2:9 | LL | #![warn(unused_attributes)] | ^^^^^^^^^^^^^^^^^ warning: `#[link_name]` attribute cannot be used on foreign modules - --> $DIR/issue-47725.rs:10:1 + --> $DIR/issue-47725.rs:11:1 | LL | #[link_name = "foobar"] | ^^^^^^^^^^^^^^^^^^^^^^^ @@ -34,7 +34,7 @@ LL | #[link_name = "foobar"] = help: `#[link_name]` can be applied to foreign functions and foreign statics warning: `#[link_name]` attribute cannot be used on foreign modules - --> $DIR/issue-47725.rs:19:1 + --> $DIR/issue-47725.rs:20:1 | LL | #[link_name] | ^^^^^^^^^^^^ diff --git a/tests/ui/extern/issue-95829.rs b/tests/ui/extern/issue-95829.rs index 493d53d2532f2..3103274879779 100644 --- a/tests/ui/extern/issue-95829.rs +++ b/tests/ui/extern/issue-95829.rs @@ -1,4 +1,5 @@ //@ edition:2018 +//@ reference: items.extern.fn.qualifiers extern "C" { async fn L() { //~ ERROR: incorrect function inside `extern` block diff --git a/tests/ui/extern/issue-95829.stderr b/tests/ui/extern/issue-95829.stderr index 2acd0fa3a2650..84cb3ca44ba88 100644 --- a/tests/ui/extern/issue-95829.stderr +++ b/tests/ui/extern/issue-95829.stderr @@ -1,5 +1,5 @@ error: incorrect function inside `extern` block - --> $DIR/issue-95829.rs:4:14 + --> $DIR/issue-95829.rs:5:14 | LL | extern "C" { | ---------- `extern` blocks define existing foreign functions and functions inside of them cannot have a body @@ -16,7 +16,7 @@ LL | | } = note: for more information, visit https://doc.rust-lang.org/std/keyword.extern.html error: functions in `extern` blocks cannot have `async` qualifier - --> $DIR/issue-95829.rs:4:5 + --> $DIR/issue-95829.rs:5:5 | LL | extern "C" { | ---------- in this `extern` block diff --git a/tests/ui/macros/macro-guard-matcher-recursion.rs b/tests/ui/macros/macro-guard-matcher-recursion.rs new file mode 100644 index 0000000000000..c916d95003770 --- /dev/null +++ b/tests/ui/macros/macro-guard-matcher-recursion.rs @@ -0,0 +1,10 @@ +//! Regression test for +#![feature(macro_guard_matcher)] +fn main() { + macro_rules! m { + ($g : guard) => { + m!($g) //~ ERROR recursion limit reached while expanding `m!` + }; + } + m!(if x) +} diff --git a/tests/ui/macros/macro-guard-matcher-recursion.stderr b/tests/ui/macros/macro-guard-matcher-recursion.stderr new file mode 100644 index 0000000000000..e12b2de765100 --- /dev/null +++ b/tests/ui/macros/macro-guard-matcher-recursion.stderr @@ -0,0 +1,14 @@ +error: recursion limit reached while expanding `m!` + --> $DIR/macro-guard-matcher-recursion.rs:6:13 + | +LL | m!($g) + | ^^^^^^ +... +LL | m!(if x) + | -------- in this macro invocation + | + = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`macro_guard_matcher_recursion`) + = note: this error originates in the macro `m` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: aborting due to 1 previous error +