diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py index f192783a55950c..ae5b30bf9f0d48 100644 --- a/Tools/jit/_optimizers.py +++ b/Tools/jit/_optimizers.py @@ -152,11 +152,17 @@ class _Block: fallthrough: bool = True # Whether this block can eventually reach the next uop (_JIT_CONTINUE): hot: bool = False + # Whether this block should be emitted in the hot section. This is separate + # from "hot": some cold fallthrough bridges must stay in the hot layout. + layout_hot: bool | None = None + # Whether this original assembler metadata/tail block should be preserved + # even if it is unreachable. + is_metadata: bool = False def resolve(self) -> typing.Self: """Find the first non-empty block reachable from this one.""" block = self - while block.link and not block.instructions: + while block.link and not block.instructions and block.fallthrough: block = block.link return block @@ -208,6 +214,8 @@ class Optimizer: const_reloc = "" _frame_pointer_modify: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH label_index: int = 0 + _cold_start: _Block | None = dataclasses.field(init=False, default=None) + _jump_name = "" def __post_init__(self) -> None: # Split the code into a linked list of basic blocks. A basic block is an @@ -339,18 +347,224 @@ def _lookup_label(self, label: str) -> _Block: def _is_far_target(self, label: str) -> bool: return not label.startswith(self.label_prefix) + def _continuation(self) -> _Block: + return self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE") + + def _cold_start_block(self) -> _Block: + if self._cold_start is None: + label = f"{self.symbol_prefix}_JIT_COLD_START" + self._cold_start = self._lookup_label(label) + self._cold_start.noninstructions.append(f"{label}:") + self._cold_start.layout_hot = False + return self._cold_start + + def _make_label(self) -> str: + label = f"{self.label_prefix}_JIT_LABEL_{self.label_index}" + self.label_index += 1 + return label + + def _ensure_label(self, block: _Block) -> str: + if block.label is None: + block.label = self._make_label() + self._labels[block.label] = block + block.noninstructions.insert(0, f"{block.label}:") + return block.label + + def _make_jump(self, target: _Block, *, hot: bool) -> _Block: + label = self._ensure_label(target) + return _Block( + instructions=[ + Instruction( + InstructionKind.JUMP, + self._jump_name, + f"\t{self._jump_name} {label}", + None, + label, + ) + ], + target=target, + fallthrough=False, + hot=hot, + layout_hot=True, + ) + + def _effective_layout_hot(self, block: _Block) -> bool: + resolved = block.resolve() + if resolved.layout_hot is not None: + return resolved.layout_hot + if block.layout_hot is not None: + return block.layout_hot + return resolved.hot + + def _same_layout_section(self, left: _Block, right: _Block) -> bool: + return self._effective_layout_hot(left) == self._effective_layout_hot(right) + + def _can_short_branch_to_layout( + self, inst: Instruction, source: _Block, target: _Block + ) -> bool: + if inst.kind != InstructionKind.SHORT_BRANCH: + return True + return self._same_layout_section(source, target) + + def _insert_fallthrough_bridge(self, block: _Block, target: _Block) -> _Block: + bridge = self._make_jump(target, hot=target.hot) + bridge.link = block.link + block.link = bridge + return bridge + + def _ensure_hot_fallthrough(self, block: _Block) -> None: + if block is self._continuation() or block is self._cold_start: + return + fallthrough = block.link + if ( + fallthrough is None + or not self._effective_layout_hot(block) + or not block.fallthrough + ): + return + + fallthrough_is_hot = self._effective_layout_hot(fallthrough) + target = block.target + inst = block.instructions[-1] if block.instructions else None + + # Keep AArch64 short branches in the hot layout: + # tbz x0, #0, .Lcold -> tbnz x0, #0, .Lhot + # .Lhot: b .Lcold + # .Lhot: + if ( + fallthrough_is_hot + and inst is not None + and inst.kind == InstructionKind.SHORT_BRANCH + and target is not None + and not self._effective_layout_hot(target) + ): + fallthrough_label = self._ensure_label(fallthrough) + inverted = self._invert_branch(inst, fallthrough_label) + assert inverted is not None + bridge = self._make_jump(target, hot=target.hot) + bridge.link = fallthrough + block.instructions[-1] = inverted + block.target = fallthrough + block.link = bridge + return + + if fallthrough_is_hot: + return + + # Make a hot-to-cold fallthrough explicit, preferably by inverting: + # b.eq .Lhot -> b.ne .Lcold + # .Lcold: .Lhot: + if ( + inst is not None + and inst.is_branch() + and target is not None + and self._effective_layout_hot(target) + ): + fallthrough_label = self._ensure_label(fallthrough) + inverted = None + if self._can_short_branch_to_layout(inst, block, fallthrough): + inverted = self._invert_branch(inst, fallthrough_label) + if inverted is not None: + bridge = self._make_jump(target, hot=True) + bridge.link = fallthrough + block.instructions[-1] = inverted + block.target = fallthrough + block.link = bridge + return + # If no inversion is possible, preserve the old fallthrough with: + # b .Lcold + self._insert_fallthrough_bridge(block, fallthrough) + + def _layout_units(self) -> list[tuple[bool, list[_Block]]]: + continuation = self._continuation() + cold_start = self._cold_start_block() + units: list[tuple[bool, list[_Block]]] = [] + unit: list[_Block] = [] + + def finish_unit() -> None: + nonlocal unit + if unit: + layout_hot = self._effective_layout_hot(unit[0]) + for unit_block in unit: + unit_block.layout_hot = layout_hot + units.append((layout_hot, unit)) + unit = [] + + for block in self._layout_blocks(): + if block is continuation or block is cold_start: + finish_unit() + continue + unit.append(block) + if block.instructions or not block.fallthrough: + finish_unit() + finish_unit() + return units + + def _metadata_blocks(self) -> list[_Block]: + return [block for block in self._blocks() if block.is_metadata] + + def _relink_blocks(self, blocks: list[_Block]) -> None: + for current, next_block in zip(blocks, blocks[1:]): + current.link = next_block + if blocks: + blocks[-1].link = None + + def _partition_hot_cold_blocks(self) -> None: + # The entry point must remain in the hot layout, even when it can't + # reach _JIT_CONTINUE. The stencil parser expects _JIT_ENTRY at code + # offset 0. + entry_label = f"{self.symbol_prefix}_JIT_ENTRY" + for block in self._layout_blocks(): + if block.label == entry_label: + block.layout_hot = True + + for block in list(self._layout_blocks()): + self._ensure_hot_fallthrough(block) + + continuation = self._continuation() + continuation.layout_hot = True + continuation.fallthrough = False + cold_start = self._cold_start_block() + cold_start.layout_hot = False + cold_start.fallthrough = True + + units = self._layout_units() + hot_blocks = [ + block for layout_hot, unit in units if layout_hot for block in unit + ] + cold_blocks = [ + block for layout_hot, unit in units if not layout_hot for block in unit + ] + self._relink_blocks( + [ + *hot_blocks, + continuation, + cold_start, + *cold_blocks, + *self._metadata_blocks(), + ] + ) + def _blocks(self) -> typing.Generator[_Block, None, None]: block: _Block | None = self._root while block: yield block block = block.link + def _layout_blocks(self) -> typing.Generator[_Block, None, None]: + for block in self._blocks(): + if not block.is_metadata: + yield block + def _body(self) -> str: lines = ["#" + line for line in self.text.splitlines()] - hot = True + hot: bool | None = True for block in self._blocks(): - if hot != block.hot: - hot = block.hot + layout_hot = block.layout_hot + if layout_hot is None: + layout_hot = block.hot + if hot != layout_hot: + hot = layout_hot # Make it easy to tell at a glance where cold code is: lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#")) lines.extend(block.noninstructions) @@ -378,12 +592,17 @@ def _insert_continue_label(self) -> None: continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE") assert continuation.label continuation.noninstructions.append(f"{continuation.label}:") + continuation.layout_hot = True + tail = end.link + while tail: + tail.is_metadata = True + tail = tail.link end.link, continuation.link = continuation, end.link def _mark_hot_blocks(self) -> None: - # Start with the last block, and perform a DFS to find all blocks that - # can eventually reach it: - todo = list(self._blocks())[-1:] + # Start with the continuation block, and perform a DFS to find all + # blocks that can eventually reach it: + todo = [self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")] while todo: block = todo.pop() block.hot = True @@ -413,11 +632,14 @@ def _invert_hot_branches(self) -> None: and len(jump.instructions) == 1 and list(self._predecessors(jump)) == [branch] ): + jump.layout_hot = True assert jump.target.label assert branch.target.label - inverted = self._invert_branch( - branch.instructions[-1], jump.target.label - ) + inst = branch.instructions[-1] + if inst.kind == InstructionKind.SHORT_BRANCH: + inverted = None + else: + inverted = self._invert_branch(inst, jump.target.label) # Check to see if the branch can even be inverted: if inverted is None: continue @@ -427,10 +649,12 @@ def _invert_hot_branches(self) -> None: ) branch.target, jump.target = jump.target, branch.target jump.hot = True + jump.layout_hot = True def _remove_redundant_jumps(self) -> None: # Zero-length jumps can be introduced by _insert_continue_label and # _invert_hot_branches: + continuation = self._continuation() for block in self._blocks(): target = block.target if target is None: @@ -441,7 +665,14 @@ def _remove_redundant_jumps(self) -> None: # FOO: # After: # FOO: - if block.link and target is block.link.resolve(): + if ( + block.link + and target is block.link.resolve() + and ( + self._same_layout_section(block, target) + or target is continuation + ) + ): block.target = None block.fallthrough = True block.instructions.pop() @@ -459,12 +690,15 @@ def _remove_redundant_jumps(self) -> None: ): assert target.target is not None assert target.target.label is not None + inst = block.instructions[-1] if block.instructions[ -1 ].kind == InstructionKind.SHORT_BRANCH and self._is_far_target( target.target.label ): continue + if not self._can_short_branch_to_layout(inst, block, target.target): + continue block.target = target.target block.instructions[-1] = block.instructions[-1].update_target( target.target.label @@ -488,20 +722,38 @@ def _find_live_blocks(self) -> set[_Block]: def _remove_unreachable(self) -> None: live = self._find_live_blocks() - continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE") - # Keep blocks after continuation as they may contain data and - # metadata that the assembler needs + continuation = self._continuation() + cont_or_cold_blocks = {continuation} + if self._cold_start is not None: + cont_or_cold_blocks.add(self._cold_start) + # Keep only the original assembler tail. Cold code after _JIT_CONTINUE + # is ordinary code and can be removed when unreachable. prev: _Block | None = None block = self._root - while block is not continuation: + # We now walk the whole list, so keep explicit sentinel checks in place + # of the old "stop at _JIT_CONTINUE" loop invariant. + seen_continuation = False + seen_cold_start = self._cold_start is None + while block is not None: + if block is continuation: + seen_continuation = True + if block is self._cold_start: + seen_cold_start = True next = block.link - assert next is not None - if not block in live and prev: + if ( + block not in live + and block not in cont_or_cold_blocks + and not block.is_metadata + and prev is not None + ): prev.link = next else: prev = block block = next - assert prev.link is block + if prev is not None: + assert prev.link is block + assert seen_continuation + assert seen_cold_start def _fixup_external_labels(self) -> None: if self._supports_external_relocations: @@ -544,8 +796,10 @@ def run(self) -> None: self._mark_hot_blocks() # Removing branches can expose opportunities for more branch removal. # Repeat a few times. 2 would probably do, but it's fast enough with 4. - for _ in range(4): + for iter in range(4): self._invert_hot_branches() + if iter == 0: + self._partition_hot_cold_blocks() self._remove_redundant_jumps() self._remove_unreachable() self._fixup_external_labels() @@ -559,6 +813,7 @@ class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods _branches = _AARCH64_BRANCHES _short_branches = _AARCH64_SHORT_BRANCHES + _jump_name = "b" # Mach-O does not support the 19 bit branch locations needed for branch reordering _supports_external_relocations = False _branch_patterns = [name.replace(".", r"\.") for name in _AARCH64_BRANCHES] @@ -776,6 +1031,7 @@ class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods _branches = _X86_BRANCHES _short_branches = {} + _jump_name = "jmp" _re_branch = re.compile( rf"\s*(?P{'|'.join(_X86_BRANCHES)})\s+(?P[\w.]+)" )