import configparser
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
import requests
import sky
import typer
import yaml
from runhouse.globals import configs, rns_client
from runhouse.rns.utils.api import load_resp_content, read_resp_data
logger = logging.getLogger(__name__)
[docs]class Secrets:
"""Handles cluster secrets management (reading and writing) across all major cloud providers.
Secrets are securely stored in Hashicorp Vault."""
PROVIDER_NAME = None
CREDENTIALS_FILE = None
USER_ENDPOINT = "user/secret"
GROUP_ENDPOINT = "group/secret"
def __str__(self):
return str(self.__class__.__name__)
@classmethod
def read_secrets(
cls, from_env: bool = False, file_path: Optional[str] = None
) -> Dict:
raise NotImplementedError
[docs] @classmethod
def save_secrets(cls, secrets: Dict, overwrite: bool = False) -> Dict:
"""Save secrets for providers to their respective configs."""
raise NotImplementedError
@classmethod
def has_secrets_file(cls) -> bool:
file_path = cls.default_credentials_path()
if not file_path:
return False
return cls.file_exists(file_path)
@classmethod
def default_credentials_path(cls):
return cls.CREDENTIALS_FILE
[docs] @classmethod
def extract_and_upload(
cls,
headers: Optional[Dict] = None,
interactive=False,
providers: Optional[List[str]] = None,
):
"""Upload all locally configured secrets into Vault. Secrets are loaded from their local config files.
(ex: ``~/.aws/credentials)``. To upload custom secrets for custom providers, see Secrets.put()
Example:
>>> rh.Secrets.extract_and_upload(providers=["aws", "lambda"])
"""
secrets: dict = cls.load_provider_secrets(providers=providers)
for provider_name, provider_secrets in secrets.items():
if interactive:
upload_secrets = typer.confirm(f"Upload secrets for {provider_name}?")
if not upload_secrets:
secrets.pop(provider_name, None)
resp = requests.put(
f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}",
data=json.dumps(secrets),
headers=headers or rns_client.request_headers,
)
if resp.status_code != 200:
raise Exception(
f"Failed to update secrets in Vault: {load_resp_content(resp)}"
)
logging.info(f"Uploaded secrets for to Vault for: {list(secrets)}")
[docs] @classmethod
def download_into_env(
cls,
save_locally: bool = True,
providers: Optional[List[str]] = None,
headers: Optional[Dict] = None,
check_enabled: bool = True,
) -> Dict:
"""Get all user secrets from Vault. Optionally save them down to local config files (where relevant).
Example:
>>> rh.Secrets.download_into_env(providers=["aws", "lambda"])
"""
logger.info("Getting secrets from Vault.")
resp = requests.get(
f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}",
headers=headers or rns_client.request_headers,
)
if resp.status_code != 200:
raise Exception("Failed to download secrets from Vault")
secrets = read_resp_data(resp)
if providers is not None:
secrets = {p: secrets[p] for p in providers if p in secrets}
if save_locally and secrets:
cls.save_provider_secrets(secrets, check_enabled=check_enabled)
logger.info("Saved secrets from Vault to local config files")
else:
return secrets
[docs] @classmethod
def put(
cls,
provider: str,
from_env: bool = False,
file_path: Optional[str] = None,
secret: Optional[dict] = None,
group: Optional[str] = None,
headers: Optional[dict] = None,
):
"""Upload locally configured secrets for a specified provider into Vault.
To upload custom provider secrets, include the secret param and specify the keys and values to upload.
Args:
from_env (bool): Whether to read secrets from environment variables instead of local config files.
(Default: False)
file_path (str or None): If provided, will read secrets directly from specified file instead of default
config file.
secret (dict or None): Dict mapping provider secrets to value, if not loading from env or file.
group (str or None): If provided, will attribute secrets to the specified group.
Example:
>>> rh.put(provider="lambda", secret={"api_key": *****})
>>> rh.put(provider="aws", file_path="~/.aws/credentials")
"""
provider_name = provider.lower()
if not secret and provider_name in cls.enabled_providers(as_str=True):
# if a supported cloud provider is given and no secret is provided, extract it from its default location
p = cls.builtin_provider_class_from_name(provider_name)
if p is not None:
secret = p.read_secrets(from_env=from_env, file_path=file_path)
if not secret and not isinstance(secret, dict):
raise Exception(
f"No secrets dict found or provided for {provider}. Please make sure the credentials "
f"file exists in its default location, or provide credentials with the `secret` param"
)
endpoint = cls.set_endpoint(group)
resp = requests.put(
f"{rns_client.api_server_url}/{endpoint}/{provider}",
data=json.dumps(secret),
headers=headers or rns_client.request_headers,
)
if resp.status_code != 200:
raise Exception(f"Failed to update {provider} secrets in Vault")
[docs] @classmethod
def get(
cls, provider: str, save_to_env: bool = False, group: Optional[str] = None
) -> dict:
"""Read secrets from the Vault service for a given provider and optionally save them to their local config.
If group is provided will read secrets for the specified group.
Example:
>>> rh.get(provider="lambda")
>>> # returns {"api_key": *****}
"""
provider_name = provider.lower()
endpoint = cls.set_endpoint(group)
url = f"{rns_client.api_server_url}/{endpoint}/{provider_name}"
resp = requests.get(url, headers=rns_client.request_headers)
if resp.status_code != 200:
raise Exception(
f"Failed to get secrets from Vault for {provider_name}: {load_resp_content(resp)}"
)
secrets = read_resp_data(resp).get(provider_name, {})
if not secrets:
logger.info(f"No secrets found in Vault for {provider_name}")
return {}
p = cls.builtin_provider_class_from_name(provider_name)
if save_to_env and p is not None:
logger.info(f"Saving secrets for {provider_name} to local config")
p.save_secrets(secrets)
return secrets
[docs] @classmethod
def to(cls, system: Union[str, "Cluster"], providers: Optional[List[str]] = None):
"""Copy secrets to the desired cluster for a list of builtin providers.
Args:
system (str or Cluster): Cluster to send secrets to.
providers (List[str] or None): Providers to send secrets for. If no providers are specified,
will load all builtin providers that are already enabled.
Example:
>>> rh.Secrets.to(my_cluster, providers=["aws", "lambda"])
"""
if isinstance(system, str):
from runhouse import Cluster
system = Cluster.from_name(name=system)
cluster_name = system.name
if not system.is_up():
raise RuntimeError(
f"Cluster {cluster_name} is not up. Run `cluster_obj.up()` to re-up the cluster."
)
enabled_providers: list = cls.enabled_providers(as_str=True)
# Extract secrets from default paths
configured_secrets: dict = cls.load_provider_secrets(providers=providers)
if not configured_secrets or len(configured_secrets) < len(providers):
# If no secrets found in the enabled providers' credentials files check if they exist in Vault
missing_providers = list(set(providers) - set(list(configured_secrets)))
secrets_for_missing_providers: dict = cls.download_into_env(
save_locally=False, providers=missing_providers
)
# Add the missing provider secrets from Vault to the configured secrets
configured_secrets.update(secrets_for_missing_providers)
# Confirm all enabled providers are either configured locally or have secrets stored in Vault
if len(configured_secrets) < len(enabled_providers):
raise Exception(
f"Failed to find secrets locally or in Vault for providers: {missing_providers}. "
f"For enabling locally save the secrets to the provider's default credentials file, "
f"or upload the secrets directly to Vault (e.g: `rh.Secrets.put({missing_providers[0]})`)"
)
# Send provider secrets over RPC to the cluster, then save each provider's secrets into their default
# file paths on the cluster
failed_to_add_secrets: dict = system.add_secrets(configured_secrets)
if len(failed_to_add_secrets) == len(configured_secrets):
raise RuntimeError(
f"Failed to copy all secrets onto the {cluster_name} cluster: {failed_to_add_secrets}"
)
elif failed_to_add_secrets:
logger.warning(
f"Failed to copy some secrets onto the {cluster_name} cluster: {failed_to_add_secrets}"
)
else:
logger.info(f"Finished copying secrets onto the {cluster_name} cluster")
[docs] @classmethod
def update(cls, provider: str, secrets: dict):
"""Add new keys to existing secrets saved for a given provider in Vault.
Example:
>>> rh.Secrets.update(provider="lambda", secrets={"api_key": new_api_key})
"""
existing_secrets = cls.get(provider=provider)
if existing_secrets:
existing_secrets.update(secrets)
cls.put(provider, secret=existing_secrets)
[docs] @classmethod
def delete_from_local_env(cls, providers: Optional[List[str]] = None):
"""Delete secrets credential files and use in Runhouse configs for list of specified providers.
If none are provided, will delete secrets for all providers which have been enabled in the local environment.
Example:
>>> rh.Secrets.delete_from_local_env(provider=["lambda"])
"""
providers = providers or cls.enabled_providers(as_str=True)
for provider in providers:
p = cls.builtin_provider_class_from_name(provider)
if p is not None:
# Use the default credentials path defined in the builtin provider's class
creds_file_path = p.default_credentials_path()
else:
# See if we have the provider's path saved in the rh config
creds_file_path = configs.get("secrets", {}).get(provider)
if not Path(creds_file_path).exists():
creds_file_path = None
configs.delete(provider)
# Delete the local creds file
if creds_file_path is None:
logger.warning(
f"Unable to delete credentials file for {provider}. Please delete the file manually."
)
else:
logger.info(
f"Deleted {provider} credentials file from path: {creds_file_path}"
)
cls.delete_secrets_file(creds_file_path)
[docs] @classmethod
def delete_from_vault(cls, providers: Optional[List[str]] = None):
"""Delete secrets from Vault for specified providers.
Args:
providers (List[str] or None): Providers to delete from vault. If not set,
will delete secrets for all providers which have been enabled in the local environment.
Example:
>>> rh.Secrets.delete_from_vault()
"""
providers = providers or cls.enabled_providers(as_str=True)
for provider in providers:
url = f"{rns_client.api_server_url}/{cls.USER_ENDPOINT}/{provider}"
resp = requests.delete(url, headers=rns_client.request_headers)
if resp.status_code != 200:
logger.error(
f"Failed to delete secrets from Vault: {load_resp_content(resp)}"
)
[docs] @classmethod
def load_provider_secrets(
cls, from_env: bool = False, providers: Optional[List] = None
) -> Dict[str, Dict]:
"""Load secret credentials for all the providers which have been configured locally, or optionally
provide a list of specific providers to load. Returns a dictionary with provider name as the key and
secrets dictionary as value.
Example:
>>> rh.Secrets.load_provider_secrets(providers=["aws"])
"""
secrets = {}
providers = providers or cls.enabled_providers()
for provider in providers:
if isinstance(provider, str):
provider = cls.builtin_provider_class_from_name(provider)
if not provider:
continue
if not from_env and not provider.has_secrets_file():
# no secrets file configured for this provider
continue
provider_secrets = provider.read_secrets(from_env=from_env)
if provider_secrets:
secrets[provider.PROVIDER_NAME] = provider_secrets
return secrets
[docs] @classmethod
def save_provider_secrets(cls, secrets: dict, check_enabled=True):
"""Save secrets for each provider to their respective local configs.
Example:
>>> rh.Secrets.save_provider_secrets(secrets={"lambda": {"api_key": ******}})
"""
for provider_name, provider_secrets in secrets.items():
provider_cls = cls.builtin_provider_class_from_name(provider_name)
if provider_cls is not None:
try:
provider_cls.save_secrets(provider_secrets, overwrite=True)
except Exception as e:
logger.error(
f"Failed to save {provider_name} secrets to config: {e}"
)
continue
if check_enabled:
enabled_providers = cls.enabled_providers(as_str=True)
not_enabled = [
p
for p in secrets.keys()
if p not in enabled_providers
and p in cls.builtin_providers(as_str=True)
]
if not_enabled:
logger.warning(
f"Received secrets {not_enabled} which Runhouse did not auto-detect as configured. "
f"For cloud providers, you may want to run `sky check` to double check that they're "
f"enabled and to see instructions on how to enable them."
)
[docs] @classmethod
def enabled_providers(cls, as_str: bool = False) -> List:
"""Returns a list of cloud provider classes which Runhouse supports out of the box.
If as_str is True, return the names of the providers as strings.
Example:
>>> rh.Secrets.enabled_providers(as_str=True)
"""
sky.check.check(quiet=True)
clouds = sky.global_user_state.get_enabled_clouds()
cloud_names = [str(c).lower() for c in clouds]
if "local" in cloud_names:
cloud_names.remove("local")
cloud_names.append("sky")
try:
import huggingface_hub # noqa
from .huggingface_secrets import HuggingFaceSecrets
if HuggingFaceSecrets.read_secrets():
cloud_names.append("huggingface")
except ModuleNotFoundError:
pass
# Add any SSH keys + GitHub token that were explicitly added
config_secrets = configs.get("secrets", {})
if config_secrets.get("ssh"):
cloud_names.append("ssh")
if config_secrets.get("github"):
cloud_names.append("github")
if as_str:
return cloud_names
return [cls.builtin_provider_class_from_name(c) for c in cloud_names]
[docs] @classmethod
def builtin_providers(cls, as_str: bool = False) -> list:
"""Return list of all Runhouse providers (as class objects) supported out of the box."""
from runhouse.rns.secrets.providers import Providers
if as_str:
return [e.name.lower() for e in Providers]
return [e.value for e in Providers]
@classmethod
def _check_secrets_for_mismatches(
cls, secrets_to_save: dict, secrets_path: str, overwrite: bool
):
"""When overwrite is set to `False` and a secrets file already exists, check if new secrets clash with
what may have already been saved."""
if overwrite or not cls.has_secrets_file():
# If explicitly overwriting or the secrets file does not exist we can ignore
return
existing_secrets: dict = cls.read_secrets(file_path=secrets_path)
provider = existing_secrets.pop("provider", None)
for existing_key, existing_val in existing_secrets.items():
new_val = secrets_to_save.get(existing_key)
if existing_key != new_val:
raise ValueError(
f"Mismatch in {provider} secrets for key `{existing_key}`! Secrets in config file {secrets_path} "
f"do not match those provided. If you intend to overwrite a particular secret key, "
f"please do so manually."
)
[docs] @classmethod
def delete_secrets_file(cls, file_path: Union[str, tuple] = None):
"""Delete local credentials file. If no path is provided will use the default path set for the provider.
Example:
>>> rh.Secrets.delete_secrets_file()
>>> rh.Secrets.delete_secrets_file("~/.aws/credentials")
"""
file_path = file_path or cls.default_credentials_path()
if isinstance(file_path, str):
Path(file_path).unlink(missing_ok=True)
if isinstance(file_path, tuple):
for f in file_path:
Secrets.delete_secrets_file(file_path=f)
@classmethod
def _add_provider_to_rh_config(cls, secrets_for_config: Optional[dict] = None):
"""Save the loaded provider config path to the runhouse config saved in the file system."""
config_secrets = secrets_for_config or {
cls.PROVIDER_NAME: cls.default_credentials_path()
}
configs.set_nested(key="secrets", value=config_secrets)
@classmethod
def set_endpoint(cls, group: Optional[str] = None):
return (
f"{cls.GROUP_ENDPOINT}/{group}" if group is not None else cls.USER_ENDPOINT
)
@staticmethod
def save_to_config_file(parser, file_path: str):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w+") as f:
parser.write(f)
@staticmethod
def save_to_json_file(data: dict, file_path: str):
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w+") as f:
json.dump(data, f, indent=4)
@staticmethod
def read_json_file(file_path: str) -> Dict:
with open(file_path, "r") as config_file:
config_data = json.load(config_file)
return config_data
@staticmethod
def read_config_file(file_path: str):
config = configparser.ConfigParser()
config.read(file_path)
return config
@staticmethod
def read_yaml_file(file_path: str):
with open(file_path, "r") as stream:
config = yaml.safe_load(stream)
return config
@staticmethod
def save_to_yaml_file(data, file_path):
with open(file_path, "w") as yaml_file:
yaml.dump(data, yaml_file, default_flow_style=False)
@staticmethod
def builtin_provider_class_from_name(name: str):
try:
from runhouse.rns.secrets.providers import Providers
return Providers[name.upper()].value
except:
# could be a custom provider, in which case there is no built-in class
return None
@staticmethod
def file_exists(file_path: str) -> bool:
if not Path(file_path).exists():
return False
return True
# TODO AWS secrets (use https://github.com/99designs/aws-vault ?)
# TODO Azure secrets
# TODO GCP secrets
# TODO custom vault secrets