1
use proc_macro::TokenStream;
2
use quote::quote;
3
use syn::{Data, DeriveInput, Fields, ItemFn, Lit, Path, parse_macro_input};
4

            
5
#[proc_macro_attribute]
6
93
pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
7
93
    let input = parse_macro_input!(item as ItemFn);
8
93
    let fn_name = &input.sig.ident;
9
93
    let block = &input.block;
10

            
11
93
    let expanded = quote! {
12
        #[sqlx::test(migrations = "../migrations")]
13
        async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
14
            setup().await;
15
            DB_POOL.set(&pool);
16
            #block
17
        Ok(())
18
        }
19
    };
20

            
21
93
    TokenStream::from(expanded)
22
93
}
23

            
24
#[proc_macro_derive(Builder, attributes(builder))]
25
48
pub fn builder_macro(input: TokenStream) -> TokenStream {
26
    // Parse the input tokens into a syntax tree
27
48
    let input = parse_macro_input!(input as DeriveInput);
28
48
    let name = &input.ident;
29
48
    let generics = &input.generics; // Capture generics (including lifetimes)
30
48
    let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
31

            
32
    // Check for custom error_kind attribute
33
48
    let mut error_kind = None;
34

            
35
    // Parse attributes
36
144
    for attr in &input.attrs {
37
144
        if attr.path().is_ident("builder") {
38
48
            attr.parse_nested_meta(|meta| {
39
48
                if meta.path.is_ident("error_kind")
40
48
                    && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
41
48
                {
42
48
                    error_kind = Some(lit_str.parse::<Path>().unwrap());
43
48
                }
44
48
                Ok(())
45
48
            })
46
48
            .unwrap();
47
96
        }
48
    }
49

            
50
    // Set a default error kind if none is provided
51
48
    let error_kind = error_kind.expect(
52
48
        "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
53
    );
54

            
55
    // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
56
48
    let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
57

            
58
48
    let fields = if let Data::Struct(data) = &input.data {
59
48
        if let Fields::Named(fields) = &data.fields {
60
48
            fields.named.iter().collect::<Vec<_>>()
61
        } else {
62
            panic!("Builder macro only supports structs with named fields");
63
        }
64
    } else {
65
        panic!("Builder macro only supports structs");
66
    };
67

            
68
    // Generate builder struct fields with the same generics (including lifetimes)
69
216
    let builder_fields = fields.iter().map(|field| {
70
216
        let field_name = &field.ident;
71
216
        let field_ty = &field.ty;
72
216
        let builder_field_type = quote! { Option<#field_ty> };
73
216
        quote! {
74
            #field_name: #builder_field_type
75
        }
76
216
    });
77

            
78
    // Generate initialization in new()
79
216
    let builder_fields_init = fields.iter().map(|field| {
80
216
        let field_name = &field.ident;
81
216
        quote! {
82
            #field_name: None
83
        }
84
216
    });
85

            
86
    // Generate setter methods
87
216
    let setters = fields.iter().map(|field| {
88
216
        let field_name = &field.ident;
89
216
        let field_type = &field.ty;
90

            
91
216
        if is_option_type(field_type) {
92
56
            let inner_type = get_inner_type(field_type);
93
56
            if is_string_type(&inner_type) {
94
                // For Option<String>, accept &str
95
8
                quote! {
96
                    pub fn #field_name(&mut self, value: &str) -> &mut Self {
97
                        self.#field_name = Some(Some(value.to_string()));
98
                        self
99
                    }
100
                }
101
            } else {
102
                // For Option<T>, accept T directly
103
48
                quote! {
104
                    pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
105
                        self.#field_name = Some(Some(value));
106
                        self
107
                    }
108
                }
109
            }
110
160
        } else if is_string_type(field_type) {
111
            // For String, accept &str
112
16
            quote! {
113
                pub fn #field_name(&mut self, value: &str) -> &mut Self {
114
                    self.#field_name = Some(value.to_string());
115
                    self
116
                }
117
            }
118
        } else {
119
            // For non-Option<T> and non-String fields, accept T directly
120
144
            quote! {
121
                pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
122
                    self.#field_name = Some(value);
123
                    self
124
                }
125
            }
126
        }
127
216
    });
128

            
129
    // Generate code to check for missing required fields
130
48
    let check_required_fields = fields
131
48
        .iter()
132
216
        .filter(|field| !is_option_type(&field.ty))
133
160
        .map(|field| {
134
160
            let field_name = &field.ident;
135
160
            let field_name_str = field_name.as_ref().unwrap().to_string();
136
160
            quote! {
137
                if self.#field_name.is_none() {
138
                    missing_fields.push(#field_name_str);
139
                }
140
            }
141
160
        });
