Source code for labelbox.schema.bulk_import_request

import json
import time
from uuid import UUID, uuid4
import functools

import logging
from pathlib import Path
from google.api_core import retry
from labelbox import parser
import requests
from labelbox import pydantic_compat
from typing_extensions import Literal
from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union,
                    Type, Set, TYPE_CHECKING)

from labelbox import exceptions as lb_exceptions
from labelbox.orm.model import Entity
from labelbox import utils
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
from labelbox.schema.enums import BulkImportRequestState
from labelbox.schema.serialization import serialize_labels

    from labelbox import Project
    from labelbox.types import Label

NDJSON_MIME_TYPE = "application/x-ndjson"
logger = logging.getLogger(__name__)

def _make_file_name(project_id: str, name: str) -> str:
    return f"{project_id}__{name}.ndjson"

# TODO(gszpak): move it to
def _make_request_data(project_id: str, name: str, content_length: int,
                       file_name: str) -> dict:
    query_str = """mutation createBulkImportRequestFromFilePyApi(
            $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
        createBulkImportRequest(data: {
            projectId: $projectId,
            name: $name,
            filePayload: {
                file: $file,
                contentLength: $contentLength
        }) {
    """ % query.results_query_part(BulkImportRequest)
    variables = {
        "projectId": project_id,
        "name": name,
        "file": None,
        "contentLength": content_length
    operations = json.dumps({"variables": variables, "query": query_str})

    return {
        "operations": operations,
        "map": (None, json.dumps({file_name: ["variables.file"]}))

def _send_create_file_command(
        client, request_data: dict, file_name: str,
        file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict:

    response = client.execute(data=request_data, files={file_name: file_data})

    if not response.get("createBulkImportRequest", None):
        raise lb_exceptions.LabelboxError(
            "Failed to create BulkImportRequest, message: %s" %
            response.get("errors", None) or response.get("error", None))

    return response

[docs]class BulkImportRequest(DbObject): """Represents the import job when importing annotations. Attributes: name (str) state (Enum): FAILED, RUNNING, or FINISHED (Refers to the whole import job) input_file_url (str): URL to your web-hosted NDJSON file error_file_url (str): NDJSON that contains error messages for failed annotations status_file_url (str): NDJSON that contains status for each annotation created_at (datetime): UTC timestamp for date BulkImportRequest was created project (Relationship): `ToOne` relationship to Project created_by (Relationship): `ToOne` relationship to User """ name = Field.String("name") state = Field.Enum(BulkImportRequestState, "state") input_file_url = Field.String("input_file_url") error_file_url = Field.String("error_file_url") status_file_url = Field.String("status_file_url") created_at = Field.DateTime("created_at") project = Relationship.ToOne("Project") created_by = Relationship.ToOne("User", False, "created_by") @property def inputs(self) -> List[Dict[str, Any]]: """ Inputs for each individual annotation uploaded. This should match the ndjson annotations that you have uploaded. Returns: Uploaded ndjson. * This information will expire after 24 hours. """ return self._fetch_remote_ndjson(self.input_file_url) @property def errors(self) -> List[Dict[str, Any]]: """ Errors for each individual annotation uploaded. This is a subset of statuses Returns: List of dicts containing error messages. Empty list means there were no errors See `BulkImportRequest.statuses` for more details. * This information will expire after 24 hours. """ self.wait_until_done() return self._fetch_remote_ndjson(self.error_file_url) @property def statuses(self) -> List[Dict[str, Any]]: """ Status for each individual annotation uploaded. Returns: A status for each annotation if the upload is done running. See below table for more details .. list-table:: :widths: 15 150 :header-rows: 1 * - Field - Description * - uuid - Specifies the annotation for the status row. * - dataRow - JSON object containing the Labelbox data row ID for the annotation. * - status - Indicates SUCCESS or FAILURE. * - errors - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info. * This information will expire after 24 hours. """ self.wait_until_done() return self._fetch_remote_ndjson(self.status_file_url) @functools.lru_cache() def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: """ Fetches the remote ndjson file and caches the results. Args: url (str): Can be any url pointing to an ndjson file. Returns: ndjson as a list of dicts. """ response = requests.get(url) response.raise_for_status() return parser.loads(response.text)
[docs] def refresh(self) -> None: """Synchronizes values of all fields with the database. """ query_str, params = query.get_single(BulkImportRequest, self.uid) res = self.client.execute(query_str, params) res = res[utils.camel_case(BulkImportRequest.type_name())] self._set_field_values(res)
[docs] def wait_until_done(self, sleep_time_seconds: int = 5) -> None: """Blocks import job until certain conditions are met. Blocks until the BulkImportRequest.state changes either to `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`, periodically refreshing object's state. Args: sleep_time_seconds (str): a time to block between subsequent API calls """ while self.state == BulkImportRequestState.RUNNING:"Sleeping for {sleep_time_seconds} seconds...") time.sleep(sleep_time_seconds) self.__exponential_backoff_refresh()
@retry.Retry(predicate=retry.if_exception_type(lb_exceptions.ApiLimitError, lb_exceptions.TimeoutError, lb_exceptions.NetworkError)) def __exponential_backoff_refresh(self) -> None: self.refresh() @classmethod def from_name(cls, client, project_id: str, name: str) -> 'BulkImportRequest': """ Fetches existing BulkImportRequest. Args: client (Client): a Labelbox client project_id (str): BulkImportRequest's project id name (str): name of BulkImportRequest Returns: BulkImportRequest object """ query_str = """query getBulkImportRequestPyApi( $projectId: ID!, $name: String!) { bulkImportRequest(where: { projectId: $projectId, name: $name }) { %s } } """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name} response = client.execute(query_str, params=params) return cls(client, response['bulkImportRequest']) @classmethod def create_from_url(cls, client, project_id: str, name: str, url: str, validate=True) -> 'BulkImportRequest': """ Creates a BulkImportRequest from a publicly accessible URL to an ndjson file with predictions. Args: client (Client): a Labelbox client project_id (str): id of project for which predictions will be imported name (str): name of BulkImportRequest url (str): publicly accessible URL pointing to ndjson file containing predictions validate (bool): a flag indicating if there should be a validation if `url` is valid ndjson Returns: BulkImportRequest object """ if validate: logger.warn( "Validation is turned on. The file will be downloaded locally and processed before uploading." ) res = requests.get(url) data = parser.loads(res.text) _validate_ndjson(data, client.get_project(project_id)) query_str = """mutation createBulkImportRequestPyApi( $projectId: ID!, $name: String!, $fileUrl: String!) { createBulkImportRequest(data: { projectId: $projectId, name: $name, fileUrl: $fileUrl }) { %s } } """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name, "fileUrl": url} bulk_import_request_response = client.execute(query_str, params=params) return cls(client, bulk_import_request_response["createBulkImportRequest"]) @classmethod def create_from_objects(cls, client, project_id: str, name: str, predictions: Union[Iterable[Dict], Iterable["Label"]], validate=True) -> 'BulkImportRequest': """ Creates a `BulkImportRequest` from an iterable of dictionaries. Conforms to JSON predictions format, e.g.: ``{ "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", "schemaId": "ckappz7d700gn0zbocmqkwd9i", "dataRow": { "id": "ck1s02fqxm8fi0757f0e6qtdc" }, "bbox": { "top": 48, "left": 58, "height": 865, "width": 1512 } }`` Args: client (Client): a Labelbox client project_id (str): id of project for which predictions will be imported name (str): name of BulkImportRequest predictions (Iterable[dict]): iterable of dictionaries representing predictions validate (bool): a flag indicating if there should be a validation if `predictions` is valid ndjson Returns: BulkImportRequest object """ if not isinstance(predictions, list): raise TypeError( f"annotations must be in a form of Iterable. Found {type(predictions)}" ) ndjson_predictions = serialize_labels(predictions) if validate: _validate_ndjson(ndjson_predictions, client.get_project(project_id)) data_str = parser.dumps(ndjson_predictions) if not data_str: raise ValueError('annotations cannot be empty') data = data_str.encode('utf-8') file_name = _make_file_name(project_id, name) request_data = _make_request_data(project_id, name, len(data_str), file_name) file_data = (file_name, data, NDJSON_MIME_TYPE) response_data = _send_create_file_command(client, request_data=request_data, file_name=file_name, file_data=file_data) return cls(client, response_data["createBulkImportRequest"]) @classmethod def create_from_local_file(cls, client, project_id: str, name: str, file: Path, validate_file=True) -> 'BulkImportRequest': """ Creates a BulkImportRequest from a local ndjson file with predictions. Args: client (Client): a Labelbox client project_id (str): id of project for which predictions will be imported name (str): name of BulkImportRequest file (Path): local ndjson file with predictions validate_file (bool): a flag indicating if there should be a validation if `file` is a valid ndjson file Returns: BulkImportRequest object """ file_name = _make_file_name(project_id, name) content_length = file.stat().st_size request_data = _make_request_data(project_id, name, content_length, file_name) with'rb') as f: if validate_file: reader = parser.reader(f) # ensure that the underlying json load call is valid # # by iterating through the file so we only store # each line in memory rather than the entire file try: _validate_ndjson(reader, client.get_project(project_id)) except ValueError: raise ValueError(f"{file} is not a valid ndjson file") else: file_data = (, f, NDJSON_MIME_TYPE) response_data = _send_create_file_command(client, request_data, file_name, file_data) return cls(client, response_data["createBulkImportRequest"])
[docs] def delete(self) -> None: """ Deletes the import job and also any annotations created by this import. Returns: None """ id_param = "bulk_request_id" query_str = """mutation deleteBulkImportRequestPyApi($%s: ID!) { deleteBulkImportRequest(where: {id: $%s}) { id name } }""" % (id_param, id_param) self.client.execute(query_str, {id_param: self.uid})
def _validate_ndjson(lines: Iterable[Dict[str, Any]], project: "Project") -> None: """ Client side validation of an ndjson object. Does not guarentee that an upload will succeed for the following reasons: * We are not checking the data row types which will cause the following errors to slip through * Missing frame indices will not causes an error for videos * Uploaded annotations for the wrong data type will pass (Eg. entity on images) * We are not checking bounds of an asset (Eg. frame index, image height, text location) Args: lines (Iterable[Dict[str,Any]]): An iterable of ndjson lines project (Project): id of project for which predictions will be imported Raises: MALValidationError: Raise for invalid NDJson UuidError: Duplicate UUID in upload """ feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas( project.ontology()) uids: Set[str] = set() for idx, line in enumerate(lines): try: annotation = NDAnnotation(**line) annotation.validate_instance(feature_schemas_by_id, feature_schemas_by_name) uuid = str(annotation.uuid) if uuid in uids: raise lb_exceptions.UuidError( f'{uuid} already used in this import job, ' 'must be unique for the project.') uids.add(uuid) except (pydantic_compat.ValidationError, ValueError, TypeError, KeyError) as e: raise lb_exceptions.MALValidationError( f"Invalid NDJson on line {idx}") from e #The rest of this file contains objects for MAL validation
[docs]def parse_classification(tool): """ Parses a classification from an ontology. Only radio, checklist, and text are supported for mal Args: tool (dict) Returns: dict """ if tool['type'] in ['radio', 'checklist']: option_schema_ids = [r['featureSchemaId'] for r in tool['options']] option_names = [r['value'] for r in tool['options']] return { 'tool': tool['type'], 'featureSchemaId': tool['featureSchemaId'], 'name': tool['name'], 'options': [*option_schema_ids, *option_names] } elif tool['type'] == 'text': return { 'tool': tool['type'], 'name': tool['name'], 'featureSchemaId': tool['featureSchemaId'] }
[docs]def get_mal_schemas(ontology): """ Converts a project ontology to a dict for easier lookup during ndjson validation Args: ontology (Ontology) Returns: Dict, Dict : Useful for looking up a tool from a given feature schema id or name """ valid_feature_schemas_by_schema_id = {} valid_feature_schemas_by_name = {} for tool in ontology.normalized['tools']: classifications = [ parse_classification(classification_tool) for classification_tool in tool['classifications'] ] classifications_by_schema_id = { v['featureSchemaId']: v for v in classifications } classifications_by_name = {v['name']: v for v in classifications} valid_feature_schemas_by_schema_id[tool['featureSchemaId']] = { 'tool': tool['tool'], 'classificationsBySchemaId': classifications_by_schema_id, 'classificationsByName': classifications_by_name, 'name': tool['name'] } valid_feature_schemas_by_name[tool['name']] = { 'tool': tool['tool'], 'classificationsBySchemaId': classifications_by_schema_id, 'classificationsByName': classifications_by_name, 'name': tool['name'] } for tool in ontology.normalized['classifications']: valid_feature_schemas_by_schema_id[ tool['featureSchemaId']] = parse_classification(tool) valid_feature_schemas_by_name[tool['name']] = parse_classification(tool) return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name
LabelboxID: str = pydantic_compat.Field(..., min_length=25, max_length=25) class Bbox(pydantic_compat.BaseModel): top: float left: float height: float width: float class Point(pydantic_compat.BaseModel): x: float y: float class FrameLocation(pydantic_compat.BaseModel): end: int start: int class VideoSupported(pydantic_compat.BaseModel): #Note that frames are only allowed as top level inferences for video frames: Optional[List[FrameLocation]] #Base class for a special kind of union. # Compatible with pydantic_compat. Improves error messages over a traditional union class SpecialUnion: def __new__(cls, **kwargs): return @classmethod def __get_validators__(cls): yield @classmethod def get_union_types(cls): if not issubclass(cls, SpecialUnion): raise TypeError("{} must be a subclass of SpecialUnion") union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")] if len(union_types) < 1: raise TypeError( "Class {cls} should inherit from a union of objects to build") if len(union_types) > 1: raise TypeError( f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}" ) return union_types[0].__args__[0].__args__ @classmethod def build(cls: Any, data: Union[dict, pydantic_compat.BaseModel]) -> "NDBase": """ Checks through all objects in the union to see which matches the input data. Args: data (Union[dict, pydantic_compat.BaseModel]) : The data for constructing one of the objects in the union raises: KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion pydantic_compat.ValidationError: Error while trying to construct a specific object in the union """ if isinstance(data, pydantic_compat.BaseModel): data = data.dict() top_level_fields = [] max_match = 0 matched = None for type_ in cls.get_union_types(): determinate_fields = type_.Config.determinants(type_) top_level_fields.append(determinate_fields) matches = sum([val in determinate_fields for val in data]) if matches == len(determinate_fields) and matches > max_match: max_match = matches matched = type_ if matched is not None: #These two have the exact same top level keys if matched in [NDRadio, NDText]: if isinstance(data['answer'], dict): matched = NDRadio elif isinstance(data['answer'], str): matched = NDText else: raise TypeError( f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict" ) return matched(**data) else: raise KeyError( f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}." ) @classmethod def schema(cls): results = {'definitions': {}} for cl in cls.get_union_types(): schema = cl.schema() results['definitions'].update(schema.pop('definitions')) results[cl.__name__] = schema return results class DataRow(pydantic_compat.BaseModel): id: str class NDFeatureSchema(pydantic_compat.BaseModel): schemaId: Optional[str] = None name: Optional[str] = None @pydantic_compat.root_validator def must_set_one(cls, values): if values['schemaId'] is None and values['name'] is None: raise ValueError( "Must set either schemaId or name for all feature schemas") return values class NDBase(NDFeatureSchema): ontology_type: str uuid: UUID dataRow: DataRow def validate_feature_schemas(self, valid_feature_schemas_by_id, valid_feature_schemas_by_name): if if not in valid_feature_schemas_by_name: raise ValueError( f"Name {} is not valid for the provided project's ontology." ) if self.ontology_type != valid_feature_schemas_by_name[]['tool']: raise ValueError( f"Name {} does not map to the assigned tool {valid_feature_schemas_by_name[]['tool']}" ) if self.schemaId: if self.schemaId not in valid_feature_schemas_by_id: raise ValueError( f"Schema id {self.schemaId} is not valid for the provided project's ontology." ) if self.ontology_type != valid_feature_schemas_by_id[ self.schemaId]['tool']: raise ValueError( f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}" ) def validate_instance(self, valid_feature_schemas_by_id, valid_feature_schemas_by_name): self.validate_feature_schemas(valid_feature_schemas_by_id, valid_feature_schemas_by_name) class Config: #Users shouldn't to add extra data to the payload extra = 'forbid' @staticmethod def determinants(parent_cls) -> List[str]: #This is a hack for better error messages return [ k for k, v in parent_cls.__fields__.items() if 'determinant' in v.field_info.extra ] ###### Classifications ###### class NDText(NDBase): ontology_type: Literal["text"] = "text" answer: str = pydantic_compat.Field(determinant=True) #No feature schema to check class NDChecklist(VideoSupported, NDBase): ontology_type: Literal["checklist"] = "checklist" answers: List[NDFeatureSchema] = pydantic_compat.Field(determinant=True) @pydantic_compat.validator('answers', pre=True) def validate_answers(cls, value, field): #constr not working with mypy. if not len(value): raise ValueError("Checklist answers should not be empty") return value def validate_feature_schemas(self, valid_feature_schemas_by_id, valid_feature_schemas_by_name): #Test top level feature schema for this tool super(NDChecklist, self).validate_feature_schemas(valid_feature_schemas_by_id, valid_feature_schemas_by_name) #Test the feature schemas provided to the answer field if len(set([ or answer.schemaId for answer in self.answers ])) != len(self.answers): raise ValueError( f"Duplicated featureSchema found for checklist {self.uuid}") for answer in self.answers: options = valid_feature_schemas_by_name[ self. name]['options'] if else valid_feature_schemas_by_id[ self.schemaId]['options'] if not in options and answer.schemaId not in options: raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}" ) class NDRadio(VideoSupported, NDBase): ontology_type: Literal["radio"] = "radio" answer: NDFeatureSchema = pydantic_compat.Field(determinant=True) def validate_feature_schemas(self, valid_feature_schemas_by_id, valid_feature_schemas_by_name): super(NDRadio, self).validate_feature_schemas(valid_feature_schemas_by_id, valid_feature_schemas_by_name) options = valid_feature_schemas_by_name[]['options'] if else valid_feature_schemas_by_id[ self.schemaId]['options'] if not in options and self.answer.schemaId not in options: raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found { or self.answer.schemaId}" ) #A union with custom construction logic to improve error messages class NDClassification( SpecialUnion, Type[Union[ # type: ignore NDText, NDRadio, NDChecklist]]): ... ###### Tools ###### class NDBaseTool(NDBase): classifications: List[NDClassification] = [] #This is indepdent of our problem def validate_feature_schemas(self, valid_feature_schemas_by_id, valid_feature_schemas_by_name): super(NDBaseTool, self).validate_feature_schemas(valid_feature_schemas_by_id, valid_feature_schemas_by_name) for classification in self.classifications: classification.validate_feature_schemas( valid_feature_schemas_by_name[]['classificationsBySchemaId'] if else valid_feature_schemas_by_id[self.schemaId] ['classificationsBySchemaId'], valid_feature_schemas_by_name[]['classificationsByName'] if else valid_feature_schemas_by_id[ self.schemaId]['classificationsByName']) @pydantic_compat.validator('classifications', pre=True) def validate_subclasses(cls, value, field): #Create uuid and datarow id so we don't have to define classification objects twice #This is caused by the fact that we require these ids for top level classifications but not for subclasses results = [] dummy_id = 'child'.center(25, '_') for row in value: results.append({ **row, 'dataRow': { 'id': dummy_id }, 'uuid': str(uuid4()) }) return results class NDPolygon(NDBaseTool): ontology_type: Literal["polygon"] = "polygon" polygon: List[Point] = pydantic_compat.Field(determinant=True) @pydantic_compat.validator('polygon') def is_geom_valid(cls, v): if len(v) < 3: raise ValueError( f"A polygon must have at least 3 points to be valid. Found {v}") return v class NDPolyline(NDBaseTool): ontology_type: Literal["line"] = "line" line: List[Point] = pydantic_compat.Field(determinant=True) @pydantic_compat.validator('line') def is_geom_valid(cls, v): if len(v) < 2: raise ValueError( f"A line must have at least 2 points to be valid. Found {v}") return v class NDRectangle(NDBaseTool): ontology_type: Literal["rectangle"] = "rectangle" bbox: Bbox = pydantic_compat.Field(determinant=True) #Could check if points are positive class NDPoint(NDBaseTool): ontology_type: Literal["point"] = "point" point: Point = pydantic_compat.Field(determinant=True) #Could check if points are positive class EntityLocation(pydantic_compat.BaseModel): start: int end: int class NDTextEntity(NDBaseTool): ontology_type: Literal["named-entity"] = "named-entity" location: EntityLocation = pydantic_compat.Field(determinant=True) @pydantic_compat.validator('location') def is_valid_location(cls, v): if isinstance(v, pydantic_compat.BaseModel): v = v.dict() if len(v) < 2: raise ValueError( f"A line must have at least 2 points to be valid. Found {v}") if v['start'] < 0: raise ValueError(f"Text location must be positive. Found {v}") if v['start'] > v['end']: raise ValueError( f"Text start location must be less or equal than end. Found {v}" ) return v class RLEMaskFeatures(pydantic_compat.BaseModel): counts: List[int] size: List[int] @pydantic_compat.validator('counts') def validate_counts(cls, counts): if not all([count >= 0 for count in counts]): raise ValueError( "Found negative value for counts. They should all be zero or positive" ) return counts @pydantic_compat.validator('size') def validate_size(cls, size): if len(size) != 2: raise ValueError( f"Mask `size` should have two ints representing height and with. Found : {size}" ) if not all([count > 0 for count in size]): raise ValueError( f"Mask `size` should be a postitive int. Found : {size}") return size class PNGMaskFeatures(pydantic_compat.BaseModel): # base64 encoded png bytes png: str class URIMaskFeatures(pydantic_compat.BaseModel): instanceURI: str colorRGB: Union[List[int], Tuple[int, int, int]] @pydantic_compat.validator('colorRGB') def validate_color(cls, colorRGB): #Does the dtype matter? Can it be a float? if not isinstance(colorRGB, (tuple, list)): raise ValueError( f"Received color that is not a list or tuple. Found : {colorRGB}" ) elif len(colorRGB) != 3: raise ValueError( f"Must provide RGB values for segmentation colors. Found : {colorRGB}" ) elif not all([0 <= color <= 255 for color in colorRGB]): raise ValueError( f"All rgb colors must be between 0 and 255. Found : {colorRGB}") return colorRGB class NDMask(NDBaseTool): ontology_type: Literal["superpixel"] = "superpixel" mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = pydantic_compat.Field(determinant=True) #A union with custom construction logic to improve error messages class NDTool( SpecialUnion, Type[Union[ # type: ignore NDMask, NDTextEntity, NDPoint, NDRectangle, NDPolyline, NDPolygon, ]]): ... class NDAnnotation( SpecialUnion, Type[Union[ # type: ignore NDTool, NDClassification]]): @classmethod def build(cls: Any, data) -> "NDBase": if not isinstance(data, dict): raise ValueError('value must be dict') errors = [] for cl in cls.get_union_types(): try: return cl(**data) except KeyError as e: errors.append(f"{cl.__name__}: {e}") raise ValueError('Unable to construct any annotation.\n{}'.format( "\n".join(errors))) @classmethod def schema(cls): data = {'definitions': {}} for type_ in cls.get_union_types(): schema_ = type_.schema() data['definitions'].update(schema_.pop('definitions')) data[type_.__name__] = schema_ return data