Source code for labelbox.schema.dataset

import os
import json
import logging
from itertools import islice
from multiprocessing.dummy import Pool as ThreadPool
import time
import ndjson
from io import StringIO
import requests

from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.orm.db_object import DbObject, Updateable, Deletable
from labelbox.orm.model import Entity, Field, Relationship

logger = logging.getLogger(__name__)


[docs]class Dataset(DbObject, Updateable, Deletable): """ A Dataset is a collection of DataRows. Attributes: name (str) description (str) updated_at (datetime) created_at (datetime) projects (Relationship): `ToMany` relationship to Project data_rows (Relationship): `ToMany` relationship to DataRow created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ name = Field.String("name") description = Field.String("description") updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") # Relationships projects = Relationship.ToMany("Project", True) data_rows = Relationship.ToMany("DataRow", False) created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization", False)
[docs] def create_data_row(self, **kwargs): """ Creates a single DataRow belonging to this dataset. >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") Args: **kwargs: Key-value arguments containing new `DataRow` data. At a minimum, must contain `row_data`. Raises: InvalidQueryError: If `DataRow.row_data` field value is not provided in `kwargs`. InvalidAttributeError: in case the DB object type does not contain any of the field names given in `kwargs`. """ DataRow = Entity.DataRow if DataRow.row_data.name not in kwargs: raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") # If row data is a local file path, upload it to server. row_data = kwargs[DataRow.row_data.name] if os.path.exists(row_data): kwargs[DataRow.row_data.name] = self.client.upload_file(row_data) kwargs[DataRow.dataset.name] = self return self.client._create(DataRow, kwargs)
[docs] def create_data_rows(self, items): """ Creates multiple DataRow objects based on the given `items`. Each element in `items` can be either a `str` or a `dict`. If it is a `str`, then it is interpreted as a local file path. The file is uploaded to Labelbox and a DataRow referencing it is created. If an item is a `dict`, then it could support one of the two following structures 1. For static imagery, video, and text it should map `DataRow` fields (or their names) to values. At the minimum an `item` passed as a `dict` must contain a `DataRow.row_data` key and value. 2. For tiled imagery the dict must match the import structure specified in the link below https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import >>> dataset.create_data_rows([ >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, >>> "path/to/file2.jpg", >>> {"tileLayerUrl" : "http://", ...} >>> ]) For an example showing how to upload tiled data_rows see the following notebook: https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb Args: items (iterable of (dict or str)): See above for details. Returns: Task representing the data import on the server side. The Task can be used for inspecting task progress and waiting until it's done. Raises: InvalidQueryError: If the `items` parameter does not conform to the specification above or if the server did not accept the DataRow creation request (unknown reason). ResourceNotFoundError: If unable to retrieve the Task for the import process. This could imply that the import failed. InvalidAttributeError: If there are fields in `items` not valid for a DataRow. """ file_upload_thread_count = 20 DataRow = Entity.DataRow def upload_if_necessary(item): if isinstance(item, str): item_url = self.client.upload_file(item) # Convert item from str into a dict so it gets processed # like all other dicts. item = {DataRow.row_data: item_url, DataRow.external_id: item} return item with ThreadPool(file_upload_thread_count) as thread_pool: items = thread_pool.map(upload_if_necessary, items) def convert_item(item): # Don't make any changes to tms data if "tileLayerUrl" in item: return item # Convert string names to fields. item = { key if isinstance(key, Field) else DataRow.field(key): value for key, value in item.items() } if DataRow.row_data not in item: raise InvalidQueryError( "DataRow.row_data missing when creating DataRow.") invalid_keys = set(item) - set(DataRow.fields()) if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) # Item is valid, convert it to a dict {graphql_field_name: value} # Need to change the name of DataRow.row_data to "data" return { "data" if key == DataRow.row_data else key.graphql_name: value for key, value in item.items() } # Prepare and upload the desciptor file items = [convert_item(item) for item in items] data = json.dumps(items) descriptor_url = self.client.upload_data(data) # Create data source dataset_param = "datasetId" url_param = "jsonUrl" query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){ appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s} ){ taskId accepted } } """ % (dataset_param, url_param, dataset_param, url_param) res = self.client.execute(query_str, { dataset_param: self.uid, url_param: descriptor_url }) res = res["appendRowsToDataset"] if not res["accepted"]: raise InvalidQueryError( "Server did not accept DataRow creation request") # Fetch and return the task. task_id = res["taskId"] user = self.client.get_user() task = list(user.created_tasks(where=Entity.Task.uid == task_id)) # Cache user in a private variable as the relationship can't be # resolved due to server-side limitations (see Task.created_by) # for more info. if len(task) != 1: raise ResourceNotFoundError(Entity.Task, task_id) task = task[0] task._user = user return task
[docs] def data_rows_for_external_id(self, external_id, limit=10): """ Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: external_id (str): External ID of the sought `DataRow`. limit (int): The maximum number of data rows to return for the given external_id Returns: A single `DataRow` with the given ID. Raises: labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ DataRow = Entity.DataRow where = DataRow.external_id == external_id data_rows = self.data_rows(where=where) # Get at most `limit` data_rows. data_rows = list(islice(data_rows, limit)) if not len(data_rows): raise ResourceNotFoundError(DataRow, where) return data_rows
[docs] def data_row_for_external_id(self, external_id): """ Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: external_id (str): External ID of the sought `DataRow`. Returns: A single `DataRow` with the given ID. Raises: labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ data_rows = self.data_rows_for_external_id(external_id=external_id, limit=2) if len(data_rows) > 1: logger.warning( f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", external_id) return data_rows[0]
[docs] def export_data_rows(self, timeout_seconds=120): """ Returns a generator that produces all data rows that are currently attached to this dataset. Args: timeout_seconds (float): Max waiting time, in seconds. Returns: Generator that yields DataRow objects belonging to this dataset. Raises: LabelboxError: if the export fails or is unable to download within the specified time. """ id_param = "datasetId" query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!) {exportDatasetDataRows(data:{datasetId: $%s }) {downloadUrl createdAt status}} """ % (id_param, id_param) sleep_time = 2 while True: res = self.client.execute(query_str, {id_param: self.uid}) res = res["exportDatasetDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] response = requests.get(download_url) response.raise_for_status() reader = ndjson.reader(StringIO(response.text)) return ( Entity.DataRow(self.client, result) for result in reader) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") timeout_seconds -= sleep_time if timeout_seconds <= 0: raise LabelboxError( f"Unable to export data rows within {timeout_seconds} seconds." ) logger.debug("Dataset '%s' data row export, waiting for server...", self.uid) time.sleep(sleep_time)