Validate memory pointers

This commit is contained in:
Jeremy Soller 2016-09-20 18:03:14 -06:00
parent be3bcbb878
commit 8dfd003c72
3 changed files with 96 additions and 63 deletions

View file

@ -65,6 +65,9 @@ impl UserInner {
} }
fn capture_inner(&self, address: usize, size: usize, writable: bool) -> Result<usize> { fn capture_inner(&self, address: usize, size: usize, writable: bool) -> Result<usize> {
if size == 0 {
Ok(0)
} else {
let context_lock = self.context.upgrade().ok_or(Error::new(ESRCH))?; let context_lock = self.context.upgrade().ok_or(Error::new(ESRCH))?;
let context = context_lock.read(); let context = context_lock.read();
@ -112,10 +115,14 @@ impl UserInner {
&mut temporary_page &mut temporary_page
)); ));
return Ok(to_address + offset); Ok(to_address + offset)
}
} }
pub fn release(&self, address: usize) -> Result<()> { pub fn release(&self, address: usize) -> Result<()> {
if address == 0 {
Ok(())
} else {
let context_lock = self.context.upgrade().ok_or(Error::new(ESRCH))?; let context_lock = self.context.upgrade().ok_or(Error::new(ESRCH))?;
let context = context_lock.read(); let context = context_lock.read();
@ -136,6 +143,7 @@ impl UserInner {
Err(Error::new(EFAULT)) Err(Error::new(EFAULT))
} }
}
pub fn read(&self, buf: &mut [u8]) -> Result<usize> { pub fn read(&self, buf: &mut [u8]) -> Result<usize> {
let packet_size = mem::size_of::<Packet>(); let packet_size = mem::size_of::<Packet>();

View file

@ -1,15 +1,39 @@
use core::slice; use core::{mem, slice};
use arch::paging::{ActivePageTable, Page, VirtualAddress, entry};
use syscall::error::*; use syscall::error::*;
/// Convert a pointer and length to slice, if valid fn validate(address: usize, size: usize, flags: entry::EntryFlags) -> Result<()> {
/// TODO: Check validity let active_table = unsafe { ActivePageTable::new() };
pub fn validate_slice<T>(ptr: *const T, len: usize) -> Result<&'static [T]> {
Ok(unsafe { slice::from_raw_parts(ptr, len) }) let start_page = Page::containing_address(VirtualAddress::new(address));
let end_page = Page::containing_address(VirtualAddress::new(address + size - 1));
for page in Page::range_inclusive(start_page, end_page) {
let page_flags = active_table.translate_page_flags(page).ok_or(Error::new(EFAULT))?;
if ! page_flags.contains(flags) {
return Err(Error::new(EFAULT));
}
}
Ok(())
} }
/// Convert a pointer and length to slice, if valid /// Convert a pointer and length to slice, if valid
/// TODO: Check validity pub fn validate_slice<T>(ptr: *const T, len: usize) -> Result<&'static [T]> {
pub fn validate_slice_mut<T>(ptr: *mut T, len: usize) -> Result<&'static mut [T]> { if len == 0 {
Ok(unsafe { slice::from_raw_parts_mut(ptr, len) }) Ok(&[])
} else {
validate(ptr as usize, len * mem::size_of::<T>(), entry::PRESENT /* TODO | entry::USER_ACCESSIBLE */)?;
Ok(unsafe { slice::from_raw_parts(ptr, len) })
}
}
/// Convert a pointer and length to slice, if valid
pub fn validate_slice_mut<T>(ptr: *mut T, len: usize) -> Result<&'static mut [T]> {
if len == 0 {
Ok(&mut [])
} else {
validate(ptr as usize, len * mem::size_of::<T>(), entry::PRESENT | entry::WRITABLE /* TODO | entry::USER_ACCESSIBLE */)?;
Ok(unsafe { slice::from_raw_parts_mut(ptr, len) })
}
} }

View file

@ -31,6 +31,7 @@ fn main(){
loop { loop {
let mut packet = Packet::default(); let mut packet = Packet::default();
socket.read(&mut packet).expect("example: failed to read events from example scheme"); socket.read(&mut packet).expect("example: failed to read events from example scheme");
println!("{:?}", packet);
scheme.handle(&mut packet); scheme.handle(&mut packet);
socket.write(&packet).expect("example: failed to write responses to example scheme"); socket.write(&packet).expect("example: failed to write responses to example scheme");
} }