1use super::img::*;
3use super::num_range::*;
4use super::threadpool::*;
5use crate::types::*;
6use anyhow::Result;
7use jpegxl_sys::common::types::*;
8use jpegxl_sys::decode::*;
9use jpegxl_sys::encoder::encode::*;
10use jpegxl_sys::metadata::codestream_header::*;
11use jpegxl_sys::threads::parallel_runner::*;
12use std::ffi::c_void;
13use std::io::Read;
14
15struct JxlDecoderHandle {
16 handle: *mut JxlDecoder,
17}
18
19impl Drop for JxlDecoderHandle {
20 fn drop(&mut self) {
21 unsafe {
22 JxlDecoderDestroy(self.handle);
23 }
24 }
25}
26
27struct JxlEncoderHandle {
28 handle: *mut JxlEncoder,
29}
30
31impl Drop for JxlEncoderHandle {
32 fn drop(&mut self) {
33 unsafe {
34 JxlEncoderDestroy(self.handle);
35 }
36 }
37}
38
39struct ThreadPoolRunner {
40 thread_pool: ThreadPool<()>,
41}
42
43impl ThreadPoolRunner {
44 fn new(workers: usize) -> Result<Self> {
45 let thread_pool = ThreadPool::new(workers, Some("jxl-thread-runner-"), true)?;
46 Ok(Self { thread_pool })
47 }
48}
49
50#[derive(Clone, Copy)]
51struct JpegxlPointer(*mut c_void);
52
53unsafe impl Send for JpegxlPointer {}
54
55unsafe extern "C-unwind" fn thread_pool_runner(
56 runner_opaque: *mut c_void,
57 jpegxl_opaque: *mut c_void,
58 init: JxlParallelRunInit,
59 func: JxlParallelRunFunction,
60 start_range: u32,
61 end_range: u32,
62) -> JxlParallelRetCode {
63 if runner_opaque.is_null() || jpegxl_opaque.is_null() {
64 return JXL_PARALLEL_RET_RUNNER_ERROR;
65 }
66 let runner = unsafe { &*(runner_opaque as *const ThreadPoolRunner) };
67 let initre = unsafe { init(jpegxl_opaque, runner.thread_pool.size()) };
68 if initre != JXL_PARALLEL_RET_SUCCESS {
69 return initre;
70 }
71 let jpegxl = JpegxlPointer(jpegxl_opaque);
72 for i in start_range..end_range {
73 let jpegxl = jpegxl;
74 let func = func;
75 match runner.thread_pool.execute(
76 move |thread_id| unsafe {
77 let jpegxl = jpegxl;
78 func(jpegxl.0, i, thread_id)
79 },
80 true,
81 ) {
82 Ok(_) => {}
83 Err(_) => return JXL_PARALLEL_RET_RUNNER_ERROR,
84 }
85 }
86 runner.thread_pool.join();
87 JXL_PARALLEL_RET_SUCCESS
88}
89
90fn check_decoder_status(status: JxlDecoderStatus) -> Result<()> {
91 match status {
92 JxlDecoderStatus::Success => Ok(()),
93 _ => Err(anyhow::anyhow!("JXL decoder error: {:?}", status)),
94 }
95}
96
97fn check_encoder_status(status: JxlEncoderStatus) -> Result<()> {
98 match status {
99 JxlEncoderStatus::Success => Ok(()),
100 _ => Err(anyhow::anyhow!("JXL encoder error: {:?}", status)),
101 }
102}
103
104fn default_basic_info() -> JxlBasicInfo {
105 let basic_info = std::mem::MaybeUninit::<JxlBasicInfo>::zeroed();
106 unsafe { basic_info.assume_init_read() }
107}
108
109pub fn decode_jxl<R: Read>(mut r: R) -> Result<ImageData> {
111 let decoder = unsafe { JxlDecoderCreate(std::ptr::null()) };
112 if decoder.is_null() {
113 return Err(anyhow::anyhow!("Failed to create JXL decoder"));
114 }
115 let dh = JxlDecoderHandle { handle: decoder };
116 let events = JxlDecoderStatus::BasicInfo as i32
117 | JxlDecoderStatus::FullImage as i32
118 | JxlDecoderStatus::ColorEncoding as i32;
119 check_decoder_status(unsafe { JxlDecoderSubscribeEvents(dh.handle, events) })?;
120 let mut data = Vec::new();
121 r.read_to_end(&mut data)?;
122 check_decoder_status(unsafe { JxlDecoderSetInput(dh.handle, data.as_ptr(), data.len()) })?;
123 unsafe {
124 JxlDecoderCloseInput(dh.handle);
125 };
126 let mut basic_info = default_basic_info();
127 let mut color_type = ImageColorType::Rgb;
128 let mut buffer = Vec::new();
129 loop {
130 let status = unsafe { JxlDecoderProcessInput(dh.handle) };
131 match status {
132 JxlDecoderStatus::BasicInfo => {
133 check_decoder_status(unsafe {
134 JxlDecoderGetBasicInfo(dh.handle, &mut basic_info)
135 })?;
136 match basic_info.num_color_channels {
137 1 => color_type = ImageColorType::Grayscale,
138 3 => {
139 if basic_info.alpha_bits > 0 {
140 color_type = ImageColorType::Rgba;
141 } else {
142 color_type = ImageColorType::Rgb;
143 }
144 }
145 _ => {
146 return Err(anyhow::anyhow!(
147 "Unsupported number of color channels: {}",
148 basic_info.num_color_channels
149 ));
150 }
151 }
152 if !matches!(basic_info.bits_per_sample, 8 | 16) {
153 return Err(anyhow::anyhow!(
154 "Unsupported bits per sample: {}",
155 basic_info.bits_per_sample
156 ));
157 }
158 }
159 JxlDecoderStatus::NeedImageOutBuffer => {
160 let format = JxlPixelFormat {
161 num_channels: color_type.bpp(1) as u32,
162 data_type: if basic_info.bits_per_sample <= 8 {
163 JxlDataType::Uint8
164 } else {
165 JxlDataType::Uint16
166 },
167 endianness: JxlEndianness::Little,
168 align: 0,
169 };
170 let mut buffer_size: usize = 0;
171 check_decoder_status(unsafe {
172 JxlDecoderImageOutBufferSize(dh.handle, &format, &mut buffer_size)
173 })?;
174 buffer.resize(buffer_size, 0);
175 check_decoder_status(unsafe {
176 JxlDecoderSetImageOutBuffer(
177 dh.handle,
178 &format,
179 buffer.as_mut_ptr() as *mut _,
180 buffer_size,
181 )
182 })?;
183 }
184 JxlDecoderStatus::Success => {
185 break;
186 }
187 JxlDecoderStatus::Error => {
188 return Err(anyhow::anyhow!("JXL decoding error"));
189 }
190 _ => {}
191 }
192 }
193 Ok(ImageData {
194 width: basic_info.xsize,
195 height: basic_info.ysize,
196 color_type,
197 depth: basic_info.bits_per_sample as u8,
198 data: buffer,
199 })
200}
201
202pub fn encode_jxl(mut img: ImageData, config: &ExtraConfig) -> Result<Vec<u8>> {
204 let encoder = unsafe { JxlEncoderCreate(std::ptr::null()) };
205 if encoder.is_null() {
206 return Err(anyhow::anyhow!("Failed to create JXL encoder"));
207 }
208 let eh = JxlEncoderHandle { handle: encoder };
209 let ph = if config.jxl_workers > 1 {
210 let ph = ThreadPoolRunner::new(config.jxl_workers)?;
211 Some(ph)
212 } else {
213 None
214 };
215 if let Some(ph) = &ph {
216 check_encoder_status(unsafe {
217 JxlEncoderSetParallelRunner(
218 eh.handle,
219 thread_pool_runner,
220 ph as *const _ as *mut c_void,
221 )
222 })?;
223 }
224 let mut basic_info = default_basic_info();
225 basic_info.xsize = img.width;
226 basic_info.ysize = img.height;
227 basic_info.bits_per_sample = match img.depth {
228 8 => 8,
229 16 => 16,
230 _ => {
231 return Err(anyhow::anyhow!(
232 "Unsupported bits per sample: {}",
233 img.depth
234 ));
235 }
236 };
237 basic_info.alpha_bits = match img.color_type {
238 ImageColorType::Rgba | ImageColorType::Bgra => img.depth as u32,
239 _ => 0,
240 };
241 basic_info.num_color_channels = match img.color_type {
242 ImageColorType::Bgr | ImageColorType::Rgb | ImageColorType::Bgra | ImageColorType::Rgba => {
243 3
244 }
245 ImageColorType::Grayscale => 1,
246 };
247 basic_info.num_extra_channels = if basic_info.alpha_bits > 0 { 1 } else { 0 };
248 basic_info.orientation = JxlOrientation::Identity;
249 basic_info.uses_original_profile = JxlBool::True;
250 check_encoder_status(unsafe { JxlEncoderSetBasicInfo(eh.handle, &basic_info) })?;
251 let options = unsafe { JxlEncoderFrameSettingsCreate(eh.handle, std::ptr::null()) };
252 if options.is_null() {
253 return Err(anyhow::anyhow!(
254 "Failed to create JXL encoder frame settings"
255 ));
256 }
257 check_encoder_status(unsafe {
258 JxlEncoderSetFrameLossless(options, JxlBool::from(config.jxl_lossless))
259 })?;
260 if !config.jxl_lossless {
261 let distance = check_range(config.jxl_distance, 0.0, 25.0)
262 .map_err(|e| anyhow::anyhow!("Invalid JXL distance: {}", e))?;
263 check_encoder_status(unsafe { JxlEncoderSetFrameDistance(options, distance) })?;
264 }
265 let format = JxlPixelFormat {
266 num_channels: img.color_type.bpp(1) as u32,
267 data_type: if img.depth <= 8 {
268 JxlDataType::Uint8
269 } else {
270 JxlDataType::Uint16
271 },
272 endianness: JxlEndianness::Little,
273 align: 0,
274 };
275 match img.color_type {
276 ImageColorType::Bgr => {
277 convert_bgr_to_rgb(&mut img)?;
278 }
279 ImageColorType::Bgra => {
280 convert_bgra_to_rgba(&mut img)?;
281 }
282 _ => {}
283 };
284 check_encoder_status(unsafe {
285 JxlEncoderAddImageFrame(
286 options,
287 &format,
288 img.data.as_ptr() as *const _,
289 img.data.len(),
290 )
291 })?;
292 unsafe { JxlEncoderCloseInput(eh.handle) };
293 let mut compressed_data = Vec::new();
294 let mut buffer = [0u8; 4096];
295 loop {
296 let mut avail_out = buffer.len();
297 let mut next_out = buffer.as_mut_ptr();
298 let status = unsafe { JxlEncoderProcessOutput(eh.handle, &mut next_out, &mut avail_out) };
299 let used = buffer.len() - avail_out;
300 compressed_data.extend_from_slice(&buffer[..used]);
301 match status {
302 JxlEncoderStatus::Success => break,
303 JxlEncoderStatus::NeedMoreOutput => {}
304 _ => {
305 return Err(anyhow::anyhow!("JXL encoding error: {:?}", status));
306 }
307 }
308 }
309 Ok(compressed_data)
310}