From 005f24c604543c47e6864de2058e0d51c4cd3cb0 Mon Sep 17 00:00:00 2001 From: Quratulain-bilal Date: Sat, 2 May 2026 16:45:41 +0500 Subject: [PATCH 1/3] fix(integration): refresh shared infra on integration switch --- src/specify_cli/__init__.py | 48 +++++- src/specify_cli/shared_infra.py | 79 ++++++++-- templates/constitution-template.md | 2 + .../test_integration_subcommand.py | 146 ++++++++++++++++++ 4 files changed, 259 insertions(+), 16 deletions(-) diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index ccd670d20e..94376c0cf0 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -765,6 +765,8 @@ def _install_shared_infra( tracker: StepTracker | None = None, force: bool = False, invoke_separator: str = ".", + refresh_managed: bool = False, + refresh_hint: str | None = None, ) -> bool: """Install shared infrastructure files into *project_path*. @@ -776,9 +778,23 @@ def _install_shared_infra( placeholders using *invoke_separator* (``"."`` for markdown agents, ``"-"`` for skills agents). - When *force* is ``True``, existing files are overwritten with the - latest bundled versions. When ``False`` (default), only missing - files are added and existing ones are skipped. + Overwrite policy: + + * ``force=True`` — overwrite every existing file (still skips symlinks + to avoid following links outside the project root). + * ``refresh_managed=True`` — overwrite only files whose on-disk hash + still matches the previously recorded manifest hash (i.e. unmodified + files installed by spec-kit). Files with diverging hashes are + treated as user customizations and preserved with a warning. + * Default — only add missing files; existing ones are skipped. + + *refresh_hint* — caller-supplied rich-text fragment shown after the + "Preserved customized files" warning to tell the user which flag/command + they should re-run with to overwrite their customizations. Each caller + passes the flag that's actually valid in its CLI surface (e.g. + ``--refresh-shared-infra`` for ``integration switch``, + ``--force`` for ``init``/``integration upgrade``). When ``None``, no + remediation hint is printed for customizations. Returns ``True`` on success. """ @@ -791,6 +807,8 @@ def _install_shared_infra( console=console, force=force, invoke_separator=invoke_separator, + refresh_managed=refresh_managed, + refresh_hint=refresh_hint, ) @@ -800,6 +818,8 @@ def _install_shared_infra_or_exit( tracker: StepTracker | None = None, force: bool = False, invoke_separator: str = ".", + refresh_managed: bool = False, + refresh_hint: str | None = None, ) -> bool: try: return _install_shared_infra( @@ -808,6 +828,8 @@ def _install_shared_infra_or_exit( tracker=tracker, force=force, invoke_separator=invoke_separator, + refresh_managed=refresh_managed, + refresh_hint=refresh_hint, ) except (ValueError, OSError) as exc: console.print(f"[red]Error:[/red] Failed to install shared infrastructure: {exc}") @@ -2578,7 +2600,8 @@ def integration_uninstall( def integration_switch( target: str = typer.Argument(help="Integration key to switch to"), script: str | None = typer.Option(None, "--script", help="Script type: sh or ps (default: from init-options.json or platform default)"), - force: bool = typer.Option(False, "--force", help="Force removal of modified files during uninstall"), + force: bool = typer.Option(False, "--force", help="Force removal of modified files during uninstall of the previous integration"), + refresh_shared_infra: bool = typer.Option(False, "--refresh-shared-infra", help="Also overwrite shared infrastructure files even if you customized them (otherwise customizations are preserved)"), integration_options: str | None = typer.Option(None, "--integration-options", help='Options for the target integration'), ): """Switch from the current integration to a different one.""" @@ -2749,14 +2772,27 @@ def integration_switch( target_integration, current, target, integration_options ) - # Ensure shared infrastructure is present (safe to run unconditionally; - # _install_shared_infra merges missing files without overwriting). + # Refresh shared infrastructure to the current CLI version. Switching + # integrations is exactly when stale vendored shared scripts (e.g. + # update-agent-context.sh that pre-dates the target integration's + # supported-agent list) would silently break the new integration. + # + # Use refresh_managed=True so only files that match their previously + # recorded hash are overwritten — user customizations are detected via + # hash divergence and preserved with a warning. Pass + # --refresh-shared-infra to overwrite customizations as well. See #2293. _install_shared_infra_or_exit( project_root, selected_script, + force=refresh_shared_infra, + refresh_managed=True, invoke_separator=_invoke_separator_for_integration( target_integration, current, target, parsed_options ), + refresh_hint=( + "To overwrite customizations, re-run with " + "[cyan]specify integration switch ... --refresh-shared-infra[/cyan]." + ), ) if os.name != "nt": ensure_executable_scripts(project_root) diff --git a/src/specify_cli/shared_infra.py b/src/specify_cli/shared_infra.py index 1e8be7b282..93e28fae43 100644 --- a/src/specify_cli/shared_infra.py +++ b/src/specify_cli/shared_infra.py @@ -242,13 +242,51 @@ def install_shared_infra( console: Any, force: bool = False, invoke_separator: str = ".", + refresh_managed: bool = False, + refresh_hint: str | None = None, ) -> bool: - """Install shared scripts and templates into *project_path*.""" + """Install shared scripts and templates into *project_path*. + + When ``refresh_managed`` is True, files whose on-disk hash still matches + the previously recorded manifest hash are overwritten with the bundled + version. Files whose hash diverges are treated as user customizations and + preserved with a warning. ``force=True`` overwrites everything regardless. + ``refresh_hint`` is shown after the customization warning to tell the user + which flag would overwrite their customizations. + """ + from .integrations.manifest import _sha256 + manifest = load_speckit_manifest(project_path, version=version, console=console) + prior_hashes = dict(manifest.files) + + def _is_managed(rel: str, dst: Path) -> bool: + expected = prior_hashes.get(rel) + if not expected or not dst.is_file() or dst.is_symlink(): + return False + try: + return _sha256(dst) == expected + except OSError: + return False + skipped_files: list[str] = [] + preserved_user_files: list[str] = [] planned_copies: list[tuple[Path, str, bytes, int]] = [] planned_templates: list[tuple[Path, str, str]] = [] + def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]: + """Return (write, bucket) where bucket is 'skip', 'preserved', or None.""" + if not dst.exists(): + return True, None + if force: + return True, None + if refresh_managed: + if _is_managed(rel, dst): + return True, None + if rel in prior_hashes: + return False, "preserved" + return False, "skip" + return False, "skip" + scripts_src = shared_scripts_source(core_pack=core_pack, repo_root=repo_root) if scripts_src.is_dir(): dest_scripts = project_path / ".specify" / "scripts" @@ -265,12 +303,16 @@ def install_shared_infra( rel_path = src_path.relative_to(variant_src) dst_path = dest_variant / rel_path _ensure_safe_shared_destination(project_path, dst_path, parent_must_exist=False) - if dst_path.exists() and not force: - skipped_files.append(dst_path.relative_to(project_path).as_posix()) + rel = dst_path.relative_to(project_path).as_posix() + write, bucket = _decide_overwrite(rel, dst_path) + if not write: + if bucket == "preserved": + preserved_user_files.append(rel) + else: + skipped_files.append(rel) continue _ensure_safe_shared_directory(project_path, dst_path.parent) - rel = dst_path.relative_to(project_path).as_posix() planned_copies.append((dst_path, rel, src_path.read_bytes(), src_path.stat().st_mode & 0o777)) templates_src = shared_templates_source(core_pack=core_pack, repo_root=repo_root) @@ -283,13 +325,17 @@ def install_shared_infra( dst = dest_templates / src.name _ensure_safe_shared_destination(project_path, dst) - if dst.exists() and not force: - skipped_files.append(dst.relative_to(project_path).as_posix()) + rel = dst.relative_to(project_path).as_posix() + write, bucket = _decide_overwrite(rel, dst) + if not write: + if bucket == "preserved": + preserved_user_files.append(rel) + else: + skipped_files.append(rel) continue content = src.read_text(encoding="utf-8") content = IntegrationBase.resolve_command_refs(content, invoke_separator) - rel = dst.relative_to(project_path).as_posix() planned_templates.append((dst, rel, content)) for dst_path, rel, content, mode in planned_copies: @@ -307,11 +353,24 @@ def install_shared_infra( ) for path in skipped_files: console.print(f" {path}") + if refresh_hint: + console.print(refresh_hint) + else: + console.print( + "To refresh shared infrastructure, run " + "[cyan]specify init --here --force[/cyan] or " + "[cyan]specify integration upgrade --force[/cyan]." + ) + + if preserved_user_files: console.print( - "To refresh shared infrastructure, run " - "[cyan]specify init --here --force[/cyan] or " - "[cyan]specify integration upgrade --force[/cyan]." + f"[yellow]⚠[/yellow] Preserved {len(preserved_user_files)} customized shared " + "infrastructure file(s) (hash differs from previous install):" ) + for path in preserved_user_files: + console.print(f" {path}") + if refresh_hint: + console.print(refresh_hint) manifest.save() return True diff --git a/templates/constitution-template.md b/templates/constitution-template.md index a4670ff469..0dfcca066a 100644 --- a/templates/constitution-template.md +++ b/templates/constitution-template.md @@ -1,3 +1,5 @@ + + # [PROJECT_NAME] Constitution diff --git a/tests/integrations/test_integration_subcommand.py b/tests/integrations/test_integration_subcommand.py index 750bbb6efa..5163f98db1 100644 --- a/tests/integrations/test_integration_subcommand.py +++ b/tests/integrations/test_integration_subcommand.py @@ -901,6 +901,152 @@ def test_switch_preserves_shared_infra(self, tmp_path): assert shared_script.exists() assert shared_script.read_text(encoding="utf-8") == shared_content + def test_switch_refreshes_stale_managed_shared_infra(self, tmp_path): + """Regression for #2293: stale managed shared scripts get refreshed on switch.""" + import hashlib + + project = _init_project(tmp_path, "claude") + shared_script = project / ".specify" / "scripts" / "bash" / "common.sh" + bundled_bytes = shared_script.read_bytes() + + # Simulate a stale vendored script: write truncated content as bytes + # (write_text would translate \n→\r\n on Windows and break the hash) + # and update the speckit manifest hash so the stale copy is treated + # as "managed" (installed by spec-kit, not a user customization). + stale_bytes = b"#!/usr/bin/env bash\n# stale vendored copy\n" + shared_script.write_bytes(stale_bytes) + + manifest_path = project / ".specify" / "integrations" / "speckit.manifest.json" + manifest_data = json.loads(manifest_path.read_text(encoding="utf-8")) + manifest_data["files"][".specify/scripts/bash/common.sh"] = ( + hashlib.sha256(stale_bytes).hexdigest() + ) + manifest_path.write_text(json.dumps(manifest_data), encoding="utf-8") + + old_cwd = os.getcwd() + try: + os.chdir(project) + result = runner.invoke(app, [ + "integration", "switch", "copilot", + "--script", "sh", + ], catch_exceptions=False) + finally: + os.chdir(old_cwd) + assert result.exit_code == 0 + + # Stale managed file should be replaced by the bundled version + assert shared_script.read_bytes() == bundled_bytes + + def test_switch_preserves_user_customized_shared_infra(self, tmp_path): + """User customizations (hash divergence from manifest) survive switch without --refresh-shared-infra.""" + project = _init_project(tmp_path, "claude") + shared_script = project / ".specify" / "scripts" / "bash" / "common.sh" + + # User customization: append bytes but do NOT update manifest hash, + # so on-disk hash diverges from the recorded one. + original = shared_script.read_bytes() + custom_bytes = original + b"\n# user customization\n" + shared_script.write_bytes(custom_bytes) + + old_cwd = os.getcwd() + try: + os.chdir(project) + result = runner.invoke(app, [ + "integration", "switch", "copilot", + "--script", "sh", + ], catch_exceptions=False) + finally: + os.chdir(old_cwd) + assert result.exit_code == 0 + assert shared_script.read_bytes() == custom_bytes + assert "Preserved" in result.output + + def test_switch_refresh_shared_infra_overwrites_customizations(self, tmp_path): + """--refresh-shared-infra explicitly overwrites user customizations on switch.""" + project = _init_project(tmp_path, "claude") + shared_script = project / ".specify" / "scripts" / "bash" / "common.sh" + bundled_bytes = shared_script.read_bytes() + + # User customization (hash diverges from manifest) + custom_bytes = bundled_bytes + b"\n# user customization\n" + shared_script.write_bytes(custom_bytes) + + old_cwd = os.getcwd() + try: + os.chdir(project) + result = runner.invoke(app, [ + "integration", "switch", "copilot", + "--script", "sh", + "--refresh-shared-infra", + ], catch_exceptions=False) + finally: + os.chdir(old_cwd) + assert result.exit_code == 0 + # Customization is overwritten with the bundled version + assert shared_script.read_bytes() == bundled_bytes + + def test_switch_skips_symlinked_parent_directory(self, tmp_path): + """Regression: if .specify/scripts/bash is a symlink, switch must not write through it. + + Copilot follow-up on #2375: leaf-only symlink check let writes escape + when an *ancestor* directory was symlinked outside the project root. + """ + import sys + if sys.platform.startswith("win"): + import pytest as _pytest + _pytest.skip("Symlink creation typically requires admin on Windows") + + project = _init_project(tmp_path, "claude") + bash_dir = project / ".specify" / "scripts" / "bash" + outside = tmp_path / "outside" + outside.mkdir() + for child in bash_dir.iterdir(): + child.rename(outside / child.name) + bash_dir.rmdir() + bash_dir.symlink_to(outside, target_is_directory=True) + sentinel = (outside / "common.sh").read_bytes() + + old_cwd = os.getcwd() + try: + os.chdir(project) + result = runner.invoke(app, [ + "integration", "switch", "copilot", + "--script", "sh", + ], catch_exceptions=False) + finally: + os.chdir(old_cwd) + assert result.exit_code == 0 + # Symlinked tree reported, not written through. + assert "symlink" in result.output.lower() + # Outside dir contents unchanged. + assert (outside / "common.sh").read_bytes() == sentinel + + def test_switch_force_alone_does_not_overwrite_shared_customizations(self, tmp_path): + """--force (uninstall semantics) must NOT overwrite shared-infra customizations. + + Regression: ensures the decoupling of --force and --refresh-shared-infra. + """ + project = _init_project(tmp_path, "claude") + shared_script = project / ".specify" / "scripts" / "bash" / "common.sh" + bundled_bytes = shared_script.read_bytes() + + custom_bytes = bundled_bytes + b"\n# user customization\n" + shared_script.write_bytes(custom_bytes) + + old_cwd = os.getcwd() + try: + os.chdir(project) + result = runner.invoke(app, [ + "integration", "switch", "copilot", + "--script", "sh", + "--force", + ], catch_exceptions=False) + finally: + os.chdir(old_cwd) + assert result.exit_code == 0 + # --force alone preserves the customization + assert shared_script.read_bytes() == custom_bytes + def test_switch_from_nothing(self, tmp_path): """Switch when no integration is installed should just install the target.""" project = tmp_path / "bare" From ee9afea14410ee54029f7847e903e972738a5ec3 Mon Sep 17 00:00:00 2001 From: Quratulain-bilal Date: Tue, 5 May 2026 07:27:48 +0500 Subject: [PATCH 2/3] fix(integration): address Copilot review on switch shared-infra refresh - Clarify install_shared_infra docstring: force overwrites regular files but always preserves symlinks (safe-destination check refuses to follow). - Print refresh_hint only for preserved_user_files; skipped_files keeps the generic remediation. Avoids misleading guidance when files were merely skipped (not detected as customized). - Catch ValueError from the safe-destination check and bucket the path under a new symlinked_files warning instead of aborting the switch. - Restore templates/constitution-template.md to upstream (drop accidental leading blank lines). --- src/specify_cli/shared_infra.py | 58 +++++++++++++++++++++++------- templates/constitution-template.md | 2 -- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/specify_cli/shared_infra.py b/src/specify_cli/shared_infra.py index 93e28fae43..fd40497fee 100644 --- a/src/specify_cli/shared_infra.py +++ b/src/specify_cli/shared_infra.py @@ -250,9 +250,12 @@ def install_shared_infra( When ``refresh_managed`` is True, files whose on-disk hash still matches the previously recorded manifest hash are overwritten with the bundled version. Files whose hash diverges are treated as user customizations and - preserved with a warning. ``force=True`` overwrites everything regardless. - ``refresh_hint`` is shown after the customization warning to tell the user - which flag would overwrite their customizations. + preserved with a warning. ``force=True`` overwrites every regular file + (symlinks and symlinked-parent destinations are always preserved with a + warning — the safe-destination check refuses to follow them so writes + cannot escape the project root). ``refresh_hint`` is shown after the + customization warning to tell the user which flag would overwrite their + customizations. """ from .integrations.manifest import _sha256 @@ -270,6 +273,7 @@ def _is_managed(rel: str, dst: Path) -> bool: skipped_files: list[str] = [] preserved_user_files: list[str] = [] + symlinked_files: list[str] = [] planned_copies: list[tuple[Path, str, bytes, int]] = [] planned_templates: list[tuple[Path, str, str]] = [] @@ -287,6 +291,22 @@ def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]: return False, "skip" return False, "skip" + def _safe_dest_or_bucket(dst: Path, rel: str, *, parent_must_exist: bool = True) -> bool: + """Run the safe-destination check and bucket symlinked paths. + + Returns True when the destination is safe to consider (write or skip). + Returns False (and records *rel* under ``symlinked_files``) when the + destination or any of its ancestors is a symlink — those paths can't + be written to safely, but they shouldn't abort the whole switch + either. They're surfaced as a separate "symlinked" warning bucket. + """ + try: + _ensure_safe_shared_destination(project_path, dst, parent_must_exist=parent_must_exist) + except ValueError: + symlinked_files.append(rel) + return False + return True + scripts_src = shared_scripts_source(core_pack=core_pack, repo_root=repo_root) if scripts_src.is_dir(): dest_scripts = project_path / ".specify" / "scripts" @@ -302,8 +322,9 @@ def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]: rel_path = src_path.relative_to(variant_src) dst_path = dest_variant / rel_path - _ensure_safe_shared_destination(project_path, dst_path, parent_must_exist=False) rel = dst_path.relative_to(project_path).as_posix() + if not _safe_dest_or_bucket(dst_path, rel, parent_must_exist=False): + continue write, bucket = _decide_overwrite(rel, dst_path) if not write: if bucket == "preserved": @@ -324,8 +345,9 @@ def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]: continue dst = dest_templates / src.name - _ensure_safe_shared_destination(project_path, dst) rel = dst.relative_to(project_path).as_posix() + if not _safe_dest_or_bucket(dst, rel): + continue write, bucket = _decide_overwrite(rel, dst) if not write: if bucket == "preserved": @@ -353,14 +375,24 @@ def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]: ) for path in skipped_files: console.print(f" {path}") - if refresh_hint: - console.print(refresh_hint) - else: - console.print( - "To refresh shared infrastructure, run " - "[cyan]specify init --here --force[/cyan] or " - "[cyan]specify integration upgrade --force[/cyan]." - ) + console.print( + "To refresh shared infrastructure, run " + "[cyan]specify init --here --force[/cyan] or " + "[cyan]specify integration upgrade --force[/cyan]." + ) + + if symlinked_files: + console.print( + f"[yellow]⚠[/yellow] Skipped {len(symlinked_files)} symlinked shared " + "infrastructure file(s) — symlinks are never overwritten because they " + "may resolve outside the project root:" + ) + for path in symlinked_files: + console.print(f" {path}") + console.print( + "To restore the bundled version, remove or replace the symlink manually, " + "then re-run the command." + ) if preserved_user_files: console.print( diff --git a/templates/constitution-template.md b/templates/constitution-template.md index 0dfcca066a..a4670ff469 100644 --- a/templates/constitution-template.md +++ b/templates/constitution-template.md @@ -1,5 +1,3 @@ - - # [PROJECT_NAME] Constitution From 93850a7a6ea94886357c0938ac9e3b130f21dcd0 Mon Sep 17 00:00:00 2001 From: Quratulain-bilal Date: Wed, 6 May 2026 02:10:25 +0500 Subject: [PATCH 3/3] fix(integration): narrow symlink bucketing to dedicated exception MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Copilot feedback on shared_infra.py:305 — _safe_dest_or_bucket caught any ValueError as 'symlinked', which masked genuine safety errors (path escape, parent-not-a-directory). - Introduce SymlinkedSharedPathError(ValueError) raised only by the symlink-specific branches in _ensure_safe_shared_*(). - _safe_dest_or_bucket() now catches only SymlinkedSharedPathError; other ValueErrors propagate so the operation aborts with the real cause instead of being silently bucketed. - Wrap top-level dest_scripts/dest_variant/dest_templates mkdir calls in the same bucket helper so a symlinked .specify/scripts or .specify/templates is preserved with a warning rather than aborting the switch (matches the documented 'preserve customizations' behavior). - Update tests to expect the new bucket+warn behavior for leaf-level symlinked destinations. --- src/specify_cli/shared_infra.py | 120 ++++++++++++++++++++------------ tests/integrations/test_cli.py | 42 ++++++----- 2 files changed, 98 insertions(+), 64 deletions(-) diff --git a/src/specify_cli/shared_infra.py b/src/specify_cli/shared_infra.py index fd40497fee..0593dbba26 100644 --- a/src/specify_cli/shared_infra.py +++ b/src/specify_cli/shared_infra.py @@ -11,6 +11,15 @@ from .integrations.manifest import IntegrationManifest +class SymlinkedSharedPathError(ValueError): + """Raised when a shared infrastructure path or ancestor is a symlink. + + Distinct from other unsafe-path errors so callers can preserve symlinked + destinations as customizations while still letting genuine safety errors + (e.g. path escape, not-a-directory) propagate and abort the operation. + """ + + def load_speckit_manifest( project_path: Path, *, @@ -89,7 +98,7 @@ def _ensure_safe_shared_directory(project_path: Path, directory: Path, *, create current = current / part label = _shared_destination_label(project_path, current) if current.is_symlink(): - raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}") + raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}") if current.exists(): if not current.is_dir(): raise ValueError(f"Shared infrastructure directory path is not a directory: {label}") @@ -102,7 +111,7 @@ def _ensure_safe_shared_directory(project_path: Path, directory: Path, *, create raise ValueError(f"Shared infrastructure directory does not exist: {label}") current.mkdir() if current.is_symlink(): - raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}") + raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}") try: current.resolve().relative_to(root) except (OSError, ValueError): @@ -119,7 +128,7 @@ def _validate_safe_shared_directory(project_path: Path, directory: Path) -> None current = current / part label = _shared_destination_label(project_path, current) if current.is_symlink(): - raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}") + raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}") if not current.exists(): continue if not current.is_dir(): @@ -145,7 +154,7 @@ def _ensure_safe_shared_destination( _validate_safe_shared_directory(project_path, dest.parent) label = _shared_destination_label(project_path, dest) if dest.is_symlink(): - raise ValueError(f"Refusing to overwrite symlinked shared infrastructure path: {label}") + raise SymlinkedSharedPathError(f"Refusing to overwrite symlinked shared infrastructure path: {label}") if dest.exists(): try: @@ -299,33 +308,76 @@ def _safe_dest_or_bucket(dst: Path, rel: str, *, parent_must_exist: bool = True) destination or any of its ancestors is a symlink — those paths can't be written to safely, but they shouldn't abort the whole switch either. They're surfaced as a separate "symlinked" warning bucket. + + Other unsafe-path errors (e.g. path escape, parent-not-a-directory) + are NOT caught here: they re-raise so the operation aborts, since + treating them as "symlinked" would mask security-relevant failures. """ try: _ensure_safe_shared_destination(project_path, dst, parent_must_exist=parent_must_exist) - except ValueError: + except SymlinkedSharedPathError: symlinked_files.append(rel) return False return True + def _ensure_or_bucket_dir(directory: Path) -> bool: + """Create *directory* unless an ancestor is symlinked. + + Returns True when the directory is safe to use. Returns False (and + records the path under ``symlinked_files``) when a symlink ancestor + forces us to skip the whole subtree. Other unsafe-path errors + (escape, not-a-directory) re-raise so the operation aborts. + """ + try: + _ensure_safe_shared_directory(project_path, directory) + except SymlinkedSharedPathError: + symlinked_files.append(directory.relative_to(project_path).as_posix()) + return False + return True + scripts_src = shared_scripts_source(core_pack=core_pack, repo_root=repo_root) if scripts_src.is_dir(): dest_scripts = project_path / ".specify" / "scripts" - _ensure_safe_shared_directory(project_path, dest_scripts) - variant_dir = "bash" if script_type == "sh" else "powershell" - variant_src = scripts_src / variant_dir - if variant_src.is_dir(): - dest_variant = dest_scripts / variant_dir - _ensure_safe_shared_directory(project_path, dest_variant) - for src_path in variant_src.rglob("*"): - if not src_path.is_file(): + if _ensure_or_bucket_dir(dest_scripts): + variant_dir = "bash" if script_type == "sh" else "powershell" + variant_src = scripts_src / variant_dir + if variant_src.is_dir(): + dest_variant = dest_scripts / variant_dir + if _ensure_or_bucket_dir(dest_variant): + for src_path in variant_src.rglob("*"): + if not src_path.is_file(): + continue + + rel_path = src_path.relative_to(variant_src) + dst_path = dest_variant / rel_path + rel = dst_path.relative_to(project_path).as_posix() + if not _safe_dest_or_bucket(dst_path, rel, parent_must_exist=False): + continue + write, bucket = _decide_overwrite(rel, dst_path) + if not write: + if bucket == "preserved": + preserved_user_files.append(rel) + else: + skipped_files.append(rel) + continue + + if not _ensure_or_bucket_dir(dst_path.parent): + continue + planned_copies.append((dst_path, rel, src_path.read_bytes(), src_path.stat().st_mode & 0o777)) + + templates_src = shared_templates_source(core_pack=core_pack, repo_root=repo_root) + if templates_src.is_dir(): + dest_templates = project_path / ".specify" / "templates" + if _ensure_or_bucket_dir(dest_templates): + for src in templates_src.iterdir(): + if not src.is_file() or src.name == "vscode-settings.json" or src.name.startswith("."): continue - rel_path = src_path.relative_to(variant_src) - dst_path = dest_variant / rel_path - rel = dst_path.relative_to(project_path).as_posix() - if not _safe_dest_or_bucket(dst_path, rel, parent_must_exist=False): + dst = dest_templates / src.name + rel = dst.relative_to(project_path).as_posix() + if not _safe_dest_or_bucket(dst, rel): continue - write, bucket = _decide_overwrite(rel, dst_path) + write, bucket = _decide_overwrite(rel, dst) if not write: if bucket == "preserved": preserved_user_files.append(rel) @@ -333,35 +385,13 @@ def _safe_dest_or_bucket(dst: Path, rel: str, *, parent_must_exist: bool = True) skipped_files.append(rel) continue - _ensure_safe_shared_directory(project_path, dst_path.parent) - planned_copies.append((dst_path, rel, src_path.read_bytes(), src_path.stat().st_mode & 0o777)) - - templates_src = shared_templates_source(core_pack=core_pack, repo_root=repo_root) - if templates_src.is_dir(): - dest_templates = project_path / ".specify" / "templates" - _ensure_safe_shared_directory(project_path, dest_templates) - for src in templates_src.iterdir(): - if not src.is_file() or src.name == "vscode-settings.json" or src.name.startswith("."): - continue - - dst = dest_templates / src.name - rel = dst.relative_to(project_path).as_posix() - if not _safe_dest_or_bucket(dst, rel): - continue - write, bucket = _decide_overwrite(rel, dst) - if not write: - if bucket == "preserved": - preserved_user_files.append(rel) - else: - skipped_files.append(rel) - continue - - content = src.read_text(encoding="utf-8") - content = IntegrationBase.resolve_command_refs(content, invoke_separator) - planned_templates.append((dst, rel, content)) + content = src.read_text(encoding="utf-8") + content = IntegrationBase.resolve_command_refs(content, invoke_separator) + planned_templates.append((dst, rel, content)) for dst_path, rel, content, mode in planned_copies: - _ensure_safe_shared_directory(project_path, dst_path.parent) + if not _ensure_or_bucket_dir(dst_path.parent): + continue _write_shared_bytes(project_path, dst_path, content, mode=mode) manifest.record_existing(rel) diff --git a/tests/integrations/test_cli.py b/tests/integrations/test_cli.py index 7732d57300..d02b6c3ef0 100644 --- a/tests/integrations/test_cli.py +++ b/tests/integrations/test_cli.py @@ -297,7 +297,7 @@ def test_shared_infra_warns_when_manifest_cannot_be_decoded(self, tmp_path, caps assert "A new shared manifest will be created" in captured.out @pytest.mark.skipif(not hasattr(os, "symlink"), reason="symlinks are unavailable") - def test_shared_infra_refuses_symlinked_script_destination(self, tmp_path): + def test_shared_infra_refuses_symlinked_script_destination(self, tmp_path, capsys): """Shared script refreshes must not follow destination symlinks.""" from specify_cli import _install_shared_infra @@ -311,13 +311,14 @@ def test_shared_infra_refuses_symlinked_script_destination(self, tmp_path): scripts_dir.mkdir(parents=True) os.symlink(outside, scripts_dir / "common.sh") - with pytest.raises(ValueError, match="Refusing to overwrite symlinked"): - _install_shared_infra(project, "sh", force=True) + _install_shared_infra(project, "sh", force=True) + captured = capsys.readouterr() + assert "symlinked shared infrastructure" in captured.out assert outside.read_text(encoding="utf-8") == "# outside\n" @pytest.mark.skipif(not hasattr(os, "symlink"), reason="symlinks are unavailable") - def test_shared_infra_refuses_symlinked_template_destination(self, tmp_path): + def test_shared_infra_refuses_symlinked_template_destination(self, tmp_path, capsys): """Shared template installs must not follow destination symlinks.""" from specify_cli import _install_shared_infra @@ -331,9 +332,10 @@ def test_shared_infra_refuses_symlinked_template_destination(self, tmp_path): templates_dir.mkdir(parents=True) os.symlink(outside, templates_dir / "plan-template.md") - with pytest.raises(ValueError, match="Refusing to overwrite symlinked"): - _install_shared_infra(project, "sh", force=True) + _install_shared_infra(project, "sh", force=True) + captured = capsys.readouterr() + assert "symlinked shared infrastructure" in captured.out assert outside.read_text(encoding="utf-8") == "# outside\n" @pytest.mark.skipif(not hasattr(os, "symlink"), reason="symlinks are unavailable") @@ -358,7 +360,7 @@ def test_shared_template_refresh_refuses_symlinked_destination(self, tmp_path): @pytest.mark.skipif(not hasattr(os, "symlink"), reason="symlinks are unavailable") def test_shared_infra_refuses_symlinked_specify_directory_before_mkdir(self, tmp_path): - """Shared infra directory creation must not follow a symlinked .specify.""" + """Shared infra installs must not follow a symlinked .specify directory.""" from specify_cli import _install_shared_infra project = tmp_path / "symlink-dir-test" @@ -367,8 +369,10 @@ def test_shared_infra_refuses_symlinked_specify_directory_before_mkdir(self, tmp outside.mkdir() os.symlink(outside, project / ".specify") - with pytest.raises(ValueError, match="symlinked shared infrastructure directory"): + with pytest.raises(ValueError, match="symlinked"): _install_shared_infra(project, "sh", force=True) + # Nothing should have been written under the symlinked .specify target. + assert list(outside.iterdir()) == [] assert not (outside / "scripts").exists() assert not (outside / "templates").exists() @@ -463,19 +467,19 @@ def test_shared_infra_install_preflights_before_writing(self, tmp_path): outside.write_text("# outside\n", encoding="utf-8") os.symlink(outside, scripts_dir / "z.sh") - with pytest.raises(ValueError, match="Refusing to overwrite symlinked"): - install_shared_infra( - project, - "sh", - version="test", - core_pack=core_pack, - repo_root=tmp_path / "unused", - console=_NoopConsole(), - force=True, - ) + install_shared_infra( + project, + "sh", + version="test", + core_pack=core_pack, + repo_root=tmp_path / "unused", + console=_NoopConsole(), + force=True, + ) - assert existing.read_text(encoding="utf-8") == "# old a\n" + # Symlinked z.sh is preserved (bucketed); regular a.sh is overwritten. assert outside.read_text(encoding="utf-8") == "# outside\n" + assert existing.read_text(encoding="utf-8") == "# new a\n" def test_shared_infra_install_supports_nested_script_sources(self, tmp_path): """Nested script source files create safe destination parents at write time."""