Skip to content

Commit

Permalink
Integrate intgemm into marian (#595)
Browse files Browse the repository at this point in the history
Adds intgemm as a module for Marian. Intgemm is @kpu 's 8/16 bit gemm library with support for architectures from SSE2 to AVX512VNNI
Removes outdated integer code, related to the --optimize option

Co-authored-by: Kenneth Heafield <github@kheafield.com>
Co-authored-by: Kenneth Heafield <kpu@users.noreply.github.com>
Co-authored-by: Ulrich Germann <ugermann@inf.ed.ac.uk>
Co-authored-by: Marcin Junczys-Dowmunt <marcinjd@microsoft.com>
Co-authored-by: Roman Grundkiewicz <rgrundkiewicz@gmail.com>
  • Loading branch information
6 people authored Jan 25, 2021
1 parent 737f430 commit 600f5cb
Show file tree
Hide file tree
Showing 39 changed files with 986 additions and 1,673 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
path = src/3rd_party/fbgemm
url = https://github.com/marian-nmt/FBGEMM
branch = master
[submodule "src/3rd_party/intgemm"]
path = src/3rd_party/intgemm
url = https://github.com/marian-nmt/intgemm/
[submodule "src/3rd_party/simple-websocket-server"]
path = src/3rd_party/simple-websocket-server
url = https://github.com/marian-nmt/Simple-WebSocket-Server
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Added `intgemm8(ssse3|avx|avx512)?`, `intgemm16(sse2|avx|avx512)?` types to marian-conv with uses intgemm backend. Types intgemm8 and intgemm16 are hardware-agnostic, the other ones hardware-specific.
- Shortlist is now always multiple-of-eight.
- Added intgemm 8/16bit integer binary architecture agnostic format.
- Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss
- Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument.
- Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation
Expand Down Expand Up @@ -56,6 +59,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Fix the runtime failures for FASTOPT on 32-bit builds (wasm just happens to be 32-bit) because it uses hashing with an inconsistent mix of uint64_t and size_t.

### Changed
- Remove `--clip-gemm` which is obsolete and was never used anyway
- Removed `--optimize` switch, instead we now determine compute type based on binary model.
- Updated SentencePiece repository to version 8336bbd0c1cfba02a879afe625bf1ddaf7cd93c5 from https://github.com/google/sentencepiece.
- Enabled compilation of SentencePiece by default since no dependency on protobuf anymore.
- Changed default value of --sentencepiece-max-lines from 10000000 to 2000000 since apparently the new version doesn't sample automatically anymore (Not quite clear how that affects quality of the vocabulary).
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ if(MSVC)
# Or maybe use these?
set(INTRINSICS "/arch:AVX2")
# set(INTRINSICS "/arch:AVX512")

set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS ${DISABLE_GLOBALLY}")
# /bigobj is necessary for expression_operators.cpp. See https://stackoverflow.com/questions/15110580/penalty-of-the-msvs-compiler-flag-bigobj
set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS /bigobj ${DISABLE_GLOBALLY}")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /Zi /MP /GL /DNDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")

