Source code for runhouse.resources.hardware.on_demand_cluster

import asyncio
import contextlib
import json
import subprocess
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Union

import requests

import rich.errors
import yaml

try:
    import sky
    from sky.backends import backend_utils
except ImportError:
    pass

from runhouse.constants import (
    DEFAULT_HTTP_PORT,
    DEFAULT_HTTPS_PORT,
    DEFAULT_SERVER_PORT,
    DOCKER_LOGIN_ENV_VARS,
    LOCAL_HOSTS,
)

from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import (
    ResourceServerStatus,
    ServerConnectionType,
    up_cluster_helper,
)

from .cluster import Cluster

logger = get_logger(__name__)


[docs]class OnDemandCluster(Cluster): RESOURCE_TYPE = "cluster" RECONNECT_TIMEOUT = 5 DEFAULT_KEYFILE = "~/.ssh/sky-key"
[docs] def __init__( self, name, instance_type: str = None, num_instances: int = None, provider: str = None, default_env: "Env" = None, dryrun: bool = False, autostop_mins: int = None, use_spot: bool = False, image_id: str = None, memory: Union[int, str] = None, disk_size: Union[int, str] = None, open_ports: Union[int, str, List[int]] = None, server_host: int = None, server_port: int = None, server_connection_type: str = None, ssl_keyfile: str = None, ssl_certfile: str = None, domain: str = None, den_auth: bool = False, region: str = None, sky_kwargs: Dict = None, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): """ On-demand `SkyPilot <https://github.com/skypilot-org/skypilot/>`__ Cluster. .. note:: To build a cluster, please use the factory method :func:`cluster`. """ super().__init__( name=name, default_env=default_env, server_host=server_host, server_port=server_port, server_connection_type=server_connection_type, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, domain=domain, den_auth=den_auth, dryrun=dryrun, **kwargs, ) self.instance_type = instance_type self.num_instances = num_instances self.provider = provider or configs.get("default_provider") self._autostop_mins = ( autostop_mins if autostop_mins is not None else configs.get("default_autostop") ) self.open_ports = open_ports self.use_spot = use_spot if use_spot is not None else configs.get("use_spot") self.image_id = image_id self.region = region self.memory = memory self.disk_size = disk_size self.sky_kwargs = sky_kwargs or {} self.stable_internal_external_ips = kwargs.get( "stable_internal_external_ips", None ) self.launched_properties = kwargs.get("launched_properties", {}) self._docker_user = None # Checks if state info is in local sky db, populates if so. if not dryrun and not self.ips and not self.creds_values: # Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once self._update_from_sky_status(dryrun=False)
@property def client(self): try: return super().client except ValueError as e: if not self.address: # Try loading in from local Sky DB self._update_from_sky_status(dryrun=True) if not self.address: raise ValueError( f"Could not determine address for ondemand cluster <{self.name}>. " "Up the cluster with `cluster.up_if_not`." ) return super().client raise e @property def autostop_mins(self): return self._autostop_mins @autostop_mins.setter def autostop_mins(self, mins): self._autostop_mins = mins if self.on_this_cluster(): obj_store.set_cluster_config_value("autostop_mins", mins) else: # if self.run_python(["import skypilot"])[0] != 0: # raise ImportError( # "Skypilot must be installed on the cluster in order to set autostop." # ) self.call_client_method("set_settings", {"autostop_mins": mins}) sky.autostop(self.name, mins, down=True) @property def docker_user(self) -> str: if self._docker_user: return self._docker_user # TODO detect whether this is a k8s cluster properly, and handle the user setting / SSH properly # (e.g. SkyPilot's new KubernetesCommandRunner) if not self.image_id or "docker:" not in self.image_id: return None if self.launched_properties["cloud"] == "kubernetes": return self.launched_properties.get( "docker_user", self.launched_properties.get("ssh_user", "root") ) from runhouse.resources.hardware.sky_command_runner import get_docker_user if not self._creds: return self._docker_user = get_docker_user(self, self._creds.values) return self._docker_user def config(self, condensed=True): config = super().config(condensed) self.save_attrs_to_config( config, [ "instance_type", "num_instances", "provider", "open_ports", "use_spot", "image_id", "region", "stable_internal_external_ips", "memory", "disk_size", "sky_kwargs", "launched_properties", ], ) config["autostop_mins"] = self._autostop_mins return config
[docs] def endpoint(self, external: bool = False): if not self.address or self.on_this_cluster(): return None try: self.client.check_server() except ConnectionError: return None return super().endpoint(external)
def _copy_sky_yaml_from_cluster(self, abs_yaml_path: str): if not Path(abs_yaml_path).exists(): Path(abs_yaml_path).parent.mkdir(parents=True, exist_ok=True) self.rsync("~/.sky/sky_ray.yml", abs_yaml_path, up=False) # Save SSH info to the ~/.ssh/config ray_yaml = yaml.safe_load(open(abs_yaml_path, "r")) backend_utils.SSHConfigHelper.add_cluster( self.name, [self.address], ray_yaml["auth"] ) @staticmethod def relative_yaml_path(yaml_path): if Path(yaml_path).is_absolute(): yaml_path = "~/.sky/generated/" + Path(yaml_path).name return yaml_path def set_connection_defaults(self): if not self.server_connection_type: if self.ssl_keyfile or self.ssl_certfile: self.server_connection_type = ServerConnectionType.TLS else: self.server_connection_type = ServerConnectionType.SSH if self.server_port is None: if self.server_connection_type == ServerConnectionType.TLS: self.server_port = DEFAULT_HTTPS_PORT elif self.server_connection_type == ServerConnectionType.NONE: self.server_port = DEFAULT_HTTP_PORT else: self.server_port = DEFAULT_SERVER_PORT if ( self.server_connection_type in [ServerConnectionType.TLS, ServerConnectionType.NONE] and self.server_host in LOCAL_HOSTS ): warnings.warn( f"Server connection type: {self.server_connection_type}, server host: {self.server_host}. " f"Note that this will require opening an SSH tunnel to forward traffic from" f" {self.server_host} to the server." ) self.open_ports = ( [] if self.open_ports is None else [self.open_ports] if isinstance(self.open_ports, (int, str)) else self.open_ports ) if self.open_ports: self.open_ports = [str(p) for p in self.open_ports] if str(self.server_port) in self.open_ports: if ( self.server_connection_type in [ServerConnectionType.TLS, ServerConnectionType.NONE] and not self.den_auth ): warnings.warn( "Server is insecure and must be inside a VPC or have `den_auth` enabled to secure it." ) else: warnings.warn( f"Server port {self.server_port} not included in open ports. Note you are responsible for opening " f"the port or ensure you have access to it via a VPC." ) else: # If using HTTP or HTTPS must enable traffic on the relevant port if self.server_connection_type in [ ServerConnectionType.TLS, ServerConnectionType.NONE, ]: if self.server_port: warnings.warn( f"No open ports specified. Setting default port {self.server_port} to open." ) self.open_ports = [str(self.server_port)] else: warnings.warn( f"No open ports specified. Make sure the relevant port is open. " f"HTTPS default: {DEFAULT_HTTPS_PORT} and HTTP " f"default: {DEFAULT_HTTP_PORT}." ) # ----------------- Launch/Lifecycle Methods -----------------
[docs] def is_up(self) -> bool: """Whether the cluster is up. Example: >>> rh.ondemand_cluster("rh-cpu").is_up() """ if self.on_this_cluster(): return True return self._ping(retry=True)
def _sky_status(self, refresh: bool = True, retry: bool = True): """ Get status of Sky cluster. Return dict looks like: .. code-block:: {'name': 'sky-cpunode-donny', 'launched_at': 1662317201, 'handle': ResourceHandle( cluster_name=sky-cpunode-donny, head_ip=54.211.97.164, cluster_yaml=/Users/donny/.sky/generated/sky-cpunode-donny.yml, launched_resources=1x AWS(m6i.2xlarge), tpu_create_script=None, tpu_delete_script=None), 'last_use': 'sky cpunode', 'status': <ClusterStatus.UP: 'UP'>, 'autostop': -1, 'metadata': {}} .. note:: For more information see SkyPilot's :code:`ResourceHandle` `class <https://github.com/skypilot-org/skypilot/blob/0c2b291b03abe486b521b40a3069195e56b62324/sky/backends/cloud_vm_ray_backend.py#L1457>`__. """ if not sky.global_user_state.get_cluster_from_name(self.name): return None try: state = sky.status(cluster_names=[self.name], refresh=refresh) except rich.errors.LiveError as e: # We can't have more than one Live display at once, so if we've already launched one (e.g. the first # time we call status), we can retry without refreshing if not retry: raise e return self._sky_status(refresh=False, retry=False) # We still need to check if the cluster present in case the cluster went down and was removed from the DB if len(state) == 0: return None return state[0] @property def internal_ips(self): if not self.stable_internal_external_ips: self._update_from_sky_status() return [int_ip for int_ip, _ in self.stable_internal_external_ips] def _start_ray_workers(self, ray_port, env): # Find the internal IP corresponding to the public_head_ip and the rest are workers internal_head_ip = None worker_ips = [] stable_internal_external_ips = self._sky_status()[ "handle" ].stable_internal_external_ips for internal, external in stable_internal_external_ips: if external == self.address: internal_head_ip = internal else: # NOTE: Using external worker address here because we're running from local worker_ips.append(external) logger.debug(f"Internal head IP: {internal_head_ip}") for host in worker_ips: logger.info( f"Starting Ray on worker {host} with head node at {internal_head_ip}:{ray_port}." ) self.run( commands=[ f"ray start --address={internal_head_ip}:{ray_port} --disable-usage-stats", ], node=host, env=env, ) time.sleep(5) def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]): if cluster_dict and cluster_dict["status"].name in ["UP", "INIT"]: handle = cluster_dict["handle"] self.address = handle.head_ip self.stable_internal_external_ips = handle.stable_internal_external_ips if self.stable_internal_external_ips is None or self.address is None: raise ValueError( "Sky's cluster status does not have the necessary information to connect to the cluster. Please check if the cluster is up via `sky status`. Consider bringing down the cluster with `sky down` if you are still having issues." ) yaml_path = handle.cluster_yaml if Path(yaml_path).exists(): ssh_values = backend_utils.ssh_credential_from_yaml( yaml_path, ssh_user=handle.ssh_user ) if not self.creds_values: from runhouse.resources.secrets.utils import setup_cluster_creds self._creds = setup_cluster_creds(ssh_values, self.name) # Add worker IPs if multi-node cluster - keep the head node as the first IP self.ips = [ext for _, ext in self.stable_internal_external_ips] launched_resource = handle.launched_resources cloud = str(launched_resource.cloud).lower() instance_type = launched_resource.instance_type region = launched_resource.region cost_per_hr = launched_resource.get_cost(60 * 60) disk_size = launched_resource.disk_size num_cpus = launched_resource.cpus self.launched_properties = { "cloud": cloud, "instance_type": instance_type, "region": region, "cost_per_hour": str(cost_per_hr), "disk_size": disk_size, "num_cpus": num_cpus, } if launched_resource.accelerators: self.launched_properties[ "accelerators" ] = launched_resource.accelerators if handle.ssh_user: self.launched_properties["ssh_user"] = handle.ssh_user if handle.docker_user: self.launched_properties["docker_user"] = handle.docker_user if cloud == "kubernetes": try: import kubernetes _, context = kubernetes.config.list_kube_config_contexts() if "namespace" in context["context"]: namespace = context["context"]["namespace"] else: namespace = "default" except: namespace = "default" pod_name = f"{handle.cluster_name_on_cloud}-head" self.launched_properties["namespace"] = namespace self.launched_properties["pod_name"] = pod_name else: self.address = None self._creds = None self.stable_internal_external_ips = None self.launched_properties = {} def _update_from_sky_status(self, dryrun: bool = False): # Try to get the cluster status from SkyDB if self.is_shared: # If the cluster is shared can ignore, since the sky data will only be saved on the machine where # the cluster was initially upped return cluster_dict = self._sky_status(refresh=not dryrun) self._populate_connection_from_status_dict(cluster_dict)
[docs] def get_instance_type(self): """Returns instance type of the cluster.""" if self.instance_type and "--" in self.instance_type: # K8s specific syntax return self.instance_type elif ( self.instance_type and ":" not in self.instance_type and "CPU" not in self.instance_type ): return self.instance_type return None
[docs] def accelerators(self): """Returns the acclerator type, or None if is a CPU.""" if ( self.instance_type and ":" in self.instance_type and "CPU" not in self.instance_type ): return self.instance_type return None
[docs] def num_cpus(self): """Return the number of CPUs for a CPU cluster.""" if ( self.instance_type and ":" in self.instance_type and "CPU" in self.instance_type ): return self.instance_type.rsplit(":", 1)[1] return None
[docs] async def a_up(self, capture_output: Union[bool, str] = True): """Up the cluster async in another process, so it can be parallelized and logs can be captured sanely. capture_output: If True, supress the output of the cluster creation process. If False, print the output normally. If a string, write the output to the file at that path. """ with ProcessPoolExecutor() as executor: loop = asyncio.get_running_loop() await loop.run_in_executor( executor, up_cluster_helper, self, capture_output ) return self
async def a_up_if_not(self, capture_output: Union[bool, str] = True): if not self.is_up(): # Don't store stale IPs self.ips = None await self.a_up(capture_output=capture_output) return self
[docs] def up(self): """Up the cluster. Example: >>> rh.ondemand_cluster("rh-cpu").up() """ if self.on_this_cluster(): return self supported_providers = ["cheapest"] + list(sky.clouds.CLOUD_REGISTRY) if self.provider not in supported_providers: raise ValueError( f"Cluster provider {self.provider} not supported. Must be one {supported_providers} supported by SkyPilot." ) task = sky.Task(num_nodes=self.num_instances) cloud_provider = ( sky.clouds.CLOUD_REGISTRY.from_str(self.provider) if self.provider != "cheapest" else None ) try: task.set_resources( sky.Resources( # TODO: confirm if passing instance type in old way (without --) works when provider is k8s cloud=cloud_provider, instance_type=self.get_instance_type(), accelerators=self.accelerators(), cpus=self.num_cpus(), memory=self.memory, region=self.region or configs.get("default_region"), disk_size=self.disk_size, ports=self.open_ports, image_id=self.image_id, use_spot=self.use_spot, **self.sky_kwargs.get("resources", {}), ) ) if self.image_id: import os docker_env_vars = {} for env_var in DOCKER_LOGIN_ENV_VARS: if os.getenv(env_var): docker_env_vars[env_var] = os.getenv(env_var) if docker_env_vars: task.update_envs(docker_env_vars) sky.launch( task, cluster_name=self.name, idle_minutes_to_autostop=self._autostop_mins, down=True, **self.sky_kwargs.get("launch", {}), ) # Make sure no args are passed both in sky_kwargs and as explicit args except TypeError as e: if "got multiple values for keyword argument" in str(e): raise TypeError( f"{str(e)}. If argument is in `sky_kwargs`, it may need to be passed directly through the " f"ondemand_cluster constructor (see `ondemand_cluster docs " f"<https://www.run.house/docs/api/python/cluster#runhouse.ondemand_cluster>`__)." ) raise e self._update_from_sky_status() if self.domain: logger.info( f"Cluster has been launched with the custom domain '{self.domain}'. " "Please add an A record to your DNS provider to point this domain to the cluster's " f"public IP address ({self.address}) to ensure successful requests." ) self.restart_server() if rns_client.autosave_resources(): self.save() return self
[docs] def keep_warm(self, mins: int = -1): """Keep the cluster warm for given number of minutes after inactivity. Args: mins (int): Amount of time (in min) to keep the cluster warm after inactivity. If set to -1, keep cluster warm indefinitely. (Default: `-1`) """ self.autostop_mins = mins return self
[docs] def teardown(self): """Teardown cluster. Example: >>> rh.ondemand_cluster("rh-cpu").teardown() """ try: cluster_status_data = self.status() status_data = { "status": ResourceServerStatus.terminated, "resource_type": self.__class__.__base__.__name__.lower(), "data": cluster_status_data, } cluster_uri = rns_client.format_rns_address(self.rns_address) api_server_url = rns_client.api_server_url status_resp = requests.post( f"{api_server_url}/resource/{cluster_uri}/cluster/status", data=json.dumps(status_data), headers=rns_client.request_headers(), ) # 404 means that the cluster is not saved in den, it is fine that the status is not updated. if status_resp.status_code not in [200, 404]: logger.warning("Failed to update Den with terminated cluster status") except Exception as e: logger.warning(e) # Stream logs sky.down(self.name) self.address = None
[docs] def teardown_and_delete(self): """Teardown cluster and delete it from configs. Example: >>> rh.ondemand_cluster("rh-cpu").teardown_and_delete() """ self.teardown() rns_client.delete_configs(resource=self)
[docs] @contextlib.contextmanager def pause_autostop(self): """Context manager to temporarily pause autostop. Example: >>> with rh.ondemand_cluster.pause_autostop(): >>> rh.ondemand_cluster.run(["python train.py"]) """ sky.autostop(self.name, idle_minutes=-1) yield sky.autostop(self.name, idle_minutes=self._autostop_mins, down=True)
# ----------------- SSH Methods ----------------- #
[docs] @staticmethod def cluster_ssh_key(path_to_file: Path): """Retrieve SSH key for the cluster. Args: path_to_file (Path): Path of the private key associated with the cluster. Example: >>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa") """ try: f = open(path_to_file, "r") private_key = f.read() return private_key except FileNotFoundError: raise Exception(f"File with ssh key not found in: {path_to_file}")
[docs] def ssh(self, node: str = None): """SSH into the cluster. Args: node: Node to SSH into. If no node is specified, will SSH onto the head node. (Default: ``None``) Example: >>> rh.ondemand_cluster("rh-cpu").ssh() >>> rh.ondemand_cluster("rh-cpu", node="3.89.174.234").ssh() """ if self.provider == "kubernetes": command = f"kubectl get pods | grep {self.name}" try: output = subprocess.check_output(command, shell=True, text=True) lines = output.strip().split("\n") if lines: pod_name = lines[0].split()[0] else: logger.info("No matching pods found.") except subprocess.CalledProcessError as e: raise Exception(f"Error: {e}") cmd = f"kubectl exec -it {pod_name} -- /bin/bash" subprocess.run(cmd, shell=True, check=True) else: # If SSHing onto a specific node, which requires the default sky public key for verification from runhouse.resources.hardware.sky_command_runner import SshMode sky_key = Path( self.creds_values.get("ssh_private_key", self.DEFAULT_KEYFILE) ).expanduser() if not sky_key.exists(): raise FileNotFoundError(f"Expected default sky key in path: {sky_key}") runner = self._command_runner(node=node) cmd = runner.run( cmd="bash --rcfile <(echo '. ~/.bashrc; conda deactivate')", ssh_mode=SshMode.INTERACTIVE, port_forward=None, return_cmd=True, ) subprocess.run(cmd, shell=True)
def _ping(self, timeout=5, retry=False): if super()._ping(timeout=timeout, retry=False): return True if retry: self._update_from_sky_status(dryrun=False) return super()._ping(timeout=timeout, retry=False) return False