"""Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT.""" import base64 import json import re import sys from dataclasses import Field, fields, is_dataclass, make_dataclass from datetime import date, datetime from decimal import Decimal from email.message import Message from enum import Enum from typing import ( Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, ) from xmlrpc.client import boolean from typing_inspect import is_optional_type import dateutil.parser from dataclasses_json import DataClassJsonMixin def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, str]]: headers: Dict[str, str] = {} query_params: Dict[str, str] = {} if security is None: return headers, query_params sec_fields: Tuple[Field, ...] = fields(security) for sec_field in sec_fields: value = getattr(security, sec_field.name) if value is None: continue metadata = sec_field.metadata.get("security") if metadata is None: continue if metadata.get("option"): _parse_security_option(headers, query_params, value) return headers, query_params if metadata.get("scheme"): # Special case for basic auth which could be a flattened struct if metadata.get("sub_type") == "basic" and not is_dataclass(value): _parse_security_scheme(headers, query_params, metadata, security) else: _parse_security_scheme(headers, query_params, metadata, value) return headers, query_params def _parse_security_option( headers: Dict[str, str], query_params: Dict[str, str], option: Any ): opt_fields: Tuple[Field, ...] = fields(option) for opt_field in opt_fields: metadata = opt_field.metadata.get("security") if metadata is None or metadata.get("scheme") is None: continue _parse_security_scheme( headers, query_params, metadata, getattr(option, opt_field.name) ) def _parse_security_scheme( headers: Dict[str, str], query_params: Dict[str, str], scheme_metadata: Dict, scheme: Any, ): scheme_type = scheme_metadata.get("type") sub_type = scheme_metadata.get("sub_type") if is_dataclass(scheme): if scheme_type == "http" and sub_type == "basic": _parse_basic_auth_scheme(headers, scheme) return scheme_fields: Tuple[Field, ...] = fields(scheme) for scheme_field in scheme_fields: metadata = scheme_field.metadata.get("security") if metadata is None or metadata.get("field_name") is None: continue value = getattr(scheme, scheme_field.name) _parse_security_scheme_value( headers, query_params, scheme_metadata, metadata, value ) else: _parse_security_scheme_value( headers, query_params, scheme_metadata, scheme_metadata, scheme ) def _parse_security_scheme_value( headers: Dict[str, str], query_params: Dict[str, str], scheme_metadata: Dict, security_metadata: Dict, value: Any, ): scheme_type = scheme_metadata.get("type") sub_type = scheme_metadata.get("sub_type") header_name = str(security_metadata.get("field_name")) if scheme_type == "apiKey": if sub_type == "header": headers[header_name] = value elif sub_type == "query": query_params[header_name] = value else: raise Exception("not supported") elif scheme_type == "openIdConnect": headers[header_name] = _apply_bearer(value) elif scheme_type == "oauth2": if sub_type != "client_credentials": headers[header_name] = _apply_bearer(value) elif scheme_type == "http": if sub_type == "bearer": headers[header_name] = _apply_bearer(value) else: raise Exception("not supported") else: raise Exception("not supported") def _apply_bearer(token: str) -> str: return token.lower().startswith("bearer ") and token or f"Bearer {token}" def _parse_basic_auth_scheme(headers: Dict[str, str], scheme: Any): username = "" password = "" scheme_fields: Tuple[Field, ...] = fields(scheme) for scheme_field in scheme_fields: metadata = scheme_field.metadata.get("security") if metadata is None or metadata.get("field_name") is None: continue field_name = metadata.get("field_name") value = getattr(scheme, scheme_field.name) if field_name == "username": username = value if field_name == "password": password = value data = f"{username}:{password}".encode() headers["Authorization"] = f"Basic {base64.b64encode(data).decode()}" def generate_url( server_url: str, path: str, path_params: Any, gbls: Optional[Any] = None, ) -> str: path_param_values: Dict[str, str] = {} globals_already_populated = _populate_path_params( path_params, gbls, path_param_values, [] ) if gbls is not None: _populate_path_params(gbls, None, path_param_values, globals_already_populated) for key, value in path_param_values.items(): path = path.replace("{" + key + "}", value, 1) return remove_suffix(server_url, "/") + path def _populate_path_params( path_params: Any, gbls: Any, path_param_values: Dict[str, str], skip_fields: List[str], ) -> List[str]: globals_already_populated: List[str] = [] path_param_fields: Tuple[Field, ...] = fields(path_params) for field in path_param_fields: if field.name in skip_fields: continue param_metadata = field.metadata.get("path_param") if param_metadata is None: continue param = getattr(path_params, field.name) if path_params is not None else None param, global_found = _populate_from_globals( field.name, param, "path_param", gbls ) if global_found: globals_already_populated.append(field.name) if param is None: continue f_name = param_metadata.get("field_name", field.name) serialization = param_metadata.get("serialization", "") if serialization != "": serialized_params = _get_serialized_params( param_metadata, field.type, f_name, param ) for key, value in serialized_params.items(): path_param_values[key] = value else: if param_metadata.get("style", "simple") == "simple": if isinstance(param, List): pp_vals: List[str] = [] for pp_val in param: if pp_val is None: continue pp_vals.append(_val_to_string(pp_val)) path_param_values[param_metadata.get("field_name", field.name)] = ( ",".join(pp_vals) ) elif isinstance(param, Dict): pp_vals: List[str] = [] for pp_key in param: if param[pp_key] is None: continue if param_metadata.get("explode"): pp_vals.append(f"{pp_key}={_val_to_string(param[pp_key])}") else: pp_vals.append(f"{pp_key},{_val_to_string(param[pp_key])}") path_param_values[param_metadata.get("field_name", field.name)] = ( ",".join(pp_vals) ) elif not isinstance(param, (str, int, float, complex, bool, Decimal)): pp_vals: List[str] = [] param_fields: Tuple[Field, ...] = fields(param) for param_field in param_fields: param_value_metadata = param_field.metadata.get("path_param") if not param_value_metadata: continue param_name = param_value_metadata.get("field_name", field.name) param_field_val = getattr(param, param_field.name) if param_field_val is None: continue if param_metadata.get("explode"): pp_vals.append( f"{param_name}={_val_to_string(param_field_val)}" ) else: pp_vals.append( f"{param_name},{_val_to_string(param_field_val)}" ) path_param_values[param_metadata.get("field_name", field.name)] = ( ",".join(pp_vals) ) else: path_param_values[param_metadata.get("field_name", field.name)] = ( _val_to_string(param) ) return globals_already_populated def is_optional(field): return get_origin(field) is Union and type(None) in get_args(field) def template_url(url_with_params: str, params: Dict[str, str]) -> str: for key, value in params.items(): url_with_params = url_with_params.replace("{" + key + "}", value) return url_with_params def get_query_params( query_params: Any, gbls: Optional[Any] = None, ) -> Dict[str, List[str]]: params: Dict[str, List[str]] = {} globals_already_populated = _populate_query_params(query_params, gbls, params, []) if gbls is not None: _populate_query_params(gbls, None, params, globals_already_populated) return params def _populate_query_params( query_params: Any, gbls: Any, query_param_values: Dict[str, List[str]], skip_fields: List[str], ) -> List[str]: globals_already_populated: List[str] = [] param_fields: Tuple[Field, ...] = fields(query_params) for field in param_fields: if field.name in skip_fields: continue metadata = field.metadata.get("query_param") if not metadata: continue param_name = field.name value = getattr(query_params, param_name) if query_params is not None else None value, global_found = _populate_from_globals( param_name, value, "query_param", gbls ) if global_found: globals_already_populated.append(param_name) f_name = metadata.get("field_name") serialization = metadata.get("serialization", "") if serialization != "": serialized_parms = _get_serialized_params( metadata, field.type, f_name, value ) for key, value in serialized_parms.items(): if key in query_param_values: query_param_values[key].extend(value) else: query_param_values[key] = [value] else: style = metadata.get("style", "form") if style == "deepObject": _populate_deep_object_query_params( metadata, f_name, value, query_param_values ) elif style == "form": _populate_delimited_query_params( metadata, f_name, value, ",", query_param_values ) elif style == "pipeDelimited": _populate_delimited_query_params( metadata, f_name, value, "|", query_param_values ) else: raise Exception("not yet implemented") return globals_already_populated def get_headers(headers_params: Any, gbls: Optional[Any] = None) -> Dict[str, str]: headers: Dict[str, str] = {} globals_already_populated = [] if headers_params is not None: globals_already_populated = _populate_headers(headers_params, gbls, headers, []) if gbls is not None: _populate_headers(gbls, None, headers, globals_already_populated) return headers def _populate_headers( headers_params: Any, gbls: Any, header_values: Dict[str, str], skip_fields: List[str], ) -> List[str]: globals_already_populated: List[str] = [] param_fields: Tuple[Field, ...] = fields(headers_params) for field in param_fields: if field.name in skip_fields: continue metadata = field.metadata.get("header") if not metadata: continue value, global_found = _populate_from_globals( field.name, getattr(headers_params, field.name), "header", gbls ) if global_found: globals_already_populated.append(field.name) value = _serialize_header(metadata.get("explode", False), value) if value != "": header_values[metadata.get("field_name", field.name)] = value return globals_already_populated def _get_serialized_params( metadata: Dict, field_type: type, field_name: str, obj: Any ) -> Dict[str, str]: params: Dict[str, str] = {} serialization = metadata.get("serialization", "") if serialization == "json": params[metadata.get("field_name", field_name)] = marshal_json(obj, field_type) return params def _populate_deep_object_query_params( metadata: Dict, field_name: str, obj: Any, params: Dict[str, List[str]] ): if obj is None: return if is_dataclass(obj): _populate_deep_object_query_params_dataclass(metadata.get("field_name", field_name), obj, params) elif isinstance(obj, Dict): _populate_deep_object_query_params_dict(metadata.get("field_name", field_name), obj, params) def _populate_deep_object_query_params_dataclass( prior_params_key: str, obj: Any, params: Dict[str, List[str]] ): if obj is None: return if not is_dataclass(obj): return obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: obj_param_metadata = obj_field.metadata.get("query_param") if not obj_param_metadata: continue obj_val = getattr(obj, obj_field.name) if obj_val is None: continue params_key = f'{prior_params_key}[{obj_param_metadata.get("field_name", obj_field.name)}]' if is_dataclass(obj_val): _populate_deep_object_query_params_dataclass(params_key, obj_val, params) elif isinstance(obj_val, Dict): _populate_deep_object_query_params_dict(params_key, obj_val, params) elif isinstance(obj_val, List): _populate_deep_object_query_params_list(params_key, obj_val, params) else: params[params_key] = [_val_to_string(obj_val)] def _populate_deep_object_query_params_dict( prior_params_key: str, value: Dict, params: Dict[str, List[str]] ): if value is None: return for key, val in value.items(): if val is None: continue params_key = f'{prior_params_key}[{key}]' if is_dataclass(val): _populate_deep_object_query_params_dataclass(params_key, val, params) elif isinstance(val, Dict): _populate_deep_object_query_params_dict(params_key, val, params) elif isinstance(val, List): _populate_deep_object_query_params_list(params_key, val, params) else: params[params_key] = [_val_to_string(val)] def _populate_deep_object_query_params_list( params_key: str, value: List, params: Dict[str, List[str]] ): if value is None: return for val in value: if val is None: continue if params.get(params_key) is None: params[params_key] = [] params[params_key].append(_val_to_string(val)) def _get_query_param_field_name(obj_field: Field) -> str: obj_param_metadata = obj_field.metadata.get("query_param") if not obj_param_metadata: return "" return obj_param_metadata.get("field_name", obj_field.name) def _populate_delimited_query_params( metadata: Dict, field_name: str, obj: Any, delimiter: str, query_param_values: Dict[str, List[str]], ): _populate_form( field_name, metadata.get("explode", True), obj, _get_query_param_field_name, delimiter, query_param_values, ) SERIALIZATION_METHOD_TO_CONTENT_TYPE = { "json": "application/json", "form": "application/x-www-form-urlencoded", "multipart": "multipart/form-data", "raw": "application/octet-stream", "string": "text/plain", } def serialize_request_body( request: Any, request_type: type, request_field_name: str, nullable: bool, optional: bool, serialization_method: str, encoder=None, ) -> Tuple[Optional[str], Optional[Any], Optional[Any]]: if request is None: if not nullable and optional: return None, None, None if not is_dataclass(request) or not hasattr(request, request_field_name): return serialize_content_type( request_field_name, request_type, SERIALIZATION_METHOD_TO_CONTENT_TYPE[serialization_method], request, encoder, ) request_val = getattr(request, request_field_name) if request_val is None: if not nullable and optional: return None, None, None request_fields: Tuple[Field, ...] = fields(request) request_metadata = None for field in request_fields: if field.name == request_field_name: request_metadata = field.metadata.get("request") break if request_metadata is None: raise Exception("invalid request type") return serialize_content_type( request_field_name, request_type, request_metadata.get("media_type", "application/octet-stream"), request_val, ) def serialize_content_type( field_name: str, request_type: Any, media_type: str, request: Any, encoder=None ) -> Tuple[Optional[str], Optional[Any], Optional[List[List[Any]]]]: if re.match(r"(application|text)\/.*?\+*json.*", media_type) is not None: return media_type, marshal_json(request, request_type, encoder), None if re.match(r"multipart\/.*", media_type) is not None: return serialize_multipart_form(media_type, request) if re.match(r"application\/x-www-form-urlencoded.*", media_type) is not None: return media_type, serialize_form_data(field_name, request), None if isinstance(request, (bytes, bytearray)): return media_type, request, None if isinstance(request, str): return media_type, request, None raise Exception( f"invalid request body type {type(request)} for mediaType {media_type}" ) def serialize_multipart_form( media_type: str, request: Any ) -> Tuple[str, Any, List[List[Any]]]: form: List[List[Any]] = [] request_fields = fields(request) for field in request_fields: val = getattr(request, field.name) if val is None: continue field_metadata = field.metadata.get("multipart_form") if not field_metadata: continue if field_metadata.get("file") is True: file_fields = fields(val) file_name = "" field_name = "" content = bytes() for file_field in file_fields: file_metadata = file_field.metadata.get("multipart_form") if file_metadata is None: continue if file_metadata.get("content") is True: content = getattr(val, file_field.name) else: field_name = file_metadata.get("field_name", file_field.name) file_name = getattr(val, file_field.name) if field_name == "" or file_name == "" or content == bytes(): raise Exception("invalid multipart/form-data file") form.append([field_name, [file_name, content]]) elif field_metadata.get("json") is True: to_append = [ field_metadata.get("field_name", field.name), [None, marshal_json(val, field.type), "application/json"], ] form.append(to_append) else: field_name = field_metadata.get("field_name", field.name) if isinstance(val, List): for value in val: if value is None: continue form.append([field_name + "[]", [None, _val_to_string(value)]]) else: form.append([field_name, [None, _val_to_string(val)]]) return media_type, None, form def serialize_dict( original: Dict, explode: bool, field_name, existing: Optional[Dict[str, List[str]]] ) -> Dict[str, List[str]]: if existing is None: existing = {} if explode is True: for key, val in original.items(): if key not in existing: existing[key] = [] existing[key].append(val) else: temp = [] for key, val in original.items(): temp.append(str(key)) temp.append(str(val)) if field_name not in existing: existing[field_name] = [] existing[field_name].append(",".join(temp)) return existing def serialize_form_data(field_name: str, data: Any) -> Dict[str, Any]: form: Dict[str, List[str]] = {} if is_dataclass(data): for field in fields(data): val = getattr(data, field.name) if val is None: continue metadata = field.metadata.get("form") if metadata is None: continue field_name = metadata.get("field_name", field.name) if metadata.get("json"): form[field_name] = [marshal_json(val, field.type)] else: if metadata.get("style", "form") == "form": _populate_form( field_name, metadata.get("explode", True), val, _get_form_field_name, ",", form, ) else: raise Exception(f"Invalid form style for field {field.name}") elif isinstance(data, Dict): for key, value in data.items(): form[key] = [_val_to_string(value)] else: raise Exception(f"Invalid request body type for field {field_name}") return form def _get_form_field_name(obj_field: Field) -> str: obj_param_metadata = obj_field.metadata.get("form") if not obj_param_metadata: return "" return obj_param_metadata.get("field_name", obj_field.name) def _populate_form( field_name: str, explode: boolean, obj: Any, get_field_name_func: Callable, delimiter: str, form: Dict[str, List[str]], ): if obj is None: return form if is_dataclass(obj): items = [] obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: obj_field_name = get_field_name_func(obj_field) if obj_field_name == "": continue val = getattr(obj, obj_field.name) if val is None: continue if explode: form[obj_field_name] = [_val_to_string(val)] else: items.append(f"{obj_field_name}{delimiter}{_val_to_string(val)}") if len(items) > 0: form[field_name] = [delimiter.join(items)] elif isinstance(obj, Dict): items = [] for key, value in obj.items(): if value is None: continue if explode: form[key] = [_val_to_string(value)] else: items.append(f"{key}{delimiter}{_val_to_string(value)}") if len(items) > 0: form[field_name] = [delimiter.join(items)] elif isinstance(obj, List): items = [] for value in obj: if value is None: continue if explode: if not field_name in form: form[field_name] = [] form[field_name].append(_val_to_string(value)) else: items.append(_val_to_string(value)) if len(items) > 0: form[field_name] = [delimiter.join([str(item) for item in items])] else: form[field_name] = [_val_to_string(obj)] return form def _serialize_header(explode: bool, obj: Any) -> str: if obj is None: return "" if is_dataclass(obj): items = [] obj_fields: Tuple[Field, ...] = fields(obj) for obj_field in obj_fields: obj_param_metadata = obj_field.metadata.get("header") if not obj_param_metadata: continue obj_field_name = obj_param_metadata.get("field_name", obj_field.name) if obj_field_name == "": continue val = getattr(obj, obj_field.name) if val is None: continue if explode: items.append(f"{obj_field_name}={_val_to_string(val)}") else: items.append(obj_field_name) items.append(_val_to_string(val)) if len(items) > 0: return ",".join(items) elif isinstance(obj, Dict): items = [] for key, value in obj.items(): if value is None: continue if explode: items.append(f"{key}={_val_to_string(value)}") else: items.append(key) items.append(_val_to_string(value)) if len(items) > 0: return ",".join([str(item) for item in items]) elif isinstance(obj, List): items = [] for value in obj: if value is None: continue items.append(_val_to_string(value)) if len(items) > 0: return ",".join(items) else: return f"{_val_to_string(obj)}" return "" def unmarshal_json(data, typ, decoder=None, infer_missing=False): unmarshal = make_dataclass("Unmarshal", [("res", typ)], bases=(DataClassJsonMixin,)) json_dict = json.loads(data) try: out = unmarshal.from_dict({"res": json_dict}, infer_missing=infer_missing) except AttributeError as attr_err: raise AttributeError( f"unable to unmarshal {data} as {typ} - {attr_err}" ) from attr_err return out.res if decoder is None else decoder(out.res) def marshal_json(val, typ, encoder=None): if not is_optional_type(typ) and val is None: raise ValueError(f"Could not marshal None into non-optional type: {typ}") marshal = make_dataclass("Marshal", [("res", typ)], bases=(DataClassJsonMixin,)) marshaller = marshal(res=val) json_dict = marshaller.to_dict() val = json_dict["res"] if encoder is None else encoder(json_dict["res"]) return json.dumps(val, separators=(",", ":"), sort_keys=True) def match_content_type(content_type: str, pattern: str) -> boolean: if pattern in (content_type, "*", "*/*"): return True msg = Message() msg["content-type"] = content_type media_type = msg.get_content_type() if media_type == pattern: return True parts = media_type.split("/") if len(parts) == 2: if pattern in (f"{parts[0]}/*", f"*/{parts[1]}"): return True return False def match_status_codes(status_codes: List[str], status_code: int) -> bool: for code in status_codes: if code == str(status_code): return True if code.endswith("XX") and code.startswith(str(status_code)[:1]): return True return False def datetimeisoformat(optional: bool): def isoformatoptional(val): if optional and val is None: return None return _val_to_string(val) return isoformatoptional def dateisoformat(optional: bool): def isoformatoptional(val): if optional and val is None: return None return date.isoformat(val) return isoformatoptional def datefromisoformat(date_str: str): return dateutil.parser.parse(date_str).date() def bigintencoder(optional: bool): def bigintencode(val: int): if optional and val is None: return None return str(val) return bigintencode def bigintdecoder(val): if val is None: return None if isinstance(val, float): raise ValueError(f"{val} is a float") return int(val) def integerstrencoder(optional: bool): def integerstrencode(val: int): if optional and val is None: return None return str(val) return integerstrencode def integerstrdecoder(val): if val is None: return None if isinstance(val, float): raise ValueError(f"{val} is a float") return int(val) def numberstrencoder(optional: bool): def numberstrencode(val: float): if optional and val is None: return None return str(val) return numberstrencode def numberstrdecoder(val): if val is None: return None return float(val) def decimalencoder(optional: bool, as_str: bool): def decimalencode(val: Decimal): if optional and val is None: return None if as_str: return str(val) return float(val) return decimalencode def decimaldecoder(val): if val is None: return None return Decimal(str(val)) def map_encoder(optional: bool, value_encoder: Callable): def map_encode(val: Dict): if optional and val is None: return None encoded = {} for key, value in val.items(): encoded[key] = value_encoder(value) return encoded return map_encode def map_decoder(value_decoder: Callable): def map_decode(val: Dict): decoded = {} for key, value in val.items(): decoded[key] = value_decoder(value) return decoded return map_decode def list_encoder(optional: bool, value_encoder: Callable): def list_encode(val: List): if optional and val is None: return None encoded = [] for value in val: encoded.append(value_encoder(value)) return encoded return list_encode def list_decoder(value_decoder: Callable): def list_decode(val: List): decoded = [] for value in val: decoded.append(value_decoder(value)) return decoded return list_decode def union_encoder(all_encoders: Dict[str, Callable]): def selective_encoder(val: Any): if type(val) in all_encoders: return all_encoders[type(val)](val) return val return selective_encoder def union_decoder(all_decoders: List[Callable]): def selective_decoder(val: Any): decoded = val for decoder in all_decoders: try: decoded = decoder(val) break except (TypeError, ValueError): continue return decoded return selective_decoder def get_field_name(name): def override(_, _field_name=name): return _field_name return override def _val_to_string(val) -> str: if isinstance(val, bool): return str(val).lower() if isinstance(val, datetime): return str(val.isoformat().replace("+00:00", "Z")) if isinstance(val, Enum): return str(val.value) return str(val) def _populate_from_globals( param_name: str, value: Any, param_type: str, gbls: Any ) -> Tuple[Any, bool]: if gbls is None: return value, False global_fields = fields(gbls) found = False for field in global_fields: if field.name is not param_name: continue found = True if value is not None: return value, True global_value = getattr(gbls, field.name) param_metadata = field.metadata.get(param_type) if param_metadata is None: return value, True return global_value, True return value, found def decoder_with_discriminator(field_name): def decode_fx(obj): kls = getattr(sys.modules["sdk.models.shared"], obj[field_name]) return unmarshal_json(json.dumps(obj), kls) return decode_fx def remove_suffix(input_string, suffix): if suffix and input_string.endswith(suffix): return input_string[: -len(suffix)] return input_string