1
use proc_macro::TokenStream;
2
use quote::quote;
3
use syn::{Data, DeriveInput, Fields, ItemFn, Lit, Path, parse_macro_input};
4

            
5
#[proc_macro_attribute]
6
120
pub fn local_db_sqlx_test(_attr: TokenStream, item: TokenStream) -> TokenStream {
7
120
    let input = parse_macro_input!(item as ItemFn);
8
120
    let fn_name = &input.sig.ident;
9
120
    let block = &input.block;
10

            
11
120
    let expanded = quote! {
12
        #[sqlx::test(migrations = "../migrations")]
13
        async fn #fn_name(pool: PgPool) -> Result<(), anyhow::Error> {
14
            setup().await;
15
            DB_POOL.set(&pool);
16
            #block
17
        Ok(())
18
        }
19
    };
20

            
21
120
    TokenStream::from(expanded)
22
120
}
23

            
24
#[proc_macro_derive(Builder, attributes(builder))]
25
54
pub fn builder_macro(input: TokenStream) -> TokenStream {
26
    // Parse the input tokens into a syntax tree
27
54
    let input = parse_macro_input!(input as DeriveInput);
28
54
    let name = &input.ident;
29
54
    let generics = &input.generics; // Capture generics (including lifetimes)
30
54
    let builder_name = syn::Ident::new(&format!("{name}Builder"), name.span());
31

            
32
    // Check for custom error_kind attribute
33
54
    let mut error_kind = None;
34

            
35
    // Parse attributes
36
216
    for attr in &input.attrs {
37
162
        if attr.path().is_ident("builder") {
38
78
            attr.parse_nested_meta(|meta| {
39
54
                if meta.path.is_ident("error_kind")
40
54
                    && let Ok(Lit::Str(lit_str)) = meta.value()?.parse()
41
54
                {
42
54
                    error_kind = Some(lit_str.parse::<Path>().unwrap());
43
54
                }
44
54
                Ok(())
45
54
            })
46
54
            .unwrap();
47
108
        }
48
    }
49

            
50
    // Set a default error kind if none is provided
51
54
    let error_kind = error_kind.expect(
52
54
        "Error kind (e.g., FinanceError) must be specified with #[builder(error_kind = \"...\")]",
53
    );
54

            
55
    // Define a custom error type based on the struct name, e.g., CommodityError for Commodity
56
54
    let custom_error_name = syn::Ident::new(&format!("{name}Error"), name.span());
57

            
58
54
    let fields = if let Data::Struct(data) = &input.data {
59
54
        if let Fields::Named(fields) = &data.fields {
60
54
            fields.named.iter().collect::<Vec<_>>()
61
        } else {
62
            panic!("Builder macro only supports structs with named fields");
63
        }
64
    } else {
65
        panic!("Builder macro only supports structs");
66
    };
67

            
68
    // Generate builder struct fields with the same generics (including lifetimes)
69
276
    let builder_fields = fields.iter().map(|field| {
70
252
        let field_name = &field.ident;
71
252
        let field_ty = &field.ty;
72
252
        let builder_field_type = quote! { Option<#field_ty> };
73
252
        quote! {
74
            #field_name: #builder_field_type
75
        }
76
252
    });
77

            
78
    // Generate initialization in new()
79
276
    let builder_fields_init = fields.iter().map(|field| {
80
252
        let field_name = &field.ident;
81
252
        quote! {
82
            #field_name: None
83
        }
84
252
    });
85

            
86
    // Generate setter methods
87
276
    let setters = fields.iter().map(|field| {
88
252
        let field_name = &field.ident;
89
252
        let field_type = &field.ty;
90

            
91
252
        if is_option_type(field_type) {
92
63
            let inner_type = get_inner_type(field_type);
93
63
            if is_string_type(&inner_type) {
94
                // For Option<String>, accept &str
95
9
                quote! {
96
                    pub fn #field_name(&mut self, value: &str) -> &mut Self {
97
                        self.#field_name = Some(Some(value.to_string()));
98
                        self
99
                    }
100
                }
101
            } else {
102
                // For Option<T>, accept T directly
103
54
                quote! {
104
                    pub fn #field_name(&mut self, value: #inner_type) -> &mut Self {
105
                        self.#field_name = Some(Some(value));
106
                        self
107
                    }
108
                }
109
            }
110
189
        } else if is_string_type(field_type) {
111
            // For String, accept &str
112
18
            quote! {
113
                pub fn #field_name(&mut self, value: &str) -> &mut Self {
114
                    self.#field_name = Some(value.to_string());
115
                    self
116
                }
117
            }
118
        } else {
119
            // For non-Option<T> and non-String fields, accept T directly
120
171
            quote! {
121
                pub fn #field_name(&mut self, value: #field_type) -> &mut Self {
122
                    self.#field_name = Some(value);
123
                    self
124
                }
125
            }
126
        }
127
252
    });
128

            
129
    // Generate code to check for missing required fields
130
54
    let check_required_fields = fields
