Перейти до вмісту

JAX

Матеріал з K2 ERP Wiki
Задача: пришвидшити числову функцію, яка викликається багато разів. Практична роль: Equinox зручний для користувачів, які хочуть поєднати JAX-підхід із простими Python-класами.

</syntaxhighlight>

</syntaxhighlight>

import jax.numpy as jnp

!</syntaxhighlight> jax.grad — це трансформація, яка створює функцію для обчислення gradient. Задача: застосувати функцію до batch прикладів.== Для чого застосовується для JAX ==

Небажаний підхід:

Приклад:

  • великі array operations;
  • jit-compiled functions;
  • vectorized code;
  • batch computation;
  • accelerator-friendly logic;
  • pure functions;
  • мінімум Python loops у compiled hot path. return x ** 2

Equinox

import jax.numpy as jnp

JAX і Scikit-learn

batched_square = jax.vmap(square)

'''Практична ідея:''' явні random keys роблять випадковість контрольованішою, відтворюванішою і суміснішою з functional programming.<syntaxhighlight lang="python">
state = []
Проблеми можуть виникати, якщо:

</div>

* Офіційна документація JAX. У JAX критично контролювати shape і dtype.== Тематичні мітки ==

Рекомендовано:
key = jax.random.PRNGKey(0)
=== Vectorization ===

<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

Результат: compiled version функції для швидшого виконання.== JAX і TensorFlow ==

== Debugging у JAX ==

== JAX для neural networks ==

'''jax.numpy''' або '''jnp''' — це NumPy-подібний API у JAX. Критерій

XLA оптимізує:

Optax здатна використовуватися для:

  • arrays;
  • matrix operations;
  • linear algebra;
  • broadcasting;
  • elementwise functions;
  • reductions;
  • reshaping;
  • indexing;
  • mathematical functions.

Інструменти: JAX + Flax/Haiku/Equinox + Optax. def loss(w): Головне правило: у JAX shapes і dtypes — це частина дизайну програми, а не другорядна деталь. Для neural networks зазвичай використовують:

JAX використовує explicit random keys. JAX

Pytrees

  • спочатку запускати без jit;
  • перевіряти shapes;
  • перевіряти dtypes;
  • використовувати менші приклади;
  • уникати зайвої складності;
  • тестувати функції окремо;
  • додавати asserts там, де доречно;
  • розуміти tracing;
  • обережно працювати з print у compiled code.
Добре працюють: Вона дає можливість автоматизовано обчислювати похідні функцій. a = jax.random.normal(key1, shape=(3,))

Типові помилки в JAX

Типові помилки користувачів

Головна перевага: JAX дає можливість комбінувати математично чистий Python-код із потужними трансформаціями для gradients, compilation і vectorization. Результати JAX-обчислень потрібно тестувати, перевіряти і валідувати на реальних сценаріях.

import jax.numpy as jnp

JAX можна розглядати як систему перетворень для числових Python-функцій. jax.numpy підтримує багато знайомих операцій:

Задача: знайти gradient loss-функції. * JAX automatic differentiation documentation. JAX arrays схожі на NumPy arrays, але мають важливі відмінності: критично: у JAX стан моделі й параметри часто передаються явно, що здатна бути незвично для користувачів PyTorch або Keras.== JAX Array ==

Для research: JAX цінують за те, що transformations можна комбінувати: як приклад, grad + jit + vmap. Результат: векторизована функція без ручного Python loop. df = jax.grad(f)

Примітка: Haiku виступає як одним із варіантів neural network framework поверх JAX, але не виступає як єдиним стандартом. import jax @jax.jit

def impure_function(x):

'''Небезпека:''' JAX-код здатна бути дуже швидким, але неправильна технічна архітектура обчислень здатна зробити його повільним, нестабільним або важким для налагодження. ! JAX дуже популярний у research-середовищах, тому що він дає можливість оперативно експериментувати з математичними ідеями.== Див. так само ==
Приклади:
x = jnp.array([1.0, 2.0, 3.0])

