Source code for latch_cli.services.execute

import json
import os
import select
import sys
import termios
import textwrap
from pathlib import Path
from tty import setraw
from typing import Tuple

import kubernetes
import requests
import websocket
from kubernetes.client.api import core_v1_api
from kubernetes.stream import stream
from latch_sdk_config.latch import config

from latch_cli.utils import account_id_from_token, current_workspace, retrieve_or_login


def _construct_kubeconfig(
    cert_auth_data: str,
    cluster_endpoint: str,
    account_id: str,
    access_key: str,
    secret_key: str,
    session_token: str,
) -> str:
    open_brack = "{"
    close_brack = "}"
    region_code = "us-west-2"
    cluster_name = "prion-prod"

    return textwrap.dedent(f"""apiVersion: v1
clusters:
- cluster:
    certificate-authority-data: {cert_auth_data}
    server: {cluster_endpoint}
  name: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
contexts:
- context:
    cluster: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
    user: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
  name: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
current-context: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
kind: Config
preferences: {open_brack}{close_brack}
users:
- name: arn:aws:eks:{region_code}:{account_id}:cluster/{cluster_name}
  user:
    exec:
      apiVersion: client.authentication.k8s.io/v1beta1
      command: aws
      args:
        - --region
        - {region_code}
        - eks
        - get-token
        - --cluster-name
        - {cluster_name}
      env:
        - name: 'AWS_ACCESS_KEY_ID'
          value: '{access_key}'
        - name: 'AWS_SECRET_ACCESS_KEY'
          value: '{secret_key}'
        - name: 'AWS_SESSION_TOKEN'
          value: '{session_token}'""")


def _fetch_pod_info(token: str, task_name: str) -> Tuple[str, str, str]:
    headers = {"Authorization": f"Bearer {token}"}
    data = {"task_name": task_name, "ws_account_id": current_workspace()}

    response = requests.post(config.api.execution.exec, headers=headers, json=data)

    try:
        response = response.json()
        access_key = response["tmp_access_key"]
        secret_key = response["tmp_secret_key"]
        session_token = response["tmp_session_token"]
        cert_auth_data = response["cert_auth_data"]
        cluster_endpoint = response["cluster_endpoint"]
        namespace = response["namespace"]
        aws_account_id = response["aws_account_id"]
    except KeyError as err:
        raise ValueError(f"malformed response: {response}") from err

    return (
        access_key,
        secret_key,
        session_token,
        cert_auth_data,
        cluster_endpoint,
        namespace,
        aws_account_id,
    )


[docs] def execute(task_name: str): """Allows a user to start an interactive shell session in the remote machine that a task is running on. When running a workflow on Latch, its often helpful while debugging to have a direct way of interacting with the machines on which tasks are run. Using `execute`, a user can easily get a shell into the machine on which the specified task is running. Args: task_name: The name of the running task you want a shell into. This is a hash that can be found in the sidebar in the browser display of the running workflow. Example: >>> execute("abcd1234-n0") root@1.2.3.4:~$ """ token = retrieve_or_login() ( access_key, secret_key, session_token, cert_auth_data, cluster_endpoint, namespace, aws_account_id, ) = _fetch_pod_info(token, task_name) account_id = account_id_from_token(token) if int(account_id) < 10: account_id = f"x{account_id}" config_data = _construct_kubeconfig( cert_auth_data, cluster_endpoint, aws_account_id, access_key, secret_key, session_token, ) config_file = Path("config").resolve() with open(config_file, "w") as c: c.write(config_data) kubernetes.config.load_kube_config("config") core_v1 = core_v1_api.CoreV1Api() # TODO pod_name = task_name stdin_channel = bytes([kubernetes.stream.ws_client.STDIN_CHANNEL]) stdout_channel = kubernetes.stream.ws_client.STDOUT_CHANNEL stderr_channel = kubernetes.stream.ws_client.STDERR_CHANNEL class WSStream: def __init__(self): self._wssock = stream( core_v1.connect_get_namespaced_pod_exec, pod_name, namespace, command=["/bin/sh"], stderr=True, stdin=True, stdout=True, tty=True, _preload_content=False, ).sock def send(self, chunk: bytes): self._wssock.send(stdin_channel + chunk, websocket.ABNF.OPCODE_BINARY) def get_frame( self, ) -> Tuple[int, websocket.ABNF]: return self._wssock.recv_data_frame(True) @property def socket(self): return self._wssock.sock def close(self): self._wssock.close() class TTY: def __init__( self, in_stream: int, out_stream: int, err_stream: int, raw: bool = True, ): if raw: setraw(sys.stdin.fileno()) self._stdin = in_stream self._stdout = out_stream self._stderr = err_stream def flush(self) -> bytes: return os.read(self._stdin, 32 * 1024) def write_out(self, chunk: bytes): os.write(self._stdout, chunk) def write_err(self, chunk: bytes): os.write(self._stderr, chunk) @property def in_stream(self): return self._stdin try: old_settings = termios.tcgetattr(sys.stdin.fileno()) tty_ = TTY( sys.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno(), ) try: wsstream = WSStream() except kubernetes.client.rest.ApiException: raise ValueError( "Unable to find requested task name - make sure that you are in the" " correct workspace." ) rlist = [wsstream.socket, tty_.in_stream] while True: rs, _, _ = select.select(rlist, [], []) if tty_.in_stream in rs: chunk = tty_.flush() if len(chunk) > 0: wsstream.send(chunk) if wsstream.socket in rs: opcode, frame = wsstream.get_frame() if opcode == websocket.ABNF.OPCODE_CLOSE: rlist.remove(wsstream.socket) elif opcode == websocket.ABNF.OPCODE_BINARY: channel = frame.data[0] chunk = frame.data[1:] if channel in (stdout_channel, stderr_channel): if len(chunk): if channel == stdout_channel: tty_.write_out(chunk) else: tty_.write_err(chunk) elif channel == kubernetes.stream.ws_client.ERROR_CHANNEL: wsstream.close() error = json.loads(chunk) if error["status"] == "Success": break raise websocket.WebSocketException( f"Status: {error['status']} - Message: {error['message']}" ) else: raise websocket.WebSocketException( f"Unexpected channel: {channel}" ) else: raise websocket.WebSocketException( f"Unexpected websocket opcode: {opcode}" ) finally: termios.tcsetattr(sys.stdin.fileno(), termios.TCSANOW, old_settings)