msg_tool\utils/
threadpool.rs

1//! Thread pool utilities
2use crate::ext::mutex::*;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{
5    Arc, Condvar, Mutex,
6    mpsc::{Receiver, SyncSender, TrySendError, sync_channel},
7};
8use std::thread::{self, JoinHandle};
9
10type Job<T> = Box<dyn FnOnce(usize) -> T + Send + 'static>;
11
12/// A simple generic thread pool.
13///
14/// - T: the return type of tasks. Completed task results are stored in `results: Arc<Mutex<Vec<T>>>`.
15/// - execute accepts a task and a `block_if_full` flag:
16///     * if true, submission will block when the pool is saturated until a worker becomes available;
17///     * if false, submission will return an error when the pool is saturated.
18/// - join waits until all submitted tasks have completed (it does not shut down the pool).
19pub struct ThreadPool<T: Send + 'static> {
20    sender: Option<SyncSender<Job<T>>>,
21    #[allow(unused)]
22    receiver: Arc<Mutex<Receiver<Job<T>>>>,
23    workers: Vec<JoinHandle<()>>,
24    /// Completed task results
25    pub results: Arc<Mutex<Vec<T>>>,
26    /// Number of pending tasks (queued + running)
27    pending: Arc<AtomicUsize>,
28    /// Pair for wait/notify in join
29    pending_pair: Arc<(Mutex<()>, Condvar)>,
30    size: usize,
31}
32
33#[derive(Debug)]
34/// Error type for [ThreadPool::execute]
35pub enum ExecuteError {
36    /// Pool is full
37    Full,
38    /// Pool is closed
39    Closed,
40}
41
42impl std::error::Error for ExecuteError {}
43
44impl std::fmt::Display for ExecuteError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            ExecuteError::Full => write!(f, "ThreadPool is full"),
48            ExecuteError::Closed => write!(f, "ThreadPool is closed"),
49        }
50    }
51}
52
53impl<T: Send + 'static> ThreadPool<T> {
54    /// Get the number of worker threads in the pool.
55    pub fn size(&self) -> usize {
56        self.size
57    }
58
59    /// Create a new thread pool with `size` workers.
60    /// The internal submission channel is bounded to `size`, so when all workers are busy and
61    /// the channel is full, further submissions will block or return error depending on the flag.
62    ///
63    /// * `name` - Optional base name for worker threads. If None, "threadpool-worker-" is used.
64    /// * `no_result` - If true, results are not stored (saves some overhead if not needed).
65    pub fn new<'a>(
66        size: usize,
67        name: Option<&'a str>,
68        no_result: bool,
69    ) -> Result<Self, std::io::Error> {
70        if size == 0 {
71            return Err(std::io::Error::new(
72                std::io::ErrorKind::InvalidInput,
73                "worker size must be > 0",
74            ));
75        }
76
77        let (tx, rx) = sync_channel::<Job<T>>(size);
78        let receiver = Arc::new(Mutex::new(rx));
79        let results = Arc::new(Mutex::new(Vec::new()));
80        let pending = Arc::new(AtomicUsize::new(0));
81        let pending_pair = Arc::new((Mutex::new(()), Condvar::new()));
82        let thread_name = name.unwrap_or("threadpool-worker-");
83
84        let mut workers = Vec::with_capacity(size);
85        for id in 0..size {
86            let rx_clone = Arc::clone(&receiver);
87            let results_clone = Arc::clone(&results);
88            let pending_clone = Arc::clone(&pending);
89            let pending_pair_clone = Arc::clone(&pending_pair);
90
91            let handle = thread::Builder::new()
92                .name(format!("{}{}", thread_name, id))
93                .spawn(move || {
94                    loop {
95                        // Lock receiver to call recv. Using a Mutex around Receiver serializes
96                        // the recv calls but is fine for this simple implementation.
97                        let job = {
98                            let guard = rx_clone.lock_blocking();
99                            // If recv returns Err, sender was dropped -> exit thread
100                            guard.recv()
101                        };
102
103                        match job {
104                            Ok(job) => {
105                                // Execute the job and store result
106                                let res = job(id);
107                                if !no_result {
108                                    let mut r = results_clone.lock_blocking();
109                                    r.push(res);
110                                }
111
112                                // Decrement pending count and notify join waiters
113                                pending_clone.fetch_sub(1, Ordering::SeqCst);
114                                let (lock, cvar) = &*pending_pair_clone;
115                                let _g = lock.lock_blocking();
116                                cvar.notify_all();
117                            }
118                            Err(_) => {
119                                // Channel closed -> shutdown worker
120                                break;
121                            }
122                        }
123                    }
124                })?;
125
126            workers.push(handle);
127        }
128
129        Ok(ThreadPool {
130            sender: Some(tx),
131            receiver,
132            workers,
133            results,
134            pending,
135            pending_pair,
136            size,
137        })
138    }
139
140    /// Execute a task. If `block_if_full` is true, this call will block when the internal
141    /// submission channel is full (i.e. all workers busy and buffer full) until space becomes available.
142    /// If `block_if_full` is false, this returns Err(ExecuteError::Full) when the channel is full.
143    ///
144    /// job: a closure that takes the worker id (0..size-1) and returns a T.
145    pub fn execute<F>(&self, job: F, block_if_full: bool) -> Result<(), ExecuteError>
146    where
147        F: FnOnce(usize) -> T + Send + 'static,
148    {
149        let sender = match &self.sender {
150            Some(s) => s,
151            None => return Err(ExecuteError::Closed),
152        };
153
154        // Increase pending count for this submission. If submission fails we will decrement.
155        self.pending.fetch_add(1, Ordering::SeqCst);
156
157        let boxed: Job<T> = Box::new(job);
158
159        if block_if_full {
160            // This will block until there is space in the bounded channel or the channel is closed.
161            if sender.send(boxed).is_err() {
162                // Channel closed
163                self.pending.fetch_sub(1, Ordering::SeqCst);
164                return Err(ExecuteError::Closed);
165            }
166            Ok(())
167        } else {
168            match sender.try_send(boxed) {
169                Ok(()) => Ok(()),
170                Err(TrySendError::Full(_)) => {
171                    // revert pending increment
172                    self.pending.fetch_sub(1, Ordering::SeqCst);
173                    Err(ExecuteError::Full)
174                }
175                Err(TrySendError::Disconnected(_)) => {
176                    self.pending.fetch_sub(1, Ordering::SeqCst);
177                    Err(ExecuteError::Closed)
178                }
179            }
180        }
181    }
182
183    /// Wait until all submitted tasks have completed. This does not shut down the pool; new tasks
184    /// can still be submitted after join returns.
185    pub fn join(&self) {
186        // Fast path
187        if self.pending.load(Ordering::SeqCst) == 0 {
188            return;
189        }
190
191        let (lock, cvar) = &*self.pending_pair;
192        let mut guard = lock.lock_blocking();
193        while self.pending.load(Ordering::SeqCst) != 0 {
194            guard = match cvar.wait(guard) {
195                Ok(g) => g,
196                Err(poisoned) => poisoned.into_inner(),
197            };
198        }
199    }
200
201    /// Take all results, leaving an empty results vector.
202    pub fn take_results(&self) -> Vec<T> {
203        let mut results = self.results.lock_blocking();
204        results.split_off(0)
205    }
206
207    /// Wait until all submitted tasks have completed, then return the results.
208    pub fn into_results(self) -> Vec<T> {
209        self.join();
210        let mut results = self.results.lock_blocking();
211        results.split_off(0)
212    }
213}
214
215impl<T: Send + 'static> Drop for ThreadPool<T> {
216    fn drop(&mut self) {
217        // Close sender so worker threads exit recv loop
218        self.sender.take();
219        // Dropping the sender (SyncSender) happens above; but to ensure we close the channel we
220        // explicitly drop any remaining clones by letting sender go out of scope.
221
222        // Join worker threads
223        while let Some(handle) = self.workers.pop() {
224            let _ = handle.join();
225        }
226    }
227}