rune_macros/
trace.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3
4pub(crate) fn expand(orig: &syn::DeriveInput) -> TokenStream {
5    let orig_name = &orig.ident;
6    let rooted_name = format_ident!("Rooted{orig_name}");
7    let generic_params = &orig.generics;
8
9    let derive = match &orig.data {
10        syn::Data::Struct(s) => derive_struct(orig, s),
11        syn::Data::Enum(data_enum) => derive_enum(orig, data_enum),
12        syn::Data::Union(_) => panic!("Derive Trace for Unions is not supported"),
13    };
14
15    quote! {
16        #derive
17
18        impl #generic_params crate::core::gc::RootedDeref for #orig_name #generic_params {
19            type Target = #rooted_name #generic_params;
20
21            fn rooted_deref(rooted: &crate::core::gc::Rt<Self>) -> &Self::Target {
22                unsafe { &*(rooted as *const crate::core::gc::Rt<Self>).cast::<Self::Target>() }            }
23
24            fn rooted_derefmut(rooted: &mut crate::core::gc::Rt<Self>) -> &mut Self::Target {
25                unsafe { &mut *(rooted as *mut crate::core::gc::Rt<Self>).cast::<Self::Target>() }
26            }
27        }
28    }
29}
30
31fn derive_enum(orig: &syn::DeriveInput, data_enum: &syn::DataEnum) -> TokenStream {
32    let rt = quote!(crate::core::gc::Rt);
33    let vis = &orig.vis;
34    let orig_name = &orig.ident;
35    let rooted_name = format_ident!("Rooted{orig_name}");
36    let generic_params = &orig.generics;
37    let repr = get_repr(&orig.attrs);
38
39    let mut new_fields = TokenStream::new();
40    let mut mark_fields = TokenStream::new();
41
42    for x in &data_enum.variants {
43        let no_trace = no_trace(&x.attrs);
44        let ident = &x.ident;
45        match &x.fields {
46            syn::Fields::Unit => {
47                new_fields.extend(quote! { #ident, });
48                mark_fields.extend(quote! { #orig_name::#ident => {}, });
49            }
50            syn::Fields::Unnamed(unnamed_fields) => {
51                if no_trace {
52                    new_fields.extend(quote! { #ident #unnamed_fields, });
53
54                    let num_fields = unnamed_fields.unnamed.iter().count();
55                    let mut punctuated = syn::punctuated::Punctuated::<_, syn::Token![,]>::new();
56                    for _ in 0..num_fields {
57                        punctuated.push(syn::Ident::new("_", proc_macro2::Span::call_site()));
58                    }
59
60                    mark_fields.extend(quote! { #orig_name::#ident (#punctuated) => {}, });
61                } else {
62                    let mut rooted_fields = TokenStream::new();
63                    let mut trace_fields = TokenStream::new();
64                    let mut bind_fields = TokenStream::new();
65                    for (i, field) in unnamed_fields.unnamed.iter().enumerate() {
66                        let binding = format_ident!("x{i}");
67                        bind_fields.extend(quote! {#binding,});
68                        rooted_fields.extend(quote! {#rt<#field>,});
69                        trace_fields
70                            .extend(quote! {crate::core::gc::Trace::trace(#binding, state);});
71                    }
72                    new_fields.extend(quote! { #ident (#rooted_fields), });
73                    mark_fields
74                        .extend(quote! { #orig_name::#ident(#bind_fields) => {#trace_fields}, });
75                }
76            }
77            syn::Fields::Named(_) => unreachable!(),
78        }
79    }
80
81    let doc_string = format!("Automatically derived from [{orig_name}] via `#[derive(Trace)]`");
82    quote! {
83        impl #generic_params crate::core::gc::Trace for #orig_name #generic_params {
84            fn trace(&self, state: &mut crate::core::gc::GcState) {
85                match self {
86                    #mark_fields
87                }
88            }
89        }
90
91        #[automatically_derived]
92        #[allow(non_camel_case_types)]
93        #[doc = #doc_string]
94        #repr
95        #vis enum #rooted_name #generic_params {#new_fields}
96    }
97}
98
99fn derive_struct(orig: &syn::DeriveInput, data_struct: &syn::DataStruct) -> TokenStream {
100    let rt = quote!(crate::core::gc::Rt);
101    let vis = &orig.vis;
102    let orig_name = &orig.ident;
103    let rooted_name = format_ident!("Rooted{orig_name}");
104    let generic_params = &orig.generics;
105    let repr = get_repr(&orig.attrs);
106
107    let mut new_fields = TokenStream::new();
108    let mut mark_fields = TokenStream::new();
109    let mut test_fields = TokenStream::new();
110
111    match &data_struct.fields {
112        syn::Fields::Named(fields) => {
113            for x in &fields.named {
114                #[rustfmt::skip]
115                let syn::Field { vis, ident, ty, attrs, .. } = &x;
116                let ident = ident.as_ref().expect("named fields should have an identifer");
117                let panic_string = format!(
118                    "Field '{}' of struct '{}' is incorrectly aligned in #[derive(Trace)]",
119                    stringify!(#ident),
120                    stringify!(#rooted_name),
121                );
122                test_fields.extend(quote! {
123                    if std::mem::offset_of!(#orig_name, #ident) != std::mem::offset_of!(#rooted_name, #ident) {
124                        panic!(#panic_string);
125                    }
126                });
127                if no_trace(attrs) {
128                    new_fields.extend(quote! {#vis #ident: #ty,});
129                    // Remove dead_code warnings
130                    mark_fields.extend(quote! {let _ = &self.#ident;});
131                } else {
132                    new_fields.extend(quote! {#vis #ident: #rt<#ty>,});
133                    mark_fields
134                        .extend(quote! {crate::core::gc::Trace::trace(&self.#ident, state);});
135                }
136            }
137            new_fields = quote! {{#new_fields}};
138        }
139        syn::Fields::Unnamed(fields) => {
140            for (i, x) in fields.unnamed.iter().enumerate() {
141                let syn::Field { vis, ty, attrs, .. } = &x;
142                let idx = syn::Index::from(i);
143                let panic_string = format!(
144                    "Field '{}' of struct '{}' is incorrectly aligned in #[derive(Trace)]",
145                    stringify!(#idx),
146                    stringify!(#rooted_name),
147                );
148                test_fields.extend(quote! {
149                    if std::mem::offset_of!(#orig_name, #idx) != std::mem::offset_of!(#rooted_name, #idx) {
150                        panic!(#panic_string);
151                    }
152                });
153                if no_trace(attrs) {
154                    new_fields.extend(quote! {#vis #ty,});
155                    // Remove dead_code warnings
156                    mark_fields.extend(quote! {let _ = &self.#idx;});
157                } else {
158                    new_fields.extend(quote! {#vis #rt<#ty>,});
159                    mark_fields.extend(quote! {crate::core::gc::Trace::trace(&self.#idx, state);});
160                }
161            }
162            new_fields = quote! {(#new_fields);};
163        }
164        syn::Fields::Unit => panic!("fieldless structs don't need tracing"),
165    }
166    let test_mod = format_ident!("derive_trace_{orig_name}");
167    let doc_string = format!("Automatically derived from [{orig_name}] via `#[derive(Trace)]`");
168    quote! {
169        impl #generic_params crate::core::gc::Trace for #orig_name #generic_params {
170            fn trace(&self, state: &mut crate::core::gc::GcState) {
171                #mark_fields
172            }
173        }
174
175        #[automatically_derived]
176        #[allow(non_camel_case_types)]
177        #[doc = #doc_string]
178        #repr
179        #vis struct #rooted_name #generic_params #new_fields
180
181        #[allow(dead_code)]
182        #[allow(non_snake_case)]
183        // Ensure at compile time that all fields are at the same offset
184        const #test_mod: () = {
185            #test_fields
186        };
187    }
188}
189
190fn get_repr(attrs: &[syn::Attribute]) -> TokenStream {
191    for attr in attrs {
192        if let syn::Meta::List(list) = &attr.meta {
193            if list.path.is_ident("repr") {
194                return quote! {#attr};
195            }
196        }
197    }
198    quote! {}
199}
200
201fn no_trace(attrs: &[syn::Attribute]) -> bool {
202    for attr in attrs {
203        if let syn::Meta::Path(path) = &attr.meta {
204            if path.is_ident("no_trace") {
205                return true;
206            }
207        }
208    }
209    false
210}
211
212#[cfg(test)]
213mod test {
214    use super::*;
215
216    #[test]
217    fn test_expand_struct() {
218        let stream = quote!(
219            struct LispStack(Vec<GcObj<'static>>);
220        );
221        let input: syn::DeriveInput = syn::parse2(stream).unwrap();
222        let result = expand(&input);
223        println!("{result}");
224
225        let stream = quote!(
226            struct Foo {
227                a: A,
228                #[no_trace]
229                b: B,
230            }
231        );
232        let input: syn::DeriveInput = syn::parse2(stream).unwrap();
233        let result = expand(&input);
234        println!("{result}");
235    }
236
237    #[test]
238    fn test_expand_enum() {
239        let stream = quote!(
240            enum LispStack {
241                A,
242                B(i32),
243                C(String, usize),
244                #[no_trace]
245                D(i32, usize),
246            }
247        );
248        let input: syn::DeriveInput = syn::parse2(stream).unwrap();
249        let result = expand(&input);
250        println!("{result}");
251    }
252
253    #[test]
254    fn test_repr_c() {
255        let stream = quote!(
256            #[repr(C)]
257            struct Foo {
258                a: A,
259                b: B,
260            }
261        );
262        let input: syn::DeriveInput = syn::parse2(stream).unwrap();
263        let result = expand(&input);
264        println!("{result}");
265    }
266}