2025-02-05 13:43:43 +01:00
|
|
|
# flake8: noqa: F811
|
|
|
|
# pylint: disable=unused-argument
|
|
|
|
# pylint: disable=too-few-public-methods
|
|
|
|
|
|
|
|
"""download module"""
|
|
|
|
|
|
|
|
import os.path
|
|
|
|
import signal
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from threading import Event
|
2025-02-05 13:50:07 +01:00
|
|
|
from typing import Any, Iterable
|
2025-02-05 13:43:43 +01:00
|
|
|
|
|
|
|
import requests
|
|
|
|
import rich
|
|
|
|
from rich import console
|
2025-02-05 13:50:07 +01:00
|
|
|
from rich.progress import (
|
|
|
|
BarColumn,
|
|
|
|
DownloadColumn,
|
|
|
|
Progress,
|
|
|
|
TaskID,
|
|
|
|
TextColumn,
|
|
|
|
TimeElapsedColumn,
|
|
|
|
TransferSpeedColumn,
|
|
|
|
)
|
2025-02-05 13:43:43 +01:00
|
|
|
|
|
|
|
console = rich.get_console()
|
|
|
|
done_event = Event()
|
|
|
|
|
|
|
|
|
|
|
|
def handle_sigint(signum: Any, frame: Any) -> None:
|
|
|
|
"""Progress bar handler"""
|
|
|
|
done_event.set()
|
|
|
|
|
|
|
|
|
|
|
|
signal.signal(signal.SIGINT, handle_sigint)
|
|
|
|
|
|
|
|
|
2025-02-05 13:50:07 +01:00
|
|
|
class DownloadProgressBar:
|
2025-02-05 13:43:43 +01:00
|
|
|
"""
|
|
|
|
Object to manage Download process with Progress Bar from Rich
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
"""
|
|
|
|
Class Constructor
|
|
|
|
"""
|
|
|
|
self.progress = Progress(
|
2025-02-05 13:50:07 +01:00
|
|
|
TextColumn(
|
|
|
|
"💾 Downloading [bold blue]{task.fields[filename]}", justify="right"
|
|
|
|
),
|
2025-02-05 13:43:43 +01:00
|
|
|
BarColumn(bar_width=None),
|
|
|
|
"[progress.percentage]{task.percentage:>3.1f}%",
|
|
|
|
"•",
|
|
|
|
TransferSpeedColumn(),
|
|
|
|
"•",
|
|
|
|
DownloadColumn(),
|
|
|
|
"•",
|
|
|
|
TimeElapsedColumn(),
|
|
|
|
"•",
|
2025-02-05 13:50:07 +01:00
|
|
|
console=console,
|
2025-02-05 13:43:43 +01:00
|
|
|
)
|
|
|
|
|
2025-02-05 13:50:07 +01:00
|
|
|
def _copy_url(
|
|
|
|
self, task_id: TaskID, url: str, path: str, block_size: int = 1024
|
|
|
|
) -> bool:
|
2025-02-05 13:43:43 +01:00
|
|
|
"""Copy data from a url to a local file."""
|
|
|
|
response = requests.get(url, stream=True, timeout=5)
|
|
|
|
# This will break if the response doesn't contain content length
|
2025-02-05 13:50:07 +01:00
|
|
|
self.progress.update(task_id, total=int(response.headers["Content-Length"]))
|
2025-02-05 13:43:43 +01:00
|
|
|
with open(path, "wb") as dest_file:
|
|
|
|
self.progress.start_task(task_id)
|
|
|
|
for data in response.iter_content(chunk_size=block_size):
|
|
|
|
dest_file.write(data)
|
|
|
|
self.progress.update(task_id, advance=len(data))
|
|
|
|
if done_event.is_set():
|
|
|
|
return True
|
|
|
|
# console.print(f"Downloaded {path}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
def download(self, urls: Iterable[str], dest_dir: str) -> None:
|
|
|
|
"""Download multuple files to the given directory."""
|
|
|
|
with self.progress:
|
|
|
|
with ThreadPoolExecutor(max_workers=4) as pool:
|
|
|
|
for url in urls:
|
2025-02-05 13:50:07 +01:00
|
|
|
filename = url.split("/")[-1].split("?")[0]
|
2025-02-05 13:43:43 +01:00
|
|
|
dest_path = os.path.join(dest_dir, filename)
|
2025-02-05 13:50:07 +01:00
|
|
|
task_id = self.progress.add_task(
|
|
|
|
"download", filename=filename, start=False
|
|
|
|
)
|
2025-02-05 13:43:43 +01:00
|
|
|
pool.submit(self._copy_url, task_id, url, dest_path)
|