import json
import logging
import os
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import islice
from string import Template
from typing import Any, Dict, List, Optional, Tuple, Union
from lbox.exceptions import (
InvalidQueryError,
LabelboxError,
ResourceCreationError,
ResourceNotFoundError,
) # type: ignore
import labelbox.schema.internal.data_row_uploader as data_row_uploader
from labelbox.orm import query
from labelbox.orm.comparison import Comparison
from labelbox.orm.db_object import DbObject, Deletable, Updateable
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.pagination import PaginatedCollection
from labelbox.schema.data_row import DataRow
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
from labelbox.schema.export_params import (
CatalogExportParams,
validate_catalog_export_params,
)
from labelbox.schema.export_task import ExportTask
from labelbox.schema.iam_integration import IAMIntegration
from labelbox.schema.identifiable import GlobalKey, UniqueId
from labelbox.schema.internal.data_row_upsert_item import (
DataRowCreateItem,
DataRowItemBase,
DataRowUpsertItem,
)
from labelbox.schema.internal.datarow_upload_constants import (
FILE_UPLOAD_THREAD_COUNT,
UPSERT_CHUNK_SIZE_BYTES,
)
from labelbox.schema.task import DataUpsertTask, Task
logger = logging.getLogger(__name__)
[docs]class Dataset(DbObject, Updateable, Deletable):
"""A Dataset is a collection of DataRows.
Attributes:
name (str)
description (str)
updated_at (datetime)
created_at (datetime)
row_count (int): The number of rows in the dataset. Fetch the dataset again to update since this is cached.
created_by (Relationship): `ToOne` relationship to User
organization (Relationship): `ToOne` relationship to Organization
"""
name = Field.String("name")
description = Field.String("description")
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
row_count = Field.Int("row_count")
# Relationships
created_by = Relationship.ToOne("User", False, "created_by")
organization = Relationship.ToOne("Organization", False)
iam_integration = Relationship.ToOne(
"IAMIntegration", False, "iam_integration", "signer"
)
[docs] def data_rows(
self,
from_cursor: Optional[str] = None,
where: Optional[Comparison] = None,
) -> PaginatedCollection:
"""
Custom method to paginate data_rows via cursor.
Args:
from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning
where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on.
example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id'
NOTE:
Order of retrieval is newest data row first.
Deleted data rows are not retrieved.
Failed data rows are not retrieved.
Data rows in progress *maybe* retrieved.
"""
page_size = 500 # hardcode to avoid overloading the server
where_param = (
query.where_as_dict(Entity.DataRow, where)
if where is not None
else None
)
template = Template(
"""query DatasetDataRowsPyApi($$id: ID!, $$from: ID, $$first: Int, $$where: DatasetDataRowWhereInput) {
datasetDataRows(id: $$id, from: $$from, first: $$first, where: $$where)
{
nodes { $datarow_selections }
pageInfo { hasNextPage startCursor }
}
}
"""
)
query_str = template.substitute(
datarow_selections=query.results_query_part(Entity.DataRow)
)
params = {
"id": self.uid,
"from": from_cursor,
"first": page_size,
"where": where_param,
}
return PaginatedCollection(
client=self.client,
query=query_str,
params=params,
dereferencing=["datasetDataRows", "nodes"],
obj_class=Entity.DataRow,
cursor_path=["datasetDataRows", "pageInfo", "startCursor"],
)
[docs] def create_data_row(self, items=None, **kwargs) -> "DataRow":
"""Creates a single DataRow belonging to this dataset.
>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")
Args:
items: Dictionary containing new `DataRow` data. At a minimum,
must contain `row_data` or `DataRow.row_data`.
**kwargs: Key-value arguments containing new `DataRow` data. At a minimum,
must contain `row_data`.
Raises:
InvalidQueryError: If both dictionary and `kwargs` are provided as inputs
InvalidQueryError: If `DataRow.row_data` field value is not provided
in `kwargs`.
InvalidAttributeError: in case the DB object type does not contain
any of the field names given in `kwargs`.
ResourceCreationError: If data row creation failed on the server side.
"""
invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum"
if items is not None and len(kwargs) > 0:
raise InvalidQueryError(invalid_argument_error)
args = items if items is not None else kwargs
file_upload_thread_count = 1
completed_task = self._create_data_rows_sync(
[args], file_upload_thread_count=file_upload_thread_count
)
res = completed_task.result
if res is None or len(res) == 0:
raise ResourceCreationError(
f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}"
)
return self.client.get_data_row(res[0]["id"])
def _create_data_rows_sync(
self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
) -> "DataUpsertTask":
if file_upload_thread_count < 1:
raise ValueError(
"file_upload_thread_count must be a positive integer"
)
task: DataUpsertTask = self.create_data_rows(
items, file_upload_thread_count
)
task.wait_till_done()
if task.has_errors():
raise ResourceCreationError(
f"Data row upload errors: {task.errors}", cause=task.uid
)
if task.status != "COMPLETE":
raise ResourceCreationError(
f"Data row upload did not complete, task status {task.status} task id {task.uid}"
)
return task
[docs] def create_data_rows(
self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
) -> "DataUpsertTask":
"""Asynchronously bulk upload data rows
Args:
items (iterable of (dict or str))
Returns:
Task representing the data import on the server side. The Task
can be used for inspecting task progress and waiting until it's done.
Raises:
InvalidQueryError: If the `items` parameter does not conform to
the specification above or if the server did not accept the
DataRow creation request (unknown reason).
ResourceNotFoundError: If unable to retrieve the Task for the
import process. This could imply that the import failed.
InvalidAttributeError: If there are fields in `items` not valid for
a DataRow.
ValueError: When the upload parameters are invalid
NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type.
"""
if file_upload_thread_count < 1:
raise ValueError(
"file_upload_thread_count must be a positive integer"
)
# Usage example
upload_items = self._separate_and_process_items(items)
specs = DataRowCreateItem.build(self.uid, upload_items)
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
def _separate_and_process_items(self, items):
string_items = [item for item in items if isinstance(item, str)]
dict_items = [item for item in items if isinstance(item, dict)]
dict_string_items = []
if len(string_items) > 0:
dict_string_items = self._build_from_local_paths(string_items)
return dict_items + dict_string_items
def _build_from_local_paths(
self,
items: List[str],
file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT,
) -> List[dict]:
uploaded_items = []
def upload_file(item):
item_url = self.client.upload_file(item)
return {"row_data": item_url, "external_id": item}
with ThreadPoolExecutor(file_upload_thread_count) as executor:
futures = [
executor.submit(upload_file, item)
for item in items
if isinstance(item, str) and os.path.exists(item)
]
more_items = [future.result() for future in as_completed(futures)]
uploaded_items.extend(more_items)
return uploaded_items
[docs] def data_rows_for_external_id(
self, external_id, limit=10
) -> List["DataRow"]:
"""Convenience method for getting a multiple `DataRow` belonging to this
`Dataset` that has the given `external_id`.
Args:
external_id (str): External ID of the sought `DataRow`.
limit (int): The maximum number of data rows to return for the given external_id
Returns:
A list of `DataRow` with the given ID.
Raises:
lbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
DataRow = Entity.DataRow
where = DataRow.external_id == external_id
data_rows = self.data_rows(where=where)
# Get at most `limit` data_rows.
at_most_data_rows = list(islice(data_rows, limit))
if not len(at_most_data_rows):
raise ResourceNotFoundError(DataRow, where)
return at_most_data_rows
[docs] def data_row_for_external_id(self, external_id) -> "DataRow":
"""Convenience method for getting a single `DataRow` belonging to this
`Dataset` that has the given `external_id`.
Args:
external_id (str): External ID of the sought `DataRow`.
Returns:
A single `DataRow` with the given ID.
Raises:
lbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
data_rows = self.data_rows_for_external_id(
external_id=external_id, limit=2
)
if len(data_rows) > 1:
logger.warning(
"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all",
external_id,
)
return data_rows[0]
[docs] def export(
self,
task_name: Optional[str] = None,
filters: Optional[DatasetExportFilters] = None,
params: Optional[CatalogExportParams] = None,
) -> ExportTask:
"""
Creates a dataset export task with the given params and returns the task.
>>> dataset = client.get_dataset(DATASET_ID)
>>> task = dataset.export(
>>> filters={
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
>>> },
>>> params={
>>> "performance_details": False,
>>> "label_details": True
>>> })
>>> task.wait_till_done()
>>> task.result
"""
task, _ = self._export(task_name, filters, params, streamable=True)
return ExportTask(task)
[docs] def export_v2(
self,
task_name: Optional[str] = None,
filters: Optional[DatasetExportFilters] = None,
params: Optional[CatalogExportParams] = None,
) -> Union[Task, ExportTask]:
"""
Creates a dataset export task with the given params and returns the task.
>>> dataset = client.get_dataset(DATASET_ID)
>>> task = dataset.export_v2(
>>> filters={
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...]
>>> },
>>> params={
>>> "performance_details": False,
>>> "label_details": True
>>> })
>>> task.wait_till_done()
>>> task.result
"""
warnings.warn(
"You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
DeprecationWarning,
stacklevel=2,
)
task, is_streamable = self._export(task_name, filters, params)
if is_streamable:
return ExportTask(task, True)
return task
def _export(
self,
task_name: Optional[str] = None,
filters: Optional[DatasetExportFilters] = None,
params: Optional[CatalogExportParams] = None,
streamable: bool = False,
) -> Tuple[Task, bool]:
_params = params or CatalogExportParams(
{
"attachments": False,
"embeddings": False,
"metadata_fields": False,
"data_row_details": False,
"project_details": False,
"performance_details": False,
"label_details": False,
"media_type_override": None,
"model_run_ids": None,
"project_ids": None,
"interpolated_frames": False,
"all_projects": False,
"all_model_runs": False,
}
)
validate_catalog_export_params(_params)
_filters = filters or DatasetExportFilters(
{
"last_activity_at": None,
"label_created_at": None,
"data_row_ids": None,
"global_keys": None,
}
)
mutation_name = "exportDataRowsInCatalog"
create_task_query_str = (
f"mutation {mutation_name}PyApi"
f"($input: ExportDataRowsInCatalogInput!)"
f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}"
)
media_type_override = _params.get("media_type_override", None)
if task_name is None:
task_name = f"Export v2: dataset - {self.name}"
query_params: Dict[str, Any] = {
"input": {
"taskName": task_name,
"filters": {
"searchQuery": {
"scope": None,
"query": None,
}
},
"isStreamableReady": True,
"params": {
"mediaTypeOverride": media_type_override.value
if media_type_override is not None
else None,
"includeAttachments": _params.get("attachments", False),
"includeEmbeddings": _params.get("embeddings", False),
"includeMetadata": _params.get("metadata_fields", False),
"includeDataRowDetails": _params.get(
"data_row_details", False
),
"includeProjectDetails": _params.get(
"project_details", False
),
"includePerformanceDetails": _params.get(
"performance_details", False
),
"includeLabelDetails": _params.get("label_details", False),
"includeInterpolatedFrames": _params.get(
"interpolated_frames", False
),
"includePredictions": _params.get("predictions", False),
"projectIds": _params.get("project_ids", None),
"modelRunIds": _params.get("model_run_ids", None),
"allProjects": _params.get("all_projects", False),
"allModelRuns": _params.get("all_model_runs", False),
},
"streamable": streamable,
}
}
search_query = build_filters(self.client, _filters)
search_query.append(
{"ids": [self.uid], "operator": "is", "type": "dataset"}
)
query_params["input"]["filters"]["searchQuery"]["query"] = search_query
res = self.client.execute(
create_task_query_str, query_params, error_log_key="errors"
)
res = res[mutation_name]
task_id = res["taskId"]
is_streamable = res["isStreamable"]
return Task.get_task(self.client, task_id), is_streamable
[docs] def upsert_data_rows(
self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
) -> "DataUpsertTask":
"""
Upserts data rows in this dataset. When "key" is provided, and it references an existing data row,
an update will be performed. When "key" is not provided a new data row will be created.
>>> task = dataset.upsert_data_rows([
>>> # create new data row
>>> {
>>> "row_data": "http://my_site.com/photos/img_01.jpg",
>>> "global_key": "global_key1",
>>> "external_id": "ex_id1",
>>> "attachments": [
>>> {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"}
>>> ],
>>> "metadata": [
>>> {"name": "tag", "value": "tag value"},
>>> ]
>>> },
>>> # update global key of data row by existing global key
>>> {
>>> "key": GlobalKey("global_key1"),
>>> "global_key": "global_key1_updated"
>>> },
>>> # update data row by ID
>>> {
>>> "key": UniqueId(dr.uid),
>>> "external_id": "ex_id1_updated"
>>> },
>>> ])
>>> task.wait_till_done()
"""
specs = DataRowUpsertItem.build(self.uid, items, (UniqueId, GlobalKey))
return self._exec_upsert_data_rows(specs, file_upload_thread_count)
def _exec_upsert_data_rows(
self,
specs: List[DataRowItemBase],
file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT,
) -> "DataUpsertTask":
manifest = data_row_uploader.upload_in_chunks(
client=self.client,
specs=specs,
file_upload_thread_count=file_upload_thread_count,
max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES,
)
data = json.dumps(manifest.model_dump()).encode("utf-8")
manifest_uri = self.client.upload_data(
data, content_type="application/json", filename="manifest.json"
)
query_str = """
mutation UpsertDataRowsPyApi($manifestUri: String!) {
upsertDataRows(data: { manifestUri: $manifestUri }) {
id createdAt updatedAt name status completionPercentage result errors type metadata
}
}
"""
res = self.client.execute(query_str, {"manifestUri": manifest_uri})
res = res["upsertDataRows"]
task = DataUpsertTask(self.client, res)
task._user = self.client.get_user()
return task
[docs] def add_iam_integration(
self, iam_integration: Union[str, IAMIntegration]
) -> IAMIntegration:
"""
Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets.
Args:
iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id.
Returns:
IAMIntegration: IAM integration object.
Raises:
LabelboxError: If the IAM integration can't be set.
Examples:
>>> # Get all IAM integrations
>>> iam_integrations = client.get_organization().get_iam_integrations()
>>>
>>> # Get IAM integration id
>>> iam_integration_id = [integration.uid for integration
>>> in iam_integrations
>>> if integration.name == "My S3 integration"][0]
>>>
>>> # Set IAM integration for integration id
>>> dataset.set_iam_integration(iam_integration_id)
>>>
>>> # Get IAM integration object
>>> iam_integration = [integration.uid for integration
>>> in iam_integrations
>>> if integration.name == "My S3 integration"][0]
>>>
>>> # Set IAM integration for IAMIntegrtion object
>>> dataset.set_iam_integration(iam_integration)
"""
iam_integration_id = (
iam_integration.uid
if isinstance(iam_integration, IAMIntegration)
else iam_integration
)
query = """
mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) {
setSignerForDataset(
data: { signerId: $signerId }
where: { id: $datasetId }
) {
id
signer {
id
}
}
}
"""
response = self.client.execute(
query, {"signerId": iam_integration_id, "datasetId": self.uid}
)
if not response:
raise ResourceNotFoundError(
IAMIntegration,
{"signerId": iam_integration_id, "datasetId": self.uid},
)
try:
iam_integration_id = response.get("setSignerForDataset", {}).get(
"signer", {}
)["id"]
return [
integration
for integration in self.client.get_organization().get_iam_integrations()
if integration.uid == iam_integration_id
][0]
except:
raise LabelboxError(
f"Can't retrieve IAM integration {iam_integration_id}"
)
[docs] def remove_iam_integration(self) -> None:
"""
Unsets the IAM integration for the dataset.
Args:
None
Returns:
None
Raises:
LabelboxError: If the IAM integration can't be unset.
Examples:
>>> dataset.remove_iam_integration()
"""
query = """
mutation DetachSignerPyApi($id: ID!) {
clearSignerForDataset(where: { id: $id }) {
id
}
}
"""
response = self.client.execute(query, {"id": self.uid})
if not response:
raise ResourceNotFoundError(Dataset, {"id": self.uid})