Skip to main content

supp_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::Lit;
4use syn::{Data, DeriveInput, Fields, ItemFn, Path, parse_macro_input};
5
6#[proc_macro_attribute]
7pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(item as ItemFn);
9    let fn_name = &input.sig.ident;
10    let block = &input.block;
11
12    let expanded = quote! {
13        #[sqlx::test(migrations = "../migrations")]
14        async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
15            setup().await;
16            DB_POOL.set(&pool);
17            #block
18        Ok(())
19        }
20    };
21
22    TokenStream::from(expanded)
23}
24
25#[proc_macro_derive(Builder, attributes(builder))]
26pub fn builder_macro(input: TokenStream) -> TokenStream {
27    // Parse the input tokens into a syntax tree
28    let input = parse_macro_input!(input as DeriveInput);
29    let name = &input.ident;
30    let generics = &input.generics; // Capture generics (including lifetimes)
31    let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
32
33    // Check for custom error_kind attribute
34    let mut error_kind = None;
35
36    // Parse attributes
37    for attr in &input.attrs {
38        if attr.path().is_ident("builder") {
39            attr.parse_nested_meta(|meta| {
40                if meta.path.is_ident("error_kind")
41                    && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
42                {
43                    error_kind = Some(lit_str.parse::<Path>().unwrap());
44                }
45                Ok(())
46            })
47            .unwrap();
48        }
49    }
50
51    // Set a default error kind if none is provided
52    let error_kind = error_kind.expect(
53        "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
54    );
55
56    // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
57    let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
58
59    let fields = if let Data::Struct(data) = &input.data {
60        if let Fields::Named(fields) = &data.fields {
61            fields.named.iter().collect::<Vec<_>>()
62        } else {
63            panic!("Builder macro only supports structs with named fields");
64        }
65    } else {
66        panic!("Builder macro only supports structs");
67    };
68
69    // Generate builder struct fields with the same generics (including lifetimes)
70    let builder_fields = fields.iter().map(|field| {
71        let field_name = &field.ident;
72        let field_ty = &field.ty;
73        let builder_field_type = quote! { Option<#field_ty> };
74        quote! {
75            #field_name: #builder_field_type
76        }
77    });
78
79    // Generate initialization in new()
80    let builder_fields_init = fields.iter().map(|field| {
81        let field_name = &field.ident;
82        quote! {
83            #field_name: None
84        }
85    });
86
87    // Generate setter methods
88    let setters = fields.iter().map(|field| {
89        let field_name = &field.ident;
90        let field_type = &field.ty;
91
92        if is_option_type(field_type) {
93            let inner_type = get_inner_type(field_type);
94            if is_string_type(&inner_type) {
95                // For Option<String>, accept &str
96                quote! {
97                    pub fn #field_name(&mut self, value: &str) -> &mut Self {
98                        self.#field_name = Some(Some(value.to_string()));
99                        self
100                    }
101                }
102            } else {
103                // For Option<T>, accept T directly
104                quote! {
105                    pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
106                        self.#field_name = Some(Some(value));
107                        self
108                    }
109                }
110            }
111        } else if is_string_type(field_type) {
112            // For String, accept &str
113            quote! {
114                pub fn #field_name(&mut self, value: &str) -> &mut Self {
115                    self.#field_name = Some(value.to_string());
116                    self
117                }
118            }
119        } else {
120            // For non-Option<T> and non-String fields, accept T directly
121            quote! {
122                pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
123                    self.#field_name = Some(value);
124                    self
125                }
126            }
127        }
128    });
129
130    // Generate code to check for missing required fields
131    let check_required_fields = fields
132        .iter()
133        .filter(|field| !is_option_type(&field.ty))
134        .map(|field| {
135            let field_name = &field.ident;
136            let field_name_str = field_name.as_ref().unwrap().to_string();
137            quote! {
138                if self.#field_name.is_none() {
139                    missing_fields.push(#field_name_str);
140                }
141            }
142        });
143
144    // Generate build_fields
145    let build_fields = fields.iter().map(|field| {
146        let field_name = &field.ident;
147        if is_option_type(&field.ty) {
148            quote! {
149                #field_name: self.#field_name.clone().unwrap_or(None)
150            }
151        } else {
152            quote! {
153                #field_name: self.#field_name.clone().unwrap()
154            }
155        }
156    });
157
158    // Extract the lifetime parameters from generics for use in the builder struct
159    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
160
161    let expanded = quote! {
162        pub struct #builder_name #impl_generics #where_clause {
163            #(#builder_fields),*
164        }
165
166        impl #impl_generics #builder_name #ty_generics #where_clause {
167            pub fn new() -> Self {
168                Self {
169                    #(#builder_fields_init),*
170                }
171            }
172
173            #(#setters)*
174
175            pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
176                let mut missing_fields = Vec::new();
177                #(#check_required_fields)*
178
179                if !missing_fields.is_empty() {
180                    return Err(#error_kind::from(#custom_error_name::Build(format!(
181                        "{} fields are missing: {}",
182                        stringify!(#name),
183                        missing_fields.join(", ")
184                    ))));
185                }
186
187                Ok(#name {
188                    #(#build_fields),*
189                })
190            }
191        }
192
193        impl #impl_generics #name #ty_generics #where_clause {
194            pub fn builder() -> #builder_name #ty_generics {
195                #builder_name::new()
196            }
197        }
198    };
199
200    TokenStream::from(expanded)
201}
202
203/// Helper function to determine if a type is an `Option<T>`
204fn is_option_type(ty: &syn::Type) -> bool {
205    matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
206}
207
208/// Helper function to get the inner type of an `Option<T>`
209fn get_inner_type(ty: &syn::Type) -> syn::Type {
210    if let syn::Type::Path(type_path) = ty
211        && let Some(segment) = type_path.path.segments.first()
212        && segment.ident == "Option"
213        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
214        && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
215    {
216        return inner_type.clone();
217    }
218    ty.clone()
219}
220
221/// Helper function to check if the type is String
222fn is_string_type(ty: &syn::Type) -> bool {
223    if let syn::Type::Path(type_path) = ty
224        && let Some(segment) = type_path.path.segments.last()
225    {
226        return segment.ident == "String";
227    }
228    false
229}
230
231/// A procedural macro for generating typed Command implementations with compile-time validation.
232///
233/// This macro provides pure value-based argument passing with compile-time type safety by generating:
234/// - Typed Args structs with proper field types passed by value only
235/// - Commands that accept Args structs directly (no `HashMap` usage)
236/// - Individual typed variables available directly in command scope
237/// - Compile-time validation of argument types and required/optional fields
238/// - Zero runtime argument parsing or validation overhead
239///
240/// # Syntax
241///
242/// ```ignore
243/// command! {
244///     CommandName {
245///         #[required]
246///         arg_name: Type,
247///         #[optional]
248///         opt_name: Type,
249///     } => {
250///         // Command implementation body
251///         // Individual typed variables are available in scope
252///     }
253/// }
254/// ```
255///
256/// # Generated Code
257///
258/// The macro generates:
259/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
260/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
261/// - Individual typed variables extracted from the Args struct and available in the command body
262/// - Pure compile-time type validation with no runtime overhead
263///
264/// # Examples
265///
266/// ## Simple command with no arguments
267///
268/// ```rust
269/// # use supp_macro::command;
270/// # use async_trait::async_trait;
271/// #
272/// # #[derive(Debug)]
273/// # pub enum CmdError {
274/// #     Args(String),
275/// # }
276/// #
277/// # impl std::fmt::Display for CmdError {
278/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279/// #         match self {
280/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
281/// #         }
282/// #     }
283/// # }
284/// #
285/// # impl std::error::Error for CmdError {}
286/// #
287/// # #[derive(Debug)]
288/// # pub enum CmdResult {
289/// #     String(String),
290/// # }
291/// #
292/// # #[derive(Debug, Default)]
293/// # pub struct CommandArgs {}
294/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
295/// #
296/// # #[async_trait]
297/// # pub trait Command: std::fmt::Debug {
298/// #     type Args;
299/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
300/// # }
301///
302/// command! {
303///     GetVersion {
304///     } => {
305///         Ok(Some(CmdResult::String("1.0.0".to_string())))
306///     }
307/// }
308///
309/// # #[tokio::main]
310/// # async fn main() {
311/// let result = GetVersion::new().run().await.unwrap();
312/// # }
313/// ```
314///
315/// ## Command with required arguments (server-compatible types)
316///
317/// ```rust
318/// # use supp_macro::command;
319/// # use async_trait::async_trait;
320/// # use uuid::Uuid;
321/// # use num_rational::Rational64;
322/// #
323/// # #[derive(Debug, Clone)]
324/// # pub enum Argument {
325/// #     String(String),
326/// #     Uuid(Uuid),
327/// #     Rational(Rational64),
328/// # }
329/// #
330/// # #[derive(Debug)]
331/// # pub enum CmdError {
332/// #     Args(String),
333/// # }
334/// #
335/// # impl std::fmt::Display for CmdError {
336/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337/// #         match self {
338/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
339/// #         }
340/// #     }
341/// # }
342/// #
343/// # impl std::error::Error for CmdError {}
344/// #
345/// # #[derive(Debug)]
346/// # pub enum CmdResult {
347/// #     String(String),
348/// # }
349/// #
350/// # #[derive(Debug, Default)]
351/// # pub struct CommandArgs {
352/// #     pub symbol: Option<String>,
353/// #     pub name: Option<String>,
354/// #     pub user_id: Option<uuid::Uuid>,
355/// # }
356/// # impl CommandArgs {
357/// #     pub fn new() -> Self { Self::default() }
358/// #     pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
359/// #     pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
360/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
361/// # }
362/// #
363/// # #[async_trait]
364/// # pub trait Command: std::fmt::Debug {
365/// #     type Args;
366/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
367/// # }
368///
369/// // This creates a commodity in the financial system
370/// command! {
371///     CreateCommodity {
372///         #[required]
373///         symbol: String,
374///         #[required]
375///         name: String,
376///         #[required]
377///         user_id: Uuid,
378///     } => {
379///         // Individual typed variables are automatically available
380///         Ok(Some(CmdResult::String(format!(
381///             "Created commodity {} ({}) for user {}",
382///             name, symbol, user_id
383///         ))))
384///     }
385/// }
386///
387/// # #[tokio::main]
388/// # async fn main() {
389/// let result = CreateCommodity::new()
390///     .symbol("USD".to_string())
391///     .name("US Dollar".to_string())
392///     .user_id(uuid::Uuid::new_v4())
393///     .run()
394///     .await
395///     .unwrap();
396/// # }
397/// ```
398///
399/// ## Command with optional arguments
400///
401/// ```rust
402/// # use supp_macro::command;
403/// # use async_trait::async_trait;
404/// # use uuid::Uuid;
405/// #
406/// # #[derive(Debug, Clone)]
407/// # pub enum Argument {
408/// #     String(String),
409/// #     Uuid(Uuid),
410/// # }
411/// #
412/// # #[derive(Debug)]
413/// # pub enum CmdError {
414/// #     Args(String),
415/// # }
416/// #
417/// # impl std::fmt::Display for CmdError {
418/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419/// #         match self {
420/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
421/// #         }
422/// #     }
423/// # }
424/// #
425/// # impl std::error::Error for CmdError {}
426/// #
427/// # #[derive(Debug)]
428/// # pub enum CmdResult {
429/// #     String(String),
430/// # }
431/// #
432/// # #[derive(Debug, Default)]
433/// # pub struct CommandArgs {
434/// #     pub user_id: Option<uuid::Uuid>,
435/// #     pub account: Option<String>,
436/// # }
437/// # impl CommandArgs {
438/// #     pub fn new() -> Self { Self::default() }
439/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
440/// #     pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
441/// # }
442/// #
443/// # #[async_trait]
444/// # pub trait Command: std::fmt::Debug {
445/// #     type Args;
446/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
447/// # }
448///
449/// command! {
450///     ListTransactions {
451///         #[required]
452///         user_id: Uuid,
453///         #[optional]
454///         account: String,
455///     } => {
456///         let filter = if let Some(account) = account {
457///             format!(" for account {}", account)
458///         } else {
459///             String::new()
460///         };
461///         Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
462///     }
463/// }
464/// ```
465///
466/// ## Command with mixed required and optional arguments
467///
468/// ```rust
469/// # use supp_macro::command;
470/// # use async_trait::async_trait;
471/// #
472/// # #[derive(Debug, Clone)]
473/// # pub enum Argument {
474/// #     String(String),
475/// #     Integer(i64),
476/// #     Boolean(bool),
477/// # }
478/// #
479/// # #[derive(Debug)]
480/// # pub enum CmdError {
481/// #     Args(String),
482/// # }
483/// #
484/// # impl std::fmt::Display for CmdError {
485/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486/// #         match self {
487/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
488/// #         }
489/// #     }
490/// # }
491/// #
492/// # impl std::error::Error for CmdError {}
493/// #
494/// # #[derive(Debug)]
495/// # pub enum CmdResult {
496/// #     Success(String),
497/// # }
498/// #
499/// # impl TryFrom<Argument> for String {
500/// #     type Error = CmdError;
501/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
502/// #         match arg {
503/// #             Argument::String(s) => Ok(s),
504/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
505/// #         }
506/// #     }
507/// # }
508/// #
509/// # impl TryFrom<Argument> for i64 {
510/// #     type Error = CmdError;
511/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
512/// #         match arg {
513/// #             Argument::Integer(i) => Ok(i),
514/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
515/// #         }
516/// #     }
517/// # }
518/// #
519/// # impl TryFrom<Argument> for bool {
520/// #     type Error = CmdError;
521/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
522/// #         match arg {
523/// #             Argument::Boolean(b) => Ok(b),
524/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
525/// #         }
526/// #     }
527/// # }
528/// #
529/// # #[async_trait]
530/// # pub trait TypedCommand {
531/// #     type Args;
532/// #     async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
533/// # }
534/// #
535/// # #[derive(Debug, Default)]
536/// # pub struct CommandArgs {
537/// #     pub user_id: Option<i64>,
538/// #     pub username: Option<String>,
539/// #     pub email: Option<String>,
540/// #     pub is_admin: Option<bool>,
541/// # }
542/// # impl CommandArgs {
543/// #     pub fn new() -> Self { Self::default() }
544/// #     pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
545/// #     pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
546/// #     pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
547/// #     pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
548/// # }
549/// #
550/// # #[async_trait]
551/// # pub trait Command {
552/// #     type Args;
553/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
554/// # }
555///
556/// command! {
557///     CreateUserCommand {
558///         #[required]
559///         user_id: i64,
560///         #[required]
561///         username: String,
562///         #[optional]
563///         email: String,
564///         #[optional]
565///         is_admin: bool,
566///     } => {
567///         let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
568///         let admin_status = is_admin.unwrap_or(false);
569///
570///         let message = format!(
571///             "Created user {} (ID: {}, Email: {}, Admin: {})",
572///             username, user_id, email_str, admin_status
573///         );
574///         Ok(Some(CmdResult::Success(message)))
575///     }
576/// }
577///
578/// # #[tokio::main]
579/// # async fn main() {
580/// let result = CreateUserCommand::new()
581///     .user_id(123)
582///     .username("alice".to_string())
583///     .is_admin(true)
584///     .run()
585///     .await
586///     .unwrap();
587/// # }
588/// ```
589///
590/// ## Server-compatible Command implementation
591///
592/// ```rust
593/// # use supp_macro::command;
594/// # use async_trait::async_trait;
595/// #
596/// # #[derive(Debug, Clone)]
597/// # pub enum Argument {
598/// #     String(String),
599/// #     Integer(i64),
600/// #     Boolean(bool),
601/// # }
602/// #
603/// # #[derive(Debug)]
604/// # pub enum CmdError {
605/// #     Args(String),
606/// # }
607/// #
608/// # impl std::fmt::Display for CmdError {
609/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
610/// #         match self {
611/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
612/// #         }
613/// #     }
614/// # }
615/// #
616/// # impl std::error::Error for CmdError {}
617/// #
618/// # #[derive(Debug)]
619/// # pub enum CmdResult {
620/// #     Success(String),
621/// # }
622/// #
623/// # #[derive(Debug, Default)]
624/// # pub struct CommandArgs {
625/// #     pub a: Option<i64>,
626/// #     pub b: Option<i64>,
627/// # }
628/// # impl CommandArgs {
629/// #     pub fn new() -> Self { Self::default() }
630/// #     pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
631/// #     pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
632/// # }
633/// #
634/// # #[async_trait]
635/// # pub trait Command {
636/// #     type Args;
637/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
638/// # }
639///
640/// command! {
641///     CalculateCommand {
642///         #[required]
643///         a: i64,
644///         #[required]
645///         b: i64,
646///     } => {
647///         let result = a + b;
648///         Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
649///     }
650/// }
651///
652/// # #[tokio::main]
653/// # async fn main() {
654/// let result = CalculateCommand::new()
655///     .a(10)
656///     .b(20)
657///     .run()
658///     .await
659///     .unwrap();
660/// # }
661/// ```
662///
663/// ## Migration from Manual Commands
664///
665/// The macro makes it easy to migrate from manual Command implementations:
666///
667/// ```rust,ignore
668/// // BEFORE: Manual implementation
669/// #[derive(Debug)]
670/// pub struct GetConfig;
671///
672/// #[async_trait]
673/// impl Command for GetConfig {
674///     async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
675///         if let Some(Argument::String(name)) = args.get("name") {
676///             Ok(config(name).await?.map(|v| CmdResult::String(v)))
677///         } else {
678///             Err(CmdError::Args("No field name provided".to_string()))
679///         }
680///     }
681/// }
682///
683/// // AFTER: Using the macro
684/// command! {
685///     GetConfig {
686///         #[required]
687///         name: String,
688///     } => {
689///         Ok(config(name).await?.map(|v| CmdResult::String(v)))
690///     }
691/// }
692/// ```
693///
694/// # Error Handling
695///
696/// The new pure typed system provides compile-time error prevention:
697///
698/// - Missing required arguments are compile-time errors (cannot compile without them)
699/// - Invalid argument types are compile-time errors (type checking at build time)
700/// - Runtime errors only occur in the command body logic itself
701/// - No argument validation overhead at runtime
702///
703/// # Supported Argument Types
704///
705/// The macro supports any Rust type for arguments:
706/// - `String` - Text arguments
707/// - `i64`, `u64`, etc. - Integer arguments
708/// - `bool` - Boolean arguments
709/// - `Rational64` - Rational number arguments (for financial precision)
710/// - `Uuid` - UUID arguments
711/// - `Vec<u8>` - Binary data arguments
712/// - `DateTime<Utc>` - `DateTime` arguments
713/// - Custom types - Any type can be used as an argument
714/// - `Option<T>` - Automatically applied for optional arguments
715#[proc_macro]
716pub fn command(input: TokenStream) -> TokenStream {
717    let input = parse_macro_input!(input as CommandInput);
718
719    let name = &input.name;
720    let required_args = &input.required_args;
721    let optional_args = &input.optional_args;
722    let body = &input.body;
723
724    // Generate progressive runner types for all combinations of required fields
725    let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
726
727    // Generate the main command struct
728    let command_struct = quote! {
729        #[derive(Debug)]
730        pub struct #name;
731    };
732
733    // Generate the new() method that starts the builder chain
734    let new_method = generate_new_method(name, required_args.len(), optional_args);
735
736    let expanded = quote! {
737        #command_struct
738
739        #runner_types
740
741        #new_method
742    };
743
744    TokenStream::from(expanded)
745}
746
747/// Generate all possible runner type combinations for required fields
748fn generate_progressive_runner_types(
749    command_name: &syn::Ident,
750    required_args: &[(syn::Ident, syn::Type)],
751    optional_args: &[(syn::Ident, syn::Type)],
752    body: &syn::Block,
753) -> proc_macro2::TokenStream {
754    let num_required = required_args.len();
755    let total_combinations = 1 << num_required; // 2^num_required
756
757    let mut runner_types = Vec::new();
758
759    // Generate a runner type for each possible combination of set required fields
760    for combination in 0..total_combinations {
761        let runner_type = generate_single_runner_type(
762            command_name,
763            required_args,
764            optional_args,
765            combination,
766            num_required,
767            body,
768        );
769        runner_types.push(runner_type);
770    }
771
772    quote! {
773        #(#runner_types)*
774    }
775}
776
777/// Generate a single runner type for a specific combination of set fields
778fn generate_single_runner_type(
779    command_name: &syn::Ident,
780    required_args: &[(syn::Ident, syn::Type)],
781    optional_args: &[(syn::Ident, syn::Type)],
782    combination: usize,
783    num_required: usize,
784    body: &syn::Block,
785) -> proc_macro2::TokenStream {
786    // Create binary representation for the runner type name
787    let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
788    let runner_name = syn::Ident::new(
789        &format!("{command_name}Runner{binary_suffix}"),
790        command_name.span(),
791    );
792
793    // Determine which required fields are set in this combination
794    let mut struct_fields = Vec::new();
795    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
796        if (combination >> i) & 1 == 1 {
797            // This required field is set in this combination
798            struct_fields.push(quote! {
799                pub #field_name: #field_type
800            });
801        }
802    }
803
804    // Always include optional fields in all runner types
805    for (field_name, field_type) in optional_args {
806        struct_fields.push(quote! {
807            pub #field_name: Option<#field_type>
808        });
809    }
810
811    // Generate the struct definition
812    let struct_def = if struct_fields.is_empty() {
813        quote! {
814            #[derive(Debug)]
815            pub struct #runner_name;
816        }
817    } else {
818        quote! {
819            #[derive(Debug)]
820            pub struct #runner_name {
821                #(#struct_fields),*
822            }
823        }
824    };
825
826    // Generate transition methods for this runner type
827    let transition_methods = generate_transition_methods(
828        command_name,
829        required_args,
830        optional_args,
831        combination,
832        num_required,
833    );
834
835    // Generate run method if this is the complete state (all required fields set)
836    let complete_mask = (1 << num_required) - 1;
837    let run_method = if combination == complete_mask {
838        generate_run_method(command_name, required_args, optional_args, body)
839    } else {
840        quote! {}
841    };
842
843    quote! {
844        #struct_def
845
846        impl #runner_name {
847            #transition_methods
848            #run_method
849        }
850    }
851}
852
853/// Generate transition methods for a runner type (field setters)
854fn generate_transition_methods(
855    command_name: &syn::Ident,
856    required_args: &[(syn::Ident, syn::Type)],
857    optional_args: &[(syn::Ident, syn::Type)],
858    current_combination: usize,
859    num_required: usize,
860) -> proc_macro2::TokenStream {
861    let mut methods = Vec::new();
862
863    // Generate setter methods for required fields not yet set
864    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
865        if (current_combination >> i) & 1 == 0 {
866            // This required field is not set yet, generate a setter
867            let new_combination = current_combination | (1 << i);
868            let binary_suffix =
869                format!("{:0width$b}", new_combination, width = num_required.max(1));
870            let target_runner = syn::Ident::new(
871                &format!("{command_name}Runner{binary_suffix}"),
872                command_name.span(),
873            );
874
875            let method = generate_field_setter_method(
876                required_args,
877                optional_args,
878                field_name,
879                field_type,
880                current_combination,
881                &target_runner,
882            );
883            methods.push(method);
884        }
885    }
886
887    // Generate setter methods for optional fields (available on all runner types)
888    for (field_name, field_type) in optional_args {
889        let current_runner = syn::Ident::new(
890            &format!(
891                "{}Runner{:0width$b}",
892                command_name,
893                current_combination,
894                width = num_required.max(1)
895            ),
896            command_name.span(),
897        );
898
899        let method = generate_optional_field_setter(
900            field_name,
901            field_type,
902            &current_runner,
903            required_args,
904            optional_args,
905            current_combination,
906            num_required,
907        );
908        methods.push(method);
909    }
910
911    quote! {
912        #(#methods)*
913    }
914}
915
916/// Generate a setter method for a required field
917fn generate_field_setter_method(
918    required_args: &[(syn::Ident, syn::Type)],
919    optional_args: &[(syn::Ident, syn::Type)],
920    field_name: &syn::Ident,
921    field_type: &syn::Type,
922    current_combination: usize,
923    target_runner: &syn::Ident,
924) -> proc_macro2::TokenStream {
925    // Generate field assignments for the new state
926    let mut field_assignments = Vec::new();
927
928    // Handle required fields
929    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
930        if req_field_name == field_name {
931            // This is the field being set
932            field_assignments.push(quote! {
933                #req_field_name: value
934            });
935        } else if (current_combination >> i) & 1 == 1 {
936            // This field was already set, move it from self
937            field_assignments.push(quote! {
938                #req_field_name: self.#req_field_name
939            });
940        }
941        // Fields not set in either state are omitted
942    }
943
944    // Handle optional fields (always present, move from self)
945    for (opt_field_name, _) in optional_args {
946        field_assignments.push(quote! {
947            #opt_field_name: self.#opt_field_name
948        });
949    }
950
951    // Generate the constructor call
952    let constructor = if field_assignments.is_empty() {
953        quote! { #target_runner }
954    } else {
955        quote! {
956            #target_runner {
957                #(#field_assignments),*
958            }
959        }
960    };
961
962    quote! {
963        pub fn #field_name(self, value: #field_type) -> #target_runner {
964            #constructor
965        }
966    }
967}
968
969/// Generate a setter method for an optional field
970fn generate_optional_field_setter(
971    field_name: &syn::Ident,
972    field_type: &syn::Type,
973    current_runner: &syn::Ident,
974    required_args: &[(syn::Ident, syn::Type)],
975    optional_args: &[(syn::Ident, syn::Type)],
976    current_combination: usize,
977    _num_required: usize,
978) -> proc_macro2::TokenStream {
979    // Generate field assignments (same state, but update the optional field)
980    let mut field_assignments = Vec::new();
981
982    // Handle required fields (move from self if set)
983    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
984        if (current_combination >> i) & 1 == 1 {
985            field_assignments.push(quote! {
986                #req_field_name: self.#req_field_name
987            });
988        }
989    }
990
991    // Handle optional fields
992    for (opt_field_name, _) in optional_args {
993        if opt_field_name == field_name {
994            // This is the field being set
995            field_assignments.push(quote! {
996                #opt_field_name: Some(value)
997            });
998        } else {
999            // Move other optional fields from self
1000            field_assignments.push(quote! {
1001                #opt_field_name: self.#opt_field_name
1002            });
1003        }
1004    }
1005
1006    let constructor = if field_assignments.is_empty() {
1007        quote! { #current_runner }
1008    } else {
1009        quote! {
1010            #current_runner {
1011                #(#field_assignments),*
1012            }
1013        }
1014    };
1015
1016    quote! {
1017        pub fn #field_name(self, value: #field_type) -> #current_runner {
1018            #constructor
1019        }
1020    }
1021}
1022
1023/// Generate the run method for the complete runner state
1024fn generate_run_method(
1025    _command_name: &syn::Ident,
1026    required_args: &[(syn::Ident, syn::Type)],
1027    optional_args: &[(syn::Ident, syn::Type)],
1028    body: &syn::Block,
1029) -> proc_macro2::TokenStream {
1030    // Extract field values directly (no unwrap needed!)
1031    let mut variable_assignments = Vec::new();
1032
1033    // Required fields - direct field access
1034    for (field_name, _) in required_args {
1035        variable_assignments.push(quote! {
1036            let #field_name = self.#field_name;
1037        });
1038    }
1039
1040    // Optional fields - direct field access
1041    for (field_name, _) in optional_args {
1042        variable_assignments.push(quote! {
1043            let #field_name = self.#field_name;
1044        });
1045    }
1046
1047    quote! {
1048        pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
1049            // Zero runtime checks - direct field access!
1050            #(#variable_assignments)*
1051
1052            // Original command body
1053            #body
1054        }
1055    }
1056}
1057
1058/// Generate the `new()` method for the command
1059fn generate_new_method(
1060    command_name: &syn::Ident,
1061    num_required: usize,
1062    optional_args: &[(syn::Ident, syn::Type)],
1063) -> proc_macro2::TokenStream {
1064    let initial_runner = syn::Ident::new(
1065        &format!(
1066            "{}Runner{:0width$b}",
1067            command_name,
1068            0,
1069            width = num_required.max(1)
1070        ),
1071        command_name.span(),
1072    );
1073
1074    // Initial state has no required fields set, but has optional fields as None
1075    let constructor = if optional_args.is_empty() && num_required > 0 {
1076        // Unit struct (no fields at all in initial state)
1077        quote! { #initial_runner }
1078    } else {
1079        // Struct with optional fields initialized to None
1080        let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
1081            quote! { #field_name: None }
1082        });
1083
1084        if optional_field_inits.len() > 0 {
1085            quote! {
1086                #initial_runner {
1087                    #(#optional_field_inits),*
1088                }
1089            }
1090        } else {
1091            quote! { #initial_runner }
1092        }
1093    };
1094
1095    quote! {
1096        impl #command_name {
1097            pub fn new() -> #initial_runner {
1098                #constructor
1099            }
1100        }
1101    }
1102}
1103
1104struct CommandInput {
1105    name: syn::Ident,
1106    required_args: Vec<(syn::Ident, syn::Type)>,
1107    optional_args: Vec<(syn::Ident, syn::Type)>,
1108    body: syn::Block,
1109}
1110
1111impl syn::parse::Parse for CommandInput {
1112    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1113        let name: syn::Ident = input.parse()?;
1114
1115        let content;
1116        syn::braced!(content in input);
1117
1118        let mut required_args = Vec::new();
1119        let mut optional_args = Vec::new();
1120
1121        while !content.is_empty() {
1122            // Parse attributes
1123            let mut is_optional = false;
1124            let mut is_required = false;
1125
1126            while content.peek(syn::Token![#]) {
1127                content.parse::<syn::Token![#]>()?;
1128                let attr_content;
1129                syn::bracketed!(attr_content in content);
1130                let attr_name: syn::Ident = attr_content.parse()?;
1131
1132                if attr_name == "optional" {
1133                    is_optional = true;
1134                } else if attr_name == "required" {
1135                    is_required = true;
1136                } else {
1137                    return Err(syn::Error::new(
1138                        attr_name.span(),
1139                        "Unknown attribute. Use #[required] or #[optional]",
1140                    ));
1141                }
1142            }
1143
1144            // Parse the field
1145            let arg_name: syn::Ident = content.parse()?;
1146            content.parse::<syn::Token![:]>()?;
1147            let arg_type: syn::Type = content.parse()?;
1148
1149            if content.peek(syn::Token![,]) {
1150                content.parse::<syn::Token![,]>()?;
1151            }
1152
1153            // Determine if optional (default to required if no attribute specified)
1154            let is_optional_field = if is_required && is_optional {
1155                return Err(syn::Error::new(
1156                    arg_name.span(),
1157                    "Field cannot be both #[required] and #[optional]",
1158                ));
1159            } else if is_optional {
1160                true
1161            } else {
1162                false // Default to required
1163            };
1164
1165            if is_optional_field {
1166                optional_args.push((arg_name, arg_type));
1167            } else {
1168                required_args.push((arg_name, arg_type));
1169            }
1170        }
1171
1172        // The '=>' is outside the braces
1173        input.parse::<syn::Token![=>]>()?;
1174        let body: syn::Block = input.parse()?;
1175
1176        Ok(CommandInput {
1177            name,
1178            required_args,
1179            optional_args,
1180            body,
1181        })
1182    }
1183}