Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variant data type support #170

Merged
merged 11 commits into from
Nov 21, 2024
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ required-features = ["rustls-tls"]
name = "data_types_derive_simple"
required-features = ["time", "uuid"]

[[example]]
name = "data_types_variant"
required-features = ["time"]

[profile.release]
debug = true

Expand Down
170 changes: 170 additions & 0 deletions examples/data_types_variant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use clickhouse_derive::Row;
use serde::{Deserialize, Serialize};

use clickhouse::sql::Identifier;
use clickhouse::{error::Result, Client};

// See also: https://clickhouse.com/docs/en/sql-reference/data-types/variant

#[tokio::main]
async fn main() -> Result<()> {
let table_name = "chrs_data_types_variant";
let client = Client::default().with_url("http://localhost:8123");

// No matter the order of the definition on the Variant types in the DDL, this particular Variant will always be sorted as follows:
// Variant(Array(UInt16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8)
client
.query(
"
CREATE OR REPLACE TABLE ?
(
`id` UInt64,
`var` Variant(
Array(UInt16),
Bool,
Date,
FixedString(6),
Float32, Float64,
Int128, Int16, Int32, Int64, Int8,
String,
UInt128, UInt16, UInt32, UInt64, UInt8
)
)
ENGINE = MergeTree
ORDER BY id",
)
.bind(Identifier(table_name))
.with_option("allow_experimental_variant_type", "1")
// This is required only if we are mixing similar types in the Variant definition
// In this case, this is various Int/UInt types, Float32/Float64, and String/FixedString
// Omit this option if there are no similar types in the definition
.with_option("allow_suspicious_variant_types", "1")
.execute()
.await?;

let mut insert = client.insert(table_name)?;
let rows_to_insert = get_rows();
for row in rows_to_insert {
insert.write(&row).await?;
}
insert.end().await?;

let rows = client
.query("SELECT ?fields FROM ?")
.bind(Identifier(table_name))
.fetch_all::<MyRow>()
.await?;

println!("{rows:#?}");
Ok(())
}

fn get_rows() -> Vec<MyRow> {
vec![
MyRow {
id: 1,
var: MyRowVariant::Array(vec![1, 2]),
},
MyRow {
id: 2,
var: MyRowVariant::Boolean(true),
},
MyRow {
id: 3,
var: MyRowVariant::Date(
time::Date::from_calendar_date(2021, time::Month::January, 1).unwrap(),
),
},
MyRow {
id: 4,
var: MyRowVariant::FixedString(*b"foobar"),
},
MyRow {
id: 5,
var: MyRowVariant::Float32(100.5),
},
MyRow {
id: 6,
var: MyRowVariant::Float64(200.1),
},
MyRow {
id: 7,
var: MyRowVariant::Int8(2),
},
MyRow {
id: 8,
var: MyRowVariant::Int16(3),
},
MyRow {
id: 9,
var: MyRowVariant::Int32(4),
},
MyRow {
id: 10,
var: MyRowVariant::Int64(5),
},
MyRow {
id: 11,
var: MyRowVariant::Int128(6),
},
MyRow {
id: 12,
var: MyRowVariant::String("my_string".to_string()),
},
MyRow {
id: 13,
var: MyRowVariant::UInt8(7),
},
MyRow {
id: 14,
var: MyRowVariant::UInt16(8),
},
MyRow {
id: 15,
var: MyRowVariant::UInt32(9),
},
MyRow {
id: 16,
var: MyRowVariant::UInt64(10),
},
MyRow {
id: 17,
var: MyRowVariant::UInt128(11),
},
]
}

// As the inner Variant types are _always_ sorted alphabetically,
slvrtrn marked this conversation as resolved.
Show resolved Hide resolved
// it should be defined in _exactly_ the same order in the enum.
//
// Rust enum variants names are irrelevant, only the order of the types matters.
// This enum represents Variant(Array(UInt16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8)
#[derive(Debug, PartialEq, Serialize, Deserialize)]
enum MyRowVariant {
Array(Vec<i16>),
Boolean(bool),
// attributes should work in this case, too
#[serde(with = "clickhouse::serde::time::date")]
Date(time::Date),
// NB: by default, fetched as raw bytes
FixedString([u8; 6]),
Float32(f32),
Float64(f64),
Int128(i128),
Int16(i16),
Int32(i32),
Int64(i64),
Int8(i8),
String(String),
UInt128(u128),
UInt16(i16),
UInt32(u32),
UInt64(u64),
UInt8(i8),
}

