Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions src/specify_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,8 @@ def _install_shared_infra(
tracker: StepTracker | None = None,
force: bool = False,
invoke_separator: str = ".",
refresh_managed: bool = False,
refresh_hint: str | None = None,
) -> bool:
"""Install shared infrastructure files into *project_path*.

Expand All @@ -776,9 +778,23 @@ def _install_shared_infra(
placeholders using *invoke_separator* (``"."`` for markdown agents,
``"-"`` for skills agents).

When *force* is ``True``, existing files are overwritten with the
latest bundled versions. When ``False`` (default), only missing
files are added and existing ones are skipped.
Overwrite policy:

* ``force=True`` — overwrite every existing file (still skips symlinks
to avoid following links outside the project root).
* ``refresh_managed=True`` — overwrite only files whose on-disk hash
still matches the previously recorded manifest hash (i.e. unmodified
files installed by spec-kit). Files with diverging hashes are
treated as user customizations and preserved with a warning.
* Default — only add missing files; existing ones are skipped.

*refresh_hint* — caller-supplied rich-text fragment shown after the
"Preserved customized files" warning to tell the user which flag/command
they should re-run with to overwrite their customizations. Each caller
passes the flag that's actually valid in its CLI surface (e.g.
``--refresh-shared-infra`` for ``integration switch``,
``--force`` for ``init``/``integration upgrade``). When ``None``, no
remediation hint is printed for customizations.

Returns ``True`` on success.
"""
Expand All @@ -791,6 +807,8 @@ def _install_shared_infra(
console=console,
force=force,
invoke_separator=invoke_separator,
refresh_managed=refresh_managed,
refresh_hint=refresh_hint,
)


