the function annotations applied for static type checking maybe become a integral python coding standard.
jax.Array is the base class represented array.
to annotate in python project.
Level 1: Annotations as documentation
deff(x: jax.Array)-> jax.Array:# type annotations are valid for traced and non-traced types.return x
Level 2: Annotations for intelligent autocomplete the many modern IDEs such as vscode make use of the type annotations in intelligent code completion systems.
Level 3: Annotations for static type-checking
the package development with JAX must abide by two python type checking facilities including pytype developed by google , and mypy which known as the most popular static type checking tools.And beyond that, JAX will face chanllenges such as array duck-typing,transformations and decorators,array annotation lack of granularity and imprecise APIs inherited from NumPy.
JAX provided that static type annotations and runtime instance checks for duck-typed objects.
Static type annotations
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp
ArrayAnnotation = Union[Array, Tracer]@jitdeff(x: ArrayAnnotation)-> ArrayAnnotation:assertisinstance(x,(Array, Tracer))# Explicit checkreturn x *2
x = jnp.array([1.0,2.0,3.0])
result = f(x)print(result)# [2. 4. 6.] (jax.Array)@jitdefg(x):return f(x)# `x` is a Tracer here!print(g(x))# Same output, but internally tracedfrom jax import grad
df_dx = grad(lambda x: f(x).sum())# Works with tracersprint(df_dx(x))# [2. 2. 2.] (gradient of x*2)
f("invalid_input")
f(234)
[2. 4. 6.][2. 4. 6.][2. 2. 2.]
Traceback (most recent call last):
File "e:\learn\learnpy\l2.py", line 29, in<module>
f("invalid_input")
TypeError: Error interpreting argument to <function f at 0x0000018EEFD999E0> as an abstract array. The problematic value is of type<class 'str'> and was passed to the function at path x.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Runtime instance checks
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp
ArrayInstance = Union[Array, Tracer]@jitdeff(x):returnisinstance(x, ArrayInstance)
x = jnp.array([1,2,3])assert f(x)# x will be an arrayassert jit(f)(x)# x will be a tracer