'''Optax''' — це бібліотека optimization algorithms для JAX.<div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">
== XLA ==
'''Суть immutable arrays:''' замість зміни масиву на місці JAX створює нове логічне представлення результату, що краще узгоджується з трансформаціями й компіляцією.== Flax ==

Типовий приклад:

</div>

це Python-бібліотека; так само реалізовано автоматичного диференціювання. '''Equinox''' — це бібліотека для JAX, яка дає можливість описувати neural networks і differentiable programs через Python-класи, сумісні з pytrees. |-
| базовий стиль
| Функціональні transformations: grad, jit, vmap
| Повна ML-платформа з Keras, TensorFlow Lite, Serving, TFX
|-
| Рівень
| Нижчий і гнучкіший для research
| Ширша production-екосистема
|-
| Neural networks
| Через Flax, Haiku, Equinox та інші бібліотеки
| Через Keras і TensorFlow API
|-
| Компіляція
| XLA через jit
| TensorFlow graph/XLA у відповідних сценаріях
|-
| Типове використання
| Research, differentiable programming, high-performance numeric code
| Production ML, deep learning, mobile/browser deployment
|}

'''Просте пояснення:''' JAX спочатку “дивиться” на функцію як на обчислення, яке можна трансформувати, а вже потім виконує оптимізований варіант. return jnp.sin(x) * jnp.cos(x) + x ** 2

<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">

<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">

* писати JAX-код як звичайний NumPy без урахування immutability;
* забувати розділяти random keys;
* додавати side effects у jit-функції;
* очікувати, що print працюватиме як у звичайному Python;
* створювати багато recompilations через змінні shapes;
* використовувати Python loops замість vmap або scan;
* переносити інформаційні дані між CPU і GPU занадто часто;
* не тестувати функції до jit;
* не контролювати dtype;
* не зберігати reproducibility.<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">

Equinox здатна бути корисним для:

== jax.numpy ==

JAX-документація зазначає, що autodiff у JAX дає можливість швидко обчислювати похідні вищих порядків, бо функції, які обчислюють derivatives, самі можуть бути диференційованими. '''Tracing''' — це механізм, через який JAX аналізує функцію для трансформацій на кшталт `jit`, `grad` або `vmap`.<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
Pytrees — це вкладені структури Python, які JAX здатна опрацьовувати як дерева даних.
Практична роль: XLA виступає як однією з причин, чому JAX здатна виконувати числові функції оперативно після компіляції. Для ефективного використання потрібно розуміти devices, sharding, data layout і синхронізацію.

До них належать:

Приклад:

b = jax.random.uniform(key2, shape=(3,))

Optax

