Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adds Borrrow trait bound wherever necessary.
  • Loading branch information
nikhilr612 committed Jan 23, 2025
commit 0752688b4adfc4f392d644347774d93a142d1016
22 changes: 15 additions & 7 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use super::core::{
HANDLE_ERROR,
};

use core::borrow::Borrow;

use libc::{c_double, c_int, c_uint};

extern "C" {
Expand Down Expand Up @@ -145,14 +147,20 @@ extern "C" {
macro_rules! dim_reduce_func_def {
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
#[doc=$doc_str]
pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> Array<$out_type>
pub fn $fn_name<T, A>(input: A, dim: i32) -> Array<$out_type>
where
T: HasAfEnum,
$out_type: HasAfEnum,
A: Borrow<Array<T>>,
{
let mut temp: af_array = std::ptr::null_mut();
let err_val =
unsafe { $ffi_name(std::ptr::from_mut::<af_array>(&mut temp), input.get(), dim) };
let err_val = unsafe {
$ffi_name(
std::ptr::from_mut::<af_array>(&mut temp),
input.borrow().get(),
dim,
)
};
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
Expand Down Expand Up @@ -1239,7 +1247,7 @@ where
/// - `key` is the key Array
/// - `input` is the data on which scan is to be performed
/// - `dim` is the dimension along which scan operation is to be performed
/// - `op` takes value of [BinaryOp](./enum.BinaryOp.html) enum indicating
/// - `op` takes value of [`BinaryOp`](./enum.BinaryOp.html) enum indicating
/// the type of scan operation
/// - `inclusive` says if inclusive/exclusive scan is to be performed
///
Expand All @@ -1261,7 +1269,7 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_scan_by_key(
&mut temp as *mut af_array,
std::ptr::from_mut::<af_array>(&mut temp),
key.get(),
input.get(),
dim,
Expand Down Expand Up @@ -1301,8 +1309,8 @@ macro_rules! dim_reduce_by_key_func_def {
let mut out_vals: af_array = std::ptr::null_mut();
let err_val = unsafe {
$ffi_name(
&mut out_keys as *mut af_array,
&mut out_vals as *mut af_array,
std::ptr::from_mut::<af_array>(&mut out_keys),
std::ptr::from_mut::<af_array>(&mut out_vals),
keys.get(),
vals.get(),
dim,
Expand Down
4 changes: 2 additions & 2 deletions src/core/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ binary_func!(
af_hypot
);

/// Type Trait to convert to an [Array](./struct.Array.html)
/// Type Trait to convert to an [`Array`](./struct.Array.html)
///
/// Generic functions that overload the binary operations such as add, div, mul, rem, ge etc. are
/// bound by this trait to allow combinations of scalar values and Array objects as parameters
Expand Down Expand Up @@ -991,7 +991,7 @@ where
T: HasAfEnum + IntegralType,
{
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe { af_bitnot(&mut temp as *mut af_array, input.get()) };
let err_val = unsafe { af_bitnot(std::ptr::from_mut::<af_array>(&mut temp), input.get()) };
HANDLE_ERROR(AfError::from(err_val));
temp.into()
}
35 changes: 19 additions & 16 deletions src/core/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ extern "C" {
///
/// ### NOTE
///
/// All operators(traits) from std::ops module implemented for Array object
/// All operators(traits) from [`std::ops`] module implemented for Array object
/// carry out element wise operations. For example, `*` does multiplication of
/// elements at corresponding locations in two different Arrays.
pub struct Array<T> {
Expand All @@ -190,7 +190,7 @@ macro_rules! is_func {
pub fn $fn_name(&self) -> bool {
unsafe {
let mut ret_val: bool = false;
let err_val = $ffi_fn(&mut ret_val as *mut bool, self.handle);
let err_val = $ffi_fn(std::ptr::from_mut::<bool>(&mut ret_val), self.handle);
HANDLE_ERROR(AfError::from(err_val));
ret_val
}
Expand All @@ -214,7 +214,7 @@ where
/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
/// print(&indices);
/// ```
/// An example of creating an Array from half::f16 array
/// An example of creating an Array from [`half::f16`] array
///
/// ```rust
/// use arrayfire::{Array, Dim4, is_half_available, print};
Expand All @@ -239,10 +239,10 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_create_array(
&mut temp as *mut af_array,
slice.as_ptr() as *const c_void,
std::ptr::from_mut::<af_array>(&mut temp),
slice.as_ptr().cast::<c_void>(),
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
dims.get().as_ptr().cast::<c_longlong>(),
aftype as c_uint,
)
};
Expand All @@ -259,12 +259,12 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_create_strided_array(
&mut temp as *mut af_array,
slice.as_ptr() as *const c_void,
std::ptr::from_mut::<af_array>(&mut temp),
slice.as_ptr().cast::<c_void>(),
offset as dim_t,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
strides.get().as_ptr() as *const c_longlong,
dims.get().as_ptr().cast::<c_longlong>(),
strides.get().as_ptr().cast::<c_longlong>(),
aftype as c_uint,
1_u32,
)
Expand All @@ -287,9 +287,9 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_create_handle(
&mut temp as *mut af_array,
std::ptr::from_mut::<af_array>(&mut temp),
dims.ndims() as c_uint,
dims.get().as_ptr() as *const c_longlong,
dims.get().as_ptr().cast::<c_longlong>(),
aftype as c_uint,
)
};
Expand Down Expand Up @@ -374,7 +374,7 @@ where
let mut temp: af_array = std::ptr::null_mut();
let err_val = unsafe {
af_device_array(
&mut temp as *mut af_array,
std::ptr::from_mut::<af_array>(&mut temp),
dev_ptr as *mut c_void,
dims.ndims() as c_uint,
dims.get().as_ptr() as *const dim_t,
Expand All @@ -393,7 +393,8 @@ where
/// was active when Array was created.
pub fn get_backend(&self) -> Backend {
let mut ret_val: u32 = 0;
let err_val = unsafe { af_get_backend_id(&mut ret_val as *mut c_uint, self.handle) };
let err_val =
unsafe { af_get_backend_id(std::ptr::from_mut::<c_uint>(&mut ret_val), self.handle) };
HANDLE_ERROR(AfError::from(err_val));
match (err_val, ret_val) {
(0, 1) => Backend::CPU,
Expand All @@ -410,15 +411,17 @@ where
/// Return the device id on which Array was created.
pub fn get_device_id(&self) -> i32 {
let mut ret_val: i32 = 0;
let err_val = unsafe { af_get_device_id(&mut ret_val as *mut c_int, self.handle) };
let err_val =
unsafe { af_get_device_id(std::ptr::from_mut::<c_int>(&mut ret_val), self.handle) };
HANDLE_ERROR(AfError::from(err_val));
ret_val
}

/// Returns the number of elements in the Array
pub fn elements(&self) -> usize {
let mut ret_val: dim_t = 0;
let err_val = unsafe { af_get_elements(&mut ret_val as *mut dim_t, self.handle) };
let err_val =
unsafe { af_get_elements(std::ptr::from_mut::<dim_t>(&mut ret_val), self.handle) };
HANDLE_ERROR(AfError::from(err_val));
ret_val as usize
}
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)]
#![warn(missing_docs)]
#![allow(non_camel_case_types)]

#[macro_use]
extern crate lazy_static;

Expand Down