Skip to content

Commit

Permalink
feat: replace simple with complex UDF implementation
Browse files Browse the repository at this point in the history
the reason being is that we can access captured
python function and serialize it if needed
  • Loading branch information
milenkovicm committed Jan 23, 2025
1 parent 78e72c9 commit 5f06cda
Showing 1 changed file with 121 additions and 53 deletions.
174 changes: 121 additions & 53 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,23 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;
use std::any::Any;

use pyo3::{prelude::*, types::PyTuple};

use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
use datafusion::arrow::array::{make_array, Array, ArrayData};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::Result;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::function::ScalarFunctionImplementation;
use datafusion::logical_expr::ScalarUDF;
use datafusion::logical_expr::{create_udf, ColumnarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Volatility};
use datafusion::logical_expr::{ScalarUDF, Signature};
use std::fmt::Debug;

use crate::expr::PyExpr;
use crate::utils::parse_volatility;

/// Create a Rust callable function from a python function that expects pyarrow arrays
fn pyarrow_function_to_rust(
func: PyObject,
) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.iter()
.map(|arg| {
arg.into_data()
.to_pyarrow(py)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))
})
.collect::<Result<Vec<_>, _>>()?;
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
Ok(make_array(array_data))
})
}
}

/// Create a DataFusion's UDF implementation from a python function
/// that expects pyarrow arrays. This is more efficient as it performs
/// a zero-copy of the contents.
fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
// Make the python function callable from rust
let pyarrow_func = pyarrow_function_to_rust(func);

// Convert input/output from datafusion ColumnarValue to arrow arrays
Arc::new(move |args: &[ColumnarValue]| {
let array_refs = ColumnarValue::values_to_arrays(args)?;
let array_result = pyarrow_func(&array_refs)?;
Ok(array_result.into())
})
}

/// Represents a PyScalarUDF
#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
Expand All @@ -94,14 +50,17 @@ impl PyScalarUDF {
return_type: PyArrowType<DataType>,
volatility: &str,
) -> PyResult<Self> {
let function = create_udf(
let function = PythonUDF::new(
name,
input_types.0,
return_type.0,
parse_volatility(volatility)?,
to_scalar_function_impl(func),
func,
);
Ok(Self { function })

Ok(Self {
function: function.into(),
})
}

/// creates a new PyExpr with the call of the udf
Expand All @@ -115,3 +74,112 @@ impl PyScalarUDF {
Ok(format!("ScalarUDF({})", self.function.name()))
}
}

/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
/// return type.
pub struct PythonUDF {
name: String,
signature: Signature,
return_type: DataType,
func: PyObject,
}

impl Debug for PythonUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("PythonUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("return_type", &self.return_type)
.field("func", &"<FUNC>")
.finish()
}
}

impl PythonUDF {
/// Create a new `PythonUDF` from a name, input types, return type and
/// implementation.
pub fn new(
name: impl Into<String>,
input_types: Vec<DataType>,
return_type: DataType,
volatility: Volatility,
func: PyObject,
) -> Self {
Self::new_with_signature(
name,
Signature::exact(input_types, volatility),
return_type,
func,
)
}

/// Create a new `SimpleScalarUDF` from a name, signature, return type and
/// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
pub fn new_with_signature(
name: impl Into<String>,
signature: Signature,
return_type: DataType,
func: PyObject,
) -> Self {
Self {
name: name.into(),
signature,
return_type,
func,
}
}
/// Returns underlying python function
///
/// Intention is to allow the function serialization from
/// logical plan encoder if we want to distribute it
#[allow(dead_code)]
pub fn py_func(&self) -> &PyObject {
&self.func
}
}

impl ScalarUDFImpl for PythonUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke_batch(&self, args: &[ColumnarValue], _number_rows: usize) -> Result<ColumnarValue> {
let array_refs = ColumnarValue::values_to_arrays(args)?;
let array_data: Result<_> = Python::with_gil(|py| {
// 1. cast args to PyArrow arrays
let py_args = array_refs
.iter()
.map(|arg| {
arg.into_data()
.to_pyarrow(py)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))
})
.collect::<Result<Vec<_>, _>>()?;
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = self
.func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

// 3. cast to arrow::array::Array
ArrayData::from_pyarrow_bound(value.bind(py))
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))
});

Ok(make_array(array_data?).into())
}
}

0 comments on commit 5f06cda

Please sign in to comment.