Source code for ablator.mp.node_manager

import getpass
import logging
import socket
import traceback
from collections import defaultdict
from pathlib import Path

import paramiko
import psutil
import ray
from ray.util.state import list_nodes, list_tasks

from ablator.mp.utils import Resource, ray_init
from ablator.utils._nvml import get_gpu_mem

DEFAULT_TIMEOUT = 60


[docs]def make_private_key(home_path: Path): pkey_path = Path(home_path).joinpath(".ssh", "ablator_id_rsa") pkey_path.parent.mkdir(exist_ok=True) if not pkey_path.exists(): pkey = paramiko.RSAKey.generate(bits=2048) with pkey_path.open("w", encoding="utf-8") as p: pkey.write_private_key(p) else: pkey = paramiko.RSAKey.from_private_key_file(pkey_path.as_posix()) name = pkey.get_name() public_key = pkey.get_base64() hostname = socket.gethostname() node_ip = socket.gethostbyname(hostname) key = f"{name} {public_key} ablator-{hostname}@{node_ip}" return pkey, key
[docs]def utilization(): free_gpu = get_gpu_mem("free") mem_usage = psutil.virtual_memory().percent cpu_usage = psutil.cpu_percent(interval=2, percpu=True) cpu_count = psutil.cpu_count() return Resource( gpu_free_mem=free_gpu, mem=mem_usage, cpu_usage=cpu_usage, cpu_count=cpu_count )
@ray.remote(num_cpus=0.001) def update_node(node_ip, key): # check if key in authorized keys ssh_dir = Path.home().joinpath(".ssh") ssh_dir.mkdir(exist_ok=True) username = getpass.getuser() authorized_keys = ssh_dir.joinpath("authorized_keys") if authorized_keys.exists() and key in authorized_keys.read_text(encoding="utf-8"): return node_ip, username with authorized_keys.open("a", encoding="utf-8") as f: f.write(f"{key}\n") return node_ip, username
[docs]class NodeManager: def __init__(self, private_key_home: Path, ray_address: str | None = None): self.pkey, self.public_key = make_private_key(private_key_home) if ( ray.is_initialized() and ray_address is not None and ray_address != ray.get_runtime_context().gcs_address ): raise RuntimeError( "`ray_address` does not match currently running ray instance. Can not initialize ray twice." ) if not ray.is_initialized(): ray_init(address=ray_address) self.ray_address = ray.get_runtime_context().gcs_address self.nodes: dict[str, str] = {} self.update()
[docs] def update(self, timeout: int | None = 10): nodes = {} for node in list_nodes(address=self.ray_address, timeout=timeout): node_ip = node.node_ip node_alive = node.state.lower() == "alive" if node_alive and node_ip not in self.nodes: future = update_node.options( resources={f"node:{node_ip}": 0.001} ).remote(node_ip, self.public_key) try: node_ip, username = ray.get(future, timeout=timeout) nodes[node_ip] = username # pylint: disable=broad-exception-caught except Exception as e: logging.error( "Could not update node with %s. %s %s", node_ip, str(e), traceback.format_exc(), ) elif node_alive and node_ip not in nodes: nodes[node_ip] = self.nodes[node_ip] self.nodes = nodes
[docs] def utilization( self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT ) -> dict[str, Resource]: return self.run_lambda(utilization, node_ips, timeout=timeout)
[docs] def available_resources( self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT ) -> dict[str, Resource]: results = self.utilization(node_ips, timeout=timeout) node_id_map: dict[str, str] = { n.node_id: n.node_ip for n in list_nodes(address=self.ray_address, timeout=timeout) } node_ip_tasks: dict[str, list[str]] = defaultdict(lambda: []) running_tasks = list_tasks( address=self.ray_address, filters=[ ("state", "=", "RUNNING"), # exclude the utilization lambda from above ("func_or_class_name", "!=", "utilization"), ], timeout=timeout, ) for task in running_tasks: if task.node_id in node_id_map: node_ip_tasks[node_id_map[task.node_id]].append(task.name) for node_ip, resource in results.items(): resource.running_tasks = node_ip_tasks[node_ip] return results
def _parse_node_ips(self, node_ips: list | str | None = None) -> list[str]: _node_ips = [] if node_ips is None: _node_ips = self.node_ips if isinstance(node_ips, str): _node_ips = [node_ips] if any(node_ip not in self.nodes for node_ip in _node_ips): raise RuntimeError( f"Not all {set(_node_ips)} found in running nodes: {set(self.node_ips)}." ) return _node_ips @property def node_ips(self) -> list[str]: return list(self.nodes.keys())
[docs] def run_lambda( self, fn, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT, ): self.update() results = {} for node_ip in self._parse_node_ips(node_ips): try: results[node_ip] = ray.get( ray.remote(fn) .options(num_cpus=0.001, resources={f"node:{node_ip}": 0.001}) .remote(), timeout=timeout, ) # pylint: disable=broad-exception-caught except Exception as e: logging.error( "Error in `run_lambda` for node with IP %s %s %s", node_ip, str(e), traceback.format_exc(), ) return results
[docs] def run_cmd( self, cmd, node_ips: list | str | None = None, timeout: int = DEFAULT_TIMEOUT ) -> dict[str, str]: self.update(timeout=timeout) result = {} node_ips = self._parse_node_ips(node_ips) errored_ips = [] while len(node_ips) > 0: if (node_ip := node_ips.pop()) not in self.nodes: continue node_username = self.nodes[node_ip] client = paramiko.SSHClient() policy = paramiko.AutoAddPolicy() client.set_missing_host_key_policy(policy) node_name = f"{node_username}@{node_ip}" try: client.connect( node_ip, username=node_username, pkey=self.pkey, timeout=timeout, banner_timeout=timeout, auth_timeout=timeout, channel_timeout=timeout, ) _, _stdout, _ = client.exec_command(cmd) result[node_name] = _stdout.read().decode() # pylint: disable=broad-exception-caught except Exception as e: logging.error( ( "Could not connect to %s. Make sure ssh is configured " "and accessible from the ray head node. \n %s" ), node_name, str(e), ) del self.nodes[node_ip] if node_ip not in errored_ips: # repeat the IP node_ips.append(node_ip) errored_ips.append(node_ip) if "Authentication failed." in str(e): self.update(timeout=timeout) return result