Source code for labelbox.schema.annotation_import

import functools
import json
import logging
import os
import time
from collections import defaultdict
from typing import (
    TYPE_CHECKING,
    Any,
    BinaryIO,
    Dict,
    List,
    Optional,
    Union,
    cast,
)

import requests
from google.api_core import retry
from lbox.exceptions import ApiLimitError, NetworkError, ResourceNotFoundError
from tqdm import tqdm  # type: ignore

import labelbox
from labelbox import parser
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.schema.confidence_presence_checker import (
    LabelsConfidencePresenceChecker,
)
from labelbox.schema.enums import AnnotationImportState
from labelbox.schema.serialization import serialize_labels
from labelbox.utils import is_exactly_one_set

if TYPE_CHECKING:
    from labelbox.types import Label

NDJSON_MIME_TYPE = "application/x-ndjson"
ANNOTATION_PER_LABEL_LIMIT = 5000

logger = logging.getLogger(__name__)


[docs]class AnnotationImport(DbObject): name = Field.String("name") state = Field.Enum(AnnotationImportState, "state") input_file_url = Field.String("input_file_url") error_file_url = Field.String("error_file_url") status_file_url = Field.String("status_file_url") progress = Field.String("progress") created_by = Relationship.ToOne("User", False, "created_by") @property def inputs(self) -> List[Dict[str, Any]]: """ Inputs for each individual annotation uploaded. This should match the ndjson annotations that you have uploaded. Returns: Uploaded ndjson. * This information will expire after 24 hours. """ return self._fetch_remote_ndjson(self.input_file_url) @property def errors(self) -> List[Dict[str, Any]]: """ Errors for each individual annotation uploaded. This is a subset of statuses Returns: List of dicts containing error messages. Empty list means there were no errors See `AnnotationImport.statuses` for more details. * This information will expire after 24 hours. """ self.wait_until_done() return self._fetch_remote_ndjson(self.error_file_url) @property def statuses(self) -> List[Dict[str, Any]]: """ Status for each individual annotation uploaded. Returns: A status for each annotation if the upload is done running. See below table for more details .. list-table:: :widths: 15 150 :header-rows: 1 * - Field - Description * - uuid - Specifies the annotation for the status row. * - dataRow - JSON object containing the Labelbox data row ID for the annotation. * - status - Indicates SUCCESS or FAILURE. * - errors - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info. * This information will expire after 24 hours. """ self.wait_until_done() return self._fetch_remote_ndjson(self.status_file_url) def wait_till_done( self, sleep_time_seconds: int = 10, show_progress: bool = False ) -> None: self.wait_until_done(sleep_time_seconds, show_progress)
[docs] def wait_until_done( self, sleep_time_seconds: int = 10, show_progress: bool = False ) -> None: """Blocks import job until certain conditions are met. Blocks until the AnnotationImport.state changes either to `AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`, periodically refreshing object's state. Args: sleep_time_seconds (int): a time to block between subsequent API calls show_progress (bool): should show progress bar """ pbar = ( tqdm( total=100, bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]", ) if show_progress else None ) while self.state.value == AnnotationImportState.RUNNING.value: logger.info(f"Sleeping for {sleep_time_seconds} seconds...") time.sleep(sleep_time_seconds) self.__backoff_refresh() if self.progress and self.progress and pbar: pbar.update(int(self.progress.replace("%", "")) - pbar.n) if pbar: pbar.update(100 - pbar.n) pbar.close()
@retry.Retry( predicate=retry.if_exception_type( ApiLimitError, TimeoutError, NetworkError, ) ) def __backoff_refresh(self) -> None: self.refresh() @functools.lru_cache() def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: """ Fetches the remote ndjson file and caches the results. Args: url (str): Can be any url pointing to an ndjson file. Returns: ndjson as a list of dicts. """ if self.state == AnnotationImportState.FAILED: raise ValueError("Import failed.") response = requests.get(url) response.raise_for_status() return parser.loads(response.text) @classmethod def _create_from_bytes( cls, client, variables, query_str, file_name, bytes_data ) -> Dict[str, Any]: operations = json.dumps({"variables": variables, "query": query_str}) data = { "operations": operations, "map": (None, json.dumps({file_name: ["variables.file"]})), } file_data = (file_name, bytes_data, NDJSON_MIME_TYPE) files = {file_name: file_data} return client.execute(data=data, files=files) @classmethod def _get_ndjson_from_objects( cls, objects: Union[List[Dict[str, Any]], List["Label"]], object_name: str, ) -> BinaryIO: if not isinstance(objects, list): raise TypeError( f"{object_name} must be in a form of list. Found {type(objects)}" ) objects = serialize_labels(objects) cls._validate_data_rows(objects) data_str = parser.dumps(objects) if not data_str: raise ValueError(f"{object_name} cannot be empty") return data_str.encode( "utf-8" ) # NOTICE this method returns bytes, NOT BinaryIO... should have done io.BytesIO(...) but not going to change this at the moment since it works and fools mypy
[docs] def refresh(self) -> None: """Synchronizes values of all fields with the database.""" cls = type(self) res = cls.from_name( self.client, self.parent_id, self.name, as_json=True ) self._set_field_values(res)
@classmethod def _validate_data_rows(cls, objects: List[Dict[str, Any]]): """ Validates annotations by checking 'dataRow' is provided and only one of 'id' or 'globalKey' is provided. Shows up to `max_num_errors` errors if invalidated, to prevent large number of error messages from being printed out """ errors = [] max_num_errors = 100 labels_per_datarow: Dict[str, Dict[str, int]] = defaultdict( lambda: defaultdict(int) ) for object in objects: if "dataRow" not in object: errors.append(f"'dataRow' is missing in {object}") continue data_row_object = object["dataRow"] if not is_exactly_one_set( data_row_object.get("id"), data_row_object.get("globalKey") ): errors.append( f"Must provide only one of 'id' or 'globalKey' for 'dataRow' in {object}" ) else: data_row_id = data_row_object.get( "globalKey" ) or data_row_object.get("id") name = object.get("name") if name: labels_per_datarow[data_row_id][name] += 1 for data_row_id, label_annotations in labels_per_datarow.items(): for label_name, annotations in label_annotations.items(): if annotations > ANNOTATION_PER_LABEL_LIMIT: errors.append( f"Row with id or global key {data_row_id} has {annotations} annotations for label {label_name}.\ Imports are limited to {ANNOTATION_PER_LABEL_LIMIT} annotations per data row." ) if errors: errors_length = len(errors) formatted_errors = "\n".join(errors[:max_num_errors]) if errors_length > max_num_errors: logger.warning( f"Found more than {max_num_errors} errors. Showing first {max_num_errors} error messages..." ) raise ValueError( f"Error while validating annotations. Found {errors_length} annotations with errors. Errors:\n{formatted_errors}" ) @classmethod def from_name( cls, client: "labelbox.Client", parent_id: str, name: str, as_json: bool = False, ): raise NotImplementedError("Inheriting class must override") @property def parent_id(self) -> str: raise NotImplementedError("Inheriting class must override")
[docs]class CreatableAnnotationImport(AnnotationImport): @classmethod def create( cls, client: "labelbox.Client", id: str, name: str, path: Optional[str] = None, url: Optional[str] = None, labels: Union[List[Dict[str, Any]], List["Label"]] = [], ) -> "AnnotationImport": if not is_exactly_one_set(url, labels, path): raise ValueError( "Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions" ) if url: return cls.create_from_url(client, id, name, url) if path: return cls.create_from_file(client, id, name, path) return cls.create_from_objects(client, id, name, labels) @classmethod def create_from_url( cls, client: "labelbox.Client", id: str, name: str, url: str ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_file( cls, client: "labelbox.Client", id: str, name: str, path: str ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_objects( cls, client: "labelbox.Client", id: str, name: str, labels: Union[List[Dict[str, Any]], List["Label"]], ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override")
[docs]class MEAPredictionImport(CreatableAnnotationImport): model_run_id = Field.String("model_run_id") @property def parent_id(self) -> str: """ Identifier for this import. Used to refresh the status """ return self.model_run_id
[docs] @classmethod def create_from_file( cls, client: "labelbox.Client", model_run_id: str, name: str, path: str ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a file of annotations Args: client: Labelbox Client for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations Returns: MEAPredictionImport """ if os.path.exists(path): with open(path, "rb") as f: return cls._create_mea_import_from_bytes( client, model_run_id, name, f, os.stat(path).st_size ) else: raise ValueError(f"File {path} is not accessible")
[docs] @classmethod def create_from_objects( cls, client: "labelbox.Client", model_run_id: str, name, predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MEAPredictionImport": """ Create an MEA prediction import job from an in memory dictionary Args: client: Labelbox Client for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later predictions: List of prediction annotations Returns: MEAPredictionImport """ data = cls._get_ndjson_from_objects(predictions, "annotations") return cls._create_mea_import_from_bytes( client, model_run_id, name, data, len(str(data)) )
[docs] @classmethod def create_from_url( cls, client: "labelbox.Client", model_run_id: str, name: str, url: str ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a url The url must point to a file containing prediction annotations. Args: client: Labelbox Client for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload Returns: MEAPredictionImport """ if requests.head(url): query_str = cls._get_url_mutation() return cls( client, client.execute( query_str, params={ "fileUrl": url, "modelRunId": model_run_id, "name": name, }, )["createModelErrorAnalysisPredictionImport"], ) else: raise ValueError(f"Url {url} is not reachable")
[docs] @classmethod def from_name( cls, client: "labelbox.Client", model_run_id: str, name: str, as_json: bool = False, ) -> "MEAPredictionImport": """ Retrieves an MEA import job. Args: client: Labelbox Client for executing queries model_run_id: ID used for querying import jobs name: Name of the import job. Returns: MEAPredictionImport """ query_str = """query getModelErrorAnalysisPredictionImportPyApi($modelRunId : ID!, $name: String!) { modelErrorAnalysisPredictionImport( where: {modelRunId: $modelRunId, name: $name}){ %s }}""" % query.results_query_part(cls) params = { "modelRunId": model_run_id, "name": name, } response = client.execute(query_str, params) if response is None: raise ResourceNotFoundError(MEAPredictionImport, params) response = response["modelErrorAnalysisPredictionImport"] if as_json: return response return cls(client, response)
@classmethod def _get_url_mutation(cls) -> str: return """mutation createMEAPredictionImportByUrlPyApi($modelRunId : ID!, $name: String!, $fileUrl: String!) { createModelErrorAnalysisPredictionImport(data: { modelRunId: $modelRunId name: $name fileUrl: $fileUrl }) {%s} }""" % query.results_query_part(cls) @classmethod def _get_file_mutation(cls) -> str: return """mutation createMEAPredictionImportByFilePyApi($modelRunId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { createModelErrorAnalysisPredictionImport(data: { modelRunId: $modelRunId name: $name filePayload: { file: $file, contentLength: $contentLength} }) {%s} }""" % query.results_query_part(cls) @classmethod def _create_mea_import_from_bytes( cls, client: "labelbox.Client", model_run_id: str, name: str, bytes_data: BinaryIO, content_len: int, ) -> "MEAPredictionImport": file_name = f"{model_run_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "modelRunId": model_run_id, "name": name, } query_str = cls._get_file_mutation() res = cls._create_from_bytes( client, variables, query_str, file_name, bytes_data, ) return cls(client, res["createModelErrorAnalysisPredictionImport"])
[docs]class MEAToMALPredictionImport(AnnotationImport): project = Relationship.ToOne("Project", cache=True) @property def parent_id(self) -> str: """ Identifier for this import. Used to refresh the status """ return self.project().uid
[docs] @classmethod def create_for_model_run_data_rows( cls, client: "labelbox.Client", model_run_id: str, data_row_ids: List[str], project_id: str, name: str, ) -> "MEAToMALPredictionImport": """ Create an MEA to MAL prediction import job from a list of data row ids of a specific model run Args: client: Labelbox Client for executing queries data_row_ids: A list of data row ids model_run_id: model run id Returns: MEAToMALPredictionImport """ query_str = cls._get_model_run_data_rows_mutation() return cls( client, client.execute( query_str, params={ "dataRowIds": data_row_ids, "modelRunId": model_run_id, "projectId": project_id, "name": name, }, )["createMalPredictionImportForModelRunDataRows"], )
[docs] @classmethod def from_name( cls, client: "labelbox.Client", project_id: str, name: str, as_json: bool = False, ) -> "MEAToMALPredictionImport": """ Retrieves an MEA to MAL import job. Args: client: Labelbox Client for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: MALPredictionImport """ query_str = """query getMEAToMALPredictionImportPyApi($projectId : ID!, $name: String!) { meaToMalPredictionImport( where: {projectId: $projectId, name: $name}){ %s }}""" % query.results_query_part(cls) params = { "projectId": project_id, "name": name, } response = client.execute(query_str, params) if response is None: raise ResourceNotFoundError(MALPredictionImport, params) response = response["meaToMalPredictionImport"] if as_json: return response return cls(client, response)
@classmethod def _get_model_run_data_rows_mutation(cls) -> str: return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) { createMalPredictionImportForModelRunDataRows(data: { name: $name modelRunId: $modelRunId dataRowIds: $dataRowIds projectId: $projectId }) {%s} }""" % query.results_query_part(cls)
[docs]class MALPredictionImport(CreatableAnnotationImport): project = Relationship.ToOne("Project", cache=True) @property def parent_id(self) -> str: """ Identifier for this import. Used to refresh the status """ return self.project().uid
[docs] @classmethod def create_from_file( cls, client: "labelbox.Client", project_id: str, name: str, path: str ) -> "MALPredictionImport": """ Create an MAL prediction import job from a file of annotations Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations Returns: MALPredictionImport """ if os.path.exists(path): with open(path, "rb") as f: return cls._create_mal_import_from_bytes( client, project_id, name, f, os.stat(path).st_size ) else: raise ValueError(f"File {path} is not accessible")
[docs] @classmethod def create_from_objects( cls, client: "labelbox.Client", project_id: str, name: str, predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MALPredictionImport": """ Create an MAL prediction import job from an in memory dictionary Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later predictions: List of prediction annotations Returns: MALPredictionImport """ data = cls._get_ndjson_from_objects(predictions, "annotations") if len(predictions) > 0 and isinstance(predictions[0], Dict): predictions_dicts = cast(List[Dict[str, Any]], predictions) has_confidence = LabelsConfidencePresenceChecker.check( predictions_dicts ) if has_confidence: logger.warning(""" Confidence scores are not supported in MAL Prediction Import. Corresponding confidence score values will be ignored. """) return cls._create_mal_import_from_bytes( client, project_id, name, data, len(str(data)) )
[docs] @classmethod def create_from_url( cls, client: "labelbox.Client", project_id: str, name: str, url: str ) -> "MALPredictionImport": """ Create an MAL prediction import job from a url The url must point to a file containing prediction annotations. Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload Returns: MALPredictionImport """ if requests.head(url): query_str = cls._get_url_mutation() return cls( client, client.execute( query_str, params={ "fileUrl": url, "projectId": project_id, "name": name, }, )["createModelAssistedLabelingPredictionImport"], ) else: raise ValueError(f"Url {url} is not reachable")
[docs] @classmethod def from_name( cls, client: "labelbox.Client", project_id: str, name: str, as_json: bool = False, ) -> "MALPredictionImport": """ Retrieves an MAL import job. Args: client: Labelbox Client for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: MALPredictionImport """ query_str = """query getModelAssistedLabelingPredictionImportPyApi($projectId : ID!, $name: String!) { modelAssistedLabelingPredictionImport( where: {projectId: $projectId, name: $name}){ %s }}""" % query.results_query_part(cls) params = { "projectId": project_id, "name": name, } response = client.execute(query_str, params) if response is None: raise ResourceNotFoundError(MALPredictionImport, params) response = response["modelAssistedLabelingPredictionImport"] if as_json: return response return cls(client, response)
@classmethod def _get_url_mutation(cls) -> str: return """mutation createMALPredictionImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) { createModelAssistedLabelingPredictionImport(data: { projectId: $projectId name: $name fileUrl: $fileUrl }) {%s} }""" % query.results_query_part(cls) @classmethod def _get_file_mutation(cls) -> str: return """mutation createMALPredictionImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { createModelAssistedLabelingPredictionImport(data: { projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength} }) {%s} }""" % query.results_query_part(cls) @classmethod def _create_mal_import_from_bytes( cls, client: "labelbox.Client", project_id: str, name: str, bytes_data: BinaryIO, content_len: int, ) -> "MALPredictionImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, "name": name, } query_str = cls._get_file_mutation() res = cls._create_from_bytes( client, variables, query_str, file_name, bytes_data ) return cls(client, res["createModelAssistedLabelingPredictionImport"])
[docs]class LabelImport(CreatableAnnotationImport): project = Relationship.ToOne("Project", cache=True) @property def parent_id(self) -> str: """ Identifier for this import. Used to refresh the status """ return self.project().uid
[docs] @classmethod def create_from_file( cls, client: "labelbox.Client", project_id: str, name: str, path: str ) -> "LabelImport": """ Create a label import job from a file of annotations Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations Returns: LabelImport """ if os.path.exists(path): with open(path, "rb") as f: return cls._create_label_import_from_bytes( client, project_id, name, f, os.stat(path).st_size ) else: raise ValueError(f"File {path} is not accessible")
[docs] @classmethod def create_from_objects( cls, client: "labelbox.Client", project_id: str, name: str, labels: Union[List[Dict[str, Any]], List["Label"]], ) -> "LabelImport": """ Create a label import job from an in memory dictionary Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later labels: List of labels Returns: LabelImport """ data = cls._get_ndjson_from_objects(labels, "labels") if len(labels) > 0 and isinstance(labels[0], Dict): label_dicts = cast(List[Dict[str, Any]], labels) has_confidence = LabelsConfidencePresenceChecker.check(label_dicts) if has_confidence: logger.warning(""" Confidence scores are not supported in Label Import. Corresponding confidence score values will be ignored. """) return cls._create_label_import_from_bytes( client, project_id, name, data, len(str(data)) )
[docs] @classmethod def create_from_url( cls, client: "labelbox.Client", project_id: str, name: str, url: str ) -> "LabelImport": """ Create a label annotation import job from a url The url must point to a file containing label annotations. Args: client: Labelbox Client for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload Returns: LabelImport """ if requests.head(url): query_str = cls._get_url_mutation() return cls( client, client.execute( query_str, params={ "fileUrl": url, "projectId": project_id, "name": name, }, )["createLabelImport"], ) else: raise ValueError(f"Url {url} is not reachable")
[docs] @classmethod def from_name( cls, client: "labelbox.Client", project_id: str, name: str, as_json: bool = False, ) -> "LabelImport": """ Retrieves an label import job. Args: client: Labelbox Client for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: LabelImport """ query_str = """query getLabelImportPyApi($projectId : ID!, $name: String!) { labelImport( where: {projectId: $projectId, name: $name}){ %s }}""" % query.results_query_part(cls) params = { "projectId": project_id, "name": name, } response = client.execute(query_str, params) if response is None: raise ResourceNotFoundError(LabelImport, params) response = response["labelImport"] if as_json: return response return cls(client, response)
@classmethod def _get_url_mutation(cls) -> str: return """mutation createLabelImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) { createLabelImport(data: { projectId: $projectId name: $name fileUrl: $fileUrl }) {%s} }""" % query.results_query_part(cls) @classmethod def _get_file_mutation(cls) -> str: return """mutation createLabelImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { createLabelImport(data: { projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength} }) {%s} }""" % query.results_query_part(cls) @classmethod def _create_label_import_from_bytes( cls, client: "labelbox.Client", project_id: str, name: str, bytes_data: BinaryIO, content_len: int, ) -> "LabelImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, "name": name, } query_str = cls._get_file_mutation() res = cls._create_from_bytes( client, variables, query_str, file_name, bytes_data ) return cls(client, res["createLabelImport"])