1use alloc::{string::String, vec::Vec};
2use core::fmt;
3
4use serde::de::{Deserializer, Error, Unexpected, Visitor};
5use serde_core as serde;
6
7use crate::SmolStr;
8
9fn smol_str<'de: 'a, 'a, D>(deserializer: D) -> Result<SmolStr, D::Error>
11where
12 D: Deserializer<'de>,
13{
14 struct SmolStrVisitor;
15
16 impl<'a> Visitor<'a> for SmolStrVisitor {
17 type Value = SmolStr;
18
19 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
20 formatter.write_str("a string")
21 }
22
23 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
24 where
25 E: Error,
26 {
27 Ok(SmolStr::from(v))
28 }
29
30 fn visit_borrowed_str<E>(self, v: &'a str) -> Result<Self::Value, E>
31 where
32 E: Error,
33 {
34 Ok(SmolStr::from(v))
35 }
36
37 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
38 where
39 E: Error,
40 {
41 Ok(SmolStr::from(v))
42 }
43
44 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
45 where
46 E: Error,
47 {
48 match core::str::from_utf8(v) {
49 Ok(s) => Ok(SmolStr::from(s)),
50 Err(_) => Err(Error::invalid_value(Unexpected::Bytes(v), &self)),
51 }
52 }
53
54 fn visit_borrowed_bytes<E>(self, v: &'a [u8]) -> Result<Self::Value, E>
55 where
56 E: Error,
57 {
58 match core::str::from_utf8(v) {
59 Ok(s) => Ok(SmolStr::from(s)),
60 Err(_) => Err(Error::invalid_value(Unexpected::Bytes(v), &self)),
61 }
62 }
63
64 fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
65 where
66 E: Error,
67 {
68 match String::from_utf8(v) {
69 Ok(s) => Ok(SmolStr::from(s)),
70 Err(e) => Err(Error::invalid_value(Unexpected::Bytes(&e.into_bytes()), &self)),
71 }
72 }
73 }
74
75 deserializer.deserialize_str(SmolStrVisitor)
76}
77
78impl serde::Serialize for SmolStr {
79 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
80 where
81 S: serde::Serializer,
82 {
83 self.as_str().serialize(serializer)
84 }
85}
86
87impl<'de> serde::Deserialize<'de> for SmolStr {
88 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
89 where
90 D: serde::Deserializer<'de>,
91 {
92 smol_str(deserializer)
93 }
94}