142

            
143
    // Generate build_fields
144
216
    let build_fields = fields.iter().map(|field| {
145
216
        let field_name = &field.ident;
146
216
        if is_option_type(&field.ty) {
147
56
            quote! {
148
                #field_name: self.#field_name.clone().unwrap_or(None)
149
            }
150
        } else {
151
160
            quote! {
152
                #field_name: self.#field_name.clone().unwrap()
153
            }
154
        }
155
216
    });
156

            
157
    // Extract the lifetime parameters from generics for use in the builder struct
158
48
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
159

            
160
48
    let expanded = quote! {
161
        pub struct #builder_name #impl_generics #where_clause {
162
            #(#builder_fields),*
163
        }
164

            
165
        impl #impl_generics #builder_name #ty_generics #where_clause {
166
            pub fn new() -> Self {
167
                Self {
168
                    #(#builder_fields_init),*
169
                }
170
            }
171

            
172
            #(#setters)*
173

            
174
            pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
175
                let mut missing_fields = Vec::new();
176
                #(#check_required_fields)*
177

            
178
                if !missing_fields.is_empty() {
179
                    return Err(#error_kind::from(#custom_error_name::Build(format!(
180
                        "{} fields are missing: {}",
181
                        stringify!(#name),
182
                        missing_fields.join(", ")
183
                    ))));
184
                }
185

            
186
                Ok(#name {
187
                    #(#build_fields),*
188
                })
189
            }
190
        }
191

            
192
        impl #impl_generics #name #ty_generics #where_clause {
193
            pub fn builder() -> #builder_name #ty_generics {
194
                #builder_name::new()
195
            }
196
        }
197
    };
198

            
199
48
    TokenStream::from(expanded)
