diff --git a/src/shm/pool.rs b/src/shm/pool.rs index d4d5e21..07dcc21 100644 --- a/src/shm/pool.rs +++ b/src/shm/pool.rs @@ -1,8 +1,16 @@ +use std::cell::Cell; use std::os::unix::io::RawFd; -use std::sync::RwLock; +use std::sync::{RwLock, Once, ONCE_INIT}; use std::ptr; +use nix::{c_int, c_void, libc}; use nix::sys::mman; +use nix::sys::signal::{self, SigAction, Signal, SigHandler}; + +thread_local!(static SIGBUS_GUARD: Cell<(*const MemMap, bool)> = Cell::new((ptr::null_mut(), false))); + +static SIGBUS_INIT: Once = ONCE_INIT; +static mut OLD_SIGBUS_HANDLER: *mut SigAction = 0 as *mut SigAction; pub struct Pool { map: RwLock @@ -30,13 +38,32 @@ impl Pool { } pub fn with_data_slice(&self, f: F) -> Result<(),()> { - // TODO: handle SIGBUS - let guard = self.map.read().unwrap(); + // Place the sigbus handler + SIGBUS_INIT.call_once(|| { + unsafe { place_sigbus_handler(); } + }); - let slice = guard.get_slice(); + let pool_guard = self.map.read().unwrap(); + + // Prepare the access + SIGBUS_GUARD.with(|guard| { + let (p,_) = guard.get(); + if !p.is_null() { + // Recursive call of this method is not supported + panic!("Recursive access to a SHM pool content is not supported."); + } + guard.set((&*pool_guard as *const MemMap, false)) + }); + + let slice = pool_guard.get_slice(); f(slice); - Ok(()) + // Cleanup Post-access + SIGBUS_GUARD.with(|guard| { + let (_, triggered) = guard.get(); + guard.set((ptr::null_mut(), false)); + if triggered { Err(()) } else { Ok(()) } + }) } } @@ -88,6 +115,14 @@ impl MemMap { // which is perfectly safe even if self.ptr is null unsafe { ::std::slice::from_raw_parts(self.ptr, self.size) } } + + fn contains(&self, ptr: *mut u8) -> bool { + ptr >= self.ptr && ptr < unsafe { self.ptr.offset(self.size as isize) } + } + + fn nullify(&self) -> Result<(),()> { + unsafe { nullify_map(self.ptr, self.size) } + } } impl Drop for MemMap { @@ -116,3 +151,81 @@ unsafe fn unmap(ptr: *mut u8, size: usize) -> Result<(),()> { let ret = mman::munmap(ptr as *mut _, size); ret.map_err(|_| ()) } + +unsafe fn nullify_map(ptr: *mut u8, size: usize) -> Result<(), ()> { + let ret = mman::mmap( + ptr as *mut _, + size, + mman::PROT_READ, + mman::MAP_ANONYMOUS | mman::MAP_PRIVATE | mman::MAP_FIXED, + -1, + 0 + ); + ret.map(|_| ()).map_err(|_| ()) +} + +unsafe fn place_sigbus_handler() { + // create our sigbus handler + let action = SigAction::new( + SigHandler::SigAction(sigbus_handler), + signal::SA_NODEFER, + signal::SigSet::empty() + ); + match signal::sigaction(Signal::SIGBUS, &action) { + Ok(old_signal) => { + OLD_SIGBUS_HANDLER = Box::into_raw(Box::new(old_signal)); + }, + Err(e) => { + panic!("sigaction failed sor SIGBUS handler: {:?}", e) + } + } +} + +unsafe fn reraise_sigbus() { + // reset the old sigaction + let _ = signal::sigaction(Signal::SIGBUS, &*OLD_SIGBUS_HANDLER); + let _ = signal::raise(Signal::SIGBUS); +} + +extern "C" fn sigbus_handler(_signum: c_int, info: *mut libc::siginfo_t, _context: *mut c_void) { + let faulty_ptr = unsafe { siginfo_si_addr(info) } as *mut u8; + SIGBUS_GUARD.with(|guard| { + let (memmap, _) = guard.get(); + match unsafe { memmap.as_ref() }.map(|m| (m, m.contains(faulty_ptr))) { + Some((m, true)) => { + // we are in a faulty memory pool ! + // remember that it was faulty + guard.set((memmap, true)); + // nullify the pool + if m.nullify().is_err() { + // something terrible occured ! + unsafe { reraise_sigbus() } + } + }, + _ => { + // something else occured, let's die honorably + unsafe { reraise_sigbus() } + } + } + }); +} + +// This was shamelessly stolen from rustc's source +// so I expect it to work whevener rust works +// I guess it's good enough? + +#[cfg(any(target_os = "linux", target_os = "android"))] +unsafe fn siginfo_si_addr(info: *mut libc::siginfo_t) -> *mut c_void { + #[repr(C)] + struct siginfo_t { + a: [libc::c_int; 3], // si_signo, si_errno, si_code + si_addr: *mut libc::c_void, + } + + (*(info as *const siginfo_t)).si_addr +} + +#[cfg(not(any(target_os = "linux", target_os = "android")))] +unsafe fn siginfo_si_addr(info: *mut libc::siginfo_t) -> *mut c_void { + (*info).si_addr +}