Path: blob/master/rust/pin-init/examples/static_init.rs
29278 views
// SPDX-License-Identifier: Apache-2.0 OR MIT12#![allow(clippy::undocumented_unsafe_blocks)]3#![cfg_attr(feature = "alloc", feature(allocator_api))]4#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]5#![allow(unused_imports)]67use core::{8cell::{Cell, UnsafeCell},9mem::MaybeUninit,10ops,11pin::Pin,12time::Duration,13};14use pin_init::*;15#[cfg(feature = "std")]16use std::{17sync::Arc,18thread::{sleep, Builder},19};2021#[allow(unused_attributes)]22mod mutex;23use mutex::*;2425pub struct StaticInit<T, I> {26cell: UnsafeCell<MaybeUninit<T>>,27init: Cell<Option<I>>,28lock: SpinLock,29present: Cell<bool>,30}3132unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}33unsafe impl<T: Send, I> Send for StaticInit<T, I> {}3435impl<T, I: PinInit<T>> StaticInit<T, I> {36pub const fn new(init: I) -> Self {37Self {38cell: UnsafeCell::new(MaybeUninit::uninit()),39init: Cell::new(Some(init)),40lock: SpinLock::new(),41present: Cell::new(false),42}43}44}4546impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {47type Target = T;48fn deref(&self) -> &Self::Target {49if self.present.get() {50unsafe { (*self.cell.get()).assume_init_ref() }51} else {52println!("acquire spinlock on static init");53let _guard = self.lock.acquire();54println!("rechecking present...");55std::thread::sleep(std::time::Duration::from_millis(200));56if self.present.get() {57return unsafe { (*self.cell.get()).assume_init_ref() };58}59println!("doing init");60let ptr = self.cell.get().cast::<T>();61match self.init.take() {62Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },63None => unsafe { core::hint::unreachable_unchecked() },64}65self.present.set(true);66unsafe { (*self.cell.get()).assume_init_ref() }67}68}69}7071pub struct CountInit;7273unsafe impl PinInit<CMutex<usize>> for CountInit {74unsafe fn __pinned_init(75self,76slot: *mut CMutex<usize>,77) -> Result<(), core::convert::Infallible> {78let init = CMutex::new(0);79std::thread::sleep(std::time::Duration::from_millis(1000));80unsafe { init.__pinned_init(slot) }81}82}8384pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);8586fn main() {87#[cfg(feature = "std")]88{89let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();90let mut handles = vec![];91let thread_count = 20;92let workload = 1_000;93for i in 0..thread_count {94let mtx = mtx.clone();95handles.push(96Builder::new()97.name(format!("worker #{i}"))98.spawn(move || {99for _ in 0..workload {100*COUNT.lock() += 1;101std::thread::sleep(std::time::Duration::from_millis(10));102*mtx.lock() += 1;103std::thread::sleep(std::time::Duration::from_millis(10));104*COUNT.lock() += 1;105}106println!("{i} halfway");107sleep(Duration::from_millis((i as u64) * 10));108for _ in 0..workload {109std::thread::sleep(std::time::Duration::from_millis(10));110*mtx.lock() += 1;111}112println!("{i} finished");113})114.expect("should not fail"),115);116}117for h in handles {118h.join().expect("thread panicked");119}120println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());121assert_eq!(*mtx.lock(), workload * thread_count * 2);122}123}124125126