123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- # Protocol Buffers - Google's data interchange format
- # Copyright 2008 Google Inc. All rights reserved.
- #
- # Use of this source code is governed by a BSD-style
- # license that can be found in the LICENSE file or at
- # https://developers.google.com/open-source/licenses/bsd
- """Contains FieldMask class."""
- from google.protobuf.descriptor import FieldDescriptor
- class FieldMask(object):
- """Class for FieldMask message type."""
- __slots__ = ()
- def ToJsonString(self):
- """Converts FieldMask to string according to proto3 JSON spec."""
- camelcase_paths = []
- for path in self.paths:
- camelcase_paths.append(_SnakeCaseToCamelCase(path))
- return ','.join(camelcase_paths)
- def FromJsonString(self, value):
- """Converts string to FieldMask according to proto3 JSON spec."""
- if not isinstance(value, str):
- raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
- self.Clear()
- if value:
- for path in value.split(','):
- self.paths.append(_CamelCaseToSnakeCase(path))
- def IsValidForDescriptor(self, message_descriptor):
- """Checks whether the FieldMask is valid for Message Descriptor."""
- for path in self.paths:
- if not _IsValidPath(message_descriptor, path):
- return False
- return True
- def AllFieldsFromDescriptor(self, message_descriptor):
- """Gets all direct fields of Message Descriptor to FieldMask."""
- self.Clear()
- for field in message_descriptor.fields:
- self.paths.append(field.name)
- def CanonicalFormFromMask(self, mask):
- """Converts a FieldMask to the canonical form.
- Removes paths that are covered by another path. For example,
- "foo.bar" is covered by "foo" and will be removed if "foo"
- is also in the FieldMask. Then sorts all paths in alphabetical order.
- Args:
- mask: The original FieldMask to be converted.
- """
- tree = _FieldMaskTree(mask)
- tree.ToFieldMask(self)
- def Union(self, mask1, mask2):
- """Merges mask1 and mask2 into this FieldMask."""
- _CheckFieldMaskMessage(mask1)
- _CheckFieldMaskMessage(mask2)
- tree = _FieldMaskTree(mask1)
- tree.MergeFromFieldMask(mask2)
- tree.ToFieldMask(self)
- def Intersect(self, mask1, mask2):
- """Intersects mask1 and mask2 into this FieldMask."""
- _CheckFieldMaskMessage(mask1)
- _CheckFieldMaskMessage(mask2)
- tree = _FieldMaskTree(mask1)
- intersection = _FieldMaskTree()
- for path in mask2.paths:
- tree.IntersectPath(path, intersection)
- intersection.ToFieldMask(self)
- def MergeMessage(
- self, source, destination,
- replace_message_field=False, replace_repeated_field=False):
- """Merges fields specified in FieldMask from source to destination.
- Args:
- source: Source message.
- destination: The destination message to be merged into.
- replace_message_field: Replace message field if True. Merge message
- field if False.
- replace_repeated_field: Replace repeated field if True. Append
- elements of repeated field if False.
- """
- tree = _FieldMaskTree(self)
- tree.MergeMessage(
- source, destination, replace_message_field, replace_repeated_field)
- def _IsValidPath(message_descriptor, path):
- """Checks whether the path is valid for Message Descriptor."""
- parts = path.split('.')
- last = parts.pop()
- for name in parts:
- field = message_descriptor.fields_by_name.get(name)
- if (field is None or
- field.label == FieldDescriptor.LABEL_REPEATED or
- field.type != FieldDescriptor.TYPE_MESSAGE):
- return False
- message_descriptor = field.message_type
- return last in message_descriptor.fields_by_name
- def _CheckFieldMaskMessage(message):
- """Raises ValueError if message is not a FieldMask."""
- message_descriptor = message.DESCRIPTOR
- if (message_descriptor.name != 'FieldMask' or
- message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
- raise ValueError('Message {0} is not a FieldMask.'.format(
- message_descriptor.full_name))
- def _SnakeCaseToCamelCase(path_name):
- """Converts a path name from snake_case to camelCase."""
- result = []
- after_underscore = False
- for c in path_name:
- if c.isupper():
- raise ValueError(
- 'Fail to print FieldMask to Json string: Path name '
- '{0} must not contain uppercase letters.'.format(path_name))
- if after_underscore:
- if c.islower():
- result.append(c.upper())
- after_underscore = False
- else:
- raise ValueError(
- 'Fail to print FieldMask to Json string: The '
- 'character after a "_" must be a lowercase letter '
- 'in path name {0}.'.format(path_name))
- elif c == '_':
- after_underscore = True
- else:
- result += c
- if after_underscore:
- raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
- 'in path name {0}.'.format(path_name))
- return ''.join(result)
- def _CamelCaseToSnakeCase(path_name):
- """Converts a field name from camelCase to snake_case."""
- result = []
- for c in path_name:
- if c == '_':
- raise ValueError('Fail to parse FieldMask: Path name '
- '{0} must not contain "_"s.'.format(path_name))
- if c.isupper():
- result += '_'
- result += c.lower()
- else:
- result += c
- return ''.join(result)
- class _FieldMaskTree(object):
- """Represents a FieldMask in a tree structure.
- For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
- the FieldMaskTree will be:
- [_root] -+- foo -+- bar
- | |
- | +- baz
- |
- +- bar --- baz
- In the tree, each leaf node represents a field path.
- """
- __slots__ = ('_root',)
- def __init__(self, field_mask=None):
- """Initializes the tree by FieldMask."""
- self._root = {}
- if field_mask:
- self.MergeFromFieldMask(field_mask)
- def MergeFromFieldMask(self, field_mask):
- """Merges a FieldMask to the tree."""
- for path in field_mask.paths:
- self.AddPath(path)
- def AddPath(self, path):
- """Adds a field path into the tree.
- If the field path to add is a sub-path of an existing field path
- in the tree (i.e., a leaf node), it means the tree already matches
- the given path so nothing will be added to the tree. If the path
- matches an existing non-leaf node in the tree, that non-leaf node
- will be turned into a leaf node with all its children removed because
- the path matches all the node's children. Otherwise, a new path will
- be added.
- Args:
- path: The field path to add.
- """
- node = self._root
- for name in path.split('.'):
- if name not in node:
- node[name] = {}
- elif not node[name]:
- # Pre-existing empty node implies we already have this entire tree.
- return
- node = node[name]
- # Remove any sub-trees we might have had.
- node.clear()
- def ToFieldMask(self, field_mask):
- """Converts the tree to a FieldMask."""
- field_mask.Clear()
- _AddFieldPaths(self._root, '', field_mask)
- def IntersectPath(self, path, intersection):
- """Calculates the intersection part of a field path with this tree.
- Args:
- path: The field path to calculates.
- intersection: The out tree to record the intersection part.
- """
- node = self._root
- for name in path.split('.'):
- if name not in node:
- return
- elif not node[name]:
- intersection.AddPath(path)
- return
- node = node[name]
- intersection.AddLeafNodes(path, node)
- def AddLeafNodes(self, prefix, node):
- """Adds leaf nodes begin with prefix to this tree."""
- if not node:
- self.AddPath(prefix)
- for name in node:
- child_path = prefix + '.' + name
- self.AddLeafNodes(child_path, node[name])
- def MergeMessage(
- self, source, destination,
- replace_message, replace_repeated):
- """Merge all fields specified by this tree from source to destination."""
- _MergeMessage(
- self._root, source, destination, replace_message, replace_repeated)
- def _StrConvert(value):
- """Converts value to str if it is not."""
- # This file is imported by c extension and some methods like ClearField
- # requires string for the field name. py2/py3 has different text
- # type and may use unicode.
- if not isinstance(value, str):
- return value.encode('utf-8')
- return value
- def _MergeMessage(
- node, source, destination, replace_message, replace_repeated):
- """Merge all fields specified by a sub-tree from source to destination."""
- source_descriptor = source.DESCRIPTOR
- for name in node:
- child = node[name]
- field = source_descriptor.fields_by_name[name]
- if field is None:
- raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
- name, source_descriptor.full_name))
- if child:
- # Sub-paths are only allowed for singular message fields.
- if (field.label == FieldDescriptor.LABEL_REPEATED or
- field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
- raise ValueError('Error: Field {0} in message {1} is not a singular '
- 'message field and cannot have sub-fields.'.format(
- name, source_descriptor.full_name))
- if source.HasField(name):
- _MergeMessage(
- child, getattr(source, name), getattr(destination, name),
- replace_message, replace_repeated)
- continue
- if field.label == FieldDescriptor.LABEL_REPEATED:
- if replace_repeated:
- destination.ClearField(_StrConvert(name))
- repeated_source = getattr(source, name)
- repeated_destination = getattr(destination, name)
- repeated_destination.MergeFrom(repeated_source)
- else:
- if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
- if replace_message:
- destination.ClearField(_StrConvert(name))
- if source.HasField(name):
- getattr(destination, name).MergeFrom(getattr(source, name))
- else:
- setattr(destination, name, getattr(source, name))
- def _AddFieldPaths(node, prefix, field_mask):
- """Adds the field paths descended from node to field_mask."""
- if not node and prefix:
- field_mask.paths.append(prefix)
- return
- for name in sorted(node):
- if prefix:
- child_path = prefix + '.' + name
- else:
- child_path = name
- _AddFieldPaths(node[name], child_path, field_mask)
|