llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: None (TsXor) <details> <summary>Changes</summary> Related: https://github.com/llvm/llvm-project/issues/76664 I used metadata reflection so that we can import C library functions just by declaring annotated python functions. This makes C function types visible to type checker, then it's easy to fix most typing errors. --- Patch is 153.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101784.diff 5 Files Affected: - (modified) clang/bindings/python/clang/cindex.py (+1145-751) - (added) clang/bindings/python/clang/ctyped.py (+334) - (modified) clang/bindings/python/tests/cindex/test_type.py (+1-1) - (added) clang/bindings/python/tests/ctyped/__init__.py () - (added) clang/bindings/python/tests/ctyped/test_stub_conversion.py (+359) ``````````diff diff --git a/clang/bindings/python/clang/cindex.py b/clang/bindings/python/clang/cindex.py index 2038ef6045c7d..521dc2829ae41 100644 --- a/clang/bindings/python/clang/cindex.py +++ b/clang/bindings/python/clang/cindex.py @@ -62,36 +62,50 @@ # # o implement additional SourceLocation, SourceRange, and File methods. -from ctypes import * +from ctypes import (c_byte, c_ubyte, c_short, c_ushort, c_int, c_uint, c_long, # pyright: ignore[reportUnusedImport] + c_ulong, c_longlong,c_ulonglong, c_size_t, c_ssize_t, # pyright: ignore[reportUnusedImport] + c_bool, c_char, c_wchar, c_float, c_double, c_longdouble, # pyright: ignore[reportUnusedImport] + c_char_p, c_wchar_p, c_void_p) # pyright: ignore[reportUnusedImport] +from ctypes import py_object, Structure, POINTER, byref, cast, cdll +from .ctyped import * +from .ctyped import ANNO_CONVERTIBLE, generate_metadata import os import sys from enum import Enum from typing import ( + cast as tcast, Any, Callable, + Dict, + Generator, Generic, + Iterator, + List, Optional, + Tuple, Type as TType, TypeVar, TYPE_CHECKING, Union as TUnion, ) +from typing_extensions import Annotated + if TYPE_CHECKING: - from ctypes import _Pointer - from typing_extensions import Protocol, TypeAlias + from typing_extensions import Protocol, Self, TypeAlias + from ctypes import CDLL StrPath: TypeAlias = TUnion[str, os.PathLike[str]] - LibFunc: TypeAlias = TUnion[ - "tuple[str, Optional[list[Any]]]", - "tuple[str, Optional[list[Any]], Any]", - "tuple[str, Optional[list[Any]], Any, Callable[..., Any]]", - ] - + StrOrBytes: TypeAlias = TUnion[str, bytes] + FsPath: TypeAlias = TUnion[StrOrBytes, os.PathLike[str]] TSeq = TypeVar("TSeq", covariant=True) + class SupportsReadStringData(Protocol): + def read(self) -> str | bytes: + ... + class NoSliceSequence(Protocol[TSeq]): def __len__(self) -> int: ... @@ -102,7 +116,7 @@ def __getitem__(self, key: int) -> TSeq: # Python 3 strings are unicode, translate them to/from utf8 for C-interop. class c_interop_string(c_char_p): - def __init__(self, p: str | bytes | None = None): + def __init__(self, p: 'CInteropString' = None): if p is None: p = "" if isinstance(p, str): @@ -120,7 +134,7 @@ def value(self) -> str | None: # type: ignore [override] return val.decode("utf8") @classmethod - def from_param(cls, param: str | bytes | None) -> c_interop_string: + def from_param(cls, param: 'CInteropString') -> c_interop_string: if isinstance(param, str): return cls(param) if isinstance(param, bytes): @@ -136,6 +150,8 @@ def from_param(cls, param: str | bytes | None) -> c_interop_string: def to_python_string(x: c_interop_string, *args: Any) -> str | None: return x.value +CInteropString = Annotated[TUnion[str, bytes, None], ANNO_CONVERTIBLE, c_interop_string] + def b(x: str | bytes) -> bytes: if isinstance(x, bytes): @@ -147,7 +163,8 @@ def b(x: str | bytes) -> bytes: # object. This is a problem, because it means that from_parameter will see an # integer and pass the wrong value on platforms where int != void*. Work around # this by marshalling object arguments as void**. -c_object_p: TType[_Pointer[Any]] = POINTER(c_void_p) +CObjectP = CPointer[c_void_p] +c_object_p: TType[CObjectP] = convert_annotation(CObjectP) ### Exception Classes ### @@ -183,7 +200,7 @@ class TranslationUnitSaveError(Exception): # Indicates that the translation unit was somehow invalid. ERROR_INVALID_TU = 3 - def __init__(self, enumeration, message): + def __init__(self, enumeration: int, message: str): assert isinstance(enumeration, int) if enumeration < 1 or enumeration > 3: @@ -241,7 +258,7 @@ def __del__(self) -> None: conf.lib.clang_disposeString(self) @staticmethod - def from_result(res: _CXString, fn: Any = None, args: Any = None) -> str: + def from_result(res: _CXString, fn: Optional[Callable[..., _CXString]] = None, args: Optional[Tuple[Any, ...]] = None) -> str: assert isinstance(res, _CXString) pystr: str | None = conf.lib.clang_getCString(res) if pystr is None: @@ -255,71 +272,73 @@ class SourceLocation(Structure): """ _fields_ = [("ptr_data", c_void_p * 2), ("int_data", c_uint)] - _data = None + _data: Optional[Tuple[Optional[File], int, int, int]] = None - def _get_instantiation(self): + def _get_instantiation(self) -> Tuple[Optional[File], int, int, int]: if self._data is None: - f, l, c, o = c_object_p(), c_uint(), c_uint(), c_uint() + fp, l, c, o = c_object_p(), c_uint(), c_uint(), c_uint() conf.lib.clang_getInstantiationLocation( - self, byref(f), byref(l), byref(c), byref(o) + self, byref(fp), byref(l), byref(c), byref(o) ) - if f: - f = File(f) + if fp: + f = File(fp) else: f = None self._data = (f, int(l.value), int(c.value), int(o.value)) return self._data @staticmethod - def from_position(tu, file, line, column): + def from_position(tu: TranslationUnit, file: File, line: int, column: int) -> SourceLocation: """ Retrieve the source location associated with a given file/line/column in a particular translation unit. """ - return conf.lib.clang_getLocation(tu, file, line, column) # type: ignore [no-any-return] + return conf.lib.clang_getLocation(tu, file, line, column) @staticmethod - def from_offset(tu, file, offset): + def from_offset(tu: TranslationUnit, file: File, offset: int) -> SourceLocation: """Retrieve a SourceLocation from a given character offset. tu -- TranslationUnit file belongs to file -- File instance to obtain offset from offset -- Integer character offset within file """ - return conf.lib.clang_getLocationForOffset(tu, file, offset) # type: ignore [no-any-return] + return conf.lib.clang_getLocationForOffset(tu, file, offset) @property - def file(self): + def file(self) -> Optional[File]: """Get the file represented by this source location.""" return self._get_instantiation()[0] @property - def line(self): + def line(self) -> int: """Get the line represented by this source location.""" return self._get_instantiation()[1] @property - def column(self): + def column(self) -> int: """Get the column represented by this source location.""" return self._get_instantiation()[2] @property - def offset(self): + def offset(self) -> int: """Get the file offset represented by this source location.""" return self._get_instantiation()[3] @property - def is_in_system_header(self): + def is_in_system_header(self) -> bool: """Returns true if the given source location is in a system header.""" - return conf.lib.clang_Location_isInSystemHeader(self) # type: ignore [no-any-return] + return conf.lib.clang_Location_isInSystemHeader(self) - def __eq__(self, other): - return conf.lib.clang_equalLocations(self, other) # type: ignore [no-any-return] + def __eq__(self, other: object) -> bool: + if not isinstance(other, SourceLocation): + return NotImplemented + return conf.lib.clang_equalLocations(self, other) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __repr__(self): + def __repr__(self) -> str: if self.file: filename = self.file.name else: @@ -346,40 +365,43 @@ class SourceRange(Structure): # FIXME: Eliminate this and make normal constructor? Requires hiding ctypes # object. @staticmethod - def from_locations(start, end): - return conf.lib.clang_getRange(start, end) # type: ignore [no-any-return] + def from_locations(start: SourceLocation, end: SourceLocation) -> SourceRange: + return conf.lib.clang_getRange(start, end) @property - def start(self): + def start(self) -> SourceLocation: """ Return a SourceLocation representing the first character within a source range. """ - return conf.lib.clang_getRangeStart(self) # type: ignore [no-any-return] + return conf.lib.clang_getRangeStart(self) @property - def end(self): + def end(self) -> SourceLocation: """ Return a SourceLocation representing the last character within a source range. """ - return conf.lib.clang_getRangeEnd(self) # type: ignore [no-any-return] + return conf.lib.clang_getRangeEnd(self) - def __eq__(self, other): - return conf.lib.clang_equalRanges(self, other) # type: ignore [no-any-return] + def __eq__(self, other: object) -> bool: + if not isinstance(other, SourceRange): + return NotImplemented + return conf.lib.clang_equalRanges(self, other) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __contains__(self, other): + def __contains__(self, other: object) -> bool: """Useful to detect the Token/Lexer bug""" if not isinstance(other, SourceLocation): return False - if other.file is None and self.start.file is None: - pass - elif ( - self.start.file.name != other.file.name - or other.file.name != self.end.file.name + if ( + other.file is not None + and self.start.file is not None + and self.end.file is not None + and (other.file.name != self.start.file.name + or other.file.name != self.end.file.name) ): # same file name return False @@ -396,7 +418,7 @@ def __contains__(self, other): return True return False - def __repr__(self): + def __repr__(self) -> str: return "<SourceRange start %r, end %r>" % (self.start, self.end) @@ -421,23 +443,25 @@ class Diagnostic: DisplayCategoryName = 0x20 _FormatOptionsMask = 0x3F - def __init__(self, ptr): + ptr: CObjectP + + def __init__(self, ptr: CObjectP): self.ptr = ptr - def __del__(self): + def __del__(self) -> None: conf.lib.clang_disposeDiagnostic(self) @property - def severity(self): - return conf.lib.clang_getDiagnosticSeverity(self) # type: ignore [no-any-return] + def severity(self) -> int: + return conf.lib.clang_getDiagnosticSeverity(self) @property - def location(self): - return conf.lib.clang_getDiagnosticLocation(self) # type: ignore [no-any-return] + def location(self) -> SourceLocation: + return conf.lib.clang_getDiagnosticLocation(self) @property - def spelling(self): - return conf.lib.clang_getDiagnosticSpelling(self) # type: ignore [no-any-return] + def spelling(self) -> str: + return conf.lib.clang_getDiagnosticSpelling(self) @property def ranges(self) -> NoSliceSequence[SourceRange]: @@ -451,7 +475,7 @@ def __len__(self) -> int: def __getitem__(self, key: int) -> SourceRange: if key >= len(self): raise IndexError - return conf.lib.clang_getDiagnosticRange(self.diag, key) # type: ignore [no-any-return] + return conf.lib.clang_getDiagnosticRange(self.diag, key) return RangeIterator(self) @@ -492,28 +516,28 @@ def __getitem__(self, key: int) -> Diagnostic: return ChildDiagnosticsIterator(self) @property - def category_number(self): + def category_number(self) -> int: """The category number for this diagnostic or 0 if unavailable.""" - return conf.lib.clang_getDiagnosticCategory(self) # type: ignore [no-any-return] + return conf.lib.clang_getDiagnosticCategory(self) @property - def category_name(self): + def category_name(self) -> str: """The string name of the category for this diagnostic.""" - return conf.lib.clang_getDiagnosticCategoryText(self) # type: ignore [no-any-return] + return conf.lib.clang_getDiagnosticCategoryText(self) @property - def option(self): + def option(self) -> str: """The command-line option that enables this diagnostic.""" - return conf.lib.clang_getDiagnosticOption(self, None) # type: ignore [no-any-return] + return conf.lib.clang_getDiagnosticOption(self, None) @property - def disable_option(self): + def disable_option(self) -> str: """The command-line option that disables this diagnostic.""" disable = _CXString() conf.lib.clang_getDiagnosticOption(self, byref(disable)) return _CXString.from_result(disable) - def format(self, options=None): + def format(self, options: Optional[int] = None) -> str: """ Format this diagnostic for display. The options argument takes Diagnostic.Display* flags, which can be combined using bitwise OR. If @@ -524,19 +548,19 @@ def format(self, options=None): options = conf.lib.clang_defaultDiagnosticDisplayOptions() if options & ~Diagnostic._FormatOptionsMask: raise ValueError("Invalid format options") - return conf.lib.clang_formatDiagnostic(self, options) # type: ignore [no-any-return] + return conf.lib.clang_formatDiagnostic(self, options) - def __repr__(self): + def __repr__(self) -> str: return "<Diagnostic severity %r, location %r, spelling %r>" % ( self.severity, self.location, self.spelling, ) - def __str__(self): + def __str__(self) -> str: return self.format() - def from_param(self): + def from_param(self) -> CObjectP: return self.ptr @@ -547,11 +571,14 @@ class FixIt: with the given value. """ - def __init__(self, range, value): + range: SourceRange + value: str + + def __init__(self, range: SourceRange, value: str): self.range = range self.value = value - def __repr__(self): + def __repr__(self) -> str: return "<FixIt range %r, value %r>" % (self.range, self.value) @@ -570,16 +597,20 @@ class TokenGroup: You should not instantiate this class outside of this module. """ - def __init__(self, tu, memory, count): + _tu: TranslationUnit + _memory: CPointer[Token] + _count: c_uint + + def __init__(self, tu: TranslationUnit, memory: CPointer[Token], count: c_uint): self._tu = tu self._memory = memory self._count = count - def __del__(self): + def __del__(self) -> None: conf.lib.clang_disposeTokens(self._tu, self._memory, self._count) @staticmethod - def get_tokens(tu, extent): + def get_tokens(tu: TranslationUnit, extent: SourceRange) -> Generator[Token, None, None]: """Helper method to return all tokens in an extent. This functionality is needed multiple places in this module. We define @@ -616,16 +647,16 @@ class BaseEnumeration(Enum): """ Common base class for named enumerations held in sync with Index.h values. """ + value: int # pyright: ignore[reportIncompatibleMethodOverride] - - def from_param(self): + def from_param(self) -> int: return self.value @classmethod - def from_id(cls, id): + def from_id(cls, id: int) -> Self: return cls(id) - def __repr__(self): + def __repr__(self) -> str: return "%s.%s" % ( self.__class__.__name__, self.name, @@ -636,7 +667,7 @@ class TokenKind(BaseEnumeration): """Describes a specific type of a Token.""" @classmethod - def from_value(cls, value): + def from_value(cls, value: int) -> Self: """Obtain a registered TokenKind instance from its value.""" return cls.from_id(value) @@ -653,45 +684,44 @@ class CursorKind(BaseEnumeration): """ @staticmethod - def get_all_kinds(): + def get_all_kinds() -> List[CursorKind]: """Return all CursorKind enumeration instances.""" return list(CursorKind) - def is_declaration(self): + def is_declaration(self) -> bool: """Test if this is a declaration kind.""" - return conf.lib.clang_isDeclaration(self) # type: ignore [no-any-return] + return conf.lib.clang_isDeclaration(self) - def is_reference(self): + def is_reference(self) -> bool: """Test if this is a reference kind.""" - return conf.lib.clang_isReference(self) # type: ignore [no-any-return] + return conf.lib.clang_isReference(self) - def is_expression(self): + def is_expression(self) -> bool: """Test if this is an expression kind.""" - return conf.lib.clang_isExpression(self) # type: ignore [no-any-return] - - def is_statement(self): + return conf.lib.clang_isExpression(self) + def is_statement(self) -> bool: """Test if this is a statement kind.""" - return conf.lib.clang_isStatement(self) # type: ignore [no-any-return] + return conf.lib.clang_isStatement(self) - def is_attribute(self): + def is_attribute(self) -> bool: """Test if this is an attribute kind.""" - return conf.lib.clang_isAttribute(self) # type: ignore [no-any-return] + return conf.lib.clang_isAttribute(self) - def is_invalid(self): + def is_invalid(self) -> bool: """Test if this is an invalid kind.""" - return conf.lib.clang_isInvalid(self) # type: ignore [no-any-return] + return conf.lib.clang_isInvalid(self) - def is_translation_unit(self): + def is_translation_unit(self) -> bool: """Test if this is a translation unit kind.""" - return conf.lib.clang_isTranslationUnit(self) # type: ignore [no-any-return] + return conf.lib.clang_isTranslationUnit(self) - def is_preprocessing(self): + def is_preprocessing(self) -> bool: """Test if this is a preprocessing kind.""" - return conf.lib.clang_isPreprocessing(self) # type: ignore [no-any-return] + return conf.lib.clang_isPreprocessing(self) - def is_unexposed(self): + def is_unexposed(self) -> bool: """Test if this is an unexposed kind.""" - return conf.lib.clang_isUnexposed(self) # type: ignore [no-any-return] + return conf.lib.clang_isUnexposed(self) ### @@ -1555,7 +1585,7 @@ class Cursor(Structure): _fields_ = [("_kind_id", c_int), ("xdata", c_int), ("data", c_void_p * 3)] @staticmethod - def from_location(tu, location): + def from_location(tu: TranslationUnit, location: SourceLocation) -> Cursor: # We store a reference to the TU in the instance so the TU won't get # collected before the cursor. cursor = conf.lib.clang_getCursor(tu, location) @@ -1563,54 +1593,56 @@ def from_location(tu, location): return cursor - def __eq__(self, other): - return conf.lib.clang_equalCursors(self, other) # type: ignore [no-any-return] + def __eq__(self, other: object) -> bool: + if not isinstance(other, Cursor): + return NotImplemented + return conf.lib.clang_equalCursors(self, other) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def is_definition(self): + def is_definition(self) -> bool: """ Returns true if the declaration pointed at by the cursor is also a definition of that entity. """ - return conf.lib.clang_isCursorDefinition(self) # type: ignore [no-any-return] + return conf.lib.clang_isCursorDefinition(self) - def is_const_method(self): + def is_const_method(self) -> bool: """Returns True if the cursor ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/101784 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits