import os
import re
import socket
from enum import EnumMeta
from random import choices
from typing import Any, Dict, List, Set, Union
from tqdm.auto import tqdm
__all__ = [
"PrintableEnumMeta",
"ProgressStatusCallback",
"extract_recursive",
"in_jupyter_notebook",
"is_regex",
"sample_one_index",
"search_in_values_to_retrieve_key",
"validate_path_exists",
"is_internet_available",
]
[docs]
def search_in_values_to_retrieve_key(
code_name: str, class_to_codes: Dict[Any, Set[str]]
):
# verify there is no repetition/reuse in language codes
all_codes = [code for codes in class_to_codes.values() for code in codes]
assert len(all_codes) == len(set(all_codes)), "code reused for multiple keys"
for key, codes in class_to_codes.items():
if code_name.lower() in codes:
return key
return None
[docs]
def sample_one_index(weights: List[float], temperature: float = 1.0) -> int:
"""Select an item based on the given probability distribution.
Returns the index of the selected item sampled from weighted random distribution.
Args:
weights (List[float]): the relative weights corresponding to each index.
temperature (float): The temperature value for controlling the sampling behavior.
High temperature means sampling probabilities are more uniform (says random things).
Low temperature means that sampling probabilities are higher for bigger weights.
Defaults to 1.0.
Returns:
int: The index of the chosen item.
"""
return choices(
range(len(weights)),
weights=[w ** (1 / temperature) for w in weights],
k=1,
)[0]
[docs]
def in_jupyter_notebook():
"""
Checks if the code is running in a Jupyter notebook.
Returns:
bool: True if running in a Jupyter notebook, False otherwise.
"""
try:
from IPython import get_ipython # type: ignore
return "IPKernelApp" in get_ipython().config # type: ignore
except: # pylint: disable = bare-except
return False
[docs]
class ProgressStatusCallback:
"""
A callback class to update a tqdm progress bar with custom status information.
Args:
tqdm_bar (tqdm): The tqdm progress bar to be updated.
Attributes:
tqdm_bar (tqdm): The tqdm progress bar associated with the callback.
Methods:
__call__(self, status: Dict[str, Any]):
Update the tqdm progress bar with the provided status information.
Example:
.. code-block:: python
# Instantiate a tqdm progress bar & callback
progress_bar = tqdm(total=100, desc='Processing')
callback = ProgressStatusCallback(tqdm_bar=progress_bar)
# Update the progress bar inside some other function
status_info = {'Epoch': 1, 'Loss': 0.123, 'Accuracy': 0.95}
callback(status_info)
"""
def __init__(self, tqdm_bar: tqdm):
"""
Initialize the ProgressStatusCallback with a tqdm progress bar.
Args:
tqdm_bar (tqdm): The tqdm progress bar to be associated with the callback.
"""
self.tqdm_bar = tqdm_bar
def __call__(self, status: Dict[str, Any]):
"""
Update the tqdm progress bar with the provided status information.
Args:
status (Dict[str, Any]): A dictionary containing custom status information.
This information will be displayed as postfix on the tqdm progress bar.
"""
self.tqdm_bar.set_postfix(status, refresh=True)
[docs]
def is_regex(string: Union[str, re.Pattern]) -> bool:
"""Tests whether the argument is a regex or a regular string.
Args:
string (str | Pattern): The string to be tested.
Returns:
bool: whether the argument is a regex (True) or a regular string (False).
"""
if isinstance(string, re.Pattern):
return True
if set("+*?|[]{}^$").intersection(set(string)):
try:
re.compile(string)
return True
except re.error:
return False
return False
[docs]
def validate_path_exists(path: str, overwrite: bool = False) -> str:
"""
Validates the existence of a given file path and optionally creates necessary directories.
This function checks if a file already exists at the specified path. If the file exists
and `overwrite` is set to `False`, a `FileExistsError` is raised. If `overwrite` is set
to `True`, or if the file does not exist, the function returns the absolute path after
ensuring that all necessary directories are created.
Args:
path (str): The file path to be validated.
overwrite (bool, optional): Whether to overwrite the file if it already exists. Defaults to False.
Raises:
FileExistsError: If the file already exists at the specified path and `overwrite` is set to `False`.
Returns:
str: The absolute path of the validated file.
Examples:
>>> validate_path_exists('/path/to/file.txt', overwrite=False)
'/absolute/path/to/file.txt'
"""
if not overwrite and os.path.exists(path):
raise FileExistsError(f"File already exists: '{path}' (Use overwrite=True)")
path = os.path.abspath(path)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
[docs]
def is_internet_available() -> bool:
"""Hit a well-known server (Google DNS) to check for internet availability.
Returns:
bool: True if internet is available, False otherwise.
"""
try:
socket.create_connection(("8.8.8.8", 53), timeout=5)
return True
except OSError:
return False