import getpass
import logging
import socket
import traceback
from collections import defaultdict
from pathlib import Path
import paramiko
import psutil
import ray
from ray.util.state import list_nodes, list_tasks
from ablator.mp.utils import Resource, ray_init
from ablator.utils._nvml import get_gpu_mem
DEFAULT_TIMEOUT = 60
[docs]def make_private_key(home_path: Path):
pkey_path = Path(home_path).joinpath(".ssh", "ablator_id_rsa")
pkey_path.parent.mkdir(exist_ok=True)
if not pkey_path.exists():
pkey = paramiko.RSAKey.generate(bits=2048)
with pkey_path.open("w", encoding="utf-8") as p:
pkey.write_private_key(p)
else:
pkey = paramiko.RSAKey.from_private_key_file(pkey_path.as_posix())
name = pkey.get_name()
public_key = pkey.get_base64()
hostname = socket.gethostname()
node_ip = socket.gethostbyname(hostname)
key = f"{name} {public_key} ablator-{hostname}@{node_ip}"
return pkey, key
[docs]def utilization():
free_gpu = get_gpu_mem("free")
mem_usage = psutil.virtual_memory().percent
cpu_usage = psutil.cpu_percent(interval=2, percpu=True)
cpu_count = psutil.cpu_count()
return Resource(
gpu_free_mem=free_gpu, mem=mem_usage, cpu_usage=cpu_usage, cpu_count=cpu_count
)
@ray.remote(num_cpus=0.001)
def update_node(node_ip, key):
# check if key in authorized keys
ssh_dir = Path.home().joinpath(".ssh")
ssh_dir.mkdir(exist_ok=True)
username = getpass.getuser()
authorized_keys = ssh_dir.joinpath("authorized_keys")
if authorized_keys.exists() and key in authorized_keys.read_text(encoding="utf-8"):
return node_ip, username
with authorized_keys.open("a", encoding="utf-8") as f:
f.write(f"{key}\n")
return node_ip, username
[docs]class NodeManager:
def __init__(self, private_key_home: Path, ray_address: str | None = None):
self.pkey, self.public_key = make_private_key(private_key_home)
if (
ray.is_initialized()
and ray_address is not None
and ray_address != ray.get_runtime_context().gcs_address
):
raise RuntimeError(
"`ray_address` does not match currently running ray instance. Can not initialize ray twice."
)
if not ray.is_initialized():
ray_init(address=ray_address)
self.ray_address = ray.get_runtime_context().gcs_address
self.nodes: dict[str, str] = {}
self.update()
[docs] def update(self, timeout: int | None = 10):
nodes = {}
for node in list_nodes(address=self.ray_address, timeout=timeout):
node_ip = node.node_ip
node_alive = node.state.lower() == "alive"
if node_alive and node_ip not in self.nodes:
future = update_node.options(
resources={f"node:{node_ip}": 0.001}
).remote(node_ip, self.public_key)
try:
node_ip, username = ray.get(future, timeout=timeout)
nodes[node_ip] = username
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error(
"Could not update node with %s. %s %s",
node_ip,
str(e),
traceback.format_exc(),
)
elif node_alive and node_ip not in nodes:
nodes[node_ip] = self.nodes[node_ip]
self.nodes = nodes
[docs] def utilization(
self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT
) -> dict[str, Resource]:
return self.run_lambda(utilization, node_ips, timeout=timeout)
[docs] def available_resources(
self, node_ips: list | str | None = None, timeout: int | None = DEFAULT_TIMEOUT
) -> dict[str, Resource]:
results = self.utilization(node_ips, timeout=timeout)
node_id_map: dict[str, str] = {
n.node_id: n.node_ip
for n in list_nodes(address=self.ray_address, timeout=timeout)
}
node_ip_tasks: dict[str, list[str]] = defaultdict(lambda: [])
running_tasks = list_tasks(
address=self.ray_address,
filters=[
("state", "=", "RUNNING"),
# exclude the utilization lambda from above
("func_or_class_name", "!=", "utilization"),
],
timeout=timeout,
)
for task in running_tasks:
if task.node_id in node_id_map:
node_ip_tasks[node_id_map[task.node_id]].append(task.name)
for node_ip, resource in results.items():
resource.running_tasks = node_ip_tasks[node_ip]
return results
def _parse_node_ips(self, node_ips: list | str | None = None) -> list[str]:
_node_ips = []
if node_ips is None:
_node_ips = self.node_ips
if isinstance(node_ips, str):
_node_ips = [node_ips]
if any(node_ip not in self.nodes for node_ip in _node_ips):
raise RuntimeError(
f"Not all {set(_node_ips)} found in running nodes: {set(self.node_ips)}."
)
return _node_ips
@property
def node_ips(self) -> list[str]:
return list(self.nodes.keys())
[docs] def run_lambda(
self,
fn,
node_ips: list | str | None = None,
timeout: int | None = DEFAULT_TIMEOUT,
):
self.update()
results = {}
for node_ip in self._parse_node_ips(node_ips):
try:
results[node_ip] = ray.get(
ray.remote(fn)
.options(num_cpus=0.001, resources={f"node:{node_ip}": 0.001})
.remote(),
timeout=timeout,
)
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error(
"Error in `run_lambda` for node with IP %s %s %s",
node_ip,
str(e),
traceback.format_exc(),
)
return results
[docs] def run_cmd(
self, cmd, node_ips: list | str | None = None, timeout: int = DEFAULT_TIMEOUT
) -> dict[str, str]:
self.update(timeout=timeout)
result = {}
node_ips = self._parse_node_ips(node_ips)
errored_ips = []
while len(node_ips) > 0:
if (node_ip := node_ips.pop()) not in self.nodes:
continue
node_username = self.nodes[node_ip]
client = paramiko.SSHClient()
policy = paramiko.AutoAddPolicy()
client.set_missing_host_key_policy(policy)
node_name = f"{node_username}@{node_ip}"
try:
client.connect(
node_ip,
username=node_username,
pkey=self.pkey,
timeout=timeout,
banner_timeout=timeout,
auth_timeout=timeout,
channel_timeout=timeout,
)
_, _stdout, _ = client.exec_command(cmd)
result[node_name] = _stdout.read().decode()
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error(
(
"Could not connect to %s. Make sure ssh is configured "
"and accessible from the ray head node. \n %s"
),
node_name,
str(e),
)
del self.nodes[node_ip]
if node_ip not in errored_ips:
# repeat the IP
node_ips.append(node_ip)
errored_ips.append(node_ip)
if "Authentication failed." in str(e):
self.update(timeout=timeout)
return result