131
54
        .iter()
132
276
        .filter(|field| !is_option_type(&field.ty))
133
213
        .map(|field| {
134
189
            let field_name = &field.ident;
135
189
            let field_name_str = field_name.as_ref().unwrap().to_string();
136
189
            quote! {
137
                if self.#field_name.is_none() {
138
                    missing_fields.push(#field_name_str);
139
                }
140
            }
141
189
        });
142

            
143
    // Generate build_fields
144
276
    let build_fields = fields.iter().map(|field| {
145
252
        let field_name = &field.ident;
146
252
        if is_option_type(&field.ty) {
147
63
            quote! {
148
                #field_name: self.#field_name.clone().unwrap_or(None)
149
            }
150
        } else {
151
189
            quote! {
152
                #field_name: self.#field_name.clone().unwrap()
153
            }
154
        }
155
252
    });
156

            
157
    // Extract the lifetime parameters from generics for use in the builder struct
158
54
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
159

            
160
54
    let expanded = quote! {
161
        pub struct #builder_name #impl_generics #where_clause {
162
            #(#builder_fields),*
163
        }
164

            
165
        impl #impl_generics #builder_name #ty_generics #where_clause {
166
            pub fn new() -> Self {
167
                Self {
168
                    #(#builder_fields_init),*
169
                }
170
            }
171

            
172
            #(#setters)*
173

            
174
            pub fn build(&self) -> Result<#name #ty_generics, #error_kind> {
175
                let mut missing_fields = Vec::new();
176
                #(#check_required_fields)*
177

            
178
                if !missing_fields.is_empty() {
179
                    return Err(#error_kind::from(#custom_error_name::Build(format!(
180
                        "{} fields are missing: {}",
181
                        stringify!(#name),
182
                        missing_fields.join(", ")
183
                    ))));
184
                }
185

            
186
                Ok(#name {
187
                    #(#build_fields),*
188
                })
189
            }
190
        }
191

            
192
        impl #impl_generics #name #ty_generics #where_clause {
193
            pub fn builder() -> #builder_name #ty_generics {
194
                #builder_name::new()
195
            }
196
        }
197
    };
198

            
199
54
    TokenStream::from(expanded)
