supp_macro/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, ItemFn, Lit, Path, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(item as ItemFn);
8 let fn_name = &input.sig.ident;
9 let block = &input.block;
10
11 let expanded = quote! {
12 #[sqlx::test(migrations = "../migrations")]
13 async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
14 setup().await;
15 DB_POOL.set(&pool);
16 #block
17 Ok(())
18 }
19 };
20
21 TokenStream::from(expanded)
22}
23
24#[proc_macro_derive(Builder, attributes(builder))]
25pub fn builder_macro(input: TokenStream) -> TokenStream {
26 // Parse the input tokens into a syntax tree
27 let input = parse_macro_input!(input as DeriveInput);
28 let name = &input.ident;
29 let generics = &input.generics; // Capture generics (including lifetimes)
30 let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
31
32 // Check for custom error_kind attribute
33 let mut error_kind = None;
34
35 // Parse attributes
36 for attr in &input.attrs {
37 if attr.path().is_ident("builder") {
38 attr.parse_nested_meta(|meta| {
39 if meta.path.is_ident("error_kind")
40 && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
41 {
42 error_kind = Some(lit_str.parse::<Path>().unwrap());
43 }
44 Ok(())
45 })
46 .unwrap();
47 }
48 }
49
50 // Set a default error kind if none is provided
51 let error_kind = error_kind.expect(
52 "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
53 );
54
55 // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
56 let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
57
58 let fields = if let Data::Struct(data) = &input.data {
59 if let Fields::Named(fields) = &data.fields {
60 fields.named.iter().collect::<Vec<_>>()
61 } else {
62 panic!("Builder macro only supports structs with named fields");
63 }
64 } else {
65 panic!("Builder macro only supports structs");
66 };
67
68 // Generate builder struct fields with the same generics (including lifetimes)
69 let builder_fields = fields.iter().map(|field| {
70 let field_name = &field.ident;
71 let field_ty = &field.ty;
72 let builder_field_type = quote! { Option<#field_ty> };
73 quote! {
74 #field_name: #builder_field_type
75 }
76 });
77
78 // Generate initialization in new()
79 let builder_fields_init = fields.iter().map(|field| {
80 let field_name = &field.ident;
81 quote! {
82 #field_name: None
83 }
84 });
85
86 // Generate setter methods
87 let setters = fields.iter().map(|field| {
88 let field_name = &field.ident;
89 let field_type = &field.ty;
90
91 if is_option_type(field_type) {
92 let inner_type = get_inner_type(field_type);
93 if is_string_type(&inner_type) {
94 // For Option<String>, accept &str
95 quote! {
96 pub fn #field_name(&mut self, value: &str) -> &mut Self {
97 self.#field_name = Some(Some(value.to_string()));
98 self
99 }
100 }
101 } else {
102 // For Option<T>, accept T directly
103 quote! {
104 pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
105 self.#field_name = Some(Some(value));
106 self
107 }
108 }
109 }
110 } else if is_string_type(field_type) {
111 // For String, accept &str
112 quote! {
113 pub fn #field_name(&mut self, value: &str) -> &mut Self {
114 self.#field_name = Some(value.to_string());
115 self
116 }
117 }
118 } else {
119 // For non-Option<T> and non-String fields, accept T directly
120 quote! {
121 pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
122 self.#field_name = Some(value);
123 self
124 }
125 }
126 }
127 });
128
129 // Generate code to check for missing required fields
130 let check_required_fields = fields
131 .iter()
132 .filter(|field| !is_option_type(&field.ty))
133 .map(|field| {
134 let field_name = &field.ident;
135 let field_name_str = field_name.as_ref().unwrap().to_string();
136 quote! {
137 if self.#field_name.is_none() {
138 missing_fields.push(#field_name_str);
139 }
140 }
141 });
142
143 // Generate build_fields
144 let build_fields = fields.iter().map(|field| {
145 let field_name = &field.ident;
146 if is_option_type(&field.ty) {
147 quote! {
148 #field_name: self.#field_name.clone().unwrap_or(None)
149 }
150 } else {
151 quote! {
152 #field_name: self.#field_name.clone().unwrap()
153 }
154 }
155 });
156
157 // Extract the lifetime parameters from generics for use in the builder struct
158 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
159
160 let expanded = quote! {
161 pub struct #builder_name #impl_generics #where_clause {
162 #(#builder_fields),*
163 }
164
165 impl #impl_generics #builder_name #ty_generics #where_clause {
166 pub fn new() -> Self {
167 Self {
168 #(#builder_fields_init),*
169 }
170 }
171
172 #(#setters)*
173
174 pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
175 let mut missing_fields = Vec::new();
176 #(#check_required_fields)*
177
178 if !missing_fields.is_empty() {
179 return Err(#error_kind::from(#custom_error_name::Build(format!(
180 "{} fields are missing: {}",
181 stringify!(#name),
182 missing_fields.join(", ")
183 ))));
184 }
185
186 Ok(#name {
187 #(#build_fields),*
188 })
189 }
190 }
191
192 impl #impl_generics #name #ty_generics #where_clause {
193 pub fn builder() -> #builder_name #ty_generics {
194 #builder_name::new()
195 }
196 }
197 };
198
199 TokenStream::from(expanded)
200}
201
202/// Helper function to determine if a type is an `Option<T>`
203fn is_option_type(ty: &syn::Type) -> bool {
204 matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
205}
206
207/// Helper function to get the inner type of an `Option<T>`
208fn get_inner_type(ty: &syn::Type) -> syn::Type {
209 if let syn::Type::Path(type_path) = ty
210 && let Some(segment) = type_path.path.segments.first()
211 && segment.ident == "Option"
212 && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
213 && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
214 {
215 return inner_type.clone();
216 }
217 ty.clone()
218}
219
220/// Helper function to check if the type is String
221fn is_string_type(ty: &syn::Type) -> bool {
222 if let syn::Type::Path(type_path) = ty
223 && let Some(segment) = type_path.path.segments.last()
224 {
225 return segment.ident == "String";
226 }
227 false
228}
229
230/// A procedural macro for generating typed Command implementations with compile-time validation.
231///
232/// This macro provides pure value-based argument passing with compile-time type safety by generating:
233/// - Typed Args structs with proper field types passed by value only
234/// - Commands that accept Args structs directly (no `HashMap` usage)
235/// - Individual typed variables available directly in command scope
236/// - Compile-time validation of argument types and required/optional fields
237/// - Zero runtime argument parsing or validation overhead
238///
239/// # Syntax
240///
241/// ```ignore
242/// command! {
243/// CommandName {
244/// #[required]
245/// arg_name: Type,
246/// #[optional]
247/// opt_name: Type,
248/// } => {
249/// // Command implementation body
250/// // Individual typed variables are available in scope
251/// }
252/// }
253/// ```
254///
255/// # Generated Code
256///
257/// The macro generates:
258/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
259/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
260/// - Individual typed variables extracted from the Args struct and available in the command body
261/// - Pure compile-time type validation with no runtime overhead
262///
263/// # Examples
264///
265/// ## Simple command with no arguments
266///
267/// ```rust
268/// # use supp_macro::command;
269/// # use async_trait::async_trait;
270/// #
271/// # #[derive(Debug)]
272/// # pub enum CmdError {
273/// # Args(String),
274/// # }
275/// #
276/// # impl std::fmt::Display for CmdError {
277/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278/// # match self {
279/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
280/// # }
281/// # }
282/// # }
283/// #
284/// # impl std::error::Error for CmdError {}
285/// #
286/// # #[derive(Debug)]
287/// # pub enum CmdResult {
288/// # String(String),
289/// # }
290/// #
291/// # #[derive(Debug, Default)]
292/// # pub struct CommandArgs {}
293/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
294/// #
295/// # #[async_trait]
296/// # pub trait Command: std::fmt::Debug {
297/// # type Args;
298/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
299/// # }
300///
301/// command! {
302/// GetVersion {
303/// } => {
304/// Ok(Some(CmdResult::String("1.0.0".to_string())))
305/// }
306/// }
307///
308/// # #[tokio::main]
309/// # async fn main() {
310/// let result = GetVersion::new().run().await.unwrap();
311/// # }
312/// ```
313///
314/// ## Command with required arguments (server-compatible types)
315///
316/// ```rust
317/// # use supp_macro::command;
318/// # use async_trait::async_trait;
319/// # use uuid::Uuid;
320/// # use num_rational::Rational64;
321/// #
322/// # #[derive(Debug, Clone)]
323/// # pub enum Argument {
324/// # String(String),
325/// # Uuid(Uuid),
326/// # Rational(Rational64),
327/// # }
328/// #
329/// # #[derive(Debug)]
330/// # pub enum CmdError {
331/// # Args(String),
332/// # }
333/// #
334/// # impl std::fmt::Display for CmdError {
335/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336/// # match self {
337/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
338/// # }
339/// # }
340/// # }
341/// #
342/// # impl std::error::Error for CmdError {}
343/// #
344/// # #[derive(Debug)]
345/// # pub enum CmdResult {
346/// # String(String),
347/// # }
348/// #
349/// # #[derive(Debug, Default)]
350/// # pub struct CommandArgs {
351/// # pub symbol: Option<String>,
352/// # pub name: Option<String>,
353/// # pub user_id: Option<uuid::Uuid>,
354/// # }
355/// # impl CommandArgs {
356/// # pub fn new() -> Self { Self::default() }
357/// # pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
358/// # pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
359/// # pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
360/// # }
361/// #
362/// # #[async_trait]
363/// # pub trait Command: std::fmt::Debug {
364/// # type Args;
365/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
366/// # }
367///
368/// // This creates a commodity in the financial system
369/// command! {
370/// CreateCommodity {
371/// #[required]
372/// symbol: String,
373/// #[required]
374/// name: String,
375/// #[required]
376/// user_id: Uuid,
377/// } => {
378/// // Individual typed variables are automatically available
379/// Ok(Some(CmdResult::String(format!(
380/// "Created commodity {} ({}) for user {}",
381/// name, symbol, user_id
382/// ))))
383/// }
384/// }
385///
386/// # #[tokio::main]
387/// # async fn main() {
388/// let result = CreateCommodity::new()
389/// .symbol("USD".to_string())
390/// .name("US Dollar".to_string())
391/// .user_id(uuid::Uuid::new_v4())
392/// .run()
393/// .await
394/// .unwrap();
395/// # }
396/// ```
397///
398/// ## Command with optional arguments
399///
400/// ```rust
401/// # use supp_macro::command;
402/// # use async_trait::async_trait;
403/// # use uuid::Uuid;
404/// #
405/// # #[derive(Debug, Clone)]
406/// # pub enum Argument {
407/// # String(String),
408/// # Uuid(Uuid),
409/// # }
410/// #
411/// # #[derive(Debug)]
412/// # pub enum CmdError {
413/// # Args(String),
414/// # }
415/// #
416/// # impl std::fmt::Display for CmdError {
417/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418/// # match self {
419/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
420/// # }
421/// # }
422/// # }
423/// #
424/// # impl std::error::Error for CmdError {}
425/// #
426/// # #[derive(Debug)]
427/// # pub enum CmdResult {
428/// # String(String),
429/// # }
430/// #
431/// # #[derive(Debug, Default)]
432/// # pub struct CommandArgs {
433/// # pub user_id: Option<uuid::Uuid>,
434/// # pub account: Option<String>,
435/// # }
436/// # impl CommandArgs {
437/// # pub fn new() -> Self { Self::default() }
438/// # pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
439/// # pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
440/// # }
441/// #
442/// # #[async_trait]
443/// # pub trait Command: std::fmt::Debug {
444/// # type Args;
445/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
446/// # }
447///
448/// command! {
449/// ListTransactions {
450/// #[required]
451/// user_id: Uuid,
452/// #[optional]
453/// account: String,
454/// } => {
455/// let filter = if let Some(account) = account {
456/// format!(" for account {}", account)
457/// } else {
458/// String::new()
459/// };
460/// Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
461/// }
462/// }
463/// ```
464///
465/// ## Command with mixed required and optional arguments
466///
467/// ```rust
468/// # use supp_macro::command;
469/// # use async_trait::async_trait;
470/// #
471/// # #[derive(Debug, Clone)]
472/// # pub enum Argument {
473/// # String(String),
474/// # Integer(i64),
475/// # Boolean(bool),
476/// # }
477/// #
478/// # #[derive(Debug)]
479/// # pub enum CmdError {
480/// # Args(String),
481/// # }
482/// #
483/// # impl std::fmt::Display for CmdError {
484/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485/// # match self {
486/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
487/// # }
488/// # }
489/// # }
490/// #
491/// # impl std::error::Error for CmdError {}
492/// #
493/// # #[derive(Debug)]
494/// # pub enum CmdResult {
495/// # Success(String),
496/// # }
497/// #
498/// # impl TryFrom<Argument> for String {
499/// # type Error = CmdError;
500/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
501/// # match arg {
502/// # Argument::String(s) => Ok(s),
503/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
504/// # }
505/// # }
506/// # }
507/// #
508/// # impl TryFrom<Argument> for i64 {
509/// # type Error = CmdError;
510/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
511/// # match arg {
512/// # Argument::Integer(i) => Ok(i),
513/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
514/// # }
515/// # }
516/// # }
517/// #
518/// # impl TryFrom<Argument> for bool {
519/// # type Error = CmdError;
520/// # fn try_from(arg: Argument) -> Result<Self, Self::Error> {
521/// # match arg {
522/// # Argument::Boolean(b) => Ok(b),
523/// # _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
524/// # }
525/// # }
526/// # }
527/// #
528/// # #[async_trait]
529/// # pub trait TypedCommand {
530/// # type Args;
531/// # async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
532/// # }
533/// #
534/// # #[derive(Debug, Default)]
535/// # pub struct CommandArgs {
536/// # pub user_id: Option<i64>,
537/// # pub username: Option<String>,
538/// # pub email: Option<String>,
539/// # pub is_admin: Option<bool>,
540/// # }
541/// # impl CommandArgs {
542/// # pub fn new() -> Self { Self::default() }
543/// # pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
544/// # pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
545/// # pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
546/// # pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
547/// # }
548/// #
549/// # #[async_trait]
550/// # pub trait Command {
551/// # type Args;
552/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
553/// # }
554///
555/// command! {
556/// CreateUserCommand {
557/// #[required]
558/// user_id: i64,
559/// #[required]
560/// username: String,
561/// #[optional]
562/// email: String,
563/// #[optional]
564/// is_admin: bool,
565/// } => {
566/// let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
567/// let admin_status = is_admin.unwrap_or(false);
568///
569/// let message = format!(
570/// "Created user {} (ID: {}, Email: {}, Admin: {})",
571/// username, user_id, email_str, admin_status
572/// );
573/// Ok(Some(CmdResult::Success(message)))
574/// }
575/// }
576///
577/// # #[tokio::main]
578/// # async fn main() {
579/// let result = CreateUserCommand::new()
580/// .user_id(123)
581/// .username("alice".to_string())
582/// .is_admin(true)
583/// .run()
584/// .await
585/// .unwrap();
586/// # }
587/// ```
588///
589/// ## Server-compatible Command implementation
590///
591/// ```rust
592/// # use supp_macro::command;
593/// # use async_trait::async_trait;
594/// #
595/// # #[derive(Debug, Clone)]
596/// # pub enum Argument {
597/// # String(String),
598/// # Integer(i64),
599/// # Boolean(bool),
600/// # }
601/// #
602/// # #[derive(Debug)]
603/// # pub enum CmdError {
604/// # Args(String),
605/// # }
606/// #
607/// # impl std::fmt::Display for CmdError {
608/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609/// # match self {
610/// # CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
611/// # }
612/// # }
613/// # }
614/// #
615/// # impl std::error::Error for CmdError {}
616/// #
617/// # #[derive(Debug)]
618/// # pub enum CmdResult {
619/// # Success(String),
620/// # }
621/// #
622/// # #[derive(Debug, Default)]
623/// # pub struct CommandArgs {
624/// # pub a: Option<i64>,
625/// # pub b: Option<i64>,
626/// # }
627/// # impl CommandArgs {
628/// # pub fn new() -> Self { Self::default() }
629/// # pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
630/// # pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
631/// # }
632/// #
633/// # #[async_trait]
634/// # pub trait Command {
635/// # type Args;
636/// # async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
637/// # }
638///
639/// command! {
640/// CalculateCommand {
641/// #[required]
642/// a: i64,
643/// #[required]
644/// b: i64,
645/// } => {
646/// let result = a + b;
647/// Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
648/// }
649/// }
650///
651/// # #[tokio::main]
652/// # async fn main() {
653/// let result = CalculateCommand::new()
654/// .a(10)
655/// .b(20)
656/// .run()
657/// .await
658/// .unwrap();
659/// # }
660/// ```
661///
662/// ## Migration from Manual Commands
663///
664/// The macro makes it easy to migrate from manual Command implementations:
665///
666/// ```rust,ignore
667/// // BEFORE: Manual implementation
668/// #[derive(Debug)]
669/// pub struct GetConfig;
670///
671/// #[async_trait]
672/// impl Command for GetConfig {
673/// async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
674/// if let Some(Argument::String(name)) = args.get("name") {
675/// Ok(config(name).await?.map(|v| CmdResult::String(v)))
676/// } else {
677/// Err(CmdError::Args("No field name provided".to_string()))
678/// }
679/// }
680/// }
681///
682/// // AFTER: Using the macro
683/// command! {
684/// GetConfig {
685/// #[required]
686/// name: String,
687/// } => {
688/// Ok(config(name).await?.map(|v| CmdResult::String(v)))
689/// }
690/// }
691/// ```
692///
693/// # Error Handling
694///
695/// The new pure typed system provides compile-time error prevention:
696///
697/// - Missing required arguments are compile-time errors (cannot compile without them)
698/// - Invalid argument types are compile-time errors (type checking at build time)
699/// - Runtime errors only occur in the command body logic itself
700/// - No argument validation overhead at runtime
701///
702/// # Supported Argument Types
703///
704/// The macro supports any Rust type for arguments:
705/// - `String` - Text arguments
706/// - `i64`, `u64`, etc. - Integer arguments
707/// - `bool` - Boolean arguments
708/// - `Rational64` - Rational number arguments (for financial precision)
709/// - `Uuid` - UUID arguments
710/// - `Vec<u8>` - Binary data arguments
711/// - `DateTime<Utc>` - `DateTime` arguments
712/// - Custom types - Any type can be used as an argument
713/// - `Option<T>` - Automatically applied for optional arguments
714///
715
716#[proc_macro]
717pub fn command(input: TokenStream) -> TokenStream {
718 let input = parse_macro_input!(input as CommandInput);
719
720 let name = &input.name;
721 let required_args = &input.required_args;
722 let optional_args = &input.optional_args;
723 let body = &input.body;
724
725 // Generate progressive runner types for all combinations of required fields
726 let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
727
728 // Generate the main command struct
729 let command_struct = quote! {
730 #[derive(Debug)]
731 pub struct #name;
732 };
733
734 // Generate the new() method that starts the builder chain
735 let new_method = generate_new_method(name, required_args.len(), optional_args);
736
737 let expanded = quote! {
738 #command_struct
739
740 #runner_types
741
742 #new_method
743 };
744
745 TokenStream::from(expanded)
746}
747
748/// Generate all possible runner type combinations for required fields
749fn generate_progressive_runner_types(
750 command_name: &syn::Ident,
751 required_args: &[(syn::Ident, syn::Type)],
752 optional_args: &[(syn::Ident, syn::Type)],
753 body: &syn::Block,
754) -> proc_macro2::TokenStream {
755 let num_required = required_args.len();
756 let total_combinations = 1 << num_required; // 2^num_required
757
758 let mut runner_types = Vec::new();
759
760 // Generate a runner type for each possible combination of set required fields
761 for combination in 0..total_combinations {
762 let runner_type = generate_single_runner_type(
763 command_name,
764 required_args,
765 optional_args,
766 combination,
767 num_required,
768 body,
769 );
770 runner_types.push(runner_type);
771 }
772
773 quote! {
774 #(#runner_types)*
775 }
776}
777
778/// Generate a single runner type for a specific combination of set fields
779fn generate_single_runner_type(
780 command_name: &syn::Ident,
781 required_args: &[(syn::Ident, syn::Type)],
782 optional_args: &[(syn::Ident, syn::Type)],
783 combination: usize,
784 num_required: usize,
785 body: &syn::Block,
786) -> proc_macro2::TokenStream {
787 // Create binary representation for the runner type name
788 let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
789 let runner_name = syn::Ident::new(
790 &format!("{command_name}Runner{binary_suffix}"),
791 command_name.span(),
792 );
793
794 // Determine which required fields are set in this combination
795 let mut struct_fields = Vec::new();
796 for (i, (field_name, field_type)) in required_args.iter().enumerate() {
797 if (combination >> i) & 1 == 1 {
798 // This required field is set in this combination
799 struct_fields.push(quote! {
800 pub #field_name: #field_type
801 });
802 }
803 }
804
805 // Always include optional fields in all runner types
806 for (field_name, field_type) in optional_args {
807 struct_fields.push(quote! {
808 pub #field_name: Option<#field_type>
809 });
810 }
811
812 // Generate the struct definition
813 let struct_def = if struct_fields.is_empty() {
814 quote! {
815 #[derive(Debug)]
816 pub struct #runner_name;
817 }
818 } else {
819 quote! {
820 #[derive(Debug)]
821 pub struct #runner_name {
822 #(#struct_fields),*
823 }
824 }
825 };
826
827 // Generate transition methods for this runner type
828 let transition_methods = generate_transition_methods(
829 command_name,
830 required_args,
831 optional_args,
832 combination,
833 num_required,
834 );
835
836 // Generate run method if this is the complete state (all required fields set)
837 let complete_mask = (1 << num_required) - 1;
838 let run_method = if combination == complete_mask {
839 generate_run_method(command_name, required_args, optional_args, body)
840 } else {
841 quote! {}
842 };
843
844 quote! {
845 #struct_def
846
847 impl #runner_name {
848 #transition_methods
849 #run_method
850 }
851 }
852}
853
854/// Generate transition methods for a runner type (field setters)
855fn generate_transition_methods(
856 command_name: &syn::Ident,
857 required_args: &[(syn::Ident, syn::Type)],
858 optional_args: &[(syn::Ident, syn::Type)],
859 current_combination: usize,
860 num_required: usize,
861) -> proc_macro2::TokenStream {
862 let mut methods = Vec::new();
863
864 // Generate setter methods for required fields not yet set
865 for (i, (field_name, field_type)) in required_args.iter().enumerate() {
866 if (current_combination >> i) & 1 == 0 {
867 // This required field is not set yet, generate a setter
868 let new_combination = current_combination | (1 << i);
869 let binary_suffix =
870 format!("{:0width$b}", new_combination, width = num_required.max(1));
871 let target_runner = syn::Ident::new(
872 &format!("{command_name}Runner{binary_suffix}"),
873 command_name.span(),
874 );
875
876 let method = generate_field_setter_method(
877 command_name,
878 required_args,
879 optional_args,
880 field_name,
881 field_type,
882 current_combination,
883 new_combination,
884 &target_runner,
885 num_required,
886 );
887 methods.push(method);
888 }
889 }
890
891 // Generate setter methods for optional fields (available on all runner types)
892 for (field_name, field_type) in optional_args {
893 let current_runner = syn::Ident::new(
894 &format!(
895 "{}Runner{:0width$b}",
896 command_name,
897 current_combination,
898 width = num_required.max(1)
899 ),
900 command_name.span(),
901 );
902
903 let method = generate_optional_field_setter(
904 field_name,
905 field_type,
906 ¤t_runner,
907 required_args,
908 optional_args,
909 current_combination,
910 num_required,
911 );
912 methods.push(method);
913 }
914
915 quote! {
916 #(#methods)*
917 }
918}
919
920/// Generate a setter method for a required field
921fn generate_field_setter_method(
922 _command_name: &syn::Ident,
923 required_args: &[(syn::Ident, syn::Type)],
924 optional_args: &[(syn::Ident, syn::Type)],
925 field_name: &syn::Ident,
926 field_type: &syn::Type,
927 current_combination: usize,
928 _new_combination: usize,
929 target_runner: &syn::Ident,
930 _num_required: usize,
931) -> proc_macro2::TokenStream {
932 // Generate field assignments for the new state
933 let mut field_assignments = Vec::new();
934
935 // Handle required fields
936 for (i, (req_field_name, _)) in required_args.iter().enumerate() {
937 if req_field_name == field_name {
938 // This is the field being set
939 field_assignments.push(quote! {
940 #req_field_name: value
941 });
942 } else if (current_combination >> i) & 1 == 1 {
943 // This field was already set, move it from self
944 field_assignments.push(quote! {
945 #req_field_name: self.#req_field_name
946 });
947 }
948 // Fields not set in either state are omitted
949 }
950
951 // Handle optional fields (always present, move from self)
952 for (opt_field_name, _) in optional_args {
953 field_assignments.push(quote! {
954 #opt_field_name: self.#opt_field_name
955 });
956 }
957
958 // Generate the constructor call
959 let constructor = if field_assignments.is_empty() {
960 quote! { #target_runner }
961 } else {
962 quote! {
963 #target_runner {
964 #(#field_assignments),*
965 }
966 }
967 };
968
969 quote! {
970 pub fn #field_name(self, value: #field_type) -> #target_runner {
971 #constructor
972 }
973 }
974}
975
976/// Generate a setter method for an optional field
977fn generate_optional_field_setter(
978 field_name: &syn::Ident,
979 field_type: &syn::Type,
980 current_runner: &syn::Ident,
981 required_args: &[(syn::Ident, syn::Type)],
982 optional_args: &[(syn::Ident, syn::Type)],
983 current_combination: usize,
984 _num_required: usize,
985) -> proc_macro2::TokenStream {
986 // Generate field assignments (same state, but update the optional field)
987 let mut field_assignments = Vec::new();
988
989 // Handle required fields (move from self if set)
990 for (i, (req_field_name, _)) in required_args.iter().enumerate() {
991 if (current_combination >> i) & 1 == 1 {
992 field_assignments.push(quote! {
993 #req_field_name: self.#req_field_name
994 });
995 }
996 }
997
998 // Handle optional fields
999 for (opt_field_name, _) in optional_args {
1000 if opt_field_name == field_name {
1001 // This is the field being set
1002 field_assignments.push(quote! {
1003 #opt_field_name: Some(value)
1004 });
1005 } else {
1006 // Move other optional fields from self
1007 field_assignments.push(quote! {
1008 #opt_field_name: self.#opt_field_name
1009 });
1010 }
1011 }
1012
1013 let constructor = if field_assignments.is_empty() {
1014 quote! { #current_runner }
1015 } else {
1016 quote! {
1017 #current_runner {
1018 #(#field_assignments),*
1019 }
1020 }
1021 };
1022
1023 quote! {
1024 pub fn #field_name(self, value: #field_type) -> #current_runner {
1025 #constructor
1026 }
1027 }
1028}
1029
1030/// Generate the run method for the complete runner state
1031fn generate_run_method(
1032 _command_name: &syn::Ident,
1033 required_args: &[(syn::Ident, syn::Type)],
1034 optional_args: &[(syn::Ident, syn::Type)],
1035 body: &syn::Block,
1036) -> proc_macro2::TokenStream {
1037 // Extract field values directly (no unwrap needed!)
1038 let mut variable_assignments = Vec::new();
1039
1040 // Required fields - direct field access
1041 for (field_name, _) in required_args {
1042 variable_assignments.push(quote! {
1043 let #field_name = self.#field_name;
1044 });
1045 }
1046
1047 // Optional fields - direct field access
1048 for (field_name, _) in optional_args {
1049 variable_assignments.push(quote! {
1050 let #field_name = self.#field_name;
1051 });
1052 }
1053
1054 quote! {
1055 pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
1056 // Zero runtime checks - direct field access!
1057 #(#variable_assignments)*
1058
1059 // Original command body
1060 #body
1061 }
1062 }
1063}
1064
1065/// Generate the `new()` method for the command
1066fn generate_new_method(
1067 command_name: &syn::Ident,
1068 num_required: usize,
1069 optional_args: &[(syn::Ident, syn::Type)],
1070) -> proc_macro2::TokenStream {
1071 let initial_runner = syn::Ident::new(
1072 &format!(
1073 "{}Runner{:0width$b}",
1074 command_name,
1075 0,
1076 width = num_required.max(1)
1077 ),
1078 command_name.span(),
1079 );
1080
1081 // Initial state has no required fields set, but has optional fields as None
1082 let constructor = if optional_args.is_empty() && num_required > 0 {
1083 // Unit struct (no fields at all in initial state)
1084 quote! { #initial_runner }
1085 } else {
1086 // Struct with optional fields initialized to None
1087 let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
1088 quote! { #field_name: None }
1089 });
1090
1091 if optional_field_inits.len() > 0 {
1092 quote! {
1093 #initial_runner {
1094 #(#optional_field_inits),*
1095 }
1096 }
1097 } else {
1098 quote! { #initial_runner }
1099 }
1100 };
1101
1102 quote! {
1103 impl #command_name {
1104 pub fn new() -> #initial_runner {
1105 #constructor
1106 }
1107 }
1108 }
1109}
1110
1111struct CommandInput {
1112 name: syn::Ident,
1113 required_args: Vec<(syn::Ident, syn::Type)>,
1114 optional_args: Vec<(syn::Ident, syn::Type)>,
1115 body: syn::Block,
1116}
1117
1118impl syn::parse::Parse for CommandInput {
1119 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1120 let name: syn::Ident = input.parse()?;
1121
1122 let content;
1123 syn::braced!(content in input);
1124
1125 let mut required_args = Vec::new();
1126 let mut optional_args = Vec::new();
1127
1128 while !content.is_empty() {
1129 // Parse attributes
1130 let mut is_optional = false;
1131 let mut is_required = false;
1132
1133 while content.peek(syn::Token![#]) {
1134 content.parse::<syn::Token![#]>()?;
1135 let attr_content;
1136 syn::bracketed!(attr_content in content);
1137 let attr_name: syn::Ident = attr_content.parse()?;
1138
1139 if attr_name == "optional" {
1140 is_optional = true;
1141 } else if attr_name == "required" {
1142 is_required = true;
1143 } else {
1144 return Err(syn::Error::new(
1145 attr_name.span(),
1146 "Unknown attribute. Use #[required] or #[optional]",
1147 ));
1148 }
1149 }
1150
1151 // Parse the field
1152 let arg_name: syn::Ident = content.parse()?;
1153 content.parse::<syn::Token![:]>()?;
1154 let arg_type: syn::Type = content.parse()?;
1155
1156 if content.peek(syn::Token![,]) {
1157 content.parse::<syn::Token![,]>()?;
1158 }
1159
1160 // Determine if optional (default to required if no attribute specified)
1161 let is_optional_field = if is_required && is_optional {
1162 return Err(syn::Error::new(
1163 arg_name.span(),
1164 "Field cannot be both #[required] and #[optional]",
1165 ));
1166 } else if is_optional {
1167 true
1168 } else {
1169 false // Default to required
1170 };
1171
1172 if is_optional_field {
1173 optional_args.push((arg_name, arg_type));
1174 } else {
1175 required_args.push((arg_name, arg_type));
1176 }
1177 }
1178
1179 // The '=>' is outside the braces
1180 input.parse::<syn::Token![=>]>()?;
1181 let body: syn::Block = input.parse()?;
1182
1183 Ok(CommandInput {
1184 name,
1185 required_args,
1186 optional_args,
1187 body,
1188 })
1189 }
1190}