chaokunyang commented on code in PR #3658:
URL: https://github.com/apache/fory/pull/3658#discussion_r3208487761
##########
compiler/fory_compiler/generators/rust.py:
##########
@@ -68,6 +76,427 @@ class RustGenerator(BaseGenerator):
PrimitiveKind.ANY: "Box<dyn Any>",
}
+ # Strict and reserved keywords defined in Rust
(https://doc.rust-lang.org/reference/keywords.html).
+ # Weak keywords are intentionally excluded because they are usable outside
their special syntax contexts.
+ RUST_RAW_IDENTIFIER_KEYWORDS = {
+ "as",
+ "async",
+ "await",
+ "abstract",
+ "become",
+ "box",
+ "break",
+ "const",
+ "continue",
+ "do",
+ "dyn",
+ "else",
+ "enum",
+ "extern",
+ "false",
+ "final",
+ "fn",
+ "for",
+ "gen",
+ "if",
+ "impl",
+ "in",
+ "let",
+ "loop",
+ "macro",
+ "match",
+ "mod",
+ "move",
+ "mut",
+ "override",
+ "priv",
+ "pub",
+ "ref",
+ "return",
+ "static",
+ "struct",
+ "trait",
+ "true",
+ "try",
+ "type",
+ "typeof",
+ "unsafe",
+ "unsized",
+ "use",
+ "virtual",
+ "where",
+ "while",
+ "yield",
+ }
+
+ # Reserved identifiers in Rust
(https://doc.rust-lang.org/reference/identifiers.html#railroad-RESERVED_RAW_IDENTIFIER).
+ # These tokens are invalid even with an `r#` prefix, so escape them by
suffixing `_` instead.
+ RUST_RESERVED_IDENTIFIERS = {"_", "self", "Self", "super", "crate"}
+
+ def sanitize_identifier(self, normalized: str) -> str:
+ """Escape an already-normalized Rust name."""
+ if normalized in self.RUST_RESERVED_IDENTIFIERS:
+ return f"{normalized}_"
+ if normalized and normalized[0].isnumeric():
+ return f"_{normalized}" # Rust identifiers cannot start with a
digit.
+ if normalized in self.RUST_RAW_IDENTIFIER_KEYWORDS:
+ return f"r#{normalized}"
+ return normalized
+
+ def to_rust_snake(self, source: str) -> str:
+ """Convert an IDL name to a sanitized Rust snake_case identifier."""
+ return self.sanitize_identifier(self.to_snake_case(source))
+
+ def to_rust_upper_camel(self, source: str) -> str:
+ """Convert an IDL name to a sanitized Rust UpperCamelCase
identifier."""
+ return self.sanitize_identifier(self.to_pascal_case(source))
+
+ def get_top_level_module_identifier(self, package: Optional[str]) -> str:
+ """Get the Rust module identifier used to reference one schema file."""
+ # e.g., `foo.bar` defined in the IDL will be `foo_bar` in the
generated Rust code.
+ module_name = package.replace(".", "_") if package else "generated"
+ return self.to_rust_snake(module_name)
+
+ def get_top_level_module_filename(self, package: Optional[str]) -> str:
+ """Return the generated Rust filename for one schema module."""
+ module_name = self.get_top_level_module_identifier(package)
+ # e.g., when resolving the file for `pub mod r#type`, Rust looks for
`type.rs`, not `r#type.rs`.
+ if module_name.startswith("r#"):
+ return module_name[2:]
+ return module_name
+
+ def get_type_identifier(self, name: str, type_def: Optional[object]) ->
str:
+ """Get the allocated Rust identifier for a type declaration or
reference."""
+ # Look up the cache first.
+ if type_def is not None:
+ self._ensure_name_caches(self._schema_for_type_def(type_def))
+ allocated =
self._type_identifier_cache.get(self._cache_key(type_def))
+ if allocated is not None:
+ return allocated
+ return self.to_rust_upper_camel(name)
+
+ def get_module_identifier(self, message: Message) -> str:
+ """Get the allocated Rust module name for a message's nested-type
scope."""
+ # Look up the cache first.
+ self._ensure_name_caches(self._schema_for_type_def(message))
+ allocated = self._module_identifier_cache.get(self._cache_key(message))
+ if allocated is not None:
+ return allocated
+ return self.to_rust_snake(message.name)
+
+ def get_field_identifier(self, message: Message, field: Field) -> str:
+ """Get the allocated Rust field name within one message."""
+ # Look up the cache first.
+ self._ensure_name_caches(self._schema_for_type_def(message))
+ allocated = self._field_identifier_cache.get(self._cache_key(message),
{}).get(
+ self._cache_key(field)
+ )
+ if allocated is not None:
+ return allocated
+ return self.to_rust_snake(field.name)
+
+ def get_enum_value_identifier(self, enum: Enum, value: object) -> str:
+ """Get the allocated Rust variant name for one enum value."""
+ # Look up the cache first
+ self._ensure_name_caches(self._schema_for_type_def(enum))
+ allocated = self._enum_value_identifier_cache.get(
+ self._cache_key(enum), {}
+ ).get(self._cache_key(value))
+ if allocated is not None:
+ return allocated
+ return self.to_rust_upper_camel(self.strip_enum_prefix(enum.name,
value.name))
+
+ def get_union_case_identifier(self, union: Union, field: Field) -> str:
+ """Get the allocated Rust variant name for one union case."""
+ # Look up the cache first
+ self._ensure_name_caches(self._schema_for_type_def(union))
+ allocated = self._union_case_identifier_cache.get(
+ self._cache_key(union), {}
+ ).get(self._cache_key(field))
+ if allocated is not None:
+ return allocated
+ return self.to_rust_upper_camel(field.name)
+
+ def _allocate_scoped_identifier(
+ self,
+ normalized_name: str,
+ used_names: Dict[str, str],
+ scope: str,
+ source_name: str,
+ ) -> str:
+ """Validate one sanitized identifier inside a single generated scope.
Throw error on collision"""
+ escaped = self.sanitize_identifier(normalized_name)
+ if not escaped:
+ raise RustNameError(
+ f"Rust identifier for {source_name!r} in {scope} is empty"
+ )
+ previous_source = used_names.get(escaped)
+ if previous_source is not None:
+ raise RustNameCollisionError(
+ f"Rust name collision in {scope}: {previous_source!r} and "
+ f"{source_name!r} both map to Rust identifier {escaped!r}"
+ )
+ used_names[escaped] = source_name
+ return escaped
+
+ def _is_local_to_schema(self, type_def: object, schema: Schema) -> bool:
+ """Return whether a type definition belongs to the given schema file.
+
+ Name allocation is done per schema file so imported types keep the
names assigned by their own generator run.
+ This helper filters out imported definitions when building the local
caches.
+ """
+ if not schema.source_file:
+ return True
+ location = getattr(type_def, "location", None)
+ file_path = getattr(location, "file", None) if location else None
+ if not file_path:
+ return True
+ try:
+ return Path(file_path).resolve() ==
Path(schema.source_file).resolve()
+ except Exception:
+ return file_path == schema.source_file
+
+ def _cache_key(self, node: object) -> Tuple[object, ...]:
+ """Get a stable cache key across reparsed schemas when source
locations exist."""
+ location = getattr(node, "location", None)
+ if location is not None and getattr(location, "file", None):
+ try:
+ file_path = str(Path(location.file).resolve())
+ except Exception:
+ file_path = location.file
+ return (
+ type(node).__name__,
+ file_path,
+ location.line,
+ location.column,
+ )
+ return (type(node).__name__, id(node))
+
+ def _source_file_key(self, file_path: Optional[str]) -> Optional[str]:
+ """Normalize a source file path for stable lookups across parsed
schemas."""
+ if not file_path:
+ return None
+ try:
+ return str(Path(file_path).resolve())
+ except Exception:
+ return file_path
+
+ def _node_source_file_key(self, node: object) -> Optional[str]:
+ """Return the normalized source file key for an AST node."""
+ location = getattr(node, "location", None)
+ file_path = getattr(location, "file", None) if location else None
+ return self._source_file_key(file_path)
+
+ def _schema_source_file_key(self, schema: Schema) -> Optional[str]:
+ """Return the normalized source file key for a schema."""
+ return self._source_file_key(schema.source_file)
+
+ def _schema_for_source_file(self, file_path: Optional[str]) ->
Optional[Schema]:
+ """Build a schema view containing declarations from one source file."""
+ source_key = self._source_file_key(file_path)
+ if source_key is None:
+ return None
+ schema_source_key = self._schema_source_file_key(self.schema)
+ if source_key == schema_source_key:
+ return self.schema
+ enums = [
+ enum
+ for enum in self.schema.enums
+ if self._node_source_file_key(enum) == source_key
+ ]
+ unions = [
+ union
+ for union in self.schema.unions
+ if self._node_source_file_key(union) == source_key
+ ]
+ messages = [
+ message
+ for message in self.schema.messages
+ if self._node_source_file_key(message) == source_key
+ ]
+ if enums or unions or messages:
+ return Schema(
+ package=self._package_for_source_file(file_path),
+ enums=enums,
+ messages=messages,
+ unions=unions,
+ source_file=file_path,
+ source_format=self.schema.source_format,
+ )
+ return None
+
+ def _package_for_source_file(self, file_path: Optional[str]) ->
Optional[str]:
+ """Get the original package for a source file in the resolved
schema."""
+ source_key = self._source_file_key(file_path)
+ if source_key is None:
+ return None
+ if source_key == self._schema_source_file_key(self.schema):
+ return self.schema.package
+ if source_key in self.schema.source_packages:
+ return self.schema.source_packages[source_key]
+ return None
+
+ def _schema_for_type_def(self, type_def: object) -> Schema:
+ """Load the owning schema for a type definition when available.
+
+ Imported types need to look up names in the schema where they were
+ declared, not in the current schema that references them.
+ """
+ location = getattr(type_def, "location", None)
+ file_path = getattr(location, "file", None) if location else None
+ schema = self._schema_for_source_file(file_path) or
self._load_schema(file_path)
+ return schema or self.schema
+
+ def _local_top_level_types(
+ self, schema: Schema
+ ) -> Tuple[List[Enum], List[Union], List[Message]]:
+ """Collect top-level types that are declared directly in one schema
file."""
+ enums = [
+ enum for enum in schema.enums if self._is_local_to_schema(enum,
schema)
+ ]
+ unions = [
+ union for union in schema.unions if
self._is_local_to_schema(union, schema)
+ ]
+ messages = [
+ message
+ for message in schema.messages
+ if self._is_local_to_schema(message, schema)
+ ]
+ return enums, unions, messages
+
+ def _local_top_level_messages(self, schema: Schema) -> List[Message]:
+ """Return only the local top-level messages from a schema file."""
+ return self._local_top_level_types(schema)[2]
+
+ def _resolve_message_path(self, schema: Schema, parts: List[str]) ->
List[Message]:
+ """Resolve a dotted message path to the concrete message lineage.
+
+ Imported type references are stored as dotted names such as
+ `outer.inner.Type`. To reuse allocated module names we need the actual
+ `Message` objects for each parent segment.
+ """
+ lineage: List[Message] = []
+ scope = self._local_top_level_messages(schema)
+ for part in parts:
+ match = next((message for message in scope if message.name ==
part), None)
+ if match is None:
+ return []
+ lineage.append(match)
+ scope = match.nested_messages
+ return lineage
+
+ def _allocate_type_scope_names(self, type_defs: List[object], scope: str)
-> None:
Review Comment:
good idea, absolute paths is better
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]