jieba_rs/
hmm.rs

1use std::cmp::Ordering;
2use std::io::BufRead;
3
4use regex::Regex;
5
6use crate::FxHashMap;
7use crate::SplitByCharacterClass;
8use crate::errors::Error;
9use jieba_macros::generate_hmm_data;
10
11thread_local! {
12    static RE_SKIP: Regex = Regex::new(r"([a-zA-Z0-9]+(?:.\d+)?%?)").unwrap();
13}
14
15/// HMM-specific CJK range `[\u{4E00}-\u{9FD5}]`
16#[inline]
17fn is_hmm_han(c: char) -> bool {
18    matches!(c, '\u{4E00}'..='\u{9FD5}')
19}
20
21/// Regex-based splitter for RE_SKIP in HMM.
22struct HmmSkipSplitter<'r, 't> {
23    finder: regex::Matches<'r, 't>,
24    text: &'t str,
25    last: usize,
26    matched: Option<regex::Match<'t>>,
27}
28
29impl<'r, 't> HmmSkipSplitter<'r, 't> {
30    fn new(re: &'r Regex, text: &'t str) -> Self {
31        HmmSkipSplitter {
32            finder: re.find_iter(text),
33            text,
34            last: 0,
35            matched: None,
36        }
37    }
38}
39
40impl<'t> Iterator for HmmSkipSplitter<'_, 't> {
41    type Item = &'t str;
42
43    fn next(&mut self) -> Option<&'t str> {
44        if let Some(m) = self.matched.take() {
45            return Some(m.as_str());
46        }
47        match self.finder.next() {
48            None => {
49                if self.last >= self.text.len() {
50                    None
51                } else {
52                    let s = &self.text[self.last..];
53                    self.last = self.text.len();
54                    Some(s)
55                }
56            }
57            Some(m) => {
58                if self.last == m.start() {
59                    self.last = m.end();
60                    Some(m.as_str())
61                } else {
62                    let unmatched = &self.text[self.last..m.start()];
63                    self.last = m.end();
64                    self.matched = Some(m);
65                    Some(unmatched)
66                }
67            }
68        }
69    }
70}
71
72pub const NUM_STATES: usize = 4;
73
74/// Result of hmm is a labeling of each Unicode Scalar Value in the input
75/// string with Begin, Middle, End, or Single. These denote the proposed
76/// segments. A segment is one of the following two patterns.
77///
78///   Begin, [Middle...], End
79///   Single
80///
81/// Each state in the enum is also assigned an index value from 0-3 that
82/// can be used as an index into an array representing data pertaining
83/// to that state.
84///
85/// WARNING: The data file format for hmm.model comments imply one can
86/// reassign the index values of each state at the top but `jieba-macros`
87/// currently ignores the mapping. Do not reassign these indices without
88/// verifying how it interacts with `jieba-macros`.  These indices must also
89/// match the order if ALLOWED_PREV_STATUS.
90#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
91pub enum State {
92    Begin = 0,
93    End = 1,
94    Middle = 2,
95    Single = 3,
96}
97
98// Mapping representing the allow transitiongs into the given state.
99//
100// WARNING: Ordering must match the indicies in State.
101static ALLOWED_PREV_STATUS: [[State; 2]; NUM_STATES] = [
102    // Can preceed State::Begin
103    [State::End, State::Single],
104    // Can preceed State::End
105    [State::Begin, State::Middle],
106    // Can preceed State::Middle
107    [State::Middle, State::Begin],
108    // Can preceed State::Single
109    [State::Single, State::End],
110];
111
112generate_hmm_data!();
113
114const MIN_FLOAT: f64 = -3.14e100;
115
116pub(crate) trait HmmParams {
117    fn initial_prob(&self, state: usize) -> f64;
118    fn trans_prob(&self, from: usize, to: usize) -> f64;
119    fn emit_prob(&self, state: usize, ch: char) -> f64;
120}
121
122/// The compile-time embedded HMM parameters.
123pub(crate) struct BuiltinHmm;
124
125impl HmmParams for BuiltinHmm {
126    #[inline]
127    fn initial_prob(&self, state: usize) -> f64 {
128        INITIAL_PROBS[state]
129    }
130
131    #[inline]
132    fn trans_prob(&self, from: usize, to: usize) -> f64 {
133        TRANS_PROBS[from][to]
134    }
135
136    #[inline]
137    fn emit_prob(&self, state: usize, ch: char) -> f64 {
138        EMIT_PROBS[state].get(&ch).cloned().unwrap_or(MIN_FLOAT)
139    }
140}
141
142#[derive(Default)]
143pub(crate) struct HmmContext {
144    v: Vec<f64>,
145    prev: Vec<Option<State>>,
146    best_path: Vec<State>,
147    chars: Vec<(usize, char)>,
148}
149
150#[allow(non_snake_case, clippy::needless_range_loop)]
151fn viterbi(sentence: &str, params: &impl HmmParams, hmm_context: &mut HmmContext) {
152    let states = [State::Begin, State::Middle, State::End, State::Single];
153    #[allow(non_snake_case)]
154    let R = states.len();
155
156    // Collect char byte offsets into reusable scratch space, derive C from the length.
157    hmm_context.chars.clear();
158    hmm_context.chars.extend(sentence.char_indices());
159    let chars = &hmm_context.chars;
160    let C = chars.len();
161    assert!(C > 1);
162
163    if hmm_context.prev.len() < R * C {
164        hmm_context.prev.resize(R * C, None);
165    }
166    hmm_context.prev[..R].fill(None);
167
168    if hmm_context.v.len() < R * C {
169        hmm_context.v.resize(R * C, 0.0);
170    }
171
172    if hmm_context.best_path.len() < C {
173        hmm_context.best_path.resize(C, State::Begin);
174    }
175
176    let first_char = chars[0].1;
177    for y in &states {
178        let prob = params.initial_prob(*y as usize) + params.emit_prob(*y as usize, first_char);
179        hmm_context.v[*y as usize] = prob;
180    }
181
182    for t in 1..C {
183        let ch = chars[t].1;
184        for y in &states {
185            let em_prob = params.emit_prob(*y as usize, ch);
186            let (prob, state) = ALLOWED_PREV_STATUS[*y as usize]
187                .iter()
188                .map(|y0| {
189                    (
190                        hmm_context.v[(t - 1) * R + (*y0 as usize)]
191                            + params.trans_prob(*y0 as usize, *y as usize)
192                            + em_prob,
193                        *y0,
194                    )
195                })
196                .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
197                .unwrap();
198            let idx = (t * R) + (*y as usize);
199            hmm_context.v[idx] = prob;
200            hmm_context.prev[idx] = Some(state);
201        }
202    }
203
204    let (_prob, state) = [State::End, State::Single]
205        .iter()
206        .map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y))
207        .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
208        .unwrap();
209
210    let mut t = C - 1;
211    let mut curr = *state;
212
213    hmm_context.best_path[t] = *state;
214    while let Some(p) = hmm_context.prev[t * R + (curr as usize)] {
215        assert!(t > 0);
216        hmm_context.best_path[t - 1] = p;
217        curr = p;
218        t -= 1;
219    }
220    hmm_context.best_path.truncate(C);
221}
222
223#[allow(non_snake_case)]
224fn cut_internal<'a>(
225    sentence: &'a str,
226    words: &mut Vec<&'a str>,
227    params: &impl HmmParams,
228    hmm_context: &mut HmmContext,
229) {
230    let str_len = sentence.len();
231    viterbi(sentence, params, hmm_context);
232    let mut begin = 0;
233    let mut next_byte_offset = 0;
234
235    for (i, &(curr_byte_offset, _)) in hmm_context.chars.iter().enumerate() {
236        let state = hmm_context.best_path[i];
237        match state {
238            State::Begin => begin = curr_byte_offset,
239            State::End => {
240                let byte_start = begin;
241                let byte_end = hmm_context.chars.get(i + 1).map_or(str_len, |&(offset, _)| offset);
242                words.push(&sentence[byte_start..byte_end]);
243                next_byte_offset = byte_end;
244            }
245            State::Single => {
246                let byte_start = curr_byte_offset;
247                let byte_end = hmm_context.chars.get(i + 1).map_or(str_len, |&(offset, _)| offset);
248                words.push(&sentence[byte_start..byte_end]);
249                next_byte_offset = byte_end;
250            }
251            State::Middle => { /* do nothing */ }
252        }
253    }
254
255    if next_byte_offset < str_len {
256        let byte_start = next_byte_offset;
257        words.push(&sentence[byte_start..]);
258    }
259}
260
261#[allow(non_snake_case)]
262pub(crate) fn cut_with_allocated_memory<'a>(
263    sentence: &'a str,
264    words: &mut Vec<&'a str>,
265    params: &impl HmmParams,
266    hmm_context: &mut HmmContext,
267) {
268    RE_SKIP.with(|re_skip| {
269        let splitter = SplitByCharacterClass::new(sentence, is_hmm_han);
270        for state in splitter {
271            let block = state.as_str();
272            if block.is_empty() {
273                continue;
274            }
275            if state.is_matched() {
276                if block.chars().nth(1).is_some() {
277                    cut_internal(block, words, params, hmm_context);
278                } else {
279                    words.push(block);
280                }
281            } else {
282                let skip_splitter = HmmSkipSplitter::new(re_skip, block);
283                for x in skip_splitter {
284                    if x.is_empty() {
285                        continue;
286                    }
287                    words.push(x);
288                }
289            }
290        }
291    })
292}
293
294/// A runtime-loadable HMM model for custom segmentation.
295///
296/// This allows loading HMM parameters trained with `scripts/train_hmm.py`
297/// instead of using the compile-time embedded model.
298#[derive(Debug, Clone)]
299pub struct HmmModel {
300    initial_probs: [f64; NUM_STATES],
301    trans_probs: [[f64; NUM_STATES]; NUM_STATES],
302    emit_probs: [FxHashMap<Box<str>, f64>; NUM_STATES],
303}
304
305impl HmmParams for HmmModel {
306    #[inline]
307    fn initial_prob(&self, state: usize) -> f64 {
308        self.initial_probs[state]
309    }
310
311    #[inline]
312    fn trans_prob(&self, from: usize, to: usize) -> f64 {
313        self.trans_probs[from][to]
314    }
315
316    #[inline]
317    fn emit_prob(&self, state: usize, ch: char) -> f64 {
318        let mut buf = [0u8; 4];
319        let s = ch.encode_utf8(&mut buf);
320        self.emit_probs[state].get(s).copied().unwrap_or(MIN_FLOAT)
321    }
322}
323
324impl HmmModel {
325    /// Load an HMM model from a reader in the `hmm.model` file format.
326    ///
327    /// The format is compatible with the output of `scripts/train_hmm.py`.
328    pub fn load<R: BufRead>(reader: &mut R) -> Result<Self, Error> {
329        let mut data_lines = Vec::new();
330        let mut buf = String::new();
331        while reader.read_line(&mut buf)? > 0 {
332            {
333                let line = buf.trim();
334                if !line.is_empty() && !line.starts_with('#') {
335                    data_lines.push(line.to_string());
336                }
337            }
338            buf.clear();
339        }
340
341        // Line 0: start probs (4 values)
342        if data_lines.len() < 9 {
343            return Err(Error::InvalidHmmModel(format!(
344                "expected at least 9 data lines, got {}",
345                data_lines.len()
346            )));
347        }
348
349        let initial_probs = Self::parse_prob_line(&data_lines[0], "initial")?;
350
351        // Lines 1-4: transition matrix
352        let mut trans_probs = [[0.0f64; NUM_STATES]; NUM_STATES];
353        for i in 0..NUM_STATES {
354            let vals = Self::parse_prob_line(&data_lines[1 + i], "transition")?;
355            trans_probs[i] = vals;
356        }
357
358        // Lines 5-8: emission probs (comma-separated char:prob pairs)
359        let mut emit_probs: [FxHashMap<Box<str>, f64>; NUM_STATES] = [
360            FxHashMap::default(),
361            FxHashMap::default(),
362            FxHashMap::default(),
363            FxHashMap::default(),
364        ];
365        for i in 0..NUM_STATES {
366            for pair in data_lines[5 + i].split(',') {
367                let pair = pair.trim();
368                if pair.is_empty() {
369                    continue;
370                }
371                let colon_pos = pair
372                    .rfind(':')
373                    .ok_or_else(|| Error::InvalidHmmModel(format!("invalid emit pair (missing ':'): `{pair}`")))?;
374                let ch = &pair[..colon_pos];
375                let prob: f64 = pair[colon_pos + 1..]
376                    .parse()
377                    .map_err(|e| Error::InvalidHmmModel(format!("invalid emit prob: {e}")))?;
378                emit_probs[i].insert(ch.into(), prob);
379            }
380        }
381
382        Ok(HmmModel {
383            initial_probs,
384            trans_probs,
385            emit_probs,
386        })
387    }
388
389    fn parse_prob_line(line: &str, context: &str) -> Result<[f64; NUM_STATES], Error> {
390        let vals: Vec<f64> = line
391            .split_whitespace()
392            .map(|v| {
393                v.parse::<f64>()
394                    .map_err(|e| Error::InvalidHmmModel(format!("invalid {context} prob `{v}`: {e}")))
395            })
396            .collect::<Result<_, _>>()?;
397        if vals.len() != NUM_STATES {
398            return Err(Error::InvalidHmmModel(format!(
399                "expected {NUM_STATES} {context} values, got {}",
400                vals.len()
401            )));
402        }
403        Ok([vals[0], vals[1], vals[2], vals[3]])
404    }
405}
406
407pub(crate) fn builtin_hmm() -> BuiltinHmm {
408    BuiltinHmm
409}
410
411#[cfg(test)]
412mod tests {
413    use expect_test::expect;
414
415    use super::{BuiltinHmm, HmmContext, cut_with_allocated_memory, viterbi};
416
417    fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
418        let mut hmm_context = HmmContext::default();
419
420        cut_with_allocated_memory(sentence, words, &BuiltinHmm, &mut hmm_context)
421    }
422    #[test]
423    #[allow(non_snake_case)]
424    fn test_viterbi() {
425        let sentence = "小明硕士毕业于中国科学院计算所";
426
427        let mut hmm_context = HmmContext::default();
428        viterbi(sentence, &BuiltinHmm, &mut hmm_context);
429        expect![[
430            r#"[Begin, End, Begin, End, Begin, Middle, End, Begin, End, Begin, Middle, End, Begin, End, Single]"#
431        ]]
432        .assert_eq(&format!("{:?}", hmm_context.best_path));
433    }
434
435    #[test]
436    fn test_hmm_cut() {
437        let sentence = "小明硕士毕业于中国科学院计算所";
438        let mut words = Vec::with_capacity(sentence.chars().count() / 2);
439        cut(sentence, &mut words);
440        expect![[r#"["小明", "硕士", "毕业于", "中国", "科学院", "计算", "所"]"#]].assert_eq(&format!("{:?}", words));
441    }
442}