解析目标结构体的元数据

93
2024/07/03 13:40:37

本章我们将讨论如何解析目标结构体的元数据,包括:结构体的名称、结构体的字段(包括可见性、字段名和数据类型)。

Db 稍微变的有趣

上一章,我们的 Derive 只是返回了一个空的 TokenStream,我们让它变的稍微有趣一点,返回一个 println!("Hello, axum.rs!");

// src/lib.rs
#[proc_macro_derive(Db)]
pub fn db_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    println!("{:#?}", input);

    r#"{println!("Hello, axum.rs!");}"#.parse().unwrap()
}
// src/lib.rs

#[proc_macro_derive(Db)]
pub fn db_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    println!("{:#?}", input);

    r#"
    impl User {
        pub fn hi(&self) -> &'static str {
            "Hello, axum.rs!"
        }
    }
    "#
    .parse()
    .unwrap()
}

调用一下试试:

// examples/ch02-parse-struct-meta.rs

use db_derive::Db;

#[derive(Debug, Default, Db)]
pub struct User {
    pub id: String,
    pub email: String,
    pub password: String,
    pub nickname: String,
    pub dateline: chrono::DateTime<chrono::Local>,
}

fn main() {
    let u = User::default();
    let s = u.hi();
    println!("{}", s);
}

神奇的事情再度发生,竟然真的能调用 hi() 方法!

从中我们可以得出结论,proc_macro::TokenStream 其实就是抽象,一种【对 rust 代码】的抽象。

但是,这代码代码中的问题很严重:我们写死了结构体的名称:impl User {},如果结构体不叫 User呢?具体叫什么,我们不得而知,需要从 input 参数中解析出来。你可以手写代码来解析 input 参数,但优秀的第三方库已经为我们做好了,可以大大提升效率。

解析结构体名称

  • syn:用于从 TokenStream 中解析元数据
  • quote:用于方便地编写 TokenStream

先看一段代码:

// src/lib.rs

#[proc_macro_derive(Db)]
pub fn db_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
    let name = ast.ident;
    println!("{:?}", name);

    quote::quote! {}.into()
}
  • 首先,我们对输入的 TokenStream 进行解析:let ast = syn::parse_macro_input!(input as syn::DeriveInput);
  • 然后,便可以拿到目标结构体的名称了:let name = ast.ident;
    • 它的数据类型是Ident,用于抽象所有【标识符】,包括但不限于:结构体名字、字段名称等等
  • 最后,我们返回一个空的 TokenStreamquote::quote! {}.into()
    • quote! 可以让我们非常方便的编写 TokenStream
    • 它的返回值是 proc_macro2::TokenStream
    • proc_macro2是对标准库proc_macro 的封装,所以 proc_marco2::TokenStream 可以很容易的 Intoproc_macro::TokenStream

下面,我们使用 quote! 来为结构体实现 hi(&self) 方法:

// src/lib.rs

#[proc_macro_derive(Db)]
pub fn db_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
    let name = ast.ident;

    quote::quote! {
        impl #name {
            pub fn hi(&self) -> &'static str {
                "Hello, axum.rs!"
            }
        }
    }
    .into()
}
  • 通过 #变量名 可以在 quote! 内部引用外部定义的变量。本例中,引用了 name 变量,即目标结构体的标识符。

解析字段

目标结构体的字段信息保存在 DeriveInputdata 字段里,它是 Data枚举:

对于我们而言,我们需要的是 Data::Struct(DataStruct),因为我们的宏只用于修饰结构体。 DataStruct维护了目标结构体的字段信息:

pub struct DataStruct {
    pub struct_token: Struct,
    pub fields: Fields,
    pub semi_token: Option<Semi>,
}

我们关心的是 fields,它是Fields枚举:

pub enum Fields {
    Named(FieldsNamed),
    Unnamed(FieldsUnnamed),
    Unit,
}

而其中,我们需要的是 Named(FieldsNamed),因为我们的字段都是有名字的。FieldsNamed中,就有我们需要的字段信息了:

pub struct FieldsNamed {
    pub brace_token: Brace,
    pub named: Punctuated<Field, Comma>,
}

其中的 namedPunctuated<Field, Comma>泛型结构体,两个泛型分别是字段和字段之间的分隔符。自然,我们只关心 Field结构体:

