import curses
import html
import os
import time
import typing as ty
from collections import defaultdict
from pathlib import Path
import ray
from tabulate import tabulate
from tqdm import tqdm
from ablator.utils.base import num_format
try:
import ipywidgets as widgets
from IPython.display import display
except ImportError:
widgets = None
[docs]def in_notebook():
try:
# pylint: disable=import-outside-toplevel
from IPython import get_ipython
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
return False
except ImportError:
return False
except AttributeError:
return False
return True
[docs]def get_last_line(filename: Path):
if filename is None or not filename.exists():
return None
with open(filename, "rb") as f:
try: # catch OSError in case of a one line file
f.seek(-2, os.SEEK_END)
while f.read(1) != b"\n":
f.seek(-2, os.SEEK_CUR)
except OSError:
f.seek(0)
last_line = f.readline().decode()
return last_line
SEPERATOR = " | "
[docs]class Display:
def __init__(self) -> None:
self.is_terminal = not in_notebook()
if self.is_terminal:
# get existing stdout and redirect future stdout
self._curses = curses
self.stdscr = curses.initscr()
self.stdscr.clear()
self.nrows, self.ncols = self.stdscr.getmaxyx()
else:
assert (
widgets is not None
), "Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html"
self.ncols = int(1e3)
self.nrows = int(1e3)
self.html_widget = widgets.HTML(value="")
self.html_value = ""
display(self.html_widget)
def _refresh(self):
if self.is_terminal:
self.stdscr.refresh()
self.stdscr.clear()
else:
self.html_widget.value = self.html_value
self.html_value = ""
def _display(self, text, pos, is_last=False):
if self.ncols is None or self.nrows is None:
return
# pylint: disable=import-outside-toplevel
import _curses
if self.is_terminal:
try:
_text = text[: self.ncols - 1] if is_last else text
self.stdscr.addstr(pos, 0, _text)
except _curses.error:
pass
else:
self.html_value += html.escape(text) + "<br>"
[docs] def close(self):
if self.is_terminal:
self._curses.nocbreak()
self.stdscr.keypad(0)
self._curses.echo()
self._curses.endwin()
self._curses.curs_set(1) # Turn cursor back on
self.is_terminal = False
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def _update_screen_dims(self):
if self.is_terminal:
self.nrows, self.ncols = self.stdscr.getmaxyx()
[docs] def print_texts(self, texts):
self._update_screen_dims()
for i, text in enumerate(texts):
self._display(text, i)
self._refresh()
@ray.remote
class RemoteProgressBar:
def __init__(self, total_trials: int | None):
super().__init__()
self.start_time: float = time.time()
self.total_trials = total_trials if total_trials is not None else float("inf")
self.closed: dict[str, bool] = defaultdict(lambda: False)
self.texts: dict[str, list[str]] = defaultdict(lambda: [])
self.finished_trials: int = 0
def __iter__(self):
for obj in range(self.total_trials):
yield obj
def close(self, uid):
self.closed[uid] = True
def make_bar(self):
return ProgressBar.make_bar(
current_iteration=self.current_iteration,
start_time=self.start_time,
total_steps=self.total_trials,
epoch_len=self.total_trials,
ncols=100,
)
@property
def current_iteration(self):
return sum(self.closed.values())
def make_print_texts(self) -> list[str]:
def _concat_texts(texts) -> list[str]:
_texts = [f"{texts[1]}{SEPERATOR}{texts[0]}"]
if len(texts) > 2:
padding = " " * (len(texts[1].split(":")[0]) + 2)
_texts.append(f"{padding}{texts[2]}")
return _texts
texts: list[str] = []
texts.append(self.make_bar())
for uid in sorted(self.texts):
if not self.closed[uid]:
texts += _concat_texts(self.texts[uid])
return texts
def update(self, finished_trials):
self.finished_trials = finished_trials
def update_status(self, uid: str, texts: list[str]):
self.texts[uid] = texts
[docs]class RemoteDisplay(Display):
def __init__(
self, remote_progress_bar: RemoteProgressBar, update_interval: int = 1
) -> None:
super().__init__()
self._prev_update_time = time.time()
self.update_interval = update_interval
self.remote_progress_bar = remote_progress_bar
[docs] def refresh(self, force=False):
if time.time() - self._prev_update_time > self.update_interval or force:
self._prev_update_time = time.time()
self.print_texts(
ray.get(self.remote_progress_bar.make_print_texts.remote())
)
[docs]class ProgressBar:
def __init__(
self,
total_steps,
epoch_len: int | None = None,
logfile: Path | None = None,
update_interval: int = 1,
remote_display: ty.Optional[RemoteProgressBar] = None,
uid: str | None = None,
):
if epoch_len is None:
self.epoch_len = total_steps
else:
self.epoch_len = epoch_len
self.total_steps = total_steps
self.update_interval = update_interval
self.start_time = time.time()
self._prev_update_time = time.time()
self.current_iteration = 0
self.metrics: dict[str, ty.Any] = {}
self.logfile = logfile
self.display: Display | None = None
self.remote_display: RemoteProgressBar | None = None
if remote_display is None:
self.display = Display()
else:
self.remote_display = remote_display
self.uid = uid
self._update()
def __iter__(self):
for obj in range(self.epoch_len):
yield obj
[docs] def reset(self) -> None:
self.current_iteration = 0
[docs] def close(self):
if self.display is not None:
self.display.close()
else:
self.remote_display.close.remote(self.uid)
def __exit__(self, exc_type, exc_value, traceback):
self.close()
[docs] @classmethod
def make_bar(
cls,
current_iteration: int,
start_time: float,
epoch_len: int,
total_steps: int,
ncols: int | None = None,
):
if current_iteration > 0:
rate = current_iteration / (time.time() - start_time)
time_remaining = (total_steps - current_iteration) / rate
ftime = tqdm.format_interval(time_remaining)
else:
ftime = "??"
post_fix = f"Remaining: {ftime}"
return tqdm.format_meter(
current_iteration,
epoch_len,
elapsed=time.time() - start_time,
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
postfix=post_fix,
ncols=ncols,
)
[docs] @classmethod
def make_metrics_message(
cls,
metrics: dict[str, ty.Any],
nrows: int | None = None,
ncols: int | None = None,
):
rows = tabulate(
[[k + ":", f"{num_format(v)}"] for k, v in metrics.items()],
disable_numparse=True,
tablefmt="plain",
stralign="right",
).split("\n")
text = ""
texts = []
for row in rows:
row += SEPERATOR
if ncols is not None and len(text) + len(row) > ncols:
text = text[: -len(SEPERATOR)]
texts.append(text)
text = ""
text += row
if nrows is not None and len(texts) > nrows:
break
text = text[: -len(SEPERATOR)]
texts.append(text)
return texts
@property
def ncols(self):
if self.display is not None:
return self.display.ncols
return None
@property
def nrows(self):
if self.display is not None:
return self.display.nrows - 5 # padding
return None
[docs] def make_print_message(self):
texts = self.make_metrics_message(self.metrics, self.nrows, self.ncols)
pbar = self.make_bar(
current_iteration=self.current_iteration,
start_time=self.start_time,
total_steps=self.total_steps,
epoch_len=self.epoch_len,
)
if self.uid is not None:
texts.append(f"{self.uid}: {pbar}")
else:
texts.append(pbar)
if (last_line := get_last_line(self.logfile)) is not None:
texts.append(last_line)
return texts
def _update(self):
texts = self.make_print_message()
if self.display is not None:
self.display.print_texts(texts)
else:
ray.get(self.remote_display.update_status.remote(self.uid, texts))
if self.current_iteration + 1 == self.epoch_len:
self.close()
[docs] def update_metrics(self, metrics: dict[str, ty.Any], current_iteration: int):
self.metrics = metrics
self.current_iteration = current_iteration
if (
current_iteration == 0
or time.time() - self._prev_update_time > self.update_interval
or current_iteration + 1 == self.epoch_len
):
self._prev_update_time = time.time()
self._update()