Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync ORTModule branch with master and fix tests #6526

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
202 commits
Select commit Hold shift + click to select a range
64709b1
Deprecate Python global configuration functions [Part 1] (#5923)
edgchen1 Dec 15, 2020
297c824
remove dnnl_dll_path from post build copy (#6142)
jywu-msft Dec 15, 2020
980a93c
Model Fusion For Bart (#6105)
liuziyue Dec 15, 2020
ac62cf8
Unify IExecutionProvider and IExecutionProviderFactory interfaces (#6…
RyanUnderhill Dec 16, 2020
939cc9b
Enable running the mnist_training sample without cuda (#6085)
georgen117 Dec 16, 2020
b648bf6
nnapi add min max support (#6117)
guoyu-wang Dec 16, 2020
0978d2b
Fix CUDA test hang: (#6138)
toothache Dec 16, 2020
aa49e47
Fix TensorRT kernel conflict issue for subgraphs of control flow oper…
stevenlix Dec 16, 2020
8fd0858
Add gradient registration for Abs. (#6139)
Dec 16, 2020
8269048
Partition initial optimizer state for Zero-1 (#6093)
ashbhandare Dec 16, 2020
7250562
Fix edge case in BFCArena where allocation failures could lead to an …
skottmckay Dec 16, 2020
344a2a8
Revert "work around of the build break in mac (#6069)" (#6150)
snnn Dec 16, 2020
0fa04bd
Fix clean_docker_image_cache.py detection of image pushes. (#6151)
edgchen1 Dec 17, 2020
503b61d
MLAS: add NEON version of int8 depthwise convolution (#6152)
tracysh Dec 17, 2020
36c03b3
Using a map of of ops to stages as input of partition function. (#5940)
Dec 17, 2020
efa1b0d
Minor fix to satisfy c++14 (#6162)
pranavsharma Dec 17, 2020
32c67c2
Deprecating Horovod and refactored Adasum computations (#5468)
Dec 18, 2020
dec703b
Update TensorRT-ExecutionProvider.md (#6161)
jayrodge Dec 18, 2020
34725ae
Bugfix for topk cuda kernel (#6164)
duli2012 Dec 18, 2020
98d8a3e
Revert "Fuse MatMulIntegerToFloat only when scales are scalar (#6008)…
yufenglee Dec 18, 2020
c339bb2
Remove ignored build warnings for pybind on Mac (#6165)
guoyu-wang Dec 18, 2020
adc2071
save_checkpoint, load_checkpoint and aggregate_checkpoints (#6136)
baijumeswani Dec 18, 2020
824ef9a
Don't try to bind unused inputs in the Training frontend (#6166)
Dec 18, 2020
86493e6
Update documentation for contributing a PR and add deprecation notice…
pranavsharma Dec 18, 2020
39aedbc
aggregate model states only for the case when mixed precision was tru…
baijumeswani Dec 18, 2020
bbb52e9
[NNAPI EP] Enable per-channel quantization for QlinearConv (#6155)
guoyu-wang Dec 19, 2020
11b0a54
Fix typo in BERT pretraining script (#6175)
Dec 19, 2020
cd3a5ac
Update get_docker_image.py to enable use without image cache containe…
edgchen1 Dec 19, 2020
2da8060
Helper for compiling EP to generate deterministic unique ids for use …
skottmckay Dec 21, 2020
f874260
Backend APIs for checkpointing (#5803)
jingyanwangms Dec 21, 2020
201d0db
Android coverage dashboard (#6163)
satyajandhyala Dec 21, 2020
ea9cfa5
Add usage details of unified MCR container image (#6182)
smkarlap Dec 21, 2020
53307a5
improve perf for softmax (#6128)
weixingzhang Dec 21, 2020
67ac6ae
Tune fast Gelu to use exp(x) instead of tanh(x) on Rocm platform (#6174)
Dec 22, 2020
234e94b
Add Status.csv to EP Perf Tool (#6167)
oliviajain Dec 22, 2020
945fae8
Lochi/quantization tool for trt (#6103)
chilo-ms Dec 22, 2020
fc27074
Implement ScatterND for CUDA EP (#6184)
hariharans29 Dec 22, 2020
04b3e0e
Condition fix in Resize operator (#6193)
hariharans29 Dec 22, 2020
a8b4826
Clean up checkpoint tests to use the new checkpoint functions (#6188)
baijumeswani Dec 22, 2020
21395f8
Implement comparing outputs that are sequence of maps of strings to f…
Dec 22, 2020
c562952
Dockerfile to build onnxruntime with ROCm 4.0
jessebenson Dec 21, 2020
0494a0f
Add ability to skip GPU tests based on GPU adapter name (#6198)
Dec 22, 2020
7347996
Openvino ep 2021.2 (#6196)
sfatimar Dec 23, 2020
1fc7f92
Fix a memory leak in test_inference.cc (#6201)
snnn Dec 25, 2020
52228a7
Use TArray in AMD element-wise kernels, rather than manually copying …
jessebenson Dec 22, 2020
7ccdfed
Remove most ROCm-specific element-wise code and reuse CUDA element-wi…
jessebenson Dec 22, 2020
8a0f5c5
Minor change to improve performance for operator Pad. (#5537)
xadupre Dec 28, 2020
2d09db6
Support double for operators Log, Reciprocal, Sum (CPU) (#6032)
xadupre Dec 28, 2020
111ac29
Support double for operators Where, LpNormalisation (#6034)
xadupre Dec 28, 2020
df7e2f3
Support double for operators Relu, Tanh, Sigmoid (#6221)
xadupre Dec 29, 2020
bbb6b41
Fix ImportError in build.py (#6231)
mgoin Dec 30, 2020
5c584b2
Removed executor todo that looks dead. (#6234)
michaelgiba Dec 31, 2020
1b23b28
Remove MKLML/openblas/jemalloc build config (#6212)
snnn Dec 31, 2020
3911105
Remove python 3.5
snnn Dec 31, 2020
c15a858
Update the readme file
snnn Dec 31, 2020
39a988c
Upgrade build.py to assert for python 3.6+
WilliamTambellini Dec 1, 2020
4cc2ffe
Support MLFloat16 type in Pow opset-12 CUDA kernel (#6233)
hariharans29 Dec 31, 2020
ecb2e11
MLAS: handle MlasGemm(M/N/K==0) cases (#6238)
tracysh Dec 31, 2020
70e2f96
Support double for operator TopK + fix one bug in TopK implementation…
xadupre Dec 31, 2020
5968a91
Support double for operator Gemm + fix bug in gemm implementation for…
xadupre Dec 31, 2020
84addcd
Support double for operator ReduceMean, ReduceLogSumExp (#6217)
xadupre Dec 31, 2020
cd14c1a
Support double for operator ArgMin (#6222)
xadupre Dec 31, 2020
d5cb17c
Update BUILD.md
snnn Dec 31, 2020
1685167
Update manylinux docker image to the latest (#6242)
snnn Jan 1, 2021
ffb4b62
Fix allocator issue for TensorRT IOBinding (#6240)
HectorSVC Jan 1, 2021
46e0e4e
Tune BiasGeluGradDx kernel in approximation mode to avoid tanh(...) o…
Jan 2, 2021
c8de3f3
Refactor EP Perf Tool (#6202)
oliviajain Jan 4, 2021
93bf7c4
Documentation for distributed CI tests pipeline (#6140)
baijumeswani Jan 4, 2021
6fd9d34
Remove a debug log in provider_test_utils.cc (#6200)
snnn Jan 4, 2021
493bf93
Add the Concat Slice Elimination transform, fix constant_folding tran…
ashbhandare Jan 5, 2021
ce6161c
Add MakeStringLite which uses current locale, update some MakeString …
edgchen1 Jan 5, 2021
addb4b8
Liqun/speech model loop to scan (#6070)
liqunfu Jan 5, 2021
eea3806
model parallel refinement (#6244)
pengwa Jan 6, 2021
d42399e
Allow querying a GraphProto's doc_string as part of ModelMetadata (#6…
hariharans29 Jan 6, 2021
2347de4
Fix Linux/Mac error message on input type mismatch (#6256)
hariharans29 Jan 6, 2021
431604e
add bfloat16 to gathergrad type constrains (#6267)
souptc Jan 6, 2021
bbc9ed9
Fix VS 2017 build break (#6276)
hariharans29 Jan 7, 2021
d761571
Deprecate Python global configuration functions [Part 2] (#6171)
edgchen1 Jan 7, 2021
481a2cd
Add script to preprocess python documentation before publishing (#6129)
xadupre Jan 7, 2021
b80e8ce
rename past to past_key_values for GPT-2 (#6269)
tianleiwu Jan 7, 2021
c109486
Rename MakeString and ParseString functions. (#6272)
edgchen1 Jan 7, 2021
04287ec
Increase timeout for Linux GPU CUDA11 build. (#6280)
edgchen1 Jan 7, 2021
a72fcbd
Add helper to compare model with different precision (#6270)
wangyems Jan 8, 2021
7fc827a
Fix Min/Max CPU kernels for float16 type (#6205)
hariharans29 Jan 8, 2021
ac5ca2b
fix data_ptr assertion error for past_sequence_length=0 in GPT-2 (#6284)
tianleiwu Jan 8, 2021
da952a9
A list of changes in transformers tool (#6224)
wangyems Jan 8, 2021
1059bfa
Workaround for static_cast<double>(half)
jessebenson Jan 8, 2021
fa851bf
Add workaround to remove ROCm-specific binary-elementwise files.
jessebenson Jan 8, 2021
5084ce0
Update nuget build (#6297)
snnn Jan 11, 2021
84024bd
Enable ONNX backend test of SequenceProto input/output (#6043)
jcwchen Jan 11, 2021
938e65d
add --sequence_lengths option (#6285)
tianleiwu Jan 11, 2021
ac5b5e5
more dtype for Equal CUDA kernel (#6288)
centwang Jan 12, 2021
c43ca45
Force reinstall onnx python package on Windows (#6309)
snnn Jan 12, 2021
a038924
update transformers required package versions (#6315)
tianleiwu Jan 12, 2021
3b3e698
Remove abs in LpPool (#6303)
luyaor Jan 12, 2021
a825766
Support 1D input for Conv + Mul/Add fusion optimizer with test (#6295)
zhanghuanrong Jan 12, 2021
ec81e29
Add longformer to python package (#6314)
tianleiwu Jan 12, 2021
b491d7c
Avoid false sharing on thread pool data structures (#6298)
tlh20 Jan 12, 2021
0ed56d4
fix opset imports for function body (#6287)
askhade Jan 12, 2021
aacc8db
Remove false positive prefast warning from threadpool (#6324)
tlh20 Jan 12, 2021
6b73bae
Java: add Semmle to Java publishing pipelines (#6326)
yuslepukhin Jan 12, 2021
f77ff1b
Quantization support for split operator with its NHWC support (#6107)
zhanghuanrong Jan 13, 2021
aeca96c
Liqun/enable pipeline parallel test (#6331)
liqunfu Jan 13, 2021
5623cc6
Use onnxruntime_USE_FULL_PROTOBUF=OFF for the cuda execution provider…
alberto-magni Jan 13, 2021
87ec1f6
MLAS: add fallback implementation for quantized GEMM (#6335)
tracysh Jan 13, 2021
56ab216
Delete float16.py (#6336)
oliviajain Jan 13, 2021
62e4045
Enable add + softmax fusion for Rocm platform (#6259)
Jan 13, 2021
f7034b9
add external data support to tensor proto utils (#6257)
askhade Jan 13, 2021
d367941
changed wording. (#6337)
Jan 13, 2021
cfd6f10
Remove OpSchema dummy definition. Only needed for Function now, and w…
skottmckay Jan 13, 2021
fcd9fc9
remove gemmlowp submodule (#6341)
tracysh Jan 13, 2021
b220fee
[NNAPI] Add pow support (#6310)
guoyu-wang Jan 14, 2021
042053c
Add support for running Android emulator from build.py on Windows. (#…
edgchen1 Jan 14, 2021
e35db19
fix the pipeline failure (#6346)
guoyu-wang Jan 14, 2021
4df356d
Train BERT Using BFloat16 on A100 (#6090)
centwang Jan 14, 2021
5b9d993
Fix DerefNullPtr issues raised by SDLNativeRules. (#6348)
pranavsharma Jan 14, 2021
c24f295
update quantize to support basic optimization and e2e example for ima…
yufenglee Jan 14, 2021
fd21c84
Enable graph save for orttrainer (#6333)
ashbhandare Jan 14, 2021
ea6789b
Add PREfast to python packaging pipeline (#6343)
snnn Jan 14, 2021
5d9552c
fix longformer benchmark io_binding output_buffers (#6345)
wangyems Jan 14, 2021
e54e2f9
Use readelf for minimal build binary size checks. (#6338)
skottmckay Jan 14, 2021
6d0fb3e
Java: Set C language warnings to W4 and adjust JNI code (#6347)
yuslepukhin Jan 14, 2021
8ce252c
Pipeline Parallel Experimental Python API (#5815)
wschin Jan 15, 2021
961bb62
Add create session to WinML telemetry to track WinML Usage (#6356)
Jan 15, 2021
c8e37e3
Fix one more SDL warning (#6359)
pranavsharma Jan 15, 2021
f5a4f7f
fix -Wdangling-gsl (#6357)
askhade Jan 15, 2021
eab164e
Add python example of TensorRT INT8 inference on ResNet model (#6255)
stevenlix Jan 15, 2021
4db4982
This added telemetry isn't needed (#6363)
Jan 16, 2021
5b6753c
Wezuo/memory analysis (#5658)
wezuo Jan 19, 2021
baac7c9
Support MLFloat16 in CumSum Cuda op for Opset 14 (#6355)
tianleiwu Jan 19, 2021
ac36596
fix convert_common version retrival (#6382)
wangyems Jan 19, 2021
d7bdd96
Refine auto_pad based pad computation in ConvTranspose (#6305)
hariharans29 Jan 20, 2021
a1b5bfc
Fix SDL warning (#6390)
hariharans29 Jan 20, 2021
453431f
Add max_norm for gradient clipping. (#6289)
pengwa Jan 20, 2021
69af044
Add the custom op project information (#6334)
wenbingl Jan 20, 2021
33f60a0
Dont use default string marshalling in C# (#6219)
hariharans29 Jan 21, 2021
d9e4795
Fix Windows x86 compiler warnings in the optimizers project (#6377)
hariharans29 Jan 21, 2021
8574854
[Perf] Optimize Tile CPU and CUDA kernels for a corner case (#6376)
hariharans29 Jan 21, 2021
eb946c4
Unblock Android CI code coverage failure (#6393)
guoyu-wang Jan 21, 2021
99a38f4
fix build on cuda11 (#6394)
centwang Jan 21, 2021
98cc7b5
Load the model path correctly (#6369)
MartinMoon Jan 21, 2021
bba185a
Fix some compile warnings (#6316)
snnn Jan 22, 2021
4442d94
OpenVino docker file changes to bypass privileged mode
smkarlap Jan 22, 2021
60c772e
Megatron checkpointing (#6293)
ashbhandare Jan 22, 2021
61ecf52
Fix generate_submodule_cgmanifest.py Windows issues. (#6404)
edgchen1 Jan 22, 2021
3c3d363
Continue memory planning when unknown shape tensor is encountered. (#…
codemzs Jan 22, 2021
6507b4f
Reintroduce experimental api changes and fix remote build break (#6385)
Jan 22, 2021
e1dc268
Add support for custom ops to minimal build. (#6228)
skottmckay Jan 25, 2021
c20965f
enable pipeline to run quantization tests (#6416)
yufenglee Jan 25, 2021
24f1bd6
Minor cmake change (#6431)
hariharans29 Jan 25, 2021
6ed1240
Liqun/liqun/enable pipeline parallel test2 (#6399)
liqunfu Jan 25, 2021
f3a0344
Farewell TrainableDropout (#5793)
codemzs Jan 26, 2021
7e42840
fix null dereference warning (#6437)
yufenglee Jan 26, 2021
76dbd88
Expose graph ModelPath to TensorRT shared library (#6353)
stevenlix Jan 26, 2021
afd7b8b
add tool for generating test data for longformer (#6415)
tianleiwu Jan 27, 2021
0d20104
only build experimental api in redist (#6465)
smk2007 Jan 27, 2021
9835b46
Add an option to save the training graph after optimization (#6410)
ryotatomioka Jan 27, 2021
b5d1a49
Share allocator between CUDA EP & TRT EP. (#6332)
HectorSVC Jan 27, 2021
fd43806
fix max norm clipping test in python packaging pipeline test (#6468)
pengwa Jan 27, 2021
c05adb1
Initial version of CoreML EP (#6392)
guoyu-wang Jan 27, 2021
d5f51c4
Bug 31463811: Servicing: Redist (Nuget) conflicts with Microsoft.AI.M…
smk2007 Jan 27, 2021
f68eb35
dequantize 1st input of lstm back if it is quantized (#6444)
yufenglee Jan 27, 2021
0100f33
[java] Adds support for OrtEnvironment thread pools (#6406)
Craigacp Jan 27, 2021
1ce1a51
fix SDL native rule warning #6246 (#6461)
fs-eire Jan 27, 2021
ed1ebd2
fix SDL rule (#6464)
fs-eire Jan 27, 2021
b6ac35f
use tickcount64 (#6447)
Jan 27, 2021
7a0ab9c
Update pypi package metadata (#6354)
faxu Jan 28, 2021
91b19b8
Delete nuget extra configs (#6477)
snnn Jan 28, 2021
d850fa6
Op kernel type reduction infrastructure. (#6466)
edgchen1 Jan 28, 2021
77d0eb3
Fixing a leak in OnnxSequences with String keys or values. (#6473)
Craigacp Jan 28, 2021
2e228d7
Increase the distributes tests pipeline timeout to 120 minutes (#6479)
baijumeswani Jan 28, 2021
752627c
[CoreML EP] Add CI for CoreML EP (macOS) and add coreml_flags for EP …
guoyu-wang Jan 28, 2021
c84bb9d
Add ability to track per operator types in reduced build config. (#6428)
skottmckay Jan 28, 2021
00afd00
merge e2e with distributed pipeline (#6443)
liqunfu Jan 28, 2021
ea2b560
Fix test breaks in Windows ingestion pipeline (#6476)
smk2007 Jan 28, 2021
3f60b27
Speed up the Mac CI runs (#6483)
guoyu-wang Jan 28, 2021
ce46f37
expose learningmodelpixelrange property (#5877)
zhangxiang1993 Jan 28, 2021
d4e1f5a
Fix of support api version bug for [de]quantize (#6492)
guoyu-wang Jan 29, 2021
21b4842
SDL fixes: add proper casts/format specifiers (#6446)
Jan 29, 2021
3b1227c
SDL annotation fixes (#6448)
Jan 29, 2021
1a5b75a
[OpenVINO-EP] Remove support for OpenVINO 2020.2 (#6493)
suryasidd Jan 29, 2021
7abb5b6
Support pad operator in quantization and quantized nhwc transformer. …
zhanghuanrong Jan 29, 2021
066520f
Improve work distribution for Expand operator, and sharded LoopCounte…
tlh20 Jan 29, 2021
d3203ad
Update document of transformer optimization (#6487)
tianleiwu Jan 29, 2021
71389ff
nuphar test to avoid test data download to improve passing rate (#6467)
liqunfu Jan 29, 2021
a19c48f
Fuse cuda conv with activation (#6351)
RandySheriffH Jan 29, 2021
06a6c63
[CoreML EP] Add support for some activations/Transpose, move some sha…
guoyu-wang Jan 29, 2021
8306150
Refine transformers profiler output (#6502)
tianleiwu Jan 29, 2021
8c6d76a
Update to match new test setup. (#6496)
skottmckay Jan 29, 2021
76bc0e4
Enable dense sequence optimized version of Pytorch exported BERT-L on…
Jan 29, 2021
7f57317
Optimize GatherGrad for AMD GPU (#6381)
weixingzhang Jan 29, 2021
76f5d9e
add explicit barriers for buffer overread and overrwrite (#6484)
Jan 29, 2021
531eb06
fix sdl bugs for uninitialized variables and returns (#6450)
Jan 29, 2021
3a30ad7
handle hr error conditions (#6449)
Jan 29, 2021
a36f627
Dnnl training (#6045)
georgen117 Jan 30, 2021
7c5bfba
Lochi/refactor yolov3 quantization (#6290)
chilo-ms Jan 30, 2021
f2872ff
Print a warning message for using newer c_api header on old binary (#…
guoyu-wang Jan 30, 2021
e5cbcec
Fix issues with ArmNN build setup (#6495)
skottmckay Jan 30, 2021
5b69cbe
Fix Windows CI builds by updating test scripts to work with numpy 1.2…
skottmckay Feb 1, 2021
891181d
Fix ORTModule branch for orttraining-* pipelines
Jan 29, 2021
6b890c2
Merge remote-tracking branch 'origin/master' into thiagofc/fix-orttra…
Feb 1, 2021
0432fa7
Update pytorch nightly version dependency
Feb 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions onnxruntime/test/python/onnxruntime_test_ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,17 @@ def get_loaders(self):
args_test_batch_size = 1000

kwargs = {'num_workers': 0, 'pin_memory': True}
# set shuffle to False to get deterministic data set among different torch version
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(os.path.join(SCRIPT_DIR, 'data'), train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args_batch_size, shuffle=True, **kwargs)
batch_size=args_batch_size, shuffle=False, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(os.path.join(SCRIPT_DIR, 'data'), train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args_test_batch_size, shuffle=True, **kwargs)
batch_size=args_test_batch_size, shuffle=False, **kwargs)

return train_loader, test_loader

Expand Down Expand Up @@ -306,13 +307,13 @@ def run_mnist_training_and_testing(onnx_opset_ver):

learningRate = 0.01
args_epochs = 2
expected_losses = [2.333008289337158, 1.0680292844772339, 0.6300537586212158, 0.5279903411865234,
0.3710068166255951, 0.4044453501701355, 0.30482712388038635, 0.4595026969909668,
0.42305776476860046, 0.4797358512878418, 0.23006735742092133, 0.48427966237068176,
0.30716797709465027, 0.3238796889781952, 0.19543828070163727, 0.3561663031578064,
0.3089643716812134, 0.37738722562789917, 0.24883587658405304, 0.30744990706443787]
expected_test_losses = [0.31038025817871095, 0.25183824462890625]
expected_test_accuracies = [0.9125, 0.9304]
expected_losses = [2.312044143676758, 0.8018650412559509, 0.5819257497787476, 0.47025489807128906,
0.35800155997276306, 0.41124576330184937, 0.2731882333755493, 0.4201386570930481,
0.39458805322647095, 0.38380366563796997, 0.2722422480583191, 0.24230478703975677,
0.23505745828151703, 0.33442264795303345, 0.21140924096107483, 0.31545233726501465,
0.18556523323059082, 0.3453553020954132, 0.29598352313041687, 0.3595045208930969]
expected_test_losses = [0.3145490005493164, 0.256188737487793]
expected_test_accuracies = [0.9075, 0.9265]

actual_losses = []
actual_test_losses, actual_accuracies = [], []
Expand Down Expand Up @@ -356,11 +357,11 @@ def testMNISTResumeTrainingAndTesting(self):
args_epochs = 2
args_checkpoint_epoch = 1
# should match those in test without checkpointing
expected_losses = [0.23006735742092133, 0.48427966237068176,
0.30716797709465027, 0.3238796889781952, 0.19543828070163727, 0.3561663031578064,
0.3089643716812134, 0.37738722562789917, 0.24883587658405304, 0.30744990706443787]
expected_test_losses = [0.25183824462890625]
expected_test_accuracies = [0.9304]
expected_losses = [0.26509523391723633, 0.24135658144950867, 0.2397943139076233, 0.3351520597934723,
0.20998981595039368, 0.31488314270973206, 0.18481917679309845, 0.34727591276168823,
0.2971782684326172, 0.3609251379966736]
expected_test_losses = [0.25632242965698243]
expected_test_accuracies = [0.9264]

actual_losses = []
actual_test_losses, actual_accuracies = [], []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<i
config_.input_names_require_grad.begin(), config_.input_names_require_grad.end(),
std::inserter(x_node_arg_names, x_node_arg_names.begin()));
auto add_transformers = [&](TransformerLevel level) {
std::unordered_map<std::string, std::string> updated_weight_names{};
auto transformers_to_register = transformer_utils::GeneratePreTrainingTransformers(
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider, updated_weight_names, {});
level, x_node_arg_names, graph_transformer_config, *cpu_execution_provider);
for (auto& entry : transformers_to_register) {
graph_transformation_mgr.Register(std::move(entry), level);
}
Expand Down
4 changes: 0 additions & 4 deletions orttraining/orttraining/core/session/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,16 +424,12 @@ Status TrainingSession::ConfigureForTraining(
// conflict. It is user's responsibility to make sure different rank is passed in with different. Also, to avoid
// writing conflict, only the ranks in first pipeline group write the partition file out.
// model_with_training_graph_path value.
#if 0
// TODO: Do not merge this on master
// This is being called above, before optimizers nodes are added
if ((IsRootNode(config) || (config.pipeline_config.has_value() &&
DistributedRunContext::GroupId(WorkerGroupType::ModelParallel) == 0)) &&
config.model_with_training_graph_path.has_value()) {
ORT_IGNORE_RETURN_VALUE(Save(
config.model_with_training_graph_path.value(), SaveOption::NO_RELOAD));
}
#endif

// After pipeline partition, we need to return the inputs allowed in this partition.
if (config.pipeline_config.has_value()) {
Expand Down
9 changes: 1 addition & 8 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ using namespace onnxruntime::logging;
using namespace onnxruntime::training;

struct TrainingParameters {
std::string model_with_loss_function_path;
std::string model_with_training_graph_path;

std::string loss_output_name;
std::unordered_set<std::string> weights_to_train;
std::unordered_set<std::string> weights_not_to_train;
Expand Down Expand Up @@ -87,9 +84,7 @@ TrainingConfigurationResult ConfigureSessionForTraining(
parameters.data_parallel_size = data_group_size;
}

training::TrainingSession::TrainingConfiguration config{};
config.model_with_loss_function_path = parameters.model_with_loss_function_path;
config.model_with_training_graph_path = parameters.model_with_training_graph_path;
training::PipelineTrainingSession::TrainingConfiguration config{};
config.weight_names_to_train = parameters.weights_to_train;
config.weight_names_to_not_train = parameters.weights_not_to_train;
config.immutable_weights = parameters.immutable_weights;
Expand Down Expand Up @@ -196,8 +191,6 @@ void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const lo
void addObjectMethodsForTraining(py::module& m) {
py::class_<TrainingParameters> parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc");
parameters.def(py::init())
.def_readwrite("model_with_loss_function_path", &TrainingParameters::model_with_loss_function_path)
.def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path)
.def_readwrite("loss_output_name", &TrainingParameters::loss_output_name)
.def_readwrite("immutable_weights", &TrainingParameters::immutable_weights)
.def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train)
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/python/training/ortmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _create_training_session(self):
if self._device.type == 'cuda':
# Configure the InferenceSessions to use the specific GPU on which the model is placed.
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
provider_options = [{"device_id": str(self._device.index)}]
provider_options = [{"device_id": str(self._device.index)}, {}]
elif self._device.type == 'cpu':
providers = ["CPUExecutionProvider"]
provider_options = [{}]
Expand Down
37 changes: 32 additions & 5 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@

from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

def monkey_patch_pytorch():
warnings.warn('ORTTrainer: Remove this monkey patch when https://github.com/pytorch/pytorch/pull/51396 is merged')

def ort_prim_ConstantChunk(g, self, chunks, dim):
input_shape = g.op("Shape", self)
axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
chunk_size_minus_1 = g.op("Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long))
input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
res = []
for i in range(chunks):
index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
end = g.op("Mul", chunk_dim, index)
res.append(g.op("Slice", self, start, end, axis))
start = end
return res

import torch.onnx.symbolic_opset11
torch.onnx.symbolic_opset11.prim_ConstantChunk = ort_prim_ConstantChunk


class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.

Expand Down Expand Up @@ -122,6 +146,11 @@ class ORTTrainer(object):
"""

def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):

# DO NOT MERGE THIS ON MASTER
# TODO: Remove after https://github.com/pytorch/pytorch/pull/51396 is merged
monkey_patch_pytorch()

# Basic validation
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
Expand Down Expand Up @@ -621,8 +650,6 @@ def _create_ort_training_session(self, state_dict = {}):

# TrainingParameters
ort_parameters = ort.TrainingParameters()
ort_parameters.model_with_loss_function_path = self.options.debug.model_with_loss_function_path
ort_parameters.model_with_training_graph_path = self.options.debug.model_with_training_graph_path
ort_parameters.loss_output_name = loss_name
ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled
ort_parameters.world_rank = self.options.distributed.world_rank
Expand All @@ -631,7 +658,7 @@ def _create_ort_training_session(self, state_dict = {}):
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = True
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
Expand All @@ -651,8 +678,8 @@ def _create_ort_training_session(self, state_dict = {}):
# SessionOptions
session_options = ort.SessionOptions()
session_options.use_deterministic_compute = self.options.debug.deterministic_compute
if (self.options.graph_transformer.attn_dropout_recompute or
self.options.graph_transformer.gelu_recompute or
if (self.options.graph_transformer.attn_dropout_recompute or
self.options.graph_transformer.gelu_recompute or
self.options.graph_transformer.transformer_layer_recompute):
session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED

Expand Down
78 changes: 58 additions & 20 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ORTTrainerOptions(object):
'transformer_layer_recompute': {
'type': 'boolean',
'default': False
},
},
'number_recompute_layers': {
'type': 'integer',
'min': 0,
Expand Down Expand Up @@ -176,13 +176,28 @@ class ORTTrainerOptions(object):
'type' : 'boolean',
'default' : False
},
'model_with_loss_function_path': {
'type': 'string',
'default': ''
},
'model_with_training_graph_path': {
'type': 'string',
'default': ''
'graph_save_paths' : {
'type' : 'dict',
'default': {},
'required': False,
'schema': {
'model_after_graph_transforms_path': {
'type': 'string',
'default': ''
},
'model_with_gradient_graph_path':{
'type': 'string',
'default': ''
},
'model_with_training_graph_path': {
'type': 'string',
'default': ''
},
'model_with_training_graph_after_optimization_path': {
'type': 'string',
'default': ''
},
},
},
}
},
Expand Down Expand Up @@ -281,10 +296,18 @@ class ORTTrainerOptions(object):
debug.check_model_export (bool, default is False)
compares PyTorch model outputs with ONNX model outputs in inference before the first
train step to ensure successful model export
debug.model_with_loss_function_path (str, default is '')
path to dump an ONNX file with model and loss function
debug.model_with_training_graph_path (str, default is '')
path to dump an ONNX file with full training graph
debug.graph_save_paths (dict):
paths used for dumping ONNX graphs for debugging purposes
debug.graph_save_paths.model_after_graph_transforms_path (str, default is "")
path to export the ONNX graph after training-related graph transforms have been applied.
No output when it is empty.
debug.graph_save_paths.model_with_gradient_graph_path (str, default is "")
path to export the ONNX graph with the gradient graph added. No output when it is empty.
debug.graph_save_paths.model_with_training_graph_path (str, default is "")
path to export the training ONNX graph with forward, gradient and optimizer nodes.
No output when it is empty.
debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "")
outputs the optimized training graph to the path if nonempty.
_internal_use (dict):
internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True):
Expand Down Expand Up @@ -542,15 +565,30 @@ def _check_is_callable(field, value, error):
'type': 'boolean',
'default': False
},
'model_with_loss_function_path': {
'type': 'string',
'default': ''
},
'model_with_training_graph_path': {
'type': 'string',
'default': ''
'graph_save_paths' : {
'type' : 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'model_after_graph_transforms_path': {
'type': 'string',
'default': ''
},
'model_with_gradient_graph_path':{
'type': 'string',
'default': ''
},
'model_with_training_graph_path': {
'type': 'string',
'default': ''
},
'model_with_training_graph_after_optimization_path': {
'type': 'string',
'default': ''
},
},
},
}
},
},
'_internal_use': {
'type': 'dict',
Expand Down
6 changes: 1 addition & 5 deletions samples/python/mnist/ort_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def main():
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-path', type=str, default='',
help='Path for Saving the current Model')
help='Path for Saving the current Model state')

# Basic setup
args = parser.parse_args()
Expand Down Expand Up @@ -141,10 +141,6 @@ def main():
model_desc = mnist_model_description()
optim_config = optim.SGDConfig(lr=args.lr)
opts = {'device': {'id': device}}
if args.save_path:
opts.update({'debug': {
'model_with_loss_function_path' : os.path.join(args.save_path, 'model_with_loss.onnx'),
'model_with_training_graph_path' : os.path.join(args.save_path, 'model_with_training.onnx'),}})
opts = ORTTrainerOptions(opts)

trainer = ORTTrainer(model,
Expand Down