From d5e7f67eb30b8d5753d00a0895af77d13f3ee2d9 Mon Sep 17 00:00:00 2001 From: Craig Watson Date: Fri, 1 May 2026 11:34:23 -0700 Subject: [PATCH 1/2] Take generic ArrayBase in concatenate and stack --- src/stacking.rs | 10 ++++++---- tests/stacking.rs | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/stacking.rs b/src/stacking.rs index 8737d6f60..b2a23f85c 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -33,8 +33,9 @@ use crate::imp_prelude::*; /// [3., 3.]])) /// ); /// ``` -pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +pub fn concatenate(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where + S: Data, A: Clone, D: RemoveAxis, { @@ -66,7 +67,7 @@ where }; for array in arrays { - res.append(axis, array.clone())?; + res.append(axis, array.view())?; } debug_assert_eq!(res.len_of(axis), stacked_dim); Ok(res) @@ -96,8 +97,9 @@ where /// ); /// # } /// ``` -pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +pub fn stack(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where + S: Data, A: Clone, D: Dimension, D::Larger: RemoveAxis, @@ -129,7 +131,7 @@ where }; for array in arrays { - res.append(axis, array.clone().insert_axis(axis))?; + res.append(axis, array.view().insert_axis(axis))?; } debug_assert_eq!(res.len_of(axis), arrays.len()); diff --git a/tests/stacking.rs b/tests/stacking.rs index bdfe478b4..be98d1697 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,4 +1,4 @@ -use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; +use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1, ViewRepr}; #[test] fn concatenating() @@ -29,7 +29,7 @@ fn concatenating() let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + let res: Result, _> = ndarray::concatenate::, _, _>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } @@ -50,6 +50,6 @@ fn stacking() let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]); + let res: Result, _> = ndarray::stack::, Ix1, _>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } From 2816113dac3c21a1dbd02330622e969d23c03cde Mon Sep 17 00:00:00 2001 From: Craig Watson Date: Fri, 1 May 2026 13:32:00 -0700 Subject: [PATCH 2/2] Simplify trait bounds --- src/stacking.rs | 13 ++++++------- tests/stacking.rs | 18 ++++++------------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/src/stacking.rs b/src/stacking.rs index b2a23f85c..eb90d8524 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -33,10 +33,10 @@ use crate::imp_prelude::*; /// [3., 3.]])) /// ); /// ``` -pub fn concatenate(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> +pub fn concatenate(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where - S: Data, - A: Clone, + S: Data, + S::Elem: Clone, D: RemoveAxis, { if arrays.is_empty() { @@ -97,12 +97,11 @@ where /// ); /// # } /// ``` -pub fn stack(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> +pub fn stack(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where - S: Data, - A: Clone, + S: Data, + S::Elem: Clone, D: Dimension, - D::Larger: RemoveAxis, { if arrays.is_empty() { return Err(from_kind(ErrorKind::Unsupported)); diff --git a/tests/stacking.rs b/tests/stacking.rs index be98d1697..7dfceacfc 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,24 +1,19 @@ use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1, ViewRepr}; #[test] -fn concatenating() -{ +fn concatenating() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let c = concatenate![Axis(0), a, b]; - assert_eq!( - c, - arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) - ); + assert_eq!(c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); let d = concatenate![Axis(1), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; - assert_eq!(d, aview2(&[[2., 9.], - [2., 9.]])); + assert_eq!(d, aview2(&[[2., 9.], [2., 9.]])); let d = concatenate![Axis(0), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; assert_eq!(d, aview2(&[[2.], [2.], [9.], [9.]])); @@ -29,13 +24,12 @@ fn concatenating() let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::concatenate::, _, _>(Axis(0), &[]); + let res: Result, _> = ndarray::concatenate::, _>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } #[test] -fn stacking() -{ +fn stacking() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); @@ -50,6 +44,6 @@ fn stacking() let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack::, Ix1, _>(Axis(0), &[]); + let res: Result, _> = ndarray::stack::, Ix1>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); }