#[derive(Debug, PartialEq, Row, Serialize, Deserialize)]
struct MyRow {
id: u64,
var: MyRowVariant,
}
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum Error {
InvalidUtf8Encoding(#[from] Utf8Error),
#[error("tag for enum is not valid")]
InvalidTagEncoding(usize),
#[error("max number of types in the Variant data type is 255, got {0}")]
VariantDiscriminatorIsOutOfBound(usize),
#[error("a custom error message from serde: {0}")]
Custom(String),
#[error("bad response: {0}")]
Expand Down
80 changes: 70 additions & 10 deletions src/rowbinary/de.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{convert::TryFrom, mem, str};

use crate::error::{Error, Result};
use bytes::Buf;
use serde::de::{EnumAccess, VariantAccess};
use serde::{
de::{DeserializeSeed, Deserializer, SeqAccess, Visitor},
Deserialize,
};

use crate::error::{Error, Result};

/// Deserializes a value from `input` with a row encoded in `RowBinary`.
///
/// It accepts _a reference to_ a byte slice because it somehow leads to a more
Expand Down Expand Up @@ -146,14 +146,79 @@ impl<'cursor, 'data> Deserializer<'data> for &mut RowBinaryDeserializer<'cursor,
visitor.visit_byte_buf(self.read_vec(size)?)
}

#[inline]
fn deserialize_identifier<V: Visitor<'data>>(self, visitor: V) -> Result<V::Value> {
self.deserialize_u8(visitor)
}

#[inline]
fn deserialize_enum<V: Visitor<'data>>(
self,
name: &'static str,
_name: &'static str,
_variants: &'static [&'static str],
_visitor: V,
visitor: V,
) -> Result<V::Value> {
panic!("enums are unsupported: `{name}`");
struct Access<'de, 'cursor, 'data> {
deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>,
}
struct VariantDeserializer<'de, 'cursor, 'data> {
deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>,
}
impl<'data> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data> {
type Error = Error;

fn unit_variant(self) -> Result<()> {
Ok(())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be an error on the definition side (SomeEnum::A without any payload), innit?

We should return an error here to avoid the generic "not enough data" error. Probably, it's time to introduce Error::Unsupported instead of panics, but up to you here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Error::Unsupported and used it in this case. The rest of the panics can be changed as a follow-up.

}

fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: DeserializeSeed<'data>,
{
DeserializeSeed::deserialize(seed, &mut *self.deserializer)
}

fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'data>,
{
self.deserializer.deserialize_tuple(len, visitor)
}

fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: Visitor<'data>,
{
self.deserializer.deserialize_tuple(fields.len(), visitor)
}
}

impl<'de, 'cursor, 'data> EnumAccess<'data> for Access<'de, 'cursor, 'data> {
type Error = Error;
type Variant = VariantDeserializer<'de, 'cursor, 'data>;

fn variant_seed<T>(
self,
seed: T,
) -> std::result::Result<(T::Value, Self::Variant), Self::Error>
slvrtrn marked this conversation as resolved.
Show resolved Hide resolved
where
T: DeserializeSeed<'data>,
{
seed.deserialize(&mut *self.deserializer).map(|v| {
slvrtrn marked this conversation as resolved.
Show resolved Hide resolved
(
v,
VariantDeserializer {
deserializer: self.deserializer,
},
)
})
}
}
visitor.visit_enum(Access { deserializer: self })
}

#[inline]
Expand Down Expand Up @@ -222,11 +287,6 @@ impl<'cursor, 'data> Deserializer<'data> for &mut RowBinaryDeserializer<'cursor,
self.deserialize_tuple(fields.len(), visitor)
}

#[inline]
fn deserialize_identifier<V: Visitor<'data>>(self, _visitor: V) -> Result<V::Value> {
panic!("identifiers are unsupported");
}

#[inline]
fn deserialize_newtype_struct<V: Visitor<'data>>(
self,
Expand Down
18 changes: 13 additions & 5 deletions src/rowbinary/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,20 @@ impl<'a, B: BufMut> Serializer for &'a mut RowBinarySerializer<B> {
#[inline]
fn serialize_newtype_variant<T: Serialize + ?Sized>(
self,
name: &'static str,
_variant_index: u32,
variant: &'static str,
_value: &T,
_name: &'static str,
variant_index: u32,
_variant: &'static str,
value: &T,
) -> Result<()> {
panic!("newtype variant types are unsupported: `{name}::{variant}`");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this code implicitly allows using enums at the top level. Usually, it's avoided by using either the parameter stored in the serializer struct or dedicated SerializeStruct/etc associated types.

In general, it's okay (the user will get some error anyway), but instead of more descriptive panic, it ends with "not enough data."

Also, it produces an unclear message for forgotten serde_repr (Enum8 and Enum16).

It's not a blocker for merge because it's an error anyway. But, probably, we should leave "TODO" at least in the code about it.

// Max number of types in the Variant data type is 255
// See also: https://github.com/ClickHouse/ClickHouse/issues/54864
if variant_index > 255 {
return Err(Error::VariantDiscriminatorIsOutOfBound(
variant_index as usize,
));
}
self.buffer.put_u8(variant_index as u8);
value.serialize(self)
}

#[inline]
Expand Down
1 change: 1 addition & 0 deletions tests/it/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ mod query;
mod time;
mod user_agent;
mod uuid;
mod variant;
mod watch;

const HOST: &str = "localhost:8123";
Expand Down
Loading