200
54
}
201

            
202
/// Helper function to determine if a type is an `Option<T>`
203
756
fn is_option_type(ty: &syn::Type) -> bool {
204
1092
    matches!(ty, syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.iter().any(|segment| segment.ident == "Option"))
205
756
}
206

            
207
/// Helper function to get the inner type of an `Option<T>`
208
63
fn get_inner_type(ty: &syn::Type) -> syn::Type {
209
63
    if let syn::Type::Path(type_path) = ty
210
63
        && let Some(segment) = type_path.path.segments.first()
211
63
        && segment.ident == "Option"
212
63
        && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
213
63
        && let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
214
    {
215
63
        return inner_type.clone();
216
    }
217
    ty.clone()
218
63
}
219

            
220
/// Helper function to check if the type is String
221
252
fn is_string_type(ty: &syn::Type) -> bool {
222
252
    if let syn::Type::Path(type_path) = ty
223
252
        && let Some(segment) = type_path.path.segments.last()
224
    {
225
252
        return segment.ident == "String";
226
    }
227
    false
228
252
}
229

            
230
/// A procedural macro for generating typed Command implementations with compile-time validation.
231
///
232
/// This macro provides pure value-based argument passing with compile-time type safety by generating:
233
/// - Typed Args structs with proper field types passed by value only
234
/// - Commands that accept Args structs directly (no HashMap usage)
235
/// - Individual typed variables available directly in command scope
236
/// - Compile-time validation of argument types and required/optional fields
237
/// - Zero runtime argument parsing or validation overhead
238
///
239
/// # Syntax
240
///
241
/// ```ignore
242
/// command! {
243
///     CommandName {
244
///         #[required]
245
///         arg_name: Type,
246
///         #[optional]
247
///         opt_name: Type,
248
///     } => {
249
///         // Command implementation body
250
///         // Individual typed variables are available in scope
251
///     }
252
/// }
253
/// ```
254
///
255
/// # Generated Code
256
///
257
/// The macro generates:
258
/// - A `CommandNameArgs` struct with typed fields (required fields as `Type`, optional as `Option<Type>`)
259
/// - A `CommandName` struct implementing `Command` trait with typed `run(args: CommandNameArgs)` method
260
/// - Individual typed variables extracted from the Args struct and available in the command body
261
/// - Pure compile-time type validation with no runtime overhead
262
///
263
/// # Examples
264
///
265
/// ## Simple command with no arguments
266
///
267
/// ```rust
268
/// # use supp_macro::command;
269
/// # use async_trait::async_trait;
270
/// #
271
/// # #[derive(Debug)]
272
/// # pub enum CmdError {
273
/// #     Args(String),
274
/// # }
275
/// #
276
/// # impl std::fmt::Display for CmdError {
277
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278
/// #         match self {
279
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
280
/// #         }
281
/// #     }
282
/// # }
283
/// #
284
/// # impl std::error::Error for CmdError {}
285
/// #
286
/// # #[derive(Debug)]
287
/// # pub enum CmdResult {
288
/// #     String(String),
289
/// # }
290
/// #
291
/// # #[derive(Debug, Default)]
292
/// # pub struct CommandArgs {}
293
/// # impl CommandArgs { pub fn new() -> Self { Self::default() } }
294
/// #
295
/// # #[async_trait]
296
/// # pub trait Command: std::fmt::Debug {
297
/// #     type Args;
298
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
299
/// # }
300
///
301
/// command! {
302
///     GetVersion {
303
///     } => {
304
///         Ok(Some(CmdResult::String("1.0.0".to_string())))
305
///     }
306
/// }
307
///
308
/// # #[tokio::main]
309
/// # async fn main() {
310
/// let result = GetVersion::new().run().await.unwrap();
311
/// # }
312
/// ```
313
///
314
/// ## Command with required arguments (server-compatible types)
315
///
316
/// ```rust
317
/// # use supp_macro::command;
318
/// # use async_trait::async_trait;
319
/// # use uuid::Uuid;
320
/// # use num_rational::Rational64;
321
/// #
322
/// # #[derive(Debug, Clone)]
323
/// # pub enum Argument {
324
/// #     String(String),
325
/// #     Uuid(Uuid),
326
/// #     Rational(Rational64),
327
/// # }
328
/// #
329
/// # #[derive(Debug)]
330
/// # pub enum CmdError {
331
/// #     Args(String),
332
/// # }
333
/// #
334
/// # impl std::fmt::Display for CmdError {
335
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336
/// #         match self {
337
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
338
/// #         }
339
/// #     }
340
/// # }
341
/// #
342
/// # impl std::error::Error for CmdError {}
343
/// #
344
/// # #[derive(Debug)]
345
/// # pub enum CmdResult {
346
/// #     String(String),
347
/// # }
348
/// #
349
/// # #[derive(Debug, Default)]
350
/// # pub struct CommandArgs {
351
/// #     pub fraction: Option<num_rational::Rational64>,
352
/// #     pub symbol: Option<String>,
353
/// #     pub name: Option<String>,
354
/// #     pub user_id: Option<uuid::Uuid>,
355
/// # }
356
/// # impl CommandArgs {
357
/// #     pub fn new() -> Self { Self::default() }
358
/// #     pub fn fraction(mut self, v: num_rational::Rational64) -> Self { self.fraction = Some(v); self }
359
/// #     pub fn symbol(mut self, v: String) -> Self { self.symbol = Some(v); self }
360
/// #     pub fn name(mut self, v: String) -> Self { self.name = Some(v); self }
361
/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
362
/// # }
363
/// #
364
/// # #[async_trait]
365
/// # pub trait Command: std::fmt::Debug {
366
/// #     type Args;
367
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
368
/// # }
369
///
370
/// // This creates a commodity in the financial system
371
/// command! {
372
///     CreateCommodity {
373
///         #[required]
374
///         fraction: Rational64,
375
///         #[required]
376
///         symbol: String,
377
///         #[required]
378
///         name: String,
379
///         #[required]
380
///         user_id: Uuid,
381
///     } => {
382
///         // Individual typed variables are automatically available
383
///         Ok(Some(CmdResult::String(format!(
384
///             "Created commodity {} ({}) with fraction {} for user {}",
385
///             name, symbol, fraction, user_id
386
///         ))))
387
///     }
388
/// }
389
///
390
/// # #[tokio::main]
391
/// # async fn main() {
392
/// let result = CreateCommodity::new()
393
///     .fraction(num_rational::Rational64::new(1, 100))
394
///     .symbol("USD".to_string())
395
///     .name("US Dollar".to_string())
396
///     .user_id(uuid::Uuid::new_v4())
397
///     .run()
398
///     .await
399
///     .unwrap();
400
/// # }
401
/// ```
402
///
403
/// ## Command with optional arguments
404
///
405
/// ```rust
406
/// # use supp_macro::command;
407
/// # use async_trait::async_trait;
408
/// # use uuid::Uuid;
409
/// #
410
/// # #[derive(Debug, Clone)]
411
/// # pub enum Argument {
412
/// #     String(String),
413
/// #     Uuid(Uuid),
414
/// # }
415
/// #
416
/// # #[derive(Debug)]
417
/// # pub enum CmdError {
418
/// #     Args(String),
419
/// # }
420
/// #
421
/// # impl std::fmt::Display for CmdError {
422
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423
/// #         match self {
424
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
425
/// #         }
426
/// #     }
427
/// # }
428
/// #
429
/// # impl std::error::Error for CmdError {}
430
/// #
431
/// # #[derive(Debug)]
432
/// # pub enum CmdResult {
433
/// #     String(String),
434
/// # }
435
/// #
436
/// # #[derive(Debug, Default)]
437
/// # pub struct CommandArgs {
438
/// #     pub user_id: Option<uuid::Uuid>,
439
/// #     pub account: Option<String>,
440
/// # }
441
/// # impl CommandArgs {
442
/// #     pub fn new() -> Self { Self::default() }
443
/// #     pub fn user_id(mut self, v: uuid::Uuid) -> Self { self.user_id = Some(v); self }
444
/// #     pub fn account(mut self, v: String) -> Self { self.account = Some(v); self }
445
/// # }
446
/// #
447
/// # #[async_trait]
448
/// # pub trait Command: std::fmt::Debug {
449
/// #     type Args;
450
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
451
/// # }
452
///
453
/// command! {
454
///     ListTransactions {
455
///         #[required]
456
///         user_id: Uuid,
457
///         #[optional]
458
///         account: String,
459
///     } => {
460
///         let filter = if let Some(account) = account {
461
///             format!(" for account {}", account)
462
///         } else {
463
///             String::new()
464
///         };
465
///         Ok(Some(CmdResult::String(format!("Listing transactions for user {}{}", user_id, filter))))
466
///     }
467
/// }
468
/// ```
469
///
470
/// ## Command with mixed required and optional arguments
471
///
472
/// ```rust
473
/// # use supp_macro::command;
474
/// # use async_trait::async_trait;
475
/// #
476
/// # #[derive(Debug, Clone)]
477
/// # pub enum Argument {
478
/// #     String(String),
479
/// #     Integer(i64),
480
/// #     Boolean(bool),
481
/// # }
482
/// #
483
/// # #[derive(Debug)]
484
/// # pub enum CmdError {
485
/// #     Args(String),
486
/// # }
487
/// #
488
/// # impl std::fmt::Display for CmdError {
489
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490
/// #         match self {
491
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
492
/// #         }
493
/// #     }
494
/// # }
495
/// #
496
/// # impl std::error::Error for CmdError {}
497
/// #
498
/// # #[derive(Debug)]
499
/// # pub enum CmdResult {
500
/// #     Success(String),
501
/// # }
502
/// #
503
/// # impl TryFrom<Argument> for String {
504
/// #     type Error = CmdError;
505
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
506
/// #         match arg {
507
/// #             Argument::String(s) => Ok(s),
508
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to String", arg))),
509
/// #         }
510
/// #     }
511
/// # }
512
/// #
513
/// # impl TryFrom<Argument> for i64 {
514
/// #     type Error = CmdError;
515
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
516
/// #         match arg {
517
/// #             Argument::Integer(i) => Ok(i),
518
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to i64", arg))),
519
/// #         }
520
/// #     }
521
/// # }
522
/// #
523
/// # impl TryFrom<Argument> for bool {
524
/// #     type Error = CmdError;
525
/// #     fn try_from(arg: Argument) -> Result<Self, Self::Error> {
526
/// #         match arg {
527
/// #             Argument::Boolean(b) => Ok(b),
528
/// #             _ => Err(CmdError::Args(format!("Cannot convert {:?} to bool", arg))),
529
/// #         }
530
/// #     }
531
/// # }
532
/// #
533
/// # #[async_trait]
534
/// # pub trait TypedCommand {
535
/// #     type Args;
536
/// #     async fn run_typed(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
537
/// # }
538
/// #
539
/// # #[derive(Debug, Default)]
540
/// # pub struct CommandArgs {
541
/// #     pub user_id: Option<i64>,
542
/// #     pub username: Option<String>,
543
/// #     pub email: Option<String>,
544
/// #     pub is_admin: Option<bool>,
545
/// # }
546
/// # impl CommandArgs {
547
/// #     pub fn new() -> Self { Self::default() }
548
/// #     pub fn user_id(mut self, v: i64) -> Self { self.user_id = Some(v); self }
549
/// #     pub fn username(mut self, v: String) -> Self { self.username = Some(v); self }
550
/// #     pub fn email(mut self, v: String) -> Self { self.email = Some(v); self }
551
/// #     pub fn is_admin(mut self, v: bool) -> Self { self.is_admin = Some(v); self }
552
/// # }
553
/// #
554
/// # #[async_trait]
555
/// # pub trait Command {
556
/// #     type Args;
557
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
558
/// # }
559
///
560
/// command! {
561
///     CreateUserCommand {
562
///         #[required]
563
///         user_id: i64,
564
///         #[required]
565
///         username: String,
566
///         #[optional]
567
///         email: String,
568
///         #[optional]
569
///         is_admin: bool,
570
///     } => {
571
///         let email_str = email.map_or_else(|| format!("{}@example.com", username), |s| s.to_string());
572
///         let admin_status = is_admin.unwrap_or(false);
573
///
574
///         let message = format!(
575
///             "Created user {} (ID: {}, Email: {}, Admin: {})",
576
///             username, user_id, email_str, admin_status
577
///         );
578
///         Ok(Some(CmdResult::Success(message)))
579
///     }
580
/// }
581
///
582
/// # #[tokio::main]
583
/// # async fn main() {
584
/// let result = CreateUserCommand::new()
585
///     .user_id(123)
586
///     .username("alice".to_string())
587
///     .is_admin(true)
588
///     .run()
589
///     .await
590
///     .unwrap();
591
/// # }
592
/// ```
593
///
594
/// ## Server-compatible Command implementation
595
///
596
/// ```rust
597
/// # use supp_macro::command;
598
/// # use async_trait::async_trait;
599
/// #
600
/// # #[derive(Debug, Clone)]
601
/// # pub enum Argument {
602
/// #     String(String),
603
/// #     Integer(i64),
604
/// #     Boolean(bool),
605
/// # }
606
/// #
607
/// # #[derive(Debug)]
608
/// # pub enum CmdError {
609
/// #     Args(String),
610
/// # }
611
/// #
612
/// # impl std::fmt::Display for CmdError {
613
/// #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614
/// #         match self {
615
/// #             CmdError::Args(msg) => write!(f, "Argument error: {}", msg),
616
/// #         }
617
/// #     }
618
/// # }
619
/// #
620
/// # impl std::error::Error for CmdError {}
621
/// #
622
/// # #[derive(Debug)]
623
/// # pub enum CmdResult {
624
/// #     Success(String),
625
/// # }
626
/// #
627
/// # #[derive(Debug, Default)]
628
/// # pub struct CommandArgs {
629
/// #     pub a: Option<i64>,
630
/// #     pub b: Option<i64>,
631
/// # }
632
/// # impl CommandArgs {
633
/// #     pub fn new() -> Self { Self::default() }
634
/// #     pub fn a(mut self, v: i64) -> Self { self.a = Some(v); self }
635
/// #     pub fn b(mut self, v: i64) -> Self { self.b = Some(v); self }
636
/// # }
637
/// #
638
/// # #[async_trait]
639
/// # pub trait Command {
640
/// #     type Args;
641
/// #     async fn run(&self, args: Self::Args) -> Result<Option<CmdResult>, CmdError>;
642
/// # }
643
///
644
/// command! {
645
///     CalculateCommand {
646
///         #[required]
647
///         a: i64,
648
///         #[required]
649
///         b: i64,
650
///     } => {
651
///         let result = a + b;
652
///         Ok(Some(CmdResult::Success(format!("{} + {} = {}", a, b, result))))
653
///     }
654
/// }
655
///
656
/// # #[tokio::main]
657
/// # async fn main() {
658
/// let result = CalculateCommand::new()
659
///     .a(10)
660
///     .b(20)
661
///     .run()
662
///     .await
663
///     .unwrap();
664
/// # }
665
/// ```
666
///
667
/// ## Migration from Manual Commands
668
///
669
/// The macro makes it easy to migrate from manual Command implementations:
670
///
671
/// ```rust,ignore
672
/// // BEFORE: Manual implementation
673
/// #[derive(Debug)]
674
/// pub struct GetConfig;
675
///
676
/// #[async_trait]
677
/// impl Command for GetConfig {
678
///     async fn run<'a>(&self, args: &'a HashMap<&'a str, &'a Argument>) -> Result<Option<CmdResult>, CmdError> {
679
///         if let Some(Argument::String(name)) = args.get("name") {
680
///             Ok(config(name).await?.map(|v| CmdResult::String(v)))
681
///         } else {
682
///             Err(CmdError::Args("No field name provided".to_string()))
683
///         }
684
///     }
685
/// }
686
///
687
/// // AFTER: Using the macro
688
/// command! {
689
///     GetConfig {
690
///         #[required]
691
///         name: String,
692
///     } => {
693
///         Ok(config(name).await?.map(|v| CmdResult::String(v)))
694
///     }
695
/// }
696
/// ```
697
///
698
/// # Error Handling
699
///
700
/// The new pure typed system provides compile-time error prevention:
701
///
702
/// - Missing required arguments are compile-time errors (cannot compile without them)
703
/// - Invalid argument types are compile-time errors (type checking at build time)
704
/// - Runtime errors only occur in the command body logic itself
705
/// - No argument validation overhead at runtime
706
///
707
/// # Supported Argument Types
708
///
709
/// The macro supports any Rust type for arguments:
710
/// - `String` - Text arguments
711
/// - `i64`, `u64`, etc. - Integer arguments
712
/// - `bool` - Boolean arguments
713
/// - `Rational64` - Rational number arguments (for financial precision)
714
/// - `Uuid` - UUID arguments
715
/// - `Vec<u8>` - Binary data arguments
716
/// - `DateTime<Utc>` - DateTime arguments
717
/// - Custom types - Any type can be used as an argument
718
/// - `Option<T>` - Automatically applied for optional arguments
719
///
720

            
721
#[proc_macro]
722
215
pub fn command(input: TokenStream) -> TokenStream {
723
215
    let input = parse_macro_input!(input as CommandInput);
724

            
725
215
    let name = &input.name;
726
215
    let required_args = &input.required_args;
727
215
    let optional_args = &input.optional_args;
728
215
    let body = &input.body;
729

            
730
    // Generate progressive runner types for all combinations of required fields
731
215
    let runner_types = generate_progressive_runner_types(name, required_args, optional_args, body);
732

            
733
    // Generate the main command struct
734
215
    let command_struct = quote! {
735
        #[derive(Debug)]
736
        pub struct #name;
737
    };
738

            
739
    // Generate the new() method that starts the builder chain
740
215
    let new_method = generate_new_method(name, required_args.len(), optional_args);
741

            
742
215
    let expanded = quote! {
743
        #command_struct
744

            
745
        #runner_types
746

            
747
        #new_method
748
    };
749

            
750
215
    TokenStream::from(expanded)
751
215
}
752

            
753
/// Generate all possible runner type combinations for required fields
754
215
fn generate_progressive_runner_types(
755
215
    command_name: &syn::Ident,
756
215
    required_args: &[(syn::Ident, syn::Type)],
757
215
    optional_args: &[(syn::Ident, syn::Type)],
758
215
    body: &syn::Block,
759
215
) -> proc_macro2::TokenStream {
760
215
    let num_required = required_args.len();
761
215
    let total_combinations = 1 << num_required; // 2^num_required
762

            
763
215
    let mut runner_types = Vec::new();
764

            
765
    // Generate a runner type for each possible combination of set required fields
766
1024
    for combination in 0..total_combinations {
767
1024
        let runner_type = generate_single_runner_type(
768
1024
            command_name,
769
1024
            required_args,
770
1024
            optional_args,
771
1024
            combination,
772
1024
            num_required,
773
1024
            body,
774
1024
        );
775
1024
        runner_types.push(runner_type);
776
1024
    }
777

            
778
215
    quote! {
779
        #(#runner_types)*
780
    }
781
215
}
782

            
783
/// Generate a single runner type for a specific combination of set fields
784
1024
fn generate_single_runner_type(
785
1024
    command_name: &syn::Ident,
786
1024
    required_args: &[(syn::Ident, syn::Type)],
787
1024
    optional_args: &[(syn::Ident, syn::Type)],
788
1024
    combination: usize,
789
1024
    num_required: usize,
790
1024
    body: &syn::Block,
791
1024
) -> proc_macro2::TokenStream {
792
    // Create binary representation for the runner type name
793
1024
    let binary_suffix = format!("{:0width$b}", combination, width = num_required.max(1));
794
1024
    let runner_name = syn::Ident::new(
795
1024
        &format!("{}Runner{}", command_name, binary_suffix),
796
1024
        command_name.span(),
797
    );
798

            
799
    // Determine which required fields are set in this combination
800
1024
    let mut struct_fields = Vec::new();
801
3080
    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
802
3080
        if (combination >> i) & 1 == 1 {
803
1540
            // This required field is set in this combination
804
1540
            struct_fields.push(quote! {
805
1540
                pub #field_name: #field_type
806
1540
            });
807
1540
        }
808
    }
