1
0
Fork 0

Adding upstream version 4.5.0+dfsg.

Signed-off-by: Daniel Baumann <daniel@debian.org>
This commit is contained in:
Daniel Baumann 2025-02-09 20:21:14 +01:00
parent 27cd5628db
commit 6bd375ed5f
Signed by: daniel
GPG key ID: FBB4F0E80A80222F
108 changed files with 6514 additions and 0 deletions

View file

View file

@ -0,0 +1,81 @@
from __future__ import annotations
import argparse
import math
import os
import subprocess
from typing import Sequence
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import zsplit
def filter_lfs_files(filenames: set[str]) -> None: # pragma: no cover (lfs)
"""Remove files tracked by git-lfs from the set."""
if not filenames:
return
check_attr = subprocess.run(
('git', 'check-attr', 'filter', '-z', '--stdin'),
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
encoding='utf-8',
check=True,
input='\0'.join(filenames),
)
stdout = zsplit(check_attr.stdout)
for i in range(0, len(stdout), 3):
filename, filter_tag = stdout[i], stdout[i + 2]
if filter_tag == 'lfs':
filenames.remove(filename)
def find_large_added_files(
filenames: Sequence[str],
maxkb: int,
*,
enforce_all: bool = False,
) -> int:
# Find all added files that are also in the list of files pre-commit tells
# us about
retv = 0
filenames_filtered = set(filenames)
filter_lfs_files(filenames_filtered)
if not enforce_all:
filenames_filtered &= added_files()
for filename in filenames_filtered:
kb = math.ceil(os.stat(filename).st_size / 1024)
if kb > maxkb:
print(f'{filename} ({kb} KB) exceeds {maxkb} KB.')
retv = 1
return retv
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
help='Filenames pre-commit believes are changed.',
)
parser.add_argument(
'--enforce-all', action='store_true',
help='Enforce all files are checked, not just staged files.',
)
parser.add_argument(
'--maxkb', type=int, default=500,
help='Maximum allowable KB for added files',
)
args = parser.parse_args(argv)
return find_large_added_files(
args.filenames,
args.maxkb,
enforce_all=args.enforce_all,
)
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,33 @@
from __future__ import annotations
import argparse
import ast
import platform
import sys
import traceback
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
retval = 0
for filename in args.filenames:
try:
with open(filename, 'rb') as f:
ast.parse(f.read(), filename=filename)
except SyntaxError:
impl = platform.python_implementation()
version = sys.version.split()[0]
print(f'{filename}: failed parsing with {impl} {version}:')
tb = ' ' + traceback.format_exc().replace('\n', '\n ')
print(f'\n{tb}')
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,105 @@
from __future__ import annotations
import argparse
import ast
from typing import NamedTuple
from typing import Sequence
BUILTIN_TYPES = {
'complex': '0j',
'dict': '{}',
'float': '0.0',
'int': '0',
'list': '[]',
'str': "''",
'tuple': '()',
}
class Call(NamedTuple):
name: str
line: int
column: int
class Visitor(ast.NodeVisitor):
def __init__(
self,
ignore: Sequence[str] | None = None,
allow_dict_kwargs: bool = True,
) -> None:
self.builtin_type_calls: list[Call] = []
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
def _check_dict_call(self, node: ast.Call) -> bool:
return self.allow_dict_kwargs and bool(node.keywords)
def visit_Call(self, node: ast.Call) -> None:
if not isinstance(node.func, ast.Name):
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
# they're doing.
return
if node.func.id not in set(BUILTIN_TYPES).difference(self.ignore):
return
if node.func.id == 'dict' and self._check_dict_call(node):
return
elif node.args:
return
self.builtin_type_calls.append(
Call(node.func.id, node.lineno, node.col_offset),
)
def check_file(
filename: str,
ignore: Sequence[str] | None = None,
allow_dict_kwargs: bool = True,
) -> list[Call]:
with open(filename, 'rb') as f:
tree = ast.parse(f.read(), filename=filename)
visitor = Visitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
visitor.visit(tree)
return visitor.builtin_type_calls
def parse_ignore(value: str) -> set[str]:
return set(value.split(','))
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--ignore', type=parse_ignore, default=set())
mutex = parser.add_mutually_exclusive_group(required=False)
mutex.add_argument('--allow-dict-kwargs', action='store_true')
mutex.add_argument(
'--no-allow-dict-kwargs',
dest='allow_dict_kwargs', action='store_false',
)
mutex.set_defaults(allow_dict_kwargs=True)
args = parser.parse_args(argv)
rc = 0
for filename in args.filenames:
calls = check_file(
filename,
ignore=args.ignore,
allow_dict_kwargs=args.allow_dict_kwargs,
)
if calls:
rc = rc or 1
for call in calls:
print(
f'{filename}:{call.line}:{call.column}: '
f'replace {call.name}() with {BUILTIN_TYPES[call.name]}',
)
return rc
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,24 @@
from __future__ import annotations
import argparse
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
with open(filename, 'rb') as f:
if f.read(3) == b'\xef\xbb\xbf':
retv = 1
print(f'{filename}: Has a byte-order marker')
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import argparse
from typing import Iterable
from typing import Iterator
from typing import Sequence
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import cmd_output
def lower_set(iterable: Iterable[str]) -> set[str]:
return {x.lower() for x in iterable}
def parents(file: str) -> Iterator[str]:
path_parts = file.split('/')
path_parts.pop()
while path_parts:
yield '/'.join(path_parts)
path_parts.pop()
def directories_for(files: set[str]) -> set[str]:
return {parent for file in files for parent in parents(file)}
def find_conflicting_filenames(filenames: Sequence[str]) -> int:
repo_files = set(cmd_output('git', 'ls-files').splitlines())
repo_files |= directories_for(repo_files)
relevant_files = set(filenames) | added_files()
relevant_files |= directories_for(relevant_files)
repo_files -= relevant_files
retv = 0
# new file conflicts with existing file
conflicts = lower_set(repo_files) & lower_set(relevant_files)
# new file conflicts with other new file
lowercase_relevant_files = lower_set(relevant_files)
for filename in set(relevant_files):
if filename.lower() in lowercase_relevant_files:
lowercase_relevant_files.remove(filename.lower())
else:
conflicts.add(filename.lower())
if conflicts:
conflicting_files = [
x for x in repo_files | relevant_files
if x.lower() in conflicts
]
for filename in sorted(conflicting_files):
print(f'Case-insensitivity conflict found: {filename}')
retv = 1
return retv
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',
help='Filenames pre-commit believes are changed.',
)
args = parser.parse_args(argv)
return find_conflicting_filenames(args.filenames)
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,61 @@
from __future__ import annotations
import argparse
import io
import tokenize
from tokenize import tokenize as tokenize_tokenize
from typing import Sequence
NON_CODE_TOKENS = frozenset((
tokenize.COMMENT, tokenize.ENDMARKER, tokenize.NEWLINE, tokenize.NL,
tokenize.ENCODING,
))
def check_docstring_first(src: bytes, filename: str = '<unknown>') -> int:
"""Returns nonzero if the source has what looks like a docstring that is
not at the beginning of the source.
A string will be considered a docstring if it is a STRING token with a
col offset of 0.
"""
found_docstring_line = None
found_code_line = None
tok_gen = tokenize_tokenize(io.BytesIO(src).readline)
for tok_type, _, (sline, scol), _, _ in tok_gen:
# Looks like a docstring!
if tok_type == tokenize.STRING and scol == 0:
if found_docstring_line is not None:
print(
f'{filename}:{sline}: Multiple module docstrings '
f'(first docstring on line {found_docstring_line}).',
)
return 1
elif found_code_line is not None:
print(
f'{filename}:{sline}: Module docstring appears after code '
f'(code seen on line {found_code_line}).',
)
return 1
else:
found_docstring_line = sline
elif tok_type not in NON_CODE_TOKENS and found_code_line is None:
found_code_line = sline
return 0
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
with open(filename, 'rb') as f:
contents = f.read()
retv |= check_docstring_first(contents, filename=filename)
return retv

