Skip to content

Commit

Permalink
Default to ZSTD compression when writing Parquet (#981)
Browse files Browse the repository at this point in the history
* fix: update default compression to ZSTD and improve documentation for write_parquet method

* fix: clarify compression level documentation for ZSTD in write_parquet method

* fix: update default compression level for ZSTD to 4 in write_parquet method

* fix: improve docstring formatting for DataFrame parquet writing method

* feat: implement Compression enum and update write_parquet method to use it

* add test

* fix: remove unused import and update default compression to ZSTD in rs' write_parquet method

* fix: update compression type strings to lowercase in DataFrame parquet writing method doc

* test: update parquet compression tests to validate invalid and default compression levels

* add comment on source of Compression

* docs: enhance Compression enum documentation and add default level method

* test: include gzip in default compression level tests for write_parquet

* refactor: simplify Compression enum methods and improve type handling in DataFrame.write_parquet

* docs: update Compression enum methods to include return type descriptions

* move comment to within test

* Ruff format

---------

Co-authored-by: Tim Saucer <timsaucer@gmail.com>
  • Loading branch information
kosiew and timsaucer authored Jan 11, 2025
1 parent db1bc62 commit 2d8b1d3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 9 deletions.
94 changes: 88 additions & 6 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@

from __future__ import annotations
import warnings
from typing import Any, Iterable, List, TYPE_CHECKING, Literal, overload
from typing import (
Any,
Iterable,
List,
TYPE_CHECKING,
Literal,
overload,
Optional,
Union,
)
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
from datafusion.plan import LogicalPlan, ExecutionPlan
Expand All @@ -35,6 +44,60 @@

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr, SortExpr, sort_or_default
from enum import Enum


# excerpt from deltalake
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
class Compression(Enum):
"""Enum representing the available compression types for Parquet files."""

UNCOMPRESSED = "uncompressed"
SNAPPY = "snappy"
GZIP = "gzip"
BROTLI = "brotli"
LZ4 = "lz4"
LZ0 = "lz0"
ZSTD = "zstd"
LZ4_RAW = "lz4_raw"

@classmethod
def from_str(cls, value: str) -> "Compression":
"""Convert a string to a Compression enum value.
Args:
value: The string representation of the compression type.
Returns:
The Compression enum lowercase value.
Raises:
ValueError: If the string does not match any Compression enum value.
"""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"{value} is not a valid Compression. Valid values are: {[item.value for item in Compression]}"
)

def get_default_level(self) -> Optional[int]:
"""Get the default compression level for the compression type.
Returns:
The default compression level for the compression type.
"""
# GZIP, BROTLI default values from deltalake repo
# https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163
# ZSTD default value from delta-rs
# https://github.com/apache/datafusion-python/pull/981#discussion_r1904789223
if self == Compression.GZIP:
return 6
elif self == Compression.BROTLI:
return 1
elif self == Compression.ZSTD:
return 4
return None


class DataFrame:
Expand Down Expand Up @@ -620,17 +683,36 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None
def write_parquet(
self,
path: str | pathlib.Path,
compression: str = "uncompressed",
compression: Union[str, Compression] = Compression.ZSTD,
compression_level: int | None = None,
) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a Parquet file.
Args:
path: Path of the Parquet file to write.
compression: Compression type to use.
compression_level: Compression level to use.
"""
self.df.write_parquet(str(path), compression, compression_level)
compression: Compression type to use. Default is "ZSTD".
Available compression types are:
- "uncompressed": No compression.
- "snappy": Snappy compression.
- "gzip": Gzip compression.
- "brotli": Brotli compression.
- "lz0": LZ0 compression.
- "lz4": LZ4 compression.
- "lz4_raw": LZ4_RAW compression.
- "zstd": Zstandard compression.
compression_level: Compression level to use. For ZSTD, the
recommended range is 1 to 22, with the default being 4. Higher levels
provide better compression but slower speed.
"""
# Convert string to Compression enum if necessary
if isinstance(compression, str):
compression = Compression.from_str(compression)

if compression in {Compression.GZIP, Compression.BROTLI, Compression.ZSTD}:
if compression_level is None:
compression_level = compression.get_default_level()

self.df.write_parquet(str(path), compression.value, compression_level)

def write_json(self, path: str | pathlib.Path) -> None:
"""Execute the :py:class:`DataFrame` and write the results to a JSON file.
Expand Down
14 changes: 12 additions & 2 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,24 @@ def test_write_compressed_parquet_wrong_compression_level(
)


@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression):
@pytest.mark.parametrize("compression", ["wrong"])
def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression):
path = tmp_path

with pytest.raises(ValueError):
df.write_parquet(str(path), compression=compression)


@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"])
def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression):
# Test write_parquet with zstd, brotli, gzip default compression level,
# ie don't specify compression level
# should complete without error
path = tmp_path

df.write_parquet(str(path), compression=compression)


def test_dataframe_export(df) -> None:
# Guarantees that we have the canonical implementation
# reading our dataframe export
Expand Down
2 changes: 1 addition & 1 deletion src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ impl PyDataFrame {
/// Write a `DataFrame` to a Parquet file.
#[pyo3(signature = (
path,
compression="uncompressed",
compression="zstd",
compression_level=None
))]
fn write_parquet(
Expand Down

0 comments on commit 2d8b1d3

Please sign in to comment.