базовий фокус Числові обчислення, autodiff, JIT, research ML Класичне машинне навчання
Типові задачі Neural networks, optimization, differentiable programming Classification, regression, clustering, preprocessing
API Функціональні transformations fit/predict/transform
Для табличного ML Можна, але часто потребує більше коду Дуже інтуїтивно
Для gradients Сильна сторона Не базовий фокус
{| class="wikitable"

! ! Якщо задача проста й таблична, Scikit-learn або NumPy можуть бути практичнішими. * Документація Equinox. '''Просте пояснення:''' pytree дає можливість JAX працювати не лише з одним масивом, а з цілою вкладеною структурою масивів. Це низькорівнева й гнучка платформа числових обчислень і трансформацій, поверх якої часто використовують додаткові бібліотеки. Результат: функція, яка повертає похідну або gradients параметрів.</div>
grad_loss = jax.grad(loss)
JAX застосовується для там, де потрібні швидкі числові обчислення і gradients. Основні конкурентні переваги JAX:

</syntaxhighlight>

  • physics simulations;
  • optimization;
  • differential equations;
  • computational biology;
  • probabilistic modeling;
  • numerical methods;
  • inverse problems;
  • differentiable rendering;
  • scientific machine learning. Висновок: Scikit-learn краще підходить для класичного tabular ML, а JAX — для задач, де потрібні gradients, JIT і custom numerical computation. * shape змінюється між викликами jit-функції;
  • dtype не той, який очікувався;
  • інформаційні дані не на тому device;
  • модель очікує batch, а отримує один приклад;
  • vmap застосований по неправильній осі;
  • broadcasting функціонує не так, як очікувалося. Увага: JAX не автоматизовано пришвидшує будь-який Python-код. Pure function — це функція, яка:

jit

Він дає можливість:

Обмеження JAX

  • NumPy-подібний API;
  • automatic differentiation;
  • jit compilation;
  • vmap для vectorization;
  • pmap для parallelism;
  • GPU/TPU support;
  • composable transformations;
  • functional programming style;
  • зручність для research;
  • сильний для optimization;
  • підходить для differentiable programming;
  • ERP-платформа Flax, Optax, Haiku, Equinox.== Haiku ==
  • optimization;
  • training neural networks;
  • loss functions;
  • scientific computing;
  • differentiable simulations;
  • gradient-based methods.

Практична роль: якщо JAX — це обчислювальний фундамент, то Flax часто застосовується для як high-level neural network library поверх JAX. return x * 2

Під час роботи з JAX часто виникають типові помилки. Просте пояснення: JAX Array — це масив для числових обчислень, який здатна працювати в JAX-світі: з gradients, JIT і прискорювачами. return x * 2

* ліцензію JAX;
* ліцензії залежностей;
* ліцензії моделей;
* ліцензії датасетів;
* умови використання accelerator-середовища;
* політики організації;
* вимоги до attribution. У JAX робота з випадковістю відрізняється від NumPy. Вона дає можливість застосувати функцію до batch даних без ручного написання циклу. def f(x):
Приклади:
'''Основна ідея:''' JAX дає можливість писати код у стилі NumPy, але додавати до нього automatic differentiation, JIT-компіляцію, векторизацію і прискорення на GPU/TPU. def compute(x):
<syntaxhighlight lang="python">
== Immutable arrays ==

Типові задачі:

print(batched_square(jnp.array([1, 2, 3, 4])))

* писати pure functions;
* передавати state явно;
* використовувати jax.numpy замість numpy у JAX-функціях;
* спочатку перевіряти код без jit;
* використовувати jit для гарячих обчислень;
* використовувати vmap замість ручних циклів;
* контролювати shapes і dtypes;
* правильно працювати з PRNG keys;
* зберігати прості й тестовані функції;
* вимірювати продуктивність;
* уникати зайвих device-host transfers;
* документувати numerical assumptions;
* тестувати gradients.</div>
=== Automatic differentiation ===
<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">

Перед використанням у продукті потрібно перевіряти:

</div>
== Джерела ==

JAX — це інструмент для обчислень і ML, тому відповідальність за моделі та їхнє використання залишається за розробником. Для кількох випадкових операцій key потрібно розділяти:

  • багато дрібних Python-викликів;
  • часті передачі даних між host і device;
  • side effects;
  • динамічні форми масивів;
  • погано структурований код;
  • надмірна recompilation.

import jax

</syntaxhighlight>

Продуктивність

!
'''Суть jax.numpy:''' розробник пише код у стилі NumPy, але отримує можливість використовувати JAX-трансформації: grad, jit, vmap та інші. '''Суть jit:''' JAX компілює Python-функцію у швидший обчислювальний код, який здатна ефективно для бізнесу виконуватися на accelerator hardware.<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
|-
| базовий стиль
| Functional programming і transformations
| Imperative/eager style із dynamic computation graph
|-
| Autodiff
| grad як функціональна трансформація
| autograd через tensor operations
|-
| Neural network API
| Зазвичай через Flax, Haiku, Equinox
| torch.nn вбудований у PyTorch
|-
| Research
| Сильний у composable transformations і accelerator-oriented code
| Дуже популярний у deep learning research
|-
| Стан моделі
| Часто передається явно
| Часто зберігається в modules/objects
|}

