Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced text to sql sample rows #17479

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from llama_index.core.indices.base import BaseRetriever
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import (
RESPONSE_TYPE,
Expand Down Expand Up @@ -590,6 +590,7 @@ def __init__(
self,
sql_database: SQLDatabase,
table_retriever: ObjectRetriever[SQLTableSchema],
rows_retrievers: Optional[dict[str, BaseRetriever]] = None,
llm: Optional[LLM] = None,
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
Expand All @@ -608,9 +609,11 @@ def __init__(
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
table_retriever=table_retriever,
rows_retrievers=rows_retrievers,
context_str_prefix=context_str_prefix,
sql_only=sql_only,
callback_manager=callback_manager,
verbose=kwargs.get("verbose", False),
)
super().__init__(
synthesize_response=synthesize_response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class NLSQLRetriever(BaseRetriever, PromptMixin):
tables (Union[List[str], List[Table]]): List of table names or Table objects.
table_retriever (ObjectRetriever[SQLTableSchema]): Object retriever for
SQLTableSchema objects. Defaults to None.
rows_retriever (Dict[str, VectorIndexRetriever]): a mapping between table name and
a vector index retriever of its rows. Defaults to None.
context_str_prefix (str): Prefix for context string. Defaults to None.
return_raw (bool): Whether to return plain-text dump of SQL results, or parsed into Nodes.
handle_sql_errors (bool): Whether to handle SQL errors. Defaults to True.
Expand All @@ -205,6 +207,7 @@ def __init__(
context_query_kwargs: Optional[dict] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
rows_retrievers: Optional[dict[str, BaseRetriever]] = None,
context_str_prefix: Optional[str] = None,
sql_parser_mode: SQLParserMode = SQLParserMode.DEFAULT,
llm: Optional[LLM] = None,
Expand Down Expand Up @@ -232,6 +235,9 @@ def __init__(
self._handle_sql_errors = handle_sql_errors
self._sql_only = sql_only
self._verbose = verbose

# To retrieve relevant rows from each retrieved table
self._rows_retrievers = rows_retrievers
super().__init__(
callback_manager=callback_manager or Settings.callback_manager,
verbose=verbose,
Expand Down Expand Up @@ -391,26 +397,32 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
return retrieved_nodes

def _get_table_context(self, query_bundle: QueryBundle) -> str:
"""Get table context.

Get tables schema + optional context as a single string.

"""
"""Get table context string."""
table_schema_objs = self._get_tables(query_bundle.query_str)
context_strs = []
if self._context_str_prefix is not None:
context_strs = [self._context_str_prefix]

for table_schema_obj in table_schema_objs:
# first append table info + additional context
table_info = self._sql_database.get_single_table_info(
table_schema_obj.table_name
)

if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context

# also lookup vector index to return relevant table rows
# if rows_retrievers was not passed, no rows will be returned
if self._rows_retrievers is not None:
rows_retriever = self._rows_retrievers[table_schema_obj.table_name]
relevant_nodes = rows_retriever.retrieve(query_bundle.query_str)
if len(relevant_nodes) > 0:
table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
table_row_context += str(node.get_content()) + "\n"
table_info += table_row_context

if self._verbose:
print(f"> Table Info: {table_info}")
context_strs.append(table_info)

return "\n\n".join(context_strs)
106 changes: 104 additions & 2 deletions llama-index-core/tests/indices/struct_store/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,32 @@

import pytest
from llama_index.core.async_utils import asyncio_run
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.indices.struct_store.base import default_output_parser
from llama_index.core.indices.struct_store.sql import SQLStructStoreIndex
from llama_index.core.indices.struct_store.sql_query import (
NLSQLTableQueryEngine,
NLStructStoreQueryEngine,
SQLStructStoreQueryEngine,
SQLTableRetrieverQueryEngine,
)
from llama_index.core.schema import Document
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core.schema import Document, TextNode
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from sqlalchemy import (
Column,
Integer,
MetaData,
String,
Table,
create_engine,
insert,
text,
)
from sqlalchemy.exc import OperationalError


Expand Down Expand Up @@ -227,3 +243,89 @@ def test_nl_query_engine_parser(
nl_query_engine._parse_response_to_sql(response)
== "SELECT * FROM table WHERE name = ''O''Reilly'';"
)


def test_sql_table_retriever_query_engine_with_rows_retriever(
patch_llm_predictor,
patch_token_text_splitter,
struct_kwargs: Tuple[Dict, Dict],
) -> None:
"""Test SQLTableRetrieverQueryEngine."""
index_kwargs, query_kwargs = struct_kwargs
sql_to_test = "SELECT user_id, foo FROM test_table"
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
table_name = "test_table"
# NOTE: table is created by tying to metadata_obj
table_instance = Table(
table_name,
metadata_obj,
Column("user_id", Integer, primary_key=True),
Column("foo", String(16), nullable=False),
)
metadata_obj.create_all(engine)
sql_database = SQLDatabase(engine)
# Inserting fake values into table
statement = insert(table_instance).values(
[{"user_id": 2, "foo": "bar"}, {"user_id": 8, "foo": "hello"}]
)
with engine.connect() as conn:
conn.execute(statement)
conn.commit()

# Building rows retriever
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
result = cursor.fetchall()
row_tups = []
for row in result:
row_tups.append(tuple(row))
# index each row, put into vector store index
nodes = [TextNode(text=str(t)) for t in result]
index = VectorStoreIndex(nodes)
rows_retrievers = {table_name: index.as_retriever()}

# Building the table retriever
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(
table_name=table_name,
context_str="This table contains information about user id and the foo attribute.",
)
] # add a SQLTableSchema for each table
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex, # type: ignore
)
table_retriever = obj_index.as_retriever()

# query the index with natural language
nl_query_engine = SQLTableRetrieverQueryEngine(
sql_database, table_retriever, rows_retrievers
)
response = nl_query_engine.query("test_table:user_id,foo")
assert str(response) == "[(2, 'bar'), (8, 'hello')]"

nl_table_engine = SQLTableRetrieverQueryEngine(
sql_database, table_retriever, rows_retrievers, sql_only=True
)
response = nl_table_engine.query("test_table:user_id,foo")
assert str(response) == sql_to_test

# query with markdown return
nl_table_engine = SQLTableRetrieverQueryEngine(
sql_database,
table_retriever,
rows_retrievers,
synthesize_response=False,
markdown_response=True,
)
response = nl_table_engine.query("test_table:user_id,foo")
assert (
str(response)
== """| user_id | foo |
|---|---|
| 2 | bar |
| 8 | hello |"""
)
Loading