use std::{
arch::x86_64::{
__m128i, _mm_cmpestri, _mm_cmpestrm, _mm_extract_epi16, _mm_loadu_si128, _SIDD_CMP_EQUAL_ORDERED,
},
cmp::min,
slice,
};
include!(concat!(env!("OUT_DIR"), "/src/simd_macros.rs"));
const BYTES_PER_OPERATION: usize = 16;
union TransmuteToSimd {
simd: __m128i,
bytes: [u8; 16],
}
trait PackedCompareControl {
const CONTROL_BYTE: i32;
fn needle(&self) -> __m128i;
fn needle_len(&self) -> i32;
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn find<C>(packed: PackedCompare<C>, mut haystack: &[u8]) -> Option<usize>
where
C: PackedCompareControl,
{
if haystack.is_empty() {
return None;
}
let mut offset = 0;
if let Some(misaligned) = Misalignment::new(haystack) {
if let Some(location) = packed.cmpestrm(misaligned.leading, misaligned.leading_junk) {
if location < haystack.len() {
return Some(location);
}
}
haystack = &haystack[misaligned.bytes_until_alignment..];
offset += misaligned.bytes_until_alignment;
}
let n_complete_chunks = haystack.len() / BYTES_PER_OPERATION;
let mut haystack_ptr = haystack.as_ptr();
let mut chunk_offset = 0;
for _ in 0..n_complete_chunks {
if let Some(location) = packed.cmpestri(haystack_ptr, BYTES_PER_OPERATION as i32) {
return Some(offset + chunk_offset + location);
}
haystack_ptr = haystack_ptr.offset(BYTES_PER_OPERATION as isize);
chunk_offset += BYTES_PER_OPERATION;
}
haystack = &haystack[chunk_offset..];
offset += chunk_offset;
debug_assert!(haystack.len() < ::std::i32::MAX as usize);
packed
.cmpestri(haystack.as_ptr(), haystack.len() as i32)
.map(|loc| offset + loc)
}
struct PackedCompare<T>(T);
impl<T> PackedCompare<T>
where
T: PackedCompareControl,
{
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn cmpestrm(&self, haystack: &[u8], leading_junk: usize) -> Option<usize> {
let haystack = _mm_loadu_si128(haystack.as_ptr() as *const __m128i);
let mask = _mm_cmpestrm(
self.0.needle(),
self.0.needle_len(),
haystack,
BYTES_PER_OPERATION as i32,
T::CONTROL_BYTE,
);
let mask = _mm_extract_epi16(mask, 0) as u16;
if mask.trailing_zeros() < 16 {
let mut mask = mask;
mask >>= leading_junk;
if mask == 0 {
None
} else {
let first_match = mask.trailing_zeros() as usize;
debug_assert!(first_match < 16);
Some(first_match)
}
} else {
None
}
}
#[inline]
#[target_feature(enable = "sse4.2")]
unsafe fn cmpestri(&self, haystack: *const u8, haystack_len: i32) -> Option<usize> {
let haystack = _mm_loadu_si128(haystack as *const __m128i);
let location = _mm_cmpestri(
self.0.needle(),
self.0.needle_len(),
haystack,
haystack_len,
T::CONTROL_BYTE,
);
if location < 16 {
Some(location as usize)
} else {
None
}
}
}
#[derive(Debug)]
struct Misalignment<'a> {
leading: &'a [u8],
leading_junk: usize,
bytes_until_alignment: usize,
}
impl<'a> Misalignment<'a> {
#[inline]
fn new(haystack: &[u8]) -> Option<Self> {
let aligned_start = ((haystack.as_ptr() as usize) & !0xF) as *const u8;
if aligned_start == haystack.as_ptr() {
return None;
}
let aligned_end = unsafe { aligned_start.offset(BYTES_PER_OPERATION as isize) };
let leading_junk = haystack.as_ptr() as usize - aligned_start as usize;
let leading_len = min(haystack.len() + leading_junk, BYTES_PER_OPERATION);
let leading = unsafe { slice::from_raw_parts(aligned_start, leading_len) };
let bytes_until_alignment = if leading_len == BYTES_PER_OPERATION {
aligned_end as usize - haystack.as_ptr() as usize
} else {
haystack.len()
};
Some(Misalignment {
leading,
leading_junk,
bytes_until_alignment,
})
}
}
pub struct Bytes {
needle: __m128i,
needle_len: i32,
}
impl Bytes {
pub fn new(bytes: [u8; 16], needle_len: i32) -> Self {
Bytes {
needle: unsafe { TransmuteToSimd { bytes }.simd },
needle_len,
}
}
#[inline]
#[target_feature(enable = "sse4.2")]
pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
find(PackedCompare(self), haystack)
}
}
impl<'b> PackedCompareControl for &'b Bytes {
const CONTROL_BYTE: i32 = 0;
fn needle(&self) -> __m128i {
self.needle
}
fn needle_len(&self) -> i32 {
self.needle_len
}
}
pub struct ByteSubstring<'a> {
complete_needle: &'a [u8],
needle: __m128i,
needle_len: i32,
}
impl<'a> ByteSubstring<'a> {
pub fn new(needle: &'a[u8]) -> Self {
use std::cmp;
let mut simd_needle = [0; 16];
let len = cmp::min(simd_needle.len(), needle.len());
simd_needle[..len].copy_from_slice(&needle[..len]);
ByteSubstring {
complete_needle: needle,
needle: unsafe { TransmuteToSimd { bytes: simd_needle }.simd },
needle_len: len as i32,
}
}
#[cfg(feature = "pattern")]
pub fn needle_len(&self) -> usize {
self.complete_needle.len()
}
#[inline]
#[target_feature(enable = "sse4.2")]
pub unsafe fn find(&self, haystack: &[u8]) -> Option<usize> {
let mut offset = 0;
while let Some(idx) = find(PackedCompare(self), &haystack[offset..]) {
let abs_offset = offset + idx;
if haystack[abs_offset..].starts_with(self.complete_needle) {
return Some(abs_offset);
}
offset += idx + 1;
}
None
}
}
impl<'a, 'b> PackedCompareControl for &'b ByteSubstring<'a> {
const CONTROL_BYTE: i32 = _SIDD_CMP_EQUAL_ORDERED;
fn needle(&self) -> __m128i {
self.needle
}
fn needle_len(&self) -> i32 {
self.needle_len
}
}
#[cfg(test)]
mod test {
use proptest::prelude::*;
use std::{fmt, str};
use memmap::MmapMut;
use region::Protection;
use super::*;
lazy_static! {
static ref SPACE: Bytes = simd_bytes!(b' ');
static ref XML_DELIM_3: Bytes = simd_bytes!(b'<', b'>', b'&');
static ref XML_DELIM_5: Bytes = simd_bytes!(b'<', b'>', b'&', b'\'', b'"');
}
trait SliceFindPolyfill<T> {
fn find_any(&self, needles: &[T]) -> Option<usize>;
fn find_seq(&self, needle: &[T]) -> Option<usize>;
}
impl<T> SliceFindPolyfill<T> for [T]
where
T: PartialEq,
{
fn find_any(&self, needles: &[T]) -> Option<usize> {
self.iter().position(|c| needles.contains(c))
}
fn find_seq(&self, needle: &[T]) -> Option<usize> {
(0..self.len()).find(|&l| self[l..].starts_with(needle))
}
}
struct Haystack {
data: Vec<u8>,
start: usize,
}
impl Haystack {
fn without_start(&self) -> &[u8] {
&self.data
}
fn with_start(&self) -> &[u8] {
&self.data[self.start..]
}
}
impl fmt::Debug for Haystack {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Haystack")
.field("data", &self.data)
.field("(addr)", &self.data.as_ptr())
.field("start", &self.start)
.finish()
}
}
fn haystack() -> BoxedStrategy<Haystack> {
any::<Vec<u8>>()
.prop_flat_map(|data| {
let len = 0..=data.len();
(Just(data), len)
})
.prop_map(|(data, start)| Haystack { data, start })
.boxed()
}
#[derive(Debug)]
struct Needle {
data: [u8; 16],
len: usize,
}
impl Needle {
fn as_slice(&self) -> &[u8] {
&self.data[..self.len]
}
}
fn needle() -> BoxedStrategy<Needle> {
(any::<[u8; 16]>(), 0..=16_usize)
.prop_map(|(data, len)| Needle { data, len })
.boxed()
}
proptest! {
#[test]
fn works_as_find_does_for_up_to_and_including_16_bytes(
(haystack, needle) in (haystack(), needle())
) {
let haystack = haystack.without_start();
let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
let them = haystack.find_any(needle.as_slice());
assert_eq!(us, them);
}
#[test]
fn works_as_find_does_for_various_memory_offsets(
(needle, haystack) in (needle(), haystack())
) {
let haystack = haystack.with_start();
let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
let them = haystack.find_any(needle.as_slice());
assert_eq!(us, them);
}
}
#[test]
fn can_search_for_null_bytes() {
unsafe {
let null = simd_bytes!(b'\0');
assert_eq!(Some(1), null.find(b"a\0"));
assert_eq!(Some(0), null.find(b"\0"));
assert_eq!(None, null.find(b""));
}
}
#[test]
fn can_search_in_null_bytes() {
unsafe {
let a = simd_bytes!(b'a');
assert_eq!(Some(1), a.find(b"\0a"));
assert_eq!(None, a.find(b"\0"));
}
}
#[test]
fn space_is_found() {
unsafe {
assert_eq!(Some(0), SPACE.find(b" "));
assert_eq!(Some(1), SPACE.find(b"0 "));
assert_eq!(Some(2), SPACE.find(b"01 "));
assert_eq!(Some(3), SPACE.find(b"012 "));
assert_eq!(Some(4), SPACE.find(b"0123 "));
assert_eq!(Some(5), SPACE.find(b"01234 "));
assert_eq!(Some(6), SPACE.find(b"012345 "));
assert_eq!(Some(7), SPACE.find(b"0123456 "));
assert_eq!(Some(8), SPACE.find(b"01234567 "));
assert_eq!(Some(9), SPACE.find(b"012345678 "));
assert_eq!(Some(10), SPACE.find(b"0123456789 "));
assert_eq!(Some(11), SPACE.find(b"0123456789A "));
assert_eq!(Some(12), SPACE.find(b"0123456789AB "));
assert_eq!(Some(13), SPACE.find(b"0123456789ABC "));
assert_eq!(Some(14), SPACE.find(b"0123456789ABCD "));
assert_eq!(Some(15), SPACE.find(b"0123456789ABCDE "));
assert_eq!(Some(16), SPACE.find(b"0123456789ABCDEF "));
assert_eq!(Some(17), SPACE.find(b"0123456789ABCDEFG "));
}
}
#[test]
fn space_not_found() {
unsafe {
assert_eq!(None, SPACE.find(b""));
assert_eq!(None, SPACE.find(b"0"));
assert_eq!(None, SPACE.find(b"01"));
assert_eq!(None, SPACE.find(b"012"));
assert_eq!(None, SPACE.find(b"0123"));
assert_eq!(None, SPACE.find(b"01234"));
assert_eq!(None, SPACE.find(b"012345"));
assert_eq!(None, SPACE.find(b"0123456"));
assert_eq!(None, SPACE.find(b"01234567"));
assert_eq!(None, SPACE.find(b"012345678"));
assert_eq!(None, SPACE.find(b"0123456789"));
assert_eq!(None, SPACE.find(b"0123456789A"));
assert_eq!(None, SPACE.find(b"0123456789AB"));
assert_eq!(None, SPACE.find(b"0123456789ABC"));
assert_eq!(None, SPACE.find(b"0123456789ABCD"));
assert_eq!(None, SPACE.find(b"0123456789ABCDE"));
assert_eq!(None, SPACE.find(b"0123456789ABCDEF"));
assert_eq!(None, SPACE.find(b"0123456789ABCDEFG"));
}
}
#[test]
fn works_on_nonaligned_beginnings() {
unsafe {
let s = b"0123456789ABCDEF ".to_vec();
assert_eq!(Some(16), SPACE.find(&s[0..]));
assert_eq!(Some(15), SPACE.find(&s[1..]));
assert_eq!(Some(14), SPACE.find(&s[2..]));
assert_eq!(Some(13), SPACE.find(&s[3..]));
assert_eq!(Some(12), SPACE.find(&s[4..]));
assert_eq!(Some(11), SPACE.find(&s[5..]));
assert_eq!(Some(10), SPACE.find(&s[6..]));
assert_eq!(Some(9), SPACE.find(&s[7..]));
assert_eq!(Some(8), SPACE.find(&s[8..]));
assert_eq!(Some(7), SPACE.find(&s[9..]));
assert_eq!(Some(6), SPACE.find(&s[10..]));
assert_eq!(Some(5), SPACE.find(&s[11..]));
assert_eq!(Some(4), SPACE.find(&s[12..]));
assert_eq!(Some(3), SPACE.find(&s[13..]));
assert_eq!(Some(2), SPACE.find(&s[14..]));
assert_eq!(Some(1), SPACE.find(&s[15..]));
assert_eq!(Some(0), SPACE.find(&s[16..]));
assert_eq!(None, SPACE.find(&s[17..]));
}
}
#[test]
fn misalignment_does_not_cause_a_false_positive_before_start() {
const AAAA: u8 = 0x01;
let needle = Needle {
data: [
AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
],
len: 1,
};
let haystack = Haystack {
data: vec![
AAAA, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00,
],
start: 1,
};
let haystack = haystack.with_start();
assert_ne!(0, (haystack.as_ptr() as usize) % 16);
assert!(haystack.len() > 64);
let us = unsafe { Bytes::new(needle.data, needle.len as i32).find(haystack) };
assert_eq!(None, us);
}
#[test]
fn xml_delim_3_is_found() {
unsafe {
assert_eq!(Some(0), XML_DELIM_3.find(b"<"));
assert_eq!(Some(0), XML_DELIM_3.find(b">"));
assert_eq!(Some(0), XML_DELIM_3.find(b"&"));
assert_eq!(None, XML_DELIM_3.find(b""));
}
}
#[test]
fn xml_delim_5_is_found() {
unsafe {
assert_eq!(Some(0), XML_DELIM_5.find(b"<"));
assert_eq!(Some(0), XML_DELIM_5.find(b">"));
assert_eq!(Some(0), XML_DELIM_5.find(b"&"));
assert_eq!(Some(0), XML_DELIM_5.find(b"'"));
assert_eq!(Some(0), XML_DELIM_5.find(b"\""));
assert_eq!(None, XML_DELIM_5.find(b""));
}
}
proptest! {
#[test]
fn works_as_find_does_for_byte_substrings(
(needle, haystack) in (any::<Vec<u8>>(), any::<Vec<u8>>())
) {
let us = unsafe {
let s = ByteSubstring::new(&needle);
s.find(&haystack)
};
let them = haystack.find_seq(&needle);
assert_eq!(us, them);
}
}
#[test]
fn byte_substring_is_found() {
unsafe {
let substr = ByteSubstring::new(b"zz");
assert_eq!(Some(0), substr.find(b"zz"));
assert_eq!(Some(1), substr.find(b"0zz"));
assert_eq!(Some(2), substr.find(b"01zz"));
assert_eq!(Some(3), substr.find(b"012zz"));
assert_eq!(Some(4), substr.find(b"0123zz"));
assert_eq!(Some(5), substr.find(b"01234zz"));
assert_eq!(Some(6), substr.find(b"012345zz"));
assert_eq!(Some(7), substr.find(b"0123456zz"));
assert_eq!(Some(8), substr.find(b"01234567zz"));
assert_eq!(Some(9), substr.find(b"012345678zz"));
assert_eq!(Some(10), substr.find(b"0123456789zz"));
assert_eq!(Some(11), substr.find(b"0123456789Azz"));
assert_eq!(Some(12), substr.find(b"0123456789ABzz"));
assert_eq!(Some(13), substr.find(b"0123456789ABCzz"));
assert_eq!(Some(14), substr.find(b"0123456789ABCDzz"));
assert_eq!(Some(15), substr.find(b"0123456789ABCDEzz"));
assert_eq!(Some(16), substr.find(b"0123456789ABCDEFzz"));
assert_eq!(Some(17), substr.find(b"0123456789ABCDEFGzz"));
}
}
#[test]
fn byte_substring_is_not_found() {
unsafe {
let substr = ByteSubstring::new(b"zz");
assert_eq!(None, substr.find(b""));
assert_eq!(None, substr.find(b"0"));
assert_eq!(None, substr.find(b"01"));
assert_eq!(None, substr.find(b"012"));
assert_eq!(None, substr.find(b"0123"));
assert_eq!(None, substr.find(b"01234"));
assert_eq!(None, substr.find(b"012345"));
assert_eq!(None, substr.find(b"0123456"));
assert_eq!(None, substr.find(b"01234567"));
assert_eq!(None, substr.find(b"012345678"));
assert_eq!(None, substr.find(b"0123456789"));
assert_eq!(None, substr.find(b"0123456789A"));
assert_eq!(None, substr.find(b"0123456789AB"));
assert_eq!(None, substr.find(b"0123456789ABC"));
assert_eq!(None, substr.find(b"0123456789ABCD"));
assert_eq!(None, substr.find(b"0123456789ABCDE"));
assert_eq!(None, substr.find(b"0123456789ABCDEF"));
assert_eq!(None, substr.find(b"0123456789ABCDEFG"));
}
}
#[test]
fn byte_substring_has_false_positive() {
unsafe {
let substr = ByteSubstring::new(b"ab");
assert_eq!(Some(16), substr.find(b"aaaaaaaaaaaaaaaaab"))
};
}
#[test]
fn byte_substring_needle_is_longer_than_16_bytes() {
unsafe {
let needle = b"0123456789abcdefg";
let haystack = b"0123456789abcdefgh";
assert_eq!(Some(0), ByteSubstring::new(needle).find(haystack));
}
}
fn with_guarded_string(value: &str, f: impl FnOnce(&str)) {
let page_size = region::page::size();
assert!(value.len() <= page_size);
let mut mmap = MmapMut::map_anon(2 * page_size).unwrap();
let (first_page, second_page) = mmap.split_at_mut(page_size);
unsafe {
region::protect(second_page.as_ptr(), page_size, Protection::None).unwrap();
}
let dest = &mut first_page[page_size - value.len()..];
dest.copy_from_slice(value.as_bytes());
f(unsafe { str::from_utf8_unchecked(dest) });
}
#[test]
fn works_at_page_boundary() {
with_guarded_string("0123456789abcdef", |text| {
let needle = simd_bytes!(b'f');
for offset in 0..text.len() {
let tail = &text[offset..];
unsafe {
assert_eq!(Some(tail.len() - 1), needle.find(tail.as_bytes()));
}
}
});
}
}