import argparse
import logging
import sys
import os
import asyncio
import getpass
import time
from .libs import AsyncCounter
import time
from sourceplus_sdk import __version__
__author__ = "Nick Padgett"
__copyright__ = "Spawning Inc"
__license__ = "MIT"
_logger = logging.getLogger(__name__)
DOWNLOAD_KEY_ENV_VAR = "SOURCEPLUS_DOWNLOAD_KEY"
import httpx
import pandas as pd
import aiofiles
[docs]
def download_images(
file_path: str,
output_folder: str,
num_download_jobs: int = 50,
limit_images: int = -1,
url_column_name: str = "url",
show_progress: bool = False,
):
"""Download images from the given file path to the destination folder.
Args:
file_path (str): The path to the file containing the URLs of the images to download.
output_folder (str): The path to the folder where the images will be saved.
num_download_jobs (int, optional): The maximum number of parallel downloads. Defaults to 50.
limit_images (int, optional): The maximum number of images to download. Defaults to -1, which means all images.
url_column_name (str, optional): The name of the column containing the URLs in the file. Defaults to "url".
show_progress (bool, optional): Whether to show a progress bar. Defaults to False.
"""
# validate the input arguments
full_source_path, full_destination_path, destination_folder_exists = validate_args(file_path, output_folder, num_download_jobs)
# check if API key is present in env, and if not, prompt for it
download_key = os.getenv(DOWNLOAD_KEY_ENV_VAR)
if not download_key:
download_key = prompt_for_download_key()
# prompt to create destination folder if it does not exist
if not destination_folder_exists:
# ask user if we should create it
prompt_destination_folder_creation(full_destination_path)
df, num_images = validate_file(full_source_path, url_column_name, limit_images)
# download the images
asyncio.run(start_image_downloads(df, num_download_jobs, num_images, full_destination_path, download_key, show_progress))
[docs]
def validate_args(file_path: str, output_folder: str, num_download_jobs: int):
"""
Validates the input arguments.
Args:
file_path (str): The path to the file containing the URLs of the images to download.
output_folder (str): The path to the folder where the images will be saved.
num_download_jobs (int): The maximum number of parallel downloads.
"""
# expand the paths
full_source_path = os.path.expanduser(file_path)
full_destination_path = os.path.expanduser(output_folder)
# make sure the number of download jobs is within the valid range
if num_download_jobs < 1 or num_download_jobs > 100:
raise ValueError("max_parallel_downloads must be between 1 and 100.")
# make sure the destination folder exists
destination_folder_exists = os.path.exists(full_destination_path)
return full_source_path, full_destination_path, destination_folder_exists
[docs]
def validate_file(file_path: str, url_column_name: str, limit_images: int = -1):
"""
Validates the file path.
Args:
file_path (str): The path to the file containing the URLs of the images to download.
url_column_name (str): The name of the column containing the URLs in the file.
Returns:
pd.DataFrame: The DataFrame containing the URLs of the images to download.
"""
# make sure the source file path exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"File '{file_path}' not found.")
# open the file
file_name = os.path.basename(file_path).lower()
file_extension = os.path.splitext(file_name)[1]
if file_extension == ".parquet":
df = pd.read_parquet(file_path)
elif file_extension == ".csv":
df = pd.read_csv(file_path)
else:
raise ValueError(f"Unsupported file format: {file_extension}")
# make sure the limit_images is valid
if limit_images < 0:
limit_images = -1
if limit_images == -1:
limit_images = len(df)
elif limit_images > len(df):
limit_images = len(df)
# make sure the url column exists
if url_column_name not in df.columns:
raise ValueError(f"Column '{url_column_name}' not found in the file.")
return df, limit_images
[docs]
def prompt_for_download_key():
"""
Prompts the user for the download key.
Returns:
str: The download key.
"""
download_key = getpass.getpass("Enter your Source+ download key: ")
return download_key
[docs]
def prompt_destination_folder_creation(destination_folder: str):
"""
Prompts the user to create the destination folder if it does not exist.
Args:
destination_folder (str): The path to the destination folder.
"""
create_folder = input(f"The destination folder '{destination_folder}' does not exist. Create it? (y/n) ")
if create_folder.lower() == "y":
os.makedirs(destination_folder)
else:
raise ValueError("Destination folder does not exist, and user chose not to create it. Exiting.")
[docs]
async def start_image_downloads(df: pd.DataFrame, max_parallel_downloads: int, max_images: int, destination_folder: str, api_key: str, show_progress: bool = False):
"""
Starts downloading images from the given DataFrame.
Args:
df (pd.DataFrame): The DataFrame containing the URLs of the images to download.
max_parallel_downloads (int): The maximum number of parallel downloads.
max_images (int): The maximum number of images to download.
destination_folder (str): The path to the folder where the images will be saved.
api_key (str): The Source+ download key.
show_progress (bool, optional): Whether to show a progress bar. Defaults to False.
"""
# create the download queue
semaphore = asyncio.Semaphore(max_parallel_downloads)
download_queue = df.iterrows()
# start the download manager
success_counter = AsyncCounter()
failure_counter = AsyncCounter()
tasks = []
# make http client
http_client = httpx.AsyncClient(timeout=10)
# create tasks
for i in range(0, max_images if max_images > 0 else len(df)):
row = next(download_queue)
url = row[1]["url"]
tasks.append(download_image(http_client, url, semaphore, success_counter, failure_counter, destination_folder, api_key))
# start the progress monitor
if show_progress:
tasks.insert(0, progress_monitor(success_counter, failure_counter, max_images))
# wait for all tasks to complete and close client
try:
await asyncio.gather(*tasks)
finally:
await http_client.aclose()
[docs]
async def progress_monitor(success_counter: AsyncCounter, failure_counter: AsyncCounter, total_images: int):
"""
Monitors the progress of the downloads.
Args:
success_counter (AsyncCounter): The counter to monitor.
failure_counter (AsyncCounter): The counter to monitor.
total_images (int): The total number of images to download.
"""
# print empty lines for us to update
for _ in range(1):
print("\n")
start_time = time.time()
while True:
successes = await success_counter.get()
failures = await failure_counter.get()
num_completed = successes + failures
sys.stdout.write("\033[F" * 1) # Move up cursor
# print a progress bar
progress = num_completed / total_images
bar_length = 50
bar = "#" * int(bar_length * progress)
elapsed_time = time.time() - start_time
print(f"[{bar.ljust(bar_length)}] {progress * 100:.2f}% - Elapsed Time: {elapsed_time:.2f}s")
if num_completed == total_images:
break
await asyncio.sleep(.1)
print(f"Total images: {total_images}. Successes: {successes}. Failures: {failures}")
[docs]
async def download_image(client: httpx.AsyncClient, url: str, semaphore: asyncio.Semaphore, success_counter: AsyncCounter, failure_counter: AsyncCounter, destination_folder: str, api_key: str):
"""
Downloads an image from the given URL, using the semaphore to control the number of parallel downloads.
Args:
client (httpx.AsyncClient): The HTTP client to use for downloading the image.
url (str): The URL of the image to download.
semaphore (asyncio.Semaphore): The semaphore to control the number of parallel downloads.
success_counter (AsyncCounter): The counter to keep track of the number of successful downloads.
failure_counter (AsyncCounter): The counter to keep track of the number of failed downloads.
destination_folder (str): The path to the folder where the image will be saved.
api_key (str): The Source+ download key.
"""
async with semaphore:
# get file path and name
file_name = url.split("/")[-1]
file_path = os.path.join(destination_folder, file_name)
# if file exists, skip it, so we don't download it twice
if os.path.exists(file_path):
await success_counter.increment()
return
try:
async with client.stream("GET", url, headers={"Authorization": f"API {api_key}"}, follow_redirects=True) as response:
if response.status_code == 200:
image_data = await response.aread()
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_data)
await success_counter.increment()
else:
#logging.error(f"Failed to download image from URL: {url} - {response.status_code}")
await failure_counter.increment()
except Exception as e:
#logging.error(f"Failed to download image from URL: {url} - {e}")
await failure_counter.increment()
# ---- CLI ----
# The functions defined in this section are wrappers around the main Python
# API allowing them to be called directly from the terminal as a CLI
# executable/script.
[docs]
def parse_args(args):
"""Parse command line parameters
Args:
args (List[str]): command line parameters as list of strings
(for example ``["--help"]``).
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = argparse.ArgumentParser(description="Download images.")
parser.add_argument("command", help="Command to execute. Valid commands: download_images", type=str)
parser.add_argument("-f", "--file", dest="file_path", help="Path to the file which contains the image URLs.", type=str)
parser.add_argument("-o", "--output", dest="output_folder", help="Path to the destination folder to download the images.", type=str)
parser.add_argument("-j", "--jobs", dest="num_download_jobs", default=50, help="The maximum number of download jobs running in parallel. Default is 50. A value too high will cause resource contention and slow down the overall download rate.", type=int)
parser.add_argument("-l", "--limit", dest="limit_images", default=-1, help="The maximum number of images you want to download from the file.", type=int)
parser.add_argument("-n", "--name", dest="url_column_name", default="url", help="The column name for the url field.", type=str)
parser.add_argument("-p", "--progress", dest="show_progress", default=True, help="Show progress bar.", type=bool)
parser.add_argument(
"--version",
action="version",
version=f"sourceplus-sdk {__version__}",
)
parser.add_argument(
"-v",
"--verbose",
dest="loglevel",
help="set loglevel to INFO",
action="store_const",
const=logging.INFO,
)
parser.add_argument(
"-vv",
"--very-verbose",
dest="loglevel",
help="set loglevel to DEBUG",
action="store_const",
const=logging.DEBUG,
)
return parser.parse_args(args)
[docs]
def setup_logging(loglevel):
"""Setup basic logging
Args:
loglevel (int): minimum loglevel for emitting messages
"""
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
logging.basicConfig(
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
)
[docs]
def main(args):
"""Wrapper allowing :func:`download_images` to be called with string arguments in a CLI fashion
Args:
args (List[str]): command line parameters as list of strings
"""
args = parse_args(args)
setup_logging(args.loglevel)
if args.command == "download_images":
download_images(args.file_path, args.output_folder, args.num_download_jobs, args.limit_images, args.url_column_name, args.show_progress)
else:
raise ValueError(f"Invalid command: {args.command}")
[docs]
def run():
"""Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv`
This function can be used as entry point to create console scripts with setuptools.
"""
main(sys.argv[1:])
if __name__ == "__main__":
run()