pub struct Field {
    pub attrs: Vec<Attribute>,
    pub vis: Visibility,
    pub mutability: FieldMutability,
    pub ident: Option<Ident>,
    pub colon_token: Option<Colon>,
    pub ty: Type,
}

来了来了,经过层层套娃它终于来了。我们重点关注:

  • attrs:字段的属性,我们后续章节会讲到
  • vis:字段的可访问性
  • ident:字段标识符,即字段名
  • ty:字段的数据类型

通过上面的分析,我们可以这样拿到目标结构体的字段:

let fileds = match ast.data {
        syn::Data::Struct(s) => match s.fields {
            syn::Fields::Named(n) => n.named,
            _ => unreachable!(),
        },
        _ => unreachable!(),
    };

你也可以用更简洁的方式,通过解构语法一步到位:

let fileds = if let syn::Data::Struct(syn::DataStruct {
        fields: syn::Fields::Named(syn::FieldsNamed { named, .. }, ..),
        ..
    }) = ast.data
    {
        named
    } else {
        unreachable!()
    };

我们现在改一下 hi(&self) ,让它返回字段列表。

#[proc_macro_derive(Db)]
pub fn db_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
    let name = ast.ident;

    let fileds = if let syn::Data::Struct(syn::DataStruct {
        fields: syn::Fields::Named(syn::FieldsNamed { named, .. }, ..),
        ..
    }) = ast.data
    {
        named
    } else {
        unreachable!()
    };

    let field_str = fileds
        .iter()
        .map(|f| f.ident.clone().unwrap().to_string())
        .collect::<Vec<_>>()
        .join(",");

    quote::quote! {
        impl #name {
            pub fn hi(&self) -> &'static str {
                #field_str
            }
        }
    }
    .into()
}

为每个字段生成 gettersetter

对于我们需求而言,gettersetter 不是必须的。但为了演示 quote 里的重要内容:循环,我们还是通过本小节来为目标结构体的每个字段生成对应的 gettersetter

quote 的循环和声明宏里的循环语法很类似:

#( #可迭代对象 )[*+]
  • 循环必须是以 #( ) 包裹
  • 包裹的内容必须是 rust 里可迭代的,比如 Vec
  • 循环次数是由最后的标识指定
    • * 循环0次到迭代结束,在迭代对象可能为空的情况下使用
    • + 循环1次到迭代结束,在迭代对象保证不为空的情况下使用

第一步:为每个字段生成同名的方法

首先,我们为每个字段生成同名的方法,该方法啥都不做:

 let field_idents = fileds
    .iter()
    .map(|f| f.ident.clone().unwrap().clone())
    .collect::<Vec<_>>();

quote::quote! {
    impl #name {
      #(
         pub fn #field_idents(&self) {

        }
      )*
    }
}
.into()
  • 通过 map 将每个字段的标识符收集为一个 Vec
  • quote! 中,遍历这个Vec
    • 每次遍历都通过 pub fn ... 生成同名方法

第二步:为每个字段的同名方法返回对应的字段值的引用

let field_types = fileds.iter().map(|f| f.ty.clone()).collect::<Vec<_>>();
  • 通过 map 将每个字段的数据类型收集为一个 Vec
quote::quote! {
    impl #name {
      #(
         pub fn #field_idents(&self) -> &#field_types {
            &self.#field_idents
        }
      )*
    }
}
  • 将字段的数据类型作为返回值的类型

第三步:为每个字段生成 setter

let setter_idents = field_idents
    .iter()
    .map(|f| {
        let ident_str = format!("set_{}", f.to_string());
        syn::Ident::new(&ident_str, f.span())
    })
    .collect::<Vec<_>>();
  • 通过 map 生成每个字段的 setter 方法,方法名是 set_字段名
 #(
    pub fn #setter_idents(&mut self, v:#field_types) {
        self.#field_idents = v;
    }
  )*
  • 生成 setter 方法

完整代码如下:

调用一下试试:

use db_derive::Db;

#[derive(Debug, Default, Db)]
pub struct User {
    pub id: String,
    pub email: String,
    pub password: String,
    pub nickname: String,
    pub dateline: chrono::DateTime<chrono::Local>,
}

fn main() {
    let mut u = User::default();
    // 调用 setter
    u.set_email("[email protected]".into());
    // 调用 getter
    let email = u.email();
    println!("email is: {}", email);
}

本章代码位于02/解析结构体元数据 分支。