Skip to content

Commit

Permalink
fix(versionable): Handle generics in NotVersioned
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Jan 20, 2025
1 parent 1f254d6 commit 1174a1a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
31 changes: 22 additions & 9 deletions utils/tfhe-versionable-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub(crate) const UNVERSIONIZE_ERROR_NAME: &str = crate_full_path!("UnversionizeE

pub(crate) const SERIALIZE_TRAIT_NAME: &str = "::serde::Serialize";
pub(crate) const DESERIALIZE_TRAIT_NAME: &str = "::serde::Deserialize";
pub(crate) const DESERIALIZE_OWNED_TRAIT_NAME: &str = "::serde::de::DeserializeOwned";
pub(crate) const FROM_TRAIT_NAME: &str = "::core::convert::From";
pub(crate) const TRY_INTO_TRAIT_NAME: &str = "::core::convert::TryInto";
pub(crate) const INTO_TRAIT_NAME: &str = "::core::convert::Into";
Expand Down Expand Up @@ -316,14 +317,26 @@ pub fn derive_versionize(input: TokenStream) -> TokenStream {
pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

let mut generics = input.generics.clone();
// Versionize needs T to impl Serialize
let mut versionize_generics = input.generics.clone();
syn_unwrap!(add_trait_where_clause(
&mut generics,
&mut versionize_generics,
&[parse_quote! { Self }],
&["Clone"]
&[SERIALIZE_TRAIT_NAME]
));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
// VersionizeOwned needs T to impl Serialize and DeserializeOwned
let mut versionize_owned_generics = input.generics.clone();
syn_unwrap!(add_trait_where_clause(
&mut versionize_owned_generics,
&[parse_quote! { Self }],
&[SERIALIZE_TRAIT_NAME, DESERIALIZE_OWNED_TRAIT_NAME]
));

let (impl_generics, ty_generics, versionize_where_clause) =
versionize_generics.split_for_impl();
let (_, _, versionize_owned_where_clause) = versionize_owned_generics.split_for_impl();

let input_ident = &input.ident;

let versionize_trait: Path = parse_const_str(VERSIONIZE_TRAIT_NAME);
Expand All @@ -334,16 +347,16 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {

quote! {
#[automatically_derived]
impl #impl_generics #versionize_trait for #input_ident #ty_generics #where_clause {
type Versioned<#lifetime> = &#lifetime Self;
impl #impl_generics #versionize_trait for #input_ident #ty_generics #versionize_where_clause {
type Versioned<#lifetime> = &#lifetime Self where Self: 'vers;

fn versionize(&self) -> Self::Versioned<'_> {
self
}
}

#[automatically_derived]
impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #where_clause {
impl #impl_generics #versionize_owned_trait for #input_ident #ty_generics #versionize_owned_where_clause {
type VersionedOwned = Self;

fn versionize_owned(self) -> Self::VersionedOwned {
Expand All @@ -352,14 +365,14 @@ pub fn derive_not_versioned(input: TokenStream) -> TokenStream {
}

#[automatically_derived]
impl #impl_generics #unversionize_trait for #input_ident #ty_generics #where_clause {
impl #impl_generics #unversionize_trait for #input_ident #ty_generics #versionize_owned_where_clause {
fn unversionize(versioned: Self::VersionedOwned) -> Result<Self, #unversionize_error> {
Ok(versioned)
}
}

#[automatically_derived]
impl NotVersioned for #input_ident #ty_generics #where_clause {}
impl #impl_generics NotVersioned for #input_ident #ty_generics #versionize_owned_where_clause {}

}
.into()
Expand Down
6 changes: 3 additions & 3 deletions utils/tfhe-versionable/examples/not_versioned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use serde::{Deserialize, Serialize};
use tfhe_versionable::{NotVersioned, Versionize, VersionsDispatch};

#[derive(Clone, Serialize, Deserialize, NotVersioned)]
struct MyStructNotVersioned {
val: u32,
struct MyStructNotVersioned<Inner> {
val: Inner,
}

#[derive(Versionize)]
#[versionize(MyStructVersions)]
struct MyStruct {
inner: MyStructNotVersioned,
inner: MyStructNotVersioned<u32>,
}

#[derive(VersionsDispatch)]
Expand Down

0 comments on commit 1174a1a

Please sign in to comment.