from dataclasses import dataclass
from typing import Optional, Tuple, Union
import warnings
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field
from labelbox.pagination import PaginatedCollection
from labelbox.schema.export_params import (
CatalogExportParams,
validate_catalog_export_params,
)
from labelbox.schema.export_task import ExportTask
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.task import Task
[docs]class Slice(DbObject):
"""
A Slice is a saved set of filters (saved query).
This is an abstract class and should not be instantiated.
Attributes:
name (datetime)
description (datetime)
created_at (datetime)
updated_at (datetime)
filter (json)
"""
name = Field.String("name")
description = Field.String("description")
created_at = Field.DateTime("created_at")
updated_at = Field.DateTime("updated_at")
filter = Field.Json("filter")
[docs] @dataclass
class DataRowIdAndGlobalKey:
id: UniqueId
global_key: Optional[GlobalKey]
def __init__(self, id: str, global_key: Optional[str]):
self.id = UniqueId(id)
self.global_key = GlobalKey(global_key) if global_key else None
def to_hash(self):
return {
"id": self.id.key,
"global_key": self.global_key.key if self.global_key else None,
}
[docs]class CatalogSlice(Slice):
"""
Represents a Slice used for filtering data rows in Catalog.
"""
[docs] def get_data_row_identifiers(self) -> PaginatedCollection:
"""
Fetches all data row ids and global keys (where defined) that match this Slice
Returns:
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
"""
query_str = """
query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
getDataRowIdentifiersBySavedQuery(input: {
savedQueryId: $id,
after: $from
first: $first
}) {
totalCount
nodes
{
id
globalKey
}
pageInfo {
endCursor
hasNextPage
}
}
}
"""
return PaginatedCollection(
client=self.client,
query=query_str,
params={"id": str(self.uid)},
dereferencing=["getDataRowIdentifiersBySavedQuery", "nodes"],
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
data_row_id_and_gk.get("id"),
data_row_id_and_gk.get("globalKey", None),
),
cursor_path=[
"getDataRowIdentifiersBySavedQuery",
"pageInfo",
"endCursor",
],
)
[docs] def export(
self,
task_name: Optional[str] = None,
params: Optional[CatalogExportParams] = None,
) -> ExportTask:
"""
Creates a slice export task with the given params and returns the task.
>>> slice = client.get_catalog_slice("SLICE_ID")
>>> task = slice.export(
>>> params={"performance_details": False, "label_details": True}
>>> )
>>> task.wait_till_done()
>>> task.result
"""
task, _ = self._export(task_name, params, streamable=True)
return ExportTask(task)
[docs] def export_v2(
self,
task_name: Optional[str] = None,
params: Optional[CatalogExportParams] = None,
) -> Union[Task, ExportTask]:
"""
Creates a slice export task with the given params and returns the task.
>>> slice = client.get_catalog_slice("SLICE_ID")
>>> task = slice.export_v2(
>>> params={"performance_details": False, "label_details": True}
>>> )
>>> task.wait_till_done()
>>> task.result
"""
warnings.warn(
"The method export_v2 for CatalogSlice is deprecated and will be removed in the next major release. Use the export method instead.",
DeprecationWarning,
stacklevel=2,
)
task, is_streamable = self._export(task_name, params)
if is_streamable:
return ExportTask(task, True)
return task
def _export(
self,
task_name: Optional[str] = None,
params: Optional[CatalogExportParams] = None,
streamable: bool = False,
) -> Tuple[Task, bool]:
_params = params or CatalogExportParams(
{
"attachments": False,
"embeddings": False,
"metadata_fields": False,
"data_row_details": False,
"project_details": False,
"performance_details": False,
"label_details": False,
"media_type_override": None,
"model_run_ids": None,
"project_ids": None,
"interpolated_frames": False,
"all_projects": False,
"all_model_runs": False,
"predictions": False,
}
)
validate_catalog_export_params(_params)
mutation_name = "exportDataRowsInSlice"
create_task_query_str = (
f"mutation {mutation_name}PyApi"
f"($input: ExportDataRowsInSliceInput!)"
f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}"
)
media_type_override = _params.get("media_type_override", None)
query_params = {
"input": {
"taskName": task_name,
"filters": {"sliceId": self.uid},
"isStreamableReady": True,
"params": {
"mediaTypeOverride": media_type_override.value
if media_type_override is not None
else None,
"includeAttachments": _params.get("attachments", False),
"includeEmbeddings": _params.get("embeddings", False),
"includeMetadata": _params.get("metadata_fields", False),
"includeDataRowDetails": _params.get(
"data_row_details", False
),
"includeProjectDetails": _params.get(
"project_details", False
),
"includePerformanceDetails": _params.get(
"performance_details", False
),
"includeLabelDetails": _params.get("label_details", False),
"includeInterpolatedFrames": _params.get(
"interpolated_frames", False
),
"projectIds": _params.get("project_ids", None),
"modelRunIds": _params.get("model_run_ids", None),
"allProjects": _params.get("all_projects", False),
"allModelRuns": _params.get("all_model_runs", False),
"includePredictions": _params.get("predictions", False),
},
"streamable": streamable,
}
}
res = self.client.execute(
create_task_query_str, query_params, error_log_key="errors"
)
res = res[mutation_name]
task_id = res["taskId"]
is_streamable = res["isStreamable"]
return Task.get_task(self.client, task_id), is_streamable
[docs]class ModelSlice(Slice):
"""
Represents a Slice used for filtering data rows in Model.
"""
@classmethod
def query_str(cls):
query_str = """
query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $modelRunId: ID, $from: DataRowIdentifierCursorInput, $first: Int!) {
getDataRowIdentifiersBySavedModelQuery(input: {
savedQueryId: $id,
modelRunId: $modelRunId,
after: $from
first: $first
}) {
totalCount
nodes
{
id
globalKey
}
pageInfo {
endCursor {
dataRowId
globalKey
}
hasNextPage
}
}
}
"""
return query_str
[docs] def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection:
"""
Fetches all data row ids that match this Slice
Params
model_run_id: str, required, uid or cuid of model run
Returns:
A PaginatedCollection of data row ids
"""
warnings.warn(
"The method get_data_row_ids for ModelSlice is deprecated and will be removed in the next major release. Use the get_data_row_identifiers method instead.",
DeprecationWarning,
stacklevel=2,
)
return PaginatedCollection(
client=self.client,
query=ModelSlice.query_str(),
params={"id": str(self.uid), "modelRunId": model_run_id},
dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"],
obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get(
"id"
),
cursor_path=[
"getDataRowIdentifiersBySavedModelQuery",
"pageInfo",
"endCursor",
],
)
[docs] def get_data_row_identifiers(
self, model_run_id: str
) -> PaginatedCollection:
"""
Fetches all data row ids and global keys (where defined) that match this Slice
Params:
model_run_id : str, required, uid or cuid of model run
Returns:
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
"""
return PaginatedCollection(
client=self.client,
query=ModelSlice.query_str(),
params={"id": str(self.uid), "modelRunId": model_run_id},
dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"],
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
data_row_id_and_gk.get("id"),
data_row_id_and_gk.get("globalKey", None),
),
cursor_path=[
"getDataRowIdentifiersBySavedModelQuery",
"pageInfo",
"endCursor",
],
)