gstreamer_analytics/
tensor.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6
7glib::wrapper! {
8    /// Hold tensor data
9    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
10    #[doc(alias = "GstTensor")]
11    pub struct Tensor(Boxed<ffi::GstTensor>);
12
13    match fn {
14        copy => |ptr| ffi::gst_tensor_copy(ptr),
15        free => |ptr| ffi::gst_tensor_free(ptr),
16        type_ => || ffi::gst_tensor_get_type(),
17    }
18}
19
20unsafe impl Send for Tensor {}
21unsafe impl Sync for Tensor {}
22
23impl Tensor {
24    /// Allocates a new [`Tensor`][crate::Tensor] of `dims_order` ROW_MAJOR or COLUMN_MAJOR and
25    /// with an interleaved layout.
26    ///
27    /// For example, a two-dimensional tensor with 32 rows and 4 columns, `dims` would
28    /// be the two element array `[32, 4]`.
29    /// ## `id`
30    /// semantically identify the contents of the tensor
31    /// ## `data_type`
32    /// [`TensorDataType`][crate::TensorDataType] of tensor data
33    /// ## `data`
34    /// [`gst::Buffer`][crate::gst::Buffer] holding tensor data
35    /// ## `dims_order`
36    /// Indicate tensor dimension indexing order
37    /// ## `dims`
38    /// size of tensor in each dimension.
39    ///  A value of 0 means the dimension is dynamic.
40    ///
41    /// # Returns
42    ///
43    /// A newly allocated [`Tensor`][crate::Tensor]
44    #[doc(alias = "gst_tensor_new_simple")]
45    pub fn new_simple(
46        id: glib::Quark,
47        data_type: TensorDataType,
48        data: gst::Buffer,
49        dims_order: TensorDimOrder,
50        dims: &[usize],
51    ) -> Tensor {
52        skip_assert_initialized!();
53        unsafe {
54            from_glib_full(ffi::gst_tensor_new_simple(
55                id.into_glib(),
56                data_type.into_glib(),
57                data.into_glib_ptr(),
58                dims_order.into_glib(),
59                dims.len(),
60                dims.as_ptr() as *mut _,
61            ))
62        }
63    }
64
65    /// Gets the dimensions of the tensor.
66    ///
67    /// # Returns
68    ///
69    /// The dims array form the tensor
70    #[doc(alias = "gst_tensor_get_dims")]
71    #[doc(alias = "get_dims")]
72    pub fn dims(&self) -> &[usize] {
73        let mut num_dims: usize = 0;
74        unsafe {
75            let dims = ffi::gst_tensor_get_dims(self.as_ptr(), &mut num_dims);
76            std::slice::from_raw_parts(dims as *const _, num_dims)
77        }
78    }
79
80    #[inline]
81    pub fn id(&self) -> glib::Quark {
82        unsafe { from_glib(self.inner.id) }
83    }
84
85    #[inline]
86    pub fn data_type(&self) -> TensorDataType {
87        unsafe { from_glib(self.inner.data_type) }
88    }
89
90    #[inline]
91    pub fn data(&self) -> &gst::BufferRef {
92        unsafe { gst::BufferRef::from_ptr(self.inner.data) }
93    }
94
95    #[inline]
96    pub fn data_mut(&mut self) -> &mut gst::BufferRef {
97        unsafe {
98            self.inner.data = gst::ffi::gst_mini_object_make_writable(self.inner.data as _) as _;
99            gst::BufferRef::from_mut_ptr(self.inner.data)
100        }
101    }
102
103    #[inline]
104    pub fn dims_order(&self) -> TensorDimOrder {
105        unsafe { from_glib(self.inner.dims_order) }
106    }
107
108    /// Validate the tensor whether it mathces the reading order, dimensions and the data type.
109    /// Validate whether the [`gst::Buffer`][crate::gst::Buffer] has enough size to hold the tensor data.
110    /// ## `data_type`
111    /// The data type of the tensor
112    /// ## `order`
113    /// The order of the tensor to read from the memory
114    /// ## `dims`
115    /// An optional array of dimensions, where G_MAXSIZE means ANY.
116    ///
117    /// # Returns
118    ///
119    /// TRUE if the [`Tensor`][crate::Tensor] has the reading order from the memory matching `order`,
120    /// dimensions matching `num_dims`, data type matching `data_type`
121    /// Otherwise FALSE will be returned.
122    #[cfg(feature = "v1_28")]
123    #[cfg_attr(docsrs, doc(cfg(feature = "v1_28")))]
124    #[doc(alias = "gst_tensor_check_type")]
125    pub fn check_type(
126        &self,
127        data_type: crate::TensorDataType,
128        order: crate::TensorDimOrder,
129        dims: &[usize],
130    ) -> bool {
131        unsafe {
132            from_glib(ffi::gst_tensor_check_type(
133                self.to_glib_none().0,
134                data_type.into_glib(),
135                order.into_glib(),
136                dims.len(),
137                dims.as_ptr(),
138            ))
139        }
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use crate::*;
146
147    #[test]
148    fn create_tensor() {
149        gst::init().unwrap();
150
151        let buf = gst::Buffer::with_size(2 * 3 * 4 * 5).unwrap();
152        assert_eq!(buf.size(), 2 * 3 * 4 * 5);
153
154        let mut tensor = Tensor::new_simple(
155            glib::Quark::from_str("me"),
156            TensorDataType::Int16,
157            buf,
158            TensorDimOrder::RowMajor,
159            &[3, 4, 5],
160        );
161
162        assert_eq!(tensor.id(), glib::Quark::from_str("me"));
163        assert_eq!(tensor.data_type(), TensorDataType::Int16);
164        assert_eq!(tensor.dims_order(), TensorDimOrder::RowMajor);
165        assert_eq!(tensor.dims()[0], 3);
166        assert_eq!(tensor.dims()[1], 4);
167        assert_eq!(tensor.dims()[2], 5);
168        assert_eq!(tensor.data().size(), 2 * 3 * 4 * 5);
169
170        tensor.data();
171        tensor.data_mut();
172    }
173}