Skip to main content

supp_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, ItemFn, Lit, Path, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(item as ItemFn);
8    let fn_name = &input.sig.ident;
9    let block = &input.block;
10
11    let expanded = quote! {
12        #[sqlx::test(migrations = "../migrations")]
13        async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
14            setup().await;
15            DB_POOL.set(&pool);
16            #block
17        Ok(())
18        }
19    };
20
21    TokenStream::from(expanded)
22}
23
24#[proc_macro_derive(Builder, attributes(builder))]
25pub fn builder_macro(input: TokenStream) -> TokenStream {
26    // Parse the input tokens into a syntax tree
27    let input = parse_macro_input!(input as DeriveInput);
28    let name = &input.ident;
29    let generics = &input.generics; // Capture generics (including lifetimes)
30    let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
31
32    // Check for custom error_kind attribute
33    let mut error_kind = None;
34
35    // Parse attributes
36    for attr in &input.attrs {
37        if attr.path().is_ident("builder") {
38            attr.parse_nested_meta(|meta| {
39                if meta.path.is_ident("error_kind")
40                    && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
41                {
42                    error_kind = Some(lit_str.parse::<Path>().unwrap());
43                }
44                Ok(())
45            })
46            .unwrap();
47        }
48    }
49
50    // Set a default error kind if none is provided
51    let error_kind = error_kind.expect(
52        "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
53    );
54
55    // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
56    let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
57
58    let fields = if let Data::Struct(data) = &input.data {
59        if let Fields::Named(fields) = &data.fields {
60            fields.named.iter().collect::<Vec<_>>()
61        } else {
62            panic!("Builder macro only supports structs with named fields");
63        }
64    } else {
65        panic!("Builder macro only supports structs");
66    };
67
68    // Generate builder struct fields with the same generics (including lifetimes)
69    let builder_fields = fields.iter().map(|field| {
70        let field_name = &field.ident;
71        let field_ty = &field.ty;
72        let builder_field_type = quote! { Option<#field_ty> };
73        quote! {
74            #field_name: #builder_field_type
75        }
76    });
77
78    // Generate initialization in new()
79    let builder_fields_init = fields.iter().map(|field| {
80        let field_name = &field.ident;
81        quote! {
82            #field_name: None
83        }
84    });
85
86    // Generate setter methods
87    let setters = fields.iter().map(|field| {
88        let field_name = &field.ident;
89        let field_type = &field.ty;
90
91        if is_option_type(field_type) {
92            let inner_type = get_inner_type(field_type);
93            if is_string_type(&inner_type) {
94                // For Option<String>, accept &str
95                quote! {
96                    pub fn #field_name(&mut self, value: &str) -> &mut Self {
97                        self.#field_name = Some(Some(value.to_string()));
98                        self
99                    }
100                }
101            } else {
102                // For Option<T>, accept T directly
103                quote! {
104                    pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
105                        self.#field_name = Some(Some(value));
106                        self
107                    }
108                }
109            }
110        } else if is_string_type(field_type) {
111            // For String, accept &str
112            quote! {
113                pub fn #field_name(&mut self, value: &str) -> &mut Self {
114                    self.#field_name = Some(value.to_string());
115                    self
116                }
117            }
118        } else {
119            // For non-Option<T> and non-String fields, accept T directly
120            quote! {
121                pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
122                    self.#field_name = Some(value);
123                    self
124                }
125            }
126        }
127    });
128
129    // Generate code to check for missing required fields
130    let check_required_fields = fields
131        .iter()
132        .filter(|field| !is_option_type(&field.ty))
133        .map(|field| {
134            let field_name = &field.ident;
135            let field_name_str = field_name.as_ref().unwrap().to_string();
136            quote! {
137                if self.#field_name.is_none() {
138                    missing_fields.push(#field_name_str);
139                }
140            }
141        });
142
143    // Generate build_fields
144    let build_fields = fields.iter().map(|field| {
145        let field_name = &field.ident;
146        if is_option_type(&field.ty) {
147            quote! {
148                #field_name: self.#field_name.clone().unwrap_or(None)
149            }
150        } else {
151            quote! {
152                #field_name: self.#field_name.clone().unwrap()
153            }
154        }
155    });
156
157    // Extract the lifetime parameters from generics for use in the builder struct
158    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
159
160    let expanded = quote! {
161        pub struct #builder_name #impl_generics #where_clause {
162            #(#builder_fields),*
163        }
164
165        impl #impl_generics #builder_name #ty_generics #where_clause {
166            pub fn new() -> Self {
167                Self {
168                    #(#builder_fields_init),*
169                }
170            }
171
172            #(#setters)*
173
174            pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
175                let mut missing_fields = Vec::new();
176                #(#check_required_fields)*
177
178                if !missing_fields.is_empty() {
179                    return Err(#error_kind::from(#custom_error_name::Build(format!(
180                        "{} fields are missing: {}",
181                        stringify!(#name),
182                        missing_fields.join(", ")
183                    ))));
184                }
185
186                Ok(#name {
187                    #(#build_fields),*
188                })
189            }
190        }
191
192        impl #impl_generics #name #ty_generics #where_clause {
193            pub fn builder() -> #builder_name #ty_generics {
194                #builder_name::new()
195            }
196        }
197    };
198
199    TokenStream::from(expanded)
200}
201
202/// Helper function to determine if a type is an `Option<T>`
203fn is_option_type(ty: &syn::Type) -> bool {
204    matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
205}
206
207/// Helper function to get the inner type of an `Option<T>`
208fn get_inner_type(ty: &syn::Type) -> syn::Type {
209    if let syn::Type::Path(type_path) = ty
210        && let Some(segment) = type_path.path.segments.first()
211        && segment.ident == "Option"
212        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
213        && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
214    {
215        return inner_type.clone();
216    }
217    ty.clone()
218}
219
220/// Helper function to check if the type is String
221fn is_string_type(ty: &syn::Type) -> bool {
222    if let syn::Type::Path(type_path) = ty
223        && let Some(segment) = type_path.path.segments.last()
224    {
225        return segment.ident == "String";
226    }
227    false
228}
229
230/// A procedural macro for generating typed Command implementations with compile-time validation.
231///
232/// This macro provides pure value-based argument passing with compile-time type safety by generating:
233/// - Typed Args structs with proper field types passed by value only
234/// - Commands that accept Args structs directly (no `HashMap` usage)
235/// - Individual typed variables available directly in command scope
236/// - Compile-time validation of argument types and required/optional fields
237/// - Zero runtime argument parsing or validation overhead
238///
239/// # Syntax
240///
241/// ```ignore
242/// command! {
243///     CommandName {
244///         #[required]
245///         arg_name: Type,
246///         #[optional]
247///         opt_name: Type,
248///     } => {
249///         // Command implementation body
250///         // Individual typed variables are available in scope
251///     }
252/// }
253/// ```
254///
255/// # Generated Code
256///
257/// The macro generates:
258/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
259/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
260/// - Individual typed variables extracted from the Args struct and available in the command body
261/// - Pure compile-time type validation with no runtime overhead
262///
263/// # Examples
264///
265/// ## Simple command with no arguments
266///
267/// ```rust
268/// # use supp_macro::command;
269/// # use async_trait::async_trait;
270/// #
271/// # #[derive(Debug)]
272/// # pub enum CmdError {
273/// #     Args(String),
274/// # }
275/// #
276/// # impl std::fmt::Display for CmdError {
277/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278/// #         match self {
279/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
280/// #         }
281/// #     }
282/// # }
283/// #
284/// # impl std::error::Error for CmdError {}
285/// #
286/// # #[derive(Debug)]
287/// # pub enum CmdResult {
288/// #     String(String),
289/// # }
290/// #
291/// # #[derive(Debug, Default)]
292/// # pub struct CommandArgs {}
293/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
294/// #
295/// # #[async_trait]
296/// # pub trait Command: std::fmt::Debug {
297/// #     type Args;
298/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
299/// # }
300///
301/// command! {
302///     GetVersion {
303///     } => {
304///         Ok(Some(CmdResult::String("1.0.0".to_string())))
305///     }
306/// }
307///
308/// # #[tokio::main]
309/// # async fn main() {
310/// let result = GetVersion::new().run().await.unwrap();
311/// # }
312/// ```
313///
314/// ## Command with required arguments (server-compatible types)
315///
316/// ```rust
317/// # use supp_macro::command;
318/// # use async_trait::async_trait;
319/// # use uuid::Uuid;
320/// # use num_rational::Rational64;
321/// #
322/// # #[derive(Debug, Clone)]
323/// # pub enum Argument {
324/// #     String(String),
325/// #     Uuid(Uuid),
326/// #     Rational(Rational64),
327/// # }
328/// #
329/// # #[derive(Debug)]
330/// # pub enum CmdError {
331/// #     Args(String),
332/// # }
333/// #
334/// # impl std::fmt::Display for CmdError {
335/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336/// #         match self {
337/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
338/// #         }
339/// #     }
340/// # }
341/// #
342/// # impl std::error::Error for CmdError {}
343/// #
344/// # #[derive(Debug)]
345/// # pub enum CmdResult {
346/// #     String(String),
347/// # }
348/// #
349/// # #[derive(Debug, Default)]
350/// # pub struct CommandArgs {
351/// #     pub symbol: Option<String>,
352/// #     pub name: Option<String>,
353/// #     pub user_id: Option<uuid::Uuid>,
354/// # }
355/// # impl CommandArgs {
356/// #     pub fn new() -> Self { Self::default() }
357/// #     pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
358/// #     pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
359/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
360/// # }
361/// #
362/// # #[async_trait]
363/// # pub trait Command: std::fmt::Debug {
364/// #     type Args;
365/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
366/// # }
367///
368/// // This creates a commodity in the financial system
369/// command! {
370///     CreateCommodity {
371///         #[required]
372///         symbol: String,
373///         #[required]
374///         name: String,
375///         #[required]
376///         user_id: Uuid,
377///     } => {
378///         // Individual typed variables are automatically available
379///         Ok(Some(CmdResult::String(format!(
380///             "Created commodity {} ({}) for user {}",
381///             name, symbol, user_id
382///         ))))
383///     }
384/// }
385///
386/// # #[tokio::main]
387/// # async fn main() {
388/// let result = CreateCommodity::new()
389///     .symbol("USD".to_string())
390///     .name("US Dollar".to_string())
391///     .user_id(uuid::Uuid::new_v4())
392///     .run()
393///     .await
394///     .unwrap();
395/// # }
396/// ```
397///
398/// ## Command with optional arguments
399///
400/// ```rust
401/// # use supp_macro::command;
402/// # use async_trait::async_trait;
403/// # use uuid::Uuid;
404/// #
405/// # #[derive(Debug, Clone)]
406/// # pub enum Argument {
407/// #     String(String),
408/// #     Uuid(Uuid),
409/// # }
410/// #
411/// # #[derive(Debug)]
412/// # pub enum CmdError {
413/// #     Args(String),
414/// # }
415/// #
416/// # impl std::fmt::Display for CmdError {
417/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418/// #         match self {
419/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
420/// #         }
421/// #     }
422/// # }
423/// #
424/// # impl std::error::Error for CmdError {}
425/// #
426/// # #[derive(Debug)]
427/// # pub enum CmdResult {
428/// #     String(String),
429/// # }
430/// #
431/// # #[derive(Debug, Default)]
432/// # pub struct CommandArgs {
433/// #     pub user_id: Option<uuid::Uuid>,
434/// #     pub account: Option<String>,
435/// # }
436/// # impl CommandArgs {
437/// #     pub fn new() -> Self { Self::default() }
438/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
439/// #     pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
440/// # }
441/// #
442/// # #[async_trait]
443/// # pub trait Command: std::fmt::Debug {
444/// #     type Args;
445/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
446/// # }
447///
448/// command! {
449///     ListTransactions {
450///         #[required]
451///         user_id: Uuid,
452///         #[optional]
453///         account: String,
454///     } => {
455///         let filter = if let Some(account) = account {
456///             format!(" for account {}", account)
457///         } else {
458///             String::new()
459///         };
460///         Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
461///     }
462/// }
463/// ```
464///
465/// ## Command with mixed required and optional arguments
466///
467/// ```rust
468/// # use supp_macro::command;
469/// # use async_trait::async_trait;
470/// #
471/// # #[derive(Debug, Clone)]
472/// # pub enum Argument {
473/// #     String(String),
474/// #     Integer(i64),
475/// #     Boolean(bool),
476/// # }
477/// #
478/// # #[derive(Debug)]
479/// # pub enum CmdError {
480/// #     Args(String),
481/// # }
482/// #
483/// # impl std::fmt::Display for CmdError {
484/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485/// #         match self {
486/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
487/// #         }
488/// #     }
489/// # }
490/// #
491/// # impl std::error::Error for CmdError {}
492/// #
493/// # #[derive(Debug)]
494/// # pub enum CmdResult {
495/// #     Success(String),
496/// # }
497/// #
498/// # impl TryFrom<Argument> for String {
499/// #     type Error = CmdError;
500/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
501/// #         match arg {
502/// #             Argument::String(s) => Ok(s),
503/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
504/// #         }
505/// #     }
506/// # }
507/// #
508/// # impl TryFrom<Argument> for i64 {
509/// #     type Error = CmdError;
510/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
511/// #         match arg {
512/// #             Argument::Integer(i) => Ok(i),
513/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
514/// #         }
515/// #     }
516/// # }
517/// #
518/// # impl TryFrom<Argument> for bool {
519/// #     type Error = CmdError;
520/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
521/// #         match arg {
522/// #             Argument::Boolean(b) => Ok(b),
523/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
524/// #         }
525/// #     }
526/// # }
527/// #
528/// # #[async_trait]
529/// # pub trait TypedCommand {
530/// #     type Args;
531/// #     async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
532/// # }
533/// #
534/// # #[derive(Debug, Default)]
535/// # pub struct CommandArgs {
536/// #     pub user_id: Option<i64>,
537/// #     pub username: Option<String>,
538/// #     pub email: Option<String>,
539/// #     pub is_admin: Option<bool>,
540/// # }
541/// # impl CommandArgs {
542/// #     pub fn new() -> Self { Self::default() }
543/// #     pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
544/// #     pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
545/// #     pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
546/// #     pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
547/// # }
548/// #
549/// # #[async_trait]
550/// # pub trait Command {
551/// #     type Args;
552/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
553/// # }
554///
555/// command! {
556///     CreateUserCommand {
557///         #[required]
558///         user_id: i64,
559///         #[required]
560///         username: String,
561///         #[optional]
562///         email: String,
563///         #[optional]
564///         is_admin: bool,
565///     } => {
566///         let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
567///         let admin_status = is_admin.unwrap_or(false);
568///
569///         let message = format!(
570///             "Created user {} (ID: {}, Email: {}, Admin: {})",
571///             username, user_id, email_str, admin_status
572///         );
573///         Ok(Some(CmdResult::Success(message)))
574///     }
575/// }
576///
577/// # #[tokio::main]
578/// # async fn main() {
579/// let result = CreateUserCommand::new()
580///     .user_id(123)
581///     .username("alice".to_string())
582///     .is_admin(true)
583///     .run()
584///     .await
585///     .unwrap();
586/// # }
587/// ```
588///
589/// ## Server-compatible Command implementation
590///
591/// ```rust
592/// # use supp_macro::command;
593/// # use async_trait::async_trait;
594/// #
595/// # #[derive(Debug, Clone)]
596/// # pub enum Argument {
597/// #     String(String),
598/// #     Integer(i64),
599/// #     Boolean(bool),
600/// # }
601/// #
602/// # #[derive(Debug)]
603/// # pub enum CmdError {
604/// #     Args(String),
605/// # }
606/// #
607/// # impl std::fmt::Display for CmdError {
608/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609/// #         match self {
610/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
611/// #         }
612/// #     }
613/// # }
614/// #
615/// # impl std::error::Error for CmdError {}
616/// #
617/// # #[derive(Debug)]
618/// # pub enum CmdResult {
619/// #     Success(String),
620/// # }
621/// #
622/// # #[derive(Debug, Default)]
623/// # pub struct CommandArgs {
624/// #     pub a: Option<i64>,
625/// #     pub b: Option<i64>,
626/// # }
627/// # impl CommandArgs {
628/// #     pub fn new() -> Self { Self::default() }
629/// #     pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
630/// #     pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
631/// # }
632/// #
633/// # #[async_trait]
634/// # pub trait Command {
635/// #     type Args;
636/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
637/// # }
638///
639/// command! {
640///     CalculateCommand {
641///         #[required]
642///         a: i64,
643///         #[required]
644///         b: i64,
645///     } => {
646///         let result = a + b;
647///         Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
648///     }
649/// }
650///
651/// # #[tokio::main]
652/// # async fn main() {
653/// let result = CalculateCommand::new()
654///     .a(10)
655///     .b(20)
656///     .run()
657///     .await
658///     .unwrap();
659/// # }
660/// ```
661///
662/// ## Migration from Manual Commands
663///
664/// The macro makes it easy to migrate from manual Command implementations:
665///
666/// ```rust,ignore
667/// // BEFORE: Manual implementation
668/// #[derive(Debug)]
669/// pub struct GetConfig;
670///
671/// #[async_trait]
672/// impl Command for GetConfig {
673///     async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
674///         if let Some(Argument::String(name)) = args.get("name") {
675///             Ok(config(name).await?.map(|v| CmdResult::String(v)))
676///         } else {
677///             Err(CmdError::Args("No field name provided".to_string()))
678///         }
679///     }
680/// }
681///
682/// // AFTER: Using the macro
683/// command! {
684///     GetConfig {
685///         #[required]
686///         name: String,
687///     } => {
688///         Ok(config(name).await?.map(|v| CmdResult::String(v)))
689///     }
690/// }
691/// ```
692///
693/// # Error Handling
694///
695/// The new pure typed system provides compile-time error prevention:
696///
697/// - Missing required arguments are compile-time errors (cannot compile without them)
698/// - Invalid argument types are compile-time errors (type checking at build time)
699/// - Runtime errors only occur in the command body logic itself
700/// - No argument validation overhead at runtime
701///
702/// # Supported Argument Types
703///
704/// The macro supports any Rust type for arguments:
705/// - `String` - Text arguments
706/// - `i64`, `u64`, etc. - Integer arguments
707/// - `bool` - Boolean arguments
708/// - `Rational64` - Rational number arguments (for financial precision)
709/// - `Uuid` - UUID arguments
710/// - `Vec<u8>` - Binary data arguments
711/// - `DateTime<Utc>` - `DateTime` arguments
712/// - Custom types - Any type can be used as an argument
713/// - `Option<T>` - Automatically applied for optional arguments
714///
715
716#[proc_macro]
717pub fn command(input: TokenStream) -> TokenStream {
718    let input = parse_macro_input!(input as CommandInput);
719
720    let name = &input.name;
721    let required_args = &input.required_args;
722    let optional_args = &input.optional_args;
723    let body = &input.body;
724
725    // Generate progressive runner types for all combinations of required fields
726    let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
727
728    // Generate the main command struct
729    let command_struct = quote! {
730        #[derive(Debug)]
731        pub struct #name;
732    };
733
734    // Generate the new() method that starts the builder chain
735    let new_method = generate_new_method(name, required_args.len(), optional_args);
736
737    let expanded = quote! {
738        #command_struct
739
740        #runner_types
741
742        #new_method
743    };
744
745    TokenStream::from(expanded)
746}
747
748/// Generate all possible runner type combinations for required fields
749fn generate_progressive_runner_types(
750    command_name: &syn::Ident,
751    required_args: &[(syn::Ident, syn::Type)],
752    optional_args: &[(syn::Ident, syn::Type)],
753    body: &syn::Block,
754) -> proc_macro2::TokenStream {
755    let num_required = required_args.len();
756    let total_combinations = 1 << num_required; // 2^num_required
757
758    let mut runner_types = Vec::new();
759
760    // Generate a runner type for each possible combination of set required fields
761    for combination in 0..total_combinations {
762        let runner_type = generate_single_runner_type(
763            command_name,
764            required_args,
765            optional_args,
766            combination,
767            num_required,
768            body,
769        );
770        runner_types.push(runner_type);
771    }
772
773    quote! {
774        #(#runner_types)*
775    }
776}
777
778/// Generate a single runner type for a specific combination of set fields
779fn generate_single_runner_type(
780    command_name: &syn::Ident,
781    required_args: &[(syn::Ident, syn::Type)],
782    optional_args: &[(syn::Ident, syn::Type)],
783    combination: usize,
784    num_required: usize,
785    body: &syn::Block,
786) -> proc_macro2::TokenStream {
787    // Create binary representation for the runner type name
788    let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
789    let runner_name = syn::Ident::new(
790        &format!("{command_name}Runner{binary_suffix}"),
791        command_name.span(),
792    );
793
794    // Determine which required fields are set in this combination
795    let mut struct_fields = Vec::new();
796    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
797        if (combination >> i) & 1 == 1 {
798            // This required field is set in this combination
799            struct_fields.push(quote! {
800                pub #field_name: #field_type
801            });
802        }
803    }
804
805    // Always include optional fields in all runner types
806    for (field_name, field_type) in optional_args {
807        struct_fields.push(quote! {
808            pub #field_name: Option<#field_type>
809        });
810    }
811
812    // Generate the struct definition
813    let struct_def = if struct_fields.is_empty() {
814        quote! {
815            #[derive(Debug)]
816            pub struct #runner_name;
817        }
818    } else {
819        quote! {
820            #[derive(Debug)]
821            pub struct #runner_name {
822                #(#struct_fields),*
823            }
824        }
825    };
826
827    // Generate transition methods for this runner type
828    let transition_methods = generate_transition_methods(
829        command_name,
830        required_args,
831        optional_args,
832        combination,
833        num_required,
834    );
835
836    // Generate run method if this is the complete state (all required fields set)
837    let complete_mask = (1 << num_required) - 1;
838    let run_method = if combination == complete_mask {
839        generate_run_method(command_name, required_args, optional_args, body)
840    } else {
841        quote! {}
842    };
843
844    quote! {
845        #struct_def
846
847        impl #runner_name {
848            #transition_methods
849            #run_method
850        }
851    }
852}
853
854/// Generate transition methods for a runner type (field setters)
855fn generate_transition_methods(
856    command_name: &syn::Ident,
857    required_args: &[(syn::Ident, syn::Type)],
858    optional_args: &[(syn::Ident, syn::Type)],
859    current_combination: usize,
860    num_required: usize,
861) -> proc_macro2::TokenStream {
862    let mut methods = Vec::new();
863
864    // Generate setter methods for required fields not yet set
865    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
866        if (current_combination >> i) & 1 == 0 {
867            // This required field is not set yet, generate a setter
868            let new_combination = current_combination | (1 << i);
869            let binary_suffix =
870                format!("{:0width$b}", new_combination, width = num_required.max(1));
871            let target_runner = syn::Ident::new(
872                &format!("{command_name}Runner{binary_suffix}"),
873                command_name.span(),
874            );
875
876            let method = generate_field_setter_method(
877                command_name,
878                required_args,
879                optional_args,
880                field_name,
881                field_type,
882                current_combination,
883                new_combination,
884                &target_runner,
885                num_required,
886            );
887            methods.push(method);
888        }
889    }
890
891    // Generate setter methods for optional fields (available on all runner types)
892    for (field_name, field_type) in optional_args {
893        let current_runner = syn::Ident::new(
894            &format!(
895                "{}Runner{:0width$b}",
896                command_name,
897                current_combination,
898                width = num_required.max(1)
899            ),
900            command_name.span(),
901        );
902
903        let method = generate_optional_field_setter(
904            field_name,
905            field_type,
906            &current_runner,
907            required_args,
908            optional_args,
909            current_combination,
910            num_required,
911        );
912        methods.push(method);
913    }
914
915    quote! {
916        #(#methods)*
917    }
918}
919
920/// Generate a setter method for a required field
921fn generate_field_setter_method(
922    _command_name: &syn::Ident,
923    required_args: &[(syn::Ident, syn::Type)],
924    optional_args: &[(syn::Ident, syn::Type)],
925    field_name: &syn::Ident,
926    field_type: &syn::Type,
927    current_combination: usize,
928    _new_combination: usize,
929    target_runner: &syn::Ident,
930    _num_required: usize,
931) -> proc_macro2::TokenStream {
932    // Generate field assignments for the new state
933    let mut field_assignments = Vec::new();
934
935    // Handle required fields
936    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
937        if req_field_name == field_name {
938            // This is the field being set
939            field_assignments.push(quote! {
940                #req_field_name: value
941            });
942        } else if (current_combination >> i) & 1 == 1 {
943            // This field was already set, move it from self
944            field_assignments.push(quote! {
945                #req_field_name: self.#req_field_name
946            });
947        }
948        // Fields not set in either state are omitted
949    }
950
951    // Handle optional fields (always present, move from self)
952    for (opt_field_name, _) in optional_args {
953        field_assignments.push(quote! {
954            #opt_field_name: self.#opt_field_name
955        });
956    }
957
958    // Generate the constructor call
959    let constructor = if field_assignments.is_empty() {
960        quote! { #target_runner }
961    } else {
962        quote! {
963            #target_runner {
964                #(#field_assignments),*
965            }
966        }
967    };
968
969    quote! {
970        pub fn #field_name(self, value: #field_type) -> #target_runner {
971            #constructor
972        }
973    }
974}
975
976/// Generate a setter method for an optional field
977fn generate_optional_field_setter(
978    field_name: &syn::Ident,
979    field_type: &syn::Type,
980    current_runner: &syn::Ident,
981    required_args: &[(syn::Ident, syn::Type)],
982    optional_args: &[(syn::Ident, syn::Type)],
983    current_combination: usize,
984    _num_required: usize,
985) -> proc_macro2::TokenStream {
986    // Generate field assignments (same state, but update the optional field)
987    let mut field_assignments = Vec::new();
988
989    // Handle required fields (move from self if set)
990    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
991        if (current_combination >> i) & 1 == 1 {
992            field_assignments.push(quote! {
993                #req_field_name: self.#req_field_name
994            });
995        }
996    }
997
998    // Handle optional fields
999    for (opt_field_name, _) in optional_args {
1000        if opt_field_name == field_name {
1001            // This is the field being set
1002            field_assignments.push(quote! {
1003                #opt_field_name: Some(value)
1004            });
1005        } else {
1006            // Move other optional fields from self
1007            field_assignments.push(quote! {
1008                #opt_field_name: self.#opt_field_name
1009            });
1010        }
1011    }
1012
1013    let constructor = if field_assignments.is_empty() {
1014        quote! { #current_runner }
1015    } else {
1016        quote! {
1017            #current_runner {
1018                #(#field_assignments),*
1019            }
1020        }
1021    };
1022
1023    quote! {
1024        pub fn #field_name(self, value: #field_type) -> #current_runner {
1025            #constructor
1026        }
1027    }
1028}
1029
1030/// Generate the run method for the complete runner state
1031fn generate_run_method(
1032    _command_name: &syn::Ident,
1033    required_args: &[(syn::Ident, syn::Type)],
1034    optional_args: &[(syn::Ident, syn::Type)],
1035    body: &syn::Block,
1036) -> proc_macro2::TokenStream {
1037    // Extract field values directly (no unwrap needed!)
1038    let mut variable_assignments = Vec::new();
1039
1040    // Required fields - direct field access
1041    for (field_name, _) in required_args {
1042        variable_assignments.push(quote! {
1043            let #field_name = self.#field_name;
1044        });
1045    }
1046
1047    // Optional fields - direct field access
1048    for (field_name, _) in optional_args {
1049        variable_assignments.push(quote! {
1050            let #field_name = self.#field_name;
1051        });
1052    }
1053
1054    quote! {
1055        pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
1056            // Zero runtime checks - direct field access!
1057            #(#variable_assignments)*
1058
1059            // Original command body
1060            #body
1061        }
1062    }
1063}
1064
1065/// Generate the `new()` method for the command
1066fn generate_new_method(
1067    command_name: &syn::Ident,
1068    num_required: usize,
1069    optional_args: &[(syn::Ident, syn::Type)],
1070) -> proc_macro2::TokenStream {
1071    let initial_runner = syn::Ident::new(
1072        &format!(
1073            "{}Runner{:0width$b}",
1074            command_name,
1075            0,
1076            width = num_required.max(1)
1077        ),
1078        command_name.span(),
1079    );
1080
1081    // Initial state has no required fields set, but has optional fields as None
1082    let constructor = if optional_args.is_empty() && num_required > 0 {
1083        // Unit struct (no fields at all in initial state)
1084        quote! { #initial_runner }
1085    } else {
1086        // Struct with optional fields initialized to None
1087        let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
1088            quote! { #field_name: None }
1089        });
1090
1091        if optional_field_inits.len() > 0 {
1092            quote! {
1093                #initial_runner {
1094                    #(#optional_field_inits),*
1095                }
1096            }
1097        } else {
1098            quote! { #initial_runner }
1099        }
1100    };
1101
1102    quote! {
1103        impl #command_name {
1104            pub fn new() -> #initial_runner {
1105                #constructor
1106            }
1107        }
1108    }
1109}
1110
1111struct CommandInput {
1112    name: syn::Ident,
1113    required_args: Vec<(syn::Ident, syn::Type)>,
1114    optional_args: Vec<(syn::Ident, syn::Type)>,
1115    body: syn::Block,
1116}
1117
1118impl syn::parse::Parse for CommandInput {
1119    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1120        let name: syn::Ident = input.parse()?;
1121
1122        let content;
1123        syn::braced!(content in input);
1124
1125        let mut required_args = Vec::new();
1126        let mut optional_args = Vec::new();
1127
1128        while !content.is_empty() {
1129            // Parse attributes
1130            let mut is_optional = false;
1131            let mut is_required = false;
1132
1133            while content.peek(syn::Token![#]) {
1134                content.parse::<syn::Token![#]>()?;
1135                let attr_content;
1136                syn::bracketed!(attr_content in content);
1137                let attr_name: syn::Ident = attr_content.parse()?;
1138
1139                if attr_name == "optional" {
1140                    is_optional = true;
1141                } else if attr_name == "required" {
1142                    is_required = true;
1143                } else {
1144                    return Err(syn::Error::new(
1145                        attr_name.span(),
1146                        "Unknown attribute. Use #[required] or #[optional]",
1147                    ));
1148                }
1149            }
1150
1151            // Parse the field
1152            let arg_name: syn::Ident = content.parse()?;
1153            content.parse::<syn::Token![:]>()?;
1154            let arg_type: syn::Type = content.parse()?;
1155
1156            if content.peek(syn::Token![,]) {
1157                content.parse::<syn::Token![,]>()?;
1158            }
1159
1160            // Determine if optional (default to required if no attribute specified)
1161            let is_optional_field = if is_required && is_optional {
1162                return Err(syn::Error::new(
1163                    arg_name.span(),
1164                    "Field cannot be both #[required] and #[optional]",
1165                ));
1166            } else if is_optional {
1167                true
1168            } else {
1169                false // Default to required
1170            };
1171
1172            if is_optional_field {
1173                optional_args.push((arg_name, arg_type));
1174            } else {
1175                required_args.push((arg_name, arg_type));
1176            }
1177        }
1178
1179        // The '=>' is outside the braces
1180        input.parse::<syn::Token![=>]>()?;
1181        let body: syn::Block = input.parse()?;
1182
1183        Ok(CommandInput {
1184            name,
1185            required_args,
1186            optional_args,
1187            body,
1188        })
1189    }
1190}