Skip to content

Commit

Permalink
enhancement: add support for Playwright's storage_state parameter (#…
Browse files Browse the repository at this point in the history
…832)

* add support for Playwright `storage_state`

* add storage_state param to node_config

* add sleep for testing

* add sleep in _with_js_support for testing

* remove asyncio.sleep() from tests

* fix typo in existing example filename; add auth example

* add example `authenticated_playwright`

* update source link in example to /feed

* add `storage_state` to missing graphs
  • Loading branch information
aflansburg authored Dec 3, 2024
1 parent fbb4252 commit a86e7d6
Show file tree
Hide file tree
Showing 16 changed files with 585 additions and 329 deletions.
93 changes: 93 additions & 0 deletions examples/extras/authenticated_playwright.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Example leveraging a state file containing session cookies which
might be leveraged to authenticate to a website and scrape protected
content.
"""

import os
import random
from dotenv import load_dotenv

# import playwright so we can use it to create the state file
from playwright.async_api import async_playwright

from scrapegraphai.graphs import OmniScraperGraph
from scrapegraphai.utils import prettify_exec_info

load_dotenv()

# ************************************************
# Leveraging Playwright external to the invocation of the graph to
# login and create the state file
# ************************************************


# note this is just an example and probably won't actually work on
# LinkedIn, the implementation of the login is highly dependent on the website
async def do_login():
async with async_playwright() as playwright:
browser = await playwright.chromium.launch(
timeout=30000,
headless=False,
slow_mo=random.uniform(500, 1500),
)
page = await browser.new_page()

# very basic implementation of a login, in reality it may be trickier
await page.goto("https://www.linkedin.com/login")
await page.get_by_label("Email or phone").fill("some_bloke@some_domain.com")
await page.get_by_label("Password").fill("test1234")
await page.get_by_role("button", name="Sign in").click()
await page.wait_for_timeout(3000)

# assuming a successful login, we save the cookies to a file
await page.context.storage_state(path="./state.json")


async def main():
await do_login()

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_api_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key": openai_api_key,
"model": "openai/gpt-4o",
},
"max_images": 10,
"headless": False,
# provide the path to the state file
"storage_state": "./state.json",
}

# ************************************************
# Create the OmniScraperGraph instance and run it
# ************************************************

omni_scraper_graph = OmniScraperGraph(
prompt="List me all the projects with their description.",
source="https://www.linkedin.com/feed/",
config=graph_config,
)

# the storage_state is used to load the cookies from the state file
# so we are authenticated and able to scrape protected content
result = omni_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = omni_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))


if __name__ == "__main__":
import asyncio

asyncio.run(main())
File renamed without changes.
35 changes: 26 additions & 9 deletions scrapegraphai/docloaders/chromium.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

logger = get_logger("web-loader")


class ChromiumLoader(BaseLoader):
"""Scrapes HTML pages from URLs using a (headless) instance of the
Chromium web driver with proxy protection.
Expand All @@ -33,6 +34,7 @@ def __init__(
proxy: Optional[Proxy] = None,
load_state: str = "domcontentloaded",
requires_js_support: bool = False,
storage_state: Optional[str] = None,
**kwargs: Any,
):
"""Initialize the loader with a list of URL paths.
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
self.urls = urls
self.load_state = load_state
self.requires_js_support = requires_js_support
self.storage_state = storage_state

async def ascrape_undetected_chromedriver(self, url: str) -> str:
"""
Expand Down Expand Up @@ -91,7 +94,9 @@ async def ascrape_undetected_chromedriver(self, url: str) -> str:
attempt += 1
logger.error(f"Attempt {attempt} failed: {e}")
if attempt == self.RETRY_LIMIT:
results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
results = (
f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
)
finally:
driver.quit()

