msg_tool\scripts\escude/
lzw.rs

1use crate::ext::io::*;
2use crate::utils::bit_stream::*;
3use anyhow::Result;
4use std::io::Write;
5
6pub struct LZWDecoder<'a> {
7    m_input: MsbBitStream<MemReaderRef<'a>>,
8    m_output_size: u32,
9}
10
11impl<'a> LZWDecoder<'a> {
12    pub fn new(input: &'a [u8]) -> Result<Self> {
13        let mut input_reader = MemReaderRef::new(input);
14        let size = input_reader.peek_u32_be_at(0x4)?;
15        let m_input = MsbBitStream::new(MemReaderRef::new(&input[0x8..]));
16        Ok(LZWDecoder {
17            m_input,
18            m_output_size: size,
19        })
20    }
21
22    pub fn unpack(&mut self) -> Result<Vec<u8>> {
23        let size = self.m_output_size as usize;
24        let mut output = Vec::with_capacity(size);
25        output.resize(size, 0);
26        let mut dict = Vec::with_capacity(0x8900);
27        dict.resize(0x8900, 0u32);
28        let mut token_width = 9;
29        let mut dict_pos = 0;
30        let mut dst = 0;
31        while dst < size {
32            let mut token = self.m_input.get_bits(token_width)?;
33            if token == 0x100 {
34                // End of stream
35                break;
36            } else if token == 0x101 {
37                token_width += 1;
38                if token_width > 24 {
39                    return Err(anyhow::anyhow!("Token width exceeded maximum of 12 bits"));
40                }
41            } else if token == 0x102 {
42                token_width = 9;
43                dict_pos = 0;
44            } else {
45                if dict_pos > 0x8900 {
46                    return Err(anyhow::anyhow!(
47                        "Dictionary position exceeded maximum of 0x8900"
48                    ));
49                }
50                dict[dict_pos] = dst as u32;
51                dict_pos += 1;
52                if token < 0x100 {
53                    output[dst] = token as u8;
54                    dst += 1;
55                } else {
56                    token -= 0x103;
57                    if token >= dict_pos as u32 {
58                        return Err(anyhow::anyhow!("Token out of bounds: {}", token));
59                    }
60                    let src = dict[token as usize];
61                    let count =
62                        (self.m_output_size - dst as u32).min(dict[token as usize + 1] - src + 1);
63                    for i in 0..count {
64                        output[dst + i as usize] = output[src as usize + i as usize];
65                    }
66                    dst += count as usize;
67                }
68            }
69        }
70        Ok(output)
71    }
72}
73
74pub struct LZWEncoder {
75    buf: MemWriter,
76}
77
78impl LZWEncoder {
79    pub fn new() -> Self {
80        LZWEncoder {
81            buf: MemWriter::new(),
82        }
83    }
84
85    pub fn encode(mut self, input: &[u8], fake: bool) -> Result<Vec<u8>> {
86        self.buf.write_all(b"acp\0")?;
87        self.buf.write_u32_be(input.len() as u32)?;
88        let mut writer = MsbBitWriter::new(&mut self.buf);
89        if fake {
90            for i in 0..input.len() {
91                if i > 0 && i % 0x4000 == 0 {
92                    writer.put_bits(0x102, 9)?;
93                }
94                writer.put_bits(input[i] as u32, 9)?;
95            }
96            writer.put_bits(0x100, 9)?; // End of stream
97            writer.flush()?;
98        } else {
99            let mut dict = std::collections::HashMap::new();
100            for i in 0..256 {
101                dict.insert(vec![i as u8], i);
102            }
103            let mut next_code = 0x103u32;
104            let mut token_width = 9;
105
106            let mut i = 0;
107            while i < input.len() {
108                let mut current = vec![input[i]];
109                i += 1;
110
111                while i < input.len()
112                    && dict.contains_key(&{
113                        let mut temp = current.clone();
114                        temp.push(input[i]);
115                        temp
116                    })
117                {
118                    current.push(input[i]);
119                    i += 1;
120                }
121
122                let code = dict[&current];
123                writer.put_bits(code, token_width)?;
124
125                if i < input.len() {
126                    let mut new_entry = current.clone();
127                    new_entry.push(input[i]);
128                    dict.insert(new_entry, next_code);
129                    next_code += 1;
130
131                    if next_code >= (1 << token_width) && token_width < 24 {
132                        writer.put_bits(0x101, token_width)?; // Increase token width
133                        token_width += 1;
134                    }
135
136                    if dict.len() >= 0x8900 {
137                        writer.put_bits(0x102, token_width)?; // Clear dictionary
138                        dict.clear();
139                        for j in 0..256 {
140                            dict.insert(vec![j as u8], j);
141                        }
142                        next_code = 0x103;
143                        token_width = 9;
144                    }
145                }
146            }
147            writer.put_bits(0x100, token_width)?; // End of stream
148            writer.flush()?;
149        }
150
151        Ok(self.buf.into_inner())
152    }
153}