Expand Down Expand Up @@ -438,6 +438,7 @@ endif(USE_MPI)
###############################################################################
# Find BLAS library
if(COMPILE_CPU)
set(EXT_LIBS ${EXT_LIBS} intgemm) # Enable intgemm when compiling CPU
if(USE_APPLE_ACCELERATE)
if(NOT APPLE)
message(FATAL_ERROR "FATAL ERROR: Apple Accelerate only works on macOS.")
Expand Down
2 changes: 1 addition & 1 deletion regression-tests
Submodule regression-tests updated 65 files
+4 −0 README.md
+4 −0 tests/decoder/intgemm/.gitignore
+2 −0 tests/decoder/intgemm/README.md
+55 −0 tests/decoder/intgemm/_test_fbgemm_packed8_intgemm_int8.sh
+56 −0 tests/decoder/intgemm/_test_fbgemm_packed8_intgemm_int8_shifted.sh
+53 −0 tests/decoder/intgemm/_test_intgemm_16bit.sh
+53 −0 tests/decoder/intgemm/_test_intgemm_8bit.sh
+57 −0 tests/decoder/intgemm/_test_intgemm_8bit_shifted.sh
+100 −0 tests/decoder/intgemm/intgemm_16bit.avx.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit.avx.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit.avx512.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit_avx2.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit_avx2.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit_avx2.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit_avx2.avx512.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_16bit_sse2.avx512.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit.avx.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit.avx.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit.avx512.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit_avx2.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit_avx2.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit_avx2.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit_avx2.avx512.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx2.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx2.expected.bleu
+100 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx512.expected
+1 −0 tests/decoder/intgemm/intgemm_8bit_ssse3.avx512.expected.bleu
+14 −0 tests/decoder/intgemm/setup.sh
+57 −0 tests/decoder/intgemm/test_intgemm_16bit.sh
+56 −0 tests/decoder/intgemm/test_intgemm_16bit_avx2.sh
+58 −0 tests/decoder/intgemm/test_intgemm_16bit_sse2.sh
+57 −0 tests/decoder/intgemm/test_intgemm_8bit.sh
+56 −0 tests/decoder/intgemm/test_intgemm_8bit_avx2.sh
+58 −0 tests/decoder/intgemm/test_intgemm_8bit_ssse3.sh
+29 −0 tests/decoder/intgemm/update_expected_outputs.sh
+3 −2 tests/models/wnmt18/.gitignore
+1 −1 tests/models/wnmt18/optimize_aan.bleu.expected
+41 −0 tests/models/wnmt18/test_student_small_aan_intgemm16.sh
+15 −8 tests/models/wnmt18/test_student_small_aan_intgemm8.sh
+1 −0 tests/training/features/quantized-model/.gitignore
+51 −0 tests/training/features/quantized-model/model_centers.expected
+10 −0 tests/training/features/quantized-model/quantized-log4bit.expected
+10 −0 tests/training/features/quantized-model/quantized-opt.expected
+10 −0 tests/training/features/quantized-model/quantized-with-bias.expected
+10 −0 tests/training/features/quantized-model/quantized.expected
+3 −0 tests/training/features/quantized-model/setup.sh
+31 −0 tests/training/features/quantized-model/test_quant_centers.sh
+36 −0 tests/training/features/quantized-model/test_quantmodel.sh
+36 −0 tests/training/features/quantized-model/test_quantmodel_log.sh
+43 −0 tests/training/features/quantized-model/test_quantmodel_with_bias.sh
+36 −0 tests/training/features/quantized-model/test_quantmodel_with_optimization.sh
+59 −0 tools/check-model-unique-vals.py
5 changes: 5 additions & 0 deletions src/3rd_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ add_subdirectory(./zlib)
add_subdirectory(./faiss)
include_directories(./faiss)

if(COMPILE_CPU)
set(INTGEMM_DONT_BUILD_TESTS ON CACHE BOOL "Disable intgemm tests")
add_subdirectory(./intgemm)
endif(COMPILE_CPU)

if(USE_FBGEMM)
# @TODO: find out if this is somehow harmful. This is supppressing CMake warnings for CMAKE_SUPPRESS_DEVELOPER_WARNINGS
# meant to silence CMakeFiles of 3rd_party tools.
Expand Down
1 change: 1 addition & 0 deletions src/3rd_party/intgemm
Submodule intgemm added at 874cee
8 changes: 4 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ include_directories(3rd_party)
include_directories(3rd_party/SQLiteCpp/include)
include_directories(3rd_party/sentencepiece)
include_directories(3rd_party/fbgemm/include)
include_directories(3rd_party/intgemm)
include_directories(${CMAKE_BINARY_DIR}/src/3rd_party/intgemm) # running cmake on the intgemm submodule triggers config file generation in this directory.
include_directories(${CMAKE_BINARY_DIR}/local/include)

