gstreamer_analytics/
tensor.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use crate::ffi;
4use crate::*;
5use glib::translate::*;
6
7glib::wrapper! {
8    /// Hold tensor data
9    #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
10    #[doc(alias = "GstTensor")]
11    pub struct Tensor(Boxed<ffi::GstTensor>);
12
13    match fn {
14        copy => |ptr| ffi::gst_tensor_copy(ptr),
15        free => |ptr| ffi::gst_tensor_free(ptr),
16        type_ => || ffi::gst_tensor_get_type(),
17    }
18}
19
20unsafe impl Send for Tensor {}
21unsafe impl Sync for Tensor {}
22
23impl Tensor {
24    /// Allocates a new [`Tensor`][crate::Tensor] of `dims_order` ROW_MAJOR or COLUMN_MAJOR and
25    /// with an interleaved layout
26    /// ## `id`
27    /// semantically identify the contents of the tensor
28    /// ## `data_type`
29    /// [`TensorDataType`][crate::TensorDataType] of tensor data
30    /// ## `data`
31    /// [`gst::Buffer`][crate::gst::Buffer] holding tensor data
32    /// ## `dims_order`
33    /// Indicate tensor dimension indexing order
34    /// ## `dims`
35    /// tensor dimensions. Value of 0 mean the
36    /// dimension is dynamic.
37    ///
38    /// # Returns
39    ///
40    /// A newly allocated [`Tensor`][crate::Tensor]
41    #[doc(alias = "gst_tensor_new_simple")]
42    pub fn new_simple(
43        id: glib::Quark,
44        data_type: TensorDataType,
45        data: gst::Buffer,
46        dims_order: TensorDimOrder,
47        dims: &[usize],
48    ) -> Tensor {
49        skip_assert_initialized!();
50        unsafe {
51            from_glib_full(ffi::gst_tensor_new_simple(
52                id.into_glib(),
53                data_type.into_glib(),
54                data.into_glib_ptr(),
55                dims_order.into_glib(),
56                dims.len(),
57                dims.as_ptr() as *mut _,
58            ))
59        }
60    }
61
62    /// Gets the dimensions of the tensor.
63    ///
64    /// # Returns
65    ///
66    /// The dims array form the tensor
67    #[doc(alias = "gst_tensor_get_dims")]
68    #[doc(alias = "get_dims")]
69    pub fn dims(&self) -> &[usize] {
70        let mut num_dims: usize = 0;
71        unsafe {
72            let dims = ffi::gst_tensor_get_dims(self.as_ptr(), &mut num_dims);
73            std::slice::from_raw_parts(dims as *const _, num_dims)
74        }
75    }
76
77    #[inline]
78    pub fn id(&self) -> glib::Quark {
79        unsafe { from_glib(self.inner.id) }
80    }
81
82    #[inline]
83    pub fn data_type(&self) -> TensorDataType {
84        unsafe { from_glib(self.inner.data_type) }
85    }
86
87    #[inline]
88    pub fn data(&self) -> &gst::BufferRef {
89        unsafe { gst::BufferRef::from_ptr(self.inner.data) }
90    }
91
92    #[inline]
93    pub fn data_mut(&mut self) -> &mut gst::BufferRef {
94        unsafe {
95            self.inner.data = gst::ffi::gst_mini_object_make_writable(self.inner.data as _) as _;
96            gst::BufferRef::from_mut_ptr(self.inner.data)
97        }
98    }
99
100    #[inline]
101    pub fn dims_order(&self) -> TensorDimOrder {
102        unsafe { from_glib(self.inner.dims_order) }
103    }
104
105    /// Validate the tensor whether it mathces the reading order, dimensions and the data type.
106    /// Validate whether the [`gst::Buffer`][crate::gst::Buffer] has enough size to hold the tensor data.
107    /// ## `order`
108    /// The order of the tensor to read from the memory
109    /// ## `num_dims`
110    /// The number of dimensions that the tensor can have
111    /// ## `data_type`
112    /// The data type of the tensor
113    /// ## `data`
114    /// [`gst::Buffer`][crate::gst::Buffer] holding tensor data
115    ///
116    /// # Returns
117    ///
118    /// TRUE if the [`Tensor`][crate::Tensor] has the reading order from the memory matching `order`,
119    /// dimensions matching `num_dims`, data type matching `data_type` and the [`gst::Buffer`][crate::gst::Buffer] mathcing `data`
120    /// has enough size to hold the tensor data.
121    /// Otherwise FALSE will be returned.
122    #[cfg(feature = "v1_28")]
123    #[cfg_attr(docsrs, doc(cfg(feature = "v1_28")))]
124    #[doc(alias = "gst_tensor_check_type")]
125    pub fn check_type(
126        &self,
127        order: crate::TensorDimOrder,
128        num_dims: usize,
129        data_type: crate::TensorDataType,
130        data: &gst::BufferRef,
131    ) -> bool {
132        unsafe {
133            from_glib(ffi::gst_tensor_check_type(
134                self.to_glib_none().0,
135                order.into_glib(),
136                num_dims,
137                data_type.into_glib(),
138                mut_override(data.as_ptr()),
139            ))
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use crate::*;
147
148    #[test]
149    fn create_tensor() {
150        gst::init().unwrap();
151
152        let buf = gst::Buffer::with_size(2 * 3 * 4 * 5).unwrap();
153        assert_eq!(buf.size(), 2 * 3 * 4 * 5);
154
155        let mut tensor = Tensor::new_simple(
156            glib::Quark::from_str("me"),
157            TensorDataType::Int16,
158            buf,
159            TensorDimOrder::RowMajor,
160            &[3, 4, 5],
161        );
162
163        assert_eq!(tensor.id(), glib::Quark::from_str("me"));
164        assert_eq!(tensor.data_type(), TensorDataType::Int16);
165        assert_eq!(tensor.dims_order(), TensorDimOrder::RowMajor);
166        assert_eq!(tensor.dims()[0], 3);
167        assert_eq!(tensor.dims()[1], 4);
168        assert_eq!(tensor.dims()[2], 5);
169        assert_eq!(tensor.data().size(), 2 * 3 * 4 * 5);
170
171        tensor.data();
172        tensor.data_mut();
173    }
174}