import os
import json
import logging
from itertools import islice
from multiprocessing.dummy import Pool as ThreadPool
import time
import ndjson
from io import StringIO
import requests
from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError
from labelbox.orm.db_object import DbObject, Updateable, Deletable
from labelbox.orm.model import Entity, Field, Relationship
logger = logging.getLogger(__name__)
[docs]class Dataset(DbObject, Updateable, Deletable):
""" A Dataset is a collection of DataRows.
Attributes:
name (str)
description (str)
updated_at (datetime)
created_at (datetime)
projects (Relationship): `ToMany` relationship to Project
data_rows (Relationship): `ToMany` relationship to DataRow
created_by (Relationship): `ToOne` relationship to User
organization (Relationship): `ToOne` relationship to Organization
"""
name = Field.String("name")
description = Field.String("description")
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
# Relationships
projects = Relationship.ToMany("Project", True)
data_rows = Relationship.ToMany("DataRow", False)
created_by = Relationship.ToOne("User", False, "created_by")
organization = Relationship.ToOne("Organization", False)
[docs] def create_data_row(self, **kwargs):
""" Creates a single DataRow belonging to this dataset.
>>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg")
Args:
**kwargs: Key-value arguments containing new `DataRow` data. At a minimum,
must contain `row_data`.
Raises:
InvalidQueryError: If `DataRow.row_data` field value is not provided
in `kwargs`.
InvalidAttributeError: in case the DB object type does not contain
any of the field names given in `kwargs`.
"""
DataRow = Entity.DataRow
if DataRow.row_data.name not in kwargs:
raise InvalidQueryError(
"DataRow.row_data missing when creating DataRow.")
# If row data is a local file path, upload it to server.
row_data = kwargs[DataRow.row_data.name]
if os.path.exists(row_data):
kwargs[DataRow.row_data.name] = self.client.upload_file(row_data)
kwargs[DataRow.dataset.name] = self
return self.client._create(DataRow, kwargs)
[docs] def create_data_rows(self, items):
""" Creates multiple DataRow objects based on the given `items`.
Each element in `items` can be either a `str` or a `dict`. If
it is a `str`, then it is interpreted as a local file path. The file
is uploaded to Labelbox and a DataRow referencing it is created.
If an item is a `dict`, then it could support one of the two following structures
1. For static imagery, video, and text it should map `DataRow` fields (or their names) to values.
At the minimum an `item` passed as a `dict` must contain a `DataRow.row_data` key and value.
2. For tiled imagery the dict must match the import structure specified in the link below
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
>>> dataset.create_data_rows([
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
>>> "path/to/file2.jpg",
>>> {"tileLayerUrl" : "http://", ...}
>>> ])
For an example showing how to upload tiled data_rows see the following notebook:
https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb
Args:
items (iterable of (dict or str)): See above for details.
Returns:
Task representing the data import on the server side. The Task
can be used for inspecting task progress and waiting until it's done.
Raises:
InvalidQueryError: If the `items` parameter does not conform to
the specification above or if the server did not accept the
DataRow creation request (unknown reason).
ResourceNotFoundError: If unable to retrieve the Task for the
import process. This could imply that the import failed.
InvalidAttributeError: If there are fields in `items` not valid for
a DataRow.
"""
file_upload_thread_count = 20
DataRow = Entity.DataRow
def upload_if_necessary(item):
if isinstance(item, str):
item_url = self.client.upload_file(item)
# Convert item from str into a dict so it gets processed
# like all other dicts.
item = {DataRow.row_data: item_url, DataRow.external_id: item}
return item
with ThreadPool(file_upload_thread_count) as thread_pool:
items = thread_pool.map(upload_if_necessary, items)
def convert_item(item):
# Don't make any changes to tms data
if "tileLayerUrl" in item:
return item
# Convert string names to fields.
item = {
key if isinstance(key, Field) else DataRow.field(key): value
for key, value in item.items()
}
if DataRow.row_data not in item:
raise InvalidQueryError(
"DataRow.row_data missing when creating DataRow.")
invalid_keys = set(item) - set(DataRow.fields())
if invalid_keys:
raise InvalidAttributeError(DataRow, invalid_keys)
# Item is valid, convert it to a dict {graphql_field_name: value}
# Need to change the name of DataRow.row_data to "data"
return {
"data" if key == DataRow.row_data else key.graphql_name: value
for key, value in item.items()
}
# Prepare and upload the desciptor file
items = [convert_item(item) for item in items]
data = json.dumps(items)
descriptor_url = self.client.upload_data(data)
# Create data source
dataset_param = "datasetId"
url_param = "jsonUrl"
query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){
appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s}
){ taskId accepted } } """ % (dataset_param, url_param,
dataset_param, url_param)
res = self.client.execute(query_str, {
dataset_param: self.uid,
url_param: descriptor_url
})
res = res["appendRowsToDataset"]
if not res["accepted"]:
raise InvalidQueryError(
"Server did not accept DataRow creation request")
# Fetch and return the task.
task_id = res["taskId"]
user = self.client.get_user()
task = list(user.created_tasks(where=Entity.Task.uid == task_id))
# Cache user in a private variable as the relationship can't be
# resolved due to server-side limitations (see Task.created_by)
# for more info.
if len(task) != 1:
raise ResourceNotFoundError(Entity.Task, task_id)
task = task[0]
task._user = user
return task
[docs] def data_rows_for_external_id(self, external_id, limit=10):
""" Convenience method for getting a single `DataRow` belonging to this
`Dataset` that has the given `external_id`.
Args:
external_id (str): External ID of the sought `DataRow`.
limit (int): The maximum number of data rows to return for the given external_id
Returns:
A single `DataRow` with the given ID.
Raises:
labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
DataRow = Entity.DataRow
where = DataRow.external_id == external_id
data_rows = self.data_rows(where=where)
# Get at most `limit` data_rows.
data_rows = list(islice(data_rows, limit))
if not len(data_rows):
raise ResourceNotFoundError(DataRow, where)
return data_rows
[docs] def data_row_for_external_id(self, external_id):
""" Convenience method for getting a single `DataRow` belonging to this
`Dataset` that has the given `external_id`.
Args:
external_id (str): External ID of the sought `DataRow`.
Returns:
A single `DataRow` with the given ID.
Raises:
labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
data_rows = self.data_rows_for_external_id(external_id=external_id,
limit=2)
if len(data_rows) > 1:
logger.warning(
f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all",
external_id)
return data_rows[0]
[docs] def export_data_rows(self, timeout_seconds=120):
""" Returns a generator that produces all data rows that are currently attached to this dataset.
Args:
timeout_seconds (float): Max waiting time, in seconds.
Returns:
Generator that yields DataRow objects belonging to this dataset.
Raises:
LabelboxError: if the export fails or is unable to download within the specified time.
"""
id_param = "datasetId"
query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!)
{exportDatasetDataRows(data:{datasetId: $%s }) {downloadUrl createdAt status}}
""" % (id_param, id_param)
sleep_time = 2
while True:
res = self.client.execute(query_str, {id_param: self.uid})
res = res["exportDatasetDataRows"]
if res["status"] == "COMPLETE":
download_url = res["downloadUrl"]
response = requests.get(download_url)
response.raise_for_status()
reader = ndjson.reader(StringIO(response.text))
return (
Entity.DataRow(self.client, result) for result in reader)
elif res["status"] == "FAILED":
raise LabelboxError("Data row export failed.")
timeout_seconds -= sleep_time
if timeout_seconds <= 0:
raise LabelboxError(
f"Unable to export data rows within {timeout_seconds} seconds."
)
logger.debug("Dataset '%s' data row export, waiting for server...",
self.uid)
time.sleep(sleep_time)