809

            
810
    // Always include optional fields in all runner types
811
2078
    for (field_name, field_type) in optional_args {
812
1054
        struct_fields.push(quote! {
813
1054
            pub #field_name: Option<#field_type>
814
1054
        });
815
1054
    }
816

            
817
    // Generate the struct definition
818
1024
    let struct_def = if struct_fields.is_empty() {
819
134
        quote! {
820
            #[derive(Debug)]
821
            pub struct #runner_name;
822
        }
823
    } else {
824
890
        quote! {
825
            #[derive(Debug)]
826
            pub struct #runner_name {
827
                #(#struct_fields),*
828
            }
829
        }
830
    };
831

            
832
    // Generate transition methods for this runner type
833
1024
    let transition_methods = generate_transition_methods(
834
1024
        command_name,
835
1024
        required_args,
836
1024
        optional_args,
837
1024
        combination,
838
1024
        num_required,
839
    );
840

            
841
    // Generate run method if this is the complete state (all required fields set)
842
1024
    let complete_mask = (1 << num_required) - 1;
843
1024
    let run_method = if combination == complete_mask {
844
215
        generate_run_method(command_name, required_args, optional_args, body)
845
    } else {
846
809
        quote! {}
847
    };
848

            
849
1024
    quote! {
850
        #struct_def
851

            
852
        impl #runner_name {
853
            #transition_methods
854
            #run_method
855
        }
