Skip to content

Commit

Permalink
Merge pull request #40 from LemurPwned/feat/explicit-stream-support
Browse files Browse the repository at this point in the history
Feat/explicit stream support
  • Loading branch information
LemurPwned authored Nov 11, 2024
2 parents f7b1d8c + 0e00b4b commit 25b2e71
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 9 deletions.
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Currently, it uses keyframe decoding, frame interval gating and perceptual hashi
- [Features](#features)
- [Installation and Usage](#installation-and-usage)
- [Basic usage](#basic-usage)
- [Streaming and RTSP support](#streaming-and-rtsp-support)
- [Image sampling](#image-sampling)
- [YT-DLP integration plugin](#yt-dlp-integration-plugin)
- [Extra YT-DLP options](#extra-yt-dlp-options)
Expand All @@ -38,9 +39,9 @@ Currently, it uses keyframe decoding, frame interval gating and perceptual hashi
- [Benchmarks](#benchmarks)
- [Benchmark videos](#benchmark-videos)
- [Flit commands](#flit-commands)
- [Build](#build)
- [Install](#install)
- [Publish](#publish)
- [Build](#build)
- [Install](#install)
- [Publish](#publish)
- [🛡 License](#-license)
- [📃 Citation](#-citation)

Expand Down Expand Up @@ -110,6 +111,24 @@ python3 -m video_sampler config ./configs/hash_base.yaml /my-video-folder/ ./my-

You can set the number of workers to use with the `n_workers` parameter. The default is 1.

#### Streaming and RTSP support

RTSP support is experimental and may not work for all RTSP servers, but it should work for most of them.
You can test out the RTSP support by running the following command:

```bash
python3 -m video_sampler config ./configs/hash_base.yaml rtsp://localhost:8554/some-stream ./sampled-stream/
```
[RTSP simple server](https://github.com/bhaney/rtsp-simple-server) is a good way to test RTSP streams.

Other streams (MJPEG) also work, e.g.

```bash
python3 -m video_sampler config ./configs/hash_base.yaml "http://honjin1.miemasu.net/nphMotionJpeg?Resolution=640x480&Quality=Standard" ./sampled-stream/
```

For proper streaming, you may want to adjust `min_frame_interval_sec` and buffer sizes to have a shorter flush time. Keep in mind that streaming will be sampled until interrupted, so you may want to specify the end time of the stream with [`end_time_s` parameter](./video_sampler/config.py#L81). If the stream is a looped video, this is especially important -- otherwise, you'll end up overwriting the same frames over and over again.

#### Image sampling

If your frames are ordered, then you can use the `image_sampler` module to sample them. The images should have some concept of ordering, e.g. they should be named in a way that allows for sorting, e.g. `image_001.png`, `image_002.png`, etc, because the sampler will deduplicate based on the circular buffer of hashes.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "video_sampler"
description = "Video Sampler -- sample frames from a video file"
url = "https://github.com/LemurPwned/video-sampler"
version = "0.12.0"
version = "0.12.1"
authors = [
{ name = "LemurPwned", email = "lemurpwned@gmail.com" }
]
Expand All @@ -27,7 +27,7 @@ keywords = [
]

dependencies = [
"av >=10.0.0",
"av >=13.0.0",
"ImageHash >=4.3.1",
"Pillow >=10.0.0",
"rich >= 13.5.3",
Expand Down
11 changes: 11 additions & 0 deletions video_sampler/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def parallel_video_processing(
future.result()


def is_video_path_url(video_path: str) -> bool:
return (
video_path.startswith("http")
or video_path.startswith("https")
or video_path.startswith("rtsp")
)


def delegate_workers(
video_path: str | Generator,
output_path: str,
Expand All @@ -126,6 +134,9 @@ def delegate_workers(
videos = video_path
msg = "Detected input as an URL generator"
is_url = True
elif is_video_path_url(video_path):
videos = iter([video_path])
msg = "Detected input as an URL"
elif not os.path.isfile(video_path) and not isinstance(cfg, ImageSamplerConfig):
if "*" not in video_path:
videos = glob.glob(os.path.join(video_path, "*"))
Expand Down
4 changes: 4 additions & 0 deletions video_sampler/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ def process_frame(
self.stats["gated"] += gated_obj.N
if gated_obj.frames:
yield gated_obj.frames

def debug_print(self, message: str):
if self.cfg.debug:
console.print(message, style=f"bold {Color.red.value}")
26 changes: 24 additions & 2 deletions video_sampler/samplers/image_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
class ImageSampler(BaseSampler):
"""
Image sampler -- sample frames from a folder of images
Args:
cfg (ImageSamplerConfig): Image sampler config
Methods:
sample(image_folder: str) -> Iterable[list[FrameObject]]: Sample frames from image folder
write_queue(image_path: str, q: Queue, _: str = None): Write frames to queue
"""

def __init__(self, cfg: ImageSamplerConfig):
Expand All @@ -29,6 +36,12 @@ def __init__(self, cfg: ImageSamplerConfig):
def extract_frame_time(self, image_path: str, default: str | None = None) -> str:
"""
Extract frame time from image path
Args:
image_path (str): Path to image
default (str | None): Default frame time to return if no regex is set
Returns:
str: Frame time
"""
if self.rgx:
if match := self.rgx.search(image_path):
Expand All @@ -47,8 +60,11 @@ def extract_frame_time(self, image_path: str, default: str | None = None) -> str
def sample(self, image_folder: str) -> Iterable[list[FrameObject]]:
"""
Sample frames from image folder
:param image_folder: path to image folder or glob pattern
:return: iterable of frames
Args:
image_folder (str): Path to image folder or glob pattern
Returns:
Iterable[list[FrameObject]]: Iterable of frames
"""
self.init_sampler()
if "*" in image_folder:
Expand All @@ -73,6 +89,12 @@ def sample(self, image_folder: str) -> Iterable[list[FrameObject]]:
yield from self.flush_buffer()

def write_queue(self, image_path: str, q: Queue, _: str = None):
"""
Write frames to queue.
Args:
image_path (str): Path to image
q (Queue): Queue to write frames to
"""
try:
for item in self.sample(image_path):
q.put(item)
Expand Down
25 changes: 23 additions & 2 deletions video_sampler/samplers/video_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def sample(self, video_path: str, subs: str = None) -> Iterable[list[FrameObject
of FrameObjects representing sampled frames.
"""
self.init_sampler()
with av.open(video_path) as container:
with av.open(
video_path,
metadata_errors="ignore",
) as container:
stream = container.streams.video[0]
if self.cfg.keyframes_only:
stream.codec_context.skip_frame = "NONKEY"
Expand Down Expand Up @@ -79,24 +82,42 @@ def sample(self, video_path: str, subs: str = None) -> Iterable[list[FrameObject
f"\n\t{e}",
style=f"bold {Color.red.value}",
)
avg_fps = float(stream.average_rate)
try:
avg_fps = float(stream.average_rate)
except (AttributeError, TypeError):
console.print(
"Failed to get average FPS, defaulting to 1. If you are using a URL handle, this is expected.",
style=f"bold {Color.yellow.value}",
)
avg_fps = 1
for frame in container.decode(stream):
if frame is None or frame.is_corrupt:
self.debug_print("Frame is None or corrupt, skipping.")
continue
try:
ftime = frame.time
except AttributeError:
self.debug_print("Failed to get frame time, skipping frame.")
continue
if self.cfg.start_time_s > 0 and ftime < self.cfg.start_time_s:
self.debug_print(
f"Frame time {ftime} is before start time {self.cfg.start_time_s}, skipping."
)
continue

if self.cfg.end_time_s is not None and ftime > self.cfg.end_time_s:
self.debug_print(
f"Frame time {ftime} is after end time {self.cfg.end_time_s}, stopping."
)
break
frame_index = int(ftime * avg_fps)
# skip frames if keyframes_only is True
time_diff = ftime - prev_time
self.stats["total"] += 1
if time_diff < self.cfg.min_frame_interval_sec:
self.debug_print(
f"Frame time {ftime} is too close to previous frame {prev_time}, skipping."
)
continue
prev_time = ftime
frame_pil = frame.to_image()
Expand Down

0 comments on commit 25b2e71

Please sign in to comment.