Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/rust/macros/kunit.rs
29266 views
1
// SPDX-License-Identifier: GPL-2.0
2
3
//! Procedural macro to run KUnit tests using a user-space like syntax.
4
//!
5
//! Copyright (c) 2023 José Expósito <[email protected]>
6
7
use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
8
use std::collections::HashMap;
9
use std::fmt::Write;
10
11
pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
12
let attr = attr.to_string();
13
14
if attr.is_empty() {
15
panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
16
}
17
18
if attr.len() > 255 {
19
panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
20
}
21
22
let mut tokens: Vec<_> = ts.into_iter().collect();
23
24
// Scan for the `mod` keyword.
25
tokens
26
.iter()
27
.find_map(|token| match token {
28
TokenTree::Ident(ident) => match ident.to_string().as_str() {
29
"mod" => Some(true),
30
_ => None,
31
},
32
_ => None,
33
})
34
.expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
35
36
// Retrieve the main body. The main body should be the last token tree.
37
let body = match tokens.pop() {
38
Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
39
_ => panic!("Cannot locate main body of module"),
40
};
41
42
// Get the functions set as tests. Search for `[test]` -> `fn`.
43
let mut body_it = body.stream().into_iter();
44
let mut tests = Vec::new();
45
let mut attributes: HashMap<String, TokenStream> = HashMap::new();
46
while let Some(token) = body_it.next() {
47
match token {
48
TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
49
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
50
if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
51
// Collect attributes because we need to find which are tests. We also
52
// need to copy `cfg` attributes so tests can be conditionally enabled.
53
attributes
54
.entry(name.to_string())
55
.or_default()
56
.extend([token, TokenTree::Group(g)]);
57
}
58
continue;
59
}
60
_ => (),
61
},
62
TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
63
if let Some(TokenTree::Ident(test_name)) = body_it.next() {
64
tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
65
}
66
}
67
68
_ => (),
69
}
70
attributes.clear();
71
}
72
73
// Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
74
let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
75
tokens.insert(
76
0,
77
TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
78
);
79
80
// Generate the test KUnit test suite and a test case for each `#[test]`.
81
// The code generated for the following test module:
82
//
83
// ```
84
// #[kunit_tests(kunit_test_suit_name)]
85
// mod tests {
86
// #[test]
87
// fn foo() {
88
// assert_eq!(1, 1);
89
// }
90
//
91
// #[test]
92
// fn bar() {
93
// assert_eq!(2, 2);
94
// }
95
// }
96
// ```
97
//
98
// Looks like:
99
//
100
// ```
101
// unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
102
// unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
103
//
104
// static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
105
// ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo),
106
// ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar),
107
// ::kernel::kunit::kunit_case_null(),
108
// ];
109
//
110
// ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
111
// ```
112
let mut kunit_macros = "".to_owned();
113
let mut test_cases = "".to_owned();
114
let mut assert_macros = "".to_owned();
115
let path = crate::helpers::file();
116
let num_tests = tests.len();
117
for (test, cfg_attr) in tests {
118
let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
119
// Append any `cfg` attributes the user might have written on their tests so we don't
120
// attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
121
// the length of the assert message.
122
let kunit_wrapper = format!(
123
r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
124
{{
125
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
126
{cfg_attr} {{
127
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
128
use ::kernel::kunit::is_test_result_ok;
129
assert!(is_test_result_ok({test}()));
130
}}
131
}}"#,
132
);
133
writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
134
writeln!(
135
test_cases,
136
" ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
137
)
138
.unwrap();
139
writeln!(
140
assert_macros,
141
r#"
142
/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
143
#[allow(unused)]
144
macro_rules! assert {{
145
($cond:expr $(,)?) => {{{{
146
kernel::kunit_assert!("{test}", "{path}", 0, $cond);
147
}}}}
148
}}
149
150
/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
151
#[allow(unused)]
152
macro_rules! assert_eq {{
153
($left:expr, $right:expr $(,)?) => {{{{
154
kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
155
}}}}
156
}}
157
"#
158
)
159
.unwrap();
160
}
161
162
writeln!(kunit_macros).unwrap();
163
writeln!(
164
kunit_macros,
165
"static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",
166
num_tests + 1
167
)
168
.unwrap();
169
170
writeln!(
171
kunit_macros,
172
"::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
173
)
174
.unwrap();
175
176
// Remove the `#[test]` macros.
177
// We do this at a token level, in order to preserve span information.
178
let mut new_body = vec![];
179
let mut body_it = body.stream().into_iter();
180
181
while let Some(token) = body_it.next() {
182
match token {
183
TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
184
Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
185
Some(next) => {
186
new_body.extend([token, next]);
187
}
188
_ => {
189
new_body.push(token);
190
}
191
},
192
_ => {
193
new_body.push(token);
194
}
195
}
196
}
197
198
let mut final_body = TokenStream::new();
199
final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
200
final_body.extend(new_body);
201
final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
202
203
tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
204
205
tokens.into_iter().collect()
206
}
207
208