Keras 3 是一个多后端深度学习框架,支持 TensorFlow、JAX 和 PyTorch。
Keras 3 在 PyPI 上以keras
.请注意,Keras 2 仍可作为tf-keras
软件包使用。
- 安装
keras
:
pip install keras --upgrade
- 安装后端包。
要使用keras
,您还应该安装选择的后端:tensorflow
、jax
或torch
。请注意,这tensorflow
是使用某些 Keras 3 功能所必需的:某些预处理层以及tf.data
管道。
Keras 3 兼容 Linux 和 MacOS 系统。对于 Windows 用户,我们建议使用 WSL2 来运行 Keras。要安装本地开发版本:
- 安装依赖项:
pip install -r requirements.txt
- 从根目录运行安装命令。
python pip_build.py --install
该requirements.txt
文件将安装仅 CPU 版本的 TensorFlow、JAX 和 PyTorch。对于 GPU 支持,我们还requirements-{backend}-cuda.txt
为 TensorFlow、JAX 和 PyTorch提供单独的支持。这些通过安装所有 CUDA 依赖项pip
,并期望预安装 NVIDIA 驱动程序。我们建议每个后端使用干净的 python 环境,以避免 CUDA 版本不匹配。作为示例,以下是如何使用以下命令创建 Jax GPU 环境conda
:
conda create -y -n keras-jax python=3.10 conda activate keras-jax pip install -r requirements-jax-cuda.txt python pip_build.py --install
您可以导出环境变量KERAS_BACKEND
,也可以编辑本地配置文件来~/.keras/keras.json
配置后端。可用的后端选项有:"tensorflow"
、"jax"
、"torch"
。例子:
export KERAS_BACKEND="jax"
在 Colab 中,您可以执行以下操作:
import os os.environ["KERAS_BACKEND"] = "jax"import keras
import keras" tabindex="0" role="button">
注意:导入前必须配置后端keras
,导入包后不能更改后端。
Keras 3 旨在作为tf.keras
(使用 TensorFlow 后端时)的直接替代品。只需使用现有的tf.keras
代码,确保您的调用model.save()
使用最新的.keras
格式,然后就完成了。
如果您的tf.keras
模型不包含自定义组件,您可以立即开始在 JAX 或 PyTorch 上运行它。
如果它确实包含自定义组件(例如自定义层或自定义train_step()
),通常可以在短短几分钟内将其转换为与后端无关的实现。
此外,Keras 模型可以使用任何格式的数据集,无论您使用什么后端:您可以使用现有tf.data.Dataset
管道或 PyTorch来训练模型DataLoaders
。
- 在任何框架之上运行高级 Keras 工作流程 - 随意受益于每个框架的优势,例如 JAX 的可扩展性和性能或 TensorFlow 的生产生态系统选项。
- 编写可在任何框架的低级工作流中使用的自定义组件(例如层、模型、指标)。
- 您可以采用 Keras 模型,并在用本机 TF、JAX 或 PyTorch 从头开始编写的训练循环中对其进行训练。
- 您可以采用 Keras 模型并将其用作 PyTorch 原生的一部分
Module
或 JAX 原生模型函数的一部分。
- 通过避免框架锁定,让您的 ML 代码面向未来。
- 作为 PyTorch 用户:终于可以使用 Keras 的强大功能和可用性了!
- 作为 JAX 用户:可以访问功能齐全、经过实战测试、文档齐全的建模和培训库。
请阅读Keras 3 发布公告了解更多信息。