Skip to content
Open
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
75c1d43
Refactoring of basic functionality to create an empty Array
Jan 24, 2023
b14aa91
Replace dim4 with CShape
roaffix Jan 24, 2023
eadbe9b
Add tests. Minor fixes. Update CI
roaffix Jan 24, 2023
c13a59f
Fix CI
roaffix Jan 24, 2023
f0f57e8
Add arithmetic operators w/o tests
roaffix Jan 26, 2023
8cef774
Fix array init bug. Add __getitem__. Change pytest for active debug mode
roaffix Jan 27, 2023
a4c7ac9
Add reflected arithmetic and array operators
roaffix Jan 27, 2023
4140527
Place TODO for repr
roaffix Jan 28, 2023
4374d93
Add bitwise operators. Add in-place operators. Add missing reflected …
roaffix Jan 28, 2023
5a29ffa
Fix tests
roaffix Jan 28, 2023
4187b27
Add tests for arithmetic operators
roaffix Jan 28, 2023
cdb7a92
Added to_list and to_ctypes_array
roaffix Jan 28, 2023
9c0435a
Fix bug when scalar is empty returns None
roaffix Jan 28, 2023
769c16c
Fix typing in array object. Add tests
roaffix Jan 29, 2023
fb27e46
Change tests and found bug with reflected operators
roaffix Jan 29, 2023
0afb92e
Fix reflected operators bug. Add test coverage for the rest of the ar…
roaffix Jan 29, 2023
1d071be
Add required by specification methods
roaffix Jan 30, 2023
04fbb1b
Change file names
roaffix Jan 30, 2023
2d91b04
Change utils. Add docstrings
roaffix Jan 30, 2023
5939388
Add docstrings for operators
roaffix Jan 30, 2023
0231e27
Change TODOs
roaffix Jan 30, 2023
07c4206
Add docstrings for other operators. Remove docstrings from mocks
roaffix Jan 30, 2023
908447b
Change tags and typings
roaffix Feb 4, 2023
fa3ad06
Change typings from python 3.10 to python 3.8
roaffix Feb 4, 2023
0de9955
Add readme with reference to run tests
roaffix Feb 4, 2023
ae6be05
Revert changes accidentally made in original array
roaffix Feb 5, 2023
cfa9114
Add initial refactoring with backend mock
roaffix Feb 8, 2023
5de8694
Add c library methods for operators
roaffix Feb 9, 2023
b9ac1c5
Remove dependency on default backend
roaffix Feb 10, 2023
171ec88
Refactor backend and project structure
roaffix Feb 10, 2023
e984caa
Refactor backend library operators
roaffix Feb 11, 2023
0b164d4
Refactor used in array_object backend methods
roaffix Feb 11, 2023
282f860
Minor test fix
roaffix Feb 11, 2023
54f7ada
Refactor tests
roaffix Feb 13, 2023
51f6efd
Add comparison operators tests
roaffix Feb 13, 2023
23e2635
Minor fixes for tests
roaffix Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add comparison operators tests
  • Loading branch information
roaffix committed Feb 13, 2023
commit 51f6efde2d4bf3c8b9aaa8472c4da72f7555181f
61 changes: 51 additions & 10 deletions arrayfire/array_api/tests/array_object/test_operators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import operator
from typing import Any, List, Union
from typing import Any, Callable, List, Union

import pytest

from arrayfire.array_api.array_object import Array
from arrayfire.array_api.dtypes import bool as af_bool

Operator = Callable[[Union[int, float, Array], Union[int, float, Array]], Array]


# HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/
def _round(list_: List[Union[int, float]], symbols: int = 4) -> List[Union[int, float]]:
# HACK replace for e.g. abs(x1-x2) < 1e-6 ~ https://davidamos.dev/the-right-way-to-compare-floats-in-python/
return [round(x, symbols) for x in list_]


Expand All @@ -17,15 +20,29 @@ def pytest_generate_tests(metafunc: Any) -> None:
[1, 2, 3],
# [4.2, 7.5, 5.41] # FIXME too big difference between python pow and af backend
])
if "op_origin" in metafunc.fixturenames:
metafunc.parametrize("op_origin", [
if "arithmetic_operator" in metafunc.fixturenames:
metafunc.parametrize("arithmetic_operator", [
"add", # __add__, __iadd__, __radd__
"sub", # __sub__, __isub__, __rsub__
"mul", # __mul__, __imul__, __rmul__
"truediv", # __truediv__, __itruediv__, __rtruediv__
# "floordiv", # __floordiv__, __ifloordiv__, __rfloordiv__ # TODO
"mod", # __mod__, __imod__, __rmod__
"pow" # __pow__, __ipow__, __rpow__,
"pow", # __pow__, __ipow__, __rpow__,
])
if "array_operator" in metafunc.fixturenames:
metafunc.parametrize("array_operator", [
operator.matmul,
operator.imatmul
])
if "comparison_operator" in metafunc.fixturenames:
metafunc.parametrize("comparison_operator", [
operator.lt,
operator.le,
operator.gt,
operator.ge,
operator.eq,
operator.ne
])
if "operand" in metafunc.fixturenames:
metafunc.parametrize("operand", [
Expand All @@ -43,10 +60,10 @@ def pytest_generate_tests(metafunc: Any) -> None:


def test_arithmetic_operators(
array_origin: List[Union[int, float]], op_origin: str,
array_origin: List[Union[int, float]], arithmetic_operator: str,
operand: Union[int, float, List[Union[int, float]]]) -> None:
op = getattr(operator, op_origin)
iop = getattr(operator, "i" + op_origin)
op = getattr(operator, arithmetic_operator)
iop = getattr(operator, "i" + arithmetic_operator)

if isinstance(operand, list):
ref = [op(x, y) for x, y in zip(array_origin, operand)]
Expand All @@ -73,8 +90,32 @@ def test_arithmetic_operators(


def test_arithmetic_operators_expected_to_raise_error(
array_origin: List[Union[int, float]], op_origin: str, false_operand: Any) -> None:
array_origin: List[Union[int, float]], arithmetic_operator: str, false_operand: Any) -> None:
array = Array(array_origin)
op = getattr(operator, op_origin)
op = getattr(operator, arithmetic_operator)
with pytest.raises(TypeError):
op(array, false_operand)


def test_comparison_operators(
array_origin: List[Union[int, float]], comparison_operator: Operator,
operand: Union[int, float, List[Union[int, float]]]) -> None:
if isinstance(operand, list):
ref = [comparison_operator(x, y) for x, y in zip(array_origin, operand)]
operand = Array(operand) # type: ignore[assignment]
else:
ref = [comparison_operator(x, operand) for x in array_origin]

array = Array(array_origin)
res = comparison_operator(array, operand) # type: ignore[arg-type]

assert res.to_list() == ref
assert res.dtype == af_bool


def test_comparison_operators_expected_to_raise_error(
array_origin: List[Union[int, float]], comparison_operator: Operator, false_operand: Any) -> None:
array = Array(array_origin)

with pytest.raises(TypeError):
comparison_operator(array, false_operand)