Skip to content

Commit

Permalink
Fix download memory consumption (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor authored Aug 29, 2024
1 parent 8d4c200 commit 56122aa
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .pyproject_generation/pyproject_custom.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ghga_datasteward_kit"
version = "4.2.0"
version = "4.2.1"
description = "GHGA Data Steward Kit - A utils package for GHGA data stewards."
dependencies = [
"crypt4gh >=1.6, <2",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Intended Audience :: Developers",
]
name = "ghga_datasteward_kit"
version = "4.2.0"
version = "4.2.1"
description = "GHGA Data Steward Kit - A utils package for GHGA data stewards."
dependencies = [
"crypt4gh >=1.6, <2",
Expand Down
2 changes: 1 addition & 1 deletion src/ghga_datasteward_kit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def serialize(self, output_path: Path):
output["Symmetric file encryption secret ID"] = self.secret_id

if not output_path.parent.exists():
output_path.mkdir(parents=True)
output_path.parent.mkdir(parents=True)

# owner read-only
with output_path.open("w") as file:
Expand Down
92 changes: 51 additions & 41 deletions src/ghga_datasteward_kit/s3_upload/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
"""Functionality related to downloading uploaded files for validation purposes."""

import math
from collections.abc import Coroutine
from functools import partial
from typing import Any

import httpx
from httpx import Response

from ghga_datasteward_kit import models
Expand Down Expand Up @@ -57,60 +60,67 @@ def __init__( # noqa: PLR0913
self.storage_cleaner = storage_cleaner
self.retry_handler = configure_retries(config)

async def _download_parts(self, download_url):
async def _download_parts(self, fetch_url: partial[Coroutine[Any, Any, str]]):
"""Download file parts"""
for part_number, (start, stop) in enumerate(
get_ranges(file_size=self.file_size, part_size=self.config.part_size),
start=1,
):
headers = {"Range": f"bytes={start}-{stop}"}
LOG.debug("Downloading part number %i. %s", part_number, headers)
try:
response: Response = await self.retry_handler(
fn=self._run_request, url=download_url, headers=headers
)
yield response.content
except (
Exception,
KeyboardInterrupt,
) as exc:
raise self.storage_cleaner.PartDownloadError(
bucket_id=get_bucket_id(self.config),
object_id=self.file_id,
part_number=part_number,
) from exc
async with httpx_client() as client:
for part_number, (start, stop) in enumerate(
get_ranges(file_size=self.file_size, part_size=self.config.part_size),
start=1,
):
headers = {"Range": f"bytes={start}-{stop}"}
LOG.debug("Downloading part number %i. %s", part_number, headers)
try:
response: Response = await self.retry_handler(
fn=self._run_request,
client=client,
url=await fetch_url(),
headers=headers,
)
yield response.content
except (
Exception,
KeyboardInterrupt,
) as exc:
raise self.storage_cleaner.PartDownloadError(
bucket_id=get_bucket_id(self.config),
object_id=self.file_id,
part_number=part_number,
) from exc

async def _run_request(self, *, url: str, headers: dict[str, str]) -> Response:
async def _run_request(
self, *, client: httpx.AsyncClient, url: str, headers: dict[str, str]
) -> Response:
"""Request to be wrapped by retry handler."""
async with httpx_client() as client:
response = await client.get(url, headers=headers)
return response
response = await client.get(url, headers=headers)
return response

async def download(self):
"""Download file in parts and validate checksums"""
LOG.info("(4/7) Downloading file %s for validation.", self.file_id)
download_url = await self.storage.get_object_download_url(
bucket_id=get_bucket_id(self.config), object_id=self.file_id
url_function = partial(
self.storage.get_object_download_url,
bucket_id=get_bucket_id(self.config),
object_id=self.file_id,
)
num_parts = math.ceil(self.file_size / self.part_size)
decryptor = Decryptor(
file_secret=self.file_secret, num_parts=num_parts, part_size=self.part_size
file_secret=self.file_secret,
num_parts=num_parts,
part_size=self.part_size,
target_checksums=self.target_checksums,
)
download_func = partial(self._download_parts, download_url=download_url)
await decryptor.process_parts(download_func)
await self.validate_checksums(checkums=decryptor.checksums)
download_func = partial(self._download_parts, fetch_url=url_function)

async def validate_checksums(self, checkums: models.Checksums):
"""Confirm checksums for upload and download match"""
if self.target_checksums.get() != checkums.get():
message = (
"Checksum mismatch:\n"
+ f"Upload:\n{checkums}\nDownload:\n{self.target_checksums}\n"
+ "Uploaded file was deleted due to validation failure."
)
try:
await decryptor.process_parts(download_func)
except (
decryptor.FileChecksumValidationError,
decryptor.PartChecksumValidationError,
) as error:
raise self.storage_cleaner.ChecksumValidationError(
bucket_id=get_bucket_id(self.config),
object_id=self.file_id,
message=message,
)
message=str(error),
) from error

LOG.info("(6/7) Successfully validated checksums for %s.", self.file_id)
108 changes: 91 additions & 17 deletions src/ghga_datasteward_kit/s3_upload/file_decryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#
"""Functionality to decrypt Crypt4GH encrypted files on-the-fly for validation purposes."""

import gc
import hashlib
from collections.abc import AsyncGenerator
from functools import partial
from time import time
Expand All @@ -25,15 +27,57 @@
from ghga_datasteward_kit import models
from ghga_datasteward_kit.s3_upload.utils import LOG, get_segments

COLLECTION_LIMIT_MIB = 256 * 1024**2


class Decryptor:
"""Handles on the fly decryption and checksum calculation"""

def __init__(self, file_secret: bytes, num_parts: int, part_size: int) -> None:
self.checksums = models.Checksums()
class FileChecksumValidationError(RuntimeError):
"""Raised when checksum validation failed and the uploaded file needs removal."""

def __init__(self, *, current_checksum: str, upload_checksum: str):
message = (
"Checksum mismatch for file:\n"
+ f"Upload:\n{current_checksum}\nDownload:\n{upload_checksum}\n"
+ "Uploaded file was deleted due to validation failure."
)
self.current_checksum = current_checksum
self.upload_checksum = upload_checksum
super().__init__(message)

class PartChecksumValidationError(RuntimeError):
"""Raised when checksum validation failed and the uploaded file needs removal."""

def __init__(
self,
*,
part_number: int,
current_part_checksum: str,
upload_part_checksum: str,
):
message = (
f"Checksum mismatch for part no. {part_number}:\n"
+ f"Upload:\n{current_part_checksum}\nDownload:\n{upload_part_checksum}\n"
+ "Uploaded file was deleted due to validation failure."
)
self.part_number = part_number
self.current_part_checksum = current_part_checksum
self.upload_part_checksum = upload_part_checksum
super().__init__(message)

def __init__(
self,
*,
file_secret: bytes,
num_parts: int,
part_size: int,
target_checksums: models.Checksums,
) -> None:
self.file_secret = file_secret
self.num_parts = num_parts
self.part_size = part_size
self.target_checksums = target_checksums

def _decrypt(self, part: bytes):
"""Decrypt file part"""
Expand All @@ -53,46 +97,76 @@ def _decrypt_segment(self, segment: bytes):
ciphersegment=segment, session_keys=[self.file_secret]
)

def _validate_current_checksum(self, *, file_part: bytes, part_number: int):
"""Verify checksums match for the given file part."""
current_part_md5 = hashlib.md5(file_part, usedforsecurity=False).hexdigest()
current_part_sha256 = hashlib.sha256(file_part).hexdigest()

upload_part_md5 = self.target_checksums.encrypted_md5[part_number - 1]
upload_part_sha256 = self.target_checksums.encrypted_sha256[part_number - 1]

if current_part_md5 != upload_part_md5:
raise self.PartChecksumValidationError(
part_number=part_number,
current_part_checksum=current_part_md5,
upload_part_checksum=upload_part_md5,
)
elif current_part_sha256 != upload_part_sha256:
raise self.PartChecksumValidationError(
part_number=part_number,
current_part_checksum=current_part_sha256,
upload_part_checksum=upload_part_sha256,
)

async def process_parts(self, download_files: partial[AsyncGenerator[bytes, Any]]):
"""Encrypt and upload file parts."""
unprocessed_bytes = b""
download_buffer = b""
unencrypted_sha256 = hashlib.sha256()

start = time()

part_number = 0
part_number = 1
collection_tracker_mib = 0
async for file_part in download_files():
# process unencrypted
self.checksums.update_encrypted(file_part)
# process encrypted
self._validate_current_checksum(
file_part=file_part, part_number=part_number
)
unprocessed_bytes += file_part
collection_tracker_mib += len(file_part)

# encrypt in chunks
# decrypt in chunks
decrypted_bytes, unprocessed_bytes = self._decrypt(unprocessed_bytes)
download_buffer += decrypted_bytes

# update checksums and yield if part size
if len(download_buffer) >= self.part_size:
current_part = download_buffer[: self.part_size]
self.checksums.update_unencrypted(current_part)
download_buffer = download_buffer[self.part_size :]
unencrypted_sha256.update(download_buffer)
download_buffer = b""

delta = time() - start
avg_speed = (part_number * (self.part_size / 1024**2)) / delta

LOG.info(
"(5/7) Downloading part %i/%i (%.2f MiB/s)",
part_number,
self.num_parts,
avg_speed,
)
part_number += 1
if collection_tracker_mib >= COLLECTION_LIMIT_MIB:
collection_tracker_mib = 0
gc.collect()

# process dangling bytes
if unprocessed_bytes:
download_buffer += self._decrypt_segment(unprocessed_bytes)

while len(download_buffer) >= self.part_size:
current_part = download_buffer[: self.part_size]
self.checksums.update_unencrypted(current_part)
download_buffer = download_buffer[self.part_size :]
unencrypted_sha256.update(download_buffer)
download_buffer = b""

if download_buffer:
self.checksums.update_unencrypted(download_buffer)
current_checksum = unencrypted_sha256.hexdigest()
upload_checksum = self.target_checksums.unencrypted_sha256.hexdigest()
if current_checksum != upload_checksum:
raise self.FileChecksumValidationError(
current_checksum=current_checksum, upload_checksum=upload_checksum
)
37 changes: 23 additions & 14 deletions src/ghga_datasteward_kit/s3_upload/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from uuid import uuid4

import crypt4gh.lib # type: ignore
import httpx
from httpx import Response

from ghga_datasteward_kit.s3_upload.config import LegacyConfig
Expand Down Expand Up @@ -72,19 +73,24 @@ async def encrypt_and_upload(self):
start = time()

with open(self.input_path, "rb") as file:
async with MultipartUpload(
config=self.config,
file_id=self.file_id,
encrypted_file_size=encrypted_file_size,
part_size=self.config.part_size,
storage_cleaner=self._storage_cleaner,
debug_mode=self.debug_mode,
) as upload:
async with (
MultipartUpload(
config=self.config,
file_id=self.file_id,
encrypted_file_size=encrypted_file_size,
part_size=self.config.part_size,
storage_cleaner=self._storage_cleaner,
debug_mode=self.debug_mode,
) as upload,
httpx_client() as client,
):
LOG.info("(1/7) Initialized file upload for %s.", upload.file_id)
for part_number, part in enumerate(
self.encryptor.process_file(file=file), start=1
):
await upload.send_part(part_number=part_number, part=part)
await upload.send_part(
client=client, part_number=part_number, part=part
)

delta = time() - start
avg_speed = part_number * (self.config.part_size / 1024**2) / delta
Expand Down Expand Up @@ -150,7 +156,9 @@ async def __aexit__(self, exc_t, exc_v, exc_tb):
upload_id=self.upload_id,
) from exc

async def send_part(self, *, part: bytes, part_number: int):
async def send_part(
self, *, client: httpx.AsyncClient, part: bytes, part_number: int
):
"""Handle upload of one file part"""
try:
upload_url = await self.storage.get_part_upload_url(
Expand All @@ -162,7 +170,7 @@ async def send_part(self, *, part: bytes, part_number: int):
# wait slightly before using the upload URL
await asyncio.sleep(0.1)
response: Response = await self.retry_handler(
fn=self._run_request, url=upload_url, part=part
fn=self._run_request, client=client, url=upload_url, part=part
)

status_code = response.status_code
Expand All @@ -179,8 +187,9 @@ async def send_part(self, *, part: bytes, part_number: int):
upload_id=self.upload_id,
) from exc

async def _run_request(self, *, url: str, part: bytes) -> Response:
async def _run_request(
self, *, client: httpx.AsyncClient, url: str, part: bytes
) -> Response:
"""Request to be wrapped by retry handler."""
async with httpx_client() as client:
response = await client.put(url=url, content=part)
response = await client.put(url=url, content=part)
return response

0 comments on commit 56122aa

Please sign in to comment.