Source code for ablator.mp.node_manager

import getpass
import logging
import socket
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

import paramiko
import ray
from ray.util.state import list_nodes, list_tasks
import numpy as np
import psutil
from ablator.utils.base import get_gpu_mem


[docs]@dataclass class Resource: gpu_free_mem: dict[str, int] mem: int cpu_usage: float cpu_count: int running_tasks: list[str] = field(default_factory=lambda: []) @property def gpu_free_mem_arr(self) -> np.ndarray: return np.array(list(self.gpu_free_mem.values())) @property def cpu_mean_util(self) -> float: return np.array(self.cpu_usage).mean() @property def least_used_gpu(self): return min(self.gpu_free_mem, key=self.gpu_free_mem.get)
[docs]def make_private_key(home_path: Path): pkey_path = Path(home_path).joinpath(".ssh", "ablator_id_rsa") pkey_path.parent.mkdir(exist_ok=True) if not pkey_path.exists(): pkey = paramiko.RSAKey.generate(bits=2048) with pkey_path.open("w", encoding="utf-8") as p: pkey.write_private_key(p) else: pkey = paramiko.RSAKey.from_private_key_file(pkey_path.as_posix()) name = pkey.get_name() public_key = pkey.get_base64() hostname = socket.gethostname() node_ip = socket.gethostbyname(hostname) key = f"{name} {public_key} ablator-{hostname}@{node_ip}" return pkey, key
[docs]def utilization(): free_gpu = get_gpu_mem("free") mem_usage = psutil.virtual_memory().percent cpu_usage = psutil.cpu_percent(interval=2, percpu=True) cpu_count = psutil.cpu_count() return Resource( gpu_free_mem=free_gpu, mem=mem_usage, cpu_usage=cpu_usage, cpu_count=cpu_count )
@ray.remote def update_node(node_ip, key): # check if key in authorized keys ssh_dir = Path.home().joinpath(".ssh") ssh_dir.mkdir(exist_ok=True) username = getpass.getuser() authorized_keys = ssh_dir.joinpath("authorized_keys") if authorized_keys.exists() and key in authorized_keys.read_text(encoding="utf-8"): return node_ip, username with authorized_keys.open("a", encoding="utf-8") as f: f.write(f"{key}\n") return node_ip, username
[docs]class NodeManager: def __init__(self, private_key_home: Path, ray_address: str | None = None): self.pkey, self.public_key = make_private_key(private_key_home) if not ray.is_initialized(): ray.init(address=ray_address) elif not ray.is_initialized(): raise RuntimeError("No ray cluster was found running.") self.ray_address = ray.get_runtime_context().gcs_address self.nodes: dict[str, str] = {} self.update()
[docs] def update(self): futures = [] nodes = {} for node in ray.nodes(): node_ip = node["NodeManagerAddress"] node_alive = node["Alive"] if node_alive and node_ip not in self.nodes: futures.append( update_node.options(resources={f"node:{node_ip}": 0.01}).remote( node_ip, self.public_key ) ) elif node_alive and node_ip not in nodes: nodes[node_ip] = self.nodes[node_ip] nodes.update(dict(ray.get(futures))) self.nodes = nodes
[docs] def utilization(self, node_ips: list | str | None = None) -> dict[str, Resource]: return self.run_lambda(utilization, node_ips)
[docs] def available_resources( self, node_ips: list | str | None = None ) -> dict[str, Resource]: results = self.utilization(node_ips) node_id_map: dict[str, str] = { n.node_id: n.node_ip for n in list_nodes(address=self.ray_address) } node_ip_tasks: dict[str, list[str]] = defaultdict(lambda: []) running_tasks = list_tasks( address=self.ray_address, filters=[ ("state", "=", "RUNNING"), # exclude the utilization lambda from above ("func_or_class_name", "!=", "utilization"), ], ) for task in running_tasks: if task.node_id in node_id_map: node_ip_tasks[node_id_map[task.node_id]].append(task.name) for node_ip in results: results[node_ip].running_tasks = node_ip_tasks[node_ip] return results
def _parse_node_ips(self, node_ips: list | str | None = None) -> list[str]: _node_ips = [] if node_ips is None: _node_ips = self.node_ips if isinstance(node_ips, str): _node_ips = [node_ips] if any(node_ip not in self.nodes for node_ip in _node_ips): raise RuntimeError( f"Not all {set(_node_ips)} found in running nodes: {set(self.node_ips)}." ) return _node_ips @property def node_ips(self) -> list[str]: return list(self.nodes.keys())
[docs] def run_lambda(self, fn, node_ips: list | str | None = None): self.update() futures = {} for node_ip in self._parse_node_ips(node_ips): futures[node_ip] = ( ray.remote(fn).options(resources={f"node:{node_ip}": 0.001}).remote() ) return {k: ray.get(v) for k, v in futures.items()}
[docs] def run_cmd( self, cmd, node_ips: list | str | None = None, timeout: int = 20 ) -> dict[str, str]: self.update() result = {} for node_ip in self._parse_node_ips(node_ips): node_username = self.nodes[node_ip] client = paramiko.SSHClient() policy = paramiko.AutoAddPolicy() client.set_missing_host_key_policy(policy) node_name = f"{node_username}@{node_ip}" try: client.connect( node_ip, username=node_username, pkey=self.pkey, timeout=timeout ) # _stdin, _stdout, _stderr _, _stdout, _ = client.exec_command(cmd) result[node_name] = _stdout.read().decode() # pylint: disable=broad-exception-caught except Exception as e: logging.error( ( "Could not connect to %s. Make sure ssh is configured " "and accessible from the ray head node. \n %s" ), node_name, str(e), ) return result