from typing import Dict, Iterable, Union
from pathlib import Path
import os
import time
import warnings
from labelbox.pagination import PaginatedCollection
from labelbox.schema.annotation_import import MEAPredictionImport
from labelbox.orm.query import results_query_part
from labelbox.orm.model import Field, Relationship
from labelbox.orm.db_object import DbObject
[docs]class ModelRun(DbObject):
name = Field.String("name")
updated_at = Field.DateTime("updated_at")
created_at = Field.DateTime("created_at")
created_by_id = Field.String("created_by_id", "createdBy")
model_id = Field.String("model_id")
[docs] def upsert_labels(self, label_ids, timeout_seconds=60):
""" Adds data rows and labels to a model run
Args:
label_ids (list): label ids to insert
timeout_seconds (float): Max waiting time, in seconds.
Returns:
ID of newly generated async task
"""
if len(label_ids) < 1:
raise ValueError("Must provide at least one label id")
mutation_name = 'createMEAModelRunLabelRegistrationTask'
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
""" % (mutation_name)
res = self.client.execute(create_task_query_str, {
'modelRunId': self.uid,
'labelIds': label_ids
})
task_id = res[mutation_name]
status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
}
"""
return self._wait_until_done(lambda: self.client.execute(
status_query_str, {'where': {
'id': task_id
}})['MEALabelRegistrationTaskStatus'],
timeout_seconds=timeout_seconds)
[docs] def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
""" Adds data rows to a model run without any associated labels
Args:
data_row_ids (list): data row ids to add to mea
timeout_seconds (float): Max waiting time, in seconds.
Returns:
ID of newly generated async task
"""
if len(data_row_ids) < 1:
raise ValueError("Must provide at least one data row id")
mutation_name = 'createMEAModelRunDataRowRegistrationTask'
create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds : [ID!]!) {
%s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds})}
""" % (mutation_name)
res = self.client.execute(create_task_query_str, {
'modelRunId': self.uid,
'dataRowIds': data_row_ids
})
task_id = res[mutation_name]
status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage}
}
"""
return self._wait_until_done(lambda: self.client.execute(
status_query_str, {'where': {
'id': task_id
}})['MEADataRowRegistrationTaskStatus'],
timeout_seconds=timeout_seconds)
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
while True:
res = status_fn()
if res['status'] == 'COMPLETE':
return True
elif res['status'] == 'FAILED':
raise Exception(
f"MEA Import Failed. Details : {res['errorMessage']}")
timeout_seconds -= sleep_time
if timeout_seconds <= 0:
raise TimeoutError(
f"Unable to complete import within {timeout_seconds} seconds."
)
time.sleep(sleep_time)
[docs] def add_predictions(
self,
name: str,
predictions: Union[str, Path, Iterable[Dict]],
) -> 'MEAPredictionImport': # type: ignore
""" Uploads predictions to a new Editor project.
Args:
name (str): name of the AnnotationImport job
predictions (str or Path or Iterable):
url that is publicly accessible by Labelbox containing an
ndjson file
OR local path to an ndjson file
OR iterable of annotation rows
Returns:
AnnotationImport
"""
kwargs = dict(client=self.client, model_run_id=self.uid, name=name)
if isinstance(predictions, str) or isinstance(predictions, Path):
if os.path.exists(predictions):
return MEAPredictionImport.create_from_file(
path=str(predictions), **kwargs)
else:
return MEAPredictionImport.create_from_url(url=str(predictions),
**kwargs)
elif isinstance(predictions, Iterable):
return MEAPredictionImport.create_from_objects(
predictions=predictions, **kwargs)
else:
raise ValueError(
f'Invalid predictions given of type: {type(predictions)}')
def model_run_data_rows(self):
query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){
annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first)
{nodes{%s},pageInfo{endCursor}}
}
""" % (results_query_part(ModelRunDataRow))
return PaginatedCollection(
self.client, query_str, {'modelRunId': self.uid},
['annotationGroups', 'nodes'],
lambda client, res: ModelRunDataRow(client, self.model_id, res),
['annotationGroups', 'pageInfo', 'endCursor'])
[docs] def annotation_groups(self):
""" `ModelRun.annotation_groups will be removed after 2021-12-06
use ModelRun.model_run_data_rows instead`
"""
warnings.warn(
"`ModelRun.annotation_groups` will be removed after 2021-12-06 use "
"`ModelRun.model_run_data_rows` instead")
return self.model_run_data_rows()
[docs] def delete(self):
""" Deletes specified model run.
Returns:
Query execution success.
"""
ids_param = "ids"
query_str = """mutation DeleteModelRunPyApi($%s: ID!) {
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
self.client.execute(query_str, {ids_param: str(self.uid)})
[docs] def delete_model_run_data_rows(self, data_row_ids):
""" Deletes data rows from model runs.
Args:
data_row_ids (list): List of data row ids to delete from the model run.
Returns:
Query execution success.
"""
model_run_id_param = "modelRunId"
data_row_ids_param = "dataRowIds"
query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) {
deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % (
model_run_id_param, data_row_ids_param, model_run_id_param,
data_row_ids_param)
self.client.execute(query_str, {
model_run_id_param: self.uid,
data_row_ids_param: data_row_ids
})
[docs] def delete_annotation_groups(self, data_row_ids):
""" `ModelRun.delete_annotation_groups will be removed after 2021-12-06
use ModelRun.delete_model_run_data_rows instead`
"""
warnings.warn(
"`ModelRun.delete_annotation_groups` will be removed after 2021-12-06 use "
"`ModelRun.delete_model_run_data_rows` instead")
return self.delete_model_run_data_rows(data_row_ids)
[docs]class ModelRunDataRow(DbObject):
label_id = Field.String("label_id")
model_run_id = Field.String("model_run_id")
data_row = Relationship.ToOne("DataRow", False, cache=True)
def __init__(self, client, model_id, *args, **kwargs):
super().__init__(client, *args, **kwargs)
self.model_id = model_id
@property
def url(self):
app_url = self.client.app_url
endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel"
return endpoint