Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive on structs with generics #79

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions bilge-impl/src/bitsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,39 +121,33 @@ fn analyze_enum(bitsize: BitSize, variants: Iter<Variant>) {
}

fn generate_struct(item: &ItemStruct, declared_bitsize: u8) -> TokenStream {
let ItemStruct { vis, ident, fields, .. } = item;
let ItemStruct { ident, fields, generics, .. } = item;
let declared_bitsize = declared_bitsize as usize;

let computed_bitsize = fields.iter().fold(quote!(0), |acc, next| {
let field_size = shared::generate_type_bitsize(&next.ty);
quote!(#acc + #field_size)
});

// we could remove this if the whole struct gets passed
let is_tuple_struct = fields.iter().any(|field| field.ident.is_none());
let fields_def = if is_tuple_struct {
let fields = fields.iter();
quote! {
( #(#fields,)* );
}
} else {
let fields = fields.iter();
quote! {
{ #(#fields,)* }
}
};
// The only part of the struct we don't want to pass through are the attributes
let mut item = item.clone();
item.attrs = Vec::new();

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
#vis struct #ident #fields_def

// constness: when we get const blocks evaluated at compile time, add a const computed_bitsize
const _: () = assert!(
(#computed_bitsize) == (#declared_bitsize),
concat!("struct size and declared bit size differ: ",
// stringify!(#computed_bitsize),
" != ",
stringify!(#declared_bitsize))
);
#item

impl #impl_generics #ident #ty_generics #where_clause {
// constness: when we get const blocks evaluated at compile time, add a const computed_bitsize
const _BITSIZE_CHECK: () = assert!(
(#computed_bitsize) == (#declared_bitsize),
concat!("struct size and declared bit size differ: ",
// stringify!(#computed_bitsize),
" != ",
stringify!(#declared_bitsize))
);
}
}
}

Expand Down
51 changes: 42 additions & 9 deletions bilge-impl/src/bitsize_internal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Attribute, Field, Item, ItemEnum, ItemStruct, Type};
use syn::{Attribute, Field, Generics, Item, ItemEnum, ItemStruct, Type};

use crate::shared::{self, unreachable};

Expand All @@ -12,6 +12,7 @@ struct ItemIr<'a> {
name: &'a Ident,
/// generated item (and setters, getters, constructor, impl Bitsized)
expanded: TokenStream,
generics: &'a Generics,
}

pub(super) fn bitsize_internal(args: TokenStream, item: TokenStream) -> TokenStream {
Expand All @@ -21,13 +22,25 @@ pub(super) fn bitsize_internal(args: TokenStream, item: TokenStream) -> TokenStr
let expanded = generate_struct(item, &arb_int);
let attrs = &item.attrs;
let name = &item.ident;
ItemIr { attrs, name, expanded }
let generics = &item.generics;
ItemIr {
attrs,
name,
expanded,
generics,
}
}
Item::Enum(ref item) => {
let expanded = generate_enum(item);
let attrs = &item.attrs;
let name = &item.ident;
ItemIr { attrs, name, expanded }
let generics = &item.generics;
ItemIr {
attrs,
name,
expanded,
generics,
}
}
_ => unreachable(()),
};
Expand All @@ -41,7 +54,19 @@ fn parse(item: TokenStream, args: TokenStream) -> (Item, TokenStream) {
}

fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStream {
let ItemStruct { vis, ident, fields, .. } = struct_data;
let ItemStruct {
vis,
ident,
fields,
generics,
..
} = struct_data;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let phantom_ty = generics.type_params().map(|e| &e.ident).map(|ident| quote!(#ident));
let phantom_lt = generics.lifetimes().map(|l| &l.lifetime).map(|lifetime| quote!(& #lifetime ()));
// TODO: integrate user-provided PhantomData somehow? (so that the user can set the variance)
let phantom = phantom_ty.chain(phantom_lt);

let mut fieldless_next_int = 0;
let mut previous_field_sizes = vec![];
Expand All @@ -67,11 +92,12 @@ fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStre
let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };

quote! {
#vis struct #ident {
#vis struct #ident #generics #where_clause {
/// WARNING: modifying this value directly can break invariants
value: #arb_int,
_phantom: ::core::marker::PhantomData<(#(#phantom),*)>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it correct that adding a PhantomData here is only needed when generics are involved?

if so, do you think it's possible (and practical) to not generate a _phantom field unless it's actually needed, just by looking at item.generics?

Copy link
Contributor Author

@kitlith kitlith Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, yes, but it would create additional conditions at the construction sites (there are 4 in this PR) and potentially any external derive macros. We could refactor those to use a common constructor method, but I figured I'd go for the approach that requires the least conditional code.

I'm fine either way, though.

}
impl #ident {
impl #impl_generics #ident #ty_generics #where_clause {
// #[inline]
#[allow(clippy::too_many_arguments, clippy::type_complexity, unused_parens)]
pub #const_ fn new(#( #constructor_args )*) -> Self {
Expand All @@ -81,7 +107,7 @@ fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStre
let mut offset = 0;
let raw_value = #( #constructor_parts )|*;
let value = #arb_int::new(raw_value);
Self { value }
Self { value, _phantom: ::core::marker::PhantomData }
}
#( #accessors )*
}
Expand Down Expand Up @@ -221,12 +247,19 @@ fn generate_enum(enum_data: &ItemEnum) -> TokenStream {
/// We have _one_ `generate_common` function, which holds everything struct and enum have _in common_.
/// Everything else has its own `generate_` functions.
fn generate_common(ir: ItemIr, arb_int: &TokenStream) -> TokenStream {
let ItemIr { attrs, name, expanded } = ir;
let ItemIr {
attrs,
name,
expanded,
generics,
} = ir;

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
#(#attrs)*
#expanded
impl ::bilge::Bitsized for #name {
impl #impl_generics ::bilge::Bitsized for #name #ty_generics #where_clause {
type ArbitraryInt = #arb_int;
const BITS: usize = <Self::ArbitraryInt as Bitsized>::BITS;
const MAX: Self::ArbitraryInt = <Self::ArbitraryInt as Bitsized>::MAX;
Expand Down
22 changes: 18 additions & 4 deletions bilge-impl/src/debug_bits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{Data, Fields};
use syn::{Data, Fields, WhereClause, WherePredicate};

use crate::shared::{self, unreachable};

Expand All @@ -17,7 +17,7 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
};

let fmt_impl = match struct_data.fields {
Fields::Named(fields) => {
Fields::Named(ref fields) => {
let calls = fields.named.iter().map(|f| {
// We can unwrap since this is a named field
let call = f.ident.as_ref().unwrap();
Expand All @@ -30,7 +30,7 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
#(#calls)*.finish()
}
}
Fields::Unnamed(fields) => {
Fields::Unnamed(ref fields) => {
let calls = fields.unnamed.iter().map(|_| {
let call: Ident = syn::parse_str(&format!("val_{}", fieldless_next_int)).unwrap_or_else(unreachable);
fieldless_next_int += 1;
Expand All @@ -45,8 +45,22 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();
let mut where_clause = where_clause.map(<_>::clone).unwrap_or_else(|| WhereClause {
where_token: <_>::default(),
predicates: <_>::default(),
});

// NOTE: This is not *ideal*, but it's approximately what the standard library does,
// for various reasons. see https://github.com/rust-lang/rust/issues/26925
where_clause.predicates.extend(derive_input.generics.type_params().map(|t| {
let ty = &t.ident;
let res: WherePredicate = syn::parse_quote!(#ty : ::core::fmt::Debug);
res
}));

quote! {
impl ::core::fmt::Debug for #name {
impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
#fmt_impl
}
Expand Down
30 changes: 22 additions & 8 deletions bilge-impl/src/default_bits.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,49 @@
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Type};
use syn::{Data, DeriveInput, Fields, Generics, Type, WhereClause, WherePredicate};

use crate::shared::{self, fallback::Fallback, unreachable, BitSize};

pub(crate) fn default_bits(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
//TODO: does fallback need handling?
let (derive_data, _, name, ..) = analyze(&derive_input);
let (derive_data, _, name, generics, ..) = analyze(&derive_input);

match derive_data {
Data::Struct(data) => generate_struct_default_impl(name, &data.fields),
Data::Struct(data) => generate_struct_default_impl(name, &data.fields, generics),
Data::Enum(_) => abort_call_site!("use derive(Default) for enums"),
_ => unreachable(()),
}
}

fn generate_struct_default_impl(struct_name: &Ident, fields: &Fields) -> TokenStream {
fn generate_struct_default_impl(struct_name: &Ident, fields: &Fields, generics: &Generics) -> TokenStream {
let default_value = fields
.iter()
.map(|field| generate_default_inner(&field.ty))
.reduce(|acc, next| quote!(#acc | #next));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut where_clause = where_clause.map(<_>::clone).unwrap_or_else(|| WhereClause {
where_token: <_>::default(),
predicates: <_>::default(),
});

// NOTE: This is not *ideal*, but it's approximately what the standard library does,
// for various reasons. see https://github.com/rust-lang/rust/issues/26925
where_clause.predicates.extend(generics.type_params().map(|t| {
let ty = &t.ident;
let res: WherePredicate = syn::parse_quote!(#ty : ::core::default::Default);
res
}));

quote! {
impl ::core::default::Default for #struct_name {
impl #impl_generics ::core::default::Default for #struct_name #ty_generics #where_clause {
fn default() -> Self {
let mut offset = 0;
let value = #default_value;
let value = <#struct_name as Bitsized>::ArbitraryInt::new(value);
Self { value }
let value = <#struct_name #ty_generics as Bitsized>::ArbitraryInt::new(value);
Self { value, _phantom: ::core::marker::PhantomData }
}
}
}
Expand Down Expand Up @@ -87,6 +101,6 @@ fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, &Generics, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}
14 changes: 8 additions & 6 deletions bilge-impl/src/fmt_bits.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Variant};
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Generics, Variant};

use crate::shared::{self, discriminant_assigner::DiscriminantAssigner, fallback::Fallback, unreachable, BitSize};

pub(crate) fn binary(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
let (derive_data, arb_int, name, bitsize, fallback) = analyze(&derive_input);
let (derive_data, arb_int, name, generics, bitsize, fallback) = analyze(&derive_input);

match derive_data {
Data::Struct(data) => generate_struct_binary_impl(name, &data.fields),
Data::Struct(data) => generate_struct_binary_impl(name, &data.fields, generics),
Data::Enum(data) => generate_enum_binary_impl(name, data.variants.iter(), arb_int, bitsize, fallback),
_ => unreachable(()),
}
}

fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields) -> TokenStream {
fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields, generics: &Generics) -> TokenStream {
let write_underscore = quote! { write!(f, "_")?; };

// fields are printed from most significant to least significant, separated by an underscore
Expand All @@ -37,8 +37,10 @@ fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields) -> TokenStr
})
.reduce(|acc, next| quote!(#acc #write_underscore #next));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
impl ::core::fmt::Binary for #struct_name {
impl #impl_generics ::core::fmt::Binary for #struct_name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let struct_size = <#struct_name as Bitsized>::BITS;
let mut last_bit_pos = struct_size;
Expand Down Expand Up @@ -107,6 +109,6 @@ fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, &Generics, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}
20 changes: 11 additions & 9 deletions bilge-impl/src/from_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use itertools::Itertools;
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::{abort, abort_call_site};
use quote::quote;
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Type, Variant};
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Generics, Type, Variant};

use crate::shared::{self, discriminant_assigner::DiscriminantAssigner, enum_fills_bitsize, fallback::Fallback, unreachable, BitSize};

pub(super) fn from_bits(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
let (derive_data, arb_int, name, internal_bitsize, fallback) = analyze(&derive_input);
let (derive_data, arb_int, name, generics, internal_bitsize, fallback) = analyze(&derive_input);
let expanded = match &derive_data {
Data::Struct(struct_data) => generate_struct(arb_int, name, &struct_data.fields),
Data::Struct(struct_data) => generate_struct(arb_int, name, &struct_data.fields, generics),
Data::Enum(enum_data) => {
let variants = enum_data.variants.iter();
let match_arms = analyze_enum(variants, name, internal_bitsize, fallback.as_ref(), &arb_int);
Expand All @@ -25,7 +25,7 @@ fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&syn::Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
fn analyze(derive_input: &DeriveInput) -> (&syn::Data, TokenStream, &Ident, &Generics, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}

Expand Down Expand Up @@ -141,7 +141,7 @@ fn generate_filled_check_for(ty: &Type, vec: &mut Vec<TokenStream>) {
}
}

fn generate_struct(arb_int: TokenStream, struct_type: &Ident, fields: &Fields) -> TokenStream {
fn generate_struct(arb_int: TokenStream, struct_type: &Ident, fields: &Fields, generics: &Generics) -> TokenStream {
let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };

let mut assumes = Vec::new();
Expand All @@ -152,15 +152,17 @@ fn generate_struct(arb_int: TokenStream, struct_type: &Ident, fields: &Fields) -
// a single check per type is enough, so the checks can be deduped
let assumes = assumes.into_iter().unique_by(TokenStream::to_string);

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
impl #const_ ::core::convert::From<#arb_int> for #struct_type {
impl #impl_generics #const_ ::core::convert::From<#arb_int> for #struct_type #ty_generics #where_clause {
fn from(value: #arb_int) -> Self {
#( #assumes )*
Self { value }
Self { value, _phantom: ::core::marker::PhantomData }
}
}
impl #const_ ::core::convert::From<#struct_type> for #arb_int {
fn from(value: #struct_type) -> Self {
impl #impl_generics #const_ ::core::convert::From<#struct_type #ty_generics> for #arb_int #where_clause {
fn from(value: #struct_type #ty_generics) -> Self {
value.value
}
}
Expand Down
Loading