Source code for labelbox.schema.model_run

# type: ignore
from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any
from pathlib import Path
import os
import time
import logging
import requests
import ndjson
from enum import Enum

from labelbox.pagination import PaginatedCollection
from labelbox.orm.query import results_query_part
from labelbox.orm.model import Field, Relationship, Entity
from labelbox.orm.db_object import DbObject, experimental

    from labelbox import MEAPredictionImport

logger = logging.getLogger(__name__)


[docs]class ModelRun(DbObject): name = Field.String("name") updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") created_by_id = Field.String("created_by_id", "createdBy") model_id = Field.String("model_id") training_metadata = Field.Json("training_metadata")
[docs] def upsert_labels(self, label_ids, timeout_seconds=3600): """ Adds data rows and labels to a Model Run Args: label_ids (list): label ids to insert timeout_seconds (float): Max waiting time, in seconds. Returns: ID of newly generated async task """ if len(label_ids) < 1: raise ValueError("Must provide at least one label id") mutation_name = 'createMEAModelRunLabelRegistrationTask' create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) { %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})} """ % (mutation_name) res = self.client.execute(create_task_query_str, { 'modelRunId': self.uid, 'labelIds': label_ids }) task_id = res[mutation_name] status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} } """ return self._wait_until_done(lambda: self.client.execute( status_query_str, {'where': { 'id': task_id }})['MEALabelRegistrationTaskStatus'], timeout_seconds=timeout_seconds)
[docs] def upsert_data_rows(self, data_row_ids, timeout_seconds=3600): """ Adds data rows to a Model Run without any associated labels Args: data_row_ids (list): data row ids to add to mea timeout_seconds (float): Max waiting time, in seconds. Returns: ID of newly generated async task """ if len(data_row_ids) < 1: raise ValueError("Must provide at least one data row id") mutation_name = 'createMEAModelRunDataRowRegistrationTask' create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds : [ID!]!) { %s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds})} """ % (mutation_name) res = self.client.execute(create_task_query_str, { 'modelRunId': self.uid, 'dataRowIds': data_row_ids }) task_id = res[mutation_name] status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage} } """ return self._wait_until_done(lambda: self.client.execute( status_query_str, {'where': { 'id': task_id }})['MEADataRowRegistrationTaskStatus'], timeout_seconds=timeout_seconds)
def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5): # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change. original_timeout = timeout_seconds while True: res = status_fn() if res['status'] == 'COMPLETE': return True elif res['status'] == 'FAILED': raise Exception(f"Jop failed. Details : {res['errorMessage']}") timeout_seconds -= sleep_time if timeout_seconds <= 0: raise TimeoutError( f"Unable to complete import within {original_timeout} seconds." ) time.sleep(sleep_time)
[docs] def upsert_predictions_and_send_to_project( self, name: str, predictions: Union[str, Path, Iterable[Dict]], project_id: str, priority: Optional[int] = 5, ) -> 'MEAPredictionImport': # type: ignore """ Provides a convenient way to execute the following steps in a single function call: 1. Upload predictions to a Model 2. Create a batch from data rows that had predictions assocated with them 3. Attach the batch to a project 4. Add those same predictions to the project as MAL annotations Note that partial successes are possible. If it is important that all stages are successful then check the status of each individual task with task.errors. E.g. >>> mea_import_job, batch, mal_import_job = upsert_predictions_and_send_to_project(name, predictions, project_id) >>> # handle mea import job successfully created (check for job failure or partial failures) >>> print(mea_import_job.status, mea_import_job.errors) >>> if batch is None: >>> # Handle batch creation failure >>> if mal_import_job is None: >>> # Handle mal_import_job creation failure >>> else: >>> # handle mal import job successfully created (check for job failure or partial failures) >>> print(mal_import_job.status, mal_import_job.errors) Args: name (str): name of the AnnotationImport job as well as the name of the batch import predictions (Iterable): iterable of annotation rows project_id (str): id of the project to import into priority (int): priority of the job Returns: Tuple[MEAPredictionImport, Batch, MEAToMALPredictionImport] If any of these steps fail the return value will be None. """ kwargs = dict(client=self.client, model_run_id=self.uid, name=name) project = self.client.get_project(project_id) import_job = self.add_predictions(name, predictions) prediction_statuses = import_job.statuses mea_to_mal_data_rows = list( set([ row['dataRow']['id'] for row in prediction_statuses if row['status'] == 'SUCCESS' ])) if not mea_to_mal_data_rows: # 0 successful model predictions imported return import_job, None, None elif len(mea_to_mal_data_rows) >= DATAROWS_IMPORT_LIMIT: mea_to_mal_data_rows = mea_to_mal_data_rows[:DATAROWS_IMPORT_LIMIT] logger.warning( f"Exeeded max data row limit {len(mea_to_mal_data_rows)}, trimmed down to {DATAROWS_IMPORT_LIMIT} data rows." ) try: batch = project.create_batch(name, mea_to_mal_data_rows, priority) except Exception as e: logger.warning(f"Failed to create batch. Messsage : {e}.") # Unable to create batch return import_job, None, None try: mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( data_row_ids=mea_to_mal_data_rows, project_id=project_id, **kwargs) mal_prediction_import.wait_until_done() except Exception as e: logger.warning( f"Failed to create MEA to MAL prediction import. Message : {e}." ) # Unable to create mea to mal prediction import return import_job, batch, None return import_job, batch, mal_prediction_import
[docs] def add_predictions( self, name: str, predictions: Union[str, Path, Iterable[Dict]], ) -> 'MEAPredictionImport': # type: ignore """ Uploads predictions to a new Editor project. Args: name (str): name of the AnnotationImport job predictions (str or Path or Iterable): url that is publicly accessible by Labelbox containing an ndjson file OR local path to an ndjson file OR iterable of annotation rows Returns: AnnotationImport """ kwargs = dict(client=self.client, model_run_id=self.uid, name=name) if isinstance(predictions, str) or isinstance(predictions, Path): if os.path.exists(predictions): return Entity.MEAPredictionImport.create_from_file( path=str(predictions), **kwargs) else: return Entity.MEAPredictionImport.create_from_url( url=str(predictions), **kwargs) elif isinstance(predictions, Iterable): return Entity.MEAPredictionImport.create_from_objects( predictions=predictions, **kwargs) else: raise ValueError( f'Invalid predictions given of type: {type(predictions)}')
def model_run_data_rows(self): query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first) {nodes{%s},pageInfo{endCursor}} } """ % (results_query_part(ModelRunDataRow)) return PaginatedCollection( self.client, query_str, {'modelRunId': self.uid}, ['annotationGroups', 'nodes'], lambda client, res: ModelRunDataRow(client, self.model_id, res), ['annotationGroups', 'pageInfo', 'endCursor'])
[docs] def delete(self): """ Deletes specified Model Run. Returns: Query execution success. """ ids_param = "ids" query_str = """mutation DeleteModelRunPyApi($%s: ID!) { deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param) self.client.execute(query_str, {ids_param: str(self.uid)})
[docs] def delete_model_run_data_rows(self, data_row_ids: List[str]): """ Deletes data rows from Model Runs. Args: data_row_ids (list): List of data row ids to delete from the Model Run. Returns: Query execution success. """ model_run_id_param = "modelRunId" data_row_ids_param = "dataRowIds" query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) { deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % ( model_run_id_param, data_row_ids_param, model_run_id_param, data_row_ids_param) self.client.execute(query_str, { model_run_id_param: self.uid, data_row_ids_param: data_row_ids })
@experimental def assign_data_rows_to_split(self, data_row_ids: List[str], split: Union[DataSplit, str], timeout_seconds=120): split_value = split.value if isinstance(split, DataSplit) else split valid_splits = DataSplit._member_names_ if split_value not in valid_splits: raise ValueError( f"`split` must be one of : `{valid_splits}`. Found : `{split}`") task_id = self.client.execute( """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){ createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)} """, { 'modelRunId': self.uid, 'data': { 'assignments': [{ 'split': split_value, 'dataRowIds': data_row_ids }] } }, experimental=True)['createAssignDataRowsToDataSplitTask'] status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){ assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}} """ return self._wait_until_done(lambda: self.client.execute( status_query_str, {'id': task_id}, experimental=True)[ 'assignDataRowsToDataSplitTaskStatus'], timeout_seconds=timeout_seconds) @experimental def update_status(self, status: Union[str, "ModelRun.Status"], metadata: Optional[Dict[str, str]] = None, error_message: Optional[str] = None): status_value = status.value if isinstance(status, ModelRun.Status) else status if status_value not in ModelRun.Status._member_names_: raise ValueError( f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`" ) data: Dict[str, Any] = {'status': status_value} if error_message: data['errorMessage'] = error_message if metadata: data['metadata'] = metadata self.client.execute( """mutation setPipelineStatusPyApi($modelRunId: ID!, $data: UpdateTrainingPipelineInput!){ updateTrainingPipeline(modelRun: {id : $modelRunId}, data: $data){status} } """, { 'modelRunId': self.uid, 'data': data }, experimental=True)
[docs] @experimental def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """ Updates the Model Run's training metadata config Args: config (dict): A dictionary of keys and values Returns: Model Run id and updated training metadata """ data: Dict[str, Any] = {'config': config} res = self.client.execute( """mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){ updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata} } """, { 'modelRunId': self.uid, 'data': data }, experimental=True) return res["updateModelRunConfig"]
[docs] @experimental def reset_config(self) -> Dict[str, Any]: """ Resets Model Run's training metadata config Returns: Model Run id and reset training metadata """ res = self.client.execute( """mutation resetModelRunConfigPyApi($modelRunId: ID!){ resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata} } """, {'modelRunId': self.uid}, experimental=True) return res["resetModelRunConfig"]
[docs] @experimental def get_config(self) -> Dict[str, Any]: """ Gets Model Run's training metadata Returns: training metadata as a dictionary """ res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){ modelRun(where: {id : $modelRunId}){trainingMetadata} } """, {'modelRunId': self.uid}, experimental=True) return res["modelRun"]["trainingMetadata"]
[docs] @experimental def export_labels( self, download: bool = False, timeout_seconds: int = 600 ) -> Optional[Union[str, List[Dict[Any, Any]]]]: """ Experimental. To use, make sure client has enable_experimental=True. Fetches Labels from the ModelRun Args: download (bool): Returns the url if False Returns: URL of the data file with this ModelRun's labels. If download=True, this instead returns the contents as NDJSON format. If the server didn't generate during the `timeout_seconds` period, None is returned. """ sleep_time = 2 query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) { exportModelRunAnnotations(data: {modelRunId: $modelRunId}) { downloadUrl createdAt status } } """ while True: url = self.client.execute( query_str, {'modelRunId': self.uid}, experimental=True)['exportModelRunAnnotations']['downloadUrl'] if url: if not download: return url else: response = requests.get(url) response.raise_for_status() return ndjson.loads(response.content) timeout_seconds -= sleep_time if timeout_seconds <= 0: return None logger.debug("ModelRun '%s' label export, waiting for server...", self.uid) time.sleep(sleep_time)
[docs]class ModelRunDataRow(DbObject): label_id = Field.String("label_id") model_run_id = Field.String("model_run_id") data_split = Field.Enum(DataSplit, "data_split") data_row = Relationship.ToOne("DataRow", False, cache=True) def __init__(self, client, model_id, *args, **kwargs): super().__init__(client, *args, **kwargs) self.model_id = model_id @property def url(self): app_url = self.client.app_url endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel" return endpoint