Source code for ablator.utils.progress_bar

import curses
import html
import os
import time
import typing as ty
from collections import defaultdict
from pathlib import Path

import ray
from tabulate import tabulate
from tqdm import tqdm

from ablator.utils.base import num_format

try:
    import ipywidgets as widgets
    from IPython.display import display
except ImportError:
    widgets = None


[docs]def in_notebook() -> bool: try: # pylint: disable=import-outside-toplevel from IPython import get_ipython if "IPKernelApp" not in get_ipython().config: # pragma: no cover return False except ImportError: return False except AttributeError: return False return True
[docs]def get_last_line(filename: Path | None) -> str | None: """ This functions gets the last line from the file. Parameters ---------- filename : Path | None The path of the filename. Returns ------- str | None None if file doesn't exists or the last line of the file as a string. """ if filename is None or not filename.exists(): return None with open(filename, "rb") as f: try: # catch OSError in case of a one line file f.seek(-2, os.SEEK_END) while f.read(1) != b"\n": f.seek(-2, os.SEEK_CUR) except OSError: f.seek(0) last_line = f.readline().decode() return last_line
SEPERATOR = " | "
[docs]class Display: """ Class for handling display for terminal and notebook. Attributes ---------- _curses : curses curses object stdscr : curses.initscr() To initialize the curses library and create a window object stdscr. nrows : int height of stdscr window. nrows : int width of stdscr window. html_widget : widget.HTML html_widget with empty value html_value : str html value for widget """ def __init__(self) -> None: self.is_terminal = not in_notebook() if self.is_terminal: # get existing stdout and redirect future stdout self._curses = curses self.stdscr = curses.initscr() self.stdscr.clear() self.nrows, self.ncols = self.stdscr.getmaxyx() else: assert ( widgets is not None ), "Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html" self.ncols = int(1e3) self.nrows = int(1e3) self.html_widget = widgets.HTML(value="") self.html_value = "" display(self.html_widget) def _refresh(self): if self.is_terminal: self.stdscr.refresh() self.stdscr.clear() else: self.html_widget.value = self.html_value self.html_value = "" def _display(self, text: str, pos: int, is_last: bool = False): if self.ncols is None or self.nrows is None: return # pylint: disable=import-outside-toplevel import _curses if self.is_terminal: try: _text = text[: self.ncols - 1] if is_last else text self.stdscr.addstr(pos, 0, _text) except _curses.error: pass else: self.html_value += html.escape(text) + "<br>"
[docs] def close(self): if self.is_terminal: self._curses.nocbreak() self.stdscr.keypad(0) self._curses.echo() self._curses.endwin() self._curses.curs_set(1) # Turn cursor back on self.is_terminal = False
def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): self.close() def _update_screen_dims(self): if self.is_terminal: self.nrows, self.ncols = self.stdscr.getmaxyx()
[docs] def print_texts(self, texts: list[str]): self._update_screen_dims() for i, text in enumerate(texts): self._display(text, i) self._refresh()
@ray.remote(num_cpus=0.001) class RemoteProgressBar: """ RemoteProgressBar is a ProgressBar that is passed on a training function to report back to a centralized server the metrics. Parameters ---------- total_trials : int | None The total_trials Attributes ---------- start_time : float Stores the start time of initializing remote progress bar. total_trials : int | float Stores the total_trials. closed : dict[str, bool] Stores the key-value pairs of trial's ``uid`` and boolean indicated it is closed or not. texts : dict[str, list[str]] Stores the text associated with the ``uid`` of the trial. finished_trials : int Tracks the total finished trials. """ def __init__(self, total_trials: int | None): super().__init__() self.start_time: float = time.time() self.total_trials = total_trials if total_trials is not None else float("inf") self.closed: dict[str, bool] = defaultdict(lambda: False) self.texts: dict[str, list[str]] = defaultdict(lambda: []) self.finished_trials: int = 0 def __iter__(self): for obj in range(self.total_trials): yield obj def close(self, uid: str): self.closed[uid] = True def make_bar(self): return ProgressBar.make_bar( current_iteration=self.current_iteration, start_time=self.start_time, total_steps=self.total_trials, epoch_len=self.total_trials, ncols=100, ) @property def current_iteration(self): return sum(self.closed.values()) def make_print_texts(self) -> list[str]: def _concat_texts(texts) -> list[str]: _texts = [f"{texts[1]}{SEPERATOR}{texts[0]}"] if len(texts) > 2: padding = " " * (len(texts[1].split(":")[0]) + 2) _texts.append(f"{padding}{texts[2]}") return _texts texts: list[str] = [] texts.append(self.make_bar()) for uid in sorted(self.texts): if not self.closed[uid]: texts += _concat_texts(self.texts[uid]) return texts def update(self, finished_trials: int): self.finished_trials = finished_trials def update_status(self, uid: str, texts: list[str]): self.texts[uid] = texts
[docs]class RemoteDisplay(Display): def __init__( self, remote_progress_bar: RemoteProgressBar, update_interval: int = 1 ) -> None: super().__init__() self._prev_update_time = time.time() self.update_interval = update_interval self.remote_progress_bar = remote_progress_bar
[docs] def refresh(self, force: bool = False): if time.time() - self._prev_update_time > self.update_interval or force: self._prev_update_time = time.time() self.print_texts( ray.get(self.remote_progress_bar.make_print_texts.remote()) # type: ignore[assignment, attr-defined] )
[docs]class ProgressBar: """ Class for using progress bar. [config.verbose = "progress"] Parameters ---------- total_steps : int The total steps the progress bar is expected to iterate epoch_len : int | None The number of iterations for a single epoch that is used to calculate the time it takes per epoch. logfile : Path | None Path of logfile to read from to display on console. update_interval : int The time interval by which the progress bar will update the displayed metrics. remote_display : ty.Optional[RemoteProgressBar] A Remote display that can be used to report the progress to instead of printing it directly on console uid : str | None The trial uid that is used to report the metrics. """ def __init__( self, total_steps: int, epoch_len: int | None = None, logfile: Path | None = None, update_interval: int = 1, remote_display: ty.Optional[RemoteProgressBar] = None, uid: str | None = None, ): if epoch_len is None: self.epoch_len = total_steps else: self.epoch_len = epoch_len self.total_steps = total_steps self.update_interval = update_interval self.start_time = time.time() self._prev_update_time = time.time() self.current_iteration = 0 self.metrics: dict[str, ty.Any] = {} self.logfile = logfile self.display: Display | None = None self.remote_display: RemoteProgressBar | None = None if remote_display is None: self.display = Display() else: self.remote_display = remote_display self.uid = uid self._update() def __iter__(self): for obj in range(self.epoch_len): yield obj
[docs] def reset(self) -> None: self.current_iteration = 0
[docs] def close(self): if self.display is not None: self.display.close() else: self.remote_display.close.remote(self.uid)
def __exit__(self, exc_type, exc_value, traceback): self.close()
[docs] @classmethod def make_bar( cls, current_iteration: int, start_time: float, epoch_len: int | None, total_steps: int, ncols: int | None = None, ): if current_iteration > 0: rate = current_iteration / (time.time() - start_time) time_remaining = (total_steps - current_iteration) / rate ftime = tqdm.format_interval(time_remaining) else: ftime = "??" post_fix = f"Remaining: {ftime}" return tqdm.format_meter( current_iteration, epoch_len, elapsed=time.time() - start_time, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", postfix=post_fix, ncols=ncols, )
[docs] @classmethod def make_metrics_message( cls, metrics: dict[str, ty.Any], nrows: int | None = None, ncols: int | None = None, ) -> list: rows = tabulate( [[k + ":", f"{num_format(v)}"] for k, v in metrics.items()], disable_numparse=True, tablefmt="plain", stralign="right", ).split("\n") text = "" texts = [] for row in rows: row += SEPERATOR if ncols is not None and len(text) + len(row) > ncols: text = text[: -len(SEPERATOR)] texts.append(text) text = "" text += row if nrows is not None and len(texts) > nrows: break text = text[: -len(SEPERATOR)] texts.append(text) return texts
@property def ncols(self): if self.display is not None: return self.display.ncols return None @property def nrows(self) -> int | None: if self.display is not None: return self.display.nrows - 5 # padding return None
[docs] def make_print_message(self) -> list: texts = self.make_metrics_message(self.metrics, self.nrows, self.ncols) pbar = self.make_bar( current_iteration=self.current_iteration, start_time=self.start_time, total_steps=self.total_steps, epoch_len=self.epoch_len, ) if self.uid is not None: texts.append(f"{self.uid}: {pbar}") else: texts.append(pbar) if (last_line := get_last_line(self.logfile)) is not None: texts.append(last_line) return texts
def _update(self): texts = self.make_print_message() if self.display is not None: self.display.print_texts(texts) else: ray.get(self.remote_display.update_status.remote(self.uid, texts)) if self.current_iteration + 1 == self.epoch_len: self.close()
[docs] def update_metrics(self, metrics: dict[str, ty.Any], current_iteration: int): self.metrics = metrics self.current_iteration = current_iteration if ( current_iteration == 0 or time.time() - self._prev_update_time > self.update_interval or current_iteration + 1 == self.epoch_len ): self._prev_update_time = time.time() self._update()