Source code for labelbox.schema.dataset

import os
import json
import logging
from itertools import islice
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import ndjson
from io import StringIO
import requests

from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.orm.db_object import DbObject, Updateable, Deletable
from labelbox.orm.model import Entity, Field, Relationship

logger = logging.getLogger(__name__)


[docs]class Dataset(DbObject, Updateable, Deletable): """ A Dataset is a collection of DataRows. Attributes: name (str) description (str) updated_at (datetime) created_at (datetime) row_count (int): The number of rows in the dataset. Fetch the dataset again to update since this is cached. projects (Relationship): `ToMany` relationship to Project data_rows (Relationship): `ToMany` relationship to DataRow created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ name = Field.String("name") description = Field.String("description") updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") row_count = Field.Int("row_count") # Relationships projects = Relationship.ToMany("Project", True) data_rows = Relationship.ToMany("DataRow", False) created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization", False)
[docs] def create_data_row(self, **kwargs): """ Creates a single DataRow belonging to this dataset. >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") Args: **kwargs: Key-value arguments containing new `DataRow` data. At a minimum, must contain `row_data`. Raises: InvalidQueryError: If `DataRow.row_data` field value is not provided in `kwargs`. InvalidAttributeError: in case the DB object type does not contain any of the field names given in `kwargs`. """ DataRow = Entity.DataRow if DataRow.row_data.name not in kwargs: raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") # If row data is a local file path, upload it to server. row_data = kwargs[DataRow.row_data.name] if os.path.exists(row_data): kwargs[DataRow.row_data.name] = self.client.upload_file(row_data) kwargs[DataRow.dataset.name] = self return self.client._create(DataRow, kwargs)
[docs] def create_data_rows(self, items): """ Creates multiple DataRow objects based on the given `items`. Each element in `items` can be either a `str` or a `dict`. If it is a `str`, then it is interpreted as a local file path. The file is uploaded to Labelbox and a DataRow referencing it is created. If an item is a `dict`, then it could support one of the two following structures 1. For static imagery, video, and text it should map `DataRow` fields (or their names) to values. At the minimum an `item` passed as a `dict` must contain a `DataRow.row_data` key and value. 2. For tiled imagery the dict must match the import structure specified in the link below https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import >>> dataset.create_data_rows([ >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, >>> "path/to/file2.jpg", >>> {"tileLayerUrl" : "http://", ...} >>> ]) For an example showing how to upload tiled data_rows see the following notebook: https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb Args: items (iterable of (dict or str)): See above for details. Returns: Task representing the data import on the server side. The Task can be used for inspecting task progress and waiting until it's done. Raises: InvalidQueryError: If the `items` parameter does not conform to the specification above or if the server did not accept the DataRow creation request (unknown reason). ResourceNotFoundError: If unable to retrieve the Task for the import process. This could imply that the import failed. InvalidAttributeError: If there are fields in `items` not valid for a DataRow. """ file_upload_thread_count = 20 DataRow = Entity.DataRow def upload_if_necessary(item): if isinstance(item, str): item_url = self.client.upload_file(item) # Convert item from str into a dict so it gets processed # like all other dicts. item = {DataRow.row_data: item_url, DataRow.external_id: item} return item with ThreadPoolExecutor(file_upload_thread_count) as executor: futures = [ executor.submit(upload_if_necessary, item) for item in items ] items = [future.result() for future in as_completed(futures)] def convert_item(item): # Don't make any changes to tms data if "tileLayerUrl" in item: return item # Convert string names to fields. item = { key if isinstance(key, Field) else DataRow.field(key): value for key, value in item.items() } if DataRow.row_data not in item: raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") invalid_keys = set(item) - set(DataRow.fields()) if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) # Item is valid, convert it to a dict {graphql_field_name: value} # Need to change the name of DataRow.row_data to "data" return { "data" if key == DataRow.row_data else key.graphql_name: value for key, value in item.items() } # Prepare and upload the desciptor file items = [convert_item(item) for item in items] data = json.dumps(items) descriptor_url = self.client.upload_data(data) # Create data source dataset_param = "datasetId" url_param = "jsonUrl" query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){ appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s} ){ taskId accepted } } """ % (dataset_param, url_param, dataset_param, url_param) res = self.client.execute(query_str, { dataset_param: self.uid, url_param: descriptor_url }) res = res["appendRowsToDataset"] if not res["accepted"]: raise InvalidQueryError( "Server did not accept DataRow creation request") # Fetch and return the task. task_id = res["taskId"] user = self.client.get_user() task = list(user.created_tasks(where=Entity.Task.uid == task_id)) # Cache user in a private variable as the relationship can't be # resolved due to server-side limitations (see Task.created_by) # for more info. if len(task) != 1: raise ResourceNotFoundError(Entity.Task, task_id) task = task[0] task._user = user return task
[docs] def data_rows_for_external_id(self, external_id, limit=10): """ Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: external_id (str): External ID of the sought `DataRow`. limit (int): The maximum number of data rows to return for the given external_id Returns: A single `DataRow` with the given ID. Raises: labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ DataRow = Entity.DataRow where = DataRow.external_id == external_id data_rows = self.data_rows(where=where) # Get at most `limit` data_rows. data_rows = list(islice(data_rows, limit)) if not len(data_rows): raise ResourceNotFoundError(DataRow, where) return data_rows
[docs] def data_row_for_external_id(self, external_id): """ Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: external_id (str): External ID of the sought `DataRow`. Returns: A single `DataRow` with the given ID. Raises: labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ data_rows = self.data_rows_for_external_id(external_id=external_id, limit=2) if len(data_rows) > 1: logger.warning( f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", external_id) return data_rows[0]
[docs] def export_data_rows(self, timeout_seconds=120): """ Returns a generator that produces all data rows that are currently attached to this dataset. Args: timeout_seconds (float): Max waiting time, in seconds. Returns: Generator that yields DataRow objects belonging to this dataset. Raises: LabelboxError: if the export fails or is unable to download within the specified time. """ id_param = "datasetId" query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!) {exportDatasetDataRows(data:{datasetId: $%s }) {downloadUrl createdAt status}} """ % (id_param, id_param) sleep_time = 2 while True: res = self.client.execute(query_str, {id_param: self.uid}) res = res["exportDatasetDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) return ( Entity.DataRow(self.client, result) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") timeout_seconds -= sleep_time if timeout_seconds <= 0: raise LabelboxError( f"Unable to export data rows within {timeout_seconds} seconds." ) logger.debug("Dataset '%s' data row export, waiting for server...", self.uid) time.sleep(sleep_time)