JAX study notes[2]

文章目录

JAX array

  1. to usually construct an array with API function provided by JAX instead of using jax.Array constructor.
  2. for example,the jax.numpy.zeros() create a familiar NumPy-style array which all element is zero .
import jax.numpy as jnp
print(jnp.zeros((2, 3), dtype=bool))
print(jnp.zeros((3, 2), dtype=float))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[False False False] 
 [False False False]]
[[0. 0.] 
 [0. 0.] 
 [0. 0.]]

you can use following update methods to change the jax array:

.set() - Replace values

.add() - Add to values

.multiply() - Multiply values

.min()/.max() - Element-wise min/max

.apply() - Custom function
  1. it sets up an array which fill with zeros having the same shape and element type to use jax.numpy.zeros_like .
import jax.numpy as jnp
from jax import jit

@jit
def updateArr(arr):
    for i in range(arr.shape[0]):
        arr=arr.at[i,0].set(10)
    return arr

a=jnp.zeros((2, 3), dtype=bool)
b=jnp.zeros((3, 2), dtype=float)
c=jnp.zeros_like(b)
print(a)
print(b)
print(c)
c=updateArr(c)
print(c)
  1. forming the elements ,which get out of two arrays , used jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None) .

jax.numpy.nonzero(a, *, size=None, fill_value=None)[source] returns the indices of nonzero elements of an array.

int the case of to be ruling out all the parameters of the jax.numpy.where other than condition, jax.numpy.where return the indices.

an example that jax.numpy.where ,jax.numpy.nonzero and jax.zeros_like are integrate in it is as shown below:

import jax.numpy as jnp
from jax import jit

@jit
def updateArr(arr):
    sg=-1
    for i in range(5,arr.shape[0]):
        sg=sg*(-1)
        arr=arr.at[i,0].set((10+i)*sg)
    return arr

a=jnp.zeros((2, 3), dtype=bool)
b=jnp.zeros((10, 2), dtype=float)
c=jnp.zeros_like(b)
print(a)
print(b)
print(c)
c=updateArr(c)
print(c)
c1=jnp.where(c > 5)
c2=jnp.where(abs(c) > 16,abs(c),c)
print(c1)
print(c2)

# 创建一个二维数组
arr = jnp.array([
    [0, 1, 0],
    [2, 0, 3],
    [0, 0, 4]
])

# 查找非零元素的位置
rows, cols = jnp.nonzero(arr)

print("行索引:", rows)  # 输出: [0 1 1 2]
print("列索引:", cols)  # 输出: [1 0 2 2]

c3=jnp.nonzero(c)

print(c3)

PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/l2.py
[[False False False]
 [False False False]]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
[[  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [ 15.   0.]
 [-16.   0.]
 [ 17.   0.]
 [-18.   0.]
 [ 19.   0.]]
(Array([5, 7, 9], dtype=int32), Array([0, 0, 0], dtype=int32))
[[  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [  0.   0.]
 [ 15.   0.]
 [-16.   0.]
 [ 17.   0.]
 [ 18.   0.]
 [ 19.   0.]]
行索引: [0 1 1 2]
列索引: [1 0 2 2]
(Array([5, 6, 7, 8, 9], dtype=int32), Array([0, 0, 0, 0, 0], dtype=int32))

references

  1. https://docs.jax.dev
  2. deepseek
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

身在此心在彼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值