Source code for ablator.mp.node_manager

import getpass
import logging
import socket
import traceback
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import paramiko
import psutil
import ray
from ray.util.state import list_nodes, list_tasks

from ablator.utils.base import get_gpu_mem

DEFAULT_TIMEOUT = 60


[docs]@dataclass class Resource: gpu_free_mem: dict[str, int] mem: int cpu_usage: float cpu_count: int running_tasks: list[str] = field(default_factory=lambda: []) @property def gpu_free_mem_arr(self) -> np.ndarray: return np.array(list(self.gpu_free_mem.values())) @property def cpu_mean_util(self) -> float: return np.array(self.cpu_usage).mean() @property def least_used_gpu(self): return min(self.gpu_free_mem, key=self.gpu_free_mem.get)
[docs]def make_private_key(home_path: Path): pkey_path = Path(home_path).joinpath(".ssh", "ablator_id_rsa") pkey_path.parent.mkdir(exist_ok=True) if not pkey_path.exists(): pkey = paramiko.RSAKey.generate(bits=2048) with pkey_path.open("w", encoding="utf-8") as p: pkey.write_private_key(p) else: pkey = paramiko.RSAKey.from_private_key_file(pkey_path.as_posix()) name = pkey.get_name() public_key = pkey.get_base64() hostname = socket.gethostname() node_ip = socket.gethostbyname(hostname) key = f"{name} {public_key} ablator-{hostname}@{node_ip}" return pkey, key
[docs]def utilization(): free_gpu = get_gpu_mem("free") mem_usage = psutil.virtual_memory().percent cpu_usage = psutil.cpu_percent(interval=2, percpu=True) cpu_count = psutil.cpu_count() return Resource( gpu_free_mem=free_gpu, mem=mem_usage, cpu_usage=cpu_usage, cpu_count=cpu_count )
@ray.remote def update_node(node_ip, key): # check if key in authorized keys ssh_dir = Path.home().joinpath(".ssh") ssh_dir.mkdir(exist_ok=True) username = getpass.getuser() authorized_keys = ssh_dir.joinpath("authorized_keys") if authorized_keys.exists() and key in authorized_keys.read_text(encoding="utf-8"): return node_ip, username with authorized_keys.open("a", encoding="utf-8") as f: f.write(f"{key}\n") return node_ip, username
[docs]class NodeManager: def __init__(self, private_key_home: Path, ray_address: str | None = None): self.pkey, self.public_key = make_private_key(private_key_home) if ( ray.is_initialized() and ray_address is not None and ray_address != ray.get_runtime_context().gcs_address ): raise RuntimeError( "`ray_address` does not match currently running ray instance. Can not initialize ray twice." ) if not ray.is_initialized(): ray.init(address=ray_address) self.ray_address = ray.get_runtime_context().gcs_address self.nodes: dict[str, str] = {} self.update()
[docs] def update(self, timeout: int | None = 10): nodes = {} for node in list_nodes(address=self.ray_address, timeout=timeout): node_ip = node.node_ip node_alive = node.state.lower() == "alive" if node_alive and node_ip not in self.nodes: future = update_node.options( # type: ignore resources={f"node:{node_ip}": 0.01} ).remote(node_ip, self.public_key) try: node_ip, username = ray.get(future, timeout=timeout) nodes[node_ip] = username # pylint: disable=broad-exception-caught except Exception as e: logging.error( "Could not update node with %s. %s %s", node_ip, str(e), traceback.format_exc(), ) elif node_alive and node_ip not in nodes: nodes[node_ip] = self.nodes[node_ip] self.nodes = nodes
[docs] def utilization( self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT ) -> dict[str, Resource]: return self.run_lambda(utilization, node_ips, timeout=timeout)
[docs] def available_resources( self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT ) -> dict[str, Resource]: results = self.utilization(node_ips, timeout=timeout) node_id_map: dict[str, str] = { n.node_id: n.node_ip for n in list_nodes(address=self.ray_address, timeout=timeout) } node_ip_tasks: dict[str, list[str]] = defaultdict(lambda: []) running_tasks = list_tasks( address=self.ray_address, filters=[ ("state", "=", "RUNNING"), # exclude the utilization lambda from above ("func_or_class_name", "!=", "utilization"), ], timeout=timeout, ) for task in running_tasks: if task.node_id in node_id_map: node_ip_tasks[node_id_map[task.node_id]].append(task.name) for node_ip, resource in results.items(): resource.running_tasks = node_ip_tasks[node_ip] return results
def _parse_node_ips(self, node_ips: list | str | None = None) -> list[str]: _node_ips = [] if node_ips is None: _node_ips = self.node_ips if isinstance(node_ips, str): _node_ips = [node_ips] if any(node_ip not in self.nodes for node_ip in _node_ips): raise RuntimeError( f"Not all {set(_node_ips)} found in running nodes: {set(self.node_ips)}." ) return _node_ips @property def node_ips(self) -> list[str]: return list(self.nodes.keys())
[docs] def run_lambda( self, fn, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT, ): self.update() results = {} for node_ip in self._parse_node_ips(node_ips): try: results[node_ip] = ray.get( ray.remote(fn) .options(resources={f"node:{node_ip}": 0.001}) .remote(), timeout=timeout, ) # pylint: disable=broad-exception-caught except Exception as e: logging.error( "Error in `run_lambda` for node with IP %s %s %s", node_ip, str(e), traceback.format_exc(), ) return results
[docs] def run_cmd( self, cmd, node_ips: list | str | None = None, timeout: int = DEFAULT_TIMEOUT ) -> dict[str, str]: self.update(timeout=timeout) result = {} node_ips = self._parse_node_ips(node_ips) errored_ips = [] while len(node_ips) > 0: if (node_ip := node_ips.pop()) not in self.nodes: continue node_username = self.nodes[node_ip] client = paramiko.SSHClient() policy = paramiko.AutoAddPolicy() client.set_missing_host_key_policy(policy) node_name = f"{node_username}@{node_ip}" try: client.connect( node_ip, username=node_username, pkey=self.pkey, timeout=timeout, banner_timeout=timeout, auth_timeout=timeout, channel_timeout=timeout, ) # _stdin, _stdout, _stderr _, _stdout, _ = client.exec_command(cmd) result[node_name] = _stdout.read().decode() # pylint: disable=broad-exception-caught except Exception as e: logging.error( ( "Could not connect to %s. Make sure ssh is configured " "and accessible from the ray head node. \n %s" ), node_name, str(e), ) del self.nodes[node_ip] if node_ip not in errored_ips: # repeat the IP node_ips.append(node_ip) errored_ips.append(node_ip) if "Authentication failed." in str(e): self.update(timeout=timeout) return result