Path: blob/master/rust/pin-init/examples/pthread_mutex.rs
29278 views
// SPDX-License-Identifier: Apache-2.0 OR MIT12// inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>3#![allow(clippy::undocumented_unsafe_blocks)]4#![cfg_attr(feature = "alloc", feature(allocator_api))]5#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]67#[cfg(not(windows))]8mod pthread_mtx {9#[cfg(feature = "alloc")]10use core::alloc::AllocError;11use core::{12cell::UnsafeCell,13marker::PhantomPinned,14mem::MaybeUninit,15ops::{Deref, DerefMut},16pin::Pin,17};18use pin_init::*;19use std::convert::Infallible;2021#[pin_data(PinnedDrop)]22pub struct PThreadMutex<T> {23#[pin]24raw: UnsafeCell<libc::pthread_mutex_t>,25data: UnsafeCell<T>,26#[pin]27pin: PhantomPinned,28}2930unsafe impl<T: Send> Send for PThreadMutex<T> {}31unsafe impl<T: Send> Sync for PThreadMutex<T> {}3233#[pinned_drop]34impl<T> PinnedDrop for PThreadMutex<T> {35fn drop(self: Pin<&mut Self>) {36unsafe {37libc::pthread_mutex_destroy(self.raw.get());38}39}40}4142#[derive(Debug)]43pub enum Error {44#[allow(dead_code)]45IO(std::io::Error),46#[allow(dead_code)]47Alloc,48}4950impl From<Infallible> for Error {51fn from(e: Infallible) -> Self {52match e {}53}54}5556#[cfg(feature = "alloc")]57impl From<AllocError> for Error {58fn from(_: AllocError) -> Self {59Self::Alloc60}61}6263impl<T> PThreadMutex<T> {64#[allow(dead_code)]65pub fn new(data: T) -> impl PinInit<Self, Error> {66fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {67let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {68// we can cast, because `UnsafeCell` has the same layout as T.69let slot: *mut libc::pthread_mutex_t = slot.cast();70let mut attr = MaybeUninit::uninit();71let attr = attr.as_mut_ptr();72// SAFETY: ptr is valid73let ret = unsafe { libc::pthread_mutexattr_init(attr) };74if ret != 0 {75return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));76}77// SAFETY: attr is initialized78let ret = unsafe {79libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)80};81if ret != 0 {82// SAFETY: attr is initialized83unsafe { libc::pthread_mutexattr_destroy(attr) };84return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));85}86// SAFETY: slot is valid87unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };88// SAFETY: attr and slot are valid ptrs and attr is initialized89let ret = unsafe { libc::pthread_mutex_init(slot, attr) };90// SAFETY: attr was initialized91unsafe { libc::pthread_mutexattr_destroy(attr) };92if ret != 0 {93return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));94}95Ok(())96};97// SAFETY: mutex has been initialized98unsafe { pin_init_from_closure(init) }99}100try_pin_init!(Self {101data: UnsafeCell::new(data),102raw <- init_raw(),103pin: PhantomPinned,104}? Error)105}106107#[allow(dead_code)]108pub fn lock(&self) -> PThreadMutexGuard<'_, T> {109// SAFETY: raw is always initialized110unsafe { libc::pthread_mutex_lock(self.raw.get()) };111PThreadMutexGuard { mtx: self }112}113}114115pub struct PThreadMutexGuard<'a, T> {116mtx: &'a PThreadMutex<T>,117}118119impl<T> Drop for PThreadMutexGuard<'_, T> {120fn drop(&mut self) {121// SAFETY: raw is always initialized122unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };123}124}125126impl<T> Deref for PThreadMutexGuard<'_, T> {127type Target = T;128129fn deref(&self) -> &Self::Target {130unsafe { &*self.mtx.data.get() }131}132}133134impl<T> DerefMut for PThreadMutexGuard<'_, T> {135fn deref_mut(&mut self) -> &mut Self::Target {136unsafe { &mut *self.mtx.data.get() }137}138}139}140141#[cfg_attr(test, test)]142#[cfg_attr(all(test, miri), ignore)]143fn main() {144#[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]145{146use core::pin::Pin;147use pin_init::*;148use pthread_mtx::*;149use std::{150sync::Arc,151thread::{sleep, Builder},152time::Duration,153};154let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();155let mut handles = vec![];156let thread_count = 20;157let workload = 1_000_000;158for i in 0..thread_count {159let mtx = mtx.clone();160handles.push(161Builder::new()162.name(format!("worker #{i}"))163.spawn(move || {164for _ in 0..workload {165*mtx.lock() += 1;166}167println!("{i} halfway");168sleep(Duration::from_millis((i as u64) * 10));169for _ in 0..workload {170*mtx.lock() += 1;171}172println!("{i} finished");173})174.expect("should not fail"),175);176}177for h in handles {178h.join().expect("thread panicked");179}180println!("{:?}", &*mtx.lock());181assert_eq!(*mtx.lock(), workload * thread_count * 2);182}183}184185186