200
48
}
201

            
202
/// Helper function to determine if a type is an `Option<T>`
203
648
fn is_option_type(ty: &syn::Type) -> bool {
204
648
    matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
205
648
}
206

            
207
/// Helper function to get the inner type of an `Option<T>`
208
56
fn get_inner_type(ty: &syn::Type) -> syn::Type {
209
56
    if let syn::Type::Path(type_path) = ty
210
56
        && let Some(segment) = type_path.path.segments.first()
211
56
        && segment.ident == "Option"
212
56
        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
213
56
        && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
214
    {
215
56
        return inner_type.clone();
216
    }
217
    ty.clone()
218
56
}
219

            
220
/// Helper function to check if the type is String
221
216
fn is_string_type(ty: &syn::Type) -> bool {
222
216
    if let syn::Type::Path(type_path) = ty
223
216
        && let Some(segment) = type_path.path.segments.last()
224
    {
225
216
        return segment.ident == "String";
226
    }
227
    false
228
216
}
229

            
230
/// A procedural macro for generating typed Command implementations with compile-time validation.
231
///
232
/// This macro provides pure value-based argument passing with compile-time type safety by generating:
233
/// - Typed Args structs with proper field types passed by value only
234
/// - Commands that accept Args structs directly (no `HashMap` usage)
235
/// - Individual typed variables available directly in command scope
236
/// - Compile-time validation of argument types and required/optional fields
237
/// - Zero runtime argument parsing or validation overhead
238
///
239
/// # Syntax
240
///
241
/// ```ignore
242
/// command! {
243
///     CommandName {
244
///         #[required]
245
///         arg_name: Type,
246
///         #[optional]
247
///         opt_name: Type,
248
///     } => {
249
///         // Command implementation body
250
///         // Individual typed variables are available in scope
251
///     }
252
/// }
253
/// ```
254
///
255
/// # Generated Code
256
///
257
/// The macro generates:
258
/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
259
/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
260
/// - Individual typed variables extracted from the Args struct and available in the command body
261
/// - Pure compile-time type validation with no runtime overhead
262
///
263
/// # Examples
264
///
265
/// ## Simple command with no arguments
266
///
267
/// ```rust
268
/// # use supp_macro::command;
269
/// # use async_trait::async_trait;
270
/// #
271
/// # #[derive(Debug)]
272
/// # pub enum CmdError {
273
/// #     Args(String),
274
/// # }
275
/// #
276
/// # impl std::fmt::Display for CmdError {
277
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278
/// #         match self {
279
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
280
/// #         }
281
/// #     }
282
/// # }
283
/// #
284
/// # impl std::error::Error for CmdError {}
285
/// #
286
/// # #[derive(Debug)]
287
/// # pub enum CmdResult {
288
/// #     String(String),
289
/// # }
290
/// #
291
/// # #[derive(Debug, Default)]
292
/// # pub struct CommandArgs {}
293
/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
294
/// #
295
/// # #[async_trait]
296
/// # pub trait Command: std::fmt::Debug {
297
/// #     type Args;
298
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
299
/// # }
300
///
301
/// command! {
302
///     GetVersion {
303
///     } => {
304
///         Ok(Some(CmdResult::String("1.0.0".to_string())))
305
///     }
306
/// }
307
///
308
/// # #[tokio::main]
309
/// # async fn main() {
310
/// let result = GetVersion::new().run().await.unwrap();
311
/// # }
312
/// ```
313
///
314
/// ## Command with required arguments (server-compatible types)
315
///
316
/// ```rust
317
/// # use supp_macro::command;
318
/// # use async_trait::async_trait;
319
/// # use uuid::Uuid;
320
/// # use num_rational::Rational64;
321
/// #
322
/// # #[derive(Debug, Clone)]
323
/// # pub enum Argument {
324
/// #     String(String),
325
/// #     Uuid(Uuid),
326
/// #     Rational(Rational64),
327
/// # }
328
/// #
329
/// # #[derive(Debug)]
330
/// # pub enum CmdError {
331
/// #     Args(String),
332
/// # }
333
/// #
334
/// # impl std::fmt::Display for CmdError {
335
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336
/// #         match self {
337
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
338
/// #         }
339
/// #     }
340
/// # }
341
/// #
342
/// # impl std::error::Error for CmdError {}
343
/// #
344
/// # #[derive(Debug)]
345
/// # pub enum CmdResult {
346
/// #     String(String),
347
/// # }
348
/// #
349
/// # #[derive(Debug, Default)]
350
/// # pub struct CommandArgs {
351
/// #     pub symbol: Option<String>,
352
/// #     pub name: Option<String>,
353
/// #     pub user_id: Option<uuid::Uuid>,
354
/// # }
355
/// # impl CommandArgs {
356
/// #     pub fn new() -> Self { Self::default() }
357
/// #     pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
358
/// #     pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
359
/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
360
/// # }
361
/// #
362
/// # #[async_trait]
363
/// # pub trait Command: std::fmt::Debug {
364
/// #     type Args;
365
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
366
/// # }
367
///
368
/// // This creates a commodity in the financial system
369
/// command! {
370
///     CreateCommodity {
371
///         #[required]
372
///         symbol: String,
373
///         #[required]
374
///         name: String,
375
///         #[required]
376
///         user_id: Uuid,
377
///     } => {
378
///         // Individual typed variables are automatically available
379
///         Ok(Some(CmdResult::String(format!(
380
///             "Created commodity {} ({}) for user {}",
381
///             name, symbol, user_id
382
///         ))))
383
///     }
384
/// }
385
///
386
/// # #[tokio::main]
387
/// # async fn main() {
388
/// let result = CreateCommodity::new()
389
///     .symbol("USD".to_string())
390
///     .name("US Dollar".to_string())
391
///     .user_id(uuid::Uuid::new_v4())
392
///     .run()
393
///     .await
394
///     .unwrap();
395
/// # }
396
/// ```
397
///
398
/// ## Command with optional arguments
399
///
400
/// ```rust
401
/// # use supp_macro::command;
402
/// # use async_trait::async_trait;
403
/// # use uuid::Uuid;
404
/// #
405
/// # #[derive(Debug, Clone)]
406
/// # pub enum Argument {
407
/// #     String(String),
408
/// #     Uuid(Uuid),
409
/// # }
410
/// #
411
/// # #[derive(Debug)]
412
/// # pub enum CmdError {
413
/// #     Args(String),
414
/// # }
415
/// #
416
/// # impl std::fmt::Display for CmdError {
417
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
418
/// #         match self {
419
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
420
/// #         }
421
/// #     }
422
/// # }
423
/// #
424
/// # impl std::error::Error for CmdError {}
425
/// #
426
/// # #[derive(Debug)]
427
/// # pub enum CmdResult {
428
/// #     String(String),
429
/// # }
430
/// #
431
/// # #[derive(Debug, Default)]
432
/// # pub struct CommandArgs {
433
/// #     pub user_id: Option<uuid::Uuid>,
434
/// #     pub account: Option<String>,
435
/// # }
436
/// # impl CommandArgs {
437
/// #     pub fn new() -> Self { Self::default() }
438
/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
439
/// #     pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
440
/// # }
441
/// #
442
/// # #[async_trait]
443
/// # pub trait Command: std::fmt::Debug {
444
/// #     type Args;
445
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
446
/// # }
447
///
448
/// command! {
449
///     ListTransactions {
450
///         #[required]
451
///         user_id: Uuid,
452
///         #[optional]
453
///         account: String,
454
///     } => {
455
///         let filter = if let Some(account) = account {
456
///             format!(" for account {}", account)
457
///         } else {
458
///             String::new()
459
///         };
460
///         Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
461
///     }
462
/// }
463
/// ```
464
///
465
/// ## Command with mixed required and optional arguments
466
///
467
/// ```rust
468
/// # use supp_macro::command;
469
/// # use async_trait::async_trait;
470
/// #
471
/// # #[derive(Debug, Clone)]
472
/// # pub enum Argument {
473
/// #     String(String),
474
/// #     Integer(i64),
475
/// #     Boolean(bool),
476
/// # }
477
/// #
478
/// # #[derive(Debug)]
479
/// # pub enum CmdError {
480
/// #     Args(String),
481
/// # }
482
/// #
483
/// # impl std::fmt::Display for CmdError {
484
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485
/// #         match self {
486
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
487
/// #         }
488
/// #     }
489
/// # }
490
/// #
491
/// # impl std::error::Error for CmdError {}
492
/// #
493
/// # #[derive(Debug)]
494
/// # pub enum CmdResult {
495
/// #     Success(String),
496
/// # }
497
/// #
498
/// # impl TryFrom<Argument> for String {
499
/// #     type Error = CmdError;
500
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
501
/// #         match arg {
502
/// #             Argument::String(s) => Ok(s),
503
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
504
/// #         }
505
/// #     }
506
/// # }
507
/// #
508
/// # impl TryFrom<Argument> for i64 {
509
/// #     type Error = CmdError;
510
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
511
/// #         match arg {
512
/// #             Argument::Integer(i) => Ok(i),
513
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
514
/// #         }
515
/// #     }
516
/// # }
517
/// #
518
/// # impl TryFrom<Argument> for bool {
519
/// #     type Error = CmdError;
520
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
521
/// #         match arg {
522
/// #             Argument::Boolean(b) => Ok(b),
523
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
524
/// #         }
525
/// #     }
526
/// # }
527
/// #
528
/// # #[async_trait]
529
/// # pub trait TypedCommand {
530
/// #     type Args;
531
/// #     async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
532
/// # }
533
/// #
534
/// # #[derive(Debug, Default)]
535
/// # pub struct CommandArgs {
536
/// #     pub user_id: Option<i64>,
537
/// #     pub username: Option<String>,
538
/// #     pub email: Option<String>,
539
/// #     pub is_admin: Option<bool>,
540
/// # }
541
/// # impl CommandArgs {
542
/// #     pub fn new() -> Self { Self::default() }
543
/// #     pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
544
/// #     pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
545
/// #     pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
546
/// #     pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
547
/// # }
548
/// #
549
/// # #[async_trait]
550
/// # pub trait Command {
551
/// #     type Args;
552
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
553
/// # }
554
///
555
/// command! {
556
///     CreateUserCommand {
557
///         #[required]
558
///         user_id: i64,
559
///         #[required]
560
///         username: String,
561
///         #[optional]
562
///         email: String,
563
///         #[optional]
564
///         is_admin: bool,
565
///     } => {
566
///         let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
567
///         let admin_status = is_admin.unwrap_or(false);
568
///
569
///         let message = format!(
570
///             "Created user {} (ID: {}, Email: {}, Admin: {})",
571
///             username, user_id, email_str, admin_status
572
///         );
573
///         Ok(Some(CmdResult::Success(message)))
574
///     }
575
/// }
576
///
577
/// # #[tokio::main]
578
/// # async fn main() {
579
/// let result = CreateUserCommand::new()
580
///     .user_id(123)
581
///     .username("alice".to_string())
582
///     .is_admin(true)
583
///     .run()
584
///     .await
585
///     .unwrap();
586
/// # }
587
/// ```
588
///
589
/// ## Server-compatible Command implementation
590
///
591
/// ```rust
592
/// # use supp_macro::command;
593
/// # use async_trait::async_trait;
594
/// #
595
/// # #[derive(Debug, Clone)]
596
/// # pub enum Argument {
597
/// #     String(String),
598
/// #     Integer(i64),
599
/// #     Boolean(bool),
600
/// # }
601
/// #
602
/// # #[derive(Debug)]
603
/// # pub enum CmdError {
604
/// #     Args(String),
605
/// # }
606
/// #
607
/// # impl std::fmt::Display for CmdError {
608
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609
/// #         match self {
610
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
611
/// #         }
612
/// #     }
613
/// # }
614
/// #
615
/// # impl std::error::Error for CmdError {}
616
/// #
617
/// # #[derive(Debug)]
618
/// # pub enum CmdResult {
619
/// #     Success(String),
620
/// # }
621
/// #
622
/// # #[derive(Debug, Default)]
623
/// # pub struct CommandArgs {
624
/// #     pub a: Option<i64>,
625
/// #     pub b: Option<i64>,
626
/// # }
627
/// # impl CommandArgs {
628
/// #     pub fn new() -> Self { Self::default() }
629
/// #     pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
630
/// #     pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
631
/// # }
632
/// #
633
/// # #[async_trait]
634
/// # pub trait Command {
635
/// #     type Args;
636
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
637
/// # }
638
///
639
/// command! {
640
///     CalculateCommand {
641
///         #[required]
642
///         a: i64,
643
///         #[required]
644
///         b: i64,
645
///     } => {
646
///         let result = a + b;
647
///         Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
648
///     }
649
/// }
650
///
651
/// # #[tokio::main]
652
/// # async fn main() {
653
/// let result = CalculateCommand::new()
654
///     .a(10)
655
///     .b(20)
656
///     .run()
657
///     .await
658
///     .unwrap();
659
/// # }
660
/// ```
661
///
662
/// ## Migration from Manual Commands
663
///
664
/// The macro makes it easy to migrate from manual Command implementations:
665
///
666
/// ```rust,ignore
667
/// // BEFORE: Manual implementation
668
/// #[derive(Debug)]
669
/// pub struct GetConfig;
670
///
671
/// #[async_trait]
672
/// impl Command for GetConfig {
673
///     async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
674
///         if let Some(Argument::String(name)) = args.get("name") {
675
///             Ok(config(name).await?.map(|v| CmdResult::String(v)))
676
///         } else {
677
///             Err(CmdError::Args("No field name provided".to_string()))
678
///         }
679
///     }
680
/// }
681
///
682
/// // AFTER: Using the macro
683
/// command! {
684
///     GetConfig {
685
///         #[required]
686
///         name: String,
687
///     } => {
688
///         Ok(config(name).await?.map(|v| CmdResult::String(v)))
689
///     }
690
/// }
691
/// ```
692
///
693
/// # Error Handling
694
///
695
/// The new pure typed system provides compile-time error prevention:
696
///
697
/// - Missing required arguments are compile-time errors (cannot compile without them)
698
/// - Invalid argument types are compile-time errors (type checking at build time)
699
/// - Runtime errors only occur in the command body logic itself
700
/// - No argument validation overhead at runtime
701
///
702
/// # Supported Argument Types
703
///
704
/// The macro supports any Rust type for arguments:
705
/// - `String` - Text arguments
706
/// - `i64`, `u64`, etc. - Integer arguments
707
/// - `bool` - Boolean arguments
708
/// - `Rational64` - Rational number arguments (for financial precision)
709
/// - `Uuid` - UUID arguments
710
/// - `Vec<u8>` - Binary data arguments
711
/// - `DateTime<Utc>` - `DateTime` arguments
712
/// - Custom types - Any type can be used as an argument
713
/// - `Option<T>` - Automatically applied for optional arguments
714
///
715

            
716
#[proc_macro]
717
222
pub fn command(input: TokenStream) -> TokenStream {
718
222
    let input = parse_macro_input!(input as CommandInput);
719

            
720
222
    let name = &input.name;
721
222
    let required_args = &input.required_args;
722
222
    let optional_args = &input.optional_args;
723
222
    let body = &input.body;
724

            
725
    // Generate progressive runner types for all combinations of required fields
726
222
    let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
727

            
728
    // Generate the main command struct
729
222
    let command_struct = quote! {
730
        #[derive(Debug)]
731
        pub struct #name;
732
    };
733

            
734
    // Generate the new() method that starts the builder chain
735
222
    let new_method = generate_new_method(name, required_args.len(), optional_args);
736

            
737
222
    let expanded = quote! {
738
        #command_struct
739

            
740
        #runner_types
741

            
742
        #new_method
743
    };
744

            
745
222
    TokenStream::from(expanded)
746
222
}
747

            
748
/// Generate all possible runner type combinations for required fields
749
222
fn generate_progressive_runner_types(
750
222
    command_name: &syn::Ident,
751
222
    required_args: &[(syn::Ident, syn::Type)],
752
222
    optional_args: &[(syn::Ident, syn::Type)],
753
222
    body: &syn::Block,
754
222
) -> proc_macro2::TokenStream {
755
222
    let num_required = required_args.len();
756
222
    let total_combinations = 1 << num_required; // 2^num_required
757

            
758
222
    let mut runner_types = Vec::new();
759

            
760
    // Generate a runner type for each possible combination of set required fields
761
1103
    for combination in 0..total_combinations {
762
1103
        let runner_type = generate_single_runner_type(
763
1103
            command_name,
764
1103
            required_args,
765
1103
            optional_args,
766
1103
            combination,
767
1103
            num_required,
768
1103
            body,
769
1103
        );
770
1103
        runner_types.push(runner_type);
771
1103
    }
772

            
773
222
    quote! {
774
        #(#runner_types)*
775
    }
776
222
}
777

            
778
/// Generate a single runner type for a specific combination of set fields
779
1103
fn generate_single_runner_type(
780
1103
    command_name: &syn::Ident,
781
1103
    required_args: &[(syn::Ident, syn::Type)],
782
1103
    optional_args: &[(syn::Ident, syn::Type)],
783
1103
    combination: usize,
784
1103
    num_required: usize,
785
1103
    body: &syn::Block,
786
1103
) -> proc_macro2::TokenStream {
787
    // Create binary representation for the runner type name
788
1103
    let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
789
1103
    let runner_name = syn::Ident::new(
790
1103
        &format!("{command_name}Runner{binary_suffix}"),
791
1103
        command_name.span(),
792
    );
793

            
794
    // Determine which required fields are set in this combination
795
1103
    let mut struct_fields = Vec::new();
796
3250
    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
797
3250
        if (combination >> i) & 1 == 1 {
798
1625
            // This required field is set in this combination
799
1625
            struct_fields.push(quote! {
800
1625
                pub #field_name: #field_type
801
1625
            });
802
1625
        }
803
    }
