Skip to content

Commit

Permalink
feat: support overriding scan kwargs for specific tables (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnadolny13 authored Jun 6, 2023
1 parent ce75d10 commit 75e7e80
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Built with the [Meltano Singer SDK](https://sdk.meltano.com).
|:------------------------|:--------:|:-------:|:------------|
| tables | False | None | An array of table names to extract from. |
| infer_schema_sample_size| False | 100 | The amount of records to sample when inferring the schema. |
| table_scan_kwargs | False | None | A mapping of table name to the scan kwargs that should be used to override the default when querying that table. |
| aws_access_key_id | False | None | The access key for your AWS account. |
| aws_secret_access_key | False | None | The secret key for your AWS account. |
| aws_session_token | False | None | The session key for your AWS account. This is only needed when you are using temporary credentials. |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "meltanolabs-tap-dynamodb"
version = "0.2.0"
version = "0.3.0"
description = "`tap-dynamodb` is a Singer tap for DynamoDB, built with the Meltano Singer SDK."
readme = "README.md"
authors = ["Pat Nadolny"]
Expand Down
26 changes: 17 additions & 9 deletions tap_dynamodb/dynamodb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ def list_tables(self, include=None):
else:
return tables

def get_items_iter(
self, table_name: str, scan_kwargs: dict = {"ConsistentRead": True}
):
def get_items_iter(self, table_name: str, scan_kwargs_override: dict):
scan_kwargs = scan_kwargs_override.copy()
if "ConsistentRead" not in scan_kwargs:
scan_kwargs["ConsistentRead"] = True

table = self.resource.Table(table_name)
try:
done = False
Expand All @@ -85,20 +87,26 @@ def get_items_iter(
)
raise

def _get_sample_records(self, table_name: str, sample_size: int) -> list:
def _get_sample_records(
self, table_name: str, sample_size: int, scan_kwargs_override: dict
) -> list:
scan_kwargs = scan_kwargs_override.copy()
sample_records = []
for batch in self.get_items_iter(
table_name, scan_kwargs={"Limit": sample_size, "ConsistentRead": True}
):
if "ConsistentRead" not in scan_kwargs:
scan_kwargs["ConsistentRead"] = True
if "Limit" not in scan_kwargs:
scan_kwargs["Limit"] = sample_size

for batch in self.get_items_iter(table_name, scan_kwargs):
sample_records.extend(batch)
if len(sample_records) >= sample_size:
break
return sample_records

def get_table_json_schema(
self, table_name: str, sample_size, strategy: str = "infer"
self, table_name: str, sample_size, scan_kwargs: dict, strategy: str = "infer"
) -> dict:
sample_records = self._get_sample_records(table_name, sample_size)
sample_records = self._get_sample_records(table_name, sample_size, scan_kwargs)

if not sample_records:
raise EmptyTableException()
Expand Down
9 changes: 8 additions & 1 deletion tap_dynamodb/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def __init__(
self._table_name: str = name
self._schema: dict = {}
self._infer_schema_sample_size = infer_schema_sample_size
self._table_scan_kwargs: dict = tap.config.get("table_scan_kwargs", {}).get(
name, {}
)
if tap.input_catalog:
catalog_entry = tap.input_catalog.get(name)
if catalog_entry:
Expand All @@ -54,7 +57,10 @@ def __init__(
super().__init__(name=name, tap=tap)

def get_records(self, context: dict | None) -> Iterable[dict]:
for batch in self._dynamodb_conn.get_items_iter(self._table_name):
for batch in self._dynamodb_conn.get_items_iter(
self._table_name,
self._table_scan_kwargs,
):
for record in batch:
yield record

Expand All @@ -71,6 +77,7 @@ def schema(self) -> dict:
self._schema = self._dynamodb_conn.get_table_json_schema(
self._table_name,
self._infer_schema_sample_size,
self._table_scan_kwargs,
)
self._primary_keys = self._dynamodb_conn.get_table_key_properties(
self._table_name
Expand Down
8 changes: 8 additions & 0 deletions tap_dynamodb/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ class TapDynamoDB(Tap):
description="The amount of records to sample when inferring the schema.",
default=100,
),
th.Property(
"table_scan_kwargs",
th.ObjectType(),
description=(
"A mapping of table name to the scan kwargs that should be used to "
"override the default when querying that table."
),
),
).to_dict()

def discover_streams(self) -> list[streams.TableStream]:
Expand Down
53 changes: 50 additions & 3 deletions tests/test_dynamodb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,35 @@ def test_get_items():
# END PREP

db_obj = DynamoDbConnector(SAMPLE_CONFIG)
records = list(db_obj.get_items_iter("table"))[0]
records = list(db_obj.get_items_iter("table", {}))[0]
assert len(records) == 1
# Type coercion
assert records[0].get("year") == "2023"
assert records[0].get("title") == "foo"
assert records[0].get("info") == {"plot": "bar"}


@mock_dynamodb
def test_get_items_w_kwargs():
# PREP
moto_conn = boto3.resource("dynamodb", region_name="us-west-2")
table = create_table(moto_conn, "table")
table.put_item(Item={"year": 2023, "title": "foo", "info": {"plot": "bar"}})
# END PREP

db_obj = DynamoDbConnector(SAMPLE_CONFIG)
records = list(
db_obj.get_items_iter(
"table",
{"Select": "SPECIFIC_ATTRIBUTES", "ProjectionExpression": "title, info"},
)
)[0]
assert len(records) == 1
# Type coercion
assert records[0].get("title") == "foo"
assert records[0].get("info") == {"plot": "bar"}


@mock_dynamodb
def test_get_items_paginate():
# PREP
Expand Down Expand Up @@ -113,7 +134,7 @@ def test_get_table_json_schema():
# END PREP

db_obj = DynamoDbConnector(SAMPLE_CONFIG)
schema = db_obj.get_table_json_schema("table", 5)
schema = db_obj.get_table_json_schema("table", 5, {})
assert schema == {
"type": "object",
"properties": {
Expand All @@ -124,6 +145,32 @@ def test_get_table_json_schema():
}


@mock_dynamodb
def test_get_table_json_schema_w_kwargs():
# PREP
moto_conn = boto3.resource("dynamodb", region_name="us-west-2")
table = create_table(moto_conn, "table")
for num in range(5):
table.put_item(
Item={"year": 2023, "title": f"foo_{num}", "info": {"plot": "bar"}}
)
# END PREP

db_obj = DynamoDbConnector(SAMPLE_CONFIG)
schema = db_obj.get_table_json_schema(
"table",
5,
{"Select": "SPECIFIC_ATTRIBUTES", "ProjectionExpression": "title, info"},
)
assert schema == {
"type": "object",
"properties": {
"title": {"type": "string"},
"info": {"type": "object", "properties": {"plot": {"type": "string"}}},
},
}


@mock_dynamodb
def test_get_table_key_properties():
# PREP
Expand Down Expand Up @@ -159,5 +206,5 @@ def test_get_sample_records():
# END PREP

db_obj = DynamoDbConnector(SAMPLE_CONFIG)
records = db_obj._get_sample_records("table", 2)
records = db_obj._get_sample_records("table", 2, {})
assert len(records) == 2

0 comments on commit 75e7e80

Please sign in to comment.