Skip to content

Commit

Permalink
#[trace(bound = ...)]
Browse files Browse the repository at this point in the history
Summary: Mostly to support `'static` bound. Used in D64275149.

Reviewed By: JakobDegen

Differential Revision: D64275150

fbshipit-source-id: 3b0db29786664f6c958bb68fb0913a6ac0a7d462
  • Loading branch information
stepancheg authored and facebook-github-bot committed Oct 14, 2024
1 parent b3f36e4 commit c08af46
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
1 change: 1 addition & 0 deletions starlark/src/tests/derive/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
* limitations under the License.
*/

mod bounds;
mod enums;
mod statics;
37 changes: 37 additions & 0 deletions starlark/src/tests/derive/trace/bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2018 The Starlark in Rust Authors.
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#![allow(dead_code)]

use crate as starlark;
use crate::values::Trace;

#[derive(Trace)]
#[trace(bound = "A: Trace<'v>, B: 'static")]
struct TestTraceWithBounds<A, B> {
a: A,
#[trace(static)]
b: B,
}

struct NotTrace;

fn assert_trace<'v, T: Trace<'v>>() {}

fn test() {
assert_trace::<TestTraceWithBounds<String, NotTrace>>();
}
45 changes: 44 additions & 1 deletion starlark_derive/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use quote::ToTokens;
use syn::parse::ParseStream;
use syn::parse_macro_input;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::Attribute;
use syn::DeriveInput;
Expand Down Expand Up @@ -54,6 +55,7 @@ fn derive_trace_impl(mut input: DeriveInput) -> syn::Result<syn::ItemImpl> {
let TraceAttrs {
unsafe_ignore,
trace_static,
bounds,
} = parse_attrs(&input.attrs)?;
if let Some(unsafe_ignore) = unsafe_ignore {
return Err(syn::Error::new_spanned(
Expand All @@ -72,14 +74,33 @@ fn derive_trace_impl(mut input: DeriveInput) -> syn::Result<syn::ItemImpl> {
let mut has_tick_v = false;
for param in &mut input.generics.params {
if let GenericParam::Type(type_param) = param {
type_param.bounds.push(bound.clone());
if bounds.is_none() {
type_param.bounds.push(bound.clone());
}
}
if let GenericParam::Lifetime(t) = param {
if t.lifetime.ident == "v" {
has_tick_v = true;
}
}
}
if let Some(bounds) = bounds {
'outer: for bound in bounds {
for param in &mut input.generics.params {
if let GenericParam::Type(type_param) = param {
if type_param.ident == bound.ident {
type_param.bounds.extend(bound.bounds);
continue 'outer;
}
}
}
return Err(syn::Error::new_spanned(
bound,
"Type parameter not found in the generic parameters",
));
}
}

let mut generics2 = input.generics.clone();

let (_, ty_generics, where_clause) = input.generics.split_for_impl();
Expand All @@ -101,13 +122,16 @@ fn derive_trace_impl(mut input: DeriveInput) -> syn::Result<syn::ItemImpl> {
}

syn::custom_keyword!(unsafe_ignore);
syn::custom_keyword!(bound);

#[derive(Default)]
struct TraceAttrs {
/// `#[trace(unsafe_ignore)]`
unsafe_ignore: Option<unsafe_ignore>,
/// `#[trace(static)]`
trace_static: Option<syn::Token![static]>,
/// `#[trace(bound = "A: 'static, B: Trace<'v>")]`
bounds: Option<Punctuated<syn::TypeParam, syn::Token![,]>>,
}

impl TraceAttrs {
Expand All @@ -131,6 +155,18 @@ impl TraceAttrs {
));
}
trace_attrs.trace_static = Some(trace_static);
} else if let Some(bound) = input.parse::<Option<bound>>()? {
if trace_attrs.bounds.is_some() {
return Err(syn::Error::new_spanned(
bound,
"Duplicate `bound` attribute",
));
}
input.parse::<syn::Token![=]>()?;
let bounds = input.parse::<syn::LitStr>()?;
let bounds = bounds
.parse_with(|parser: ParseStream| Punctuated::parse_terminated(parser))?;
trace_attrs.bounds = Some(bounds);
} else {
return Err(input.error("Unknown attribute"));
}
Expand Down Expand Up @@ -174,13 +210,20 @@ fn trace_impl(derive_input: &DeriveInput, generics: &Generics) -> syn::Result<sy
let TraceAttrs {
unsafe_ignore,
trace_static,
bounds,
} = parse_attrs(&field.attrs)?;
if let (Some(unsafe_ignore), Some(_trace_static)) = (unsafe_ignore, trace_static) {
return Err(syn::Error::new_spanned(
unsafe_ignore,
"Cannot have both `unsafe_ignore` and `static` attributes",
));
}
if let Some(bounds) = bounds {
return Err(syn::Error::new_spanned(
bounds,
"The `bound` attribute can only be used on the `#[derive(Trace)]`",
));
}
if unsafe_ignore.is_some() {
Ok(quote! {})
} else if trace_static.is_some() || is_static(&field.ty, &generic_types) {
Expand Down

0 comments on commit c08af46

Please sign in to comment.