View file

@ -0,0 +1,85 @@
"""Check that executable text files have a shebang."""
from __future__ import annotations
import argparse
import shlex
import sys
from typing import Generator
from typing import NamedTuple
from typing import Sequence
from pre_commit_hooks.util import cmd_output
from pre_commit_hooks.util import zsplit
EXECUTABLE_VALUES = frozenset(('1', '3', '5', '7'))
def check_executables(paths: list[str]) -> int:
fs_tracks_executable_bit = cmd_output(
'git', 'config', 'core.fileMode', retcode=None,
).strip()
if fs_tracks_executable_bit == 'false': # pragma: win32 cover
return _check_git_filemode(paths)
else: # pragma: win32 no cover
retv = 0
for path in paths:
if not has_shebang(path):
_message(path)
retv = 1
return retv
class GitLsFile(NamedTuple):
mode: str
filename: str
def git_ls_files(paths: Sequence[str]) -> Generator[GitLsFile, None, None]:
outs = cmd_output('git', 'ls-files', '-z', '--stage', '--', *paths)
for out in zsplit(outs):
metadata, filename = out.split('\t')
mode, _, _ = metadata.split()
yield GitLsFile(mode, filename)
def _check_git_filemode(paths: Sequence[str]) -> int:
seen: set[str] = set()
for ls_file in git_ls_files(paths):
is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:])
if is_executable and not has_shebang(ls_file.filename):
_message(ls_file.filename)
seen.add(ls_file.filename)
return int(bool(seen))
def has_shebang(path: str) -> int:
with open(path, 'rb') as f:
first_bytes = f.read(2)
return first_bytes == b'#!'
def _message(path: str) -> None:
print(
f'{path}: marked executable but has no (or invalid) shebang!\n'
f" If it isn't supposed to be executable, try: "
f'`chmod -x {shlex.quote(path)}`\n'
f' If on Windows, you may also need to: '
f'`git add --chmod=-x {shlex.quote(path)}`\n'
f' If it is supposed to be executable, double-check its shebang.',
file=sys.stderr,
)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
return check_executables(args.filenames)
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,38 @@
from __future__ import annotations
import argparse
import json
from typing import Any
from typing import Sequence
def raise_duplicate_keys(
ordered_pairs: list[tuple[str, Any]],
) -> dict[str, Any]:
d = {}
for key, val in ordered_pairs:
if key in d:
raise ValueError(f'Duplicate key: {key}')
else:
d[key] = val
return d
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check.')
args = parser.parse_args(argv)
retval = 0
for filename in args.filenames:
with open(filename, 'rb') as f:
try:
json.load(f, object_pairs_hook=raise_duplicate_keys)
except ValueError as exc:
print(f'{filename}: Failed to json decode ({exc})')
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,56 @@
from __future__ import annotations
import argparse
import os.path
from typing import Sequence
from pre_commit_hooks.util import cmd_output
CONFLICT_PATTERNS = [
b'<<<<<<< ',
b'======= ',
b'=======\r\n',
b'=======\n',
b'>>>>>>> ',
]
def is_in_merge() -> bool:
git_dir = cmd_output('git', 'rev-parse', '--git-dir').rstrip()
return (
os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and
(
os.path.exists(os.path.join(git_dir, 'MERGE_HEAD')) or
os.path.exists(os.path.join(git_dir, 'rebase-apply')) or
os.path.exists(os.path.join(git_dir, 'rebase-merge'))
)
)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--assume-in-merge', action='store_true')
args = parser.parse_args(argv)
if not is_in_merge() and not args.assume_in_merge:
return 0
retcode = 0
for filename in args.filenames:
with open(filename, 'rb') as inputfile:
for i, line in enumerate(inputfile, start=1):
for pattern in CONFLICT_PATTERNS:
if line.startswith(pattern):
print(
f'{filename}:{i}: Merge conflict string '
f'{pattern.strip().decode()!r} found',
)
retcode = 1
return retcode
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,54 @@
"""Check that text files with a shebang are executable."""
from __future__ import annotations
import argparse
import shlex
import sys
from typing import Sequence
from pre_commit_hooks.check_executables_have_shebangs import EXECUTABLE_VALUES
from pre_commit_hooks.check_executables_have_shebangs import git_ls_files
from pre_commit_hooks.check_executables_have_shebangs import has_shebang
def check_shebangs(paths: list[str]) -> int:
# Cannot optimize on non-executability here if we intend this check to
# work on win32 -- and that's where problems caused by non-executability
# (elsewhere) are most likely to arise from.
return _check_git_filemode(paths)
def _check_git_filemode(paths: Sequence[str]) -> int:
seen: set[str] = set()
for ls_file in git_ls_files(paths):
is_executable = any(b in EXECUTABLE_VALUES for b in ls_file.mode[-3:])
if not is_executable and has_shebang(ls_file.filename):
_message(ls_file.filename)
seen.add(ls_file.filename)
return int(bool(seen))
def _message(path: str) -> None:
print(
f'{path}: has a shebang but is not marked executable!\n'
f' If it is supposed to be executable, try: '
f'`chmod +x {shlex.quote(path)}`\n'
f' If on Windows, you may also need to: '
f'`git add --chmod=+x {shlex.quote(path)}`\n'
f' If it not supposed to be executable, double-check its shebang '
f'is wanted.\n',
file=sys.stderr,
)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
return check_shebangs(args.filenames)
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,27 @@
from __future__ import annotations
import argparse
import os.path
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
if (
os.path.islink(filename) and
not os.path.exists(filename)
): # pragma: no cover (symlink support required)
print(f'{filename}: Broken symlink')
retv = 1
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,30 @@
from __future__ import annotations
import argparse
import sys
from typing import Sequence
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
import tomllib
else: # pragma: <3.11 cover
import tomli as tomllib
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check.')
args = parser.parse_args(argv)
retval = 0
for filename in args.filenames:
try:
with open(filename, mode='rb') as fp:
tomllib.load(fp)
except tomllib.TOMLDecodeError as exc:
print(f'{filename}: {exc}')
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,60 @@
from __future__ import annotations
import argparse
import re
import sys
from typing import Pattern
from typing import Sequence
def _get_pattern(domain: str) -> Pattern[bytes]:
regex = (
rf'https://{domain}/[^/ ]+/[^/ ]+/blob/'
r'(?![a-fA-F0-9]{4,64}/)([^/. ]+)/[^# ]+#L\d+'
)
return re.compile(regex.encode())
def _check_filename(filename: str, patterns: list[Pattern[bytes]]) -> int:
retv = 0
with open(filename, 'rb') as f:
for i, line in enumerate(f, 1):
for pattern in patterns:
if pattern.search(line):
sys.stdout.write(f'{filename}:{i}:')
sys.stdout.flush()
sys.stdout.buffer.write(line)
retv = 1
return retv
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument(
'--additional-github-domain',
dest='additional_github_domains',
action='append',
default=['github.com'],
)
args = parser.parse_args(argv)
patterns = [
_get_pattern(domain)
for domain in args.additional_github_domains
]
retv = 0
for filename in args.filenames:
retv |= _check_filename(filename, patterns)
if retv:
print()
print('Non-permanent github link detected.')
print('On any page on github press [y] to load a permalink.')
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,26 @@
from __future__ import annotations
import argparse
import xml.sax.handler
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
args = parser.parse_args(argv)
retval = 0
handler = xml.sax.handler.ContentHandler()
for filename in args.filenames:
try:
with open(filename, 'rb') as xml_file:
xml.sax.parse(xml_file, handler)
except xml.sax.SAXException as exc:
print(f'{filename}: Failed to xml parse ({exc})')
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import argparse
from typing import Any
from typing import Generator
from typing import NamedTuple
from typing import Sequence
import ruamel.yaml
yaml = ruamel.yaml.YAML(typ='safe')
def _exhaust(gen: Generator[str, None, None]) -> None:
for _ in gen:
pass
def _parse_unsafe(*args: Any, **kwargs: Any) -> None:
_exhaust(yaml.parse(*args, **kwargs))
def _load_all(*args: Any, **kwargs: Any) -> None:
_exhaust(yaml.load_all(*args, **kwargs))
class Key(NamedTuple):
multi: bool
unsafe: bool
LOAD_FNS = {
Key(multi=False, unsafe=False): yaml.load,
Key(multi=False, unsafe=True): _parse_unsafe,
Key(multi=True, unsafe=False): _load_all,
Key(multi=True, unsafe=True): _parse_unsafe,
}
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--multi', '--allow-multiple-documents', action='store_true',
)
parser.add_argument(
'--unsafe', action='store_true',
help=(
'Instead of loading the files, simply parse them for syntax. '
'A syntax-only check enables extensions and unsafe constructs '
'which would otherwise be forbidden. Using this option removes '
'all guarantees of portability to other yaml implementations. '
'Implies --allow-multiple-documents'
),
)
parser.add_argument('filenames', nargs='*', help='Filenames to check.')
args = parser.parse_args(argv)
load_fn = LOAD_FNS[Key(multi=args.multi, unsafe=args.unsafe)]
retval = 0
for filename in args.filenames:
try:
with open(filename, encoding='UTF-8') as f:
load_fn(f)
except ruamel.yaml.YAMLError as exc:
print(exc)
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import argparse
import ast
import traceback
from typing import NamedTuple
from typing import Sequence
DEBUG_STATEMENTS = {
'bpdb',
'ipdb',
'pdb',
'pdbr',
'pudb',
'pydevd_pycharm',
'q',
'rdb',
'rpdb',
'wdb',
}
class Debug(NamedTuple):
line: int
col: int
name: str
reason: str
class DebugStatementParser(ast.NodeVisitor):
def __init__(self) -> None:
self.breakpoints: list[Debug] = []
def visit_Import(self, node: ast.Import) -> None:
for name in node.names:
if name.name in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, name.name, 'imported')
self.breakpoints.append(st)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.module, 'imported')
self.breakpoints.append(st)
def visit_Call(self, node: ast.Call) -> None:
"""python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
self.breakpoints.append(st)
self.generic_visit(node)
def check_file(filename: str) -> int:
try:
with open(filename, 'rb') as f:
ast_obj = ast.parse(f.read(), filename=filename)
except SyntaxError:
print(f'{filename} - Could not parse ast')
print()
print('\t' + traceback.format_exc().replace('\n', '\n\t'))
print()
return 1
visitor = DebugStatementParser()
visitor.visit(ast_obj)
for bp in visitor.breakpoints:
print(f'{filename}:{bp.line}:{bp.col}: {bp.name} {bp.reason}')
return int(bool(visitor.breakpoints))
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to run')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
retv |= check_file(filename)
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,92 @@
from __future__ import annotations
import argparse
import shlex
import subprocess
from typing import Sequence
from pre_commit_hooks.util import cmd_output
from pre_commit_hooks.util import zsplit
ORDINARY_CHANGED_ENTRIES_MARKER = '1'
PERMS_LINK = '120000'
PERMS_NONEXIST = '000000'
def find_destroyed_symlinks(files: Sequence[str]) -> list[str]:
destroyed_links: list[str] = []
if not files:
return destroyed_links
for line in zsplit(
cmd_output('git', 'status', '--porcelain=v2', '-z', '--', *files),
):
splitted = line.split(' ')
if splitted and splitted[0] == ORDINARY_CHANGED_ENTRIES_MARKER:
# https://git-scm.com/docs/git-status#_changed_tracked_entries
(
_, _, _,
mode_HEAD,
mode_index,
_,
hash_HEAD,
hash_index,
*path_splitted,
) = splitted
path = ' '.join(path_splitted)
if (
mode_HEAD == PERMS_LINK and
mode_index != PERMS_LINK and
mode_index != PERMS_NONEXIST
):
if hash_HEAD == hash_index:
# if old and new hashes are equal, it's not needed to check
# anything more, we've found a destroyed symlink for sure
destroyed_links.append(path)
else:
# if old and new hashes are *not* equal, it doesn't mean
# that everything is OK - new file may be altered
# by something like trailing-whitespace and/or
# mixed-line-ending hooks so we need to go deeper
SIZE_CMD = ('git', 'cat-file', '-s')
size_index = int(cmd_output(*SIZE_CMD, hash_index).strip())
size_HEAD = int(cmd_output(*SIZE_CMD, hash_HEAD).strip())
# in the worst case new file may have CRLF added
# so check content only if new file is bigger
# not more than 2 bytes compared to the old one
if size_index <= size_HEAD + 2:
head_content = subprocess.check_output(
('git', 'cat-file', '-p', hash_HEAD),
).rstrip()
index_content = subprocess.check_output(
('git', 'cat-file', '-p', hash_index),
).rstrip()
if head_content == index_content:
destroyed_links.append(path)
return destroyed_links
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check.')
args = parser.parse_args(argv)
destroyed_links = find_destroyed_symlinks(files=args.filenames)
if destroyed_links:
print('Destroyed symlinks:')
for destroyed_link in destroyed_links:
print(f'- {destroyed_link}')
print('You should unstage affected files:')
print(f'\tgit reset HEAD -- {shlex.join(destroyed_links)}')
print(
'And retry commit. As a long term solution '
'you may try to explicitly tell git that your '
'environment does not support symlinks:',
)
print('\tgit config core.symlinks false')
return 1
else:
return 0
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,151 @@
from __future__ import annotations
import argparse
import configparser
import os
from typing import NamedTuple
from typing import Sequence
class BadFile(NamedTuple):
filename: str
key: str
def get_aws_cred_files_from_env() -> set[str]:
"""Extract credential file paths from environment variables."""
return {
os.environ[env_var]
for env_var in (
'AWS_CONFIG_FILE', 'AWS_CREDENTIAL_FILE',
'AWS_SHARED_CREDENTIALS_FILE', 'BOTO_CONFIG',
)
if env_var in os.environ
}
def get_aws_secrets_from_env() -> set[str]:
"""Extract AWS secrets from environment variables."""
keys = set()
for env_var in (
'AWS_SECRET_ACCESS_KEY', 'AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN',
):
if os.environ.get(env_var):
keys.add(os.environ[env_var])
return keys
def get_aws_secrets_from_file(credentials_file: str) -> set[str]:
"""Extract AWS secrets from configuration files.
Read an ini-style configuration file and return a set with all found AWS
secret access keys.
"""
aws_credentials_file_path = os.path.expanduser(credentials_file)
if not os.path.exists(aws_credentials_file_path):
return set()
parser = configparser.ConfigParser()
try:
parser.read(aws_credentials_file_path)
except configparser.MissingSectionHeaderError:
return set()
keys = set()
for section in parser.sections():
for var in (
'aws_secret_access_key', 'aws_security_token',
'aws_session_token',
):
try:
key = parser.get(section, var).strip()
if key:
keys.add(key)
except configparser.NoOptionError:
pass
return keys
def check_file_for_aws_keys(
filenames: Sequence[str],
keys: set[bytes],
) -> list[BadFile]:
"""Check if files contain AWS secrets.
Return a list of all files containing AWS secrets and keys found, with all
but the first four characters obfuscated to ease debugging.
"""
bad_files = []
for filename in filenames:
with open(filename, 'rb') as content:
text_body = content.read()
for key in keys:
# naively match the entire file, low chance of incorrect
# collision
if key in text_body:
key_hidden = key.decode()[:4].ljust(28, '*')
bad_files.append(BadFile(filename, key_hidden))
return bad_files
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Filenames to run')
parser.add_argument(
'--credentials-file',
dest='credentials_file',
action='append',
default=[
'~/.aws/config', '~/.aws/credentials', '/etc/boto.cfg', '~/.boto',
],
help=(
'Location of additional AWS credential file from which to get '
'secret keys. Can be passed multiple times.'
),
)
parser.add_argument(
'--allow-missing-credentials',
dest='allow_missing_credentials',
action='store_true',
help='Allow hook to pass when no credentials are detected.',
)
args = parser.parse_args(argv)
credential_files = set(args.credentials_file)
# Add the credentials files configured via environment variables to the set
# of files to to gather AWS secrets from.
credential_files |= get_aws_cred_files_from_env()
keys: set[str] = set()
for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file)
# Secrets might be part of environment variables, so add such secrets to
# the set of keys.
keys |= get_aws_secrets_from_env()
if not keys and args.allow_missing_credentials:
return 0
if not keys:
print(
'No AWS keys were found in the configured credential files and '
'environment variables.\nPlease ensure you have the correct '
'setting for --credentials-file',
)
return 2
keys_b = {key.encode() for key in keys}
bad_filenames = check_file_for_aws_keys(args.filenames, keys_b)
if bad_filenames:
for bad_file in bad_filenames:
print(f'AWS secret found in {bad_file.filename}: {bad_file.key}')
return 1
else:
return 0
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,42 @@
from __future__ import annotations
import argparse
from typing import Sequence
BLACKLIST = [
b'BEGIN RSA PRIVATE KEY',
b'BEGIN DSA PRIVATE KEY',
b'BEGIN EC PRIVATE KEY',
b'BEGIN OPENSSH PRIVATE KEY',
b'BEGIN PRIVATE KEY',
b'PuTTY-User-Key-File-2',
b'BEGIN SSH2 ENCRYPTED PRIVATE KEY',
b'BEGIN PGP PRIVATE KEY BLOCK',
b'BEGIN ENCRYPTED PRIVATE KEY',
b'BEGIN OpenVPN Static key V1',
]
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
private_key_files = []
for filename in args.filenames:
with open(filename, 'rb') as f:
content = f.read()
if any(line in content for line in BLACKLIST):
private_key_files.append(filename)
if private_key_files:
for private_key_file in private_key_files:
print(f'Private key found: {private_key_file}')
return 1
else:
return 0
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,71 @@
from __future__ import annotations
import argparse
import os
from typing import IO
from typing import Sequence
def fix_file(file_obj: IO[bytes]) -> int:
# Test for newline at end of file
# Empty files will throw IOError here
try:
file_obj.seek(-1, os.SEEK_END)
except OSError:
return 0
last_character = file_obj.read(1)
# last_character will be '' for an empty file
if last_character not in {b'\n', b'\r'} and last_character != b'':
# Needs this seek for windows, otherwise IOError
file_obj.seek(0, os.SEEK_END)
file_obj.write(b'\n')
return 1
while last_character in {b'\n', b'\r'}:
# Deal with the beginning of the file
if file_obj.tell() == 1:
# If we've reached the beginning of the file and it is all
# linebreaks then we can make this file empty
file_obj.seek(0)
file_obj.truncate()
return 1
# Go back two bytes and read a character
file_obj.seek(-2, os.SEEK_CUR)
last_character = file_obj.read(1)
# Our current position is at the end of the file just before any amount of
# newlines. If we find extraneous newlines, then backtrack and trim them.
position = file_obj.tell()
remaining = file_obj.read()
for sequence in (b'\n', b'\r\n', b'\r'):
if remaining == sequence:
return 0
elif remaining.startswith(sequence):
file_obj.seek(position + len(sequence))
file_obj.truncate()
return 1
return 0
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
# Read as binary so we can read byte-by-byte
with open(filename, 'rb+') as file_obj:
ret_for_file = fix_file(file_obj)
if ret_for_file:
print(f'Fixing {filename}')
retv |= ret_for_file
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,88 @@
"""
A very simple pre-commit hook that, when passed one or more filenames
as arguments, will sort the lines in those files.
An example use case for this: you have a deploy-allowlist.txt file
in a repo that contains a list of filenames that is used to specify
files to be included in a docker container. This file has one filename
per line. Various users are adding/removing lines from this file; using
this hook on that file should reduce the instances of git merge
conflicts and keep the file nicely ordered.
"""
from __future__ import annotations
import argparse
from typing import Any
from typing import Callable
from typing import IO
from typing import Iterable
from typing import Sequence
PASS = 0
FAIL = 1
def sort_file_contents(
f: IO[bytes],
key: Callable[[bytes], Any] | None,
*,
unique: bool = False,
) -> int:
before = list(f)
lines: Iterable[bytes] = (
line.rstrip(b'\n\r') for line in before if line.strip()
)
if unique:
lines = set(lines)
after = sorted(lines, key=key)
before_string = b''.join(before)
after_string = b'\n'.join(after)
if after_string:
after_string += b'\n'
if before_string == after_string:
return PASS
else:
f.seek(0)
f.write(after_string)
f.truncate()
return FAIL
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Files to sort')
parser.add_argument(
'--ignore-case',
action='store_const',
const=bytes.lower,
default=None,
help='fold lower case to upper case characters',
)
parser.add_argument(
'--unique',
action='store_true',
help='ensure each line is unique',
)
args = parser.parse_args(argv)
retv = PASS
for arg in args.filenames:
with open(arg, 'rb+') as file_obj:
ret_for_file = sort_file_contents(
file_obj, key=args.ignore_case, unique=args.unique,
)
if ret_for_file:
print(f'Sorting {arg}')
retv |= ret_for_file
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,31 @@
from __future__ import annotations
import argparse
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
with open(filename, 'rb') as f_b:
bts = f_b.read(3)
if bts == b'\xef\xbb\xbf':
with open(filename, newline='', encoding='utf-8-sig') as f:
contents = f.read()
with open(filename, 'w', newline='', encoding='utf-8') as f:
f.write(contents)
print(f'{filename}: removed byte-order marker')
retv = 1
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,149 @@
from __future__ import annotations
import argparse
from typing import IO
from typing import NamedTuple
from typing import Sequence
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-'
def has_coding(line: bytes) -> bool:
if not line.strip():
return False
return (
line.lstrip()[:1] == b'#' and (
b'unicode' in line or
b'encoding' in line or
b'coding:' in line or
b'coding=' in line
)
)
class ExpectedContents(NamedTuple):
shebang: bytes
rest: bytes
# True: has exactly the coding pragma expected
# False: missing coding pragma entirely
# None: has a coding pragma, but it does not match
pragma_status: bool | None
ending: bytes
@property
def has_any_pragma(self) -> bool:
return self.pragma_status is not False
def is_expected_pragma(self, remove: bool) -> bool:
expected_pragma_status = not remove
return self.pragma_status is expected_pragma_status
def _get_expected_contents(
first_line: bytes,
second_line: bytes,
rest: bytes,
expected_pragma: bytes,
) -> ExpectedContents:
ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n'
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
else:
shebang = b''
potential_coding = first_line
rest = second_line + rest
if potential_coding.rstrip(b'\r\n') == expected_pragma:
pragma_status: bool | None = True
elif has_coding(potential_coding):
pragma_status = None
else:
pragma_status = False
rest = potential_coding + rest
return ExpectedContents(
shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending,
)
def fix_encoding_pragma(
f: IO[bytes],
remove: bool = False,
expected_pragma: bytes = DEFAULT_PRAGMA,
) -> int:
expected = _get_expected_contents(
f.readline(), f.readline(), f.read(), expected_pragma,
)
# Special cases for empty files
if not expected.rest.strip():
# If a file only has a shebang or a coding pragma, remove it
if expected.has_any_pragma or expected.shebang:
f.seek(0)
f.truncate()
f.write(b'')
return 1
else:
return 0
if expected.is_expected_pragma(remove):
return 0
# Otherwise, write out the new file
f.seek(0)
f.truncate()
f.write(expected.shebang)
if not remove:
f.write(expected_pragma + expected.ending)
f.write(expected.rest)
return 1
def _normalize_pragma(pragma: str) -> bytes:
return pragma.encode().rstrip()
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(
'Fixes the encoding pragma of python files',
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(
'--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma,
help=(
f'The encoding pragma to use. '
f'Default: {DEFAULT_PRAGMA.decode()}'
),
)
parser.add_argument(
'--remove', action='store_true',
help='Remove the encoding pragma (Useful in a python3-only codebase)',
)
args = parser.parse_args(argv)
retv = 0
if args.remove:
fmt = 'Removed encoding pragma from {filename}'
else:
fmt = 'Added `{pragma}` to {filename}'
for filename in args.filenames:
with open(filename, 'r+b') as f:
file_ret = fix_encoding_pragma(
f, remove=args.remove, expected_pragma=args.pragma,
)
retv |= file_ret
if file_ret:
print(
fmt.format(pragma=args.pragma.decode(), filename=filename),
)
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,48 @@
from __future__ import annotations
import argparse
import os
from typing import Sequence
from pre_commit_hooks.util import cmd_output
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
if (
'PRE_COMMIT_FROM_REF' in os.environ and
'PRE_COMMIT_TO_REF' in os.environ
):
diff_arg = '...'.join((
os.environ['PRE_COMMIT_FROM_REF'],
os.environ['PRE_COMMIT_TO_REF'],
))
else:
diff_arg = '--staged'
added_diff = cmd_output(
'git', 'diff', '--diff-filter=A', '--raw', diff_arg, '--',
*args.filenames,
)
retv = 0
for line in added_diff.splitlines():
metadata, filename = line.split('\t', 1)
new_mode = metadata.split(' ')[1]
if new_mode == '160000':
print(f'{filename}: new submodule introduced')
retv = 1
if retv:
print()
print('This commit introduces new submodules.')
print('Did you unintentionally `git add .`?')
print('To fix: git rm {thesubmodule} # no trailing slash')
print('Also check .gitmodules')
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,88 @@
from __future__ import annotations
import argparse
import collections
from typing import Sequence
CRLF = b'\r\n'
LF = b'\n'
CR = b'\r'
# Prefer LF to CRLF to CR, but detect CRLF before LF
ALL_ENDINGS = (CR, CRLF, LF)
FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
def _fix(filename: str, contents: bytes, ending: bytes) -> None:
new_contents = b''.join(
line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
)
with open(filename, 'wb') as f:
f.write(new_contents)
def fix_filename(filename: str, fix: str) -> int:
with open(filename, 'rb') as f:
contents = f.read()
counts: dict[bytes, int] = collections.defaultdict(int)
for line in contents.splitlines(True):
for ending in ALL_ENDINGS:
if line.endswith(ending):
counts[ending] += 1
break
# Some amount of mixed line endings
mixed = sum(bool(x) for x in counts.values()) > 1
if fix == 'no' or (fix == 'auto' and not mixed):
return mixed
if fix == 'auto':
max_ending = LF
max_lines = 0
# ordering is important here such that lf > crlf > cr
for ending_type in ALL_ENDINGS:
# also important, using >= to find a max that prefers the last
if counts[ending_type] >= max_lines:
max_ending = ending_type
max_lines = counts[ending_type]
_fix(filename, contents, max_ending)
return 1
else:
target_ending = FIX_TO_LINE_ENDING[fix]
# find if there are lines with *other* endings
# It's possible there's no line endings of the target type
counts.pop(target_ending, None)
other_endings = bool(sum(counts.values()))
if other_endings:
_fix(filename, contents, target_ending)
return other_endings
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'-f', '--fix',
choices=('auto', 'no') + tuple(FIX_TO_LINE_ENDING),
default='auto',
help='Replace line ending with the specified. Default is "auto"',
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
if fix_filename(filename, args.fix):
if args.fix == 'no':
print(f'{filename}: mixed line endings')
else:
print(f'{filename}: fixed mixed line endings')
retv = 1
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,48 @@
from __future__ import annotations
import argparse
import re
from typing import AbstractSet
from typing import Sequence
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
def is_on_branch(
protected: AbstractSet[str],
patterns: AbstractSet[str] = frozenset(),
) -> bool:
try:
ref_name = cmd_output('git', 'symbolic-ref', 'HEAD')
except CalledProcessError:
return False
chunks = ref_name.strip().split('/')
branch_name = '/'.join(chunks[2:])
return branch_name in protected or any(
re.match(p, branch_name) for p in patterns
)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'-b', '--branch', action='append',
help='branch to disallow commits to, may be specified multiple times',
)
parser.add_argument(
'-p', '--pattern', action='append',
help=(
'regex pattern for branch name to disallow commits to, '
'may be specified multiple times'
),
)
args = parser.parse_args(argv)
protected = frozenset(args.branch or ('master', 'main'))
patterns = frozenset(args.pattern or ())
return int(is_on_branch(protected, patterns))
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,133 @@
from __future__ import annotations
import argparse
import json
import sys
from difflib import unified_diff
from typing import Mapping
from typing import Sequence
def _get_pretty_format(
contents: str,
indent: str,
ensure_ascii: bool = True,
sort_keys: bool = True,
top_keys: Sequence[str] = (),
) -> str:
def pairs_first(pairs: Sequence[tuple[str, str]]) -> Mapping[str, str]:
before = [pair for pair in pairs if pair[0] in top_keys]
before = sorted(before, key=lambda x: top_keys.index(x[0]))
after = [pair for pair in pairs if pair[0] not in top_keys]
if sort_keys:
after.sort()
return dict(before + after)
json_pretty = json.dumps(
json.loads(contents, object_pairs_hook=pairs_first),
indent=indent,
ensure_ascii=ensure_ascii,
)
return f'{json_pretty}\n'
def _autofix(filename: str, new_contents: str) -> None:
print(f'Fixing file {filename}')
with open(filename, 'w', encoding='UTF-8') as f:
f.write(new_contents)
def parse_num_to_int(s: str) -> int | str:
"""Convert string numbers to int, leaving strings as is."""
try:
return int(s)
except ValueError:
return s
def parse_topkeys(s: str) -> list[str]:
return s.split(',')
def get_diff(source: str, target: str, file: str) -> str:
source_lines = source.splitlines(True)
target_lines = target.splitlines(True)
diff = unified_diff(source_lines, target_lines, fromfile=file, tofile=file)
return ''.join(diff)
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'--autofix',
action='store_true',
dest='autofix',
help='Automatically fixes encountered not-pretty-formatted files',
)
parser.add_argument(
'--indent',
type=parse_num_to_int,
default='2',
help=(
'The number of indent spaces or a string to be used as delimiter'
' for indentation level e.g. 4 or "\t" (Default: 2)'
),
)
parser.add_argument(
'--no-ensure-ascii',
action='store_true',
dest='no_ensure_ascii',
default=False,
help=(
'Do NOT convert non-ASCII characters to Unicode escape sequences '
'(\\uXXXX)'
),
)
parser.add_argument(
'--no-sort-keys',
action='store_true',
dest='no_sort_keys',
default=False,
help='Keep JSON nodes in the same order',
)
parser.add_argument(
'--top-keys',
type=parse_topkeys,
dest='top_keys',
default=[],
help='Ordered list of keys to keep at the top of JSON hashes',
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
status = 0
for json_file in args.filenames:
with open(json_file, encoding='UTF-8') as f:
contents = f.read()
try:
pretty_contents = _get_pretty_format(
contents, args.indent, ensure_ascii=not args.no_ensure_ascii,
sort_keys=not args.no_sort_keys, top_keys=args.top_keys,
)
except ValueError:
print(
f'Input File {json_file} is not a valid JSON, consider using '
f'check-json',
)
return 1
if contents != pretty_contents:
if args.autofix:
_autofix(json_file, pretty_contents)
else:
diff_output = get_diff(contents, pretty_contents, json_file)
sys.stdout.buffer.write(diff_output.encode())
status = 1
return status
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,16 @@
from __future__ import annotations
import sys
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
argv = argv if argv is not None else sys.argv[1:]
hookid, new_hookid, url = argv[:3]
raise SystemExit(
f'`{hookid}` has been removed -- use `{new_hookid}` from {url}',
)
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,153 @@
from __future__ import annotations
import argparse
import re
from typing import IO
from typing import Sequence
PASS = 0
FAIL = 1
class Requirement:
UNTIL_COMPARISON = re.compile(b'={2,3}|!=|~=|>=?|<=?')
UNTIL_SEP = re.compile(rb'[^;\s]+')
def __init__(self) -> None:
self.value: bytes | None = None
self.comments: list[bytes] = []
@property
def name(self) -> bytes:
assert self.value is not None, self.value
name = self.value.lower()
for egg in (b'#egg=', b'&egg='):
if egg in self.value:
return name.partition(egg)[-1]
m = self.UNTIL_SEP.match(name)
assert m is not None
name = m.group()
m = self.UNTIL_COMPARISON.search(name)
if not m:
return name
return name[:m.start()]
def __lt__(self, requirement: Requirement) -> bool:
# \n means top of file comment, so always return True,
# otherwise just do a string comparison with value.
assert self.value is not None, self.value
if self.value == b'\n':
return True
elif requirement.value == b'\n':
return False
else:
return self.name < requirement.name
def is_complete(self) -> bool:
return (
self.value is not None and
not self.value.rstrip(b'\r\n').endswith(b'\\')
)
def append_value(self, value: bytes) -> None:
if self.value is not None:
self.value += value
else:
self.value = value
def fix_requirements(f: IO[bytes]) -> int:
requirements: list[Requirement] = []
before = list(f)
after: list[bytes] = []
before_string = b''.join(before)
# adds new line in case one is missing
# AND a change to the requirements file is needed regardless:
if before and not before[-1].endswith(b'\n'):
before[-1] += b'\n'
# If the file is empty (i.e. only whitespace/newlines) exit early
if before_string.strip() == b'':
return PASS
for line in before:
# If the most recent requirement object has a value, then it's
# time to start building the next requirement object.
if not len(requirements) or requirements[-1].is_complete():
requirements.append(Requirement())
requirement = requirements[-1]
# If we see a newline before any requirements, then this is a
# top of file comment.
if len(requirements) == 1 and line.strip() == b'':
if (
len(requirement.comments) and
requirement.comments[0].startswith(b'#')
):
requirement.value = b'\n'
else:
requirement.comments.append(line)
elif line.lstrip().startswith(b'#') or line.strip() == b'':
requirement.comments.append(line)
else:
requirement.append_value(line)
# if a file ends in a comment, preserve it at the end
if requirements[-1].value is None:
rest = requirements.pop().comments
else:
rest = []
# find and remove pkg-resources==0.0.0
# which is automatically added by broken pip package under Debian
requirements = [
req for req in requirements
if req.value != b'pkg-resources==0.0.0\n'
]
for requirement in sorted(requirements):
after.extend(requirement.comments)
assert requirement.value, requirement.value
after.append(requirement.value)
after.extend(rest)
after_string = b''.join(after)
if before_string == after_string:
return PASS
else:
f.seek(0)
f.write(after_string)
f.truncate()
return FAIL
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
retv = PASS
for arg in args.filenames:
with open(arg, 'rb+') as file_obj:
ret_for_file = fix_requirements(file_obj)
if ret_for_file:
print(f'Sorting {arg}')
retv |= ret_for_file
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,125 @@
"""Sort a simple YAML file, keeping blocks of comments and definitions
together.
We assume a strict subset of YAML that looks like:
# block of header comments
# here that should always
# be at the top of the file
# optional comments
# can go here
key: value
key: value
key: value
In other words, we don't sort deeper than the top layer, and might corrupt
complicated YAML files.
"""
from __future__ import annotations
import argparse
from typing import Sequence
QUOTES = ["'", '"']
def sort(lines: list[str]) -> list[str]:
"""Sort a YAML file in alphabetical order, keeping blocks together.
:param lines: array of strings (without newlines)
:return: sorted array of strings
"""
# make a copy of lines since we will clobber it
lines = list(lines)
new_lines = parse_block(lines, header=True)
for block in sorted(parse_blocks(lines), key=first_key):
if new_lines:
new_lines.append('')
new_lines.extend(block)
return new_lines
def parse_block(lines: list[str], header: bool = False) -> list[str]:
"""Parse and return a single block, popping off the start of `lines`.
If parsing a header block, we stop after we reach a line that is not a
comment. Otherwise, we stop after reaching an empty line.
:param lines: list of lines
:param header: whether we are parsing a header block
:return: list of lines that form the single block
"""
block_lines = []
while lines and lines[0] and (not header or lines[0].startswith('#')):
block_lines.append(lines.pop(0))
return block_lines
def parse_blocks(lines: list[str]) -> list[list[str]]:
"""Parse and return all possible blocks, popping off the start of `lines`.
:param lines: list of lines
:return: list of blocks, where each block is a list of lines
"""
blocks = []
while lines:
if lines[0] == '':
lines.pop(0)
else:
blocks.append(parse_block(lines))
return blocks
def first_key(lines: list[str]) -> str:
"""Returns a string representing the sort key of a block.
The sort key is the first YAML key we encounter, ignoring comments, and
stripping leading quotes.
>>> print(test)
# some comment
'foo': true
>>> first_key(test)
'foo'
"""
for line in lines:
if line.startswith('#'):
continue
if any(line.startswith(quote) for quote in QUOTES):
return line[1:]
return line
else:
return '' # not actually reached in reality
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
retval = 0
for filename in args.filenames:
with open(filename, 'r+') as f:
lines = [line.rstrip() for line in f.readlines()]
new_lines = sort(lines)
if lines != new_lines:
print(f'Fixing file `{filename}`')
f.seek(0)
f.write('\n'.join(new_lines) + '\n')
f.truncate()
retval = 1
return retval
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,93 @@
from __future__ import annotations
import argparse
import io
import re
import sys
import tokenize
from typing import Sequence
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
FSTRING_START = tokenize.FSTRING_START
FSTRING_END = tokenize.FSTRING_END
else: # pragma: <3.12 cover
FSTRING_START = FSTRING_END = -1
START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
def handle_match(token_text: str) -> str:
if '"""' in token_text or "'''" in token_text:
return token_text
match = START_QUOTE_RE.match(token_text)
if match is not None:
meat = token_text[match.end():-1]
if '"' in meat or "'" in meat:
return token_text
else:
return match.group().replace('"', "'") + meat + "'"
else:
return token_text
def get_line_offsets_by_line_no(src: str) -> list[int]:
# Padded so we can index with line number
offsets = [-1, 0]
for line in src.splitlines(True):
offsets.append(offsets[-1] + len(line))
return offsets
def fix_strings(filename: str) -> int:
with open(filename, encoding='UTF-8', newline='') as f:
contents = f.read()
line_offsets = get_line_offsets_by_line_no(contents)
# Basically a mutable string
splitcontents = list(contents)
fstring_depth = 0
# Iterate in reverse so the offsets are always correct
tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline))
tokens = reversed(tokens_l)
for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens:
if token_type == FSTRING_START: # pragma: >=3.12 cover
fstring_depth += 1
elif token_type == FSTRING_END: # pragma: >=3.12 cover
fstring_depth -= 1
elif fstring_depth == 0 and token_type == tokenize.STRING:
new_text = handle_match(token_text)
splitcontents[
line_offsets[srow] + scol:
line_offsets[erow] + ecol
] = new_text
new_contents = ''.join(splitcontents)
if contents != new_contents:
with open(filename, 'w', encoding='UTF-8', newline='') as f:
f.write(new_contents)
return 1
else:
return 0
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
retv = 0
for filename in args.filenames:
return_value = fix_strings(filename)
if return_value != 0:
print(f'Fixing strings in {filename}')
retv |= return_value
return retv
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,53 @@
from __future__ import annotations
import argparse
import os.path
import re
from typing import Sequence
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
mutex = parser.add_mutually_exclusive_group()
mutex.add_argument(
'--pytest',
dest='pattern',
action='store_const',
const=r'.*_test\.py',
default=r'.*_test\.py',
help='(the default) ensure tests match %(const)s',
)
mutex.add_argument(
'--pytest-test-first',
dest='pattern',
action='store_const',
const=r'test_.*\.py',
help='ensure tests match %(const)s',
)
mutex.add_argument(
'--django', '--unittest',
dest='pattern',
action='store_const',
const=r'test.*\.py',
help='ensure tests match %(const)s',
)
args = parser.parse_args(argv)
retcode = 0
reg = re.compile(args.pattern)
for filename in args.filenames:
base = os.path.basename(filename)
if (
not reg.fullmatch(base) and
not base == '__init__.py' and
not base == 'conftest.py'
):
retcode = 1
print(f'{filename} does not match pattern "{args.pattern}"')
return retcode
if __name__ == '__main__':
raise SystemExit(main())

