msg_tool\utils/
jxl.rs

1//! JPEG XL image support
2use super::img::*;
3use super::num_range::*;
4use super::threadpool::*;
5use crate::types::*;
6use anyhow::Result;
7use jpegxl_sys::common::types::*;
8use jpegxl_sys::decode::*;
9use jpegxl_sys::encoder::encode::*;
10use jpegxl_sys::metadata::codestream_header::*;
11use jpegxl_sys::threads::parallel_runner::*;
12use std::ffi::c_void;
13use std::io::Read;
14
15struct JxlDecoderHandle {
16    handle: *mut JxlDecoder,
17}
18
19impl Drop for JxlDecoderHandle {
20    fn drop(&mut self) {
21        unsafe {
22            JxlDecoderDestroy(self.handle);
23        }
24    }
25}
26
27struct JxlEncoderHandle {
28    handle: *mut JxlEncoder,
29}
30
31impl Drop for JxlEncoderHandle {
32    fn drop(&mut self) {
33        unsafe {
34            JxlEncoderDestroy(self.handle);
35        }
36    }
37}
38
39struct ThreadPoolRunner {
40    thread_pool: ThreadPool<()>,
41}
42
43impl ThreadPoolRunner {
44    fn new(workers: usize) -> Result<Self> {
45        let thread_pool = ThreadPool::new(workers, Some("jxl-thread-runner-"), true)?;
46        Ok(Self { thread_pool })
47    }
48}
49
50#[derive(Clone, Copy)]
51struct JpegxlPointer(*mut c_void);
52
53unsafe impl Send for JpegxlPointer {}
54
55unsafe extern "C-unwind" fn thread_pool_runner(
56    runner_opaque: *mut c_void,
57    jpegxl_opaque: *mut c_void,
58    init: JxlParallelRunInit,
59    func: JxlParallelRunFunction,
60    start_range: u32,
61    end_range: u32,
62) -> JxlParallelRetCode {
63    if runner_opaque.is_null() || jpegxl_opaque.is_null() {
64        return JXL_PARALLEL_RET_RUNNER_ERROR;
65    }
66    let runner = unsafe { &*(runner_opaque as *const ThreadPoolRunner) };
67    let initre = unsafe { init(jpegxl_opaque, runner.thread_pool.size()) };
68    if initre != JXL_PARALLEL_RET_SUCCESS {
69        return initre;
70    }
71    let jpegxl = JpegxlPointer(jpegxl_opaque);
72    for i in start_range..end_range {
73        let jpegxl = jpegxl;
74        let func = func;
75        match runner.thread_pool.execute(
76            move |thread_id| unsafe {
77                let jpegxl = jpegxl;
78                func(jpegxl.0, i, thread_id)
79            },
80            true,
81        ) {
82            Ok(_) => {}
83            Err(_) => return JXL_PARALLEL_RET_RUNNER_ERROR,
84        }
85    }
86    runner.thread_pool.join();
87    JXL_PARALLEL_RET_SUCCESS
88}
89
90fn check_decoder_status(status: JxlDecoderStatus) -> Result<()> {
91    match status {
92        JxlDecoderStatus::Success => Ok(()),
93        _ => Err(anyhow::anyhow!("JXL decoder error: {:?}", status)),
94    }
95}
96
97fn check_encoder_status(status: JxlEncoderStatus) -> Result<()> {
98    match status {
99        JxlEncoderStatus::Success => Ok(()),
100        _ => Err(anyhow::anyhow!("JXL encoder error: {:?}", status)),
101    }
102}
103
104fn default_basic_info() -> JxlBasicInfo {
105    let basic_info = std::mem::MaybeUninit::<JxlBasicInfo>::zeroed();
106    unsafe { basic_info.assume_init_read() }
107}
108
109/// Decode JXL image from reader
110pub fn decode_jxl<R: Read>(mut r: R) -> Result<ImageData> {
111    let decoder = unsafe { JxlDecoderCreate(std::ptr::null()) };
112    if decoder.is_null() {
113        return Err(anyhow::anyhow!("Failed to create JXL decoder"));
114    }
115    let dh = JxlDecoderHandle { handle: decoder };
116    let events = JxlDecoderStatus::BasicInfo as i32
117        | JxlDecoderStatus::FullImage as i32
118        | JxlDecoderStatus::ColorEncoding as i32;
119    check_decoder_status(unsafe { JxlDecoderSubscribeEvents(dh.handle, events) })?;
120    let mut data = Vec::new();
121    r.read_to_end(&mut data)?;
122    check_decoder_status(unsafe { JxlDecoderSetInput(dh.handle, data.as_ptr(), data.len()) })?;
123    unsafe {
124        JxlDecoderCloseInput(dh.handle);
125    };
126    let mut basic_info = default_basic_info();
127    let mut color_type = ImageColorType::Rgb;
128    let mut buffer = Vec::new();
129    loop {
130        let status = unsafe { JxlDecoderProcessInput(dh.handle) };
131        match status {
132            JxlDecoderStatus::BasicInfo => {
133                check_decoder_status(unsafe {
134                    JxlDecoderGetBasicInfo(dh.handle, &mut basic_info)
135                })?;
136                match basic_info.num_color_channels {
137                    1 => color_type = ImageColorType::Grayscale,
138                    3 => {
139                        if basic_info.alpha_bits > 0 {
140                            color_type = ImageColorType::Rgba;
141                        } else {
142                            color_type = ImageColorType::Rgb;
143                        }
144                    }
145                    _ => {
146                        return Err(anyhow::anyhow!(
147                            "Unsupported number of color channels: {}",
148                            basic_info.num_color_channels
149                        ));
150                    }
151                }
152                if !matches!(basic_info.bits_per_sample, 8 | 16) {
153                    return Err(anyhow::anyhow!(
154                        "Unsupported bits per sample: {}",
155                        basic_info.bits_per_sample
156                    ));
157                }
158            }
159            JxlDecoderStatus::NeedImageOutBuffer => {
160                let format = JxlPixelFormat {
161                    num_channels: color_type.bpp(1) as u32,
162                    data_type: if basic_info.bits_per_sample <= 8 {
163                        JxlDataType::Uint8
164                    } else {
165                        JxlDataType::Uint16
166                    },
167                    endianness: JxlEndianness::Little,
168                    align: 0,
169                };
170                let mut buffer_size: usize = 0;
171                check_decoder_status(unsafe {
172                    JxlDecoderImageOutBufferSize(dh.handle, &format, &mut buffer_size)
173                })?;
174                buffer.resize(buffer_size, 0);
175                check_decoder_status(unsafe {
176                    JxlDecoderSetImageOutBuffer(
177                        dh.handle,
178                        &format,
179                        buffer.as_mut_ptr() as *mut _,
180                        buffer_size,
181                    )
182                })?;
183            }
184            JxlDecoderStatus::Success => {
185                break;
186            }
187            JxlDecoderStatus::Error => {
188                return Err(anyhow::anyhow!("JXL decoding error"));
189            }
190            _ => {}
191        }
192    }
193    Ok(ImageData {
194        width: basic_info.xsize,
195        height: basic_info.ysize,
196        color_type,
197        depth: basic_info.bits_per_sample as u8,
198        data: buffer,
199    })
200}
201
202/// Encode image data to JXL format
203pub fn encode_jxl(mut img: ImageData, config: &ExtraConfig) -> Result<Vec<u8>> {
204    let encoder = unsafe { JxlEncoderCreate(std::ptr::null()) };
205    if encoder.is_null() {
206        return Err(anyhow::anyhow!("Failed to create JXL encoder"));
207    }
208    let eh = JxlEncoderHandle { handle: encoder };
209    let ph = if config.jxl_workers > 1 {
210        let ph = ThreadPoolRunner::new(config.jxl_workers)?;
211        Some(ph)
212    } else {
213        None
214    };
215    if let Some(ph) = &ph {
216        check_encoder_status(unsafe {
217            JxlEncoderSetParallelRunner(
218                eh.handle,
219                thread_pool_runner,
220                ph as *const _ as *mut c_void,
221            )
222        })?;
223    }
224    let mut basic_info = default_basic_info();
225    basic_info.xsize = img.width;
226    basic_info.ysize = img.height;
227    basic_info.bits_per_sample = match img.depth {
228        8 => 8,
229        16 => 16,
230        _ => {
231            return Err(anyhow::anyhow!(
232                "Unsupported bits per sample: {}",
233                img.depth
234            ));
235        }
236    };
237    basic_info.alpha_bits = match img.color_type {
238        ImageColorType::Rgba | ImageColorType::Bgra => img.depth as u32,
239        _ => 0,
240    };
241    basic_info.num_color_channels = match img.color_type {
242        ImageColorType::Bgr | ImageColorType::Rgb | ImageColorType::Bgra | ImageColorType::Rgba => {
243            3
244        }
245        ImageColorType::Grayscale => 1,
246    };
247    basic_info.num_extra_channels = if basic_info.alpha_bits > 0 { 1 } else { 0 };
248    basic_info.orientation = JxlOrientation::Identity;
249    basic_info.uses_original_profile = JxlBool::True;
250    check_encoder_status(unsafe { JxlEncoderSetBasicInfo(eh.handle, &basic_info) })?;
251    let options = unsafe { JxlEncoderFrameSettingsCreate(eh.handle, std::ptr::null()) };
252    if options.is_null() {
253        return Err(anyhow::anyhow!(
254            "Failed to create JXL encoder frame settings"
255        ));
256    }
257    check_encoder_status(unsafe {
258        JxlEncoderSetFrameLossless(options, JxlBool::from(config.jxl_lossless))
259    })?;
260    if !config.jxl_lossless {
261        let distance = check_range(config.jxl_distance, 0.0, 25.0)
262            .map_err(|e| anyhow::anyhow!("Invalid JXL distance: {}", e))?;
263        check_encoder_status(unsafe { JxlEncoderSetFrameDistance(options, distance) })?;
264    }
265    let format = JxlPixelFormat {
266        num_channels: img.color_type.bpp(1) as u32,
267        data_type: if img.depth <= 8 {
268            JxlDataType::Uint8
269        } else {
270            JxlDataType::Uint16
271        },
272        endianness: JxlEndianness::Little,
273        align: 0,
274    };
275    match img.color_type {
276        ImageColorType::Bgr => {
277            convert_bgr_to_rgb(&mut img)?;
278        }
279        ImageColorType::Bgra => {
280            convert_bgra_to_rgba(&mut img)?;
281        }
282        _ => {}
283    };
284    check_encoder_status(unsafe {
285        JxlEncoderAddImageFrame(
286            options,
287            &format,
288            img.data.as_ptr() as *const _,
289            img.data.len(),
290        )
291    })?;
292    unsafe { JxlEncoderCloseInput(eh.handle) };
293    let mut compressed_data = Vec::new();
294    let mut buffer = [0u8; 4096];
295    loop {
296        let mut avail_out = buffer.len();
297        let mut next_out = buffer.as_mut_ptr();
298        let status = unsafe { JxlEncoderProcessOutput(eh.handle, &mut next_out, &mut avail_out) };
299        let used = buffer.len() - avail_out;
300        compressed_data.extend_from_slice(&buffer[..used]);
301        match status {
302            JxlEncoderStatus::Success => break,
303            JxlEncoderStatus::NeedMoreOutput => {}
304            _ => {
305                return Err(anyhow::anyhow!("JXL encoding error: {:?}", status));
306            }
307        }
308    }
309    Ok(compressed_data)
310}