804

            
805
    // Always include optional fields in all runner types
806
1458
    for (field_name, field_type) in optional_args {
807
1458
        struct_fields.push(quote! {
808
1458
            pub #field_name: Option<#field_type>
809
1458
        });
810
1458
    }
811

            
812
    // Generate the struct definition
813
1103
    let struct_def = if struct_fields.is_empty() {
814
125
        quote! {
815
            #[derive(Debug)]
816
            pub struct #runner_name;
817
        }
818
    } else {
819
978
        quote! {
820
            #[derive(Debug)]
821
            pub struct #runner_name {
822
                #(#struct_fields),*
823
            }
824
        }
825
    };
826

            
827
    // Generate transition methods for this runner type
828
1103
    let transition_methods = generate_transition_methods(
829
1103
        command_name,
830
1103
        required_args,
831
1103
        optional_args,
832
1103
        combination,
833
1103
        num_required,
834
    );
835

            
836
    // Generate run method if this is the complete state (all required fields set)
837
1103
    let complete_mask = (1 << num_required) - 1;
838
1103
    let run_method = if combination == complete_mask {
839
222
        generate_run_method(command_name, required_args, optional_args, body)
840
    } else {
841
881
        quote! {}
842
    };
843

            
844
1103
    quote! {
845
        #struct_def
846

            
847
        impl #runner_name {
848
            #transition_methods
849
            #run_method
850
        }