'''Помилка:''' обирати JAX лише тому, що він швидкий.<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
'''критично:''' JAX  це не повна high-level ML-платформа на кшталт TensorFlow або PyTorch. Поширені помилки:

'''Висновок:''' NumPy  базова бібліотека числових обчислень, а JAX додає до NumPy-подібного стилю autodiff, JIT і accelerator support. ! '''Суть екосистеми:''' JAX дає фундаментальні трансформації й обчислення, а додаткові бібліотеки додають neural networks, optimizers, checkpoints, probabilistic programming та інші інструменти.== Висновок ==
Інструмент: jax.grad. ! Репозиторій JAX поширюється під ліцензією Apache 2.0. !=== JIT-компіляція ===

</div>
'''Flax'''  це бібліотека для neural networks на JAX. JAX виступає як open-source проєктом. * писати NumPy-подібний код;
* автоматизовано обчислювати gradients;
* компілювати функції через jit;
* векторизувати функції через vmap;
* паралелити обчислення через pmap;
* працювати з GPU і TPU;
* будувати neural networks через додаткові бібліотеки;
* створювати differentiable programs;
* оптимізувати числові функції;
* виконувати research-oriented ML-експерименти. '''jax.vmap'''  це трансформація для автоматичної векторизації функцій. '''Практична роль:''' Optax часто застосовується для разом із JAX і Flax для навчання neural networks.{{SEO
|title=JAX  Python-бібліотека для високопродуктивних обчислень, automatic differentiation, NumPy API і машинного навчання
|description=JAX  Wiki-стаття про Python-бібліотеку для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, NumPy-подібного API, GPU/TPU-прискорення і machine learning. Розглянуто jax.numpy, grad, jit, vmap, pmap, XLA, pure functions, immutable arrays, PRNG, JAX ecosystem, Flax, Optax, Haiku, Equinox, переваги, обмеження, безпеку і відповідальне використання.
|keywords=JAX, jax.numpy, jnp, Google JAX, Python JAX, automatic differentiation, autograd, jit, vmap, pmap, XLA, GPU, TPU, NumPy API, machine learning, deep learning, high-performance computing, differentiable programming, Flax, Optax, Haiku, Equinox, neural networks, functional programming, JAX arrays
|alternativeTo=ручна реалізація automatic differentiation; повільні NumPy-обчислення без GPU/TPU; самописна JIT-компіляція; складне масштабування числових обчислень; ручне векторизування циклів; окремі інструменти для gradient-based optimization; класичні Python-обчислення без accelerator support
}}

* model parameters;
* forward function;
* loss function;
* grad;
* optimizer update;
* jit;
* batch processing;
* evaluation.== vmap ==

JAX найкраще функціонує з '''pure functions'''.

Підказка: JAX варто вивчати через маленькі функції: спочатку jnp, потім grad, потім jit, потім vmap. Водночас JAX потребує розуміння functional programming, immutable arrays, explicit random keys, tracing, shapes, dtypes і особливостей compiled execution.== ліцензійний пакет == офіційний GitHub-репозиторій JAX описує його як систему для composable transformations of Python+NumPy programs, а серед ключових трансформацій виділяє `grad`, `jit` і `vmap`. result = compute(jnp.ones((1000,))) Потрібно враховувати:

y = x.at [0].set(10)

Критично: швидка модель не означає правильна модель. Висновок: JAX більше схожий на гнучку систему числових трансформацій, а TensorFlow — на ширшу end-to-end ML-платформу.

* функція викликається багато разів;
* обчислення великі;
* застосовується для GPU або TPU;
* виступає як багато array operations;
* код підходить для компіляції. Критерій

!=== Neural network training ===

<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">

JAX можна використовувати в різних сценаріях. ! * Документація Haiku. JAX не намагається бути однією великою бібліотекою для всього. '''Haiku'''  це бібліотека для neural networks на JAX, створювалась як DeepMind.</div>

