Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

856 token indices sequence length #877

Merged
merged 1 commit into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/local_models/smart_scraper_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
"model_tokens": 1024,
"model_tokens": 4096,
},
"verbose": True,
"headless": False,
Expand All @@ -25,7 +25,7 @@
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="Find some information about what does the company do, the name and a contact email.",
prompt="Find some information about what does the company do and the list of founders.",
source="https://scrapegraphai.com/",
config=graph_config,
)
Expand Down
29 changes: 18 additions & 11 deletions examples/local_models/smart_scraper_schema_ollama.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
"""
"""
Basic example of scraping pipeline using SmartScraper with schema
"""

import json
from typing import List

from pydantic import BaseModel, Field

from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info


# ************************************************
# Define the configuration for the graph
# ************************************************
class Project(BaseModel):
title: str = Field(description="The title of the project")
description: str = Field(description="The description of the project")


class Projects(BaseModel):
projects: List[Project]
projects: list[Project]


graph_config = {
"llm": {
"model": "ollama/llama3.1",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"llm": {"model": "ollama/llama3.2", "temperature": 0, "model_tokens": 4096},
"verbose": True,
"headless": False
"headless": False,
}

# ************************************************
Expand All @@ -36,8 +36,15 @@ class Projects(BaseModel):
prompt="List me all the projects with their description",
source="https://perinim.github.io/projects/",
schema=Projects,
config=graph_config
config=graph_config,
)

result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ dependencies = [
"googlesearch-python>=1.2.5",
"async-timeout>=4.0.3",
"simpleeval>=1.0.0",
"jsonschema>=4.23.0",
"transformers>=4.46.3",
"jsonschema>=4.23.0"
]

readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/docloaders/chromium.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def __init__(

dynamic_import(backend, message)

self.backend = backend
self.browser_config = kwargs
self.headless = headless
self.proxy = parse_or_search_proxy(proxy) if proxy else None
self.urls = urls
self.load_state = load_state
self.requires_js_support = requires_js_support
self.storage_state = storage_state
self.browser_name = browser_name
self.backend = kwargs.get("backend", backend)
self.browser_name = kwargs.get("browser_name", browser_name)
self.retry_limit = kwargs.get("retry_limit", retry_limit)
self.timeout = kwargs.get("timeout", timeout)

Expand Down
5 changes: 3 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def _create_llm(self, llm_config: dict) -> object:
]
except KeyError:
print(
f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
using default token size (8192)"""
f"""Max input tokens for model {llm_params['model_provider']}/{llm_params['model']} not found,
please specify the model_tokens parameter in the llm section of the graph configuration.
Using default token size: 8192"""
)
self.model_token = 8192
else:
Expand Down
10 changes: 6 additions & 4 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai import ChatOpenAI
from requests.exceptions import Timeout
from tqdm import tqdm

