Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,13 @@ nash_mtl = [
cagrad = [
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
]
qpth = [
"qpth>=0.0.15",
]
full = [
"cvxpy>=1.3.0", # No Clarabel solver before 1.3.0
"ecos>=2.0.14", # Does not work before 2.0.14
"qpth>=0.0.15",
]

[tool.pytest.ini_options]
Expand Down
8 changes: 6 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class DualProjWeighting(GramianWeighting):
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param solver: The solver used to optimize the underlying optimization problem. Use
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
device of the input tensors (requires the optional ``qpth`` package).
"""

def __init__(
Expand Down Expand Up @@ -90,7 +92,9 @@ class DualProj(GramianWeightedAggregator):
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param solver: The solver used to optimize the underlying optimization problem. Use
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
device of the input tensors (requires the optional ``qpth`` package).
"""

gramian_weighting: DualProjWeighting
Expand Down
8 changes: 6 additions & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class UPGradWeighting(GramianWeighting):
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param solver: The solver used to optimize the underlying optimization problem. Use
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
device of the input tensors (requires the optional ``qpth`` package).
"""

def __init__(
Expand Down Expand Up @@ -93,7 +95,9 @@ class UPGrad(GramianWeightedAggregator):
numerical errors when computing the gramian, it might not exactly be positive definite.
This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian
ensures that it is positive definite.
:param solver: The solver used to optimize the underlying optimization problem.
:param solver: The solver used to optimize the underlying optimization problem. Use
``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the
device of the input tensors (requires the optional ``qpth`` package).
"""

gramian_weighting: UPGradWeighting
Expand Down
53 changes: 51 additions & 2 deletions src/torchjd/aggregation/_utils/dual_cone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from qpsolvers import solve_qp
from torch import Tensor

SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"]
SUPPORTED_SOLVER: TypeAlias = Literal["quadprog", "qpth"]


def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor:
Expand All @@ -15,10 +15,15 @@

:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
:param solver: The quadratic programming solver to use.
:param solver: The quadratic programming solver to use. ``"quadprog"`` converts tensors to
CPU numpy arrays and uses qpsolvers. ``"qpth"`` solves natively on the same device as
the input tensors (e.g. CUDA) using the ``qpth`` package (optional dependency).
:return: A tensor of projection weights with the same shape as `U`.
"""

if solver == "qpth":
return _project_weights_qpth(U, G)

Check warning on line 25 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L25

Added line #L25 was not covered by tests

G_ = _to_array(G)
U_ = _to_array(U)

Expand All @@ -27,6 +32,50 @@
return torch.as_tensor(W, device=G.device, dtype=G.dtype)


def _project_weights_qpth(U: Tensor, G: Tensor) -> Tensor:
r"""
Computes the tensor of projection weights using qpth, keeping computation on the device of
the input tensors and running without gradient tracking.

:param U: The tensor of weights to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite.
"""
from qpth.qp import QPFunction # lazy import: qpth is an optional dependency

Check warning on line 43 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L43

Added line #L43 was not covered by tests

shape = U.shape
m = shape[-1]
batch_size = U.numel() // m
device = G.device
original_dtype = G.dtype

Check warning on line 49 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L45-L49

Added lines #L45 - L49 were not covered by tests

# Use float64 for numerical precision, matching the quadprog solver's behavior.
U_flat = U.reshape(batch_size, m).double()
G_double = G.double()

Check warning on line 53 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L52-L53

Added lines #L52 - L53 were not covered by tests

# QP formulation: minimize (1/2) v^T (2G) v + 0^T v subject to -I v <= -u (i.e., u <= v)
Q = (2.0 * G_double).unsqueeze(0).expand(batch_size, m, m).contiguous()
p = torch.zeros(batch_size, m, device=device, dtype=torch.float64)
G_ineq = (

Check warning on line 58 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L56-L58

Added lines #L56 - L58 were not covered by tests
(-torch.eye(m, device=device, dtype=torch.float64))
.unsqueeze(0)
.expand(batch_size, m, m)
.contiguous()
)
h_ineq = -U_flat
A = torch.zeros(batch_size, 0, m, device=device, dtype=torch.float64)
b = torch.zeros(batch_size, 0, device=device, dtype=torch.float64)

Check warning on line 66 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L64-L66

Added lines #L64 - L66 were not covered by tests

with torch.no_grad():
W_flat = QPFunction(verbose=False, maxIter=10, check_Q_spd=False, notImprovedLim=1)(

Check warning on line 69 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L68-L69

Added lines #L68 - L69 were not covered by tests
Q, p, G_ineq, h_ineq, A, b
)

if torch.any(torch.isnan(W_flat)):
raise ValueError("Failed to solve the quadratic programming problem.")

Check warning on line 74 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L73-L74

Added lines #L73 - L74 were not covered by tests

return W_flat.to(original_dtype).reshape(shape)

Check warning on line 76 in src/torchjd/aggregation/_utils/dual_cone.py

View check run for this annotation

Codecov / codecov/patch

src/torchjd/aggregation/_utils/dual_cone.py#L76

Added line #L76 was not covered by tests


def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray:
r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
Expand Down
35 changes: 30 additions & 5 deletions tests/unit/aggregation/test_dualproj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib.util

import torch
from pytest import mark, raises
from pytest import mark, param, raises
from torch import Tensor
from utils.tensors import ones_

Expand All @@ -15,28 +17,44 @@
)
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices

_has_qpth = importlib.util.find_spec("qpth") is not None
_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed")

scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices]
typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices]
non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices]
requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))]

_qpth_typical_pairs = [
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices
]
_qpth_non_strong_pairs = [
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices
]
_qpth_scaled_pairs = [
param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices
]

@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)

@mark.parametrize(
["aggregator", "matrix"],
scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs,
)
def test_expected_structure(aggregator: DualProj, matrix: Tensor) -> None:
assert_expected_structure(aggregator, matrix)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None:
assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None:
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07)


@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs)
def test_strongly_stationary(aggregator: DualProj, matrix: Tensor) -> None:
assert_strongly_stationary(aggregator, matrix, threshold=3e-03)

Expand Down Expand Up @@ -66,6 +84,13 @@ def test_representations() -> None:
assert str(A) == "DualProj([1., 2., 3.])"


@_skip_no_qpth
def test_representations_qpth() -> None:
A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth")
assert repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')"
assert str(A) == "DualProj"


def test_pref_vector_setter_updates_value() -> None:
A = DualProj()
new_pref = torch.tensor([1.0, 2.0, 3.0])
Expand Down
37 changes: 31 additions & 6 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib.util

import torch
from pytest import mark, raises
from pytest import mark, param, raises
from torch import Tensor
from utils.tensors import ones_

Expand All @@ -16,33 +18,49 @@
)
from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices

_has_qpth = importlib.util.find_spec("qpth") is not None
_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed")

scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices]
typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices]
non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices]
requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))]

_qpth_typical_pairs = [
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices
]
_qpth_non_strong_pairs = [
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices
]
_qpth_scaled_pairs = [
param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices
]

@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs)

@mark.parametrize(
["aggregator", "matrix"],
scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs,
)
def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None:
assert_expected_structure(aggregator, matrix)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None:
assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None:
assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07)


@mark.parametrize(["aggregator", "matrix"], typical_pairs)
@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs)
def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None:
assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02)


@mark.parametrize(["aggregator", "matrix"], non_strong_pairs)
@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs)
def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None:
assert_strongly_stationary(aggregator, matrix, threshold=5e-03)

Expand Down Expand Up @@ -70,6 +88,13 @@ def test_representations() -> None:
assert str(A) == "UPGrad([1., 2., 3.])"


@_skip_no_qpth
def test_representations_qpth() -> None:
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth")
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')"
assert str(A) == "UPGrad"


def test_pref_vector_setter_updates_value() -> None:
A = UPGrad()
new_pref = torch.tensor([1.0, 2.0, 3.0])
Expand Down
Loading