Source code for latch_cli.services.get_params

try:
    from typing import get_args, get_origin
except ImportError:
    from typing_extensions import get_args, get_origin

import enum
import json
import keyword
import typing
from typing import Optional

import google.protobuf.json_format as gpjson
from flyteidl.core.literals_pb2 import Literal as _Literal
from flyteidl.core.types_pb2 import LiteralType as _LiteralType
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType

from latch.types.directory import LatchDir
from latch.types.file import LatchFile
from latch.utils import retrieve_or_login
from latch_cli.services.launch import _get_workflow_interface


class _Unsupported: ...


_simple_table = {
    0: type(None),
    1: int,
    2: float,
    3: str,
    4: bool,
    5: _Unsupported,
    6: _Unsupported,
    7: _Unsupported,
    8: _Unsupported,
    9: _Unsupported,
}

_primitive_table = {
    type(None): None,
    int: 0,
    float: 0.0,
    str: "foo",
    bool: False,
}

# TODO(ayush): fix this to
# (1) support records,
# (2) support fully qualified workflow names,
# (note from kenny) - pretty sure you intend to support the opposite,
# fqn are supported by default, address when you get to this todo
# (3) show a message indicating the generated filename,
# (4) optionally specify the output filename


[docs]def get_params(wf_name: str, wf_version: Optional[str] = None): """Constructs a parameter map for a workflow given its name and an optional version. This function creates a python parameter file that can be used by `launch`. You can specify the specific parameters by editing the file, and then launch an execution on Latch using those parameters with `launch`. Args: wf_name: The unique name of the workflow. wf_version: An optional workflow version. If this argument is not given, `get_params` will default to generating a parameter map of the most recent version of the workflow. Example: >>> get_params("wf.__init__.alphafold_wf") # creates a file called `wf.__init__.alphafold_wf.params.py` that # contains a template parameter map. """ token = retrieve_or_login() wf_id, wf_interface, wf_default_params = _get_workflow_interface( token, wf_name, wf_version ) params = {} wf_vars = wf_interface["variables"] default_wf_vars = wf_default_params["parameters"] for key, value in wf_vars.items(): try: description_json = json.loads(value["description"]) param_name = description_json["name"] except (json.decoder.JSONDecodeError, KeyError) as e: raise ValueError( f"Parameter description json for workflow {wf_name} is malformed" ) from e literal_type_json = value["type"] literal_type = gpjson.ParseDict(literal_type_json, _LiteralType()) python_type = _guess_python_type( LiteralType.from_flyte_idl(literal_type), param_name ) default = True if default_wf_vars[param_name].get("required") is not True: literal_json = default_wf_vars[param_name].get("default") literal = gpjson.ParseDict(literal_json, _Literal()) val = _guess_python_val(Literal.from_flyte_idl(literal), python_type) else: default = False val = _best_effort_default_val(python_type) params[param_name] = (python_type, val, default) import_statements = { LatchFile: "from latch.types import LatchFile", LatchDir: "from latch.types import LatchDir", enum.Enum: "from enum import Enum", } import_types = [] enum_literals = [] param_map_str = "" param_map_str += "\nparams = {" param_map_str += f'\n "_name": "{wf_name}", # Don\'t edit this value.' for param_name, value in params.items(): python_type, python_val, default = value # Check for imports. def _check_and_import(python_type: typing.T): if python_type in import_statements and python_type not in import_types: import_types.append(python_type) def _handle_enum(python_type: typing.T): if type(python_type) is enum.EnumMeta: if enum.Enum not in import_types: import_types.append(enum.Enum) variants = python_type._variants name = python_type._name _enum_literal = f"class {name}(Enum):" for variant in variants: if variant in keyword.kwlist: variant_name = f"_{variant}" else: variant_name = variant _enum_literal += f"\n {variant_name} = '{variant}'" enum_literals.append(_enum_literal) # Parse collection, union types for potential imports and dependent # objects, eg. enum class construction. if get_origin(python_type) is not None: if get_origin(python_type) is list: _check_and_import(get_args(python_type)[0]) _handle_enum(get_args(python_type)[0]) elif get_origin(python_type) is typing.Union: for variant in get_args(python_type): _check_and_import(variant) _handle_enum(variant) else: _check_and_import(python_type) _handle_enum(python_type) python_val, python_type = _get_code_literal(python_val, python_type) if default is True: default = "DEFAULT. " else: default = "" param_map_str += f'\n "{param_name}": {python_val}, # {default}{python_type}' param_map_str += "\n}" with open(f"{wf_name}.params.py", "w") as f: f.write( f'"""Run `latch launch {wf_name}.params.py` to launch this workflow"""\n' ) for t in import_types: f.write(f"\n{import_statements[t]}") for e in enum_literals: f.write(f"\n\n{e}\n") f.write("\n") f.write(param_map_str)
def _get_code_literal(python_val: any, python_type: typing.T): """Construct value that is executable python when templated into a code block.""" if python_type is str or (type(python_val) is str and str in get_args(python_type)): return f'"{python_val}"', python_type if type(python_type) is enum.EnumMeta: name = python_type._name return python_val, f"<enum '{name}'>" if get_origin(python_type) is typing.Union: variants = get_args(python_type) type_repr = "typing.Union[" for i, variant in enumerate(variants): if i < len(variants) - 1: delimiter = ", " else: delimiter = "" type_repr += f"{_get_code_literal(python_val, variant)[1]}{delimiter}" type_repr += "]" return python_val, type_repr if get_origin(python_type) is list: if python_val is None: _, type_repr = _get_code_literal(None, get_args(python_type)[0]) return None, f"typing.List[{type_repr}]" else: collection_literal = "[" if len(python_val) > 0: for i, item in enumerate(python_val): item_literal, type_repr = _get_code_literal( item, get_args(python_type)[0] ) if i < len(python_val) - 1: delimiter = "," else: delimiter = "" collection_literal += f"{item_literal}{delimiter}" else: list_t = get_args(python_type)[0] _, type_repr = _get_code_literal( _best_effort_default_val(list_t), list_t ) collection_literal += "]" return collection_literal, f"typing.List[{type_repr}]" return python_val, python_type def _guess_python_val(literal: _Literal, python_type: typing.T): """Transform flyte literal value to native python value.""" if literal.scalar is not None: if literal.scalar.none_type is not None: return None if literal.scalar.primitive is not None: primitive = literal.scalar.primitive if primitive.string_value is not None: if type(python_type) is enum.EnumMeta: return f"{python_type._name}.{str(primitive.string_value)}" return str(primitive.string_value) if primitive.integer is not None: return int(primitive.integer) if primitive.float_value is not None: return float(primitive.float_value) if primitive.boolean is not None: return bool(primitive.boolean) if literal.scalar.blob is not None: blob = literal.scalar.blob dim = blob.metadata.type.dimensionality if dim == 0: return LatchFile(blob.uri) else: return LatchDir(blob.uri) # collection if literal.collection is not None: p_list = [] for item in literal.collection.literals: p_list.append(_guess_python_val(item, get_args(python_type)[0])) return p_list # sum # enum raise NotImplementedError( f"The flyte literal {literal} cannot be transformed to a python type." ) def _guess_python_type(literal: LiteralType, param_name: str): """Transform flyte type literal to native python type.""" if literal.simple is not None: return _simple_table[literal.simple] if literal.collection_type is not None: return typing.List[_guess_python_type(literal.collection_type, param_name)] if literal.blob is not None: # flyteidl BlobType message for reference: # enum BlobDimensionality { # SINGLE = 0; # MULTIPART = 1; # } dim = literal.blob.dimensionality if dim == 0: return LatchFile else: return LatchDir if literal.union_type is not None: variant_types = [ _guess_python_type(variant, param_name) for variant in literal.union_type.variants ] # Trying to directly construct set of types will throw error if list is # included as 'list' is not hashable. unique_variants = [] for t in variant_types: if t not in unique_variants: unique_variants.append(t) return typing.Union[tuple(variant_types)] if literal.enum_type is not None: # We can hold the variants a proxy class that is also type 'Enum', s.t. # we can parse the variants and define the object in the param map # code. class _VariantCarrier(enum.Enum): ... _VariantCarrier._variants = literal.enum_type.values # Use param name to uniquely identify each enum _VariantCarrier._name = param_name return _VariantCarrier raise NotImplementedError( f"The flyte literal {literal} cannot be transformed to a python type." ) def _best_effort_default_val(t: typing.T): """Produce a "best-effort" default value given a python type.""" if t in _primitive_table: return _primitive_table[t] if t is list: return [] file_like_table = { LatchDir: LatchDir("latch:///foobar"), LatchFile: LatchFile("latch:///foobar"), } if t in file_like_table: return file_like_table[t] if type(t) is enum.EnumMeta: return f"{t._name}.{t._variants[0]}" if get_origin(t) is None: raise NotImplementedError( f"Unable to produce a best-effort value for the python type {t}" ) if get_origin(t) is list: list_args = get_args(t) if len(list_args) == 0: return [] return [_best_effort_default_val(arg) for arg in list_args] if get_origin(t) is typing.Union: return _best_effort_default_val(get_args(t)[0]) raise NotImplementedError( f"Unable to produce a best-effort value for the python type {t}" )