Expand Down Expand Up @@ -59,7 +59,10 @@ def __init__(
self.llm_model = node_config["llm_model"]

if isinstance(node_config["llm_model"], ChatOllama):
self.llm_model.format = "json"
if node_config.get("schema", None) is None:
self.llm_model.format = "json"
else:
self.llm_model.format = self.node_config["schema"].model_json_schema()

self.verbose = node_config.get("verbose", False)
self.force = node_config.get("force", False)
Expand Down Expand Up @@ -123,8 +126,7 @@ def execute(self, state: dict) -> dict:
format_instructions = ""

if (
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
and not self.script_creator
not self.script_creator
or self.force
and not self.script_creator
or self.is_md_scraper
Expand Down
13 changes: 10 additions & 3 deletions scrapegraphai/nodes/generate_answer_node_k_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_mistralai import ChatMistralAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_openai import ChatOpenAI
from tqdm import tqdm

from ..prompts import (
Expand Down Expand Up @@ -55,6 +56,13 @@ def __init__(
super().__init__(node_name, "node", input, output, 2, node_config)

self.llm_model = node_config["llm_model"]

if isinstance(node_config["llm_model"], ChatOllama):
if node_config.get("schema", None) is None:
self.llm_model.format = "json"
else:
self.llm_model.format = self.node_config["schema"].model_json_schema()

self.embedder_model = node_config.get("embedder_model", None)
self.verbose = node_config.get("verbose", False)
self.force = node_config.get("force", False)
Expand Down Expand Up @@ -92,8 +100,7 @@ def execute(self, state: dict) -> dict:
format_instructions = ""

if (
isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI))
and not self.script_creator
not self.script_creator
or self.force
and not self.script_creator
or self.is_md_scraper
Expand Down
4 changes: 1 addition & 3 deletions scrapegraphai/nodes/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def execute(self, state: dict) -> dict:
chunks = split_text_into_chunks(
text=docs_transformed.page_content,
chunk_size=self.chunk_size - 250,
model=self.llm_model,
)
else:
docs_transformed = docs_transformed[0]
Expand All @@ -115,11 +114,10 @@ def execute(self, state: dict) -> dict:
chunks = split_text_into_chunks(
text=docs_transformed.page_content,
chunk_size=chunk_size,
model=self.llm_model,
)
else:
chunks = split_text_into_chunks(
text=docs_transformed, chunk_size=chunk_size, model=self.llm_model
text=docs_transformed, chunk_size=chunk_size
)

state.update({self.output[0]: chunks})
Expand Down
14 changes: 5 additions & 9 deletions scrapegraphai/utils/split_text_into_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@

from typing import List

from langchain_core.language_models.chat_models import BaseChatModel

from .tokenizer import num_tokens_calculus


def split_text_into_chunks(
text: str, chunk_size: int, model: BaseChatModel, use_semchunk=True
) -> List[str]:
def split_text_into_chunks(text: str, chunk_size: int, use_semchunk=True) -> List[str]:
"""
Splits the text into chunks based on the number of tokens.

Expand All @@ -27,17 +23,17 @@ def split_text_into_chunks(
from semchunk import chunk

def count_tokens(text):
return num_tokens_calculus(text, model)
return num_tokens_calculus(text)

chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
chunk_size = min(chunk_size, int(chunk_size * 0.9))

chunks = chunk(
text=text, chunk_size=chunk_size, token_counter=count_tokens, memoize=False
)
return chunks

else:
tokens = num_tokens_calculus(text, model)
tokens = num_tokens_calculus(text)

if tokens <= chunk_size:
return [text]
Expand All @@ -48,7 +44,7 @@ def count_tokens(text):

words = text.split()
for word in words:
word_tokens = num_tokens_calculus(word, model)
word_tokens = num_tokens_calculus(word)
if current_length + word_tokens > chunk_size:
chunks.append(" ".join(current_chunk))
current_chunk = [word]
Expand Down
28 changes: 4 additions & 24 deletions scrapegraphai/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,15 @@
Module for counting tokens and splitting text into chunks
"""

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_mistralai import ChatMistralAI
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from .tokenizers.tokenizer_openai import num_tokens_openai


def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
def num_tokens_calculus(string: str) -> int:
"""
Returns the number of tokens in a text string.
"""
if isinstance(llm_model, ChatOpenAI):
from .tokenizers.tokenizer_openai import num_tokens_openai

num_tokens_fn = num_tokens_openai
num_tokens_fn = num_tokens_openai

elif isinstance(llm_model, ChatMistralAI):
from .tokenizers.tokenizer_mistral import num_tokens_mistral

num_tokens_fn = num_tokens_mistral

elif isinstance(llm_model, ChatOllama):
from .tokenizers.tokenizer_ollama import num_tokens_ollama

num_tokens_fn = num_tokens_ollama

else:
from .tokenizers.tokenizer_openai import num_tokens_openai

num_tokens_fn = num_tokens_openai

num_tokens = num_tokens_fn(string, llm_model)
num_tokens = num_tokens_fn(string)
return num_tokens
6 changes: 2 additions & 4 deletions scrapegraphai/utils/tokenizers/tokenizer_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
"""

import tiktoken
from langchain_core.language_models.chat_models import BaseChatModel

from ..logging import get_logger


def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:
def num_tokens_openai(text: str) -> int:
"""
Estimate the number of tokens in a given text using OpenAI's tokenization method,
adjusted for different OpenAI models.

Args:
text (str): The text to be tokenized and counted.
llm_model (BaseChatModel): The specific OpenAI model to adjust tokenization.

Returns:
int: The number of tokens in the text.
Expand All @@ -25,7 +23,7 @@ def num_tokens_openai(text: str, llm_model: BaseChatModel) -> int:

logger.debug(f"Counting tokens for text of {len(text)} characters")

encoding = tiktoken.encoding_for_model("gpt-4")
encoding = tiktoken.encoding_for_model("gpt-4o")

num_tokens = len(encoding.encode(text))
return num_tokens
4 changes: 1 addition & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading