Source code for lbox.request_client

# for the Labelbox Python SDK
import json
import logging
import os
from datetime import datetime, timezone
from types import MappingProxyType
from typing import Callable, Dict, Optional

import requests
import requests.exceptions
from google.api_core import retry
from lbox import exceptions
from lbox.call_info import call_info_as_str, python_version_info  # type: ignore

logger = logging.getLogger(__name__)

_LABELBOX_API_KEY = "LABELBOX_API_KEY"


[docs]class RequestClient: """A Labelbox request client. Contains info necessary for connecting to a Labelbox server (URL, authentication key). """ def __init__( self, sdk_version, api_key=None, endpoint="https://api.labelbox.com/graphql", enable_experimental=False, app_url="https://app.labelbox.com", rest_endpoint="https://api.labelbox.com/api/v1", ): """Creates and initializes a RequestClient. This class executes graphql and rest requests to the Labelbox server. Args: api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. endpoint (str): URL of the Labelbox server to connect to. enable_experimental (bool): Indicates whether or not to use experimental features app_url (str) : host url for all links to the web app Raises: exceptions.AuthenticationError: If no `api_key` is provided as an argument or via the environment variable. """ if api_key is None: if _LABELBOX_API_KEY not in os.environ: raise exceptions.AuthenticationError("Labelbox API key not provided") api_key = os.environ[_LABELBOX_API_KEY] self.api_key = api_key self.enable_experimental = enable_experimental if enable_experimental: logger.info("Experimental features have been enabled") logger.info("Initializing Labelbox client at '%s'", endpoint) self.app_url = app_url self.endpoint = endpoint self.rest_endpoint = rest_endpoint self.sdk_version = sdk_version self._sdk_method = None self._connection: requests.Session = self._init_connection() def _init_connection(self) -> requests.Session: connection = requests.Session() # using default connection pool size of 10 connection.headers.update(self._default_headers()) return connection @property def headers(self) -> MappingProxyType: return self._connection.headers @property def sdk_method(self): return self._sdk_method @sdk_method.setter def sdk_method(self, value): self._sdk_method = value def _default_headers(self): return { "Authorization": "Bearer %s" % self.api_key, "Accept": "application/json", "Content-Type": "application/json", "X-User-Agent": f"python-sdk {self.sdk_version}", "X-Python-Version": f"{python_version_info()}", }
[docs] @retry.Retry( predicate=retry.if_exception_type( exceptions.InternalServerError, exceptions.TimeoutError, ) ) def execute( self, query=None, params=None, data=None, files=None, timeout=60.0, experimental=False, error_log_key="message", raise_return_resource_not_found=False, error_handlers: Optional[ Dict[str, Callable[[requests.models.Response], None]] ] = None, ): """Sends a request to the server for the execution of the given query. Checks the response for errors and wraps errors in appropriate `exceptions.LabelboxError` subtypes. Args: query (str): The query to execute. params (dict): Query parameters referenced within the query. data (str): json string containing the query to execute files (dict): file arguments for request timeout (float): Max allowed time for query execution, in seconds. raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found. If this is set to True, the client will raise a ResourceNotFoundError exception automatically. This simplifies processing. We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query. error_handlers (dict): A dictionary mapping graphql error code to handler functions. Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages. Example - custom error handler: >>> def _raise_readable_errors(self, response): >>> errors = response.json().get('errors', []) >>> if errors: >>> message = errors[0].get( >>> 'message', json.dumps([{ >>> "errorMessage": "Unknown error" >>> }])) >>> errors = json.loads(message) >>> error_messages = [error['errorMessage'] for error in errors] >>> else: >>> error_messages = ["Uknown error"] >>> raise LabelboxError(". ".join(error_messages)) Returns: dict, parsed JSON response. Raises: exceptions.AuthenticationError: If authentication failed. exceptions.InvalidQueryError: If `query` is not syntactically or semantically valid (checked server-side). exceptions.ApiLimitError: If the server API limit was exceeded. See "How to import data" in the online documentation to see API limits. exceptions.TimeoutError: If response was not received in `timeout` seconds. exceptions.NetworkError: If an unknown error occurred most likely due to connection issues. exceptions.LabelboxError: If an unknown error of any kind occurred. ValueError: If query and data are both None. """ logger.debug("Query: %s, params: %r, data %r", query, params, data) # Convert datetimes to UTC strings. def convert_value(value): if isinstance(value, datetime): value = value.astimezone(timezone.utc) value = value.strftime("%Y-%m-%dT%H:%M:%SZ") return value if query is not None: if params is not None: params = {key: convert_value(value) for key, value in params.items()} data = json.dumps({"query": query, "variables": params}).encode("utf-8") elif data is None: raise ValueError("query and data cannot both be none") endpoint = ( self.endpoint if not experimental else self.endpoint.replace("/graphql", "/_gql") ) try: headers = self._connection.headers.copy() if files: del headers["Content-Type"] del headers["Accept"] headers["X-SDK-Method"] = ( self.sdk_method if self.sdk_method else call_info_as_str() ) request = requests.Request( "POST", endpoint, headers=headers, data=data, files=files if files else None, ) prepped: requests.PreparedRequest = request.prepare() settings = self._connection.merge_environment_settings( prepped.url, {}, None, None, None ) response = self._connection.send(prepped, timeout=timeout, **settings) logger.debug("Response: %s", response.text) except requests.exceptions.Timeout as e: raise exceptions.TimeoutError(str(e)) except requests.exceptions.RequestException as e: logger.error("Unknown error: %s", str(e)) raise exceptions.NetworkError(e) except Exception as e: raise exceptions.LabelboxError( "Unknown error during Client.query(): " + str(e), e ) if ( 200 <= response.status_code < 300 or response.status_code < 500 or response.status_code >= 600 ): try: r_json = response.json() except Exception: raise exceptions.LabelboxError( "Failed to parse response as JSON: %s" % response.text ) else: if ( "upstream connect error or disconnect/reset before headers" in response.text ): raise exceptions.InternalServerError("Connection reset") elif response.status_code == 502: error_502 = "502 Bad Gateway" raise exceptions.InternalServerError(error_502) elif 500 <= response.status_code < 600: error_500 = f"Internal server http error {response.status_code}" raise exceptions.InternalServerError(error_500) errors = r_json.get("errors", []) def check_errors(keywords, *path): """Helper that looks for any of the given `keywords` in any of current errors on paths (like error[path][component][to][keyword]). """ for error in errors: obj = error for path_elem in path: obj = obj.get(path_elem, {}) if obj in keywords: return error return None def get_error_status_code(error: dict) -> int: try: return int(error["extensions"].get("exception").get("status")) except Exception: return 500 if check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") is not None: raise exceptions.AuthenticationError("Invalid API key") authorization_error = check_errors( ["AUTHORIZATION_ERROR"], "extensions", "code" ) if authorization_error is not None: raise exceptions.AuthorizationError(authorization_error["message"]) validation_error = check_errors( ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" ) if validation_error is not None: message = validation_error["message"] if message == "Query complexity limit exceeded": raise exceptions.ValidationFailedError(message) else: raise exceptions.InvalidQueryError(message) graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", "code") if graphql_error is not None: raise exceptions.InvalidQueryError(graphql_error["message"]) # Check if API limit was exceeded response_msg = r_json.get("message", "") if response_msg.startswith("You have exceeded"): raise exceptions.ApiLimitError(response_msg) resource_not_found_error = check_errors( ["RESOURCE_NOT_FOUND"], "extensions", "code" ) if resource_not_found_error is not None: if raise_return_resource_not_found: raise exceptions.ResourceNotFoundError( message=resource_not_found_error["message"] ) else: # Return None and let the caller methods raise an exception # as they already know which resource type and ID was requested return None resource_conflict_error = check_errors( ["RESOURCE_CONFLICT"], "extensions", "code" ) if resource_conflict_error is not None: raise exceptions.ResourceConflict(resource_conflict_error["message"]) malformed_request_error = check_errors( ["MALFORMED_REQUEST"], "extensions", "code" ) error_code = "MALFORMED_REQUEST" if malformed_request_error is not None: if error_handlers and error_code in error_handlers: handler = error_handlers[error_code] handler(response) return None raise exceptions.MalformedQueryException( malformed_request_error[error_log_key] ) # A lot of different error situations are now labeled serverside # as INTERNAL_SERVER_ERROR, when they are actually client errors. # TODO: fix this in the server API internal_server_error = check_errors( ["INTERNAL_SERVER_ERROR"], "extensions", "code" ) error_code = "INTERNAL_SERVER_ERROR" if internal_server_error is not None: if error_handlers and error_code in error_handlers: handler = error_handlers[error_code] handler(response) return None message = internal_server_error.get("message") error_status_code = get_error_status_code(internal_server_error) if error_status_code == 400: raise exceptions.InvalidQueryError(message) elif error_status_code == 422: raise exceptions.UnprocessableEntityError(message) elif error_status_code == 426: raise exceptions.OperationNotAllowedException(message) elif error_status_code == 500: raise exceptions.LabelboxError(message) else: raise exceptions.InternalServerError(message) not_allowed_error = check_errors( ["OPERATION_NOT_ALLOWED"], "extensions", "code" ) if not_allowed_error is not None: message = not_allowed_error.get("message") raise exceptions.OperationNotAllowedException(message) if len(errors) > 0: logger.warning("Unparsed errors on query execution: %r", errors) messages = list( map( lambda x: { "message": x["message"], "code": x["extensions"]["code"], }, errors, ) ) raise exceptions.LabelboxError("Unknown error: %s" % str(messages)) # if we do return a proper error code, and didn't catch this above # reraise # this mainly catches a 401 for API access disabled for free tier # TODO: need to unify API errors to handle things more uniformly # in the SDK if response.status_code != requests.codes.ok: message = f"{response.status_code} {response.reason}" cause = r_json.get("message") raise exceptions.LabelboxError(message, cause) return r_json["data"]