Source code for latch_cli.services.register.register

import contextlib
import re
import shutil
import sys
import tempfile
import time
from pathlib import Path
from typing import Iterable, List, Optional

import click
import gql
import latch_sdk_gql.execute as l_gql
from scp import SCPClient

from ...centromere.ctx import _CentromereCtx
from ...centromere.utils import MaybeRemoteDir
from ...utils import WorkflowType, current_workspace
from ..workspace import _get_workspaces
from .constants import ANSI_REGEX, MAX_LINES
from .utils import (
    DockerBuildLogItem,
    build_image,
    register_serialized_pkg,
    serialize_pkg_in_container,
    upload_image,
)


def _delete_lines(num: int):
    """Deletes the previous len(lines) lines, assuming cursor is on a
    new line just below the first line to be deleted"""
    for i in range(num):
        click.echo("\x1b[1F\x1b[0G\x1b[2K", nl=False)


def _print_window(cur_lines: List[str], line: str):
    """Prints the lines curr_lines[1:] and line, overwriting curr_lines
    in the process"""
    if line == "":
        return cur_lines
    elif len(cur_lines) >= MAX_LINES:
        line = ANSI_REGEX.sub("", line)
        new_lines = cur_lines[len(cur_lines) - MAX_LINES + 1 :]
        new_lines.append(line)
        _delete_lines(len(cur_lines))
        for s in new_lines:
            click.echo("\x1b[38;5;245m" + s + "\x1b[0m")
        return new_lines
    else:
        click.echo("\x1b[38;5;245m" + line + "\x1b[0m")
        cur_lines.append(line)
        return cur_lines


docker_build_step_pat = re.compile("^Step [0-9]+/[0-9]+ :")








def _print_reg_resp(resp, image):
    click.secho(f"Registering workflow {image}", bold=True)
    version = image.split(":")[1]

    if not resp.get("success"):
        error_str = f"Failed:\n\n"
        if resp.get("stderr") is not None:
            for line in resp.get("stderr").split("\n"):
                if not line:
                    continue

                if line.startswith('{"json"'):
                    continue

                error_str += line + "\n"

        if "task with different structure already exists" in error_str:
            error_str = (
                f"Version {version} already exists. Make sure that you've saved any"
                " changes you made."
            )

        click.secho(f"\n{error_str}", fg="red", bold=True)
        sys.exit(1)
    elif not "Successfully registered file" in resp["stdout"]:
        click.secho(
            f"\nVersion ({version}) already exists."
            " Make sure that you've saved any changes you made.",
            fg="red",
            bold=True,
        )
        sys.exit(1)

    click.echo(resp.get("stdout"))





def _build_and_serialize(
    ctx: _CentromereCtx,
    image_name: str,
    context_path: Path,
    tmp_dir: str,
    dockerfile: Optional[Path] = None,
    *,
    progress_plain: bool = False,
):
    assert ctx.pkg_root is not None

    jit_wf = None
    if ctx.workflow_type == WorkflowType.snakemake:
        assert ctx.snakefile is not None
        assert ctx.version is not None

        from ...snakemake.serialize import generate_jit_register_code
        from ...snakemake.workflow import build_jit_register_wrapper

        jit_wf = build_jit_register_wrapper()
        generate_jit_register_code(
            jit_wf,
            ctx.pkg_root,
            ctx.snakefile,
            ctx.version,
            image_name,
            current_workspace(),
        )

    image_build_logs = build_image(ctx, image_name, context_path, dockerfile)
    print_and_write_build_logs(
        image_build_logs, image_name, ctx.pkg_root, progress_plain=progress_plain
    )

    if ctx.workflow_type == WorkflowType.snakemake:
        assert jit_wf is not None
        assert ctx.dkr_repo is not None

        from ...snakemake.serialize import serialize_jit_register_workflow

        serialize_jit_register_workflow(jit_wf, tmp_dir, image_name, ctx.dkr_repo)
    else:
        serialize_logs, container_id = serialize_pkg_in_container(
            ctx, image_name, tmp_dir
        )
        print_serialize_logs(serialize_logs, image_name)

        assert ctx.dkr_client is not None
        exit_status = ctx.dkr_client.wait(container_id)
        if exit_status["StatusCode"] != 0:
            click.secho("\nWorkflow failed to serialize", fg="red", bold=True)
            sys.exit(1)

    click.echo()
    upload_image_logs = upload_image(ctx, image_name)
    print_upload_logs(upload_image_logs, image_name)


def _recursive_list(directory: Path) -> List[Path]:
    res: List[Path] = []

    stack: List[Path] = [directory]
    while len(stack) > 0:
        cur = stack.pop()
        for x in cur.iterdir():
            res.append(x)

            if x.is_dir():
                stack.append(x)

    return res


