From f9cb4aac4142143fb03131be75e62d189ab93ea0 Mon Sep 17 00:00:00 2001 From: Melissa King Date: Thu, 16 Nov 2023 14:30:53 -0700 Subject: [PATCH] add support for Decimal128(S) --- src/types/block/builder.rs | 2 +- src/types/column/decimal.rs | 19 +++++++-- src/types/column/factory.rs | 4 +- src/types/column/iter/mod.rs | 6 ++- src/types/decimal.rs | 80 ++++++++++++++++++++++++++---------- src/types/value.rs | 2 +- src/types/value_ref.rs | 4 +- tests/clickhouse.rs | 10 +++-- 8 files changed, 94 insertions(+), 33 deletions(-) diff --git a/src/types/block/builder.rs b/src/types/block/builder.rs index 3ed3a2e9..06b482de 100644 --- a/src/types/block/builder.rs +++ b/src/types/block/builder.rs @@ -206,6 +206,6 @@ mod test { block.columns[14].sql_type(), SqlType::DateTime(DateTimeType::Chrono) ); - assert_eq!(block.columns[15].sql_type(), SqlType::Decimal(18, 4)); + assert_eq!(block.columns[15].sql_type(), SqlType::Decimal(38, 4)); } } diff --git a/src/types/column/decimal.rs b/src/types/column/decimal.rs index cab49c85..3898a28b 100644 --- a/src/types/column/decimal.rs +++ b/src/types/column/decimal.rs @@ -52,6 +52,7 @@ impl DecimalColumnData { let type_name = match nobits { NoBits::N32 => "Int32", NoBits::N64 => "Int64", + NoBits::N128 => "Int128", }; let inner = ::load_data::(reader, type_name, size, tz)?; @@ -161,6 +162,10 @@ impl ColumnData for DecimalColumnData { let internal: i64 = decimal.internal(); self.inner.push(internal.into()) } + NoBits::N128 => { + let internal: i128 = decimal.internal(); + self.inner.push(internal.into()) + } } } else { panic!("value should be decimal ({value:?})"); @@ -168,9 +173,10 @@ impl ColumnData for DecimalColumnData { } fn at(&self, index: usize) -> ValueRef { - let underlying: i64 = match self.nobits { - NoBits::N32 => i64::from(i32::from(self.inner.at(index))), - NoBits::N64 => i64::from(self.inner.at(index)), + let underlying: i128 = match self.nobits { + NoBits::N32 => i128::from(i32::from(self.inner.at(index))), + NoBits::N64 => i128::from(i64::from(self.inner.at(index))), + NoBits::N128 => i128::from(self.inner.at(index)), }; ValueRef::Decimal(Decimal { @@ -224,6 +230,10 @@ impl ColumnData for DecimalAdapter { let internal: i64 = decimal.internal(); encoder.write(internal); } + NoBits::N128 => { + let internal: i128 = decimal.internal(); + encoder.write(internal); + } } } else { panic!("should be decimal"); @@ -276,6 +286,9 @@ impl ColumnData for NullableDecimalAdapter { encoder.write(underlying as i32); } NoBits::N64 => { + encoder.write(underlying as i64); + } + NoBits::N128 => { encoder.write(underlying); } } diff --git a/src/types/column/factory.rs b/src/types/column/factory.rs index 274eca3e..88423e34 100644 --- a/src/types/column/factory.rs +++ b/src/types/column/factory.rs @@ -182,6 +182,7 @@ impl dyn ColumnData { let inner_type = match nobits { NoBits::N32 => SqlType::Int32, NoBits::N64 => SqlType::Int64, + NoBits::N128 => SqlType::Int128, }; W::wrap(DecimalColumnData { @@ -382,6 +383,7 @@ fn parse_decimal(source: &str) -> Option<(u8, u8, NoBits)> { let precision = match bits { NoBits::N32 => 9, NoBits::N64 => 18, + NoBits::N128 => 38, }; Some((precision, scale, bits)) } @@ -554,7 +556,7 @@ mod test { fn test_parse_decimal() { assert_eq!(parse_decimal("Decimal(9, 4)"), Some((9, 4, NoBits::N32))); assert_eq!(parse_decimal("Decimal(10, 4)"), Some((10, 4, NoBits::N64))); - assert_eq!(parse_decimal("Decimal(20, 4)"), None); + assert_eq!(parse_decimal("Decimal(20, 4)"), Some((20, 4, NoBits::N128))); assert_eq!(parse_decimal("Decimal(2000, 4)"), None); assert_eq!(parse_decimal("Decimal(3, 4)"), None); assert_eq!(parse_decimal("Decimal(20, -4)"), None); diff --git a/src/types/column/iter/mod.rs b/src/types/column/iter/mod.rs index f34956fc..a6659a61 100644 --- a/src/types/column/iter/mod.rs +++ b/src/types/column/iter/mod.rs @@ -309,7 +309,7 @@ impl<'a> DecimalIterator<'a> { unsafe fn next_unchecked_(&mut self) -> Decimal where T: Copy + Sized, - i64: From, + i128: From, { let current_value = *(self.ptr as *const T); self.ptr = (self.ptr as *const T).offset(1) as *const u8; @@ -327,6 +327,7 @@ impl<'a> DecimalIterator<'a> { match self.nobits { NoBits::N32 => self.next_unchecked_::(), NoBits::N64 => self.next_unchecked_::(), + NoBits::N128 => self.next_unchecked_::(), } } @@ -336,6 +337,7 @@ impl<'a> DecimalIterator<'a> { match self.nobits { NoBits::N32 => self.ptr = (self.ptr as *const i32).add(n) as *const u8, NoBits::N64 => self.ptr = (self.ptr as *const i64).add(n) as *const u8, + NoBits::N128 => self.ptr = (self.ptr as *const i128).add(n) as *const u8, } } } @@ -347,6 +349,7 @@ impl<'a> ExactSizeIterator for DecimalIterator<'a> { let size = match self.nobits { NoBits::N32 => mem::size_of::(), NoBits::N64 => mem::size_of::(), + NoBits::N128 => mem::size_of::(), }; (self.end as usize - self.ptr as usize) / size } @@ -983,6 +986,7 @@ impl<'a> Iterable<'a, Simple> for Decimal { match nobits { NoBits::N32 => (ptr as *const u32).add(size) as *const u8, NoBits::N64 => (ptr as *const u64).add(size) as *const u8, + NoBits::N128 => (ptr as *const u128).add(size) as *const u8, } }; diff --git a/src/types/decimal.rs b/src/types/decimal.rs index 800ca1b9..c17d2b24 100644 --- a/src/types/decimal.rs +++ b/src/types/decimal.rs @@ -4,7 +4,7 @@ use std::{ hash::{Hash, Hasher}, }; -static FACTORS10: &[i64] = &[ +static FACTORS10: &[i128] = &[ 1, 10, 100, @@ -24,27 +24,48 @@ static FACTORS10: &[i64] = &[ 10_000_000_000_000_000, 100_000_000_000_000_000, 1_000_000_000_000_000_000, + 10_000_000_000_000_000_000, + 100_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000_000_000_000_000, + 1_000_000_000_000_000_000_000_000_000_000_000_000, + 10_000_000_000_000_000_000_000_000_000_000_000_000, + 100_000_000_000_000_000_000_000_000_000_000_000_000, ]; pub trait Base { - fn scale(self, scale: i64) -> i64; + fn scale(self, scale: i128) -> i128; } pub trait InternalResult { - fn get(underlying: i64) -> Self; + fn get(underlying: i128) -> Self; } #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub(crate) enum NoBits { N32, N64, + N128, } /// Provides arbitrary-precision floating point decimal. #[derive(Clone)] pub struct Decimal { - pub(crate) underlying: i64, - pub(crate) nobits: NoBits, // its domain is {32, 64} + pub(crate) underlying: i128, + pub(crate) nobits: NoBits, // its domain is {32, 64, 128} pub(crate) precision: u8, pub(crate) scale: u8, } @@ -73,8 +94,8 @@ macro_rules! base_for { ( $( $t:ty: $cast:expr ),* ) => { $( impl Base for $t { - fn scale(self, scale: i64) -> i64 { - $cast(self * (scale as $t)) as i64 + fn scale(self, scale: i128) -> i128 { + $cast(self * (scale as $t)) as i128 } } )* @@ -84,26 +105,35 @@ macro_rules! base_for { base_for! { f32: std::convert::identity, f64: std::convert::identity, - i8: i64::from, - i16: i64::from, - i32: i64::from, - i64: std::convert::identity, - u8: i64::from, - u16: i64::from, - u32: i64::from, - u64 : std::convert::identity + i8: i128::from, + i16: i128::from, + i32: i128::from, + i64: i128::from, + i128: std::convert::identity, + u8: i128::from, + u16: i128::from, + u32: i128::from, + u64: i128::from, + u128: std::convert::identity } impl InternalResult for i32 { #[inline(always)] - fn get(underlying: i64) -> Self { + fn get(underlying: i128) -> Self { underlying as Self } } impl InternalResult for i64 { #[inline(always)] - fn get(underlying: i64) -> Self { + fn get(underlying: i128) -> Self { + underlying as Self + } +} + +impl InternalResult for i128 { + #[inline(always)] + fn get(underlying: i128) -> Self { underlying } } @@ -114,6 +144,8 @@ impl NoBits { Some(NoBits::N32) } else if precision <= 18 { Some(NoBits::N64) + } else if precision <= 38 { + Some(NoBits::N128) } else { None } @@ -177,8 +209,8 @@ impl From for f64 { impl Decimal { /// Method of creating a Decimal. - pub fn new(underlying: i64, scale: u8) -> Decimal { - let precision = 18; + pub fn new(underlying: i128, scale: u8) -> Decimal { + let precision = 38; if scale > precision { panic!("scale can't be greater than 18"); } @@ -192,7 +224,7 @@ impl Decimal { } pub fn of(source: B, scale: u8) -> Decimal { - let precision = 18; + let precision = 38; if scale > precision { panic!("scale can't be greater than 18"); } @@ -210,7 +242,7 @@ impl Decimal { } } - /// Get the internal representation of decimal as [`i32`] or [`i64`]. + /// Get the internal representation of decimal as [`i32`] or [`i64`] or [`i128`]. /// /// example: /// ```rust @@ -305,6 +337,12 @@ mod test { assert_eq!(internal, 20000_i64); } + #[test] + fn test_internal128() { + let internal: i128 = Decimal::of(2, 4).internal(); + assert_eq!(internal, 20000_i128); + } + #[test] fn test_scale() { assert_eq!(Decimal::of(2, 4).scale(), 4); diff --git a/src/types/value.rs b/src/types/value.rs index 9b38b416..e86ef9c2 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -806,7 +806,7 @@ mod test { #[test] fn test_size_of() { use std::mem; - assert_eq!(56, mem::size_of::<[Value; 1]>()); + assert_eq!(64, mem::size_of::<[Value; 1]>()); } #[test] diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index 081c0b20..60475831 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -591,7 +591,7 @@ mod test { #[test] fn test_size_of() { use std::mem; - assert_eq!(32, mem::size_of::<[ValueRef<'_>; 1]>()); + assert_eq!(48, mem::size_of::<[ValueRef<'_>; 1]>()); } #[test] @@ -675,7 +675,7 @@ mod test { assert_eq!( SqlType::from(ValueRef::Decimal(Decimal::of(2.0_f64, 4))), - SqlType::Decimal(18, 4) + SqlType::Decimal(38, 4) ); assert_eq!( diff --git a/tests/clickhouse.rs b/tests/clickhouse.rs index e72a1ad2..0a2d7898 100644 --- a/tests/clickhouse.rs +++ b/tests/clickhouse.rs @@ -1712,14 +1712,16 @@ async fn test_decimal() -> Result<(), Error> { let ddl = " CREATE TABLE clickhouse_decimal ( x Decimal(8, 3), - ox Nullable(Decimal(10, 2)) + ox Nullable(Decimal(10, 2)), + xx Decimal(30, 4) ) Engine=Memory"; - let query = "SELECT x, ox FROM clickhouse_decimal"; + let query = "SELECT x, ox, xx FROM clickhouse_decimal"; let block = Block::new() .column("x", vec![Decimal::of(1.234, 3), Decimal::of(5, 3)]) - .column("ox", vec![None, Some(Decimal::of(1.23, 2))]); + .column("ox", vec![None, Some(Decimal::of(1.23, 2))]) + .column("xx", vec![Decimal::of(1.23456, 4), Decimal::of(5, 4)]); let pool = Pool::new(database_url()); @@ -1732,11 +1734,13 @@ async fn test_decimal() -> Result<(), Error> { let x: Decimal = block.get(0, "x")?; let ox: Option = block.get(1, "ox")?; let ox0: Option = block.get(0, "ox")?; + let xx: Decimal = block.get(0, "xx")?; assert_eq!(2, block.row_count()); assert_eq!(1.234, x.into()); assert_eq!(Some(1.23), ox.map(|v| v.into())); assert_eq!(None, ox0); + assert_eq!(1.2345, xx.into()); Ok(()) }