Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/rust/pin-init/examples/mutex.rs
29278 views
1
// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3
#![allow(clippy::undocumented_unsafe_blocks)]
4
#![cfg_attr(feature = "alloc", feature(allocator_api))]
5
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6
#![allow(clippy::missing_safety_doc)]
7
8
use core::{
9
cell::{Cell, UnsafeCell},
10
marker::PhantomPinned,
11
ops::{Deref, DerefMut},
12
pin::Pin,
13
sync::atomic::{AtomicBool, Ordering},
14
};
15
#[cfg(feature = "std")]
16
use std::{
17
sync::Arc,
18
thread::{self, sleep, Builder, Thread},
19
time::Duration,
20
};
21
22
use pin_init::*;
23
#[allow(unused_attributes)]
24
#[path = "./linked_list.rs"]
25
pub mod linked_list;
26
use linked_list::*;
27
28
pub struct SpinLock {
29
inner: AtomicBool,
30
}
31
32
impl SpinLock {
33
#[inline]
34
pub fn acquire(&self) -> SpinLockGuard<'_> {
35
while self
36
.inner
37
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
38
.is_err()
39
{
40
#[cfg(feature = "std")]
41
while self.inner.load(Ordering::Relaxed) {
42
thread::yield_now();
43
}
44
}
45
SpinLockGuard(self)
46
}
47
48
#[inline]
49
#[allow(clippy::new_without_default)]
50
pub const fn new() -> Self {
51
Self {
52
inner: AtomicBool::new(false),
53
}
54
}
55
}
56
57
pub struct SpinLockGuard<'a>(&'a SpinLock);
58
59
impl Drop for SpinLockGuard<'_> {
60
#[inline]
61
fn drop(&mut self) {
62
self.0.inner.store(false, Ordering::Release);
63
}
64
}
65
66
#[pin_data]
67
pub struct CMutex<T> {
68
#[pin]
69
wait_list: ListHead,
70
spin_lock: SpinLock,
71
locked: Cell<bool>,
72
#[pin]
73
data: UnsafeCell<T>,
74
}
75
76
impl<T> CMutex<T> {
77
#[inline]
78
pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
79
pin_init!(CMutex {
80
wait_list <- ListHead::new(),
81
spin_lock: SpinLock::new(),
82
locked: Cell::new(false),
83
data <- unsafe {
84
pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
85
val.__pinned_init(slot.cast::<T>())
86
})
87
},
88
})
89
}
90
91
#[inline]
92
pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
93
let mut sguard = self.spin_lock.acquire();
94
if self.locked.get() {
95
stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
96
// println!("wait list length: {}", self.wait_list.size());
97
while self.locked.get() {
98
drop(sguard);
99
#[cfg(feature = "std")]
100
thread::park();
101
sguard = self.spin_lock.acquire();
102
}
103
// This does have an effect, as the ListHead inside wait_entry implements Drop!
104
#[expect(clippy::drop_non_drop)]
105
drop(wait_entry);
106
}
107
self.locked.set(true);
108
unsafe {
109
Pin::new_unchecked(CMutexGuard {
110
mtx: self,
111
_pin: PhantomPinned,
112
})
113
}
114
}
115
116
#[allow(dead_code)]
117
pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
118
// SAFETY: we have an exclusive reference and thus nobody has access to data.
119
unsafe { &mut *self.data.get() }
120
}
121
}
122
123
unsafe impl<T: Send> Send for CMutex<T> {}
124
unsafe impl<T: Send> Sync for CMutex<T> {}
125
126
pub struct CMutexGuard<'a, T> {
127
mtx: &'a CMutex<T>,
128
_pin: PhantomPinned,
129
}
130
131
impl<T> Drop for CMutexGuard<'_, T> {
132
#[inline]
133
fn drop(&mut self) {
134
let sguard = self.mtx.spin_lock.acquire();
135
self.mtx.locked.set(false);
136
if let Some(list_field) = self.mtx.wait_list.next() {
137
let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
138
#[cfg(feature = "std")]
139
unsafe {
140
(*_wait_entry).thread.unpark()
141
};
142
}
143
drop(sguard);
144
}
145
}
146
147
impl<T> Deref for CMutexGuard<'_, T> {
148
type Target = T;
149
150
#[inline]
151
fn deref(&self) -> &Self::Target {
152
unsafe { &*self.mtx.data.get() }
153
}
154
}
155
156
impl<T> DerefMut for CMutexGuard<'_, T> {
157
#[inline]
158
fn deref_mut(&mut self) -> &mut Self::Target {
159
unsafe { &mut *self.mtx.data.get() }
160
}
161
}
162
163
#[pin_data]
164
#[repr(C)]
165
struct WaitEntry {
166
#[pin]
167
wait_list: ListHead,
168
#[cfg(feature = "std")]
169
thread: Thread,
170
}
171
172
impl WaitEntry {
173
#[inline]
174
fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
175
#[cfg(feature = "std")]
176
{
177
pin_init!(Self {
178
thread: thread::current(),
179
wait_list <- ListHead::insert_prev(list),
180
})
181
}
182
#[cfg(not(feature = "std"))]
183
{
184
pin_init!(Self {
185
wait_list <- ListHead::insert_prev(list),
186
})
187
}
188
}
189
}
190
191
#[cfg_attr(test, test)]
192
#[allow(dead_code)]
193
fn main() {
194
#[cfg(feature = "std")]
195
{
196
let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
197
let mut handles = vec![];
198
let thread_count = 20;
199
let workload = if cfg!(miri) { 100 } else { 1_000 };
200
for i in 0..thread_count {
201
let mtx = mtx.clone();
202
handles.push(
203
Builder::new()
204
.name(format!("worker #{i}"))
205
.spawn(move || {
206
for _ in 0..workload {
207
*mtx.lock() += 1;
208
}
209
println!("{i} halfway");
210
sleep(Duration::from_millis((i as u64) * 10));
211
for _ in 0..workload {
212
*mtx.lock() += 1;
213
}
214
println!("{i} finished");
215
})
216
.expect("should not fail"),
217
);
218
}
219
for h in handles {
220
h.join().expect("thread panicked");
221
}
222
println!("{:?}", &*mtx.lock());
223
assert_eq!(*mtx.lock(), workload * thread_count * 2);
224
}
225
}
226
227