field_mask.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Protocol Buffers - Google's data interchange format
  2. # Copyright 2008 Google Inc. All rights reserved.
  3. #
  4. # Use of this source code is governed by a BSD-style
  5. # license that can be found in the LICENSE file or at
  6. # https://developers.google.com/open-source/licenses/bsd
  7. """Contains FieldMask class."""
  8. from google.protobuf.descriptor import FieldDescriptor
  9. class FieldMask(object):
  10. """Class for FieldMask message type."""
  11. __slots__ = ()
  12. def ToJsonString(self):
  13. """Converts FieldMask to string according to proto3 JSON spec."""
  14. camelcase_paths = []
  15. for path in self.paths:
  16. camelcase_paths.append(_SnakeCaseToCamelCase(path))
  17. return ','.join(camelcase_paths)
  18. def FromJsonString(self, value):
  19. """Converts string to FieldMask according to proto3 JSON spec."""
  20. if not isinstance(value, str):
  21. raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
  22. self.Clear()
  23. if value:
  24. for path in value.split(','):
  25. self.paths.append(_CamelCaseToSnakeCase(path))
  26. def IsValidForDescriptor(self, message_descriptor):
  27. """Checks whether the FieldMask is valid for Message Descriptor."""
  28. for path in self.paths:
  29. if not _IsValidPath(message_descriptor, path):
  30. return False
  31. return True
  32. def AllFieldsFromDescriptor(self, message_descriptor):
  33. """Gets all direct fields of Message Descriptor to FieldMask."""
  34. self.Clear()
  35. for field in message_descriptor.fields:
  36. self.paths.append(field.name)
  37. def CanonicalFormFromMask(self, mask):
  38. """Converts a FieldMask to the canonical form.
  39. Removes paths that are covered by another path. For example,
  40. "foo.bar" is covered by "foo" and will be removed if "foo"
  41. is also in the FieldMask. Then sorts all paths in alphabetical order.
  42. Args:
  43. mask: The original FieldMask to be converted.
  44. """
  45. tree = _FieldMaskTree(mask)
  46. tree.ToFieldMask(self)
  47. def Union(self, mask1, mask2):
  48. """Merges mask1 and mask2 into this FieldMask."""
  49. _CheckFieldMaskMessage(mask1)
  50. _CheckFieldMaskMessage(mask2)
  51. tree = _FieldMaskTree(mask1)
  52. tree.MergeFromFieldMask(mask2)
  53. tree.ToFieldMask(self)
  54. def Intersect(self, mask1, mask2):
  55. """Intersects mask1 and mask2 into this FieldMask."""
  56. _CheckFieldMaskMessage(mask1)
  57. _CheckFieldMaskMessage(mask2)
  58. tree = _FieldMaskTree(mask1)
  59. intersection = _FieldMaskTree()
  60. for path in mask2.paths:
  61. tree.IntersectPath(path, intersection)
  62. intersection.ToFieldMask(self)
  63. def MergeMessage(
  64. self, source, destination,
  65. replace_message_field=False, replace_repeated_field=False):
  66. """Merges fields specified in FieldMask from source to destination.
  67. Args:
  68. source: Source message.
  69. destination: The destination message to be merged into.
  70. replace_message_field: Replace message field if True. Merge message
  71. field if False.
  72. replace_repeated_field: Replace repeated field if True. Append
  73. elements of repeated field if False.
  74. """
  75. tree = _FieldMaskTree(self)
  76. tree.MergeMessage(
  77. source, destination, replace_message_field, replace_repeated_field)
  78. def _IsValidPath(message_descriptor, path):
  79. """Checks whether the path is valid for Message Descriptor."""
  80. parts = path.split('.')
  81. last = parts.pop()
  82. for name in parts:
  83. field = message_descriptor.fields_by_name.get(name)
  84. if (field is None or
  85. field.label == FieldDescriptor.LABEL_REPEATED or
  86. field.type != FieldDescriptor.TYPE_MESSAGE):
  87. return False
  88. message_descriptor = field.message_type
  89. return last in message_descriptor.fields_by_name
  90. def _CheckFieldMaskMessage(message):
  91. """Raises ValueError if message is not a FieldMask."""
  92. message_descriptor = message.DESCRIPTOR
  93. if (message_descriptor.name != 'FieldMask' or
  94. message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
  95. raise ValueError('Message {0} is not a FieldMask.'.format(
  96. message_descriptor.full_name))
  97. def _SnakeCaseToCamelCase(path_name):
  98. """Converts a path name from snake_case to camelCase."""
  99. result = []
  100. after_underscore = False
  101. for c in path_name:
  102. if c.isupper():
  103. raise ValueError(
  104. 'Fail to print FieldMask to Json string: Path name '
  105. '{0} must not contain uppercase letters.'.format(path_name))
  106. if after_underscore:
  107. if c.islower():
  108. result.append(c.upper())
  109. after_underscore = False
  110. else:
  111. raise ValueError(
  112. 'Fail to print FieldMask to Json string: The '
  113. 'character after a "_" must be a lowercase letter '
  114. 'in path name {0}.'.format(path_name))
  115. elif c == '_':
  116. after_underscore = True
  117. else:
  118. result += c
  119. if after_underscore:
  120. raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
  121. 'in path name {0}.'.format(path_name))
  122. return ''.join(result)
  123. def _CamelCaseToSnakeCase(path_name):
  124. """Converts a field name from camelCase to snake_case."""
  125. result = []
  126. for c in path_name:
  127. if c == '_':
  128. raise ValueError('Fail to parse FieldMask: Path name '
  129. '{0} must not contain "_"s.'.format(path_name))
  130. if c.isupper():
  131. result += '_'
  132. result += c.lower()
  133. else:
  134. result += c
  135. return ''.join(result)
  136. class _FieldMaskTree(object):
  137. """Represents a FieldMask in a tree structure.
  138. For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
  139. the FieldMaskTree will be:
  140. [_root] -+- foo -+- bar
  141. | |
  142. | +- baz
  143. |
  144. +- bar --- baz
  145. In the tree, each leaf node represents a field path.
  146. """
  147. __slots__ = ('_root',)
  148. def __init__(self, field_mask=None):
  149. """Initializes the tree by FieldMask."""
  150. self._root = {}
  151. if field_mask:
  152. self.MergeFromFieldMask(field_mask)
  153. def MergeFromFieldMask(self, field_mask):
  154. """Merges a FieldMask to the tree."""
  155. for path in field_mask.paths:
  156. self.AddPath(path)
  157. def AddPath(self, path):
  158. """Adds a field path into the tree.
  159. If the field path to add is a sub-path of an existing field path
  160. in the tree (i.e., a leaf node), it means the tree already matches
  161. the given path so nothing will be added to the tree. If the path
  162. matches an existing non-leaf node in the tree, that non-leaf node
  163. will be turned into a leaf node with all its children removed because
  164. the path matches all the node's children. Otherwise, a new path will
  165. be added.
  166. Args:
  167. path: The field path to add.
  168. """
  169. node = self._root
  170. for name in path.split('.'):
  171. if name not in node:
  172. node[name] = {}
  173. elif not node[name]:
  174. # Pre-existing empty node implies we already have this entire tree.
  175. return
  176. node = node[name]
  177. # Remove any sub-trees we might have had.
  178. node.clear()
  179. def ToFieldMask(self, field_mask):
  180. """Converts the tree to a FieldMask."""
  181. field_mask.Clear()
  182. _AddFieldPaths(self._root, '', field_mask)
  183. def IntersectPath(self, path, intersection):
  184. """Calculates the intersection part of a field path with this tree.
  185. Args:
  186. path: The field path to calculates.
  187. intersection: The out tree to record the intersection part.
  188. """
  189. node = self._root
  190. for name in path.split('.'):
  191. if name not in node:
  192. return
  193. elif not node[name]:
  194. intersection.AddPath(path)
  195. return
  196. node = node[name]
  197. intersection.AddLeafNodes(path, node)
  198. def AddLeafNodes(self, prefix, node):
  199. """Adds leaf nodes begin with prefix to this tree."""
  200. if not node:
  201. self.AddPath(prefix)
  202. for name in node:
  203. child_path = prefix + '.' + name
  204. self.AddLeafNodes(child_path, node[name])
  205. def MergeMessage(
  206. self, source, destination,
  207. replace_message, replace_repeated):
  208. """Merge all fields specified by this tree from source to destination."""
  209. _MergeMessage(
  210. self._root, source, destination, replace_message, replace_repeated)
  211. def _StrConvert(value):
  212. """Converts value to str if it is not."""
  213. # This file is imported by c extension and some methods like ClearField
  214. # requires string for the field name. py2/py3 has different text
  215. # type and may use unicode.
  216. if not isinstance(value, str):
  217. return value.encode('utf-8')
  218. return value
  219. def _MergeMessage(
  220. node, source, destination, replace_message, replace_repeated):
  221. """Merge all fields specified by a sub-tree from source to destination."""
  222. source_descriptor = source.DESCRIPTOR
  223. for name in node:
  224. child = node[name]
  225. field = source_descriptor.fields_by_name[name]
  226. if field is None:
  227. raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
  228. name, source_descriptor.full_name))
  229. if child:
  230. # Sub-paths are only allowed for singular message fields.
  231. if (field.label == FieldDescriptor.LABEL_REPEATED or
  232. field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
  233. raise ValueError('Error: Field {0} in message {1} is not a singular '
  234. 'message field and cannot have sub-fields.'.format(
  235. name, source_descriptor.full_name))
  236. if source.HasField(name):
  237. _MergeMessage(
  238. child, getattr(source, name), getattr(destination, name),
  239. replace_message, replace_repeated)
  240. continue
  241. if field.label == FieldDescriptor.LABEL_REPEATED:
  242. if replace_repeated:
  243. destination.ClearField(_StrConvert(name))
  244. repeated_source = getattr(source, name)
  245. repeated_destination = getattr(destination, name)
  246. repeated_destination.MergeFrom(repeated_source)
  247. else:
  248. if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
  249. if replace_message:
  250. destination.ClearField(_StrConvert(name))
  251. if source.HasField(name):
  252. getattr(destination, name).MergeFrom(getattr(source, name))
  253. else:
  254. setattr(destination, name, getattr(source, name))
  255. def _AddFieldPaths(node, prefix, field_mask):
  256. """Adds the field paths descended from node to field_mask."""
  257. if not node and prefix:
  258. field_mask.paths.append(prefix)
  259. return
  260. for name in sorted(node):
  261. if prefix:
  262. child_path = prefix + '.' + name
  263. else:
  264. child_path = name
  265. _AddFieldPaths(node[name], child_path, field_mask)