diff --git a/src/shm/pool.rs b/src/shm/pool.rs index 4aa07f3..d4d5e21 100644 --- a/src/shm/pool.rs +++ b/src/shm/pool.rs @@ -49,17 +49,34 @@ struct MemMap { impl MemMap { fn new(fd: RawFd, size: usize) -> Result { Ok(MemMap { - ptr: map(fd, size)?, + ptr: unsafe { map(fd, size) }?, fd: fd, size: size }) } fn remap(&mut self, newsize: usize) -> Result<(),()> { - unmap(self.ptr, self.size)?; - self.ptr = map(self.fd, newsize)?; - self.size = newsize; - Ok(()) + if self.ptr.is_null() { + return Err(()) + } + // memunmap cannot fail, as we are unmapping a pre-existing map + let _ = unsafe { unmap(self.ptr, self.size) }; + // remap the fd with the new size + match unsafe { map(self.fd, newsize) } { + Ok(ptr) => { + // update the parameters + self.ptr = ptr; + self.size = newsize; + Ok(()) + }, + Err(()) => { + // set ourselves in an empty state + self.ptr = ptr::null_mut(); + self.size = 0; + self.fd = -1; + Err(()) + } + } } fn size(&self) -> usize { @@ -67,27 +84,35 @@ impl MemMap { } fn get_slice(&self) -> &[u8] { + // if we are in the 'invalid state', self.size == 0 and we return &[] + // which is perfectly safe even if self.ptr is null unsafe { ::std::slice::from_raw_parts(self.ptr, self.size) } } } +impl Drop for MemMap { + fn drop(&mut self) { + if !self.ptr.is_null() { + let _ = unsafe { unmap(self.ptr, self.size) }; + } + } +} + // mman::mmap should really be unsafe... why isn't it? -#[allow(unused_unsafe)] -fn map(fd: RawFd, size: usize) -> Result<*mut u8, ()> { - let ret = unsafe { mman::mmap( +unsafe fn map(fd: RawFd, size: usize) -> Result<*mut u8, ()> { + let ret = mman::mmap( ptr::null_mut(), size, mman::PROT_READ, mman::MAP_SHARED, fd, 0 - ) }; + ); ret.map(|p| p as *mut u8).map_err(|_| ()) } // mman::munmap should really be unsafe... why isn't it? -#[allow(unused_unsafe)] -fn unmap(ptr: *mut u8, size: usize) -> Result<(),()> { - let ret = unsafe { mman::munmap(ptr as *mut _, size) }; +unsafe fn unmap(ptr: *mut u8, size: usize) -> Result<(),()> { + let ret = mman::munmap(ptr as *mut _, size); ret.map_err(|_| ()) }