Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 39 additions & 54 deletions library/core/src/str/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,35 @@ unsafe impl const SliceIndex<str> for ops::RangeFull {
}
}

/// Check that a range is in bounds for slicing a string.
/// If this returns true, it is safe to call `slice.get_unchecked(range)` or
/// `slice.get_unchecked_mut(range)`.
#[inline(always)]
const fn check_range(slice: &str, range: crate::range::Range<usize>) -> bool {
let crate::range::Range { start, end } = range;
let bytes = slice.as_bytes();

if start > end || end > slice.len() {
return false;
}

if start == slice.len() {
// If `start == slice.len()`, then `end == slice.len()` must also be true.
return true;
}

// SAFETY:
// `start > end || end > slice.len()` is false, so `start <= end <= slice.len()` is true.
// `start == slice.len()` is false, so `start < slice.len()` is also true.
//
// No need to check for `end == 0`, because if `end == 0` is true then `start == slice.len()`
// would also be true, which is already handled above.
unsafe {
(start == 0 || bytes.as_ptr().add(start).read().is_utf8_char_boundary())
&& (end == slice.len() || bytes.as_ptr().add(end).read().is_utf8_char_boundary())
}
}

/// Implements substring slicing with syntax `&self[begin .. end]` or `&mut
/// self[begin .. end]`.
///
Expand Down Expand Up @@ -159,30 +188,11 @@ unsafe impl const SliceIndex<str> for ops::Range<usize> {
type Output = str;
#[inline]
fn get(self, slice: &str) -> Option<&Self::Output> {
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary,
// and we are passing in a safe reference, so the return value will also be one.
// We also checked char boundaries, so this is valid UTF-8.
Some(unsafe { &*self.get_unchecked(slice) })
} else {
None
}
range::Range::from(self).get(slice)
}
#[inline]
fn get_mut(self, slice: &mut str) -> Option<&mut Self::Output> {
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary.
// We know the pointer is unique because we got it from `slice`.
Some(unsafe { &mut *self.get_unchecked_mut(slice) })
} else {
None
}
range::Range::from(self).get_mut(slice)
}
#[inline]
#[track_caller]
Expand Down Expand Up @@ -235,26 +245,11 @@ unsafe impl const SliceIndex<str> for ops::Range<usize> {
}
#[inline]
fn index(self, slice: &str) -> &Self::Output {
let (start, end) = (self.start, self.end);
match self.get(slice) {
Some(s) => s,
None => super::slice_error_fail(slice, start, end),
}
range::Range::from(self).index(slice)
}
#[inline]
fn index_mut(self, slice: &mut str) -> &mut Self::Output {
// is_char_boundary checks that the index is in [0, .len()]
// cannot reuse `get` as above, because of NLL trouble
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary,
// and we are passing in a safe reference, so the return value will also be one.
unsafe { &mut *self.get_unchecked_mut(slice) }
} else {
super::slice_error_fail(slice, self.start, self.end)
}
range::Range::from(self).index_mut(slice)
}
}

Expand All @@ -264,11 +259,8 @@ unsafe impl const SliceIndex<str> for range::Range<usize> {
type Output = str;
#[inline]
fn get(self, slice: &str) -> Option<&Self::Output> {
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary,
if check_range(slice, self) {
// SAFETY: just checked that `self` is in bounds,
// and we are passing in a safe reference, so the return value will also be one.
// We also checked char boundaries, so this is valid UTF-8.
Some(unsafe { &*self.get_unchecked(slice) })
Expand All @@ -278,11 +270,8 @@ unsafe impl const SliceIndex<str> for range::Range<usize> {
}
#[inline]
fn get_mut(self, slice: &mut str) -> Option<&mut Self::Output> {
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary.
if check_range(slice, self) {
// SAFETY: just checked that `self` is in bounds.
// We know the pointer is unique because we got it from `slice`.
Some(unsafe { &mut *self.get_unchecked_mut(slice) })
} else {
Expand Down Expand Up @@ -348,13 +337,9 @@ unsafe impl const SliceIndex<str> for range::Range<usize> {
}
#[inline]
fn index_mut(self, slice: &mut str) -> &mut Self::Output {
// is_char_boundary checks that the index is in [0, .len()]
// cannot reuse `get` as above, because of NLL trouble
if self.start <= self.end
&& slice.is_char_boundary(self.start)
&& slice.is_char_boundary(self.end)
{
// SAFETY: just checked that `start` and `end` are on a char boundary,
if check_range(slice, self) {
// SAFETY: just checked that `self` is in bounds,
// and we are passing in a safe reference, so the return value will also be one.
unsafe { &mut *self.get_unchecked_mut(slice) }
} else {
Expand Down
13 changes: 7 additions & 6 deletions tests/codegen-llvm/str-range-indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ macro_rules! tests {
};
}

// 9 comparisons required:
// start <= end
// && (start == 0 || (start >= len && start == len) || bytes[start] >= -0x40)
// && (end == 0 || (end >= len && end == len) || bytes[end] >= -0x40)
// 7 comparisons required:
// start <= end && end <= len
// && (start == len ||
// ( (start == 0 || bytes[start] >= -0x40)
// && (end == len || bytes[end] >= -0x40)))

// CHECK-LABEL: @get_range
// CHECK-COUNT-9: %{{.+}} = icmp
// CHECK-COUNT-7: %{{.+}} = icmp
// CHECK-NOT: %{{.+}} = icmp
// CHECK: ret

// CHECK-LABEL: @index_range
// CHECK-COUNT-9: %{{.+}} = icmp
// CHECK-COUNT-7: %{{.+}} = icmp
// CHECK-NOT: %{{.+}} = icmp
// CHECK: ret
tests!(Range<usize>, get_range, index_range);
Expand Down
Loading