Source code for runhouse.resources.secrets.secret

import copy
import json
import logging
import os
from pathlib import Path

from typing import Dict, List, Optional, Union

from runhouse.globals import configs, rns_client
from runhouse.resources.hardware import _get_cluster_from, Cluster
from runhouse.resources.resource import Resource
from runhouse.resources.secrets.utils import _delete_vault_secrets, load_config
from runhouse.rns.utils.api import load_resp_content, read_resp_data
from runhouse.rns.utils.names import _generate_default_name

logger = logging.getLogger(__name__)

[docs]class Secret(Resource): RESOURCE_TYPE = "secret" USER_ENDPOINT = "user/secret" GROUP_ENDPOINT = "group/secret" DEFAULT_DIR = "~/.rh/secrets"
[docs] def __init__( self, name: Optional[str], values: Dict = None, dryrun: bool = False, **kwargs, ): """ Runhouse Secret object. .. note:: To create a Secret, please use one of the factory methods. """ super().__init__(name=name, dryrun=dryrun) self._values = values
@property def values(self): return self._values def config(self, condensed=True): config = super().config(condensed) if self._values: config.update( { "values": self._values, } ) return config @staticmethod def _write_shared_secret_to_local(config): import runhouse as rh new_creds_values = config["values"] folder_name = config["name"].replace("/", "_") path = f"{Secret.DEFAULT_DIR}/{folder_name}" private_key_value, public_key_value = new_creds_values.get( "private_key" ), new_creds_values.get("public_key") private_key_path = public_key_path = Path(f"{path}").expanduser() if private_key_value: if not private_key_path.exists(): os.makedirs(str(private_key_path)) private_file_path = private_key_path / "ssh-key" with open(str(private_file_path), "w") as f: f.write(private_key_value) private_file_path.chmod(0o600) if public_key_value: public_file_path = public_key_path / "" with open(str(public_key_path / ""), "w") as f: f.write(public_key_value) public_file_path.chmod(0o600) if private_key_value and public_key_value: new_creds_values = { "ssh_private_key": str(private_key_path / "ssh-key"), "ssh_public_key": str(public_key_path / ""), } if private_key_value and new_creds_values.get("ssh_user"): new_creds_values = { "ssh_private_key": str(private_key_path / "ssh-key"), "ssh_user": new_creds_values.get("ssh_user"), } return rh.secret( values=new_creds_values, name=f"loaded_secret_{config['name']}" )
[docs] @staticmethod def from_config(config: dict, dryrun: bool = False, _resolve_children=True): """Create a Secret object from a config dictionary.""" if "provider" in config: from runhouse.resources.secrets.provider_secrets.providers import ( _get_provider_class, ) provider_class = _get_provider_class(config["provider"]) return provider_class.from_config(config, dryrun=dryrun) # checks if the config is a of a shared secret current_user = configs.username owner_user = config["owner"]["username"] if "owner" in config.keys() else None if owner_user and current_user != owner_user and config["values"]: return Secret._write_shared_secret_to_local(config) return Secret(**config, dryrun=dryrun)
[docs] @classmethod def from_name(cls, name, dryrun=False): """Load existing Secret via its name.""" try: config = load_config(name, cls.USER_ENDPOINT) if config: return cls.from_config(config=config, dryrun=dryrun) except ValueError: pass if name in cls.builtin_providers(as_str=True): from runhouse.resources.secrets.provider_secrets.providers import ( _get_provider_class, ) provider_class = _get_provider_class(name) return provider_class(provider=name, dryrun=dryrun) raise ValueError(f"Could not locate secret {name}")
[docs] @classmethod def builtin_providers(cls, as_str: bool = False) -> List: """Return list of all Runhouse providers (as class objects) supported out of the box.""" from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) if as_str: return list(_str_to_provider_class.keys()) return list(_str_to_provider_class.values())
[docs] @classmethod def vault_secrets(cls, headers: Optional[Dict] = None) -> List[str]: """Get secret names that are stored in Vault""" uri = f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}" resp = rns_client.session.get( uri, headers=headers or rns_client.request_headers(), ) if resp.status_code not in [200, 404]: raise Exception( f"Received [{resp.status_code}] from Den GET '{uri}': Failed to download secrets from Vault." ) response = read_resp_data(resp) return list(response.keys())
@classmethod def local_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: if not os.path.exists(os.path.expanduser("~/.rh/secrets")): return {} all_names = [ Path(file).stem for file in os.listdir(os.path.expanduser("~/.rh/secrets")) if file.endswith("json") ] names = [name for name in names if name in all_names] if names else all_names secrets = {} for name in names: path = os.path.expanduser(f"~/.rh/secrets/{name}.json") try: with open(path, "r") as f: config = json.load(f) if config["name"].startswith("~") or config["name"].startswith("^"): config["name"] = config["name"][2:] secrets[name] = Secret.from_config(config) except json.JSONDecodeError: # Ignore any empty / corrupted files continue return secrets @classmethod def extract_provider_secrets(cls, names: List[str] = None) -> Dict[str, "Secret"]: from runhouse.resources.secrets.provider_secrets.providers import ( _str_to_provider_class, ) from runhouse.resources.secrets.secret_factory import provider_secret secrets = {} names = names or _str_to_provider_class.keys() for provider in names: if provider == "ssh": continue try: secret = provider_secret(provider=provider) secrets[provider] = secret except ValueError: continue # locally configured ssh secrets if "ssh" in names: default_ssh_folder = "~/.ssh" ssh_files = os.listdir(os.path.expanduser(default_ssh_folder)) for file in ssh_files: if file != "sky-key" and f"{file}.pub" in ssh_files: name = f"ssh-{file}" secret = provider_secret( provider="ssh", name=name, path=os.path.join(default_ssh_folder, file), ) secrets[name] = secret return secrets # TODO: refactor this code to reuse rns_client save_config code instead of rewriting
[docs] def save( self, name: str = None, save_values: bool = True, headers: Optional[Dict] = None, folder: str = None, ): """ Save the secret config to Den. Save the secret values into Vault if the user is logged in, or to local if not or if the resource is a local resource. If a folder is specified, save the secret to that folder in Den (e.g. saving secrets for a cluster associated with an organization). """ if name: = name elif not raise ValueError("A resource must have a name to be saved.") self._rns_folder = folder or self._rns_folder or rns_client.current_folder config = self.config() config["name"] = self.rns_address if "values" in config: # don't save values into Den config del config["values"] headers = headers or rns_client.request_headers() # Save metadata to Den if self.rns_address.startswith("/"):"Saving config for {} to Den") payload = rns_client.resource_request_payload(config) uri = f"{rns_client.api_server_url}/resource" resp = uri, data=json.dumps(payload), headers=headers, ) # If resource config hasn't changed (i.e. nothing to update) will return a 422 if resp.status_code not in [200, 422]: raise Exception( f"Received [{resp.status_code}] from Den POST '{uri}': Failed to save metadata to Den: {load_resp_content(resp)}" ) if save_values and self.values:"Saving secrets for {} to Vault") resource_uri = rns_client.resource_uri(self.rns_address) uri = f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}" resp = rns_client.session.put( uri, data=json.dumps( {"name": self.rns_address, "data": {"values": self.values}} ), headers=headers, ) if resp.status_code != 200: raise Exception( f"Received [{resp.status_code}] from Den PUT '{uri}': Failed to put resources in Vault: {load_resp_content(resp)}" ) else: config_path = os.path.expanduser(f"~/.rh/secrets/{}.json") os.makedirs(os.path.dirname(config_path), exist_ok=True) if save_values: config["values"] = self.values with open(config_path, "w") as f: json.dump(config, f, indent=4)"Saving config for {self.rns_address} to: {config_path}") return self
[docs] def delete(self, headers: Optional[Dict] = None): """Delete the secret config from Den and from Vault/local.""" if not (self.in_vault() or self.in_local()): logger.warning( "Can not delete a secret that has not been saved down to Vault or local." ) else: if self.rns_address and self.rns_address.startswith("/"): self._delete_secret_configs(headers) else: self._delete_local_config() configs.delete_provider(
def _delete_local_config(self): config_path = os.path.expanduser(f"~/.rh/secrets/{}.json") if os.path.exists(config_path): os.remove(config_path) def _delete_secret_configs(self, headers: Optional[Dict] = None): headers = headers or rns_client.request_headers() # Delete secrets in Vault resource_uri = rns_client.resource_uri(self.rns_address) _delete_vault_secrets(resource_uri, self.USER_ENDPOINT, headers=headers) # Delete RNS data for resource uri = f"{rns_client.api_server_url}/resource/{resource_uri}" resp = rns_client.session.delete( uri, headers=headers, ) if resp.status_code != 200: logger.error( f"Received [{resp.status_code}] from Den DELETE '{uri}': Failed to delete secret resource from Den: {load_resp_content(resp)}" )
[docs] def to( self, system: Union[str, Cluster], name: Optional[str] = None, env: Optional["Env"] = None, ): """Return a copy of the secret on a system. Args: system (str or Cluster): Cluster to send the secret to name (str, ooptional): Name to assign the resource on the cluster. Example: >>>, path=secret.path) """ new_secret = copy.deepcopy(self) = name or or _generate_default_name(prefix="secret") system = _get_cluster_from(system) if system.on_this_cluster(): else: env = env or system.default_env system.put_resource(new_secret, env=env) return new_secret
[docs] def in_local(self): """Whether the secret config is stored locally (as opposed to Vault).""" path = os.path.expanduser(f"~/.rh/secrets/{}.json") if os.path.exists(os.path.expanduser(path)): return True return False
[docs] def in_vault(self, headers=None): """Whether the secret is stored in Vault""" if not self.rns_address: return False resource_uri = rns_client.resource_uri(self.rns_address) resp = rns_client.session.get( f"{rns_client.api_server_url}/{self.USER_ENDPOINT}/{resource_uri}", headers=headers or rns_client.request_headers(), ) if resp.status_code != 200: return False response = read_resp_data(resp) # TODO: switch this to use once vault updates if response and response[list(response.keys())[0]]: return True return False
def is_present(self): if self.values: return True return False