856
    }
857
1024
}
858

            
859
/// Generate transition methods for a runner type (field setters)
860
1024
fn generate_transition_methods(
861
1024
    command_name: &syn::Ident,
862
1024
    required_args: &[(syn::Ident, syn::Type)],
863
1024
    optional_args: &[(syn::Ident, syn::Type)],
864
1024
    current_combination: usize,
865
1024
    num_required: usize,
866
1024
) -> proc_macro2::TokenStream {
867
1024
    let mut methods = Vec::new();
868

            
869
    // Generate setter methods for required fields not yet set
870
3080
    for (i, (field_name, field_type)) in required_args.iter().enumerate() {
871
3080
        if (current_combination >> i) & 1 == 0 {
872
1540
            // This required field is not set yet, generate a setter
873
1540
            let new_combination = current_combination | (1 << i);
874
1540
            let binary_suffix =
875
1540
                format!("{:0width$b}", new_combination, width = num_required.max(1));
876
1540
            let target_runner = syn::Ident::new(
877
1540
                &format!("{}Runner{}", command_name, binary_suffix),
878
1540
                command_name.span(),
879
1540
            );
880
1540

            
881
1540
            let method = generate_field_setter_method(
882
1540
                command_name,
883
1540
                required_args,
884
1540
                optional_args,
885
1540
                field_name,
886
1540
                field_type,
887
1540
                current_combination,
888
1540
                new_combination,
889
1540
                &target_runner,
890
1540
                num_required,
891
1540
            );
892
1540
            methods.push(method);
893
1540
        }
894
    }
895

            
896
    // Generate setter methods for optional fields (available on all runner types)
897
2078
    for (field_name, field_type) in optional_args {
898
1054
        let current_runner = syn::Ident::new(
899
1054
            &format!(
900
1054
                "{}Runner{:0width$b}",
901
1054
                command_name,
902
1054
                current_combination,
903
1054
                width = num_required.max(1)
904
1054
            ),
905
1054
            command_name.span(),
906
1054
        );
907
1054

            
908
1054
        let method = generate_optional_field_setter(
909
1054
            field_name,
910
1054
            field_type,
911
1054
            &current_runner,
912
1054
            required_args,
913
1054
            optional_args,
914
1054
            current_combination,
915
1054
            num_required,
916
1054
        );
917
1054
        methods.push(method);
918
1054
    }
919

            
920
1024
    quote! {
921
        #(#methods)*
922
    }
923
1024
}
924

            
925
/// Generate a setter method for a required field
926
1540
fn generate_field_setter_method(
927
1540
    command_name: &syn::Ident,
928
1540
    required_args: &[(syn::Ident, syn::Type)],
929
1540
    optional_args: &[(syn::Ident, syn::Type)],
930
1540
    field_name: &syn::Ident,
931
1540
    field_type: &syn::Type,
932
1540
    current_combination: usize,
933
1540
    new_combination: usize,
934
1540
    target_runner: &syn::Ident,
935
1540
    num_required: usize,
936
1540
) -> proc_macro2::TokenStream {
937
    // Generate field assignments for the new state
938
1540
    let mut field_assignments = Vec::new();
939

            
940
    // Handle required fields
941
5888
    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
942
5888
        if req_field_name == field_name {
943
1540
            // This is the field being set
944
1540
            field_assignments.push(quote! {
945
1540
                #req_field_name: value
946
1540
            });
947
4348
        } else if (current_combination >> i) & 1 == 1 {
948
2174
            // This field was already set, move it from self
949
2174
            field_assignments.push(quote! {
950
2174
                #req_field_name: self.#req_field_name
951
2174
            });
952
2174
        }
953
        // Fields not set in either state are omitted
954
    }