* якість даних;
* bias;
* correctness of gradients;
* reproducibility;
* numerical stability;
* privacy;
* security of model deployment;
* ліцензії даних;
* вплив ML-рішень на користувачів;
* моніторинг після deployment. '''Практична роль:''' grad дає можливість писати математичну функцію напряму, а похідні для оптимізації отримувати автоматизовано. '''Головне правило:''' JAX найкраще функціонує тоді, коли код написаний функціонально, інформаційні дані мають стабільні shapes, а transformations використовуються усвідомлено. |-
| базовий фокус
| Прискорені числові обчислення, transformations, autodiff
| Загальні числові обчислення в Python
|-
| GPU/TPU
| технічна підтримка accelerator execution
| Зазвичай CPU-орієнтований
|-
| Automatic differentiation
| Вбудовано через grad
| Немає вбудованого autodiff
|-
| JIT
| виступає як через jax.jit
| Немає стандартного JIT у NumPy
|-
| Mutability
| Functional-style updates
| Часто in-place mutation
|}

<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
JAX дуже схожий на NumPy за стилем API, але має важливі відмінності. y = jnp.sin(x) + x ** 2
Pytree здатна містити:

<syntaxhighlight lang="python">

</div>

JAX так само часто порівнюють із PyTorch. JAX

<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">

* навчання neural network;
* custom optimization;
* differentiable physics simulation;
* research prototype;
* reinforcement learning;
* probabilistic modeling;
* scientific computing;
* gradient-based calibration;
* vectorized numerical experiments;
* high-performance array computation;
* TPU-based experiments;
* custom loss functions. Навколо нього існує ERP-платформа бібліотек.== Automatic differentiation ==

'''Суть automatic differentiation:''' JAX здатна сам побудувати функцію, яка обчислює gradient іншої функції. import jax
'''Небезпека:''' код здатна виглядати схожим на NumPy, але поводитися інакше через JAX-трансформації, компіляцію і immutable arrays. * Документація Flax. JIT означає '''Just-In-Time compilation'''. * параметрів моделей;
* gradients;
* optimizer state;
* batch data;
* structured outputs;
* tree transformations. Для налагодження корисно:

Практична цінність: якщо наукова модель диференційована, JAX здатна допомогти оптимізувати її параметри через gradients. Вона поєднує NumPy-подібний API із потужними функціональними трансформаціями: `grad`, `jit`, `vmap`, `pmap`. Flax застосовується для для:

JAX часто порівнюють із TensorFlow. Інструмент: jax.jit. Код потрібно писати з урахуванням JIT, vectorization і device execution.

Приклади:

* очікування NumPy-style mutation;
* використання side effects у jit-функціях;
* неправильна робота з random keys;
* надмірна recompilation;
* Python control flow там, де потрібен JAX control flow;
* змішування NumPy і jax.numpy без розуміння наслідків;
* передача Python objects у jit без static_argnums;
* часті device-host transfers;
* неправильне використання vmap;
* недостатнє розуміння shapes.<syntaxhighlight lang="python">

JAX arrays зазвичай розглядаються як immutable. JAX
</div>
Приклад:

JAX — це Python-бібліотека для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, векторизації і роботи з accelerator hardware. Практична порада: якщо задача потребує gradients, accelerator execution і кастомної математики, JAX здатна бути дуже сильним вибором.

Вона оптимізує:

  • neural networks;
  • scientific computing;
  • differentiable programming;
  • structured models;
  • research code;
  • функціонального стилю з класами.

критично: JAX-трансформації краще працюють із функціональним стилем програмування, де стан передається явно, а не змінюється приховано. NumPy

print(grad_loss(2.0))

Pytrees часто використовуються для:

JAX ecosystem

state.append(x)

