diff --git a/src/stacking.rs b/src/stacking.rs index 8737d6f60..eb90d8524 100644 --- a/src/stacking.rs +++ b/src/stacking.rs @@ -33,9 +33,10 @@ use crate::imp_prelude::*; /// [3., 3.]])) /// ); /// ``` -pub fn concatenate(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +pub fn concatenate(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where - A: Clone, + S: Data, + S::Elem: Clone, D: RemoveAxis, { if arrays.is_empty() { @@ -66,7 +67,7 @@ where }; for array in arrays { - res.append(axis, array.clone())?; + res.append(axis, array.view())?; } debug_assert_eq!(res.len_of(axis), stacked_dim); Ok(res) @@ -96,11 +97,11 @@ where /// ); /// # } /// ``` -pub fn stack(axis: Axis, arrays: &[ArrayView]) -> Result, ShapeError> +pub fn stack(axis: Axis, arrays: &[ArrayBase]) -> Result, ShapeError> where - A: Clone, + S: Data, + S::Elem: Clone, D: Dimension, - D::Larger: RemoveAxis, { if arrays.is_empty() { return Err(from_kind(ErrorKind::Unsupported)); @@ -129,7 +130,7 @@ where }; for array in arrays { - res.append(axis, array.clone().insert_axis(axis))?; + res.append(axis, array.view().insert_axis(axis))?; } debug_assert_eq!(res.len_of(axis), arrays.len()); diff --git a/tests/stacking.rs b/tests/stacking.rs index bdfe478b4..7dfceacfc 100644 --- a/tests/stacking.rs +++ b/tests/stacking.rs @@ -1,24 +1,19 @@ -use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1}; +use ndarray::{arr2, arr3, aview1, aview2, concatenate, stack, Array2, Axis, ErrorKind, Ix1, ViewRepr}; #[test] -fn concatenating() -{ +fn concatenating() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::concatenate(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let c = concatenate![Axis(0), a, b]; - assert_eq!( - c, - arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]]) - ); + assert_eq!(c, arr2(&[[2., 2.], [3., 3.], [2., 2.], [3., 3.], [2., 2.], [3., 3.]])); let d = concatenate![Axis(0), a.row(0), &[9., 9.]]; assert_eq!(d, aview1(&[2., 2., 9., 9.])); let d = concatenate![Axis(1), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; - assert_eq!(d, aview2(&[[2., 9.], - [2., 9.]])); + assert_eq!(d, aview2(&[[2., 9.], [2., 9.]])); let d = concatenate![Axis(0), a.row(0).insert_axis(Axis(1)), aview1(&[9., 9.]).insert_axis(Axis(1))]; assert_eq!(d, aview2(&[[2.], [2.], [9.], [9.]])); @@ -29,13 +24,12 @@ fn concatenating() let res = ndarray::concatenate(Axis(2), &[a.view(), c.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::concatenate(Axis(0), &[]); + let res: Result, _> = ndarray::concatenate::, _>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); } #[test] -fn stacking() -{ +fn stacking() { let a = arr2(&[[2., 2.], [3., 3.]]); let b = ndarray::stack(Axis(0), &[a.view(), a.view()]).unwrap(); assert_eq!(b, arr3(&[[[2., 2.], [3., 3.]], [[2., 2.], [3., 3.]]])); @@ -50,6 +44,6 @@ fn stacking() let res = ndarray::stack(Axis(3), &[a.view(), a.view()]); assert_eq!(res.unwrap_err().kind(), ErrorKind::OutOfBounds); - let res: Result, _> = ndarray::stack::<_, Ix1>(Axis(0), &[]); + let res: Result, _> = ndarray::stack::, Ix1>(Axis(0), &[]); assert_eq!(res.unwrap_err().kind(), ErrorKind::Unsupported); }