955

            
956
    // Handle optional fields (always present, move from self)
957
3403
    for (opt_field_name, _) in optional_args {
958
1863
        field_assignments.push(quote! {
959
1863
            #opt_field_name: self.#opt_field_name
960
1863
        });
961
1863
    }
962

            
963
    // Generate the constructor call
964
1540
    let constructor = if field_assignments.is_empty() {
965
        quote! { #target_runner }
966
    } else {
967
1540
        quote! {
968
            #target_runner {
969
                #(#field_assignments),*
970
            }
971
        }
972
    };
973

            
974
1540
    quote! {
975
        pub fn #field_name(self, value: #field_type) -> #target_runner {
976
            #constructor
977
        }
978
    }
979
1540
}
980

            
981
/// Generate a setter method for an optional field
982
1054
fn generate_optional_field_setter(
983
1054
    field_name: &syn::Ident,
984
1054
    field_type: &syn::Type,
985
1054
    current_runner: &syn::Ident,
986
1054
    required_args: &[(syn::Ident, syn::Type)],
987
1054
    optional_args: &[(syn::Ident, syn::Type)],
988
1054
    current_combination: usize,
989
1054
    num_required: usize,
990
1054
) -> proc_macro2::TokenStream {
991
    // Generate field assignments (same state, but update the optional field)
992
1054
    let mut field_assignments = Vec::new();
993

            
994
    // Handle required fields (move from self if set)
995
3726
    for (i, (req_field_name, _)) in required_args.iter().enumerate() {
996
3726
        if (current_combination >> i) & 1 == 1 {
997
1863
            field_assignments.push(quote! {
998
1863
                #req_field_name: self.#req_field_name
999
1863
            });
1863
        }
    }
    // Handle optional fields
3924
    for (opt_field_name, _) in optional_args {
2870
        if opt_field_name == field_name {
1054
            // This is the field being set
1054
            field_assignments.push(quote! {
1054
                #opt_field_name: Some(value)
1054
            });
1816
        } else {
1816
            // Move other optional fields from self
1816
            field_assignments.push(quote! {
1816
                #opt_field_name: self.#opt_field_name
1816
            });
1816
        }
    }
1054
    let constructor = if field_assignments.is_empty() {
        quote! { #current_runner }
    } else {
1054
        quote! {
            #current_runner {
                #(#field_assignments),*
            }
        }
    };
1054
    quote! {
        pub fn #field_name(self, value: #field_type) -> #current_runner {
            #constructor
        }
    }
