Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/rust/pin-init/examples/pthread_mutex.rs
29278 views
1
// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3
// inspired by <https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs>
4
#![allow(clippy::undocumented_unsafe_blocks)]
5
#![cfg_attr(feature = "alloc", feature(allocator_api))]
6
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
7
8
#[cfg(not(windows))]
9
mod pthread_mtx {
10
#[cfg(feature = "alloc")]
11
use core::alloc::AllocError;
12
use core::{
13
cell::UnsafeCell,
14
marker::PhantomPinned,
15
mem::MaybeUninit,
16
ops::{Deref, DerefMut},
17
pin::Pin,
18
};
19
use pin_init::*;
20
use std::convert::Infallible;
21
22
#[pin_data(PinnedDrop)]
23
pub struct PThreadMutex<T> {
24
#[pin]
25
raw: UnsafeCell<libc::pthread_mutex_t>,
26
data: UnsafeCell<T>,
27
#[pin]
28
pin: PhantomPinned,
29
}
30
31
unsafe impl<T: Send> Send for PThreadMutex<T> {}
32
unsafe impl<T: Send> Sync for PThreadMutex<T> {}
33
34
#[pinned_drop]
35
impl<T> PinnedDrop for PThreadMutex<T> {
36
fn drop(self: Pin<&mut Self>) {
37
unsafe {
38
libc::pthread_mutex_destroy(self.raw.get());
39
}
40
}
41
}
42
43
#[derive(Debug)]
44
pub enum Error {
45
#[allow(dead_code)]
46
IO(std::io::Error),
47
#[allow(dead_code)]
48
Alloc,
49
}
50
51
impl From<Infallible> for Error {
52
fn from(e: Infallible) -> Self {
53
match e {}
54
}
55
}
56
57
#[cfg(feature = "alloc")]
58
impl From<AllocError> for Error {
59
fn from(_: AllocError) -> Self {
60
Self::Alloc
61
}
62
}
63
64
impl<T> PThreadMutex<T> {
65
#[allow(dead_code)]
66
pub fn new(data: T) -> impl PinInit<Self, Error> {
67
fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
68
let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
69
// we can cast, because `UnsafeCell` has the same layout as T.
70
let slot: *mut libc::pthread_mutex_t = slot.cast();
71
let mut attr = MaybeUninit::uninit();
72
let attr = attr.as_mut_ptr();
73
// SAFETY: ptr is valid
74
let ret = unsafe { libc::pthread_mutexattr_init(attr) };
75
if ret != 0 {
76
return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
77
}
78
// SAFETY: attr is initialized
79
let ret = unsafe {
80
libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
81
};
82
if ret != 0 {
83
// SAFETY: attr is initialized
84
unsafe { libc::pthread_mutexattr_destroy(attr) };
85
return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
86
}
87
// SAFETY: slot is valid
88
unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
89
// SAFETY: attr and slot are valid ptrs and attr is initialized
90
let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
91
// SAFETY: attr was initialized
92
unsafe { libc::pthread_mutexattr_destroy(attr) };
93
if ret != 0 {
94
return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
95
}
96
Ok(())
97
};
98
// SAFETY: mutex has been initialized
99
unsafe { pin_init_from_closure(init) }
100
}
101
try_pin_init!(Self {
102
data: UnsafeCell::new(data),
103
raw <- init_raw(),
104
pin: PhantomPinned,
105
}? Error)
106
}
107
108
#[allow(dead_code)]
109
pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
110
// SAFETY: raw is always initialized
111
unsafe { libc::pthread_mutex_lock(self.raw.get()) };
112
PThreadMutexGuard { mtx: self }
113
}
114
}
115
116
pub struct PThreadMutexGuard<'a, T> {
117
mtx: &'a PThreadMutex<T>,
118
}
119
120
impl<T> Drop for PThreadMutexGuard<'_, T> {
121
fn drop(&mut self) {
122
// SAFETY: raw is always initialized
123
unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
124
}
125
}
126
127
impl<T> Deref for PThreadMutexGuard<'_, T> {
128
type Target = T;
129
130
fn deref(&self) -> &Self::Target {
131
unsafe { &*self.mtx.data.get() }
132
}
133
}
134
135
impl<T> DerefMut for PThreadMutexGuard<'_, T> {
136
fn deref_mut(&mut self) -> &mut Self::Target {
137
unsafe { &mut *self.mtx.data.get() }
138
}
139
}
140
}
141
142
#[cfg_attr(test, test)]
143
#[cfg_attr(all(test, miri), ignore)]
144
fn main() {
145
#[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
146
{
147
use core::pin::Pin;
148
use pin_init::*;
149
use pthread_mtx::*;
150
use std::{
151
sync::Arc,
152
thread::{sleep, Builder},
153
time::Duration,
154
};
155
let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
156
let mut handles = vec![];
157
let thread_count = 20;
158
let workload = 1_000_000;
159
for i in 0..thread_count {
160
let mtx = mtx.clone();
161
handles.push(
162
Builder::new()
163
.name(format!("worker #{i}"))
164
.spawn(move || {
165
for _ in 0..workload {
166
*mtx.lock() += 1;
167
}
168
println!("{i} halfway");
169
sleep(Duration::from_millis((i as u64) * 10));
170
for _ in 0..workload {
171
*mtx.lock() += 1;
172
}
173
println!("{i} finished");
174
})
175
.expect("should not fail"),
176
);
177
}
178
for h in handles {
179
h.join().expect("thread panicked");
180
}
181
println!("{:?}", &*mtx.lock());
182
assert_eq!(*mtx.lock(), workload * thread_count * 2);
183
}
184
}
185
186