1use std::cmp::Ordering;
2use std::io::BufRead;
3
4use regex::Regex;
5
6use crate::FxHashMap;
7use crate::SplitByCharacterClass;
8use crate::errors::Error;
9use jieba_macros::generate_hmm_data;
10
11thread_local! {
12 static RE_SKIP: Regex = Regex::new(r"([a-zA-Z0-9]+(?:.\d+)?%?)").unwrap();
13}
14
15#[inline]
17fn is_hmm_han(c: char) -> bool {
18 matches!(c, '\u{4E00}'..='\u{9FD5}')
19}
20
21struct HmmSkipSplitter<'r, 't> {
23 finder: regex::Matches<'r, 't>,
24 text: &'t str,
25 last: usize,
26 matched: Option<regex::Match<'t>>,
27}
28
29impl<'r, 't> HmmSkipSplitter<'r, 't> {
30 fn new(re: &'r Regex, text: &'t str) -> Self {
31 HmmSkipSplitter {
32 finder: re.find_iter(text),
33 text,
34 last: 0,
35 matched: None,
36 }
37 }
38}
39
40impl<'t> Iterator for HmmSkipSplitter<'_, 't> {
41 type Item = &'t str;
42
43 fn next(&mut self) -> Option<&'t str> {
44 if let Some(m) = self.matched.take() {
45 return Some(m.as_str());
46 }
47 match self.finder.next() {
48 None => {
49 if self.last >= self.text.len() {
50 None
51 } else {
52 let s = &self.text[self.last..];
53 self.last = self.text.len();
54 Some(s)
55 }
56 }
57 Some(m) => {
58 if self.last == m.start() {
59 self.last = m.end();
60 Some(m.as_str())
61 } else {
62 let unmatched = &self.text[self.last..m.start()];
63 self.last = m.end();
64 self.matched = Some(m);
65 Some(unmatched)
66 }
67 }
68 }
69 }
70}
71
72pub const NUM_STATES: usize = 4;
73
74#[derive(Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Clone, Copy)]
91pub enum State {
92 Begin = 0,
93 End = 1,
94 Middle = 2,
95 Single = 3,
96}
97
98static ALLOWED_PREV_STATUS: [[State; 2]; NUM_STATES] = [
102 [State::End, State::Single],
104 [State::Begin, State::Middle],
106 [State::Middle, State::Begin],
108 [State::Single, State::End],
110];
111
112generate_hmm_data!();
113
114const MIN_FLOAT: f64 = -3.14e100;
115
116pub(crate) trait HmmParams {
117 fn initial_prob(&self, state: usize) -> f64;
118 fn trans_prob(&self, from: usize, to: usize) -> f64;
119 fn emit_prob(&self, state: usize, ch: char) -> f64;
120}
121
122pub(crate) struct BuiltinHmm;
124
125impl HmmParams for BuiltinHmm {
126 #[inline]
127 fn initial_prob(&self, state: usize) -> f64 {
128 INITIAL_PROBS[state]
129 }
130
131 #[inline]
132 fn trans_prob(&self, from: usize, to: usize) -> f64 {
133 TRANS_PROBS[from][to]
134 }
135
136 #[inline]
137 fn emit_prob(&self, state: usize, ch: char) -> f64 {
138 EMIT_PROBS[state].get(&ch).cloned().unwrap_or(MIN_FLOAT)
139 }
140}
141
142#[derive(Default)]
143pub(crate) struct HmmContext {
144 v: Vec<f64>,
145 prev: Vec<Option<State>>,
146 best_path: Vec<State>,
147 chars: Vec<(usize, char)>,
148}
149
150#[allow(non_snake_case, clippy::needless_range_loop)]
151fn viterbi(sentence: &str, params: &impl HmmParams, hmm_context: &mut HmmContext) {
152 let states = [State::Begin, State::Middle, State::End, State::Single];
153 #[allow(non_snake_case)]
154 let R = states.len();
155
156 hmm_context.chars.clear();
158 hmm_context.chars.extend(sentence.char_indices());
159 let chars = &hmm_context.chars;
160 let C = chars.len();
161 assert!(C > 1);
162
163 if hmm_context.prev.len() < R * C {
164 hmm_context.prev.resize(R * C, None);
165 }
166 hmm_context.prev[..R].fill(None);
167
168 if hmm_context.v.len() < R * C {
169 hmm_context.v.resize(R * C, 0.0);
170 }
171
172 if hmm_context.best_path.len() < C {
173 hmm_context.best_path.resize(C, State::Begin);
174 }
175
176 let first_char = chars[0].1;
177 for y in &states {
178 let prob = params.initial_prob(*y as usize) + params.emit_prob(*y as usize, first_char);
179 hmm_context.v[*y as usize] = prob;
180 }
181
182 for t in 1..C {
183 let ch = chars[t].1;
184 for y in &states {
185 let em_prob = params.emit_prob(*y as usize, ch);
186 let (prob, state) = ALLOWED_PREV_STATUS[*y as usize]
187 .iter()
188 .map(|y0| {
189 (
190 hmm_context.v[(t - 1) * R + (*y0 as usize)]
191 + params.trans_prob(*y0 as usize, *y as usize)
192 + em_prob,
193 *y0,
194 )
195 })
196 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
197 .unwrap();
198 let idx = (t * R) + (*y as usize);
199 hmm_context.v[idx] = prob;
200 hmm_context.prev[idx] = Some(state);
201 }
202 }
203
204 let (_prob, state) = [State::End, State::Single]
205 .iter()
206 .map(|y| (hmm_context.v[(C - 1) * R + (*y as usize)], y))
207 .max_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal))
208 .unwrap();
209
210 let mut t = C - 1;
211 let mut curr = *state;
212
213 hmm_context.best_path[t] = *state;
214 while let Some(p) = hmm_context.prev[t * R + (curr as usize)] {
215 assert!(t > 0);
216 hmm_context.best_path[t - 1] = p;
217 curr = p;
218 t -= 1;
219 }
220 hmm_context.best_path.truncate(C);
221}
222
223#[allow(non_snake_case)]
224fn cut_internal<'a>(
225 sentence: &'a str,
226 words: &mut Vec<&'a str>,
227 params: &impl HmmParams,
228 hmm_context: &mut HmmContext,
229) {
230 let str_len = sentence.len();
231 viterbi(sentence, params, hmm_context);
232 let mut begin = 0;
233 let mut next_byte_offset = 0;
234
235 for (i, &(curr_byte_offset, _)) in hmm_context.chars.iter().enumerate() {
236 let state = hmm_context.best_path[i];
237 match state {
238 State::Begin => begin = curr_byte_offset,
239 State::End => {
240 let byte_start = begin;
241 let byte_end = hmm_context.chars.get(i + 1).map_or(str_len, |&(offset, _)| offset);
242 words.push(&sentence[byte_start..byte_end]);
243 next_byte_offset = byte_end;
244 }
245 State::Single => {
246 let byte_start = curr_byte_offset;
247 let byte_end = hmm_context.chars.get(i + 1).map_or(str_len, |&(offset, _)| offset);
248 words.push(&sentence[byte_start..byte_end]);
249 next_byte_offset = byte_end;
250 }
251 State::Middle => { }
252 }
253 }
254
255 if next_byte_offset < str_len {
256 let byte_start = next_byte_offset;
257 words.push(&sentence[byte_start..]);
258 }
259}
260
261#[allow(non_snake_case)]
262pub(crate) fn cut_with_allocated_memory<'a>(
263 sentence: &'a str,
264 words: &mut Vec<&'a str>,
265 params: &impl HmmParams,
266 hmm_context: &mut HmmContext,
267) {
268 RE_SKIP.with(|re_skip| {
269 let splitter = SplitByCharacterClass::new(sentence, is_hmm_han);
270 for state in splitter {
271 let block = state.as_str();
272 if block.is_empty() {
273 continue;
274 }
275 if state.is_matched() {
276 if block.chars().nth(1).is_some() {
277 cut_internal(block, words, params, hmm_context);
278 } else {
279 words.push(block);
280 }
281 } else {
282 let skip_splitter = HmmSkipSplitter::new(re_skip, block);
283 for x in skip_splitter {
284 if x.is_empty() {
285 continue;
286 }
287 words.push(x);
288 }
289 }
290 }
291 })
292}
293
294#[derive(Debug, Clone)]
299pub struct HmmModel {
300 initial_probs: [f64; NUM_STATES],
301 trans_probs: [[f64; NUM_STATES]; NUM_STATES],
302 emit_probs: [FxHashMap<Box<str>, f64>; NUM_STATES],
303}
304
305impl HmmParams for HmmModel {
306 #[inline]
307 fn initial_prob(&self, state: usize) -> f64 {
308 self.initial_probs[state]
309 }
310
311 #[inline]
312 fn trans_prob(&self, from: usize, to: usize) -> f64 {
313 self.trans_probs[from][to]
314 }
315
316 #[inline]
317 fn emit_prob(&self, state: usize, ch: char) -> f64 {
318 let mut buf = [0u8; 4];
319 let s = ch.encode_utf8(&mut buf);
320 self.emit_probs[state].get(s).copied().unwrap_or(MIN_FLOAT)
321 }
322}
323
324impl HmmModel {
325 pub fn load<R: BufRead>(reader: &mut R) -> Result<Self, Error> {
329 let mut data_lines = Vec::new();
330 let mut buf = String::new();
331 while reader.read_line(&mut buf)? > 0 {
332 {
333 let line = buf.trim();
334 if !line.is_empty() && !line.starts_with('#') {
335 data_lines.push(line.to_string());
336 }
337 }
338 buf.clear();
339 }
340
341 if data_lines.len() < 9 {
343 return Err(Error::InvalidHmmModel(format!(
344 "expected at least 9 data lines, got {}",
345 data_lines.len()
346 )));
347 }
348
349 let initial_probs = Self::parse_prob_line(&data_lines[0], "initial")?;
350
351 let mut trans_probs = [[0.0f64; NUM_STATES]; NUM_STATES];
353 for i in 0..NUM_STATES {
354 let vals = Self::parse_prob_line(&data_lines[1 + i], "transition")?;
355 trans_probs[i] = vals;
356 }
357
358 let mut emit_probs: [FxHashMap<Box<str>, f64>; NUM_STATES] = [
360 FxHashMap::default(),
361 FxHashMap::default(),
362 FxHashMap::default(),
363 FxHashMap::default(),
364 ];
365 for i in 0..NUM_STATES {
366 for pair in data_lines[5 + i].split(',') {
367 let pair = pair.trim();
368 if pair.is_empty() {
369 continue;
370 }
371 let colon_pos = pair
372 .rfind(':')
373 .ok_or_else(|| Error::InvalidHmmModel(format!("invalid emit pair (missing ':'): `{pair}`")))?;
374 let ch = &pair[..colon_pos];
375 let prob: f64 = pair[colon_pos + 1..]
376 .parse()
377 .map_err(|e| Error::InvalidHmmModel(format!("invalid emit prob: {e}")))?;
378 emit_probs[i].insert(ch.into(), prob);
379 }
380 }
381
382 Ok(HmmModel {
383 initial_probs,
384 trans_probs,
385 emit_probs,
386 })
387 }
388
389 fn parse_prob_line(line: &str, context: &str) -> Result<[f64; NUM_STATES], Error> {
390 let vals: Vec<f64> = line
391 .split_whitespace()
392 .map(|v| {
393 v.parse::<f64>()
394 .map_err(|e| Error::InvalidHmmModel(format!("invalid {context} prob `{v}`: {e}")))
395 })
396 .collect::<Result<_, _>>()?;
397 if vals.len() != NUM_STATES {
398 return Err(Error::InvalidHmmModel(format!(
399 "expected {NUM_STATES} {context} values, got {}",
400 vals.len()
401 )));
402 }
403 Ok([vals[0], vals[1], vals[2], vals[3]])
404 }
405}
406
407pub(crate) fn builtin_hmm() -> BuiltinHmm {
408 BuiltinHmm
409}
410
411#[cfg(test)]
412mod tests {
413 use expect_test::expect;
414
415 use super::{BuiltinHmm, HmmContext, cut_with_allocated_memory, viterbi};
416
417 fn cut<'a>(sentence: &'a str, words: &mut Vec<&'a str>) {
418 let mut hmm_context = HmmContext::default();
419
420 cut_with_allocated_memory(sentence, words, &BuiltinHmm, &mut hmm_context)
421 }
422 #[test]
423 #[allow(non_snake_case)]
424 fn test_viterbi() {
425 let sentence = "小明硕士毕业于中国科学院计算所";
426
427 let mut hmm_context = HmmContext::default();
428 viterbi(sentence, &BuiltinHmm, &mut hmm_context);
429 expect![[
430 r#"[Begin, End, Begin, End, Begin, Middle, End, Begin, End, Begin, Middle, End, Begin, End, Single]"#
431 ]]
432 .assert_eq(&format!("{:?}", hmm_context.best_path));
433 }
434
435 #[test]
436 fn test_hmm_cut() {
437 let sentence = "小明硕士毕业于中国科学院计算所";
438 let mut words = Vec::with_capacity(sentence.chars().count() / 2);
439 cut(sentence, &mut words);
440 expect![[r#"["小明", "硕士", "毕业于", "中国", "科学院", "计算", "所"]"#]].assert_eq(&format!("{:?}", words));
441 }
442}