Source code for latch_cli.centromere.utils

import builtins
import contextlib
import os
import random
import string
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Iterator, List, Optional

import docker
import paramiko
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from import module_loader

from latch_cli.constants import latch_constants

[docs]@dataclass class RemoteConnInfo: ip: str username: str jump_key_path: Path ssh_key_path: Path
@contextlib.contextmanager def _add_sys_paths(paths: List[Path]) -> Iterator[None]: paths = [os.fspath(p) for p in paths] try: for p in paths: sys.path.insert(0, p) yield finally: for p in paths: sys.path.remove(p) def _import_flyte_objects(paths: List[Path], module_name: str = "wf"): with _add_sys_paths(paths): class FakeModule(ModuleType): def __getattr__(self, key): # This value is designed to be used in the execution of top # level code without throwing errors but need not do anything. # Here are some cases where this value can be used: # # ``` # from foo import bar # # # Referencing an attribute from an attribute # a = bar.x # # # Calling an attribute # b = bar() # # # Subclassing from an attribute # class C(bar): # ``` # ... class FakeAttr: def __new__(*args, **kwargs): return None return FakeAttr __all__ = [] real_import = builtins.__import__ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): try: return real_import( name, globals=globals, locals=locals, fromlist=fromlist, level=level, ) except (ModuleNotFoundError, AttributeError) as e: return FakeModule(name) # Temporary ctx tells lytekit to skip local execution when # inspecting objects fap = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), raw_output_prefix="bar", ) tmp_context = FlyteContext(fap, inspect_objects_only=True) FlyteContextManager.push_context(tmp_context) builtins.__import__ = fake_import mods = list(module_loader.iterate_modules([module_name])) builtins.__import__ = real_import FlyteContextManager.pop_context() return mods def _construct_dkr_client(ssh_host: Optional[str] = None): """Try many methods of establishing valid connection with client. This was helpful - If `ssh_host` is passed, we attempt to make a connection with a remote machine. """ def _from_env(): host = environment.get("DOCKER_HOST") # empty string for cert path is the same as unset. cert_path = environment.get("DOCKER_CERT_PATH") if cert_path == "": cert_path = None # empty string for tls verify counts as "false". # Any value or 'unset' counts as true. tls_verify = environment.get("DOCKER_TLS_VERIFY") != "" enable_tls = tls_verify or cert_path is not None dkr_client = None try: if not enable_tls: dkr_client = docker.APIClient(host) else: if not cert_path: cert_path = os.path.join(os.path.expanduser("~"), ".docker") tls_config = docker.tls.TLSConfig( client_cert=( os.path.join(cert_path, "cert.pem"), os.path.join(cert_path, "key.pem"), ), ca_cert=os.path.join(cert_path, "ca.pem"), verify=tls_verify, ) dkr_client = docker.APIClient(host, tls=tls_config) except docker.errors.DockerException as de: raise OSError( "Unable to establish a connection to Docker. Make sure that" " Docker is running and properly configured before attempting" " to register a workflow." ) from de return dkr_client if ssh_host is not None: try: return docker.APIClient(ssh_host, version="1.41") except docker.errors.DockerException as de: raise OSError( f"Unable to establish a connection to remote docker host {ssh_host}." ) from de environment = os.environ host = environment.get("DOCKER_HOST") if host is not None and host != "": return _from_env() else: try: # TODO: platform specific socket defaults return docker.APIClient(base_url="unix://var/run/docker.sock") except docker.errors.DockerException as de: raise OSError( "Docker is not running. Make sure that" " Docker is running before attempting to register a workflow." ) from de def _construct_ssh_client( remote_conn_info: RemoteConnInfo, *, use_gateway: bool = True ) -> paramiko.SSHClient: if use_gateway: gateway = paramiko.SSHClient() gateway.load_system_host_keys() gateway.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) gateway_pkey = paramiko.PKey.from_path( path=str(remote_conn_info.jump_key_path.resolve()) ) gateway.connect( latch_constants.jump_host, username=latch_constants.jump_user, pkey=gateway_pkey, ) gateway_transport = gateway.get_transport() if gateway_transport is None: raise ConnectionError("unable to create connection to jump host") sock = gateway_transport.open_channel( kind="direct-tcpip", dest_addr=(remote_conn_info.ip, 22), src_addr=("", 0), ) else: sock = None pkey = paramiko.PKey.from_path(path=str(remote_conn_info.ssh_key_path.resolve())) ssh = paramiko.SSHClient() ssh.load_system_host_keys() ssh.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) ssh.connect( remote_conn_info.ip, username=remote_conn_info.username, sock=sock, pkey=pkey, ) transport = ssh.get_transport() if transport is None: raise ConnectionError( "unable to create connection from jump host to centromere deployment" ) # (kenny) Equivalent of OpenSSH configuration `ServerAliveInterval` # No analogue for `ServerAliveCountMax` in paramiko I could find. transport.set_keepalive(latch_constants.centromere_keepalive_interval) return ssh
[docs]class MaybeRemoteDir: """A temporary directory that exists locally or on a remote machine.""" def __init__(self, ssh_client: Optional[paramiko.SSHClient] = None): self.ssh_client = ssh_client def __enter__(self): return self.create() def __exit__(self, exc_type, exc_value, tb): self.cleanup()
[docs] def create(self): if self.ssh_client is None: self._tempdir = tempfile.TemporaryDirectory() return str(Path( td = build_random_string() self._tempdir = f"/tmp/{td}" self.ssh_client.exec_command(f"mkdir {self._tempdir}") return self._tempdir
[docs] def cleanup(self): if self.ssh_client is not None: self.ssh_client.exec_command(f"rm -rf {self._tempdir}") return if self._tempdir is not None and not isinstance(self._tempdir, str): self._tempdir.cleanup()
[docs]def build_random_string(k: int = 8) -> str: return "".join( random.choices( string.ascii_uppercase + string.ascii_lowercase + string.digits, k=k ) )