org.sonar.l10n.py.rules.python.S6729.html Maven / Gradle / Ivy
This rule raises an issue when np.where
is used with only the condition parameter set.
Why is this an issue?
The NumPy function np.where
provides a way to execute operations on an array under a certain condition:
import numpy as np
arr = np.array([1,2,3,4])
result = np.where(arr > 3, arr * 2, arr)
In the example above the np.where
function will multiply all the elements in the array which satisfy the condition: element >
3
by 2. The elements that do not satisfy the condition will be left untouched. The result
array holds now the values 1, 2, 3 and
8.
It is also possible to call np.where
with only the condition parameter set:
import numpy as np
arr = np.array([1,2,3,4])
result = np.where(arr > 2)
Even though this is perfectly valid code in NumPy, it may not yield the expected results.
When providing only the condition parameter to the np.where
function, it will behave as np.asarray(condition).nonzero()
or np.nonzero(condition)
. Both these functions provide a way to find the indices of the elements satisfying the condition passed as
parameter. Be mindful that np.asarray(condition).nonzero()
and np.nonzero(condition)
do not return the
values that satisfy the condition but only their indices. This means the result
variable now holds a
tuple with the first element being an array of all the indices where the condition arr > 2
was satisfied:
(array([2,3]),)
.
If the intention is to find the indices of the elements which satisfy a certain condition it is preferable to use the
np.asarray(condition).nonzero()
or np.nonzero(condition)
function instead.
How to fix it
To fix this issue either:
- provide all three parameters to the
np.where
function (condition, value if the condition is satisfied, value if the condition is
not satisfied) or,
- use the
np.nonzero
function.
Code examples
Noncompliant code example
import numpy as np
def bigger_than_two():
arr = np.array([1,2,3,4])
result = np.where(arr > 2) # Noncompliant: only the condition parameter is provided to the np.where function.
Compliant solution
import numpy as np
def bigger_than_two():
arr = np.array([1,2,3,4])
result = np.where(arr > 2, arr + 1, arr) # Compliant
indices = np.nonzero(arr > 2) # Compliant
Resources
Documentation
- NumPy Documentation - numpy.where
- NumPy Documentation - numpy.nonzero