import inspect
from dataclasses import is_dataclass
from textwrap import dedent
from typing import Callable, Dict, Union, get_args, get_origin
import click
from flytekit import workflow as _workflow
from flytekit.core.workflow import PythonFunctionWorkflow
from latch.types.metadata import LatchAuthor, LatchMetadata, LatchParameter
from latch_cli.utils import best_effort_display_name
def _generate_metadata(f: Callable) -> LatchMetadata:
    signature = inspect.signature(f)
    metadata = LatchMetadata(f.__name__, LatchAuthor())
    metadata.parameters = {
        param: LatchParameter(display_name=best_effort_display_name(param))
        for param in signature.parameters
    }
    return metadata
def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None:
    if f.__doc__ is None:
        f.__doc__ = f"{f.__name__}\n\nSample Description"
    short_desc, long_desc = f.__doc__.split("\n", 1)
    f.__doc__ = f"{short_desc}\n{dedent(long_desc)}\n\n" + str(metadata)
# this weird Union thing is to ensure backwards compatibility,
# so that when users call @workflow without any arguments or
# parentheses, the workflow still serializes as expected
[docs]def workflow(
    metadata: Union[LatchMetadata, Callable]
) -> Union[PythonFunctionWorkflow, Callable]:
    if isinstance(metadata, Callable):
        f = metadata
        if f.__doc__ is None or "__metadata__:" not in f.__doc__:
            metadata = _generate_metadata(f)
            _inject_metadata(f, metadata)
        return _workflow(f)
    def decorator(f: Callable):
        signature = inspect.signature(f)
        wf_params = signature.parameters
        updated_params: Dict[str, LatchParameter] = {}
        for wf_param in wf_params:
            updated_params[wf_param] = (
                LatchParameter(display_name=best_effort_display_name(wf_param))
                if wf_param not in metadata.parameters
                else metadata.parameters[wf_param]
            )
        metadata.parameters = updated_params
        in_meta_not_in_wf = []
        for meta_param in metadata.parameters:
            if meta_param not in wf_params:
                in_meta_not_in_wf.append(meta_param)
        if len(in_meta_not_in_wf) > 0:
            error_str = (
                "Inconsistency detected between parameters in your `LatchMetadata`"
                " object and parameters in your workflow signature.\n\n"
                "The following parameters appear in your `LatchMetadata` object"
                " but not in your workflow signature:\n\n"
            )
            for meta_param in in_meta_not_in_wf:
                error_str += f"    \x1b[1m{meta_param}\x1b[22m\n"
            error_str += (
                "\nPlease resolve these inconsistencies and ensure that your"
                " `LatchMetadata` object and workflow signature have the same"
                " parameters."
            )
            click.secho(error_str, fg="red")
            raise click.exceptions.Exit(1)
        for name, meta_param in metadata.parameters.items():
            if meta_param.samplesheet is not True:
                continue
            annotation = wf_params[name].annotation
            origin = get_origin(annotation)
            args = get_args(annotation)
            valid = (
                origin is not None
                and issubclass(origin, list)
                and is_dataclass(args[0])
            )
            if not valid:
                click.secho(
                    f"parameter marked as samplesheet is not valid: {name} "
                    f"in workflow {f.__name__} must be a list of dataclasses",
                    fg="red",
                )
                raise click.exceptions.Exit(1)
        _inject_metadata(f, metadata)
        return _workflow(f)
    return decorator