文章目录
JAX array
- to usually construct an array with API function provided by JAX instead of using jax.Array constructor.
- for example,the
jax.numpy.zeros()
create a familiar NumPy-style array which all element is zero .
import jax.numpy as jnp
print(jnp.zeros((2, 3), dtype=bool))
print(jnp.zeros((3, 2), dtype=float))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[False False False]
[False False False]]
[[0. 0.]
[0. 0.]
[0. 0.]]
you can use following update methods to change the jax array:
.set() - Replace values
.add() - Add to values
.multiply() - Multiply values
.min()/.max() - Element-wise min/max
.apply() - Custom function
- it sets up an array which fill with zeros having the same shape and element type to use
jax.numpy.zeros_like
.
import jax.numpy as jnp
from jax import jit
@jit
def updateArr(arr):
for i in range(arr.shape[0]):
arr=arr.at[i,0].set(10)
return arr
a=jnp.zeros((2, 3), dtype=bool)
b=jnp.zeros((3, 2), dtype=float)
c=jnp.zeros_like(b)
print(a)
print(b)
print(c)
c=updateArr(c)
print(c)
- forming the elements ,which get out of two arrays , used
jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)
.
jax.numpy.nonzero(a, *, size=None, fill_value=None)[source]
returns the indices of nonzero elements of an array.
int the case of to be ruling out all the parameters of the jax.numpy.where
other than condition
, jax.numpy.where
return the indices.
an example that jax.numpy.where
,jax.numpy.nonzero
and jax.zeros_like
are integrate in it is as shown below:
import jax.numpy as jnp
from jax import jit
@jit
def updateArr(arr):
sg=-1
for i in range(5,arr.shape[0]):
sg=sg*(-1)
arr=arr.at[i,0].set((10+i)*sg)
return arr
a=jnp.zeros((2, 3), dtype=bool)
b=jnp.zeros((10, 2), dtype=float)
c=jnp.zeros_like(b)
print(a)
print(b)
print(c)
c=updateArr(c)
print(c)
c1=jnp.where(c > 5)
c2=jnp.where(abs(c) > 16,abs(c),c)
print(c1)
print(c2)
# 创建一个二维数组
arr = jnp.array([
[0, 1, 0],
[2, 0, 3],
[0, 0, 4]
])
# 查找非零元素的位置
rows, cols = jnp.nonzero(arr)
print("行索引:", rows) # 输出: [0 1 1 2]
print("列索引:", cols) # 输出: [1 0 2 2]
c3=jnp.nonzero(c)
print(c3)
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[False False False]
[False False False]]
[[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
[[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]
[0. 0.]]
[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 15. 0.]
[-16. 0.]
[ 17. 0.]
[-18. 0.]
[ 19. 0.]]
(Array([5, 7, 9], dtype=int32), Array([0, 0, 0], dtype=int32))
[[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 0. 0.]
[ 15. 0.]
[-16. 0.]
[ 17. 0.]
[ 18. 0.]
[ 19. 0.]]
行索引: [0 1 1 2]
列索引: [1 0 2 2]
(Array([5, 6, 7, 8, 9], dtype=int32), Array([0, 0, 0, 0, 0], dtype=int32))