1054
}
/// Generate the run method for the complete runner state
215
fn generate_run_method(
215
    command_name: &syn::Ident,
215
    required_args: &[(syn::Ident, syn::Type)],
215
    optional_args: &[(syn::Ident, syn::Type)],
215
    body: &syn::Block,
215
) -> proc_macro2::TokenStream {
    // Extract field values directly (no unwrap needed!)
215
    let mut variable_assignments = Vec::new();
    // Required fields - direct field access
566
    for (field_name, _) in required_args {
351
        variable_assignments.push(quote! {
351
            let #field_name = self.#field_name;
351
        });
351
    }
    // Optional fields - direct field access
382
    for (field_name, _) in optional_args {
167
        variable_assignments.push(quote! {
167
            let #field_name = self.#field_name;
167
        });
167
    }
215
    quote! {
        pub async fn run(self) -> Result<Option<CmdResult>, CmdError> {
            // Zero runtime checks - direct field access!
            #(#variable_assignments)*
            // Original command body
            #body
        }
    }
215
}
/// Generate the new() method for the command
215
fn generate_new_method(
215
    command_name: &syn::Ident,
215
    num_required: usize,
215
    optional_args: &[(syn::Ident, syn::Type)],
215
) -> proc_macro2::TokenStream {
215
    let initial_runner = syn::Ident::new(
215
        &format!(
215
            "{}Runner{:0width$b}",
215
            command_name,
215
            0,
215
            width = num_required.max(1)
215
        ),
215
        command_name.span(),
    );
    // Initial state has no required fields set, but has optional fields as None
215
    let constructor = if optional_args.is_empty() && num_required > 0 {
        // Unit struct (no fields at all in initial state)
102
        quote! { #initial_runner }
    } else {
        // Struct with optional fields initialized to None
219
        let optional_field_inits = optional_args.iter().map(|(field_name, _)| {
167
            quote! { #field_name: None }
167
        });
113
        if optional_field_inits.len() > 0 {
81
            quote! {
                #initial_runner {
                    #(#optional_field_inits),*
                }
            }
        } else {
32
            quote! { #initial_runner }
        }
    };
215
    quote! {
        impl #command_name {
            pub fn new() -> #initial_runner {
                #constructor
            }
        }
    }
215
}
struct CommandInput {
    name: syn::Ident,
    required_args: Vec<(syn::Ident, syn::Type)>,
    optional_args: Vec<(syn::Ident, syn::Type)>,
    body: syn::Block,
}
impl syn::parse::Parse for CommandInput {
215
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
215
        let name: syn::Ident = input.parse()?;
        let content;
215
        syn::braced!(content in input);
215
        let mut required_args = Vec::new();
215
        let mut optional_args = Vec::new();
733
        while !content.is_empty() {
            // Parse attributes
518
            let mut is_optional = false;
518
            let mut is_required = false;
1036
            while content.peek(syn::Token![#]) {
518
                content.parse::<syn::Token![#]>()?;
                let attr_content;
518
                syn::bracketed!(attr_content in content);
518
                let attr_name: syn::Ident = attr_content.parse()?;
518
                if attr_name == "optional" {
167
                    is_optional = true;
351
                } else if attr_name == "required" {
351
                    is_required = true;
351
                } else {
                    return Err(syn::Error::new(
                        attr_name.span(),
                        "Unknown attribute. Use #[required] or #[optional]",
                    ));
                }
            }
            // Parse the field
518
            let arg_name: syn::Ident = content.parse()?;
518
            content.parse::<syn::Token![:]>()?;
518
            let arg_type: syn::Type = content.parse()?;
518
            if content.peek(syn::Token![,]) {
518
                content.parse::<syn::Token![,]>()?;
            }
            // Determine if optional (default to required if no attribute specified)
518
            let is_optional_field = if is_required && is_optional {
                return Err(syn::Error::new(
                    arg_name.span(),
                    "Field cannot be both #[required] and #[optional]",
                ));
518
            } else if is_optional {
167
                true
            } else {
351
                false // Default to required
            };
518
            if is_optional_field {
167
                optional_args.push((arg_name, arg_type));
351
            } else {
351
                required_args.push((arg_name, arg_type));
351
            }
        }
        // The '=>' is outside the braces
215
        input.parse::<syn::Token![=>]>()?;
215
        let body: syn::Block = input.parse()?;
215
        Ok(CommandInput {
215
            name,
215
            required_args,
215
            optional_args,
215
            body,
215
        })
215
    }
}