[docs]def register( pkg_root: str, disable_auto_version: bool = False, remote: bool = False, skip_confirmation: bool = False, snakefile: Optional[Path] = None, *, progress_plain=False, use_new_centromere: bool = False, ): """Registers a workflow, defined as python code, with Latch. Kicks off a three-legged OAuth2.0 flow outlined in `RFC6749`_. Logic scaffolding this flow and detailed documentation can be found in the `latch.auth` package From a high-level, the user will be redirected to a browser and prompted to login. The SDK meanwhile spins up a callback server on a separate thread that will be hit when the browser login is successful with an access token. .. _RFC6749: https://datatracker.ietf.org/doc/html/rfc6749 The major constituent steps are: - Constructing a Docker image - Serializing flyte objects within an instantiated container - Uploading the container with a latch-owned registry - Registering serialized objects + the container with latch. The Docker image is constructed by inferring relevant files + dependencies from the workflow package code itself. If a Dockerfile is provided explicitly, it will be used for image construction instead. The registration flow makes heavy use of `Flyte`_, and while the Latch SDK modifies many components to play nicely with Latch, eg. platform API, user-specific auth, the underlying concepts are nicely summarized in the `flytekit documentation`_. Args: pkg_root: A valid path pointing to the worklow code a user wishes to register. The path can be absolute or relative. The path is always a directory, with its structure exactly as constructed and described in the `cli.services.init` function. Example: >>> register("./example_workflow") .. _Flyte: https://docs.flyte.org .. _flytekit documentation: https://docs.flyte.org/en/latest/concepts/registration.html """ if snakefile is not None: if remote: click.secho( "Cannot use remote builds with Snakemake, switching to a local build\n", fg="yellow", bold=True, ) remote = False try: import snakemake except ImportError as e: click.secho("\n`snakemake` package is not installed.", fg="red", bold=True) sys.exit(1) with _CentromereCtx( Path(pkg_root), disable_auto_version=disable_auto_version, remote=remote, snakefile=snakefile, use_new_centromere=use_new_centromere, ) as ctx: assert ctx.workflow_name is not None, "Unable to determine workflow name" assert ctx.version is not None, "Unable to determine workflow version" # todo(maximsmol): we really want the workflow display name here click.echo( " ".join( [click.style("Workflow name:", fg="bright_blue"), ctx.workflow_name] ) ) click.echo(" ".join([click.style("Version:", fg="bright_blue"), ctx.version])) workspaces = _get_workspaces() ws_name = next( ( x[1] for x in workspaces.items() if x[0] == current_workspace() or (current_workspace() == "" and x[1] == "Personal Workspace") ), "N/A", ) click.echo( " ".join( [ click.style("Target workspace:", fg="bright_blue"), ws_name, f"({current_workspace()})", ] ) ) click.echo( " ".join( [ click.style("Workflow root:", fg="bright_blue"), str(ctx.default_container.pkg_dir), ] ) ) if use_new_centromere: click.secho("Using experimental registration server.", fg="yellow") if not skip_confirmation: if not click.confirm("Start registration?"): click.secho("Cancelled", bold=True) return else: click.secho("Skipping confirmation because of --yes", bold=True) click.secho("\nInitializing registration", bold=True) transport = None scp = None click.echo( " ".join( [ click.style("Docker Image:", fg="bright_blue"), ctx.default_container.image_name, ] ) ) click.echo() if remote: click.secho("Connecting to remote server for docker build\n", bold=True) assert ctx.ssh_client is not None transport = ctx.ssh_client.get_transport() assert transport is not None scp = SCPClient(transport=transport, sanitize=lambda x: x) with contextlib.ExitStack() as stack: td: str = stack.enter_context(MaybeRemoteDir(ctx.ssh_client)) _build_and_serialize( ctx, ctx.default_container.image_name, ctx.default_container.pkg_dir, td, dockerfile=ctx.default_container.dockerfile, progress_plain=progress_plain, ) if remote: local_td = Path(stack.enter_context(tempfile.TemporaryDirectory())) assert scp is not None scp.get(f"{td}/*", local_path=str(local_td), recursive=True) else: local_td = Path(td) protos = _recursive_list(local_td) for task_name, container in ctx.container_map.items(): task_td = stack.enter_context(MaybeRemoteDir(ctx.ssh_client)) try: _build_and_serialize( ctx, container.image_name, ctx.default_container.pkg_dir, task_td, dockerfile=container.dockerfile, progress_plain=progress_plain, ) if remote: local_task_td = Path( stack.enter_context(tempfile.TemporaryDirectory()) ) assert scp is not None scp.get( f"{task_td}/*", local_path=str(local_td), recursive=True ) new_protos = _recursive_list(local_td) else: local_task_td = Path(task_td) new_protos = _recursive_list(local_task_td) try: split_task_name = task_name.split(".") task_name = ".".join( split_task_name[split_task_name.index("wf") :] ) for new_proto in new_protos: if task_name in new_proto.name: protos = [ new_proto if new_proto.name == f.name else f for f in protos ] except ValueError as e: raise ValueError( f"Unable to match {task_name} to any of the protobuf files" f" in {new_protos}" ) from e except TypeError as e: raise ValueError( "The path to your provided dockerfile ", f"{container.dockerfile} given to {task_name} is invalid.", ) from e reg_resp = register_serialized_pkg( protos, ctx.token, ctx.version, current_workspace() ) _print_reg_resp(reg_resp, ctx.default_container.image_name) click.secho("Successfully registered workflow.", fg="green", bold=True) wf_infos = [] retries = 0 wf_name = ctx.workflow_name if snakefile is not None: # todo(maximsmol): this is quite awful wf_name = f"{wf_name}_jit_register" while len(wf_infos) == 0: wf_infos = l_gql.execute( gql.gql(""" query workflowQuery($name: String, $ownerId: BigInt, $version: String) { workflowInfos(condition: { name: $name, ownerId: $ownerId, version: $version}) { nodes { id } } } """), { "name": wf_name, "version": ctx.version, "ownerId": current_workspace(), }, )["workflowInfos"]["nodes"] time.sleep(1) if retries >= 5: click.secho( "Failed to query workflow ID in 5 seconds.", fg="red", bold=True ) click.secho( "This could be due to high demand or a bug in the platform.", fg="red", ) click.secho( "If the workflow is not visible in latch console, contact" " support.", fg="red", ) break retries += 1 if len(wf_infos) > 0: if len(wf_infos) > 1: click.secho( f"Worfklow {ctx.workflow_name}:{ctx.version} is not unique. The" " link below might be wrong.", fg="yellow", ) wf_id = wf_infos[0]["id"] click.secho(f"https://console.latch.bio/workflows/{wf_id}", fg="green")