import json
import logging
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import warnings
import requests
from lbox.exceptions import ResourceNotFoundError
from labelbox import parser
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Entity, Field, Relationship
from labelbox.pagination import PaginatedCollection
from labelbox.schema.internal.datarow_upload_constants import (
DOWNLOAD_RESULT_PAGE_SIZE,
)
from labelbox.schema.taskstatus import TaskStatus
if TYPE_CHECKING:
from labelbox import User
def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]:
pass
else:
from functools import lru_cache
logger = logging.getLogger(__name__)
[docs]class Task(DbObject):
"""Represents a server-side process that might take a longer time to process.
Allows the Task state to be updated and checked on the client side.
Attributes:
updated_at (datetime)
created_at (datetime)
name (str)
status (str)
completion_percentage (float)
created_by (Relationship): `ToOne` relationship to User
organization (Relationship): `ToOne` relationship to Organization
"""
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
name = Field.String("name")
status = Field.String("status")
status_as_enum = Field.Enum(
TaskStatus, "status_as_enum", "status"
) # additional status for filtering
completion_percentage = Field.Float("completion_percentage")
result_url = Field.String("result_url", "result")
errors_url = Field.String("errors_url", "errors")
type = Field.String("type")
metadata = Field.Json("metadata")
_user: Optional["User"] = None
# Relationships
created_by = Relationship.ToOne("User", False, "created_by")
organization = Relationship.ToOne("Organization")
def __eq__(self, task):
return (
isinstance(task, Task)
and task.uid == self.uid
and task.type == self.type
)
def __hash__(self):
return hash(self.uid)
# Import and upsert have several instances of special casing
def is_creation_task(self) -> bool:
return self.name == "JSON Import" or self.type == "adv-upsert-data-rows"
[docs] def refresh(self) -> None:
"""Refreshes Task data from the server."""
assert self._user is not None
tasks = list(self._user.created_tasks(where=Task.uid == self.uid))
if len(tasks) != 1:
raise ResourceNotFoundError(Task, self.uid)
for field in self.fields():
setattr(self, field.name, getattr(tasks[0], field.name))
if self.is_creation_task():
self.errors_url = self.result_url
def has_errors(self) -> bool:
if self.type == "export-data-rows":
# self.errors fetches the error content.
# This first condition prevents us from downloading the content for v2 exports
return bool(self.errors_url or self.errors)
if self.is_creation_task():
return bool(self.failed_data_rows)
return self.status == "FAILED"
def wait_until_done(
self, timeout_seconds: float = 300.0, check_frequency: float = 2.0
) -> None:
warnings.warn(
"The method wait_until_done for Task is deprecated and will be removed in the next major release. Use the wait_till_done method instead.",
DeprecationWarning,
stacklevel=2,
)
self.wait_till_done(timeout_seconds, check_frequency)
[docs] def wait_till_done(
self, timeout_seconds: float = 300.0, check_frequency: float = 2.0
) -> None:
"""Waits until the task is completed. Periodically queries the server
to update the task attributes.
Args:
timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes.
check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds.
"""
if check_frequency < 2.0:
raise ValueError(
"Expected check frequency to be two seconds or more"
)
while timeout_seconds > 0:
if self.status != "IN_PROGRESS":
if self.has_errors():
logger.warning(
"There are errors present. Please look at `task.errors` for more details"
)
return
logger.debug(
"Task.wait_till_done sleeping for %d seconds", check_frequency
)
time.sleep(check_frequency)
timeout_seconds -= check_frequency
self.refresh()
@property
def errors(self) -> Optional[Dict[str, Any]]:
"""Fetch the error associated with an import task."""
if self.is_creation_task():
if self.status == "FAILED":
result = self._fetch_remote_json()
return result["error"]
elif self.status == "COMPLETE":
return self.failed_data_rows
elif self.type == "export-data-rows":
return self._fetch_remote_json(remote_json_field="errors_url")
elif (
self.type == "add-data-rows-to-batch"
or self.type == "send-to-task-queue"
or self.type == "send-to-annotate"
):
if self.status == "FAILED":
# for these tasks, the error is embedded in the result itself
return json.loads(self.result_url)
return None
@property
def result(self) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
"""Fetch the result for an import task."""
if self.status == "FAILED":
raise ValueError(f"Job failed. Errors : {self.errors}")
else:
result = self._fetch_remote_json()
if self.type == "export-data-rows":
return result
return [
{
"id": data_row["id"],
"external_id": data_row.get("externalId"),
"row_data": data_row["rowData"],
"global_key": data_row.get("globalKey"),
}
for data_row in result["createdDataRows"]
]
@property
def failed_data_rows(self) -> Optional[Dict[str, Any]]:
"""Fetch data rows which failed to be created for an import task."""
result = self._fetch_remote_json()
if len(result.get("errors", [])) > 0:
return result["errors"]
else:
return None
@property
def created_data_rows(self) -> Optional[Dict[str, Any]]:
"""Fetch data rows which successfully created for an import task."""
result = self._fetch_remote_json()
if len(result.get("createdDataRows", [])) > 0:
return result["createdDataRows"]
else:
return None
@lru_cache()
def _fetch_remote_json(
self, remote_json_field: Optional[str] = None
) -> Dict[str, Any]:
"""Function for fetching and caching the result data."""
def download_result(remote_json_field: Optional[str], format: str):
url = getattr(self, remote_json_field or "result_url")
if url is None:
return None
response = requests.get(url)
response.raise_for_status()
if format == "json":
return response.json()
elif format == "ndjson":
return parser.loads(response.text)
else:
raise ValueError(
"Expected the result format to be either `ndjson` or `json`."
)
if self.is_creation_task():
format = "json"
elif self.type == "export-data-rows":
format = "ndjson"
else:
raise ValueError(
"Task result is only supported for `JSON Import` and `export` tasks."
" Download task.result_url manually to access the result for other tasks."
)
if self.status != "IN_PROGRESS":
return download_result(remote_json_field, format)
else:
self.wait_till_done(timeout_seconds=600)
if self.status == "IN_PROGRESS":
raise ValueError(
"Job status still in `IN_PROGRESS`. The result is not available. Call task.wait_till_done() with a larger timeout or contact support."
)
return download_result(remote_json_field, format)
@staticmethod
def get_task(client, task_id):
user: User = client.get_user()
tasks: List[Task] = list(
user.created_tasks(where=Entity.Task.uid == task_id)
)
# Cache user in a private variable as the relationship can't be
# resolved due to server-side limitations (see Task.created_by)
# for more info.
if len(tasks) != 1:
raise ResourceNotFoundError(Entity.Task, {task_id: task_id})
task: Task = tasks[0]
task._user = user
return task
[docs]class DataUpsertTask(Task):
"""
Task class for data row upsert operations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._user = None
@property
def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore
"""
Fetches all results.
Note, for large uploads (>150K data rows), it could take multiple minutes to complete
"""
if self.status == "FAILED":
raise ValueError(f"Job failed. Errors : {self.errors}")
return self._results_as_list()
@property
def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore
"""
Fetches all errors.
Note, for large uploads / large number of errors (>150K), it could take multiple minutes to complete
"""
return self._errors_as_list()
@property
def created_data_rows( # type: ignore
self,
) -> Optional[List[Dict[str, Any]]]:
return self.result
@property
def failed_data_rows( # type: ignore
self,
) -> Optional[List[Dict[str, Any]]]:
return self.errors
def _download_results_paginated(self) -> PaginatedCollection:
page_size = DOWNLOAD_RESULT_PAGE_SIZE
from_cursor = None
query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) {
successesfulDataRowImports(data: { taskId: $taskId, first: $first, from: $from})
{
nodes {
id
externalId
globalKey
rowData
}
after
total
}
}
"""
params = {
"taskId": self.uid,
"first": page_size,
"from": from_cursor,
}
return PaginatedCollection(
client=self.client,
query=query_str,
params=params,
dereferencing=["successesfulDataRowImports", "nodes"],
obj_class=lambda _, data_row: {
"id": data_row.get("id"),
"external_id": data_row.get("externalId"),
"row_data": data_row.get("rowData"),
"global_key": data_row.get("globalKey"),
},
cursor_path=["successesfulDataRowImports", "after"],
)
def _download_errors_paginated(self) -> PaginatedCollection:
page_size = DOWNLOAD_RESULT_PAGE_SIZE # hardcode to avoid overloading the server
from_cursor = None
query_str = """query FailedDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) {
failedDataRowImports(data: { taskId: $taskId, first: $first, from: $from})
{
after
total
results {
message
spec {
externalId
globalKey
rowData
metadata {
schemaId
value
name
}
attachments {
type
value
name
}
}
}
}
}
"""
params = {
"taskId": self.uid,
"first": page_size,
"from": from_cursor,
}
def convert_errors_to_legacy_format(client, data_row):
spec = data_row.get("spec", {})
return {
"message": data_row.get("message"),
"failedDataRows": [
{
"externalId": spec.get("externalId"),
"rowData": spec.get("rowData"),
"globalKey": spec.get("globalKey"),
"metadata": spec.get("metadata", []),
"attachments": spec.get("attachments", []),
}
],
}
return PaginatedCollection(
client=self.client,
query=query_str,
params=params,
dereferencing=["failedDataRowImports", "results"],
obj_class=convert_errors_to_legacy_format,
cursor_path=["failedDataRowImports", "after"],
)
def _results_as_list(self) -> Optional[List[Dict[str, Any]]]:
total_downloaded = 0
results = []
data = self._download_results_paginated()
for row in data:
results.append(row)
total_downloaded += 1
if len(results) == 0:
return None
return results
def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]:
total_downloaded = 0
errors = []
data = self._download_errors_paginated()
for row in data:
errors.append(row)
total_downloaded += 1
if len(errors) == 0:
return None
return errors