supp_macro/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::Lit;
4use syn::{Data, DeriveInput, Fields, ItemFn, Path, parse_macro_input};
5
6#[proc_macro_attribute]
7pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(item as ItemFn);
9 let fn_name = &input.sig.ident;
10 let block = &input.block;
11
12 let expanded = quote! {
13 #[sqlx::test(migrations = "../migrations")]
14 async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
15 setup().await;
16 DB_POOL.set(&pool);
17 #block
18 Ok(())
19 }
20 };
21
22 TokenStream::from(expanded)
23}
24
25#[proc_macro_derive(Builder, attributes(builder))]
26pub fn builder_macro(input: TokenStream) -> TokenStream {
27 // Parse the input tokens into a syntax tree
28 let input = parse_macro_input!(input as DeriveInput);
29 let name = &input.ident;
30 let generics = &input.generics; // Capture generics (including lifetimes)
31 let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
32
33 // Check for custom error_kind attribute
34 let mut error_kind = None;
35
36 // Parse attributes
37 for attr in &input.attrs {
38 if attr.path().is_ident("builder") {
39 attr.parse_nested_meta(|meta| {
40 if meta.path.is_ident("error_kind")
41 && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
42 {
43 error_kind = Some(lit_str.parse::<Path>().unwrap());
44 }
45 Ok(())
46 })
47 .unwrap();
48 }
49 }
50
51 // Set a default error kind if none is provided
52 let error_kind = error_kind.expect(
53 "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
54 );
55
56 // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
57 let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
58
59 let fields = if let Data::Struct(data) = &input.data {
60 if let Fields::Named(fields) = &data.fields {
61 fields.named.iter().collect::<Vec<_>>()
62 } else {
63 panic!("Builder macro only supports structs with named fields");
64 }
65 } else {
66 panic!("Builder macro only supports structs");
67 };
68
69 // Generate builder struct fields with the same generics (including lifetimes)
70 let builder_fields = fields.iter().map(|field| {
71 let field_name = &field.ident;
72 let field_ty = &field.ty;
73 let builder_field_type = quote! { Option<#field_ty> };
74 quote! {
75 #field_name: #builder_field_type
76 }
77 });
78
79 // Generate initialization in new()
80 let builder_fields_init = fields.iter().map(|field| {
81 let field_name = &field.ident;
82 quote! {
83 #field_name: None
84 }
85 });
86
87 // Generate setter methods
88 let setters = fields.iter().map(|field| {
89 let field_name = &field.ident;
90 let field_type = &field.ty;
91
92 if is_option_type(field_type) {
93 let inner_type = get_inner_type(field_type);
94 if is_string_type(&inner_type) {
95 // For Option<String>, accept &str
96 quote! {
97 pub fn #field_name(&mut self, value: &str) -> &mut Self {
98 self.#field_name = Some(Some(value.to_string()));
99 self
100 }
101 }
102 } else {
103 // For Option<T>, accept T directly
104 quote! {
105 pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
106 self.#field_name = Some(Some(value));
107 self
108 }
109 }
110 }
111 } else if is_string_type(field_type) {
112 // For String, accept &str
113 quote! {
114 pub fn #field_name(&mut self, value: &str) -> &mut Self {
115 self.#field_name = Some(value.to_string());
116 self
117 }
118 }
119 } else {
120 // For non-Option<T> and non-String fields, accept T directly
121 quote! {
122 pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
123 self.#field_name = Some(value);
124 self
125 }
126 }
127 }
128 });
129
130 // Generate code to check for missing required fields
131 let check_required_fields = fields
132 .iter()
133 .filter(|field| !is_option_type(&field.ty))
134 .map(|field| {
135 let field_name = &field.ident;
136 let field_name_str = field_name.as_ref().unwrap().to_string();
137 quote! {
138 if self.#field_name.is_none() {
139 missing_fields.push(#field_name_str);
140 }
141 }
142 });
143
144 // Generate build_fields
145 let build_fields = fields.iter().map(|field| {
146 let field_name = &field.ident;
147 if is_option_type(&field.ty) {
148 quote! {
149 #field_name: self.#field_name.clone().unwrap_or(None)
150 }
151 } else {
152 quote! {
153 #field_name: self.#field_name.clone().unwrap()
154 }
155 }
156 });
157
158 // Extract the lifetime parameters from generics for use in the builder struct
159 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
160
161 let expanded = quote! {
162 pub struct #builder_name #impl_generics #where_clause {
163 #(#builder_fields),*
164 }
165
166 impl #impl_generics #builder_name #ty_generics #where_clause {
167 pub fn new() -> Self {
168 Self {
169 #(#builder_fields_init),*
170 }
171 }
172
173 #(#setters)*
174
175 pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
176 let mut missing_fields = Vec::new();
177 #(#check_required_fields)*
178
179 if !missing_fields.is_empty() {
180 return Err(#error_kind::from(#custom_error_name::Build(format!(
181 "{} fields are missing: {}",
182 stringify!(#name),
183 missing_fields.join(", ")
184 ))));
185 }
186
187 Ok(#name {
188 #(#build_fields),*
189 })
190 }
191 }
192
193 impl #impl_generics #name #ty_generics #where_clause {
194 pub fn builder() -> #builder_name #ty_generics {
195 #builder_name::new()
196 }
197 }
198 };
199
200 TokenStream::from(expanded)
201}
202
203/// Helper function to determine if a type is an `Option<T>`
204fn is_option_type(ty: &syn::Type) -> bool {
205 matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
206}
207
208/// Helper function to get the inner type of an `Option<T>`
209fn get_inner_type(ty: &syn::Type) -> syn::Type {
210 if let syn::Type::Path(type_path) = ty
211 && let Some(segment) = type_path.path.segments.first()
212 && segment.ident == "Option"
213 && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
214 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
215 {
216 return inner_type.clone();
217 }
218 ty.clone()
219}
220
221/// Helper function to check if the type is String
222fn is_string_type(ty: &syn::Type) -> bool {
223 if let syn::Type::Path(type_path) = ty
224 && let Some(segment) = type_path.path.segments.last()
225 {
226 return segment.ident == "String";
227 }
228 false
229}
230
231/// A procedural macro for generating typed Command implementations with compile-time validation.
232///
233/// This macro provides pure value-based argument passing with compile-time type safety by generating:
234/// - Typed Args structs with proper field types passed by value only
235/// - Commands that accept Args structs directly (no `HashMap` usage)
236/// - Individual typed variables available directly in command scope
237/// - Compile-time validation of argument types and required/optional fields
238/// - Zero runtime argument parsing or validation overhead
239///
240/// # Syntax
241///
242/// ```ignore
243/// command! {
244/// CommandName {
245/// #[required]
246/// arg_name: Type,
247/// #[optional]
248/// opt_name: Type,
249/// } => {
250/// // Command implementation body
251/// // Individual typed variables are available in scope
252/// }
253/// }
254/// ```
255///
256/// # Generated Code
257///
258/// The macro generates:
259/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
260/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
261/// - Individual typed variables extracted from the Args struct and available in the command body
262/// - Pure compile-time type validation with no runtime overhead
263///
264/// # Examples
265///
266/// ## Simple command with no arguments
267///
268/// ```rust
269/// # use supp_macro::command;
270/// # use async_trait::async_trait;
271/// #
272/// # #[derive(Debug)]
273/// # pub enum CmdError {
274/// # Args(String),
275/// # }
276/// #
277/// # impl std::fmt::Display for CmdError {
278/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279/// # match self {
280/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
281/// # }
282/// # }
283/// # }
284/// #
285/// # impl std::error::Error for CmdError {}
286/// #
287/// # #[derive(Debug)]
288/// # pub enum CmdResult {
289/// # String(String),
290/// # }
291/// #
292/// # #[derive(Debug, Default)]
293/// # pub struct CommandArgs {}
294/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
295/// #
296/// # #[async_trait]
297/// # pub trait Command: std::fmt::Debug {
298/// # type Args;
299/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
300/// # }
301///
302/// command! {
303/// GetVersion {
304/// } => {
305/// Ok(Some(CmdResult::String("1.0.0".to_string())))
306/// }
307/// }
308///
309/// # #[tokio::main]
310/// # async fn main() {
311/// let result = GetVersion::new().run().await.unwrap();
312/// # }
313/// ```
314///
315/// ## Command with required arguments (server-compatible types)
316///
317/// ```rust
318/// # use supp_macro::command;
319/// # use async_trait::async_trait;
320/// # use uuid::Uuid;
321/// # use num_rational::Rational64;
322/// #
323/// # #[derive(Debug, Clone)]
324/// # pub enum Argument {
325/// # String(String),
326/// # Uuid(Uuid),
327/// # Rational(Rational64),
328/// # }
329/// #
330/// # #[derive(Debug)]
331/// # pub enum CmdError {
332/// # Args(String),
333/// # }
334/// #
335/// # impl std::fmt::Display for CmdError {
336/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337/// # match self {
338/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
339/// # }
340/// # }
341/// # }
342/// #
343/// # impl std::error::Error for CmdError {}
344/// #
345/// # #[derive(Debug)]
346/// # pub enum CmdResult {
347/// # String(String),
348/// # }
349/// #
350/// # #[derive(Debug, Default)]
351/// # pub struct CommandArgs {
352/// # pub symbol: Option<String>,
353/// # pub name: Option<String>,
354/// # pub user_id: Option<uuid::Uuid>,
355/// # }
356/// # impl CommandArgs {
357/// # pub fn new() -> Self { Self::default() }
358/// # pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
359/// # pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
360/// # pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
361/// # }
362/// #
363/// # #[async_trait]
364/// # pub trait Command: std::fmt::Debug {
365/// # type Args;
366/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
367/// # }
368///
369/// // This creates a commodity in the financial system
370/// command! {
371/// CreateCommodity {
372/// #[required]
373/// symbol: String,
374/// #[required]
375/// name: String,
376/// #[required]
377/// user_id: Uuid,
378/// } => {
379/// // Individual typed variables are automatically available
380/// Ok(Some(CmdResult::String(format!(
381/// "Created commodity {} ({}) for user {}",
382/// name, symbol, user_id
383/// ))))
384/// }
385/// }
386///
387/// # #[tokio::main]
388/// # async fn main() {
389/// let result = CreateCommodity::new()
390/// .symbol("USD".to_string())
391/// .name("US Dollar".to_string())
392/// .user_id(uuid::Uuid::new_v4())
393/// .run()
394/// .await
395/// .unwrap();
396/// # }
397/// ```
398///
399/// ## Command with optional arguments
400///
401/// ```rust
402/// # use supp_macro::command;
403/// # use async_trait::async_trait;
404/// # use uuid::Uuid;
405/// #
406/// # #[derive(Debug, Clone)]
407/// # pub enum Argument {
408/// # String(String),
409/// # Uuid(Uuid),
410/// # }
411/// #
412/// # #[derive(Debug)]
413/// # pub enum CmdError {
414/// # Args(String),
415/// # }
416/// #
417/// # impl std::fmt::Display for CmdError {
418/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419/// # match self {
420/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
421/// # }
422/// # }
423/// # }
424/// #
425/// # impl std::error::Error for CmdError {}
426/// #
427/// # #[derive(Debug)]
428/// # pub enum CmdResult {
429/// # String(String),
430/// # }
431/// #
432/// # #[derive(Debug, Default)]
433/// # pub struct CommandArgs {
434/// # pub user_id: Option<uuid::Uuid>,
435/// # pub account: Option<String>,
436/// # }
437/// # impl CommandArgs {
438/// # pub fn new() -> Self { Self::default() }
439/// # pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
440/// # pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
441/// # }
442/// #
443/// # #[async_trait]
444/// # pub trait Command: std::fmt::Debug {
445/// # type Args;
446/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
447/// # }
448///
449/// command! {
450/// ListTransactions {
451/// #[required]
452/// user_id: Uuid,
453/// #[optional]
454/// account: String,
455/// } => {
456/// let filter = if let Some(account) = account {
457/// format!(" for account {}", account)
458/// } else {
459/// String::new()
460/// };
461/// Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
462/// }
463/// }
464/// ```
465///
466/// ## Command with mixed required and optional arguments
467///
468/// ```rust
469/// # use supp_macro::command;
470/// # use async_trait::async_trait;
471/// #
472/// # #[derive(Debug, Clone)]
473/// # pub enum Argument {
474/// # String(String),
475/// # Integer(i64),
476/// # Boolean(bool),
477/// # }
478/// #
479/// # #[derive(Debug)]
480/// # pub enum CmdError {
481/// # Args(String),
482/// # }
483/// #
484/// # impl std::fmt::Display for CmdError {
485/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486/// # match self {
487/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
488/// # }
489/// # }
490/// # }
491/// #
492/// # impl std::error::Error for CmdError {}
493/// #
494/// # #[derive(Debug)]
495/// # pub enum CmdResult {
496/// # Success(String),
497/// # }
498/// #
499/// # impl TryFrom<Argument> for String {
500/// # type Error = CmdError;
501/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
502/// # match arg {
503/// # Argument::String(s) => Ok(s),
504/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
505/// # }
506/// # }
507/// # }
508/// #
509/// # impl TryFrom<Argument> for i64 {
510/// # type Error = CmdError;
511/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
512/// # match arg {
513/// # Argument::Integer(i) => Ok(i),
514/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
515/// # }
516/// # }
517/// # }
518/// #
519/// # impl TryFrom<Argument> for bool {
520/// # type Error = CmdError;
521/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
522/// # match arg {
523/// # Argument::Boolean(b) => Ok(b),
524/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
525/// # }
526/// # }
527/// # }
528/// #
529/// # #[async_trait]
530/// # pub trait TypedCommand {
531/// # type Args;
532/// # async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
533/// # }
534/// #
535/// # #[derive(Debug, Default)]
536/// # pub struct CommandArgs {
537/// # pub user_id: Option<i64>,
538/// # pub username: Option<String>,
539/// # pub email: Option<String>,
540/// # pub is_admin: Option<bool>,
541/// # }
542/// # impl CommandArgs {
543/// # pub fn new() -> Self { Self::default() }
544/// # pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
545/// # pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
546/// # pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
547/// # pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
548/// # }
549/// #
550/// # #[async_trait]
551/// # pub trait Command {
552/// # type Args;
553/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
554/// # }
555///
556/// command! {
557/// CreateUserCommand {
558/// #[required]
559/// user_id: i64,
560/// #[required]
561/// username: String,
562/// #[optional]
563/// email: String,
564/// #[optional]
565/// is_admin: bool,
566/// } => {
567/// let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
568/// let admin_status = is_admin.unwrap_or(false);
569///
570/// let message = format!(
571/// "Created user {} (ID: {}, Email: {}, Admin: {})",
572/// username, user_id, email_str, admin_status
573/// );
574/// Ok(Some(CmdResult::Success(message)))
575/// }
576/// }
577///
578/// # #[tokio::main]
579/// # async fn main() {
580/// let result = CreateUserCommand::new()
581/// .user_id(123)
582/// .username("alice".to_string())
583/// .is_admin(true)
584/// .run()
585/// .await
586/// .unwrap();
587/// # }
588/// ```
589///
590/// ## Server-compatible Command implementation
591///
592/// ```rust
593/// # use supp_macro::command;
594/// # use async_trait::async_trait;
595/// #
596/// # #[derive(Debug, Clone)]
597/// # pub enum Argument {
598/// # String(String),
599/// # Integer(i64),
600/// # Boolean(bool),
601/// # }
602/// #
603/// # #[derive(Debug)]
604/// # pub enum CmdError {
605/// # Args(String),
606/// # }
607/// #
608/// # impl std::fmt::Display for CmdError {
609/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
610/// # match self {
611/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
612/// # }
613/// # }
614/// # }
615/// #
616/// # impl std::error::Error for CmdError {}
617/// #
618/// # #[derive(Debug)]
619/// # pub enum CmdResult {
620/// # Success(String),
621/// # }
622/// #
623/// # #[derive(Debug, Default)]
624/// # pub struct CommandArgs {
625/// # pub a: Option<i64>,
626/// # pub b: Option<i64>,
627/// # }
628/// # impl CommandArgs {
629/// # pub fn new() -> Self { Self::default() }
630/// # pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
631/// # pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
632/// # }
633/// #
634/// # #[async_trait]
635/// # pub trait Command {
636/// # type Args;
637/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
638/// # }
639///
640/// command! {
641/// CalculateCommand {
642/// #[required]
643/// a: i64,
644/// #[required]
645/// b: i64,
646/// } => {
647/// let result = a + b;
648/// Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
649/// }
650/// }
651///
652/// # #[tokio::main]
653/// # async fn main() {
654/// let result = CalculateCommand::new()
655/// .a(10)
656/// .b(20)
657/// .run()
658/// .await
659/// .unwrap();
660/// # }
661/// ```
662///
663/// ## Migration from Manual Commands
664///
665/// The macro makes it easy to migrate from manual Command implementations:
666///
667/// ```rust,ignore
668/// // BEFORE: Manual implementation
669/// #[derive(Debug)]
670/// pub struct GetConfig;
671///
672/// #[async_trait]
673/// impl Command for GetConfig {
674/// async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
675/// if let Some(Argument::String(name)) = args.get("name") {
676/// Ok(config(name).await?.map(|v| CmdResult::String(v)))
677/// } else {
678/// Err(CmdError::Args("No field name provided".to_string()))
679/// }
680/// }
681/// }
682///
683/// // AFTER: Using the macro
684/// command! {
685/// GetConfig {
686/// #[required]
687/// name: String,
688/// } => {
689/// Ok(config(name).await?.map(|v| CmdResult::String(v)))
690/// }
691/// }
692/// ```
693///
694/// # Error Handling
695///
696/// The new pure typed system provides compile-time error prevention:
697///
698/// - Missing required arguments are compile-time errors (cannot compile without them)
699/// - Invalid argument types are compile-time errors (type checking at build time)
700/// - Runtime errors only occur in the command body logic itself
701/// - No argument validation overhead at runtime
702///
703/// # Supported Argument Types
704///
705/// The macro supports any Rust type for arguments:
706/// - `String` - Text arguments
707/// - `i64`, `u64`, etc. - Integer arguments
708/// - `bool` - Boolean arguments
709/// - `Rational64` - Rational number arguments (for financial precision)
710/// - `Uuid` - UUID arguments
711/// - `Vec<u8>` - Binary data arguments
712/// - `DateTime<Utc>` - `DateTime` arguments
713/// - Custom types - Any type can be used as an argument
714/// - `Option<T>` - Automatically applied for optional arguments
715#[proc_macro]
716pub fn command(input: TokenStream) -> TokenStream {
717 let input = parse_macro_input!(input as CommandInput);
718
719 let name = &input.name;
720 let required_args = &input.required_args;
721 let optional_args = &input.optional_args;
722 let body = &input.body;
723
724 // Generate progressive runner types for all combinations of required fields
725 let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
726
727 // Generate the main command struct
728 let command_struct = quote! {
729 #[derive(Debug)]
730 pub struct #name;
731 };
732
733 // Generate the new() method that starts the builder chain
734 let new_method = generate_new_method(name, required_args.len(), optional_args);
735
736 let expanded = quote! {
737 #command_struct
738
739 #runner_types
740
741 #new_method
742 };
743
744 TokenStream::from(expanded)
745}
746
747/// Generate all possible runner type combinations for required fields
748fn generate_progressive_runner_types(
749 command_name: &syn::Ident,
750 required_args: &[(syn::Ident, syn::Type)],
751 optional_args: &[(syn::Ident, syn::Type)],
752 body: &syn::Block,
753) -> proc_macro2::TokenStream {
754 let num_required = required_args.len();
755 let total_combinations = 1 << num_required; // 2^num_required
756
757 let mut runner_types = Vec::new();
758
759 // Generate a runner type for each possible combination of set required fields
760 for combination in 0..total_combinations {
761 let runner_type = generate_single_runner_type(
762 command_name,
763 required_args,
764 optional_args,
765 combination,
766 num_required,
767 body,
768 );
769 runner_types.push(runner_type);
770 }
771
772 quote! {
773 #(#runner_types)*
774 }
775}
776
777/// Generate a single runner type for a specific combination of set fields
778fn generate_single_runner_type(
779 command_name: &syn::Ident,
780 required_args: &[(syn::Ident, syn::Type)],
781 optional_args: &[(syn::Ident, syn::Type)],
782 combination: usize,
783 num_required: usize,
784 body: &syn::Block,
785) -> proc_macro2::TokenStream {
786 // Create binary representation for the runner type name
787 let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
788 let runner_name = syn::Ident::new(
789 &format!("{command_name}Runner{binary_suffix}"),
790 command_name.span(),
791 );
792
793 // Determine which required fields are set in this combination
794 let mut struct_fields = Vec::new();
795 for (i, (field_name, field_type)) in required_args.iter().enumerate() {
796 if (combination >> i) & 1 == 1 {
797 // This required field is set in this combination
798 struct_fields.push(quote! {
799 pub #field_name: #field_type
800 });
801 }
802 }
803
804 // Always include optional fields in all runner types
805 for (field_name, field_type) in optional_args {
806 struct_fields.push(quote! {
807 pub #field_name: Option<#field_type>
808 });
809 }
810
811 // Generate the struct definition
812 let struct_def = if struct_fields.is_empty() {
813 quote! {
814 #[derive(Debug)]
815 pub struct #runner_name;
816 }
817 } else {
818 quote! {
819 #[derive(Debug)]
820 pub struct #runner_name {
821 #(#struct_fields),*
822 }
823 }
824 };
825
826 // Generate transition methods for this runner type
827 let transition_methods = generate_transition_methods(
828 command_name,
829 required_args,
830 optional_args,
831 combination,
832 num_required,
833 );
834
835 // Generate run method if this is the complete state (all required fields set)
836 let complete_mask = (1 << num_required) - 1;
837 let run_method = if combination == complete_mask {
838 generate_run_method(command_name, required_args, optional_args, body)
839 } else {
840 quote! {}
841 };
842
843 quote! {
844 #struct_def
845
846 impl #runner_name {
847 #transition_methods
848 #run_method
849 }
850 }
851}
852
853/// Generate transition methods for a runner type (field setters)
854fn generate_transition_methods(
855 command_name: &syn::Ident,
856 required_args: &[(syn::Ident, syn::Type)],
857 optional_args: &[(syn::Ident, syn::Type)],
858 current_combination: usize,
859 num_required: usize,
860) -> proc_macro2::TokenStream {
861 let mut methods = Vec::new();
862
863 // Generate setter methods for required fields not yet set
864 for (i, (field_name, field_type)) in required_args.iter().enumerate() {
865 if (current_combination >> i) & 1 == 0 {
866 // This required field is not set yet, generate a setter
867 let new_combination = current_combination | (1 << i);
868 let binary_suffix =
869 format!("{:0width$b}", new_combination, width = num_required.max(1));
870 let target_runner = syn::Ident::new(
871 &format!("{command_name}Runner{binary_suffix}"),
872 command_name.span(),
873 );
874
875 let method = generate_field_setter_method(
876 required_args,
877 optional_args,
878 field_name,
879 field_type,
880 current_combination,
881 &target_runner,
882 );
883 methods.push(method);
884 }
885 }
886
887 // Generate setter methods for optional fields (available on all runner types)
888 for (field_name, field_type) in optional_args {
889 let current_runner = syn::Ident::new(
890 &format!(
891 "{}Runner{:0width$b}",
892 command_name,
893 current_combination,
894 width = num_required.max(1)
895 ),
896 command_name.span(),
897 );
898
899 let method = generate_optional_field_setter(
900 field_name,
901 field_type,
902 ¤t_runner,
903 required_args,
904 optional_args,
905 current_combination,
906 num_required,
907 );
908 methods.push(method);
909 }
910
911 quote! {
912 #(#methods)*
913 }
914}
915
916/// Generate a setter method for a required field
917fn generate_field_setter_method(
918 required_args: &[(syn::Ident, syn::Type)],
919 optional_args: &[(syn::Ident, syn::Type)],
920 field_name: &syn::Ident,
921 field_type: &syn::Type,
922 current_combination: usize,
923 target_runner: &syn::Ident,
924) -> proc_macro2::TokenStream {
925 // Generate field assignments for the new state
926 let mut field_assignments = Vec::new();
927
928 // Handle required fields
929 for (i, (req_field_name, _)) in required_args.iter().enumerate() {
930 if req_field_name == field_name {
931 // This is the field being set
932 field_assignments.push(quote! {
933 #req_field_name: value
934 });
935 } else if (current_combination >> i) & 1 == 1 {
936 // This field was already set, move it from self
937 field_assignments.push(quote! {
938 #req_field_name: self.#req_field_name
939 });
940 }
941 // Fields not set in either state are omitted
942 }
943
944 // Handle optional fields (always present, move from self)
945 for (opt_field_name, _) in optional_args {
946 field_assignments.push(quote! {
947 #opt_field_name: self.#opt_field_name
948 });
949 }
950
951 // Generate the constructor call
952 let constructor = if field_assignments.is_empty() {
953 quote! { #target_runner }
954 } else {
955 quote! {
956 #target_runner {
957 #(#field_assignments),*
958 }
959 }
960 };
961
962 quote! {
963 pub fn #field_name(self, value: #field_type) -> #target_runner {
964 #constructor
965 }
966 }
967}
968
969/// Generate a setter method for an optional field
970fn generate_optional_field_setter(
971 field_name: &syn::Ident,
972 field_type: &syn::Type,
973 current_runner: &syn::Ident,
974 required_args: &[(syn::Ident, syn::Type)],
975 optional_args: &[(syn::Ident, syn::Type)],
976 current_combination: usize,
977 _num_required: usize,
978) -> proc_macro2::TokenStream {
979 // Generate field assignments (same state, but update the optional field)
980 let mut field_assignments = Vec::new();
981
982 // Handle required fields (move from self if set)
983 for (i, (req_field_name, _)) in required_args.iter().enumerate() {
984 if (current_combination >> i) & 1 == 1 {
985 field_assignments.push(quote! {
986 #req_field_name: self.#req_field_name
987 });
988 }
989 }
990
991 // Handle optional fields
992 for (opt_field_name, _) in optional_args {
993 if opt_field_name == field_name {
994 // This is the field being set
995 field_assignments.push(quote! {
996 #opt_field_name: Some(value)
997 });
998 } else {
999 // Move other optional fields from self
1000 field_assignments.push(quote! {
1001 #opt_field_name: self.#opt_field_name
1002 });
1003 }
1004 }
1005
1006 let constructor = if field_assignments.is_empty() {
1007 quote! { #current_runner }
1008 } else {
1009 quote! {
1010 #current_runner {
1011 #(#field_assignments),*
1012 }
1013 }
1014 };
1015
1016 quote! {
1017 pub fn #field_name(self, value: #field_type) -> #current_runner {
1018 #constructor
1019 }
1020 }
1021}
1022
1023/// Generate the run method for the complete runner state
1024fn generate_run_method(
1025 _command_name: &syn::Ident,
1026 required_args: &[(syn::Ident, syn::Type)],
1027 optional_args: &[(syn::Ident, syn::Type)],
1028 body: &syn::Block,
1029) -> proc_macro2::TokenStream {
1030 // Extract field values directly (no unwrap needed!)
1031 let mut variable_assignments = Vec::new();
1032
1033 // Required fields - direct field access
1034 for (field_name, _) in required_args {
1035 variable_assignments.push(quote! {
1036 let #field_name = self.#field_name;
1037 });
1038 }
1039
1040 // Optional fields - direct field access
1041 for (field_name, _) in optional_args {
1042 variable_assignments.push(quote! {
1043 let #field_name = self.#field_name;
1044 });
1045 }
1046
1047 quote! {
1048 pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
1049 // Zero runtime checks - direct field access!
1050 #(#variable_assignments)*
1051
1052 // Original command body
1053 #body
1054 }
1055 }
1056}
1057
1058/// Generate the `new()` method for the command
1059fn generate_new_method(
1060 command_name: &syn::Ident,
1061 num_required: usize,
1062 optional_args: &[(syn::Ident, syn::Type)],
1063) -> proc_macro2::TokenStream {
1064 let initial_runner = syn::Ident::new(
1065 &format!(
1066 "{}Runner{:0width$b}",
1067 command_name,
1068 0,
1069 width = num_required.max(1)
1070 ),
1071 command_name.span(),
1072 );
1073
1074 // Initial state has no required fields set, but has optional fields as None
1075 let constructor = if optional_args.is_empty() && num_required > 0 {
1076 // Unit struct (no fields at all in initial state)
1077 quote! { #initial_runner }
1078 } else {
1079 // Struct with optional fields initialized to None
1080 let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
1081 quote! { #field_name: None }
1082 });
1083
1084 if optional_field_inits.len() > 0 {
1085 quote! {
1086 #initial_runner {
1087 #(#optional_field_inits),*
1088 }
1089 }
1090 } else {
1091 quote! { #initial_runner }
1092 }
1093 };
1094
1095 quote! {
1096 impl #command_name {
1097 pub fn new() -> #initial_runner {
1098 #constructor
1099 }
1100 }
1101 }
1102}
1103
1104struct CommandInput {
1105 name: syn::Ident,
1106 required_args: Vec<(syn::Ident, syn::Type)>,
1107 optional_args: Vec<(syn::Ident, syn::Type)>,
1108 body: syn::Block,
1109}
1110
1111impl syn::parse::Parse for CommandInput {
1112 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1113 let name: syn::Ident = input.parse()?;
1114
1115 let content;
1116 syn::braced!(content in input);
1117
1118 let mut required_args = Vec::new();
1119 let mut optional_args = Vec::new();
1120
1121 while !content.is_empty() {
1122 // Parse attributes
1123 let mut is_optional = false;
1124 let mut is_required = false;
1125
1126 while content.peek(syn::Token![#]) {
1127 content.parse::<syn::Token![#]>()?;
1128 let attr_content;
1129 syn::bracketed!(attr_content in content);
1130 let attr_name: syn::Ident = attr_content.parse()?;
1131
1132 if attr_name == "optional" {
1133 is_optional = true;
1134 } else if attr_name == "required" {
1135 is_required = true;
1136 } else {
1137 return Err(syn::Error::new(
1138 attr_name.span(),
1139 "Unknown attribute. Use #[required] or #[optional]",
1140 ));
1141 }
1142 }
1143
1144 // Parse the field
1145 let arg_name: syn::Ident = content.parse()?;
1146 content.parse::<syn::Token![:]>()?;
1147 let arg_type: syn::Type = content.parse()?;
1148
1149 if content.peek(syn::Token![,]) {
1150 content.parse::<syn::Token![,]>()?;
1151 }
1152
1153 // Determine if optional (default to required if no attribute specified)
1154 let is_optional_field = if is_required && is_optional {
1155 return Err(syn::Error::new(
1156 arg_name.span(),
1157 "Field cannot be both #[required] and #[optional]",
1158 ));
1159 } else if is_optional {
1160 true
1161 } else {
1162 false // Default to required
1163 };
1164
1165 if is_optional_field {
1166 optional_args.push((arg_name, arg_type));
1167 } else {
1168 required_args.push((arg_name, arg_type));
1169 }
1170 }
1171
1172 // The '=>' is outside the braces
1173 input.parse::<syn::Token![=>]>()?;
1174 let body: syn::Block = input.parse()?;
1175
1176 Ok(CommandInput {
1177 name,
1178 required_args,
1179 optional_args,
1180 body,
1181 })
1182 }
1183}