Skip to content

Commit

Permalink
Add missing Python dependencies for ORT training (#7104)
Browse files Browse the repository at this point in the history
* Add missing Python dependencies for training

cerberus - option parsing
h5py - checkpoint
onnx - model proto
packaging/sympy - symbolic shape inference

* Separate requirements.txt for inference and training Python packages.
  • Loading branch information
KeDengMS authored Mar 24, 2021
1 parent fffe16c commit 6987106
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
8 changes: 8 additions & 0 deletions requirements-training.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cerberus
flatbuffers
h5py
numpy >= 1.16.6
onnx
packaging
protobuf
sympy
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,14 @@ def run(self):
'onnxruntime.transformers.longformer',
]

requirements_file = "requirements.txt"

if '--enable_training' in sys.argv:
packages.extend(['onnxruntime.training',
'onnxruntime.training.amp',
'onnxruntime.training.optim'])
sys.argv.remove('--enable_training')
requirements_file = "requirements-training.txt"

package_data = {}
data_files = []
Expand Down Expand Up @@ -310,12 +313,12 @@ def run(self):
cmd_classes['bdist_wheel'] = bdist_wheel
cmd_classes['build_ext'] = build_ext

requirements_path = path.join(getcwd(), "requirements.txt")
requirements_path = path.join(getcwd(), requirements_file)
if not path.exists(requirements_path):
this = path.dirname(__file__)
requirements_path = path.join(this, "requirements.txt")
requirements_path = path.join(this, requirements_file)
if not path.exists(requirements_path):
raise FileNotFoundError("Unable to find 'requirements.txt'")
raise FileNotFoundError("Unable to find " + requirements_file)
with open(requirements_path) as f:
install_requires = f.read().splitlines()

Expand Down

0 comments on commit 6987106

Please sign in to comment.