View file

@ -0,0 +1,103 @@
from __future__ import annotations
import argparse
import os
from typing import Sequence
def _fix_file(
filename: str,
is_markdown: bool,
chars: bytes | None,
) -> bool:
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown, chars) for line in lines]
if newlines != lines:
with open(filename, mode='wb') as file_processed:
for line in newlines:
file_processed.write(line)
return True
else:
return False
def _process_line(
line: bytes,
is_markdown: bool,
chars: bytes | None,
) -> bytes:
if line[-2:] == b'\r\n':
eol = b'\r\n'
line = line[:-2]
elif line[-1:] == b'\n':
eol = b'\n'
line = line[:-1]
else:
eol = b''
# preserve trailing two-space for non-blank lines in markdown files
if is_markdown and (not line.isspace()) and line.endswith(b' '):
return line[:-2].rstrip(chars) + b' ' + eol
return line.rstrip(chars) + eol
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
'--no-markdown-linebreak-ext',
action='store_true',
help=argparse.SUPPRESS,
)
parser.add_argument(
'--markdown-linebreak-ext',
action='append',
default=[],
metavar='*|EXT[,EXT,...]',
help=(
'Markdown extensions (or *) to not strip linebreak spaces. '
'default: %(default)s'
),
)
parser.add_argument(
'--chars',
help=(
'The set of characters to strip from the end of lines. '
'Defaults to all whitespace characters.'
),
)
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
if args.no_markdown_linebreak_ext:
print('--no-markdown-linebreak-ext now does nothing!')
md_args = args.markdown_linebreak_ext
if '' in md_args:
parser.error('--markdown-linebreak-ext requires a non-empty argument')
all_markdown = '*' in md_args
# normalize extensions; split at ',', lowercase, and force 1 leading '.'
md_exts = [
'.' + x.lower().lstrip('.') for x in ','.join(md_args).split(',')
]
# reject probable "eaten" filename as extension: skip leading '.' with [1:]
for ext in md_exts:
if any(c in ext[1:] for c in r'./\:'):
parser.error(
f'bad --markdown-linebreak-ext extension '
f'{ext!r} (has . / \\ :)\n'
f" (probably filename; use '--markdown-linebreak-ext=EXT')",
)
chars = None if args.chars is None else args.chars.encode()
return_code = 0
for filename in args.filenames:
_, extension = os.path.splitext(filename.lower())
md = all_markdown or extension in md_exts
if _fix_file(filename, md, chars):
print(f'Fixing {filename}')
return_code = 1
return return_code
if __name__ == '__main__':
raise SystemExit(main())

32
pre_commit_hooks/util.py Normal file
View file

@ -0,0 +1,32 @@
from __future__ import annotations
import subprocess
from typing import Any
class CalledProcessError(RuntimeError):
pass
def added_files() -> set[str]:
cmd = ('git', 'diff', '--staged', '--name-only', '--diff-filter=A')
return set(cmd_output(*cmd).splitlines())
def cmd_output(*cmd: str, retcode: int | None = 0, **kwargs: Any) -> str:
kwargs.setdefault('stdout', subprocess.PIPE)
kwargs.setdefault('stderr', subprocess.PIPE)
proc = subprocess.Popen(cmd, **kwargs)
stdout, stderr = proc.communicate()
stdout = stdout.decode()
if retcode is not None and proc.returncode != retcode:
raise CalledProcessError(cmd, retcode, proc.returncode, stdout, stderr)
return stdout
def zsplit(s: str) -> list[str]:
s = s.strip('\0')
if s:
return s.split('\0')
else:
return []