Source code for labelbox.schema.slice

from dataclasses import dataclass
from typing import Optional, Tuple, Union
import warnings
from labelbox.orm.db_object import DbObject, experimental
from labelbox.orm.model import Field
from labelbox.pagination import PaginatedCollection
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
from labelbox.schema.export_task import ExportTask
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.task import Task


[docs]class Slice(DbObject): """ A Slice is a saved set of filters (saved query). This is an abstract class and should not be instantiated. Attributes: name (datetime) description (datetime) created_at (datetime) updated_at (datetime) filter (json) """ name = Field.String("name") description = Field.String("description") created_at = Field.DateTime("created_at") updated_at = Field.DateTime("updated_at") filter = Field.Json("filter")
[docs] @dataclass class DataRowIdAndGlobalKey: id: UniqueId global_key: Optional[GlobalKey] def __init__(self, id: str, global_key: Optional[str]): self.id = UniqueId(id) self.global_key = GlobalKey(global_key) if global_key else None def to_hash(self): return { "id": self.id.key, "global_key": self.global_key.key if self.global_key else None }
[docs]class CatalogSlice(Slice): """ Represents a Slice used for filtering data rows in Catalog. """
[docs] def get_data_row_ids(self) -> PaginatedCollection: """ Fetches all data row ids that match this Slice Returns: A PaginatedCollection of mapping of data row ids to global keys """ warnings.warn( "get_data_row_ids will be deprecated. Use get_data_row_identifiers instead" ) query_str = """ query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { getDataRowIdsBySavedQuery(input: { savedQueryId: $id, after: $from first: $first }) { totalCount nodes pageInfo { endCursor hasNextPage } } } """ return PaginatedCollection( client=self.client, query=query_str, params={'id': str(self.uid)}, dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], obj_class=lambda _, data_row_id: data_row_id, cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
[docs] def get_data_row_identifiers(self) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice Returns: A PaginatedCollection of Slice.DataRowIdAndGlobalKey """ query_str = """ query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { getDataRowIdentifiersBySavedQuery(input: { savedQueryId: $id, after: $from first: $first }) { totalCount nodes { id globalKey } pageInfo { endCursor hasNextPage } } } """ return PaginatedCollection( client=self.client, query=query_str, params={'id': str(self.uid)}, dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( data_row_id_and_gk.get('id'), data_row_id_and_gk.get('globalKey', None)), cursor_path=[ 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' ])
[docs] def export(self, task_name: Optional[str] = None, params: Optional[CatalogExportParams] = None) -> ExportTask: """ Creates a slice export task with the given params and returns the task. >>> slice = client.get_catalog_slice("SLICE_ID") >>> task = slice.export( >>> params={"performance_details": False, "label_details": True} >>> ) >>> task.wait_till_done() >>> task.result """ task, _ = self._export(task_name, params, streamable=True) return ExportTask(task)
[docs] def export_v2( self, task_name: Optional[str] = None, params: Optional[CatalogExportParams] = None, ) -> Union[Task, ExportTask]: """ Creates a slice export task with the given params and returns the task. >>> slice = client.get_catalog_slice("SLICE_ID") >>> task = slice.export_v2( >>> params={"performance_details": False, "label_details": True} >>> ) >>> task.wait_till_done() >>> task.result """ task, is_streamable = self._export(task_name, params) if (is_streamable): return ExportTask(task, True) return task
def _export( self, task_name: Optional[str] = None, params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: _params = params or CatalogExportParams({ "attachments": False, "metadata_fields": False, "data_row_details": False, "project_details": False, "performance_details": False, "label_details": False, "media_type_override": None, "model_run_ids": None, "project_ids": None, "interpolated_frames": False, "all_projects": False, "all_model_runs": False, }) validate_catalog_export_params(_params) mutation_name = "exportDataRowsInSlice" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInSliceInput!)" f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") media_type_override = _params.get('media_type_override', None) query_params = { "input": { "taskName": task_name, "filters": { "sliceId": self.uid }, "isStreamableReady": True, "params": { "mediaTypeOverride": media_type_override.value if media_type_override is not None else None, "includeAttachments": _params.get('attachments', False), "includeMetadata": _params.get('metadata_fields', False), "includeDataRowDetails": _params.get('data_row_details', False), "includeProjectDetails": _params.get('project_details', False), "includePerformanceDetails": _params.get('performance_details', False), "includeLabelDetails": _params.get('label_details', False), "includeInterpolatedFrames": _params.get('interpolated_frames', False), "projectIds": _params.get('project_ids', None), "modelRunIds": _params.get('model_run_ids', None), "allProjects": _params.get('all_projects', False), "allModelRuns": _params.get('all_model_runs', False), }, "streamable": streamable, } } res = self.client.execute(create_task_query_str, query_params, error_log_key="errors") res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable
[docs]class ModelSlice(Slice): """ Represents a Slice used for filtering data rows in Model. """ @classmethod def query_str(cls): query_str = """ query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $modelRunId: ID, $from: DataRowIdentifierCursorInput, $first: Int!) { getDataRowIdentifiersBySavedModelQuery(input: { savedQueryId: $id, modelRunId: $modelRunId, after: $from first: $first }) { totalCount nodes { id globalKey } pageInfo { endCursor { dataRowId globalKey } hasNextPage } } } """ return query_str
[docs] def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection: """ Fetches all data row ids that match this Slice Params model_run_id: str, required, uid or cuid of model run Returns: A PaginatedCollection of data row ids """ return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), params={ 'id': str(self.uid), 'modelRunId': model_run_id }, dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' ), cursor_path=[ 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', 'endCursor' ])
[docs] def get_data_row_identifiers(self, model_run_id: str) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice Params: model_run_id : str, required, uid or cuid of model run Returns: A PaginatedCollection of Slice.DataRowIdAndGlobalKey """ return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), params={ 'id': str(self.uid), 'modelRunId': model_run_id }, dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( data_row_id_and_gk.get('id'), data_row_id_and_gk.get('globalKey', None)), cursor_path=[ 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', 'endCursor' ])