Expand All @@ -113,7 +118,9 @@ async def ascrape_playwright(self, url: str) -> str:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy, **self.browser_config
)
context = await browser.new_context()
context = await browser.new_context(
storage_state=self.storage_state
)
await Malenia.apply_stealth(context)
page = await context.new_page()
await page.goto(url, wait_until="domcontentloaded")
Expand All @@ -125,9 +132,11 @@ async def ascrape_playwright(self, url: str) -> str:
attempt += 1
logger.error(f"Attempt {attempt} failed: {e}")
if attempt == self.RETRY_LIMIT:
raise RuntimeError(f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}")
raise RuntimeError(
f"Failed to fetch {url} after {self.RETRY_LIMIT} attempts: {e}"
)
finally:
if 'browser' in locals():
if "browser" in locals():
await browser.close()

async def ascrape_with_js_support(self, url: str) -> str:
Expand All @@ -138,7 +147,7 @@ async def ascrape_with_js_support(self, url: str) -> str:
url (str): The URL to scrape.
Returns:
str: The fully rendered HTML content after JavaScript execution,
str: The fully rendered HTML content after JavaScript execution,
or an error message if an exception occurs.
"""
from playwright.async_api import async_playwright
Expand All @@ -153,7 +162,9 @@ async def ascrape_with_js_support(self, url: str) -> str:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy, **self.browser_config
)
context = await browser.new_context()
context = await browser.new_context(
storage_state=self.storage_state
)
page = await context.new_page()
await page.goto(url, wait_until="networkidle")
results = await page.content()
Expand All @@ -163,7 +174,9 @@ async def ascrape_with_js_support(self, url: str) -> str:
attempt += 1
logger.error(f"Attempt {attempt} failed: {e}")
if attempt == self.RETRY_LIMIT:
results = f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
results = (
f"Error: Network error after {self.RETRY_LIMIT} attempts - {e}"
)
finally:
await browser.close()

Expand All @@ -180,7 +193,9 @@ def lazy_load(self) -> Iterator[Document]:
Document: The scraped content encapsulated within a Document object.
"""
scraping_fn = (
self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}")
self.ascrape_with_js_support
if self.requires_js_support
else getattr(self, f"ascrape_{self.backend}")
)

for url in self.urls:
Expand All @@ -202,7 +217,9 @@ async def alazy_load(self) -> AsyncIterator[Document]:
source URL as metadata.
"""
scraping_fn = (
self.ascrape_with_js_support if self.requires_js_support else getattr(self, f"ascrape_{self.backend}")
self.ascrape_with_js_support
if self.requires_js_support
else getattr(self, f"ascrape_{self.backend}")
)

tasks = [scraping_fn(url) for url in self.urls]
Expand Down
95 changes: 65 additions & 30 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
AbstractGraph Module
"""

from abc import ABC, abstractmethod
from typing import Optional
import uuid
Expand All @@ -9,12 +10,10 @@
from langchain.chat_models import init_chat_model
from langchain_core.rate_limiters import InMemoryRateLimiter
from ..helpers import models_tokens
from ..models import (
OneApi,
DeepSeek
)
from ..models import OneApi, DeepSeek
from ..utils.logging import set_verbosity_warning, set_verbosity_info


class AbstractGraph(ABC):
"""
Scaffolding class for creating a graph representation and executing it.
Expand All @@ -39,14 +38,18 @@ class AbstractGraph(ABC):
... # Implementation of graph creation here
... return graph
...
>>> my_graph = MyGraph("Example Graph",
>>> my_graph = MyGraph("Example Graph",
{"llm": {"model": "gpt-3.5-turbo"}}, "example_source")
>>> result = my_graph.run()
"""

def __init__(self, prompt: str, config: dict,
source: Optional[str] = None, schema: Optional[BaseModel] = None):

def __init__(
self,
prompt: str,
config: dict,
source: Optional[str] = None,
schema: Optional[BaseModel] = None,
):
if config.get("llm").get("temperature") is None:
config["llm"]["temperature"] = 0

