import json
import time
from uuid import UUID, uuid4
import functools
import logging
from pathlib import Path
from google.api_core import retry
from labelbox import parser
import requests
from labelbox import pydantic_compat
from typing_extensions import Literal
from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union,
Type, Set, TYPE_CHECKING)
from labelbox import exceptions as lb_exceptions
from labelbox.orm.model import Entity
from labelbox import utils
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.schema.enums import BulkImportRequestState
from labelbox.schema.serialization import serialize_labels
if TYPE_CHECKING:
from labelbox import Project
from labelbox.types import Label
NDJSON_MIME_TYPE = "application/x-ndjson"
logger = logging.getLogger(__name__)
def _make_file_name(project_id: str, name: str) -> str:
return f"{project_id}__{name}.ndjson"
# TODO(gszpak): move it to client.py
def _make_request_data(project_id: str, name: str, content_length: int,
file_name: str) -> dict:
query_str = """mutation createBulkImportRequestFromFilePyApi(
$projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
createBulkImportRequest(data: {
projectId: $projectId,
name: $name,
filePayload: {
file: $file,
contentLength: $contentLength
}
}) {
%s
}
}
""" % query.results_query_part(BulkImportRequest)
variables = {
"projectId": project_id,
"name": name,
"file": None,
"contentLength": content_length
}
operations = json.dumps({"variables": variables, "query": query_str})
return {
"operations": operations,
"map": (None, json.dumps({file_name: ["variables.file"]}))
}
def _send_create_file_command(
client, request_data: dict, file_name: str,
file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict:
response = client.execute(data=request_data, files={file_name: file_data})
if not response.get("createBulkImportRequest", None):
raise lb_exceptions.LabelboxError(
"Failed to create BulkImportRequest, message: %s" %
response.get("errors", None) or response.get("error", None))
return response
[docs]class BulkImportRequest(DbObject):
"""Represents the import job when importing annotations.
Attributes:
name (str)
state (Enum): FAILED, RUNNING, or FINISHED (Refers to the whole import job)
input_file_url (str): URL to your web-hosted NDJSON file
error_file_url (str): NDJSON that contains error messages for failed annotations
status_file_url (str): NDJSON that contains status for each annotation
created_at (datetime): UTC timestamp for date BulkImportRequest was created
project (Relationship): `ToOne` relationship to Project
created_by (Relationship): `ToOne` relationship to User
"""
name = Field.String("name")
state = Field.Enum(BulkImportRequestState, "state")
input_file_url = Field.String("input_file_url")
error_file_url = Field.String("error_file_url")
status_file_url = Field.String("status_file_url")
created_at = Field.DateTime("created_at")
project = Relationship.ToOne("Project")
created_by = Relationship.ToOne("User", False, "created_by")
@property
def inputs(self) -> List[Dict[str, Any]]:
"""
Inputs for each individual annotation uploaded.
This should match the ndjson annotations that you have uploaded.
Returns:
Uploaded ndjson.
* This information will expire after 24 hours.
"""
return self._fetch_remote_ndjson(self.input_file_url)
@property
def errors(self) -> List[Dict[str, Any]]:
"""
Errors for each individual annotation uploaded. This is a subset of statuses
Returns:
List of dicts containing error messages. Empty list means there were no errors
See `BulkImportRequest.statuses` for more details.
* This information will expire after 24 hours.
"""
self.wait_until_done()
return self._fetch_remote_ndjson(self.error_file_url)
@property
def statuses(self) -> List[Dict[str, Any]]:
"""
Status for each individual annotation uploaded.
Returns:
A status for each annotation if the upload is done running.
See below table for more details
.. list-table::
:widths: 15 150
:header-rows: 1
* - Field
- Description
* - uuid
- Specifies the annotation for the status row.
* - dataRow
- JSON object containing the Labelbox data row ID for the annotation.
* - status
- Indicates SUCCESS or FAILURE.
* - errors
- An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
* This information will expire after 24 hours.
"""
self.wait_until_done()
return self._fetch_remote_ndjson(self.status_file_url)
@functools.lru_cache()
def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
"""
Fetches the remote ndjson file and caches the results.
Args:
url (str): Can be any url pointing to an ndjson file.
Returns:
ndjson as a list of dicts.
"""
response = requests.get(url)
response.raise_for_status()
return parser.loads(response.text)
[docs] def refresh(self) -> None:
"""Synchronizes values of all fields with the database.
"""
query_str, params = query.get_single(BulkImportRequest, self.uid)
res = self.client.execute(query_str, params)
res = res[utils.camel_case(BulkImportRequest.type_name())]
self._set_field_values(res)
[docs] def wait_until_done(self, sleep_time_seconds: int = 5) -> None:
"""Blocks import job until certain conditions are met.
Blocks until the BulkImportRequest.state changes either to
`BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`,
periodically refreshing object's state.
Args:
sleep_time_seconds (str): a time to block between subsequent API calls
"""
while self.state == BulkImportRequestState.RUNNING:
logger.info(f"Sleeping for {sleep_time_seconds} seconds...")
time.sleep(sleep_time_seconds)
self.__exponential_backoff_refresh()
@retry.Retry(predicate=retry.if_exception_type(lb_exceptions.ApiLimitError,
lb_exceptions.TimeoutError,
lb_exceptions.NetworkError))
def __exponential_backoff_refresh(self) -> None:
self.refresh()
@classmethod
def from_name(cls, client, project_id: str,
name: str) -> 'BulkImportRequest':
""" Fetches existing BulkImportRequest.
Args:
client (Client): a Labelbox client
project_id (str): BulkImportRequest's project id
name (str): name of BulkImportRequest
Returns:
BulkImportRequest object
"""
query_str = """query getBulkImportRequestPyApi(
$projectId: ID!, $name: String!) {
bulkImportRequest(where: {
projectId: $projectId,
name: $name
}) {
%s
}
}
""" % query.results_query_part(cls)
params = {"projectId": project_id, "name": name}
response = client.execute(query_str, params=params)
return cls(client, response['bulkImportRequest'])
@classmethod
def create_from_url(cls,
client,
project_id: str,
name: str,
url: str,
validate=True) -> 'BulkImportRequest':
"""
Creates a BulkImportRequest from a publicly accessible URL
to an ndjson file with predictions.
Args:
client (Client): a Labelbox client
project_id (str): id of project for which predictions will be imported
name (str): name of BulkImportRequest
url (str): publicly accessible URL pointing to ndjson file containing predictions
validate (bool): a flag indicating if there should be a validation
if `url` is valid ndjson
Returns:
BulkImportRequest object
"""
if validate:
logger.warn(
"Validation is turned on. The file will be downloaded locally and processed before uploading."
)
res = requests.get(url)
data = parser.loads(res.text)
_validate_ndjson(data, client.get_project(project_id))
query_str = """mutation createBulkImportRequestPyApi(
$projectId: ID!, $name: String!, $fileUrl: String!) {
createBulkImportRequest(data: {
projectId: $projectId,
name: $name,
fileUrl: $fileUrl
}) {
%s
}
}
""" % query.results_query_part(cls)
params = {"projectId": project_id, "name": name, "fileUrl": url}
bulk_import_request_response = client.execute(query_str, params=params)
return cls(client,
bulk_import_request_response["createBulkImportRequest"])
@classmethod
def create_from_objects(cls,
client,
project_id: str,
name: str,
predictions: Union[Iterable[Dict],
Iterable["Label"]],
validate=True) -> 'BulkImportRequest':
"""
Creates a `BulkImportRequest` from an iterable of dictionaries.
Conforms to JSON predictions format, e.g.:
``{
"uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092",
"schemaId": "ckappz7d700gn0zbocmqkwd9i",
"dataRow": {
"id": "ck1s02fqxm8fi0757f0e6qtdc"
},
"bbox": {
"top": 48,
"left": 58,
"height": 865,
"width": 1512
}
}``
Args:
client (Client): a Labelbox client
project_id (str): id of project for which predictions will be imported
name (str): name of BulkImportRequest
predictions (Iterable[dict]): iterable of dictionaries representing predictions
validate (bool): a flag indicating if there should be a validation
if `predictions` is valid ndjson
Returns:
BulkImportRequest object
"""
if not isinstance(predictions, list):
raise TypeError(
f"annotations must be in a form of Iterable. Found {type(predictions)}"
)
ndjson_predictions = serialize_labels(predictions)
if validate:
_validate_ndjson(ndjson_predictions, client.get_project(project_id))
data_str = parser.dumps(ndjson_predictions)
if not data_str:
raise ValueError('annotations cannot be empty')
data = data_str.encode('utf-8')
file_name = _make_file_name(project_id, name)
request_data = _make_request_data(project_id, name, len(data_str),
file_name)
file_data = (file_name, data, NDJSON_MIME_TYPE)
response_data = _send_create_file_command(client,
request_data=request_data,
file_name=file_name,
file_data=file_data)
return cls(client, response_data["createBulkImportRequest"])
@classmethod
def create_from_local_file(cls,
client,
project_id: str,
name: str,
file: Path,
validate_file=True) -> 'BulkImportRequest':
"""
Creates a BulkImportRequest from a local ndjson file with predictions.
Args:
client (Client): a Labelbox client
project_id (str): id of project for which predictions will be imported
name (str): name of BulkImportRequest
file (Path): local ndjson file with predictions
validate_file (bool): a flag indicating if there should be a validation
if `file` is a valid ndjson file
Returns:
BulkImportRequest object
"""
file_name = _make_file_name(project_id, name)
content_length = file.stat().st_size
request_data = _make_request_data(project_id, name, content_length,
file_name)
with file.open('rb') as f:
if validate_file:
reader = parser.reader(f)
# ensure that the underlying json load call is valid
# https://github.com/rhgrant10/ndjson/blob/ff2f03c56b21f28f7271b27da35ca4a8bf9a05d0/ndjson/api.py#L53
# by iterating through the file so we only store
# each line in memory rather than the entire file
try:
_validate_ndjson(reader, client.get_project(project_id))
except ValueError:
raise ValueError(f"{file} is not a valid ndjson file")
else:
f.seek(0)
file_data = (file.name, f, NDJSON_MIME_TYPE)
response_data = _send_create_file_command(client, request_data,
file_name, file_data)
return cls(client, response_data["createBulkImportRequest"])
[docs] def delete(self) -> None:
""" Deletes the import job and also any annotations created by this import.
Returns:
None
"""
id_param = "bulk_request_id"
query_str = """mutation deleteBulkImportRequestPyApi($%s: ID!) {
deleteBulkImportRequest(where: {id: $%s}) {
id
name
}
}""" % (id_param, id_param)
self.client.execute(query_str, {id_param: self.uid})
def _validate_ndjson(lines: Iterable[Dict[str, Any]],
project: "Project") -> None:
"""
Client side validation of an ndjson object.
Does not guarentee that an upload will succeed for the following reasons:
* We are not checking the data row types which will cause the following errors to slip through
* Missing frame indices will not causes an error for videos
* Uploaded annotations for the wrong data type will pass (Eg. entity on images)
* We are not checking bounds of an asset (Eg. frame index, image height, text location)
Args:
lines (Iterable[Dict[str,Any]]): An iterable of ndjson lines
project (Project): id of project for which predictions will be imported
Raises:
MALValidationError: Raise for invalid NDJson
UuidError: Duplicate UUID in upload
"""
feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas(
project.ontology())
uids: Set[str] = set()
for idx, line in enumerate(lines):
try:
annotation = NDAnnotation(**line)
annotation.validate_instance(feature_schemas_by_id,
feature_schemas_by_name)
uuid = str(annotation.uuid)
if uuid in uids:
raise lb_exceptions.UuidError(
f'{uuid} already used in this import job, '
'must be unique for the project.')
uids.add(uuid)
except (pydantic_compat.ValidationError, ValueError, TypeError,
KeyError) as e:
raise lb_exceptions.MALValidationError(
f"Invalid NDJson on line {idx}") from e
#The rest of this file contains objects for MAL validation
[docs]def parse_classification(tool):
"""
Parses a classification from an ontology. Only radio, checklist, and text are supported for mal
Args:
tool (dict)
Returns:
dict
"""
if tool['type'] in ['radio', 'checklist']:
option_schema_ids = [r['featureSchemaId'] for r in tool['options']]
option_names = [r['value'] for r in tool['options']]
return {
'tool': tool['type'],
'featureSchemaId': tool['featureSchemaId'],
'name': tool['name'],
'options': [*option_schema_ids, *option_names]
}
elif tool['type'] == 'text':
return {
'tool': tool['type'],
'name': tool['name'],
'featureSchemaId': tool['featureSchemaId']
}
[docs]def get_mal_schemas(ontology):
"""
Converts a project ontology to a dict for easier lookup during ndjson validation
Args:
ontology (Ontology)
Returns:
Dict, Dict : Useful for looking up a tool from a given feature schema id or name
"""
valid_feature_schemas_by_schema_id = {}
valid_feature_schemas_by_name = {}
for tool in ontology.normalized['tools']:
classifications = [
parse_classification(classification_tool)
for classification_tool in tool['classifications']
]
classifications_by_schema_id = {
v['featureSchemaId']: v for v in classifications
}
classifications_by_name = {v['name']: v for v in classifications}
valid_feature_schemas_by_schema_id[tool['featureSchemaId']] = {
'tool': tool['tool'],
'classificationsBySchemaId': classifications_by_schema_id,
'classificationsByName': classifications_by_name,
'name': tool['name']
}
valid_feature_schemas_by_name[tool['name']] = {
'tool': tool['tool'],
'classificationsBySchemaId': classifications_by_schema_id,
'classificationsByName': classifications_by_name,
'name': tool['name']
}
for tool in ontology.normalized['classifications']:
valid_feature_schemas_by_schema_id[
tool['featureSchemaId']] = parse_classification(tool)
valid_feature_schemas_by_name[tool['name']] = parse_classification(tool)
return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name
LabelboxID: str = pydantic_compat.Field(..., min_length=25, max_length=25)
class Bbox(pydantic_compat.BaseModel):
top: float
left: float
height: float
width: float
class Point(pydantic_compat.BaseModel):
x: float
y: float
class FrameLocation(pydantic_compat.BaseModel):
end: int
start: int
class VideoSupported(pydantic_compat.BaseModel):
#Note that frames are only allowed as top level inferences for video
frames: Optional[List[FrameLocation]]
#Base class for a special kind of union.
# Compatible with pydantic_compat. Improves error messages over a traditional union
class SpecialUnion:
def __new__(cls, **kwargs):
return cls.build(kwargs)
@classmethod
def __get_validators__(cls):
yield cls.build
@classmethod
def get_union_types(cls):
if not issubclass(cls, SpecialUnion):
raise TypeError("{} must be a subclass of SpecialUnion")
union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
if len(union_types) < 1:
raise TypeError(
"Class {cls} should inherit from a union of objects to build")
if len(union_types) > 1:
raise TypeError(
f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}"
)
return union_types[0].__args__[0].__args__
@classmethod
def build(cls: Any, data: Union[dict,
pydantic_compat.BaseModel]) -> "NDBase":
"""
Checks through all objects in the union to see which matches the input data.
Args:
data (Union[dict, pydantic_compat.BaseModel]) : The data for constructing one of the objects in the union
raises:
KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
pydantic_compat.ValidationError: Error while trying to construct a specific object in the union
"""
if isinstance(data, pydantic_compat.BaseModel):
data = data.dict()
top_level_fields = []
max_match = 0
matched = None
for type_ in cls.get_union_types():
determinate_fields = type_.Config.determinants(type_)
top_level_fields.append(determinate_fields)
matches = sum([val in determinate_fields for val in data])
if matches == len(determinate_fields) and matches > max_match:
max_match = matches
matched = type_
if matched is not None:
#These two have the exact same top level keys
if matched in [NDRadio, NDText]:
if isinstance(data['answer'], dict):
matched = NDRadio
elif isinstance(data['answer'], str):
matched = NDText
else:
raise TypeError(
f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict"
)
return matched(**data)
else:
raise KeyError(
f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}."
)
@classmethod
def schema(cls):
results = {'definitions': {}}
for cl in cls.get_union_types():
schema = cl.schema()
results['definitions'].update(schema.pop('definitions'))
results[cl.__name__] = schema
return results
class DataRow(pydantic_compat.BaseModel):
id: str
class NDFeatureSchema(pydantic_compat.BaseModel):
schemaId: Optional[str] = None
name: Optional[str] = None
@pydantic_compat.root_validator
def must_set_one(cls, values):
if values['schemaId'] is None and values['name'] is None:
raise ValueError(
"Must set either schemaId or name for all feature schemas")
return values
class NDBase(NDFeatureSchema):
ontology_type: str
uuid: UUID
dataRow: DataRow
def validate_feature_schemas(self, valid_feature_schemas_by_id,
valid_feature_schemas_by_name):
if self.name:
if self.name not in valid_feature_schemas_by_name:
raise ValueError(
f"Name {self.name} is not valid for the provided project's ontology."
)
if self.ontology_type != valid_feature_schemas_by_name[
self.name]['tool']:
raise ValueError(
f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}"
)
if self.schemaId:
if self.schemaId not in valid_feature_schemas_by_id:
raise ValueError(
f"Schema id {self.schemaId} is not valid for the provided project's ontology."
)
if self.ontology_type != valid_feature_schemas_by_id[
self.schemaId]['tool']:
raise ValueError(
f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}"
)
def validate_instance(self, valid_feature_schemas_by_id,
valid_feature_schemas_by_name):
self.validate_feature_schemas(valid_feature_schemas_by_id,
valid_feature_schemas_by_name)
class Config:
#Users shouldn't to add extra data to the payload
extra = 'forbid'
@staticmethod
def determinants(parent_cls) -> List[str]:
#This is a hack for better error messages
return [
k for k, v in parent_cls.__fields__.items()
if 'determinant' in v.field_info.extra
]
###### Classifications ######
class NDText(NDBase):
ontology_type: Literal["text"] = "text"
answer: str = pydantic_compat.Field(determinant=True)
#No feature schema to check
class NDChecklist(VideoSupported, NDBase):
ontology_type: Literal["checklist"] = "checklist"
answers: List[NDFeatureSchema] = pydantic_compat.Field(determinant=True)
@pydantic_compat.validator('answers', pre=True)
def validate_answers(cls, value, field):
#constr not working with mypy.
if not len(value):
raise ValueError("Checklist answers should not be empty")
return value
def validate_feature_schemas(self, valid_feature_schemas_by_id,
valid_feature_schemas_by_name):
#Test top level feature schema for this tool
super(NDChecklist,
self).validate_feature_schemas(valid_feature_schemas_by_id,
valid_feature_schemas_by_name)
#Test the feature schemas provided to the answer field
if len(set([answer.name or answer.schemaId for answer in self.answers
])) != len(self.answers):
raise ValueError(
f"Duplicated featureSchema found for checklist {self.uuid}")
for answer in self.answers:
options = valid_feature_schemas_by_name[
self.
name]['options'] if self.name else valid_feature_schemas_by_id[
self.schemaId]['options']
if answer.name not in options and answer.schemaId not in options:
raise ValueError(
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}"
)
class NDRadio(VideoSupported, NDBase):
ontology_type: Literal["radio"] = "radio"
answer: NDFeatureSchema = pydantic_compat.Field(determinant=True)
def validate_feature_schemas(self, valid_feature_schemas_by_id,
valid_feature_schemas_by_name):
super(NDRadio,
self).validate_feature_schemas(valid_feature_schemas_by_id,
valid_feature_schemas_by_name)
options = valid_feature_schemas_by_name[
self.name]['options'] if self.name else valid_feature_schemas_by_id[
self.schemaId]['options']
if self.answer.name not in options and self.answer.schemaId not in options:
raise ValueError(
f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}"
)
#A union with custom construction logic to improve error messages
class NDClassification(
SpecialUnion,
Type[Union[ # type: ignore
NDText, NDRadio, NDChecklist]]):
...
###### Tools ######
class NDBaseTool(NDBase):
classifications: List[NDClassification] = []
#This is indepdent of our problem
def validate_feature_schemas(self, valid_feature_schemas_by_id,
valid_feature_schemas_by_name):
super(NDBaseTool,
self).validate_feature_schemas(valid_feature_schemas_by_id,
valid_feature_schemas_by_name)
for classification in self.classifications:
classification.validate_feature_schemas(
valid_feature_schemas_by_name[
self.name]['classificationsBySchemaId']
if self.name else valid_feature_schemas_by_id[self.schemaId]
['classificationsBySchemaId'], valid_feature_schemas_by_name[
self.name]['classificationsByName']
if self.name else valid_feature_schemas_by_id[
self.schemaId]['classificationsByName'])
@pydantic_compat.validator('classifications', pre=True)
def validate_subclasses(cls, value, field):
#Create uuid and datarow id so we don't have to define classification objects twice
#This is caused by the fact that we require these ids for top level classifications but not for subclasses
results = []
dummy_id = 'child'.center(25, '_')
for row in value:
results.append({
**row, 'dataRow': {
'id': dummy_id
},
'uuid': str(uuid4())
})
return results
class NDPolygon(NDBaseTool):
ontology_type: Literal["polygon"] = "polygon"
polygon: List[Point] = pydantic_compat.Field(determinant=True)
@pydantic_compat.validator('polygon')
def is_geom_valid(cls, v):
if len(v) < 3:
raise ValueError(
f"A polygon must have at least 3 points to be valid. Found {v}")
return v
class NDPolyline(NDBaseTool):
ontology_type: Literal["line"] = "line"
line: List[Point] = pydantic_compat.Field(determinant=True)
@pydantic_compat.validator('line')
def is_geom_valid(cls, v):
if len(v) < 2:
raise ValueError(
f"A line must have at least 2 points to be valid. Found {v}")
return v
class NDRectangle(NDBaseTool):
ontology_type: Literal["rectangle"] = "rectangle"
bbox: Bbox = pydantic_compat.Field(determinant=True)
#Could check if points are positive
class NDPoint(NDBaseTool):
ontology_type: Literal["point"] = "point"
point: Point = pydantic_compat.Field(determinant=True)
#Could check if points are positive
class EntityLocation(pydantic_compat.BaseModel):
start: int
end: int
class NDTextEntity(NDBaseTool):
ontology_type: Literal["named-entity"] = "named-entity"
location: EntityLocation = pydantic_compat.Field(determinant=True)
@pydantic_compat.validator('location')
def is_valid_location(cls, v):
if isinstance(v, pydantic_compat.BaseModel):
v = v.dict()
if len(v) < 2:
raise ValueError(
f"A line must have at least 2 points to be valid. Found {v}")
if v['start'] < 0:
raise ValueError(f"Text location must be positive. Found {v}")
if v['start'] > v['end']:
raise ValueError(
f"Text start location must be less or equal than end. Found {v}"
)
return v
class RLEMaskFeatures(pydantic_compat.BaseModel):
counts: List[int]
size: List[int]
@pydantic_compat.validator('counts')
def validate_counts(cls, counts):
if not all([count >= 0 for count in counts]):
raise ValueError(
"Found negative value for counts. They should all be zero or positive"
)
return counts
@pydantic_compat.validator('size')
def validate_size(cls, size):
if len(size) != 2:
raise ValueError(
f"Mask `size` should have two ints representing height and with. Found : {size}"
)
if not all([count > 0 for count in size]):
raise ValueError(
f"Mask `size` should be a postitive int. Found : {size}")
return size
class PNGMaskFeatures(pydantic_compat.BaseModel):
# base64 encoded png bytes
png: str
class URIMaskFeatures(pydantic_compat.BaseModel):
instanceURI: str
colorRGB: Union[List[int], Tuple[int, int, int]]
@pydantic_compat.validator('colorRGB')
def validate_color(cls, colorRGB):
#Does the dtype matter? Can it be a float?
if not isinstance(colorRGB, (tuple, list)):
raise ValueError(
f"Received color that is not a list or tuple. Found : {colorRGB}"
)
elif len(colorRGB) != 3:
raise ValueError(
f"Must provide RGB values for segmentation colors. Found : {colorRGB}"
)
elif not all([0 <= color <= 255 for color in colorRGB]):
raise ValueError(
f"All rgb colors must be between 0 and 255. Found : {colorRGB}")
return colorRGB
class NDMask(NDBaseTool):
ontology_type: Literal["superpixel"] = "superpixel"
mask: Union[URIMaskFeatures, PNGMaskFeatures,
RLEMaskFeatures] = pydantic_compat.Field(determinant=True)
#A union with custom construction logic to improve error messages
class NDTool(
SpecialUnion,
Type[Union[ # type: ignore
NDMask,
NDTextEntity,
NDPoint,
NDRectangle,
NDPolyline,
NDPolygon,
]]):
...
class NDAnnotation(
SpecialUnion,
Type[Union[ # type: ignore
NDTool, NDClassification]]):
@classmethod
def build(cls: Any, data) -> "NDBase":
if not isinstance(data, dict):
raise ValueError('value must be dict')
errors = []
for cl in cls.get_union_types():
try:
return cl(**data)
except KeyError as e:
errors.append(f"{cl.__name__}: {e}")
raise ValueError('Unable to construct any annotation.\n{}'.format(
"\n".join(errors)))
@classmethod
def schema(cls):
data = {'definitions': {}}
for type_ in cls.get_union_types():
schema_ = type_.schema()
data['definitions'].update(schema_.pop('definitions'))
data[type_.__name__] = schema_
return data