</syntaxhighlight> JAX часто застосовують, коли потрібно в машинному навчанні, deep learning, наукових обчисленнях, optimization, differentiable programming, research-проєктах і задачах, де потрібне поєднання гнучкого Python-коду з високою продуктивністю. * defining neural networks;

  • training models;
  • research experiments;
  • transformer models;
  • model state;
  • neural network modules;
  • integration with Optax;
  • large-scale ML research. Просте пояснення: vmap бере функцію для одного прикладу і автоматизовано робить її функцією для batch. * компілювати array operations;
  • оптимізувати граф обчислень;
  • виконувати код на CPU, GPU або TPU;
  • об’єднувати операції;
  • зменшувати overhead;
  • пришвидшувати великі обчислення. Приклад:

`grad` часто застосовується для для:

JAX для research

Tracing

Shape і dtype

Замість in-place mutation застосовується для функціональний стиль оновлення версій. Гірше працюють:

Загальний SEO-опис

  • вищий поріг входу;
  • незвичний functional style;
  • immutable arrays;
  • explicit PRNG keys;
  • складніші помилки при jit;
  • потрібно розуміти tracing;
  • не всі NumPy-патерни переносяться напряму;
  • neural network API винесений в окремі бібліотеки;
  • production deployment здатна потребувати додаткової роботи;
  • складніше debugging у compiled code;
  • можливі проблеми сумісності з версіями CUDA/TPU stack.
  • SGD;
  • Adam;
  • AdamW;
  • learning rate schedules;
  • gradient transformations;
  • gradient clipping;
  • optimizer state;
  • training loops. Інструмент: jax.vmap. Практична порада: перед оптимізацією через jit спочатку варто переконатися, що функція правильно функціонує у звичайному режимі.

def pure_function(x):

pmap

критично: pmap складніший за grad, jit і vmap. JAX сам по собі не має такого центрального high-level neural network API, як `torch.nn` у PyTorch або Keras у TensorFlow. * Flax;

  • Haiku;
  • Equinox;
  • custom JAX code;
  • Optax для optimizers.

def square(x):

return x ** 2 + 3 * x + 1

JAX особливо корисний для research, differentiable programming, optimization, neural networks, scientific computing і задач, де потрібно поєднати математичну гнучкість із продуктивністю.== конкурентні переваги JAX ==