851
    }
852
1103
}
853

            
854
/// Generate transition methods for a runner type (field setters)
855
1103
fn generate_transition_methods(
856
1103
    command_name: &syn::Ident,
857
1103
    required_args: &[(syn::Ident, syn::Type)],
858
1103
    optional_args: &[(syn::Ident, syn::Type)],
859
1103
    current_combination: usize,
860
1103
    num_required: usize,
861
1103
) -> proc_macro2::TokenStream {
862
1103
    let mut methods = Vec::new();
863

            
864
    // Generate setter methods for required fields not yet set
865
3250
    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
866
3250
        if (current_combination >> i) & 1 == 0 {
867
1625
            // This required field is not set yet, generate a setter
868
1625
            let new_combination = current_combination | (1 << i);
869
1625
            let binary_suffix =
870
1625
                format!("{:0width$b}", new_combination, width = num_required.max(1));
871
1625
            let target_runner = syn::Ident::new(
872
1625
                &format!("{command_name}Runner{binary_suffix}"),
873
1625
                command_name.span(),
874
1625
            );
875
1625

            
876
1625
            let method = generate_field_setter_method(
877
1625
                command_name,
878
1625
                required_args,
879
1625
                optional_args,
880
1625
                field_name,
881
1625
                field_type,
882
1625
                current_combination,
883
1625
                new_combination,
884
1625
                &target_runner,
885
1625
                num_required,
886
1625
            );
887
1625
            methods.push(method);
888
1625
        }
889
    }
