tool_support.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # util/tool_support.py
  2. # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. # mypy: allow-untyped-defs, allow-untyped-calls
  8. """support routines for the helpers in tools/.
  9. These aren't imported by the enclosing util package as the are not
  10. needed for normal library use.
  11. """
  12. from __future__ import annotations
  13. from argparse import ArgumentParser
  14. from argparse import Namespace
  15. import contextlib
  16. import difflib
  17. import os
  18. from pathlib import Path
  19. import shlex
  20. import shutil
  21. import subprocess
  22. import sys
  23. from typing import Any
  24. from typing import Dict
  25. from typing import Iterator
  26. from typing import Optional
  27. from typing import Union
  28. from . import compat
  29. class code_writer_cmd:
  30. parser: ArgumentParser
  31. args: Namespace
  32. suppress_output: bool
  33. diffs_detected: bool
  34. source_root: Path
  35. pyproject_toml_path: Path
  36. def __init__(self, tool_script: str):
  37. self.source_root = Path(tool_script).parent.parent
  38. self.pyproject_toml_path = self.source_root / Path("pyproject.toml")
  39. assert self.pyproject_toml_path.exists()
  40. self.parser = ArgumentParser()
  41. self.parser.add_argument(
  42. "--stdout",
  43. action="store_true",
  44. help="Write to stdout instead of saving to file",
  45. )
  46. self.parser.add_argument(
  47. "-c",
  48. "--check",
  49. help="Don't write the files back, just return the "
  50. "status. Return code 0 means nothing would change. "
  51. "Return code 1 means some files would be reformatted",
  52. action="store_true",
  53. )
  54. def run_zimports(self, tempfile: str) -> None:
  55. self._run_console_script(
  56. str(tempfile),
  57. {
  58. "entrypoint": "zimports",
  59. "options": f"--toml-config {self.pyproject_toml_path}",
  60. },
  61. )
  62. def run_black(self, tempfile: str) -> None:
  63. self._run_console_script(
  64. str(tempfile),
  65. {
  66. "entrypoint": "black",
  67. "options": f"--config {self.pyproject_toml_path}",
  68. },
  69. )
  70. def _run_console_script(self, path: str, options: Dict[str, Any]) -> None:
  71. """Run a Python console application from within the process.
  72. Used for black, zimports
  73. """
  74. is_posix = os.name == "posix"
  75. entrypoint_name = options["entrypoint"]
  76. for entry in compat.importlib_metadata_get("console_scripts"):
  77. if entry.name == entrypoint_name:
  78. impl = entry
  79. break
  80. else:
  81. raise Exception(
  82. f"Could not find entrypoint console_scripts.{entrypoint_name}"
  83. )
  84. cmdline_options_str = options.get("options", "")
  85. cmdline_options_list = shlex.split(
  86. cmdline_options_str, posix=is_posix
  87. ) + [path]
  88. kw: Dict[str, Any] = {}
  89. if self.suppress_output:
  90. kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
  91. subprocess.run(
  92. [
  93. sys.executable,
  94. "-c",
  95. "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
  96. ]
  97. + cmdline_options_list,
  98. cwd=str(self.source_root),
  99. **kw,
  100. )
  101. def write_status(self, *text: str) -> None:
  102. if not self.suppress_output:
  103. sys.stderr.write(" ".join(text))
  104. def write_output_file_from_text(
  105. self, text: str, destination_path: Union[str, Path]
  106. ) -> None:
  107. if self.args.check:
  108. self._run_diff(destination_path, source=text)
  109. elif self.args.stdout:
  110. print(text)
  111. else:
  112. self.write_status(f"Writing {destination_path}...")
  113. Path(destination_path).write_text(
  114. text, encoding="utf-8", newline="\n"
  115. )
  116. self.write_status("done\n")
  117. def write_output_file_from_tempfile(
  118. self, tempfile: str, destination_path: str
  119. ) -> None:
  120. if self.args.check:
  121. self._run_diff(destination_path, source_file=tempfile)
  122. os.unlink(tempfile)
  123. elif self.args.stdout:
  124. with open(tempfile) as tf:
  125. print(tf.read())
  126. os.unlink(tempfile)
  127. else:
  128. self.write_status(f"Writing {destination_path}...")
  129. shutil.move(tempfile, destination_path)
  130. self.write_status("done\n")
  131. def _run_diff(
  132. self,
  133. destination_path: Union[str, Path],
  134. *,
  135. source: Optional[str] = None,
  136. source_file: Optional[str] = None,
  137. ) -> None:
  138. if source_file:
  139. with open(source_file, encoding="utf-8") as tf:
  140. source_lines = list(tf)
  141. elif source is not None:
  142. source_lines = source.splitlines(keepends=True)
  143. else:
  144. assert False, "source or source_file is required"
  145. with open(destination_path, encoding="utf-8") as dp:
  146. d = difflib.unified_diff(
  147. list(dp),
  148. source_lines,
  149. fromfile=Path(destination_path).as_posix(),
  150. tofile="<proposed changes>",
  151. n=3,
  152. lineterm="\n",
  153. )
  154. d_as_list = list(d)
  155. if d_as_list:
  156. self.diffs_detected = True
  157. print("".join(d_as_list))
  158. @contextlib.contextmanager
  159. def add_arguments(self) -> Iterator[ArgumentParser]:
  160. yield self.parser
  161. @contextlib.contextmanager
  162. def run_program(self) -> Iterator[None]:
  163. self.args = self.parser.parse_args()
  164. if self.args.check:
  165. self.diffs_detected = False
  166. self.suppress_output = True
  167. elif self.args.stdout:
  168. self.suppress_output = True
  169. else:
  170. self.suppress_output = False
  171. yield
  172. if self.args.check and self.diffs_detected:
  173. sys.exit(1)
  174. else:
  175. sys.exit(0)