Source code for ablator.utils.file

import copy
import json
import os
import typing as ty
from pathlib import Path

import numpy as np
import pandas as pd
import torch


[docs]def make_sub_dirs(parent: str | Path, *dir_names: str) -> list[Path]: """ Create subdirectories under the given parent directory. Parameters ---------- parent : str | Path Parent directory where subdirectories should be created. *dir_names : str Names of the subdirectories to create. Returns ------- list[Path] A list of created subdirectory paths. Examples -------- >>> parent_directory = "C:/example_parent_directory" >>> subdirectory_names = ["subdir1", "subdir2", "subdir3"] >>> created_subdirectories = make_sub_dirs(parent_directory, *subdirectory_names) >>> created_subdirectories [Path('C:/example_parent_directory/subdir1'), Path('C:/example_parent_directory/subdir2'), Path('C:/example_parent_directory/subdir3')] """ dirs: list[Path] = [] for dir_name in dir_names: dir_path = Path(parent).joinpath(dir_name) dir_path.mkdir(parents=True, exist_ok=True) dirs.append(dir_path) return dirs
[docs]def save_checkpoint(state: object, filename: str = "checkpoint.pt"): """ Save a checkpoint of the given state. Parameters ---------- state : object Model State dictionary to save. filename : str The name of the checkpoint file, by default "checkpoint.pt". """ torch.save(state, filename)
[docs]def clean_checkpoints(checkpoint_folder: Path, n_checkpoints: int): """ Remove all but the n latest checkpoints from the given directory. Parameters ---------- checkpoint_folder : Path Directory containing the checkpoint files. n_checkpoints : int Number of checkpoints to keep. """ chkpts = sorted(list(checkpoint_folder.glob("*.pt")))[::-1] # Keep only last n checkpoints (or first n because we sort in reverse) if len(chkpts) > n_checkpoints: chkpts_to_del = chkpts[n_checkpoints:] for _chkpt in chkpts_to_del: Path(_chkpt).unlink(missing_ok=True)
[docs]def default_val_parser(val: ty.Any) -> ty.Any: """ Converts the input value to a JSON compatible format. Parameters ---------- val : ty.Any The value to be converted. Returns ------- ty.Any The converted value. """ if isinstance(val, np.ndarray): return val.tolist() if isinstance(val, torch.Tensor): return default_val_parser(val.detach().cpu().numpy()) if isinstance(val, pd.DataFrame): return default_val_parser(np.array(val)) return str(val)
[docs]def json_to_dict(json_: str) -> dict: """ Convert a JSON string into a dictionary. Parameters ---------- json_ : str JSON string to be converted. Returns ------- dict A dictionary representation of the JSON string. """ dict_ = json.loads(json_) return dict_
[docs]def dict_to_json(dict_: dict) -> str: """ Convert a dictionary into a JSON string. Parameters ---------- dict_ : dict The dictionary to be converted. Returns ------- str The JSON string representation of the dictionary. """ _json = json.dumps(dict_, indent=0, default=default_val_parser) # make sure it can be decoded json_to_dict(_json) return _json
[docs]def nested_set(dict_: dict, keys: list[str], value: ty.Any) -> dict: """ Set a value in a nested dictionary. Parameters ---------- dict_ : dict The dictionary to update. keys : list[str] List of keys representing the nested path. value : ty.Any The value need to set at the specified path. Examples -------- >>> dict_ = {'a': {'b': {'c': 1}}} >>> nested_set(dict_, ['a', 'b', 'c'], 2) >>> dict_ {'a': {'b': {'c': 2}}} Returns ------- dict The updated dictionary with the new value set. """ original_dict = copy.deepcopy(dict_) x = original_dict for key in keys[:-1]: if key not in x: x[key] = {} x = x[key] x[keys[-1]] = value return original_dict
[docs]def truncate_utf8_chars(filename: Path, last_char: str): assert ( len(last_char) == 1 ), f"Can not truncate up to a single character. `last_char`: {last_char}" last_char_ord = ord(last_char) with open(filename, "rb+") as f: size = os.fstat(f.fileno()).st_size offset = 1 while offset <= size: f.seek(-offset, os.SEEK_END) if ord(f.read(1)) == last_char_ord: f.seek(-1, os.SEEK_CUR) f.truncate() return offset += 1 raise RuntimeError( f"Could not truncate {filename} since `last_char`: {last_char} was not found in the file." )