JAX study notes[8]

文章目录

jax.typing

  1. the function annotations applied for static type checking maybe become a integral python coding standard.
  2. jax.Array is the base class represented array.
  3. to annotate in python project.
  • Level 1: Annotations as documentation
def f(x: jax.Array) -> jax.Array:  # type annotations are valid for traced and non-traced types.
  return x
  • Level 2: Annotations for intelligent autocomplete
    the many modern IDEs such as vscode make use of the type annotations in intelligent code completion systems.
  • Level 3: Annotations for static type-checking
  1. the package development with JAX must abide by two python type checking facilities including pytype developed by google , and mypy which known as the most popular static type checking tools.And beyond that, JAX will face chanllenges such as array duck-typing,transformations and decorators,array annotation lack of granularity and imprecise APIs inherited from NumPy.
  2. JAX provided that static type annotations and runtime instance checks for duck-typed objects.
  • Static type annotations
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp

ArrayAnnotation = Union[Array, Tracer]

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
    assert isinstance(x, (Array, Tracer))  # Explicit check
    return x * 2

x = jnp.array([1.0, 2.0, 3.0])
result = f(x)
print(result)  # [2. 4. 6.] (jax.Array)

@jit
def g(x):
    return f(x)  # `x` is a Tracer here!

print(g(x))      # Same output, but internally traced


from jax import grad

df_dx = grad(lambda x: f(x).sum())  # Works with tracers
print(df_dx(x))  # [2. 2. 2.] (gradient of x*2)

f("invalid_input")
f(234)
[2. 4. 6.]
[2. 4. 6.]
[2. 2. 2.]
Traceback (most recent call last):
  File "e:\learn\learnpy\l2.py", line 29, in <module>
    f("invalid_input")
TypeError: Error interpreting argument to <function f at 0x0000018EEFD999E0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
  • Runtime instance checks
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp

ArrayInstance = Union[Array, Tracer]

@jit
def f(x):
  return isinstance(x, ArrayInstance)


x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

references

https://docs.jax.dev/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

海边的水水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值