about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py')
-rw-r--r--.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py1036
1 files changed, 1036 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py b/.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py
new file mode 100644
index 00000000..dcde1d94
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/google/protobuf/internal/decoder.py
@@ -0,0 +1,1036 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc.  All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Code for decoding protocol buffer primitives.
+
+This code is very similar to encoder.py -- read the docs for that module first.
+
+A "decoder" is a function with the signature:
+  Decode(buffer, pos, end, message, field_dict)
+The arguments are:
+  buffer:     The string containing the encoded message.
+  pos:        The current position in the string.
+  end:        The position in the string where the current message ends.  May be
+              less than len(buffer) if we're reading a sub-message.
+  message:    The message object into which we're parsing.
+  field_dict: message._fields (avoids a hashtable lookup).
+The decoder reads the field and stores it into field_dict, returning the new
+buffer position.  A decoder for a repeated field may proactively decode all of
+the elements of that field, if they appear consecutively.
+
+Note that decoders may throw any of the following:
+  IndexError:  Indicates a truncated message.
+  struct.error:  Unpacking of a fixed-width field failed.
+  message.DecodeError:  Other errors.
+
+Decoders are expected to raise an exception if they are called with pos > end.
+This allows callers to be lax about bounds checking:  it's fineto read past
+"end" as long as you are sure that someone else will notice and throw an
+exception later on.
+
+Something up the call stack is expected to catch IndexError and struct.error
+and convert them to message.DecodeError.
+
+Decoders are constructed using decoder constructors with the signature:
+  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
+The arguments are:
+  field_number:  The field number of the field we want to decode.
+  is_repeated:   Is the field a repeated field? (bool)
+  is_packed:     Is the field a packed field? (bool)
+  key:           The key to use when looking up the field within field_dict.
+                 (This is actually the FieldDescriptor but nothing in this
+                 file should depend on that.)
+  new_default:   A function which takes a message object as a parameter and
+                 returns a new instance of the default value for this field.
+                 (This is called for repeated fields and sub-messages, when an
+                 instance does not already exist.)
+
+As with encoders, we define a decoder constructor for every type of field.
+Then, for every field of every message class we construct an actual decoder.
+That decoder goes into a dict indexed by tag, so when we decode a message
+we repeatedly read a tag, look up the corresponding decoder, and invoke it.
+"""
+
+__author__ = 'kenton@google.com (Kenton Varda)'
+
+import math
+import struct
+
+from google.protobuf import message
+from google.protobuf.internal import containers
+from google.protobuf.internal import encoder
+from google.protobuf.internal import wire_format
+
+
+# This is not for optimization, but rather to avoid conflicts with local
+# variables named "message".
+_DecodeError = message.DecodeError
+
+
+def _VarintDecoder(mask, result_type):
+  """Return an encoder for a basic varint value (does not include tag).
+
+  Decoded values will be bitwise-anded with the given mask before being
+  returned, e.g. to limit them to 32 bits.  The returned decoder does not
+  take the usual "end" parameter -- the caller is expected to do bounds checking
+  after the fact (often the caller can defer such checking until later).  The
+  decoder returns a (value, new_pos) pair.
+  """
+
+  def DecodeVarint(buffer, pos: int=None):
+    result = 0
+    shift = 0
+    while 1:
+      if pos is None:
+        # Read from BytesIO
+        try:
+          b = buffer.read(1)[0]
+        except IndexError as e:
+          if shift == 0:
+            # End of BytesIO.
+            return None
+          else:
+            raise ValueError('Fail to read varint %s' % str(e))
+      else:
+        b = buffer[pos]
+        pos += 1
+      result |= ((b & 0x7f) << shift)
+      if not (b & 0x80):
+        result &= mask
+        result = result_type(result)
+        return result if pos is None else (result, pos)
+      shift += 7
+      if shift >= 64:
+        raise _DecodeError('Too many bytes when decoding varint.')
+
+  return DecodeVarint
+
+
+def _SignedVarintDecoder(bits, result_type):
+  """Like _VarintDecoder() but decodes signed values."""
+
+  signbit = 1 << (bits - 1)
+  mask = (1 << bits) - 1
+
+  def DecodeVarint(buffer, pos):
+    result = 0
+    shift = 0
+    while 1:
+      b = buffer[pos]
+      result |= ((b & 0x7f) << shift)
+      pos += 1
+      if not (b & 0x80):
+        result &= mask
+        result = (result ^ signbit) - signbit
+        result = result_type(result)
+        return (result, pos)
+      shift += 7
+      if shift >= 64:
+        raise _DecodeError('Too many bytes when decoding varint.')
+  return DecodeVarint
+
+# All 32-bit and 64-bit values are represented as int.
+_DecodeVarint = _VarintDecoder((1 << 64) - 1, int)
+_DecodeSignedVarint = _SignedVarintDecoder(64, int)
+
+# Use these versions for values which must be limited to 32 bits.
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
+_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
+
+
+def ReadTag(buffer, pos):
+  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
+
+  We return the raw bytes of the tag rather than decoding them.  The raw
+  bytes can then be used to look up the proper decoder.  This effectively allows
+  us to trade some work that would be done in pure-python (decoding a varint)
+  for work that is done in C (searching for a byte string in a hash table).
+  In a low-level language it would be much cheaper to decode the varint and
+  use that, but not in Python.
+
+  Args:
+    buffer: memoryview object of the encoded bytes
+    pos: int of the current position to start from
+
+  Returns:
+    Tuple[bytes, int] of the tag data and new position.
+  """
+  start = pos
+  while buffer[pos] & 0x80:
+    pos += 1
+  pos += 1
+
+  tag_bytes = buffer[start:pos].tobytes()
+  return tag_bytes, pos
+
+
+# --------------------------------------------------------------------
+
+
+def _SimpleDecoder(wire_type, decode_value):
+  """Return a constructor for a decoder for fields of a particular type.
+
+  Args:
+      wire_type:  The field's wire type.
+      decode_value:  A function which decodes an individual value, e.g.
+        _DecodeVarint()
+  """
+
+  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
+                      clear_if_default=False):
+    if is_packed:
+      local_DecodeVarint = _DecodeVarint
+      def DecodePackedField(buffer, pos, end, message, field_dict):
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        (endpoint, pos) = local_DecodeVarint(buffer, pos)
+        endpoint += pos
+        if endpoint > end:
+          raise _DecodeError('Truncated message.')
+        while pos < endpoint:
+          (element, pos) = decode_value(buffer, pos)
+          value.append(element)
+        if pos > endpoint:
+          del value[-1]   # Discard corrupt value.
+          raise _DecodeError('Packed element was truncated.')
+        return pos
+      return DecodePackedField
+    elif is_repeated:
+      tag_bytes = encoder.TagBytes(field_number, wire_type)
+      tag_len = len(tag_bytes)
+      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        while 1:
+          (element, new_pos) = decode_value(buffer, pos)
+          value.append(element)
+          # Predict that the next tag is another copy of the same repeated
+          # field.
+          pos = new_pos + tag_len
+          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+            # Prediction failed.  Return.
+            if new_pos > end:
+              raise _DecodeError('Truncated message.')
+            return new_pos
+      return DecodeRepeatedField
+    else:
+      def DecodeField(buffer, pos, end, message, field_dict):
+        (new_value, pos) = decode_value(buffer, pos)
+        if pos > end:
+          raise _DecodeError('Truncated message.')
+        if clear_if_default and not new_value:
+          field_dict.pop(key, None)
+        else:
+          field_dict[key] = new_value
+        return pos
+      return DecodeField
+
+  return SpecificDecoder
+
+
+def _ModifiedDecoder(wire_type, decode_value, modify_value):
+  """Like SimpleDecoder but additionally invokes modify_value on every value
+  before storing it.  Usually modify_value is ZigZagDecode.
+  """
+
+  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+  # not enough to make a significant difference.
+
+  def InnerDecode(buffer, pos):
+    (result, new_pos) = decode_value(buffer, pos)
+    return (modify_value(result), new_pos)
+  return _SimpleDecoder(wire_type, InnerDecode)
+
+
+def _StructPackDecoder(wire_type, format):
+  """Return a constructor for a decoder for a fixed-width field.
+
+  Args:
+      wire_type:  The field's wire type.
+      format:  The format string to pass to struct.unpack().
+  """
+
+  value_size = struct.calcsize(format)
+  local_unpack = struct.unpack
+
+  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
+  # not enough to make a significant difference.
+
+  # Note that we expect someone up-stack to catch struct.error and convert
+  # it to _DecodeError -- this way we don't have to set up exception-
+  # handling blocks every time we parse one value.
+
+  def InnerDecode(buffer, pos):
+    new_pos = pos + value_size
+    result = local_unpack(format, buffer[pos:new_pos])[0]
+    return (result, new_pos)
+  return _SimpleDecoder(wire_type, InnerDecode)
+
+
+def _FloatDecoder():
+  """Returns a decoder for a float field.
+
+  This code works around a bug in struct.unpack for non-finite 32-bit
+  floating-point values.
+  """
+
+  local_unpack = struct.unpack
+
+  def InnerDecode(buffer, pos):
+    """Decode serialized float to a float and new position.
+
+    Args:
+      buffer: memoryview of the serialized bytes
+      pos: int, position in the memory view to start at.
+
+    Returns:
+      Tuple[float, int] of the deserialized float value and new position
+      in the serialized data.
+    """
+    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
+    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
+    new_pos = pos + 4
+    float_bytes = buffer[pos:new_pos].tobytes()
+
+    # If this value has all its exponent bits set, then it's non-finite.
+    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
+    # To avoid that, we parse it specially.
+    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
+      # If at least one significand bit is set...
+      if float_bytes[0:3] != b'\x00\x00\x80':
+        return (math.nan, new_pos)
+      # If sign bit is set...
+      if float_bytes[3:4] == b'\xFF':
+        return (-math.inf, new_pos)
+      return (math.inf, new_pos)
+
+    # Note that we expect someone up-stack to catch struct.error and convert
+    # it to _DecodeError -- this way we don't have to set up exception-
+    # handling blocks every time we parse one value.
+    result = local_unpack('<f', float_bytes)[0]
+    return (result, new_pos)
+  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
+
+
+def _DoubleDecoder():
+  """Returns a decoder for a double field.
+
+  This code works around a bug in struct.unpack for not-a-number.
+  """
+
+  local_unpack = struct.unpack
+
+  def InnerDecode(buffer, pos):
+    """Decode serialized double to a double and new position.
+
+    Args:
+      buffer: memoryview of the serialized bytes.
+      pos: int, position in the memory view to start at.
+
+    Returns:
+      Tuple[float, int] of the decoded double value and new position
+      in the serialized data.
+    """
+    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
+    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
+    new_pos = pos + 8
+    double_bytes = buffer[pos:new_pos].tobytes()
+
+    # If this value has all its exponent bits set and at least one significand
+    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
+    # as inf or -inf.  To avoid that, we treat it specially.
+    if ((double_bytes[7:8] in b'\x7F\xFF')
+        and (double_bytes[6:7] >= b'\xF0')
+        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
+      return (math.nan, new_pos)
+
+    # Note that we expect someone up-stack to catch struct.error and convert
+    # it to _DecodeError -- this way we don't have to set up exception-
+    # handling blocks every time we parse one value.
+    result = local_unpack('<d', double_bytes)[0]
+    return (result, new_pos)
+  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
+
+
+def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
+                clear_if_default=False):
+  """Returns a decoder for enum field."""
+  enum_type = key.enum_type
+  if is_packed:
+    local_DecodeVarint = _DecodeVarint
+    def DecodePackedField(buffer, pos, end, message, field_dict):
+      """Decode serialized packed enum to its value and a new position.
+
+      Args:
+        buffer: memoryview of the serialized bytes.
+        pos: int, position in the memory view to start at.
+        end: int, end position of serialized data
+        message: Message object to store unknown fields in
+        field_dict: Map[Descriptor, Any] to store decoded values in.
+
+      Returns:
+        int, new position in serialized data.
+      """
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      (endpoint, pos) = local_DecodeVarint(buffer, pos)
+      endpoint += pos
+      if endpoint > end:
+        raise _DecodeError('Truncated message.')
+      while pos < endpoint:
+        value_start_pos = pos
+        (element, pos) = _DecodeSignedVarint32(buffer, pos)
+        # pylint: disable=protected-access
+        if element in enum_type.values_by_number:
+          value.append(element)
+        else:
+          if not message._unknown_fields:
+            message._unknown_fields = []
+          tag_bytes = encoder.TagBytes(field_number,
+                                       wire_format.WIRETYPE_VARINT)
+
+          message._unknown_fields.append(
+              (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+          # pylint: enable=protected-access
+      if pos > endpoint:
+        if element in enum_type.values_by_number:
+          del value[-1]   # Discard corrupt value.
+        else:
+          del message._unknown_fields[-1]
+          # pylint: enable=protected-access
+        raise _DecodeError('Packed element was truncated.')
+      return pos
+    return DecodePackedField
+  elif is_repeated:
+    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      """Decode serialized repeated enum to its value and a new position.
+
+      Args:
+        buffer: memoryview of the serialized bytes.
+        pos: int, position in the memory view to start at.
+        end: int, end position of serialized data
+        message: Message object to store unknown fields in
+        field_dict: Map[Descriptor, Any] to store decoded values in.
+
+      Returns:
+        int, new position in serialized data.
+      """
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+        # pylint: disable=protected-access
+        if element in enum_type.values_by_number:
+          value.append(element)
+        else:
+          if not message._unknown_fields:
+            message._unknown_fields = []
+          message._unknown_fields.append(
+              (tag_bytes, buffer[pos:new_pos].tobytes()))
+        # pylint: enable=protected-access
+        # Predict that the next tag is another copy of the same repeated
+        # field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+          # Prediction failed.  Return.
+          if new_pos > end:
+            raise _DecodeError('Truncated message.')
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      """Decode serialized repeated enum to its value and a new position.
+
+      Args:
+        buffer: memoryview of the serialized bytes.
+        pos: int, position in the memory view to start at.
+        end: int, end position of serialized data
+        message: Message object to store unknown fields in
+        field_dict: Map[Descriptor, Any] to store decoded values in.
+
+      Returns:
+        int, new position in serialized data.
+      """
+      value_start_pos = pos
+      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
+      if pos > end:
+        raise _DecodeError('Truncated message.')
+      if clear_if_default and not enum_value:
+        field_dict.pop(key, None)
+        return pos
+      # pylint: disable=protected-access
+      if enum_value in enum_type.values_by_number:
+        field_dict[key] = enum_value
+      else:
+        if not message._unknown_fields:
+          message._unknown_fields = []
+        tag_bytes = encoder.TagBytes(field_number,
+                                     wire_format.WIRETYPE_VARINT)
+        message._unknown_fields.append(
+            (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+        # pylint: enable=protected-access
+      return pos
+    return DecodeField
+
+
+# --------------------------------------------------------------------
+
+
+Int32Decoder = _SimpleDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
+
+Int64Decoder = _SimpleDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
+
+UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
+UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
+
+SInt32Decoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
+SInt64Decoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
+
+# Note that Python conveniently guarantees that when using the '<' prefix on
+# formats, they will also have the same size across all platforms (as opposed
+# to without the prefix, where their sizes depend on the C compiler's basic
+# type sizes).
+Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
+Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
+SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
+SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
+FloatDecoder = _FloatDecoder()
+DoubleDecoder = _DoubleDecoder()
+
+BoolDecoder = _ModifiedDecoder(
+    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
+
+
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
+                  clear_if_default=False):
+  """Returns a decoder for a string field."""
+
+  local_DecodeVarint = _DecodeVarint
+
+  def _ConvertToUnicode(memview):
+    """Convert byte to unicode."""
+    byte_str = memview.tobytes()
+    try:
+      value = str(byte_str, 'utf-8')
+    except UnicodeDecodeError as e:
+      # add more information to the error message and re-raise it.
+      e.reason = '%s in field: %s' % (e, key.full_name)
+      raise
+
+    return value
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated string.')
+        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated string.')
+      if clear_if_default and not size:
+        field_dict.pop(key, None)
+      else:
+        field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
+      return new_pos
+    return DecodeField
+
+
+def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
+                 clear_if_default=False):
+  """Returns a decoder for a bytes field."""
+
+  local_DecodeVarint = _DecodeVarint
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated string.')
+        value.append(buffer[pos:new_pos].tobytes())
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated string.')
+      if clear_if_default and not size:
+        field_dict.pop(key, None)
+      else:
+        field_dict[key] = buffer[pos:new_pos].tobytes()
+      return new_pos
+    return DecodeField
+
+
+def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a group field."""
+
+  end_tag_bytes = encoder.TagBytes(field_number,
+                                   wire_format.WIRETYPE_END_GROUP)
+  end_tag_len = len(end_tag_bytes)
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_START_GROUP)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        value = field_dict.get(key)
+        if value is None:
+          value = field_dict.setdefault(key, new_default(message))
+        # Read sub-message.
+        pos = value.add()._InternalParse(buffer, pos, end)
+        # Read end tag.
+        new_pos = pos+end_tag_len
+        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+          raise _DecodeError('Missing group end tag.')
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      # Read sub-message.
+      pos = value._InternalParse(buffer, pos, end)
+      # Read end tag.
+      new_pos = pos+end_tag_len
+      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
+        raise _DecodeError('Missing group end tag.')
+      return new_pos
+    return DecodeField
+
+
+def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
+  """Returns a decoder for a message field."""
+
+  local_DecodeVarint = _DecodeVarint
+
+  assert not is_packed
+  if is_repeated:
+    tag_bytes = encoder.TagBytes(field_number,
+                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
+    tag_len = len(tag_bytes)
+    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      while 1:
+        # Read length.
+        (size, pos) = local_DecodeVarint(buffer, pos)
+        new_pos = pos + size
+        if new_pos > end:
+          raise _DecodeError('Truncated message.')
+        # Read sub-message.
+        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+          # The only reason _InternalParse would return early is if it
+          # encountered an end-group tag.
+          raise _DecodeError('Unexpected end-group tag.')
+        # Predict that the next tag is another copy of the same repeated field.
+        pos = new_pos + tag_len
+        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+          # Prediction failed.  Return.
+          return new_pos
+    return DecodeRepeatedField
+  else:
+    def DecodeField(buffer, pos, end, message, field_dict):
+      value = field_dict.get(key)
+      if value is None:
+        value = field_dict.setdefault(key, new_default(message))
+      # Read length.
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated message.')
+      # Read sub-message.
+      if value._InternalParse(buffer, pos, new_pos) != new_pos:
+        # The only reason _InternalParse would return early is if it encountered
+        # an end-group tag.
+        raise _DecodeError('Unexpected end-group tag.')
+      return new_pos
+    return DecodeField
+
+
+# --------------------------------------------------------------------
+
+MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
+
+def MessageSetItemDecoder(descriptor):
+  """Returns a decoder for a MessageSet item.
+
+  The parameter is the message Descriptor.
+
+  The message set message looks like this:
+    message MessageSet {
+      repeated group Item = 1 {
+        required int32 type_id = 2;
+        required string message = 3;
+      }
+    }
+  """
+
+  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
+  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+
+  local_ReadTag = ReadTag
+  local_DecodeVarint = _DecodeVarint
+  local_SkipField = SkipField
+
+  def DecodeItem(buffer, pos, end, message, field_dict):
+    """Decode serialized message set to its value and new position.
+
+    Args:
+      buffer: memoryview of the serialized bytes.
+      pos: int, position in the memory view to start at.
+      end: int, end position of serialized data
+      message: Message object to store unknown fields in
+      field_dict: Map[Descriptor, Any] to store decoded values in.
+
+    Returns:
+      int, new position in serialized data.
+    """
+    message_set_item_start = pos
+    type_id = -1
+    message_start = -1
+    message_end = -1
+
+    # Technically, type_id and message can appear in any order, so we need
+    # a little loop here.
+    while 1:
+      (tag_bytes, pos) = local_ReadTag(buffer, pos)
+      if tag_bytes == type_id_tag_bytes:
+        (type_id, pos) = local_DecodeVarint(buffer, pos)
+      elif tag_bytes == message_tag_bytes:
+        (size, message_start) = local_DecodeVarint(buffer, pos)
+        pos = message_end = message_start + size
+      elif tag_bytes == item_end_tag_bytes:
+        break
+      else:
+        pos = SkipField(buffer, pos, end, tag_bytes)
+        if pos == -1:
+          raise _DecodeError('Missing group end tag.')
+
+    if pos > end:
+      raise _DecodeError('Truncated message.')
+
+    if type_id == -1:
+      raise _DecodeError('MessageSet item missing type_id.')
+    if message_start == -1:
+      raise _DecodeError('MessageSet item missing message.')
+
+    extension = message.Extensions._FindExtensionByNumber(type_id)
+    # pylint: disable=protected-access
+    if extension is not None:
+      value = field_dict.get(extension)
+      if value is None:
+        message_type = extension.message_type
+        if not hasattr(message_type, '_concrete_class'):
+          message_factory.GetMessageClass(message_type)
+        value = field_dict.setdefault(
+            extension, message_type._concrete_class())
+      if value._InternalParse(buffer, message_start,message_end) != message_end:
+        # The only reason _InternalParse would return early is if it encountered
+        # an end-group tag.
+        raise _DecodeError('Unexpected end-group tag.')
+    else:
+      if not message._unknown_fields:
+        message._unknown_fields = []
+      message._unknown_fields.append(
+          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
+      # pylint: enable=protected-access
+
+    return pos
+
+  return DecodeItem
+
+
+def UnknownMessageSetItemDecoder():
+  """Returns a decoder for a Unknown MessageSet item."""
+
+  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
+  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
+  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
+
+  def DecodeUnknownItem(buffer):
+    pos = 0
+    end = len(buffer)
+    message_start = -1
+    message_end = -1
+    while 1:
+      (tag_bytes, pos) = ReadTag(buffer, pos)
+      if tag_bytes == type_id_tag_bytes:
+        (type_id, pos) = _DecodeVarint(buffer, pos)
+      elif tag_bytes == message_tag_bytes:
+        (size, message_start) = _DecodeVarint(buffer, pos)
+        pos = message_end = message_start + size
+      elif tag_bytes == item_end_tag_bytes:
+        break
+      else:
+        pos = SkipField(buffer, pos, end, tag_bytes)
+        if pos == -1:
+          raise _DecodeError('Missing group end tag.')
+
+    if pos > end:
+      raise _DecodeError('Truncated message.')
+
+    if type_id == -1:
+      raise _DecodeError('MessageSet item missing type_id.')
+    if message_start == -1:
+      raise _DecodeError('MessageSet item missing message.')
+
+    return (type_id, buffer[message_start:message_end].tobytes())
+
+  return DecodeUnknownItem
+
+# --------------------------------------------------------------------
+
+def MapDecoder(field_descriptor, new_default, is_message_map):
+  """Returns a decoder for a map field."""
+
+  key = field_descriptor
+  tag_bytes = encoder.TagBytes(field_descriptor.number,
+                               wire_format.WIRETYPE_LENGTH_DELIMITED)
+  tag_len = len(tag_bytes)
+  local_DecodeVarint = _DecodeVarint
+  # Can't read _concrete_class yet; might not be initialized.
+  message_type = field_descriptor.message_type
+
+  def DecodeMap(buffer, pos, end, message, field_dict):
+    submsg = message_type._concrete_class()
+    value = field_dict.get(key)
+    if value is None:
+      value = field_dict.setdefault(key, new_default(message))
+    while 1:
+      # Read length.
+      (size, pos) = local_DecodeVarint(buffer, pos)
+      new_pos = pos + size
+      if new_pos > end:
+        raise _DecodeError('Truncated message.')
+      # Read sub-message.
+      submsg.Clear()
+      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
+        # The only reason _InternalParse would return early is if it
+        # encountered an end-group tag.
+        raise _DecodeError('Unexpected end-group tag.')
+
+      if is_message_map:
+        value[submsg.key].CopyFrom(submsg.value)
+      else:
+        value[submsg.key] = submsg.value
+
+      # Predict that the next tag is another copy of the same repeated field.
+      pos = new_pos + tag_len
+      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
+        # Prediction failed.  Return.
+        return new_pos
+
+  return DecodeMap
+
+# --------------------------------------------------------------------
+# Optimization is not as heavy here because calls to SkipField() are rare,
+# except for handling end-group tags.
+
+def _SkipVarint(buffer, pos, end):
+  """Skip a varint value.  Returns the new position."""
+  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
+  # With this code, ord(b'') raises TypeError.  Both are handled in
+  # python_message.py to generate a 'Truncated message' error.
+  while ord(buffer[pos:pos+1].tobytes()) & 0x80:
+    pos += 1
+  pos += 1
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+def _SkipFixed64(buffer, pos, end):
+  """Skip a fixed64 value.  Returns the new position."""
+
+  pos += 8
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+
+def _DecodeFixed64(buffer, pos):
+  """Decode a fixed64."""
+  new_pos = pos + 8
+  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
+
+
+def _SkipLengthDelimited(buffer, pos, end):
+  """Skip a length-delimited value.  Returns the new position."""
+
+  (size, pos) = _DecodeVarint(buffer, pos)
+  pos += size
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+
+def _SkipGroup(buffer, pos, end):
+  """Skip sub-group.  Returns the new position."""
+
+  while 1:
+    (tag_bytes, pos) = ReadTag(buffer, pos)
+    new_pos = SkipField(buffer, pos, end, tag_bytes)
+    if new_pos == -1:
+      return pos
+    pos = new_pos
+
+
+def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
+  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
+
+  unknown_field_set = containers.UnknownFieldSet()
+  while end_pos is None or pos < end_pos:
+    (tag_bytes, pos) = ReadTag(buffer, pos)
+    (tag, _) = _DecodeVarint(tag_bytes, 0)
+    field_number, wire_type = wire_format.UnpackTag(tag)
+    if wire_type == wire_format.WIRETYPE_END_GROUP:
+      break
+    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+    # pylint: disable=protected-access
+    unknown_field_set._add(field_number, wire_type, data)
+
+  return (unknown_field_set, pos)
+
+
+def _DecodeUnknownField(buffer, pos, wire_type):
+  """Decode a unknown field.  Returns the UnknownField and new position."""
+
+  if wire_type == wire_format.WIRETYPE_VARINT:
+    (data, pos) = _DecodeVarint(buffer, pos)
+  elif wire_type == wire_format.WIRETYPE_FIXED64:
+    (data, pos) = _DecodeFixed64(buffer, pos)
+  elif wire_type == wire_format.WIRETYPE_FIXED32:
+    (data, pos) = _DecodeFixed32(buffer, pos)
+  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+    (size, pos) = _DecodeVarint(buffer, pos)
+    data = buffer[pos:pos+size].tobytes()
+    pos += size
+  elif wire_type == wire_format.WIRETYPE_START_GROUP:
+    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+  elif wire_type == wire_format.WIRETYPE_END_GROUP:
+    return (0, -1)
+  else:
+    raise _DecodeError('Wrong wire type in tag.')
+
+  return (data, pos)
+
+
+def _EndGroup(buffer, pos, end):
+  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
+
+  return -1
+
+
+def _SkipFixed32(buffer, pos, end):
+  """Skip a fixed32 value.  Returns the new position."""
+
+  pos += 4
+  if pos > end:
+    raise _DecodeError('Truncated message.')
+  return pos
+
+
+def _DecodeFixed32(buffer, pos):
+  """Decode a fixed32."""
+
+  new_pos = pos + 4
+  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
+
+
+def _RaiseInvalidWireType(buffer, pos, end):
+  """Skip function for unknown wire types.  Raises an exception."""
+
+  raise _DecodeError('Tag had invalid wire type.')
+
+def _FieldSkipper():
+  """Constructs the SkipField function."""
+
+  WIRETYPE_TO_SKIPPER = [
+      _SkipVarint,
+      _SkipFixed64,
+      _SkipLengthDelimited,
+      _SkipGroup,
+      _EndGroup,
+      _SkipFixed32,
+      _RaiseInvalidWireType,
+      _RaiseInvalidWireType,
+      ]
+
+  wiretype_mask = wire_format.TAG_TYPE_MASK
+
+  def SkipField(buffer, pos, end, tag_bytes):
+    """Skips a field with the specified tag.
+
+    |pos| should point to the byte immediately after the tag.
+
+    Returns:
+        The new position (after the tag value), or -1 if the tag is an end-group
+        tag (in which case the calling loop should break).
+    """
+
+    # The wire type is always in the first byte since varints are little-endian.
+    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
+    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
+
+  return SkipField
+
+SkipField = _FieldSkipper()