rune_macros/
defun.rs

1#![allow(clippy::manual_unwrap_or_default)]
2use darling::FromMeta;
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::Error;
6
7pub(crate) fn expand(function: Function, spec: Spec) -> TokenStream {
8    if let Some(required) = spec.required {
9        let actual_required = function.args.iter().filter(|x| x.is_positional_arg()).count();
10        if required as usize > actual_required {
11            return quote! { compile_error!("Spec `required` is larger then the number of arguments provided"); };
12        }
13    }
14
15    let body = function.body;
16    let subr = function.name;
17    let subr_name = subr.to_string();
18    let struct_name = format_ident!("__subr_{}", &subr_name);
19    let func_name = format_ident!("__wrapper_fn_{}", &subr_name);
20    let lisp_name = spec.name.unwrap_or_else(|| map_function_name(&subr_name));
21    let (required, optional, rest) = parse_call_signature(&function.args, spec.required);
22
23    let arg_conversion = get_arg_conversion(&function.args);
24
25    let create_args = if !function.args.iter().any(|x| matches!(x, ArgType::Env(MUT))) {
26        // If mut Env is not needed, then we can just pass a slice from the
27        // stack directly. This is the cheapest option
28        quote! { let args = &env.stack[..arg_cnt]; }
29    } else if !rest || function.args.iter().any(|x| matches!(x, ArgType::ArgSlice)) {
30        // If number of arguments are known, we can allocate them on the stack
31        // as an array to avoid allocation. This is the purpose of the ArgSlice
32        // type, to give us a know set of arguments when the callee is expecting
33        // a slice.
34        let positional_args = (required + optional) as usize;
35        quote! {
36            let init: [crate::core::gc::Slot<crate::core::object::Object>; #positional_args] = core::array::from_fn(|x| crate::core::gc::Slot::new(crate::core::object::NIL));
37            rune_core::macros::root!(args, init(init), cx);
38            let pos_count = std::cmp::min(arg_cnt, #positional_args);
39            let stack_slice = &env.stack[..arg_cnt];
40            for i in 0..pos_count {
41                args[i].set(&stack_slice[i]);
42            }
43            let args = &args[..pos_count];
44        }
45    } else {
46        // If the function requires a mutable env and arguments are unbounded,
47        // we should use the ArgSlice type to avoid allocating
48        quote! { compile_error!("Can't use an argument slice with a mutable enviroment. Use `ArgSlice` instead."); }
49    };
50
51    let err = if function.fallible {
52        quote! {?}
53    } else {
54        quote! {}
55    };
56
57    // Create the context from a pointer to get around the issue that the
58    // return val is bound to the mutable borrow, meaning we can use them
59    // both in the into_obj function. Similar to the rebind! macro.
60    let subr_call = quote! {
61        let ptr = cx as *mut crate::core::gc::Context;
62        let val = #subr(#(#arg_conversion),*)#err;
63        let cx: &'ob mut crate::core::gc::Context = unsafe {&mut *ptr};
64        Ok(crate::core::object::IntoObject::into_obj(val, cx).into())
65    };
66
67    let arg_count_guard = match (required as usize, optional as usize, rest) {
68        (r, 0, false) => quote! {arg_cnt != #r},
69        (0, o, false) => quote! {#o < arg_cnt},
70        (r, o, false) => quote! {arg_cnt < #r || #r + #o < arg_cnt},
71        (r, _, true) => quote! {arg_cnt < #r},
72    };
73
74    quote! {
75        #[automatically_derived]
76        #[doc(hidden)]
77        fn #func_name<'ob>(
78            arg_cnt: usize,
79            env: &mut crate::core::gc::Rt<crate::core::env::Env>,
80            cx: &'ob mut crate::core::gc::Context,
81        ) -> anyhow::Result<crate::core::object::Object<'ob>> {
82            #[allow(clippy::manual_range_contains)]
83            if #arg_count_guard {
84                let upper = #required + #optional;
85                let expected = if arg_cnt > upper as usize {upper} else {#required};
86                return Err(crate::data::LispError::arg_cnt(#lisp_name, expected, arg_cnt as u16, cx).into());
87            }
88            #create_args
89            #subr_call
90        }
91
92        #[automatically_derived]
93        #[doc(hidden)]
94        #[allow(non_upper_case_globals)]
95        pub(crate) const #struct_name: crate::core::object::SubrFn = crate::core::object::SubrFn {
96            name: #lisp_name,
97            subr: #func_name,
98            args: crate::core::object::FnArgs {
99                required: #required,
100                optional: #optional,
101                rest: #rest,
102                advice: false,
103            }
104        };
105
106        #body
107    }
108}
109
110fn get_arg_conversion(args: &[ArgType]) -> Vec<TokenStream> {
111    let is_mut = args.iter().any(|ty| matches!(ty, ArgType::Context(MUT)));
112    args.iter()
113        .enumerate()
114        .map(|(idx, arg_type)| match arg_type {
115            ArgType::Context(_) => quote! {cx},
116            ArgType::Env(_) => quote! {env},
117            // Rt<Gc<..>>
118            ArgType::Rt(gc) => match gc {
119                Gc::Obj => quote! {&args[#idx]},
120                Gc::Other => quote! {args[#idx].try_as()?},
121            },
122            // Gc<..>
123            ArgType::Gc(gc) => {
124                let bind = quote! {args[#idx].bind(cx)};
125                match gc {
126                    Gc::Obj => bind,
127                    Gc::Other => quote! { std::convert::TryFrom::try_from(#bind)? },
128                }
129            }
130            // &[Gc<..>]
131            ArgType::Slice(gc) => {
132                let bind =
133                    quote! {crate::core::gc::Rt::bind_slice(&args[(#idx).min(args.len())..], cx)};
134                match gc {
135                    Gc::Obj => bind,
136                    Gc::Other => quote! {crate::core::object::try_from_slice(#bind)?},
137                }
138            }
139            // &[Rt<Gc<..>>]
140            ArgType::SliceRt(gc) => match gc {
141                Gc::Obj => quote! {&args[(#idx).min(args.len())..]},
142                Gc::Other => unreachable!(),
143            },
144            // ArgSlice
145            ArgType::ArgSlice => {
146                let positional = args.iter().filter(|x| x.is_positional_arg()).count();
147                quote! {crate::core::env::ArgSlice::new(arg_cnt.saturating_sub(#positional))}
148            }
149            // Option<Rt<Gc<..>>>
150            ArgType::OptionRt => {
151                quote! {
152                    match args.get(#idx) {
153                        Some(x) => crate::core::gc::Rt::try_as_option(x)?,
154                        None => None,
155                    }
156                }
157            }
158            // Option<T>
159            ArgType::Option => {
160                let bind = quote! {x.bind(cx)};
161                quote! {
162                    match args.get(#idx) {
163                        Some(x) => crate::core::object::Gc::try_from_option(#bind)?,
164                        None => None,
165                    }
166                }
167            }
168            ArgType::Other => {
169                if is_mut {
170                    quote! { std::convert::TryFrom::try_from(&args[#idx])? }
171                } else {
172                    let bind = quote! {args[#idx].bind(cx)};
173                    quote! { std::convert::TryFrom::try_from(#bind)? }
174                }
175            }
176        })
177        .collect()
178}
179
180fn parse_call_signature(args: &[ArgType], spec_required: Option<u16>) -> (u16, u16, bool) {
181    let required = {
182        let actual_required = args.iter().filter(|x| x.is_required_arg()).count();
183        let spec_required = match spec_required {
184            Some(x) => x as usize,
185            None => 0,
186        };
187        std::cmp::max(actual_required, spec_required)
188    };
189
190    let optional = {
191        let pos_args = args.iter().filter(|x| x.is_positional_arg()).count();
192        pos_args - required
193    };
194
195    let rest = args.iter().any(|x| x.is_rest_arg());
196
197    let required = u16::try_from(required).unwrap();
198    let optional = u16::try_from(optional).unwrap();
199    (required, optional, rest)
200}
201
202fn get_path_ident_name(type_path: &syn::TypePath) -> String {
203    type_path.path.segments.last().unwrap().ident.to_string()
204}
205
206fn map_function_name(name: &str) -> String {
207    name.replace('_', "-")
208}
209
210const MUT: bool = true;
211
212#[derive(PartialEq, Debug, Copy, Clone)]
213enum Gc {
214    Obj,
215    Other,
216}
217
218#[derive(PartialEq, Debug, Copy, Clone)]
219enum ArgType {
220    Context(bool),
221    Env(bool),
222    Rt(Gc),
223    Gc(Gc),
224    Slice(Gc),
225    SliceRt(Gc),
226    ArgSlice,
227    Option,
228    OptionRt,
229    Other,
230}
231
232impl ArgType {
233    fn is_required_arg(self) -> bool {
234        use ArgType as A;
235        matches!(self, A::Rt(_) | A::Gc(_) | A::Other)
236    }
237
238    fn is_positional_arg(self) -> bool {
239        use ArgType as A;
240        matches!(self, A::Rt(_) | A::Gc(_) | A::Other | A::Option | A::OptionRt)
241    }
242
243    fn is_rest_arg(self) -> bool {
244        use ArgType as A;
245        matches!(self, A::SliceRt(_) | A::Slice(_) | A::ArgSlice)
246    }
247}
248
249pub(crate) struct Function {
250    name: syn::Ident,
251    body: syn::Item,
252    args: Vec<ArgType>,
253    fallible: bool,
254}
255
256impl syn::parse::Parse for Function {
257    fn parse(input: syn::parse::ParseStream) -> Result<Self, Error> {
258        let item: syn::Item = input.parse()?;
259        parse_fn(item)
260    }
261}
262
263fn parse_fn(item: syn::Item) -> Result<Function, Error> {
264    match item {
265        syn::Item::Fn(syn::ItemFn { ref sig, .. }) => {
266            if sig.unsafety.is_some() {
267                Err(Error::new_spanned(sig, "lisp functions cannot be `unsafe`"))
268            } else {
269                let args = parse_signature(sig)?;
270                check_invariants(&args, sig)?;
271                let fallible = return_type_is_result(&sig.output);
272                Ok(Function { name: sig.ident.clone(), body: item, args, fallible })
273            }
274        }
275        _ => Err(Error::new_spanned(item, "`lisp_fn` attribute can only be used on functions")),
276    }
277}
278
279fn check_invariants(args: &[ArgType], sig: &syn::Signature) -> Result<(), Error> {
280    let is_mut = args.iter().any(|x| matches!(x, ArgType::Context(MUT)));
281    if is_mut {
282        let mut iter = sig.inputs.iter().zip(args.iter());
283        if let Some((arg, _)) = iter.find(|(_, ty)| matches!(ty, ArgType::Gc(_))) {
284            return Err(Error::new_spanned(
285                arg,
286                "Can't have raw Gc pointer in function with mutable Context",
287            ));
288        }
289    }
290    let mut iter = sig.inputs.iter().zip(args.iter());
291    if let Some((arg, _)) = iter.find(|(_, ty)| matches!(ty, ArgType::SliceRt(Gc::Other))) {
292        return Err(Error::new_spanned(arg, "Converting an Rt slice to is unimplemented"));
293    }
294
295    for (arg, ty) in sig.inputs.iter().zip(args.iter()) {
296        if matches!(ty, ArgType::Rt(_))
297            && let syn::FnArg::Typed(ty) = arg
298            && !matches!(ty.ty.as_ref(), syn::Type::Reference(_))
299        {
300            return Err(Error::new_spanned(arg, "Can't take Rt by value"));
301        }
302    }
303
304    let rest_args = args.iter().filter(|x| x.is_rest_arg()).count();
305    if rest_args > 1 {
306        return Err(Error::new_spanned(sig, "Found duplicate argument slice in signature"));
307    }
308
309    let first_opt = args.iter().position(|x| matches!(x, ArgType::Option));
310    let last_required = args.iter().rposition(|x| x.is_required_arg()).unwrap_or_default();
311    if let Some(first_optional) = first_opt
312        && last_required > first_optional
313    {
314        let arg = sig.inputs.iter().nth(last_required).unwrap();
315        return Err(Error::new_spanned(
316            arg,
317            "Required argument is after the first optional argument",
318        ));
319    }
320    Ok(())
321}
322
323fn parse_signature(sig: &syn::Signature) -> Result<Vec<ArgType>, Error> {
324    let mut args = Vec::new();
325    for input in &sig.inputs {
326        match input {
327            syn::FnArg::Receiver(x) => {
328                return Err(Error::new_spanned(x, "Self is not valid in lisp functions"));
329            }
330            syn::FnArg::Typed(pat_type) => {
331                let ty = pat_type.ty.as_ref().clone();
332                let arg = get_arg_type(&ty)?;
333                args.push(arg);
334            }
335        }
336    }
337    Ok(args)
338}
339
340fn return_type_is_result(output: &syn::ReturnType) -> bool {
341    match output {
342        syn::ReturnType::Type(_, ty) => match ty.as_ref() {
343            syn::Type::Path(path) => get_path_ident_name(path) == "Result",
344            _ => false,
345        },
346        syn::ReturnType::Default => false,
347    }
348}
349
350fn get_arg_type(ty: &syn::Type) -> Result<ArgType, Error> {
351    Ok(match ty {
352        syn::Type::Reference(syn::TypeReference { elem, mutability, .. }) => match elem.as_ref() {
353            syn::Type::Path(path) => match &*get_path_ident_name(path) {
354                "Context" => ArgType::Context(mutability.is_some()),
355                "Rt" | "Rto" => get_rt_type(path, mutability.is_some())?,
356                _ => ArgType::Other,
357            },
358            syn::Type::Slice(slice) => match get_arg_type(slice.elem.as_ref())? {
359                ArgType::Rt(rt) => ArgType::SliceRt(rt),
360                ArgType::Gc(gc) => ArgType::Slice(gc),
361                _ => ArgType::Slice(Gc::Other),
362            },
363            _ => ArgType::Other,
364        },
365        syn::Type::Path(path) => {
366            let name = get_path_ident_name(path);
367            match &*name {
368                "ArgSlice" => ArgType::ArgSlice,
369                "Rt" | "Rto" => get_rt_type(path, false)?,
370                "Option" => {
371                    let outer = path.path.segments.last().unwrap();
372                    match get_generic_param(outer) {
373                        Some(syn::Type::Reference(inner)) => match inner.elem.as_ref() {
374                            syn::Type::Path(path) => match &*get_path_ident_name(path) {
375                                "Rt" | "Rto" => ArgType::OptionRt,
376                                _ => ArgType::Option,
377                            },
378                            _ => ArgType::Option,
379                        },
380                        _ => ArgType::Option,
381                    }
382                }
383                "OptionalFlag" => ArgType::Option,
384                _ => get_object_type(path),
385            }
386        }
387        _ => ArgType::Other,
388    })
389}
390
391fn get_object_type(type_path: &syn::TypePath) -> ArgType {
392    let outer_type = type_path.path.segments.last().unwrap();
393    if outer_type.ident == "Object" {
394        ArgType::Gc(Gc::Obj)
395    } else if outer_type.ident == "Function"
396        || outer_type.ident == "Number"
397        || outer_type.ident == "List"
398    {
399        ArgType::Gc(Gc::Other)
400    } else if outer_type.ident == "Slot" {
401        match get_generic_param(outer_type) {
402            Some(syn::Type::Path(inner)) => get_object_type(inner),
403            _ => ArgType::Gc(Gc::Other),
404        }
405    } else if outer_type.ident == "Gc" {
406        let inner = match get_generic_param(outer_type) {
407            Some(syn::Type::Path(generic)) if get_path_ident_name(generic) == "Object" => Gc::Obj,
408            _ => Gc::Other,
409        };
410        ArgType::Gc(inner)
411    } else if outer_type.ident == "Env" {
412        ArgType::Env(false)
413    } else {
414        ArgType::Other
415    }
416}
417
418fn get_rt_type(type_path: &syn::TypePath, mutable: bool) -> Result<ArgType, Error> {
419    let segment = type_path.path.segments.last().unwrap();
420    match get_generic_param(segment) {
421        Some(syn::Type::Path(inner)) => match get_object_type(inner) {
422            ArgType::Gc(gc) => Ok(ArgType::Rt(gc)),
423            ArgType::Env(_) => Ok(ArgType::Env(mutable)),
424            _ => Err(Error::new_spanned(inner, "Found Rt of non-Gc type")),
425        },
426        _ => Ok(ArgType::Other),
427    }
428}
429
430fn get_generic_param(outer_type: &syn::PathSegment) -> Option<&syn::Type> {
431    match &outer_type.arguments {
432        syn::PathArguments::AngleBracketed(generic) => match generic.args.first().unwrap() {
433            syn::GenericArgument::Type(ty) => Some(ty),
434            _ => None,
435        },
436        _ => None,
437    }
438}
439
440#[derive(Default, PartialEq, Debug, FromMeta)]
441pub(crate) struct Spec {
442    #[darling(default)]
443    name: Option<String>,
444    #[darling(default)]
445    required: Option<u16>,
446}
447
448#[cfg(test)]
449mod test {
450    use super::*;
451
452    fn test_sig(stream: TokenStream, min: Option<u16>, expect: (u16, u16, bool)) {
453        let function: Function = syn::parse2(stream).unwrap();
454        let sig = parse_call_signature(&function.args, min);
455        assert_eq!(sig, expect);
456    }
457
458    #[test]
459    fn sig() {
460        test_sig(quote! {fn foo() -> u8 {}}, None, (0, 0, false));
461        test_sig(quote! {fn foo(vars: &[u8]) -> u8 {0}}, None, (0, 0, true));
462        test_sig(quote! {fn foo(var: u8) -> u8 {0}}, None, (1, 0, false));
463        test_sig(quote! {fn foo(var0: u8, var1: u8, vars: &[u8]) -> u8 {0}}, None, (2, 0, true));
464        test_sig(
465            quote! {fn foo(var0: u8, var1: Option<u8>, vars: &[u8]) -> u8 {0}},
466            None,
467            (1, 1, true),
468        );
469        test_sig(
470            quote! {fn foo(var0: u8, var1: Option<u8>, var2: Option<u8>) -> u8 {0}},
471            Some(2),
472            (2, 1, false),
473        );
474        test_sig(
475            quote! { fn foo(a: &Rt<Slot<Gc<foo>>>, b: &[Rt<Slot<Object>>], env: &Rt<Env>, cx: &mut Context) -> u8 {0} },
476            None,
477            (1, 0, true),
478        );
479        test_sig(
480            quote! { fn foo(env: &Rt<Env>, a: &Rt<Slot<Gc<foo>>>, x: Option<u8>, cx: &mut Context, b: &[Rt<Slot<Object>>]) -> u8 {0} },
481            None,
482            (1, 1, true),
483        );
484    }
485
486    #[expect(clippy::needless_pass_by_value)]
487    fn test_args(args: TokenStream, expect: &[ArgType]) {
488        let stream = quote! {fn foo(#args) -> u8 {0}};
489        let function: Function = syn::parse2(stream).unwrap();
490        let iter = std::iter::zip(function.args, expect);
491        for (cmp, exp) in iter {
492            assert_eq!(cmp, *exp, "input: `{args}`");
493        }
494    }
495
496    #[test]
497    fn test_arguments() {
498        test_args(quote! {x: Object}, &[ArgType::Gc(Gc::Obj)]);
499        test_args(quote! {x: Gc<T>}, &[ArgType::Gc(Gc::Other)]);
500        test_args(quote! {x: &Rt<Slot<Object>>}, &[ArgType::Rt(Gc::Obj)]);
501        test_args(quote! {x: &Rt<Slot<Gc<Object>>>}, &[ArgType::Rt(Gc::Obj)]);
502        test_args(quote! {x: &Rt<Slot<Gc<T>>>}, &[ArgType::Rt(Gc::Other)]);
503        test_args(quote! {x: &Rto<Object>}, &[ArgType::Rt(Gc::Obj)]);
504        test_args(quote! {x: &Rto<Gc<Object>>}, &[ArgType::Rt(Gc::Obj)]);
505        test_args(quote! {x: &Rto<Gc<T>>}, &[ArgType::Rt(Gc::Other)]);
506        test_args(quote! {x: u8}, &[ArgType::Other]);
507        test_args(quote! {x: Option<u8>}, &[ArgType::Option]);
508        test_args(quote! {x: Option<()>}, &[ArgType::Option]);
509        test_args(quote! {x: OptionalFlag}, &[ArgType::Option]);
510        test_args(quote! {x: Option<&Rt<Slot<Object>>>}, &[ArgType::OptionRt]);
511        test_args(quote! {x: Option<&Rto<Object>>}, &[ArgType::OptionRt]);
512        test_args(quote! {x: &[Object]}, &[ArgType::Slice(Gc::Obj)]);
513        test_args(quote! {x: &[Gc<T>]}, &[ArgType::Slice(Gc::Other)]);
514        test_args(quote! {x: &[Gc<T>]}, &[ArgType::Slice(Gc::Other)]);
515        test_args(quote! {x: &[u8]}, &[ArgType::Slice(Gc::Other)]);
516        test_args(quote! {x: ArgSlice}, &[ArgType::ArgSlice]);
517        test_args(quote! {x: &[Rt<Slot<Object>>]}, &[ArgType::SliceRt(Gc::Obj)]);
518        test_args(quote! {x: &[Rto<Object>]}, &[ArgType::SliceRt(Gc::Obj)]);
519        test_args(quote! {x: &mut Context}, &[ArgType::Context(MUT)]);
520        test_args(quote! {x: &Context}, &[ArgType::Context(false)]);
521        test_args(quote! {x: &Rt<Env>}, &[ArgType::Env(false)]);
522        test_args(quote! {x: &mut Rt<Env>}, &[ArgType::Env(MUT)]);
523        test_args(
524            quote! {x: u8, s: &[Rt<Slot<Object>>], y: &Context, z: &Rt<Env>},
525            &[
526                ArgType::Other,
527                ArgType::SliceRt(Gc::Obj),
528                ArgType::Context(false),
529                ArgType::Env(false),
530            ],
531        );
532
533        test_args(
534            quote! {x: u8, s: &[Rto<Object>], y: &Context, z: &Rt<Env>},
535            &[
536                ArgType::Other,
537                ArgType::SliceRt(Gc::Obj),
538                ArgType::Context(false),
539                ArgType::Env(false),
540            ],
541        );
542    }
543
544    fn check_error(stream: TokenStream) {
545        println!("signature: {stream}");
546        let function: Result<Function, _> = syn::parse2(stream);
547        assert!(function.is_err());
548    }
549
550    #[test]
551    fn test_error() {
552        check_error(quote! {fn foo(a: Object, a: &mut Context) {}});
553        check_error(quote! {fn foo(a: &[Rt<T>]) {}});
554        check_error(quote! {fn foo(a: Rt<Slot<Object>>) {}});
555        check_error(quote! {fn foo(a: u8, b: &[Object], c: &[Object]) {}});
556        check_error(quote! {fn foo(a: u8, b: Option<u8>, c: u8) {}});
557    }
558
559    #[test]
560    fn test_expand() {
561        let stream = quote! {
562            fn car<'ob>(list: Gc<List>, cx: &'ob Context) -> Object<'ob> {
563                match list.get() {
564                    List::Cons(cons) => cons.car(),
565                    List::Nil => NIL,
566                }
567            }
568        };
569        let function: Function = syn::parse2(stream).unwrap();
570        let spec = Spec { name: Some("+".into()), ..Default::default() };
571        let result = expand(function, spec);
572        println!("{result}");
573    }
574}