890

            
891
    // Generate setter methods for optional fields (available on all runner types)
892
1458
    for (field_name, field_type) in optional_args {
893
1458
        let current_runner = syn::Ident::new(
894
1458
            &format!(
895
1458
                "{}Runner{:0width$b}",
896
1458
                command_name,
897
1458
                current_combination,
898
1458
                width = num_required.max(1)
899
1458
            ),
900
1458
            command_name.span(),
901
1458
        );
902
1458

            
903
1458
        let method = generate_optional_field_setter(
904
1458
            field_name,
905
1458
            field_type,
906
1458
            &current_runner,
907
1458
            required_args,
908
1458
            optional_args,
909
1458
            current_combination,
910
1458
            num_required,
911
1458
        );
912
1458
        methods.push(method);
913
1458
    }
914

            
915
1103
    quote! {
916
        #(#methods)*
917
    }
918
1103
}
919

            
920
/// Generate a setter method for a required field
921
1625
fn generate_field_setter_method(
922
1625
    _command_name: &syn::Ident,
923
1625
    required_args: &[(syn::Ident, syn::Type)],
924
1625
    optional_args: &[(syn::Ident, syn::Type)],
925
1625
    field_name: &syn::Ident,
926
1625
    field_type: &syn::Type,
927
1625
    current_combination: usize,
928
1625
    _new_combination: usize,
929
1625
    target_runner: &syn::Ident,
930
1625
    _num_required: usize,
931
1625
) -> proc_macro2::TokenStream {
932
    // Generate field assignments for the new state
933
1625
    let mut field_assignments = Vec::new();
934

            
935
    // Handle required fields
936
5917
    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
937
5917
        if req_field_name == field_name {
938
1625
            // This is the field being set
939
1625
            field_assignments.push(quote! {
940
1625
                #req_field_name: value
941
1625
            });
942
4292
        } else if (current_combination >> i) & 1 == 1 {
943
2146
            // This field was already set, move it from self
944
2146
            field_assignments.push(quote! {
945
2146
                #req_field_name: self.#req_field_name
946
2146
            });
947
2146
        }
948
        // Fields not set in either state are omitted
949
    }
