import logging
import random
import sys
import time
import typing as ty
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
import numpy as np
import ray
import torch
[docs]class Dummy:
def __init__(self, *args, **kwargs):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, *args, **kwargs):
return self
def __getitem__(self, *args, **kwargs):
return self
@ray.remote
class _Lock:
def __init__(self) -> None:
self.lock: bool = False
def acquire(self):
if not self.lock:
self.lock = True
return True
return False
def release(self):
if not self.lock:
raise ValueError("lock released too many times")
self.lock = False
[docs]class Lock:
def __init__(self, timeout: float | None = None) -> None:
# pylint: disable=no-member
self.lock = _Lock.remote() # type: ignore[attr-defined]
self.timeout = timeout
def __enter__(self):
self.acquire()
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
[docs] def acquire(self):
current_time = time.time()
n = 0
while True:
if ray.get(self.lock.acquire.remote()):
break
if self.timeout is not None and time.time() - current_time > self.timeout:
raise TimeoutError(
f"Could not obtain lock within {self.timeout:.2f} seconds"
)
n += 1
if n % 10000 == 0:
logging.warning(
"Lock(%s) takes an excessive time and it could be caused by a deadlock.",
id(self),
)
time.sleep(0.1)
[docs] def release(self):
ray.get(self.lock.release.remote())
[docs]def iter_to_numpy(iterable: Iterable) -> ty.Any:
"""
Convert elements of the input iterable to NumPy arrays if they are torch.Tensor objects.
Parameters
----------
iterable : Iterable
The input iterable.
Returns
-------
ty.Any
The iterable with torch.Tensor elements replaced with their NumPy array equivalents.
"""
return apply_lambda_to_iter(
iterable,
lambda v: v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else v,
)
[docs]def iter_to_device(
data_dict: Iterable, device: str
) -> ty.Union[Sequence[torch.Tensor], dict[str, torch.Tensor]]:
"""
Moving torch.Tensor elements to the specified device.
Parameters
----------
data_dict : Iterable
The input dictionary or list containing torch.Tensor elements.
device : str
The target device for the tensors.
Returns
-------
ty.Union[Sequence[torch.Tensor], dict[str, torch.Tensor]]
The input data with tensors moved to the target device.
"""
return apply_lambda_to_iter(
data_dict, lambda v: v.to(device) if isinstance(v, torch.Tensor) else v
)
[docs]def apply_lambda_to_iter(iterable: Iterable, fn: Callable) -> ty.Any:
"""
Applies a given function ``fn`` to each element of an iterable data structure.
This function recursively applies ``fn`` to elements within nested dictionaries or lists.
It can be used for converting torch.Tensor elements to NumPy arrays or moving tensors
to a specified device.
Parameters
----------
iterable : Iterable
The input iterable.
fn : Callable
The function to apply to each element.
Returns
-------
ty.Any
The resulting data structure after applying ``fn`` to each element of the input ``iterable``.
The type of the returned object matches the type of the input ``iterable``.
"""
if isinstance(iterable, dict):
return {
k: apply_lambda_to_iter(v, fn) if isinstance(v, (Iterable)) else fn(v)
for k, v in iterable.items()
}
if isinstance(iterable, list):
return [apply_lambda_to_iter(v, fn) for v in iterable]
return fn(iterable)
[docs]def set_seed(seed: int) -> int:
"""
Set the random seed.
Parameters
----------
seed : int
The random seed to set.
Returns
-------
int
The set random seed.
"""
assert seed is not None, "Must provide a seed"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
[docs]def get_lr(optimizer: torch.optim.Optimizer | dict) -> float:
"""
Get the learning rate from an optimizer.
Parameters
----------
optimizer : torch.optim.Optimizer | dict
The optimizer.
Returns
-------
float
The learning rate.
"""
if isinstance(optimizer, dict):
param_groups = optimizer["param_groups"]
else:
param_groups = optimizer.param_groups
return param_groups[0]["lr"]
[docs]def debugger_is_active() -> bool:
"""
Check if the debugger is currently active.
Returns
-------
bool
True if the debugger is active, False otherwise.
Notes
-----
Return if the debugger is currently active
"""
gettrace = getattr(sys, "gettrace", lambda: None)
return gettrace() is not None
[docs]def get_latest_chkpts(checkpoint_dir: Path) -> list[Path]:
"""
Get a list of all checkpoint files in a directory, sorted from the latest to the earliest.
Parameters
----------
checkpoint_dir : Path
The directory containing checkpoint files.
Returns
-------
list[Path]
A list of the checkpoint files sorted by filename.
"""
return sorted(list(checkpoint_dir.glob("*.pt")))[::-1]
[docs]def parse_device(device: ty.Union[str, list[str], int]) -> ty.Any:
"""
Parse a device string, an integer, or a list of device strings or integers.
Parameters
----------
device : ty.Union[str, list[str], int]
The target device for the tensors.
Returns
-------
ty.Any
The parsed device string, integer, or list of device strings or integers.
Raises
------
ValueError
If the device string is not one of {'cpu', 'cuda'} or doesn't start with 'cuda:'.
AssertionError
If cuda is not found on system or gpu number of device is not available.
Examples
--------
>>> parse_device("cpu")
'cpu'
>>> parse_device("cuda")
'cuda'
>>> parse_device("cuda:0")
'cuda:0'
>>> parse_device(["cpu", "cuda"])
['cpu', 'cuda']
>>> parse_device(["cpu", "cuda:0", "cuda:1", "cuda:2"])
['cpu', 'cuda:0', 'cuda:1', 'cuda:2']
"""
if isinstance(device, str):
if device == "cpu":
return device
if device == "cuda" or (device.startswith("cuda:") and device[5:].isdigit()):
assert (
torch.cuda.is_available()
), "Could not find a torch.cuda installation on your system."
if device.startswith("cuda:"):
gpu_number = int(device[5:])
assert (
gpu_number < torch.cuda.device_count()
), f"gpu {device} does not exist on this machine"
return device
raise ValueError
if isinstance(device, int):
return device
if isinstance(device, Iterable):
return [parse_device(_device) for _device in device]
return "cuda" if torch.cuda.is_available() else "cpu"
def _num_e_format(value: int | np.integer | float | np.floating, width: int) -> str:
# minimum width of 8 "x.xxe+nn" 6 fixed 1.xxe+00 and 2 available
diff = 6
if value < 0:
diff += 1
return f"{value:.{width-diff}e}"
def _num_format_int(value: int | np.integer, width: int) -> str:
str_value = str(value)
if len(str_value) > width:
return _num_e_format(value, width)
return f"{value:0{width}d}"
def _num_format_float(value: float | np.floating, width: int) -> str:
str_value = str(value)
if "e" in str_value:
return _num_e_format(value, width)
int_part, *float_part = str_value.split(".")
int_len = len(int_part)
if len(float_part):
_float_part = float_part[0]
float_len = len(_float_part)
leading_zeros = float_len - len(_float_part.lstrip("0"))
else:
_float_part = None
float_len = 0
leading_zeros = 0
if int_len > 0 and int_part != "0":
# xxx.xxxx format
if float_len + int_len >= width:
return _num_e_format(value, width)
return f"{value:.{width-int_len - 1}f}"
if int_part == "0" and leading_zeros == float_len:
return f"{value:.{width-2}f}"
if len(str_value) < width:
return f"{value:.{width-2}f}"
return _num_e_format(value, width)
[docs]def is_oom_exception(err: RuntimeError) -> bool:
"""
is_oom_exception checks whether the exception is caused by CUDA out of memory errors.
Parameters
----------
err : RuntimeError
the exception to parse
Returns
-------
bool
whether the exception indicates out of memory error.
"""
return any(
x in str(err)
for x in (
"CUDA out of memory",
"CUBLAS_STATUS_ALLOC_FAILED",
"CUDA error: out of memory",
)
)