From 5f06cdabbcf4d8066287c7a1d3fe0f0ae92a12dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Thu, 23 Jan 2025 21:48:42 +0000 Subject: [PATCH] feat: replace simple with complex UDF implementation the reason being is that we can access captured python function and serialize it if needed --- src/udf.rs | 174 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 121 insertions(+), 53 deletions(-) diff --git a/src/udf.rs b/src/udf.rs index 4570e77a..68941282 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -15,67 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::any::Any; use pyo3::{prelude::*, types::PyTuple}; -use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; +use datafusion::arrow::array::{make_array, Array, ArrayData}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; +use datafusion::common::Result; use datafusion::error::DataFusionError; -use datafusion::logical_expr::function::ScalarFunctionImplementation; -use datafusion::logical_expr::ScalarUDF; -use datafusion::logical_expr::{create_udf, ColumnarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Volatility}; +use datafusion::logical_expr::{ScalarUDF, Signature}; +use std::fmt::Debug; use crate::expr::PyExpr; use crate::utils::parse_volatility; -/// Create a Rust callable function from a python function that expects pyarrow arrays -fn pyarrow_function_to_rust( - func: PyObject, -) -> impl Fn(&[ArrayRef]) -> Result { - move |args: &[ArrayRef]| -> Result { - Python::with_gil(|py| { - // 1. cast args to Pyarrow arrays - let py_args = args - .iter() - .map(|arg| { - arg.into_data() - .to_pyarrow(py) - .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) - }) - .collect::, _>>()?; - let py_args = PyTuple::new_bound(py, py_args); - - // 2. call function - let value = func - .call_bound(py, py_args, None) - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - - // 3. cast to arrow::array::Array - let array_data = ArrayData::from_pyarrow_bound(value.bind(py)) - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - Ok(make_array(array_data)) - }) - } -} - -/// Create a DataFusion's UDF implementation from a python function -/// that expects pyarrow arrays. This is more efficient as it performs -/// a zero-copy of the contents. -fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { - // Make the python function callable from rust - let pyarrow_func = pyarrow_function_to_rust(func); - - // Convert input/output from datafusion ColumnarValue to arrow arrays - Arc::new(move |args: &[ColumnarValue]| { - let array_refs = ColumnarValue::values_to_arrays(args)?; - let array_result = pyarrow_func(&array_refs)?; - Ok(array_result.into()) - }) -} - /// Represents a PyScalarUDF #[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -94,14 +50,17 @@ impl PyScalarUDF { return_type: PyArrowType, volatility: &str, ) -> PyResult { - let function = create_udf( + let function = PythonUDF::new( name, input_types.0, return_type.0, parse_volatility(volatility)?, - to_scalar_function_impl(func), + func, ); - Ok(Self { function }) + + Ok(Self { + function: function.into(), + }) } /// creates a new PyExpr with the call of the udf @@ -115,3 +74,112 @@ impl PyScalarUDF { Ok(format!("ScalarUDF({})", self.function.name())) } } + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct PythonUDF { + name: String, + signature: Signature, + return_type: DataType, + func: PyObject, +} + +impl Debug for PythonUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("PythonUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("return_type", &self.return_type) + .field("func", &"") + .finish() + } +} + +impl PythonUDF { + /// Create a new `PythonUDF` from a name, input types, return type and + /// implementation. + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + func: PyObject, + ) -> Self { + Self::new_with_signature( + name, + Signature::exact(input_types, volatility), + return_type, + func, + ) + } + + /// Create a new `SimpleScalarUDF` from a name, signature, return type and + /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + func: PyObject, + ) -> Self { + Self { + name: name.into(), + signature, + return_type, + func, + } + } + /// Returns underlying python function + /// + /// Intention is to allow the function serialization from + /// logical plan encoder if we want to distribute it + #[allow(dead_code)] + pub fn py_func(&self) -> &PyObject { + &self.func + } +} + +impl ScalarUDFImpl for PythonUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke_batch(&self, args: &[ColumnarValue], _number_rows: usize) -> Result { + let array_refs = ColumnarValue::values_to_arrays(args)?; + let array_data: Result<_> = Python::with_gil(|py| { + // 1. cast args to PyArrow arrays + let py_args = array_refs + .iter() + .map(|arg| { + arg.into_data() + .to_pyarrow(py) + .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) + }) + .collect::, _>>()?; + let py_args = PyTuple::new_bound(py, py_args); + + // 2. call function + let value = self + .func + .call_bound(py, py_args, None) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + // 3. cast to arrow::array::Array + ArrayData::from_pyarrow_bound(value.bind(py)) + .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) + }); + + Ok(make_array(array_data?).into()) + } +}