950

            
951
    // Handle optional fields (always present, move from self)
952
2413
    for (opt_field_name, _) in optional_args {
953
2413
        field_assignments.push(quote! {
954
2413
            #opt_field_name: self.#opt_field_name
955
2413
        });
956
2413
    }
957

            
958
    // Generate the constructor call
959
1625
    let constructor = if field_assignments.is_empty() {
960
        quote! { #target_runner }
961
    } else {
962
1625
        quote! {
963
            #target_runner {
964
                #(#field_assignments),*
965
            }
966
        }
967
    };
968

            
969
1625
    quote! {
970
        pub fn #field_name(self, value: #field_type) -> #target_runner {
971
            #constructor
972
        }
973
    }
974
1625
}
975

            
976
/// Generate a setter method for an optional field
977
1458
fn generate_optional_field_setter(
978
1458
    field_name: &syn::Ident,
979
1458
    field_type: &syn::Type,
980
1458
    current_runner: &syn::Ident,
981
1458
    required_args: &[(syn::Ident, syn::Type)],
982
1458
    optional_args: &[(syn::Ident, syn::Type)],
983
1458
    current_combination: usize,
984
1458
    _num_required: usize,
985
1458
) -> proc_macro2::TokenStream {
986
    // Generate field assignments (same state, but update the optional field)
987
1458
    let mut field_assignments = Vec::new();
988

            
989
    // Handle required fields (move from self if set)
990
4826
    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
991
4826
        if (current_combination >> i) & 1 == 1 {
992
2413
            field_assignments.push(quote! {
993
2413
                #req_field_name: self.#req_field_name
994
2413
            });
995
2413
        }
996
    }
997

            
998
    // Handle optional fields
999
3966
    for (opt_field_name, _) in optional_args {
3966
        if opt_field_name == field_name {
1458
            // This is the field being set
1458
            field_assignments.push(quote! {
1458
                #opt_field_name: Some(value)
1458
            });
2508
        } else {
2508
            // Move other optional fields from self
2508
            field_assignments.push(quote! {
2508
                #opt_field_name: self.#opt_field_name
2508
            });
2508
        }
    }
1458
    let constructor = if field_assignments.is_empty() {
        quote! { #current_runner }
    } else {
1458
        quote! {
            #current_runner {
                #(#field_assignments),*
            }
        }
    };
