aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py1116
1 files changed, 1116 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py b/.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py
new file mode 100644
index 00000000..f21a65d9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/unstructured_client/utils/utils.py
@@ -0,0 +1,1116 @@
+"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
+
+import base64
+import json
+import re
+import sys
+from dataclasses import Field, fields, is_dataclass, make_dataclass
+from datetime import date, datetime
+from decimal import Decimal
+from email.message import Message
+from enum import Enum
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ get_args,
+ get_origin,
+)
+from xmlrpc.client import boolean
+from typing_inspect import is_optional_type
+import dateutil.parser
+from dataclasses_json import DataClassJsonMixin
+
+
+def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, str]]:
+ headers: Dict[str, str] = {}
+ query_params: Dict[str, str] = {}
+
+ if security is None:
+ return headers, query_params
+
+ sec_fields: Tuple[Field, ...] = fields(security)
+ for sec_field in sec_fields:
+ value = getattr(security, sec_field.name)
+ if value is None:
+ continue
+
+ metadata = sec_field.metadata.get("security")
+ if metadata is None:
+ continue
+ if metadata.get("option"):
+ _parse_security_option(headers, query_params, value)
+ return headers, query_params
+ if metadata.get("scheme"):
+ # Special case for basic auth which could be a flattened struct
+ if metadata.get("sub_type") == "basic" and not is_dataclass(value):
+ _parse_security_scheme(headers, query_params, metadata, security)
+ else:
+ _parse_security_scheme(headers, query_params, metadata, value)
+
+ return headers, query_params
+
+
+def _parse_security_option(
+ headers: Dict[str, str], query_params: Dict[str, str], option: Any
+):
+ opt_fields: Tuple[Field, ...] = fields(option)
+ for opt_field in opt_fields:
+ metadata = opt_field.metadata.get("security")
+ if metadata is None or metadata.get("scheme") is None:
+ continue
+ _parse_security_scheme(
+ headers, query_params, metadata, getattr(option, opt_field.name)
+ )
+
+
+def _parse_security_scheme(
+ headers: Dict[str, str],
+ query_params: Dict[str, str],
+ scheme_metadata: Dict,
+ scheme: Any,
+):
+ scheme_type = scheme_metadata.get("type")
+ sub_type = scheme_metadata.get("sub_type")
+
+ if is_dataclass(scheme):
+ if scheme_type == "http" and sub_type == "basic":
+ _parse_basic_auth_scheme(headers, scheme)
+ return
+
+ scheme_fields: Tuple[Field, ...] = fields(scheme)
+ for scheme_field in scheme_fields:
+ metadata = scheme_field.metadata.get("security")
+ if metadata is None or metadata.get("field_name") is None:
+ continue
+
+ value = getattr(scheme, scheme_field.name)
+
+ _parse_security_scheme_value(
+ headers, query_params, scheme_metadata, metadata, value
+ )
+ else:
+ _parse_security_scheme_value(
+ headers, query_params, scheme_metadata, scheme_metadata, scheme
+ )
+
+
+def _parse_security_scheme_value(
+ headers: Dict[str, str],
+ query_params: Dict[str, str],
+ scheme_metadata: Dict,
+ security_metadata: Dict,
+ value: Any,
+):
+ scheme_type = scheme_metadata.get("type")
+ sub_type = scheme_metadata.get("sub_type")
+
+ header_name = str(security_metadata.get("field_name"))
+
+ if scheme_type == "apiKey":
+ if sub_type == "header":
+ headers[header_name] = value
+ elif sub_type == "query":
+ query_params[header_name] = value
+ else:
+ raise Exception("not supported")
+ elif scheme_type == "openIdConnect":
+ headers[header_name] = _apply_bearer(value)
+ elif scheme_type == "oauth2":
+ if sub_type != "client_credentials":
+ headers[header_name] = _apply_bearer(value)
+ elif scheme_type == "http":
+ if sub_type == "bearer":
+ headers[header_name] = _apply_bearer(value)
+ else:
+ raise Exception("not supported")
+ else:
+ raise Exception("not supported")
+
+
+def _apply_bearer(token: str) -> str:
+ return token.lower().startswith("bearer ") and token or f"Bearer {token}"
+
+
+def _parse_basic_auth_scheme(headers: Dict[str, str], scheme: Any):
+ username = ""
+ password = ""
+
+ scheme_fields: Tuple[Field, ...] = fields(scheme)
+ for scheme_field in scheme_fields:
+ metadata = scheme_field.metadata.get("security")
+ if metadata is None or metadata.get("field_name") is None:
+ continue
+
+ field_name = metadata.get("field_name")
+ value = getattr(scheme, scheme_field.name)
+
+ if field_name == "username":
+ username = value
+ if field_name == "password":
+ password = value
+
+ data = f"{username}:{password}".encode()
+ headers["Authorization"] = f"Basic {base64.b64encode(data).decode()}"
+
+
+def generate_url(
+ server_url: str,
+ path: str,
+ path_params: Any,
+ gbls: Optional[Any] = None,
+) -> str:
+ path_param_values: Dict[str, str] = {}
+
+ globals_already_populated = _populate_path_params(
+ path_params, gbls, path_param_values, []
+ )
+ if gbls is not None:
+ _populate_path_params(gbls, None, path_param_values, globals_already_populated)
+
+ for key, value in path_param_values.items():
+ path = path.replace("{" + key + "}", value, 1)
+
+ return remove_suffix(server_url, "/") + path
+
+
+def _populate_path_params(
+ path_params: Any,
+ gbls: Any,
+ path_param_values: Dict[str, str],
+ skip_fields: List[str],
+) -> List[str]:
+ globals_already_populated: List[str] = []
+
+ path_param_fields: Tuple[Field, ...] = fields(path_params)
+ for field in path_param_fields:
+ if field.name in skip_fields:
+ continue
+
+ param_metadata = field.metadata.get("path_param")
+ if param_metadata is None:
+ continue
+
+ param = getattr(path_params, field.name) if path_params is not None else None
+ param, global_found = _populate_from_globals(
+ field.name, param, "path_param", gbls
+ )
+ if global_found:
+ globals_already_populated.append(field.name)
+
+ if param is None:
+ continue
+
+ f_name = param_metadata.get("field_name", field.name)
+ serialization = param_metadata.get("serialization", "")
+ if serialization != "":
+ serialized_params = _get_serialized_params(
+ param_metadata, field.type, f_name, param
+ )
+ for key, value in serialized_params.items():
+ path_param_values[key] = value
+ else:
+ if param_metadata.get("style", "simple") == "simple":
+ if isinstance(param, List):
+ pp_vals: List[str] = []
+ for pp_val in param:
+ if pp_val is None:
+ continue
+ pp_vals.append(_val_to_string(pp_val))
+ path_param_values[param_metadata.get("field_name", field.name)] = (
+ ",".join(pp_vals)
+ )
+ elif isinstance(param, Dict):
+ pp_vals: List[str] = []
+ for pp_key in param:
+ if param[pp_key] is None:
+ continue
+ if param_metadata.get("explode"):
+ pp_vals.append(f"{pp_key}={_val_to_string(param[pp_key])}")
+ else:
+ pp_vals.append(f"{pp_key},{_val_to_string(param[pp_key])}")
+ path_param_values[param_metadata.get("field_name", field.name)] = (
+ ",".join(pp_vals)
+ )
+ elif not isinstance(param, (str, int, float, complex, bool, Decimal)):
+ pp_vals: List[str] = []
+ param_fields: Tuple[Field, ...] = fields(param)
+ for param_field in param_fields:
+ param_value_metadata = param_field.metadata.get("path_param")
+ if not param_value_metadata:
+ continue
+
+ param_name = param_value_metadata.get("field_name", field.name)
+
+ param_field_val = getattr(param, param_field.name)
+ if param_field_val is None:
+ continue
+ if param_metadata.get("explode"):
+ pp_vals.append(
+ f"{param_name}={_val_to_string(param_field_val)}"
+ )
+ else:
+ pp_vals.append(
+ f"{param_name},{_val_to_string(param_field_val)}"
+ )
+ path_param_values[param_metadata.get("field_name", field.name)] = (
+ ",".join(pp_vals)
+ )
+ else:
+ path_param_values[param_metadata.get("field_name", field.name)] = (
+ _val_to_string(param)
+ )
+
+ return globals_already_populated
+
+
+def is_optional(field):
+ return get_origin(field) is Union and type(None) in get_args(field)
+
+
+def template_url(url_with_params: str, params: Dict[str, str]) -> str:
+ for key, value in params.items():
+ url_with_params = url_with_params.replace("{" + key + "}", value)
+
+ return url_with_params
+
+
+def get_query_params(
+ query_params: Any,
+ gbls: Optional[Any] = None,
+) -> Dict[str, List[str]]:
+ params: Dict[str, List[str]] = {}
+
+ globals_already_populated = _populate_query_params(query_params, gbls, params, [])
+ if gbls is not None:
+ _populate_query_params(gbls, None, params, globals_already_populated)
+
+ return params
+
+
+def _populate_query_params(
+ query_params: Any,
+ gbls: Any,
+ query_param_values: Dict[str, List[str]],
+ skip_fields: List[str],
+) -> List[str]:
+ globals_already_populated: List[str] = []
+
+ param_fields: Tuple[Field, ...] = fields(query_params)
+ for field in param_fields:
+ if field.name in skip_fields:
+ continue
+
+ metadata = field.metadata.get("query_param")
+ if not metadata:
+ continue
+
+ param_name = field.name
+ value = getattr(query_params, param_name) if query_params is not None else None
+
+ value, global_found = _populate_from_globals(
+ param_name, value, "query_param", gbls
+ )
+ if global_found:
+ globals_already_populated.append(param_name)
+
+ f_name = metadata.get("field_name")
+ serialization = metadata.get("serialization", "")
+ if serialization != "":
+ serialized_parms = _get_serialized_params(
+ metadata, field.type, f_name, value
+ )
+ for key, value in serialized_parms.items():
+ if key in query_param_values:
+ query_param_values[key].extend(value)
+ else:
+ query_param_values[key] = [value]
+ else:
+ style = metadata.get("style", "form")
+ if style == "deepObject":
+ _populate_deep_object_query_params(
+ metadata, f_name, value, query_param_values
+ )
+ elif style == "form":
+ _populate_delimited_query_params(
+ metadata, f_name, value, ",", query_param_values
+ )
+ elif style == "pipeDelimited":
+ _populate_delimited_query_params(
+ metadata, f_name, value, "|", query_param_values
+ )
+ else:
+ raise Exception("not yet implemented")
+
+ return globals_already_populated
+
+
+def get_headers(headers_params: Any, gbls: Optional[Any] = None) -> Dict[str, str]:
+ headers: Dict[str, str] = {}
+
+ globals_already_populated = []
+ if headers_params is not None:
+ globals_already_populated = _populate_headers(headers_params, gbls, headers, [])
+ if gbls is not None:
+ _populate_headers(gbls, None, headers, globals_already_populated)
+
+ return headers
+
+
+def _populate_headers(
+ headers_params: Any,
+ gbls: Any,
+ header_values: Dict[str, str],
+ skip_fields: List[str],
+) -> List[str]:
+ globals_already_populated: List[str] = []
+
+ param_fields: Tuple[Field, ...] = fields(headers_params)
+ for field in param_fields:
+ if field.name in skip_fields:
+ continue
+
+ metadata = field.metadata.get("header")
+ if not metadata:
+ continue
+
+ value, global_found = _populate_from_globals(
+ field.name, getattr(headers_params, field.name), "header", gbls
+ )
+ if global_found:
+ globals_already_populated.append(field.name)
+ value = _serialize_header(metadata.get("explode", False), value)
+
+ if value != "":
+ header_values[metadata.get("field_name", field.name)] = value
+
+ return globals_already_populated
+
+
+def _get_serialized_params(
+ metadata: Dict, field_type: type, field_name: str, obj: Any
+) -> Dict[str, str]:
+ params: Dict[str, str] = {}
+
+ serialization = metadata.get("serialization", "")
+ if serialization == "json":
+ params[metadata.get("field_name", field_name)] = marshal_json(obj, field_type)
+
+ return params
+
+
+def _populate_deep_object_query_params(
+ metadata: Dict, field_name: str, obj: Any, params: Dict[str, List[str]]
+):
+ if obj is None:
+ return
+
+ if is_dataclass(obj):
+ _populate_deep_object_query_params_dataclass(metadata.get("field_name", field_name), obj, params)
+ elif isinstance(obj, Dict):
+ _populate_deep_object_query_params_dict(metadata.get("field_name", field_name), obj, params)
+
+
+def _populate_deep_object_query_params_dataclass(
+ prior_params_key: str, obj: Any, params: Dict[str, List[str]]
+):
+ if obj is None:
+ return
+
+ if not is_dataclass(obj):
+ return
+
+ obj_fields: Tuple[Field, ...] = fields(obj)
+ for obj_field in obj_fields:
+ obj_param_metadata = obj_field.metadata.get("query_param")
+ if not obj_param_metadata:
+ continue
+
+ obj_val = getattr(obj, obj_field.name)
+ if obj_val is None:
+ continue
+
+ params_key = f'{prior_params_key}[{obj_param_metadata.get("field_name", obj_field.name)}]'
+
+ if is_dataclass(obj_val):
+ _populate_deep_object_query_params_dataclass(params_key, obj_val, params)
+ elif isinstance(obj_val, Dict):
+ _populate_deep_object_query_params_dict(params_key, obj_val, params)
+ elif isinstance(obj_val, List):
+ _populate_deep_object_query_params_list(params_key, obj_val, params)
+ else:
+ params[params_key] = [_val_to_string(obj_val)]
+
+
+def _populate_deep_object_query_params_dict(
+ prior_params_key: str, value: Dict, params: Dict[str, List[str]]
+):
+ if value is None:
+ return
+
+ for key, val in value.items():
+ if val is None:
+ continue
+
+ params_key = f'{prior_params_key}[{key}]'
+
+ if is_dataclass(val):
+ _populate_deep_object_query_params_dataclass(params_key, val, params)
+ elif isinstance(val, Dict):
+ _populate_deep_object_query_params_dict(params_key, val, params)
+ elif isinstance(val, List):
+ _populate_deep_object_query_params_list(params_key, val, params)
+ else:
+ params[params_key] = [_val_to_string(val)]
+
+
+def _populate_deep_object_query_params_list(
+ params_key: str, value: List, params: Dict[str, List[str]]
+):
+ if value is None:
+ return
+
+ for val in value:
+ if val is None:
+ continue
+
+ if params.get(params_key) is None:
+ params[params_key] = []
+
+ params[params_key].append(_val_to_string(val))
+
+
+def _get_query_param_field_name(obj_field: Field) -> str:
+ obj_param_metadata = obj_field.metadata.get("query_param")
+
+ if not obj_param_metadata:
+ return ""
+
+ return obj_param_metadata.get("field_name", obj_field.name)
+
+
+def _populate_delimited_query_params(
+ metadata: Dict,
+ field_name: str,
+ obj: Any,
+ delimiter: str,
+ query_param_values: Dict[str, List[str]],
+):
+ _populate_form(
+ field_name,
+ metadata.get("explode", True),
+ obj,
+ _get_query_param_field_name,
+ delimiter,
+ query_param_values,
+ )
+
+
+SERIALIZATION_METHOD_TO_CONTENT_TYPE = {
+ "json": "application/json",
+ "form": "application/x-www-form-urlencoded",
+ "multipart": "multipart/form-data",
+ "raw": "application/octet-stream",
+ "string": "text/plain",
+}
+
+
+def serialize_request_body(
+ request: Any,
+ request_type: type,
+ request_field_name: str,
+ nullable: bool,
+ optional: bool,
+ serialization_method: str,
+ encoder=None,
+) -> Tuple[Optional[str], Optional[Any], Optional[Any]]:
+ if request is None:
+ if not nullable and optional:
+ return None, None, None
+
+ if not is_dataclass(request) or not hasattr(request, request_field_name):
+ return serialize_content_type(
+ request_field_name,
+ request_type,
+ SERIALIZATION_METHOD_TO_CONTENT_TYPE[serialization_method],
+ request,
+ encoder,
+ )
+
+ request_val = getattr(request, request_field_name)
+
+ if request_val is None:
+ if not nullable and optional:
+ return None, None, None
+
+ request_fields: Tuple[Field, ...] = fields(request)
+ request_metadata = None
+
+ for field in request_fields:
+ if field.name == request_field_name:
+ request_metadata = field.metadata.get("request")
+ break
+
+ if request_metadata is None:
+ raise Exception("invalid request type")
+
+ return serialize_content_type(
+ request_field_name,
+ request_type,
+ request_metadata.get("media_type", "application/octet-stream"),
+ request_val,
+ )
+
+
+def serialize_content_type(
+ field_name: str, request_type: Any, media_type: str, request: Any, encoder=None
+) -> Tuple[Optional[str], Optional[Any], Optional[List[List[Any]]]]:
+ if re.match(r"(application|text)\/.*?\+*json.*", media_type) is not None:
+ return media_type, marshal_json(request, request_type, encoder), None
+ if re.match(r"multipart\/.*", media_type) is not None:
+ return serialize_multipart_form(media_type, request)
+ if re.match(r"application\/x-www-form-urlencoded.*", media_type) is not None:
+ return media_type, serialize_form_data(field_name, request), None
+ if isinstance(request, (bytes, bytearray)):
+ return media_type, request, None
+ if isinstance(request, str):
+ return media_type, request, None
+
+ raise Exception(
+ f"invalid request body type {type(request)} for mediaType {media_type}"
+ )
+
+
+def serialize_multipart_form(
+ media_type: str, request: Any
+) -> Tuple[str, Any, List[List[Any]]]:
+ form: List[List[Any]] = []
+ request_fields = fields(request)
+
+ for field in request_fields:
+ val = getattr(request, field.name)
+ if val is None:
+ continue
+
+ field_metadata = field.metadata.get("multipart_form")
+ if not field_metadata:
+ continue
+
+ if field_metadata.get("file") is True:
+ file_fields = fields(val)
+
+ file_name = ""
+ field_name = ""
+ content = bytes()
+
+ for file_field in file_fields:
+ file_metadata = file_field.metadata.get("multipart_form")
+ if file_metadata is None:
+ continue
+
+ if file_metadata.get("content") is True:
+ content = getattr(val, file_field.name)
+ else:
+ field_name = file_metadata.get("field_name", file_field.name)
+ file_name = getattr(val, file_field.name)
+ if field_name == "" or file_name == "" or content == bytes():
+ raise Exception("invalid multipart/form-data file")
+
+ form.append([field_name, [file_name, content]])
+ elif field_metadata.get("json") is True:
+ to_append = [
+ field_metadata.get("field_name", field.name),
+ [None, marshal_json(val, field.type), "application/json"],
+ ]
+ form.append(to_append)
+ else:
+ field_name = field_metadata.get("field_name", field.name)
+ if isinstance(val, List):
+ for value in val:
+ if value is None:
+ continue
+ form.append([field_name + "[]", [None, _val_to_string(value)]])
+ else:
+ form.append([field_name, [None, _val_to_string(val)]])
+ return media_type, None, form
+
+
+def serialize_dict(
+ original: Dict, explode: bool, field_name, existing: Optional[Dict[str, List[str]]]
+) -> Dict[str, List[str]]:
+ if existing is None:
+ existing = {}
+
+ if explode is True:
+ for key, val in original.items():
+ if key not in existing:
+ existing[key] = []
+ existing[key].append(val)
+ else:
+ temp = []
+ for key, val in original.items():
+ temp.append(str(key))
+ temp.append(str(val))
+ if field_name not in existing:
+ existing[field_name] = []
+ existing[field_name].append(",".join(temp))
+ return existing
+
+
+def serialize_form_data(field_name: str, data: Any) -> Dict[str, Any]:
+ form: Dict[str, List[str]] = {}
+
+ if is_dataclass(data):
+ for field in fields(data):
+ val = getattr(data, field.name)
+ if val is None:
+ continue
+
+ metadata = field.metadata.get("form")
+ if metadata is None:
+ continue
+
+ field_name = metadata.get("field_name", field.name)
+
+ if metadata.get("json"):
+ form[field_name] = [marshal_json(val, field.type)]
+ else:
+ if metadata.get("style", "form") == "form":
+ _populate_form(
+ field_name,
+ metadata.get("explode", True),
+ val,
+ _get_form_field_name,
+ ",",
+ form,
+ )
+ else:
+ raise Exception(f"Invalid form style for field {field.name}")
+ elif isinstance(data, Dict):
+ for key, value in data.items():
+ form[key] = [_val_to_string(value)]
+ else:
+ raise Exception(f"Invalid request body type for field {field_name}")
+
+ return form
+
+
+def _get_form_field_name(obj_field: Field) -> str:
+ obj_param_metadata = obj_field.metadata.get("form")
+
+ if not obj_param_metadata:
+ return ""
+
+ return obj_param_metadata.get("field_name", obj_field.name)
+
+
+def _populate_form(
+ field_name: str,
+ explode: boolean,
+ obj: Any,
+ get_field_name_func: Callable,
+ delimiter: str,
+ form: Dict[str, List[str]],
+):
+ if obj is None:
+ return form
+
+ if is_dataclass(obj):
+ items = []
+
+ obj_fields: Tuple[Field, ...] = fields(obj)
+ for obj_field in obj_fields:
+ obj_field_name = get_field_name_func(obj_field)
+ if obj_field_name == "":
+ continue
+
+ val = getattr(obj, obj_field.name)
+ if val is None:
+ continue
+
+ if explode:
+ form[obj_field_name] = [_val_to_string(val)]
+ else:
+ items.append(f"{obj_field_name}{delimiter}{_val_to_string(val)}")
+
+ if len(items) > 0:
+ form[field_name] = [delimiter.join(items)]
+ elif isinstance(obj, Dict):
+ items = []
+ for key, value in obj.items():
+ if value is None:
+ continue
+
+ if explode:
+ form[key] = [_val_to_string(value)]
+ else:
+ items.append(f"{key}{delimiter}{_val_to_string(value)}")
+
+ if len(items) > 0:
+ form[field_name] = [delimiter.join(items)]
+ elif isinstance(obj, List):
+ items = []
+
+ for value in obj:
+ if value is None:
+ continue
+
+ if explode:
+ if not field_name in form:
+ form[field_name] = []
+ form[field_name].append(_val_to_string(value))
+ else:
+ items.append(_val_to_string(value))
+
+ if len(items) > 0:
+ form[field_name] = [delimiter.join([str(item) for item in items])]
+ else:
+ form[field_name] = [_val_to_string(obj)]
+
+ return form
+
+
+def _serialize_header(explode: bool, obj: Any) -> str:
+ if obj is None:
+ return ""
+
+ if is_dataclass(obj):
+ items = []
+ obj_fields: Tuple[Field, ...] = fields(obj)
+ for obj_field in obj_fields:
+ obj_param_metadata = obj_field.metadata.get("header")
+
+ if not obj_param_metadata:
+ continue
+
+ obj_field_name = obj_param_metadata.get("field_name", obj_field.name)
+ if obj_field_name == "":
+ continue
+
+ val = getattr(obj, obj_field.name)
+ if val is None:
+ continue
+
+ if explode:
+ items.append(f"{obj_field_name}={_val_to_string(val)}")
+ else:
+ items.append(obj_field_name)
+ items.append(_val_to_string(val))
+
+ if len(items) > 0:
+ return ",".join(items)
+ elif isinstance(obj, Dict):
+ items = []
+
+ for key, value in obj.items():
+ if value is None:
+ continue
+
+ if explode:
+ items.append(f"{key}={_val_to_string(value)}")
+ else:
+ items.append(key)
+ items.append(_val_to_string(value))
+
+ if len(items) > 0:
+ return ",".join([str(item) for item in items])
+ elif isinstance(obj, List):
+ items = []
+
+ for value in obj:
+ if value is None:
+ continue
+
+ items.append(_val_to_string(value))
+
+ if len(items) > 0:
+ return ",".join(items)
+ else:
+ return f"{_val_to_string(obj)}"
+
+ return ""
+
+
+def unmarshal_json(data, typ, decoder=None, infer_missing=False):
+ unmarshal = make_dataclass("Unmarshal", [("res", typ)], bases=(DataClassJsonMixin,))
+ json_dict = json.loads(data)
+ try:
+ out = unmarshal.from_dict({"res": json_dict}, infer_missing=infer_missing)
+ except AttributeError as attr_err:
+ raise AttributeError(
+ f"unable to unmarshal {data} as {typ} - {attr_err}"
+ ) from attr_err
+
+ return out.res if decoder is None else decoder(out.res)
+
+
+def marshal_json(val, typ, encoder=None):
+ if not is_optional_type(typ) and val is None:
+ raise ValueError(f"Could not marshal None into non-optional type: {typ}")
+
+ marshal = make_dataclass("Marshal", [("res", typ)], bases=(DataClassJsonMixin,))
+ marshaller = marshal(res=val)
+ json_dict = marshaller.to_dict()
+ val = json_dict["res"] if encoder is None else encoder(json_dict["res"])
+
+ return json.dumps(val, separators=(",", ":"), sort_keys=True)
+
+
+def match_content_type(content_type: str, pattern: str) -> boolean:
+ if pattern in (content_type, "*", "*/*"):
+ return True
+
+ msg = Message()
+ msg["content-type"] = content_type
+ media_type = msg.get_content_type()
+
+ if media_type == pattern:
+ return True
+
+ parts = media_type.split("/")
+ if len(parts) == 2:
+ if pattern in (f"{parts[0]}/*", f"*/{parts[1]}"):
+ return True
+
+ return False
+
+
+def match_status_codes(status_codes: List[str], status_code: int) -> bool:
+ for code in status_codes:
+ if code == str(status_code):
+ return True
+
+ if code.endswith("XX") and code.startswith(str(status_code)[:1]):
+ return True
+ return False
+
+
+def datetimeisoformat(optional: bool):
+ def isoformatoptional(val):
+ if optional and val is None:
+ return None
+ return _val_to_string(val)
+
+ return isoformatoptional
+
+
+def dateisoformat(optional: bool):
+ def isoformatoptional(val):
+ if optional and val is None:
+ return None
+ return date.isoformat(val)
+
+ return isoformatoptional
+
+
+def datefromisoformat(date_str: str):
+ return dateutil.parser.parse(date_str).date()
+
+
+def bigintencoder(optional: bool):
+ def bigintencode(val: int):
+ if optional and val is None:
+ return None
+ return str(val)
+
+ return bigintencode
+
+
+def bigintdecoder(val):
+ if val is None:
+ return None
+
+ if isinstance(val, float):
+ raise ValueError(f"{val} is a float")
+ return int(val)
+
+def integerstrencoder(optional: bool):
+ def integerstrencode(val: int):
+ if optional and val is None:
+ return None
+ return str(val)
+
+ return integerstrencode
+
+
+def integerstrdecoder(val):
+ if val is None:
+ return None
+
+ if isinstance(val, float):
+ raise ValueError(f"{val} is a float")
+ return int(val)
+
+
+def numberstrencoder(optional: bool):
+ def numberstrencode(val: float):
+ if optional and val is None:
+ return None
+ return str(val)
+
+ return numberstrencode
+
+
+def numberstrdecoder(val):
+ if val is None:
+ return None
+
+ return float(val)
+
+
+def decimalencoder(optional: bool, as_str: bool):
+ def decimalencode(val: Decimal):
+ if optional and val is None:
+ return None
+
+ if as_str:
+ return str(val)
+
+ return float(val)
+
+ return decimalencode
+
+
+def decimaldecoder(val):
+ if val is None:
+ return None
+
+ return Decimal(str(val))
+
+
+def map_encoder(optional: bool, value_encoder: Callable):
+ def map_encode(val: Dict):
+ if optional and val is None:
+ return None
+
+ encoded = {}
+ for key, value in val.items():
+ encoded[key] = value_encoder(value)
+
+ return encoded
+
+ return map_encode
+
+
+def map_decoder(value_decoder: Callable):
+ def map_decode(val: Dict):
+ decoded = {}
+ for key, value in val.items():
+ decoded[key] = value_decoder(value)
+
+ return decoded
+
+ return map_decode
+
+
+def list_encoder(optional: bool, value_encoder: Callable):
+ def list_encode(val: List):
+ if optional and val is None:
+ return None
+
+ encoded = []
+ for value in val:
+ encoded.append(value_encoder(value))
+
+ return encoded
+
+ return list_encode
+
+
+def list_decoder(value_decoder: Callable):
+ def list_decode(val: List):
+ decoded = []
+ for value in val:
+ decoded.append(value_decoder(value))
+
+ return decoded
+
+ return list_decode
+
+
+def union_encoder(all_encoders: Dict[str, Callable]):
+ def selective_encoder(val: Any):
+ if type(val) in all_encoders:
+ return all_encoders[type(val)](val)
+ return val
+
+ return selective_encoder
+
+
+def union_decoder(all_decoders: List[Callable]):
+ def selective_decoder(val: Any):
+ decoded = val
+ for decoder in all_decoders:
+ try:
+ decoded = decoder(val)
+ break
+ except (TypeError, ValueError):
+ continue
+ return decoded
+
+ return selective_decoder
+
+
+def get_field_name(name):
+ def override(_, _field_name=name):
+ return _field_name
+
+ return override
+
+
+def _val_to_string(val) -> str:
+ if isinstance(val, bool):
+ return str(val).lower()
+ if isinstance(val, datetime):
+ return str(val.isoformat().replace("+00:00", "Z"))
+ if isinstance(val, Enum):
+ return str(val.value)
+
+ return str(val)
+
+
+def _populate_from_globals(
+ param_name: str, value: Any, param_type: str, gbls: Any
+) -> Tuple[Any, bool]:
+ if gbls is None:
+ return value, False
+
+ global_fields = fields(gbls)
+
+ found = False
+ for field in global_fields:
+ if field.name is not param_name:
+ continue
+
+ found = True
+
+ if value is not None:
+ return value, True
+
+ global_value = getattr(gbls, field.name)
+
+ param_metadata = field.metadata.get(param_type)
+ if param_metadata is None:
+ return value, True
+
+ return global_value, True
+
+ return value, found
+
+
+def decoder_with_discriminator(field_name):
+ def decode_fx(obj):
+ kls = getattr(sys.modules["sdk.models.shared"], obj[field_name])
+ return unmarshal_json(json.dumps(obj), kls)
+
+ return decode_fx
+
+
+def remove_suffix(input_string, suffix):
+ if suffix and input_string.endswith(suffix):
+ return input_string[: -len(suffix)]
+ return input_string