Expand All @@ -55,14 +58,13 @@ def __init__(self, prompt: str, config: dict,
self.config = config
self.schema = schema
self.llm_model = self._create_llm(config["llm"])
self.verbose = False if config is None else config.get(
"verbose", False)
self.headless = True if self.config is None else config.get(
"headless", True)
self.verbose = False if config is None else config.get("verbose", False)
self.headless = True if self.config is None else config.get("headless", True)
self.loader_kwargs = self.config.get("loader_kwargs", {})
self.cache_path = self.config.get("cache_path", False)
self.browser_base = self.config.get("browser_base")
self.scrape_do = self.config.get("scrape_do")
self.storage_state = self.config.get("storage_state")

self.graph = self._create_graph()
self.final_state = None
Expand All @@ -81,7 +83,7 @@ def __init__(self, prompt: str, config: dict,
"loader_kwargs": self.loader_kwargs,
"llm_model": self.llm_model,
"cache_path": self.cache_path,
}
}

self.set_common_params(common_params, overwrite=True)

Expand Down Expand Up @@ -129,7 +131,8 @@ def _create_llm(self, llm_config: dict) -> object:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
llm_params["rate_limiter"] = InMemoryRateLimiter(
requests_per_second=requests_per_second)
requests_per_second=requests_per_second
)
if max_retries is not None:
llm_params["max_retries"] = max_retries

Expand All @@ -140,30 +143,55 @@ def _create_llm(self, llm_config: dict) -> object:
raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"]

known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}

if '/' in llm_params["model"]:
split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1]
known_providers = {
"openai",
"azure_openai",
"google_genai",
"google_vertexai",
"ollama",
"oneapi",
"nvidia",
"groq",
"anthropic",
"bedrock",
"mistralai",
"hugging_face",
"deepseek",
"ernie",
"fireworks",
"togetherai",
}

if "/" in llm_params["model"]:
split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1]
else:
possible_providers = [provider for provider, models_d in models_tokens.items() if llm_params["model"] in models_d]
possible_providers = [
provider
for provider, models_d in models_tokens.items()
if llm_params["model"] in models_d
]
if len(possible_providers) <= 0:
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
If possible, try to use a model instance instead.""")
llm_params["model_provider"] = possible_providers[0]
print((f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
"If it was not intended please specify the model provider in the graph configuration"))
print(
(
f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
"If it was not intended please specify the model provider in the graph configuration"
)
)

if llm_params["model_provider"] not in known_providers:
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
If possible, try to use a model instance instead.""")

if "model_tokens" not in llm_params:
try:
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
self.model_token = models_tokens[llm_params["model_provider"]][
llm_params["model"]
]
except KeyError:
print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
using default token size (8192)""")
Expand All @@ -172,10 +200,17 @@ def _create_llm(self, llm_config: dict) -> object:
self.model_token = llm_params["model_tokens"]

try:
if llm_params["model_provider"] not in \
{"oneapi","nvidia","ernie","deepseek","togetherai"}:
if llm_params["model_provider"] not in {
"oneapi",
"nvidia",
"ernie",
"deepseek",
"togetherai",
}:
if llm_params["model_provider"] == "bedrock":
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
llm_params["model_kwargs"] = {
"temperature": llm_params.pop("temperature")
}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)
Expand All @@ -187,6 +222,7 @@ def _create_llm(self, llm_config: dict) -> object:

if model_provider == "ernie":
from langchain_community.chat_models import ErnieBotChat

return ErnieBotChat(**llm_params)

elif model_provider == "oneapi":
Expand All @@ -211,7 +247,6 @@ def _create_llm(self, llm_config: dict) -> object:
except Exception as e:
raise Exception(f"Error instancing model: {e}")


def get_state(self, key=None) -> dict:
""" ""
Get the final state of the graph.
Expand Down
Loading

0 comments on commit a86e7d6

Please sign in to comment.