set(MARIAN_SOURCES
Expand Down Expand Up @@ -41,6 +43,7 @@ set(MARIAN_SOURCES

3rd_party/cnpy/cnpy.cpp
3rd_party/ExceptionWithCallStack.cpp

3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.cpp

3rd_party/phf/phf.cc
Expand All @@ -52,10 +55,7 @@ set(MARIAN_SOURCES
tensors/cpu/prod.cpp
tensors/cpu/topk.cpp
tensors/cpu/tensor_operators.cpp

tensors/cpu/sharp/int_gemm.cpp
tensors/cpu/sharp/avx_gemm.cpp
tensors/cpu/sharp/sse_gemm.cpp
tensors/cpu/integer_common.cpp
tensors/cpu/fbgemm/packed_gemm.cpp

graph/expression_graph.cpp
Expand Down
28 changes: 8 additions & 20 deletions src/command/marian_conv.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#include "marian.h"

#include "common/cli_wrapper.h"
#include "tensors/cpu/expression_graph_packable.h"
#include "onnx/expression_graph_onnx_exporter.h"

#include <sstream>

#include "tensors/cpu/fbgemm/expression_graph_packable.h"
#include "onnx/expression_graph_onnx_exporter.h"

int main(int argc, char** argv) {
using namespace marian;

Expand All @@ -24,7 +22,9 @@ int main(int argc, char** argv) {
cli->add<std::string>("--from,-f", "Input model", "model.npz");
cli->add<std::string>("--to,-t", "Output model", "model.bin");
cli->add<std::string>("--export-as", "Kind of conversion: marian-bin or onnx-{encode,decoder-step,decoder-init,decoder-stop}", "marian-bin");
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512", "float32");
cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512, "
"intgemm8, intgemm8ssse3, intgemm8avx2, intgemm8avx512, intgemm16, intgemm16sse2, intgemm16avx2, intgemm16avx512",
"float32");
cli->add<std::vector<std::string>>("--vocabs,-V", "Vocabulary file, required for ONNX export");
cli->parse(argc, argv);
options->merge(config);
Expand All @@ -35,19 +35,8 @@ int main(int argc, char** argv) {
auto exportAs = options->get<std::string>("export-as");
auto vocabPaths = options->get<std::vector<std::string>>("vocabs");// , std::vector<std::string>());

auto saveGemmTypeStr = options->get<std::string>("gemm-type", "float32");
Type saveGemmType;
if(saveGemmTypeStr == "float32") {
saveGemmType = Type::float32;
} else if(saveGemmTypeStr == "packed16") { // packed16 only supports AVX2. AVX512 might be added later
saveGemmType = Type::packed16;
} else if(saveGemmTypeStr == "packed8avx2") { // packed8 for AVX2
saveGemmType = Type::packed8avx2;
} else if(saveGemmTypeStr == "packed8avx512") { // packed8 for AVX512
saveGemmType = Type::packed8avx512;
} else {
ABORT("Unknown gemm-type: {}", saveGemmTypeStr);
}
// We accept any type here and will later croak during packAndSave if the type cannot be used for conversion
Type saveGemmType = typeFromString(options->get<std::string>("gemm-type", "float32"));

LOG(info, "Outputting {}, precision: {}", modelTo, saveGemmType);

Expand All @@ -58,12 +47,11 @@ int main(int argc, char** argv) {

auto load = [&](Ptr<ExpressionGraph> graph) {
graph->setDevice(CPU0);
graph->getBackend()->setOptimized(false);

graph->load(modelFrom);
graph->forward(); // run the initializers
};


if (exportAs == "marian-bin") {
auto graph = New<ExpressionGraphPackable>();
load(graph);
Expand Down
21 changes: 20 additions & 1 deletion src/common/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "common/file_stream.h"
#include "common/io_item.h"
#include "common/types.h"
#include "tensors/cpu/integer_common.h"

#include <string>

Expand Down Expand Up @@ -57,13 +58,31 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
get<char>(current, offset);

for(int i = 0; i < numHeaders; ++i) {
// For intgemm AVX512 and AVX512VNNI have the same arangement, but the VNNI algorithm is faster.
// Change the type to the fastest one supported.
if (items[i].type == Type::intgemm8avx512) {
items[i].type = cpu::integer::getIntgemmType(Type::intgemm8);
}
if(items[i].mapped) { // memory-mapped, hence only set pointer
// @TOOD: verify this actually works for the hardware-specific ones like intgemm8avx2
ABORT_IF(items[i].type == Type::intgemm8 || items[i].type == Type::intgemm16, "mmap format not supported for hardware non-specific intgemm matrices");
items[i].ptr = get<char>(current, headers[i].dataLength);
} else { // reading into item data
size_t len = headers[i].dataLength;
items[i].bytes.resize(len);
const char* ptr = get<char>(current, len);
std::copy(ptr, ptr + len, items[i].bytes.begin());
// Intgemm8/16 matrices in binary model are just quantized, however they also need to be reordered
// Reordering depends on the architecture (SSE/AVX2/AVX512) so we read in the quantized matrices and
// then reorder them before adding them as a parameter in the graph.
if (matchType<intgemm8>(items[i].type)) {
items[i].type = cpu::integer::getIntgemmType(Type::intgemm8);
cpu::integer::prepareAndTransposeB<Type::intgemm8>(items[i], ptr);
} else if (matchType<intgemm16>(items[i].type)) {
items[i].type = cpu::integer::getIntgemmType(Type::intgemm16);
cpu::integer::prepareAndTransposeB<Type::intgemm16>(items[i], ptr);
} else {
std::copy(ptr, ptr + len, items[i].bytes.begin());
}
}
}
}
Expand Down
15 changes: 3 additions & 12 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
"Suppress logging for translation");
cli.add<size_t>("--seed",
"Seed for all random number generators. 0 means initialize randomly");
cli.add<float>("--clip-gemm",
"If not 0 clip GEMM input values to +/- arg");
cli.add<bool>("--interpolate-env-vars",
"allow the use of environment variables in paths, of the form ${VAR_NAME}");
cli.add<bool>("--relative-paths",
Expand Down Expand Up @@ -671,15 +669,13 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);

cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--skip-cost",
"Ignore model cost during translation, not recommended for beam-size > 1");
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision for inference, set parameter type in expression graph",
{"float32"});
cli.add<bool>("--skip-cost",
"Ignore model cost during translation, not recommended for beam-size > 1");

cli.add<std::vector<std::string>>("--shortlist",
"Use softmax shortlist: path first best prune");
Expand Down Expand Up @@ -737,8 +733,6 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);

cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
Expand Down Expand Up @@ -776,12 +770,10 @@ void ConfigParser::addOptionsEmbedding(cli::CLIWrapper& cli) {
addSuboptionsDevices(cli);
addSuboptionsBatching(cli);

cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--fp16",
"Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision for inference, set parameter type in expression graph",
"Mixed precision for inference, set parameter type in expression graph. Supported values: float32, float16",
{"float32"});

cli.switchGroup(previous_group);
Expand Down Expand Up @@ -934,7 +926,6 @@ void ConfigParser::addSuboptionsQuantization(cli::CLIWrapper& cli) {
// clang-format on
}


cli::mode ConfigParser::getMode() const { return mode_; }

Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
Expand Down
11 changes: 7 additions & 4 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ size_t requiredBytes(const Shape& shape, Type type) {
ABORT("Not a supported data type: {}", type);
return 0;
}
}
#endif // USE_FBGEMM

if (isIntgemm(type)) {
/* Intgemm tensors have an extra float at the back that stores the quantization multiplier */
return shape.elements() * sizeOf(type) + sizeOf(Type::float32);
} else {
return shape.elements() * sizeOf(type);
}
#else
return shape.elements() * sizeOf(type);
#endif // USE_FBGEMM

}

}
} // namespace marian
Loading

0 comments on commit 600f5cb

Please sign in to comment.