1458
    quote! {
        pub fn #field_name(self, value: #field_type) -> #current_runner {
            #constructor
        }
    }
1458
}
/// Generate the run method for the complete runner state
222
fn generate_run_method(
222
    _command_name: &syn::Ident,
222
    required_args: &[(syn::Ident, syn::Type)],
222
    optional_args: &[(syn::Ident, syn::Type)],
222
    body: &syn::Block,
222
) -> proc_macro2::TokenStream {
    // Extract field values directly (no unwrap needed!)
222
    let mut variable_assignments = Vec::new();
    // Required fields - direct field access
394
    for (field_name, _) in required_args {
394
        variable_assignments.push(quote! {
394
            let #field_name = self.#field_name;
394
        });
394
    }
    // Optional fields - direct field access
240
    for (field_name, _) in optional_args {
240
        variable_assignments.push(quote! {
240
            let #field_name = self.#field_name;
240
        });
240
    }
222
    quote! {
        pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
            // Zero runtime checks - direct field access!
            #(#variable_assignments)*
            // Original command body
            #body
        }
    }
222
}
/// Generate the `new()` method for the command
222
fn generate_new_method(
222
    command_name: &syn::Ident,
222
    num_required: usize,
222
    optional_args: &[(syn::Ident, syn::Type)],
222
) -> proc_macro2::TokenStream {
222
    let initial_runner = syn::Ident::new(
222
        &format!(
222
            "{}Runner{:0width$b}",
222
            command_name,
222
            0,
222
            width = num_required.max(1)
222
        ),
222
        command_name.span(),
    );
    // Initial state has no required fields set, but has optional fields as None
222
    let constructor = if optional_args.is_empty() && num_required > 0 {
        // Unit struct (no fields at all in initial state)
102
        quote! { #initial_runner }
    } else {
        // Struct with optional fields initialized to None
240
        let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
240
            quote! { #field_name: None }
240
        });
120
        if optional_field_inits.len() > 0 {
97
            quote! {
                #initial_runner {
                    #(#optional_field_inits),*
                }
            }
        } else {
23
            quote! { #initial_runner }
        }
    };
222
    quote! {
        impl #command_name {
            pub fn new() -> #initial_runner {
                #constructor
            }
        }
    }
222
}
struct CommandInput {
    name: syn::Ident,
    required_args: Vec<(syn::Ident, syn::Type)>,
    optional_args: Vec<(syn::Ident, syn::Type)>,
    body: syn::Block,
}
impl syn::parse::Parse for CommandInput {
222
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
222
        let name: syn::Ident = input.parse()?;
        let content;
222
        syn::braced!(content in input);
222
        let mut required_args = Vec::new();
222
        let mut optional_args = Vec::new();
856
        while !content.is_empty() {
            // Parse attributes
634
            let mut is_optional = false;
634
            let mut is_required = false;
1268
            while content.peek(syn::Token![#]) {
634
                content.parse::<syn::Token![#]>()?;
                let attr_content;
634
                syn::bracketed!(attr_content in content);
634
                let attr_name: syn::Ident = attr_content.parse()?;
634
                if attr_name == "optional" {
240
                    is_optional = true;
394
                } else if attr_name == "required" {
394
                    is_required = true;
394
                } else {
                    return Err(syn::Error::new(
                        attr_name.span(),
                        "Unknown attribute. Use #[required] or #[optional]",
                    ));
                }
            }
            // Parse the field
634
            let arg_name: syn::Ident = content.parse()?;
634
            content.parse::<syn::Token![:]>()?;
634
            let arg_type: syn::Type = content.parse()?;
634
            if content.peek(syn::Token![,]) {
634
                content.parse::<syn::Token![,]>()?;
            }
            // Determine if optional (default to required if no attribute specified)
634
            let is_optional_field = if is_required && is_optional {
                return Err(syn::Error::new(
                    arg_name.span(),
                    "Field cannot be both #[required] and #[optional]",
                ));
634
            } else if is_optional {
240
                true
            } else {
394
                false // Default to required
            };
634
            if is_optional_field {
240
                optional_args.push((arg_name, arg_type));
394
            } else {
394
                required_args.push((arg_name, arg_type));
394
            }
        }
        // The '=>' is outside the braces
222
        input.parse::<syn::Token![=>]>()?;
222
        let body: syn::Block = input.parse()?;
222
        Ok(CommandInput {
222
            name,
222
            required_args,
222
            optional_args,
222
            body,
222
        })
222
    }
}