print(df(2.0)) Тут `y` — новий масив із оновленим значенням. * control flow;

  • shapes;
  • static arguments;
  • error messages;
  • recompilation;
  • debug behavior. Задача: навчити neural network. XLA або Accelerated Linear Algebra — це компілятор, який застосовується для JAX для оптимізації числових обчислень. * Документація Optax.
    == PRNG у JAX ==
    Результат: training loop із gradients, optimizer update і evaluation. * створювати modules;
    * керувати parameters;
    * будувати neural networks;
    * працювати з JAX transformations;
    * організовувати model code. Scikit-learn
    
    import jax.numpy as jnp
    
    * machine learning research;
    * deep learning;
    * neural networks;
    * optimization;
    * automatic differentiation;
    * scientific computing;
    * simulation;
    * probabilistic modeling;
    * differentiable programming;
    * reinforcement learning;
    * large-scale numerical computing;
    * GPU/TPU acceleration. * list;
    * tuple;
    * dict;
    * dataclass;
    * nested structures;
    * arrays;
    * parameters of neural networks. Окремо варто відзначити JIT-компіляції, векторизації, роботи з NumPy-подібним API і запуску обчислень на CPU, GPU і TPU виступає ключовою рисою високопродуктивних числових обчислень забезпечується через '''JAX'''. JAX здатна бути дуже швидким, але продуктивність залежить від стилю коду. JAX і Scikit-learn мають різні ролі. TensorFlow
    
    <syntaxhighlight lang="text">
    
    * Flax;
    * Optax;
    * Haiku;
    * Equinox;
    * Orbax;
    * Chex;
    * JAXopt;
    * NumPyro;
    * Distrax;
    * TFP on JAX.== grad ==
    
    == Pure functions ==
    
    * можуть виконуватися на accelerator hardware;
    * підтримують JAX-трансформації;
    * зазвичай виступає як immutable;
    * можуть бути частиною compiled computation;
    * можуть брати участь в automatic differentiation;
    * можуть переноситися між devices. * JAX Quickstart.</div>
    JAX має обмеження. JAX
    Debugging у JAX здатна бути складнішим, ніж у звичайному Python, особливо всередині `jit`. Критерій
    
    `jit` здатна пришвидшити обчислення, особливо якщо:
    
    '''Головна думка:''' JAX  це не без зусиль швидкий NumPy, а платформа composable transformations для Python-функцій, яка відкриває потужні функції ERP для gradients, JIT, vectorization і accelerator-based computing. import jax.numpy as jnp
    
    </div>
    
    * batch processing;
    * per-example gradients;
    * vectorized evaluation;
    * заміни Python loops;
    * прискорення обчислень;
    * cleaner code. x = jax.random.normal(key, shape=(3,))
    
    JAX застосовується для не лише для нейронних мереж, а й для наукових обчислень. key1, key2 = jax.random.split(key)
    {| class="wikitable"
    </div>
    
    <syntaxhighlight lang="text">
    
    == JAX для наукових обчислень ==
    
    pmap здатна використовуватися для:
    
    * залежить лише від своїх аргументів;
    * не змінює зовнішній стан;
    * не має прихованих побічних ефектів;
    * для однакових входів повертає однаковий результат. * JAX documentation щодо jit, vmap, pmap і pytrees. Критерій
    
    import jax
    
    <div style="background:#f0eaff; border-left:6px solid #8e44ad; padding:12px; margin:12px 0;">
    
    '''JAX Array'''  це базовий тип масиву в JAX.<div style="background:#eef2ff; border-left:6px solid #4f46e5; padding:12px; margin:12px 0;">
    <div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">
    
    <div style="background:#eef2ff; border-left:6px solid #4f46e5; padding:12px; margin:12px 0;">
    
    </div>
    
    '''jax.pmap'''  це трансформація для паралельного виконання обчислень на кількох devices.== JAX і PyTorch ==
    
    Приклад:
    
    '''Automatic differentiation'''  одна з ключових можливостей JAX. '''критично:''' open-source ліцензійний пакет JAX не скасовує обмежень на інформаційні дані, моделі або сторонні бібліотеки, які використовуються разом із ним. Він корисний для:
    {| class="wikitable"
    

Хороші практики роботи з JAX

  • custom loss functions;
  • differentiable simulations;
  • optimization algorithms;
  • neural architectures;
  • reinforcement learning;
  • probabilistic programming;
  • scientific ML;
  • large-scale research;
  • vectorized experiments;
  • accelerator-friendly code.== Безпека і відповідальне використання ==

JAX і NumPy

Перевага: JAX поєднує знайомий стиль NumPy із сучасними можливостями для machine learning і high-performance computing. Приклад:

Під час tracing JAX не завжди має звичайні Python-значення, а функціонує з абстрактними представленнями. * multi-GPU training;

  • multi-TPU computation;
  • паралельного виконання batch;
  • distributed-style обчислень;
  • масштабування ML-експериментів. import jax.numpy as jnp

jax.jit — це трансформація, яка компілює функцію для швидшого виконання. * JAX GitHub repository. Це означає, що масив не змінюється “на місці” так само, як це часто роблять у NumPy.== Типові сценарії використання ==

!

import jax

Він дає можливість писати код, схожий на NumPy: x = jnp.array([1, 2, 3])

Приклади задач

Висновок: PyTorch часто зручніший для класичного object-oriented deep learning workflow, а JAX — для функціонального, трансформаційного і research-oriented підходу. PyTorch

Типовий training loop у JAX складається з:

`vmap` корисний для:

return (w - 5.0) ** 2

Це здатна впливати на:

Можливі складнощі: