Source code for latch.registry.table

from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import (
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    TypedDict,
    Union,
    cast,
    get_args,
    get_origin,
    overload,
)

import gql
import graphql.language as l
import graphql.language.parser as lp
from latch_sdk_gql.execute import execute
from latch_sdk_gql.utils import (
    _GqlJsonValue,
    _json_value,
    _name_node,
    _parse_selection,
    _var_def_node,
    _var_node,
)
from typing_extensions import TypeAlias

from latch.registry.record import NoSuchColumnError, Record
from latch.registry.types import (
    Column,
    InvalidValue,
    LinkedRecordType,
    RecordValue,
    RegistryEnumDefinition,
    RegistryPythonType,
    RegistryPythonValue,
)
from latch.registry.upstream_types.types import DBType, RegistryType
from latch.registry.upstream_types.values import DBValue, EmptyCell
from latch.registry.utils import to_python_literal, to_python_type, to_registry_literal
from latch.types.directory import LatchDir
from latch.types.file import LatchFile

from ..types.json import JsonValue


class _AllRecordsNode(TypedDict):
    sampleId: str
    sampleName: str
    sampleDataKey: str
    sampleDataValue: DBValue


class _ColumnNode(TypedDict("_ColumnNodeReserved", {"def": DBValue})):
    key: str
    type: DBType


@dataclass
class _Cache:
    display_name: Optional[str] = None
    columns: Optional[Dict[str, Column]] = None
    project_id: Optional[str] = None