Expand All @@ -800,6 +818,8 @@ def _install_shared_infra_or_exit(
tracker: StepTracker | None = None,
force: bool = False,
invoke_separator: str = ".",
refresh_managed: bool = False,
refresh_hint: str | None = None,
) -> bool:
try:
return _install_shared_infra(
Expand All @@ -808,6 +828,8 @@ def _install_shared_infra_or_exit(
tracker=tracker,
force=force,
invoke_separator=invoke_separator,
refresh_managed=refresh_managed,
refresh_hint=refresh_hint,
)
except (ValueError, OSError) as exc:
console.print(f"[red]Error:[/red] Failed to install shared infrastructure: {exc}")
Expand Down Expand Up @@ -2578,7 +2600,8 @@ def integration_uninstall(
def integration_switch(
target: str = typer.Argument(help="Integration key to switch to"),
script: str | None = typer.Option(None, "--script", help="Script type: sh or ps (default: from init-options.json or platform default)"),
force: bool = typer.Option(False, "--force", help="Force removal of modified files during uninstall"),
force: bool = typer.Option(False, "--force", help="Force removal of modified files during uninstall of the previous integration"),
refresh_shared_infra: bool = typer.Option(False, "--refresh-shared-infra", help="Also overwrite shared infrastructure files even if you customized them (otherwise customizations are preserved)"),
integration_options: str | None = typer.Option(None, "--integration-options", help='Options for the target integration'),
Comment thread
mnriem marked this conversation as resolved.
):
"""Switch from the current integration to a different one."""
Expand Down Expand Up @@ -2749,14 +2772,27 @@ def integration_switch(
target_integration, current, target, integration_options
)

# Ensure shared infrastructure is present (safe to run unconditionally;
# _install_shared_infra merges missing files without overwriting).
# Refresh shared infrastructure to the current CLI version. Switching
# integrations is exactly when stale vendored shared scripts (e.g.
# update-agent-context.sh that pre-dates the target integration's
# supported-agent list) would silently break the new integration.
#
# Use refresh_managed=True so only files that match their previously
# recorded hash are overwritten — user customizations are detected via
# hash divergence and preserved with a warning. Pass
# --refresh-shared-infra to overwrite customizations as well. See #2293.
_install_shared_infra_or_exit(
project_root,
selected_script,
force=refresh_shared_infra,
refresh_managed=True,
invoke_separator=_invoke_separator_for_integration(
target_integration, current, target, parsed_options
),
refresh_hint=(
"To overwrite customizations, re-run with "
"[cyan]specify integration switch ... --refresh-shared-infra[/cyan]."
),
)
Comment thread
Quratulain-bilal marked this conversation as resolved.
if os.name != "nt":
ensure_executable_scripts(project_root)
Expand Down
199 changes: 160 additions & 39 deletions src/specify_cli/shared_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
from .integrations.manifest import IntegrationManifest


class SymlinkedSharedPathError(ValueError):
"""Raised when a shared infrastructure path or ancestor is a symlink.

Distinct from other unsafe-path errors so callers can preserve symlinked
destinations as customizations while still letting genuine safety errors
(e.g. path escape, not-a-directory) propagate and abort the operation.
"""


def load_speckit_manifest(
project_path: Path,
*,
Expand Down Expand Up @@ -89,7 +98,7 @@ def _ensure_safe_shared_directory(project_path: Path, directory: Path, *, create
current = current / part
label = _shared_destination_label(project_path, current)
if current.is_symlink():
raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}")
raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}")
if current.exists():
if not current.is_dir():
raise ValueError(f"Shared infrastructure directory path is not a directory: {label}")
Expand All @@ -102,7 +111,7 @@ def _ensure_safe_shared_directory(project_path: Path, directory: Path, *, create
raise ValueError(f"Shared infrastructure directory does not exist: {label}")
current.mkdir()
if current.is_symlink():
raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}")
raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}")
try:
current.resolve().relative_to(root)
except (OSError, ValueError):
Expand All @@ -119,7 +128,7 @@ def _validate_safe_shared_directory(project_path: Path, directory: Path) -> None
current = current / part
label = _shared_destination_label(project_path, current)
if current.is_symlink():
raise ValueError(f"Refusing to use symlinked shared infrastructure directory: {label}")
raise SymlinkedSharedPathError(f"Refusing to use symlinked shared infrastructure directory: {label}")
if not current.exists():
continue
if not current.is_dir():
Expand All @@ -145,7 +154,7 @@ def _ensure_safe_shared_destination(
_validate_safe_shared_directory(project_path, dest.parent)
label = _shared_destination_label(project_path, dest)
if dest.is_symlink():
raise ValueError(f"Refusing to overwrite symlinked shared infrastructure path: {label}")
raise SymlinkedSharedPathError(f"Refusing to overwrite symlinked shared infrastructure path: {label}")

if dest.exists():
try:
Expand Down Expand Up @@ -242,58 +251,147 @@ def install_shared_infra(
console: Any,
force: bool = False,
invoke_separator: str = ".",
refresh_managed: bool = False,
refresh_hint: str | None = None,
) -> bool:
"""Install shared scripts and templates into *project_path*."""
"""Install shared scripts and templates into *project_path*.

When ``refresh_managed`` is True, files whose on-disk hash still matches
the previously recorded manifest hash are overwritten with the bundled
version. Files whose hash diverges are treated as user customizations and
preserved with a warning. ``force=True`` overwrites every regular file
(symlinks and symlinked-parent destinations are always preserved with a
warning — the safe-destination check refuses to follow them so writes
cannot escape the project root). ``refresh_hint`` is shown after the
customization warning to tell the user which flag would overwrite their
customizations.
"""
from .integrations.manifest import _sha256

manifest = load_speckit_manifest(project_path, version=version, console=console)
prior_hashes = dict(manifest.files)

def _is_managed(rel: str, dst: Path) -> bool:
expected = prior_hashes.get(rel)
if not expected or not dst.is_file() or dst.is_symlink():
return False
try:
return _sha256(dst) == expected
except OSError:
return False

skipped_files: list[str] = []
preserved_user_files: list[str] = []
symlinked_files: list[str] = []
planned_copies: list[tuple[Path, str, bytes, int]] = []
planned_templates: list[tuple[Path, str, str]] = []

def _decide_overwrite(rel: str, dst: Path) -> tuple[bool, str | None]:
"""Return (write, bucket) where bucket is 'skip', 'preserved', or None."""
if not dst.exists():
return True, None
if force:
return True, None
if refresh_managed:
if _is_managed(rel, dst):
return True, None
if rel in prior_hashes:
return False, "preserved"
return False, "skip"
return False, "skip"

def _safe_dest_or_bucket(dst: Path, rel: str, *, parent_must_exist: bool = True) -> bool:
"""Run the safe-destination check and bucket symlinked paths.

Returns True when the destination is safe to consider (write or skip).
Returns False (and records *rel* under ``symlinked_files``) when the
destination or any of its ancestors is a symlink — those paths can't
be written to safely, but they shouldn't abort the whole switch
either. They're surfaced as a separate "symlinked" warning bucket.

Other unsafe-path errors (e.g. path escape, parent-not-a-directory)
are NOT caught here: they re-raise so the operation aborts, since
treating them as "symlinked" would mask security-relevant failures.
"""
try:
_ensure_safe_shared_destination(project_path, dst, parent_must_exist=parent_must_exist)
except SymlinkedSharedPathError:
symlinked_files.append(rel)
return False
return True

def _ensure_or_bucket_dir(directory: Path) -> bool:
"""Create *directory* unless an ancestor is symlinked.

Returns True when the directory is safe to use. Returns False (and
records the path under ``symlinked_files``) when a symlink ancestor
forces us to skip the whole subtree. Other unsafe-path errors
(escape, not-a-directory) re-raise so the operation aborts.
"""
try:
_ensure_safe_shared_directory(project_path, directory)
except SymlinkedSharedPathError:
symlinked_files.append(directory.relative_to(project_path).as_posix())
return False
return True

scripts_src = shared_scripts_source(core_pack=core_pack, repo_root=repo_root)
if scripts_src.is_dir():
dest_scripts = project_path / ".specify" / "scripts"
_ensure_safe_shared_directory(project_path, dest_scripts)
variant_dir = "bash" if script_type == "sh" else "powershell"
variant_src = scripts_src / variant_dir
if variant_src.is_dir():
dest_variant = dest_scripts / variant_dir
_ensure_safe_shared_directory(project_path, dest_variant)
for src_path in variant_src.rglob("*"):
if not src_path.is_file():
continue

rel_path = src_path.relative_to(variant_src)
dst_path = dest_variant / rel_path
_ensure_safe_shared_destination(project_path, dst_path, parent_must_exist=False)
if dst_path.exists() and not force:
skipped_files.append(dst_path.relative_to(project_path).as_posix())
continue

_ensure_safe_shared_directory(project_path, dst_path.parent)
rel = dst_path.relative_to(project_path).as_posix()
planned_copies.append((dst_path, rel, src_path.read_bytes(), src_path.stat().st_mode & 0o777))
if _ensure_or_bucket_dir(dest_scripts):
variant_dir = "bash" if script_type == "sh" else "powershell"
variant_src = scripts_src / variant_dir
if variant_src.is_dir():
dest_variant = dest_scripts / variant_dir
if _ensure_or_bucket_dir(dest_variant):
for src_path in variant_src.rglob("*"):
if not src_path.is_file():
continue

rel_path = src_path.relative_to(variant_src)
dst_path = dest_variant / rel_path
rel = dst_path.relative_to(project_path).as_posix()
if not _safe_dest_or_bucket(dst_path, rel, parent_must_exist=False):
continue
write, bucket = _decide_overwrite(rel, dst_path)
if not write:
if bucket == "preserved":
preserved_user_files.append(rel)
else:
skipped_files.append(rel)
continue

if not _ensure_or_bucket_dir(dst_path.parent):
continue
planned_copies.append((dst_path, rel, src_path.read_bytes(), src_path.stat().st_mode & 0o777))

templates_src = shared_templates_source(core_pack=core_pack, repo_root=repo_root)
if templates_src.is_dir():
dest_templates = project_path / ".specify" / "templates"
_ensure_safe_shared_directory(project_path, dest_templates)
for src in templates_src.iterdir():
if not src.is_file() or src.name == "vscode-settings.json" or src.name.startswith("."):
continue
if _ensure_or_bucket_dir(dest_templates):
for src in templates_src.iterdir():
if not src.is_file() or src.name == "vscode-settings.json" or src.name.startswith("."):
continue

dst = dest_templates / src.name
_ensure_safe_shared_destination(project_path, dst)
if dst.exists() and not force:
skipped_files.append(dst.relative_to(project_path).as_posix())
continue
dst = dest_templates / src.name
rel = dst.relative_to(project_path).as_posix()
if not _safe_dest_or_bucket(dst, rel):
continue
write, bucket = _decide_overwrite(rel, dst)
if not write:
if bucket == "preserved":
preserved_user_files.append(rel)
else:
skipped_files.append(rel)
continue

content = src.read_text(encoding="utf-8")
content = IntegrationBase.resolve_command_refs(content, invoke_separator)
rel = dst.relative_to(project_path).as_posix()
planned_templates.append((dst, rel, content))
content = src.read_text(encoding="utf-8")
content = IntegrationBase.resolve_command_refs(content, invoke_separator)
planned_templates.append((dst, rel, content))

for dst_path, rel, content, mode in planned_copies:
_ensure_safe_shared_directory(project_path, dst_path.parent)
if not _ensure_or_bucket_dir(dst_path.parent):
continue
_write_shared_bytes(project_path, dst_path, content, mode=mode)
manifest.record_existing(rel)

Expand All @@ -313,5 +411,28 @@ def install_shared_infra(
"[cyan]specify integration upgrade --force[/cyan]."
)

if symlinked_files:
console.print(
f"[yellow]⚠[/yellow] Skipped {len(symlinked_files)} symlinked shared "
"infrastructure file(s) — symlinks are never overwritten because they "
"may resolve outside the project root:"
)
for path in symlinked_files:
console.print(f" {path}")
console.print(
"To restore the bundled version, remove or replace the symlink manually, "
"then re-run the command."
)

if preserved_user_files:
console.print(
f"[yellow]⚠[/yellow] Preserved {len(preserved_user_files)} customized shared "
"infrastructure file(s) (hash differs from previous install):"
)
for path in preserved_user_files:
console.print(f" {path}")
if refresh_hint:
console.print(refresh_hint)

manifest.save()
return True
Loading