[docs]@dataclass(frozen=True) class Table: """Registry table. Contains :class:`records <latch.registry.record.Record>`. :meth:`~latch.registry.project.Project.list_tables` is the typical way to get a :class:`Table`. """ _cache: _Cache = field( default_factory=lambda: _Cache(), init=False, repr=False, hash=False, compare=False, ) id: str """Unique identifier."""
[docs] def load(self) -> None: """(Re-)populate this table instance's cache. Future calls to most getters will return immediately without making a network request. Always makes a network request. """ data = execute( gql.gql(""" query TableQuery($id: BigInt!) { catalogExperiment(id: $id) { id displayName catalogExperimentColumnDefinitionsByExperimentId { nodes { key type def } } projectId } } """), variables={"id": self.id}, )["catalogExperiment"] # todo(maximsmol): deal with nonexistent tables self._cache.project_id = data["projectId"] self._cache.display_name = data["displayName"] self._cache.columns = {} columns: List[_ColumnNode] = data[ "catalogExperimentColumnDefinitionsByExperimentId" ]["nodes"] for x in columns: py_type = to_python_type(x["type"]["type"]) if x["type"]["allowEmpty"]: py_type = Union[py_type, EmptyCell] cur = Column(x["key"], py_type, x["type"]) self._cache.columns[cur.key] = cur
# get_project_id @overload def get_project_id(self, *, load_if_missing: Literal[True] = True) -> str: ... @overload def get_project_id(self, *, load_if_missing: bool) -> Optional[str]: ...
[docs] def get_project_id(self, *, load_if_missing: bool = True) -> Optional[str]: """Get the ID of the project that contains this table. Args: load_if_missing: If true, :meth:`load` the project ID if not in cache. If false, return `None` if not in cache. Returns: ID of the :class:`Project` containing this table. """ if self._cache.project_id is None: if not load_if_missing: return None self.load() return self._cache.project_id
# get_display_name @overload def get_display_name(self, *, load_if_missing: Literal[True] = True) -> str: ... @overload def get_display_name(self, *, load_if_missing: bool) -> Optional[str]: ...
[docs] def get_display_name(self, *, load_if_missing: bool = True) -> Optional[str]: """Get the display name of this table. This is an opaque string that can contain any valid Unicode data. Display names are *not unique* and *must never be used as identifiers*. Use :attr:`id` instead. Args: load_if_missing: If true, :meth:`load` the display name if not in cache. If false, return `None` if not in cache. Returns: Display name. """ if self._cache.display_name is None and load_if_missing: self.load() return self._cache.display_name
# get_columns @overload def get_columns( self, *, load_if_missing: Literal[True] = True ) -> Dict[str, Column]: ... @overload def get_columns(self, *, load_if_missing: bool) -> Optional[Dict[str, Column]]: ...
[docs] def get_columns( self, *, load_if_missing: bool = True ) -> Optional[Dict[str, Column]]: """Get the columns of this table. Args: load_if_missing: If true, :meth:`load` the column list if not in cache. If false, return `None` if not in cache. Returns: Mapping between column keys and :class:`columns <latch.registry.types.Column>`. """ if self._cache.columns is None and load_if_missing: self.load() return self._cache.columns
[docs] def list_records(self, *, page_size: int = 100) -> Iterator[Dict[str, Record]]: """List Registry records contained in this table. Args: page_size: Maximum size of a page of records. The last page may be shorter than this value. Yields: Pages of records. Each page is a mapping between record IDs and :class:`records <latch.registry.record.Record>`. """ cols = self.get_columns() # todo(maximsmol): because allSamples returns each column as its own # row, we can't paginate by samples because we don't know when a sample is finished nodes: List[_AllRecordsNode] = execute( gql.gql(""" query TableQuery($id: BigInt!) { catalogExperiment(id: $id) { allSamples { nodes { sampleId sampleName sampleDataKey sampleDataValue } } } } """), { "id": self.id, }, )["catalogExperiment"]["allSamples"]["nodes"] # todo(maximsmol): deal with nonexistent tables record_names: Dict[str, str] = {} record_values: Dict[str, Dict[str, RecordValue]] = {} for node in nodes: record_names[node["sampleId"]] = node["sampleName"] vals = record_values.setdefault(node["sampleId"], {}) col = cols.get(node["sampleDataKey"]) if col is None: continue # todo(maximsmol): in the future, allow storing or yielding values that failed to parse vals[col.key] = to_python_literal( node["sampleDataValue"], col.upstream_type["type"] ) page: Dict[str, Record] = {} for id, values in record_values.items(): for col in cols.values(): if col.key in values: continue if not col.upstream_type["allowEmpty"]: values[col.key] = InvalidValue("") cur = Record(id) cur._cache.name = record_names[id] cur._cache.values = values cur._cache.columns = cols page[id] = cur if len(page) == page_size: yield page page = {} if len(page) > 0: yield page
[docs] def get_dataframe(self): """Get a pandas DataFrame of all records in this table. Returns: DataFrame representing all records in this table. """ try: import pandas as pd except ImportError: raise ImportError( "pandas needs to be installed to use get_dataframe. Install it with" " `pip install pandas` or `pip install latch[pandas]`." ) records = [] for page in self.list_records(): for record in page.values(): full_record = record.get_values() if full_record is not None: full_record["Name"] = record.get_name() records.append(full_record) if len(records) == 0: cols = self.get_columns() if cols is None: return pd.DataFrame() return pd.DataFrame(columns=list(cols.keys())) return pd.DataFrame(records)
[docs] @contextmanager def update(self, *, reload_on_commit: bool = True) -> Iterator["TableUpdate"]: """Start an update transaction. The transaction will commit when the context manager closes unless an error occurs. No changes will occur until the transaction commits. The transaction can be cancelled by running :meth:`TableUpdate.clear` before closing the context manager. Args: reload_on_commit: If true, :meth:`load` this table after the transaction commits. Returns: Context manager for the new transaction. """ upd = TableUpdate(self) yield upd upd.commit() if reload_on_commit: self.load()
def __repr__(self): display_name = self.get_display_name(load_if_missing=False) if display_name is not None: return f"Table(id={self.id}, display_name={display_name})" return f"Table(id={self.id})" def __str__(self): return repr(self)
@dataclass(frozen=True) class _TableRecordsUpsertData: name: str values: Dict[str, DBValue] @dataclass(frozen=True) class _TableRecordsDeleteData: name: str @dataclass(frozen=True) class _TableColumnUpsertData: key: str type: DBType _TableRecordsMutationData: TypeAlias = Union[ _TableRecordsUpsertData, _TableRecordsDeleteData, _TableColumnUpsertData, ]
[docs]class InvalidColumnTypeError(ValueError): """Failure to use an invalid column type. Attributes: key: Identifier of the invalid column. invalid_type: Requested column type. """ def __init__( self, key: str, invalid_type: Union[Type[object], RegistryPythonType], msg: str ): super().__init__( f"invalid column type for {repr(key)}. {msg}: {repr(invalid_type)}" ) self.key = key self.invalid_type = invalid_type
[docs]@dataclass(frozen=True) class TableUpdate: """Ongoing :class:`Table` update transaction. Groups requested updates to commit everything together in one network request. Transactions are atomic. The entire transaction either commits or fails with an exception. """ _record_mutations: List[_TableRecordsMutationData] = field( default_factory=list, init=False, repr=False, hash=False, compare=False, ) table: Table # upsert record
[docs] def upsert_record_raw_unsafe( self, *, name: str, values: Dict[str, DBValue] ) -> None: """DANGEROUSLY Update or create a record using raw :class:`values <latch.registry.upstream_types.values.DBValue>`. Values are not checked against the existing columns. A transport error will be thrown if non-existent columns are updated. The update will succeed if values do not match column types and future reads will produce :class:`invalid values <latch.registry.types.InvalidValue>`. Unsafe: The value dictionary is not validated in any way. It is possible to create completely invalid record values that are not a valid Registry value of any type. Future reads will fail catastrophically when trying to parse these values. Args: name: Target record name. values: Column values that to set. """ self._record_mutations.append(_TableRecordsUpsertData(name, values))
[docs] def upsert_record(self, name: str, **values: RegistryPythonValue) -> None: """Update or create a record. A transport error will be thrown if non-existent columns are updated. It is possible that the column definitions changed since the table was last loaded and the update will be issued with values that do not match current column types. This will succeed with no error and future reads will produce :class:`invalid values <latch.registry.types.InvalidValue>`. Args: name: Target record name. values: Column values to set. """ cols = self.table.get_columns() db_vals: Dict[str, DBValue] = {} for k, v in values.items(): col = cols.get(k) if col is None: raise NoSuchColumnError(k) db_vals[k] = to_registry_literal(v, col.upstream_type["type"]) self._record_mutations.append(_TableRecordsUpsertData(name, db_vals))
def _add_record_upserts_selection( self, upserts: List[_TableRecordsUpsertData], mutations: List[l.SelectionNode], vars: Dict[str, Tuple[l.TypeNode, JsonValue]], ) -> None: if len(upserts) == 0: return names: _GqlJsonValue = [x.name for x in upserts] values: JsonValue = [cast(Dict[str, JsonValue], x.values) for x in upserts] res = _parse_selection(""" catalogMultiUpsertSamples(input: {}) { clientMutationId } """) assert isinstance(res, l.FieldNode) argDataVar = f"upd{len(mutations)}ArgData" args = l.ArgumentNode() args.name = _name_node("input") args.value = _json_value( { "argExperimentId": self.table.id, "argNames": names, "argData": _var_node(argDataVar), } ) res.alias = _name_node(f"upd{len(mutations)}") res.arguments = tuple([args]) mutations.append(res) vars[argDataVar] = (l.parse_type("[JSON]"), values) # delete record
[docs] def delete_record(self, name: str) -> None: """Delete a record. Args: name: Target record name. """ self._record_mutations.append(_TableRecordsDeleteData(name))
def _add_record_deletes_selection( self, deletes: List[_TableRecordsDeleteData], mutations: List[l.SelectionNode] ) -> None: if len(deletes) == 0: return names: _GqlJsonValue = [x.name for x in deletes] res = _parse_selection(""" catalogMultiDeleteSampleByName(input: {}) { clientMutationId } """) assert isinstance(res, l.FieldNode) args = l.ArgumentNode() args.name = _name_node("input") args.value = _json_value( { "argExperimentId": self.table.id, "argNames": names, } ) res.alias = _name_node(f"upd{len(mutations)}") res.arguments = tuple([args]) mutations.append(res) # upsert column
[docs] def upsert_column( self, key: str, type: RegistryPythonType, *, required: bool = False, ): """Create a column. Support for updating columns is planned. Args: key: Identifier of the new column. type: Type of the new column. Only a limited set of Python types is currently supported and will be expanded over time. :class:`latch.registry.types.RegistryPythonType` represents the currently supported types. required: If true, records without a value for this column are considered invalid. Note that an explicit `None` value is different from a missing/empty value. `None` is a valid value for an `Optional` (nullable) column marked as required. """ registry_type: Optional[RegistryType] = None if type is str: registry_type = {"primitive": "string"} if type is int: registry_type = {"primitive": "integer"} if type is float: registry_type = {"primitive": "number"} if type is date: registry_type = {"primitive": "date"} if type is datetime: registry_type = {"primitive": "datetime"} if type is bool: registry_type = {"primitive": "boolean"} if type is LatchFile: registry_type = {"primitive": "blob"} if type is LatchDir: registry_type = {"primitive": "blob", "metadata": {"nodeType": "dir"}} origin = get_origin(type) if origin is not None: if issubclass(origin, List): inner_type = get_args(type)[0] inner_type_origin = get_origin(inner_type) if inner_type_origin is not None: if issubclass(inner_type_origin, LinkedRecordType): experiment_id = get_args(get_args(inner_type)[0])[0] registry_type = { "array": { "primitive": "link", "experimentId": experiment_id, } } else: raise InvalidColumnTypeError( key, type, "Unsupported list inner type" ) elif issubclass(inner_type, LatchFile): registry_type = {"array": {"primitive": "blob"}} elif issubclass(inner_type, LatchDir): registry_type = { "array": {"primitive": "blob", "metadata": {"nodeType": "dir"}} } else: raise InvalidColumnTypeError( key, type, "Unsupported list inner type" ) if issubclass(origin, LinkedRecordType): experiment_id = get_args(get_args(type)[0])[0] registry_type = {"primitive": "link", "experimentId": experiment_id} if issubclass(origin, RegistryEnumDefinition): members = list(get_args(t)[0] for t in get_args(get_args(type)[0])) for x in members: if isinstance(x, str): continue raise InvalidColumnTypeError( key, type, f"Enum value {repr(x)} is not a string" ) registry_type = { "primitive": "enum", "members": members, } if isinstance(type, Enum): members: List[str] = [] for f in cast(Type[Enum], type): if not isinstance(f.value, str): raise InvalidColumnTypeError( key, type, f"Enum value for {repr(f.name)} ({repr(f.value)}) is not a" " string", ) members.append(f.value) registry_type = { "primitive": "enum", "members": members, } if registry_type is None: raise InvalidColumnTypeError(key, type, "Unsupported type") db_type: DBType = {"type": registry_type, "allowEmpty": not required} self._record_mutations.append(_TableColumnUpsertData(key, db_type))
def _add_column_upserts_selection( self, upserts: List[_TableColumnUpsertData], mutations: List[l.SelectionNode], vars: Dict[str, Tuple[l.TypeNode, JsonValue]], ) -> None: if len(upserts) == 0: return keys: _GqlJsonValue = [x.key for x in upserts] types: JsonValue = [cast(JsonValue, x.type) for x in upserts] res = _parse_selection(""" catalogExperimentColumnDefinitionMultiUpsert(input: {}) { clientMutationId } """) assert isinstance(res, l.FieldNode) argTypesVar = f"upd{len(mutations)}ArgTypes" args = l.ArgumentNode() args.name = _name_node("input") args.value = _json_value( { "argExperimentId": self.table.id, "argKeys": keys, "argTypes": _var_node(argTypesVar), } ) res.alias = _name_node(f"upd{len(mutations)}") res.arguments = tuple([args]) mutations.append(res) vars[argTypesVar] = (l.parse_type("[JSON]!"), types) # transaction
[docs] def commit(self) -> None: """Commit this table update transaction. May be called multiple times. All pending updates are committed with one network request. Atomic. The entire transaction either commits or fails with an exception. """ mutations: List[l.SelectionNode] = [] vars: Dict[str, Tuple[l.TypeNode, JsonValue]] = {} if len(self._record_mutations) == 0: return def _add_record_data_selection(cur): if isinstance(cur[0], _TableRecordsUpsertData): self._add_record_upserts_selection(cur, mutations, vars) if isinstance(cur[0], _TableRecordsDeleteData): self._add_record_deletes_selection(cur, mutations) if isinstance(cur[0], _TableColumnUpsertData): self._add_column_upserts_selection(cur, mutations, vars) cur = [self._record_mutations[0]] for mut in self._record_mutations[1:]: if isinstance(mut, type(cur[0])): cur.append(mut) continue _add_record_data_selection(cur) cur = [mut] _add_record_data_selection(cur) sel_set = l.SelectionSetNode() sel_set.selections = tuple(mutations) doc = l.parse(""" mutation TableUpdate { placeholder } """) assert len(doc.definitions) == 1 mut = doc.definitions[0] assert isinstance(mut, l.OperationDefinitionNode) mut.selection_set = sel_set mut.variable_definitions = tuple( _var_def_node(k, t) for k, (t, _) in vars.items() ) # todo(maximsmol): catch errors here and raise appropriate Python exceptions # 1. column upsert: already exists execute(doc, {k: v for k, (_, v) in vars.items()}) self.clear()
[docs] def clear(self): """Remove pending updates. May be called to cancel any pending updates that have not been committed. """ self._record_mutations.clear()