Skip to content

Commit

Permalink
coreml : use Core ML encoder inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 5, 2023
1 parent 72af0f5 commit b0ac915
Show file tree
Hide file tree
Showing 9 changed files with 643 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
*.o
*.a
*.mlmodel
*.mlmodelc
.cache/
.vs/
.vscode/
Expand Down
68 changes: 59 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ if (APPLE)
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)

option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
else()
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
endif()
Expand Down Expand Up @@ -86,16 +88,33 @@ endif()

find_package(Threads REQUIRED)

# on APPLE - include Accelerate framework
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")
# on APPLE
if (APPLE)
# include Accelerate framework
if (NOT WHISPER_NO_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)

if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")

set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(WARNING "Accelerate framework not found")
endif()
endif()

if (WHISPER_COREML)
find_library(FOUNDATION_FRAMEWORK Foundation)
find_library(COREML_FRAMEWORK CoreML)

if (COREML_FRAMEWORK)
message(STATUS "CoreML framework found")

set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
else()
message(WARNING "CoreML framework not found")
endif()
endif()
endif()

Expand Down Expand Up @@ -181,6 +200,33 @@ if (WHISPER_PERF)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
endif()

#
# whisper.coreml - Core ML support
#

if (WHISPER_COREML)
set(TARGET whisper.coreml)

add_library(${TARGET}
coreml/whisper-encoder.h
coreml/whisper-encoder.mm
coreml/whisper-encoder-impl.h
coreml/whisper-encoder-impl.m
)

include(DefaultTargetOptions)

target_include_directories(${TARGET} PUBLIC
.
)

target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})

set_target_properties(${TARGET} PROPERTIES
COMPILE_FLAGS "-fobjc-arc"
)
endif()

#
# whisper - this is the main library of the project
#
Expand All @@ -200,6 +246,10 @@ target_include_directories(${TARGET} PUBLIC
.
)

if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()

if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})

Expand Down
44 changes: 30 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ ifndef WHISPER_NO_ACCELERATE
LDFLAGS += -framework Accelerate
endif
endif
ifdef WHISPER_COREML
CXXFLAGS += -DWHISPER_USE_COREML
LDFLAGS += -framework Foundation -framework CoreML
endif
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas
Expand Down Expand Up @@ -184,11 +188,23 @@ ggml.o: ggml.c ggml.h
whisper.o: whisper.cpp whisper.h
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o

libwhisper.a: ggml.o whisper.o
$(AR) rcs libwhisper.a ggml.o whisper.o
ifndef WHISPER_COREML
WHISPER_OBJ = whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o

whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o

WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
endif

libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)

libwhisper.so: ggml.o whisper.o
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)

clean:
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
Expand All @@ -202,21 +218,21 @@ CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp

main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
./main -h

stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)

command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)

talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)

bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)

#
# Audio samples
Expand Down
142 changes: 142 additions & 0 deletions coreml/whisper-encoder-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
//
// CoremlEncoder.h
//
// This file was automatically generated and should not be edited.
//

#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>
#include <stdint.h>
#include <os/log.h>

NS_ASSUME_NONNULL_BEGIN


/// Model Prediction Input Type
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface CoremlEncoderInput : NSObject<MLFeatureProvider>

/// melSegment as 1 × 80 × 3000 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * melSegment;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithMelSegment:(MLMultiArray *)melSegment NS_DESIGNATED_INITIALIZER;

@end


/// Model Prediction Output Type
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface CoremlEncoderOutput : NSObject<MLFeatureProvider>

/// output as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * output;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;

@end


/// Class for model loading and prediction
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface CoremlEncoder : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;

/**
URL of the underlying .mlmodelc directory.
*/
+ (nullable NSURL *)URLOfModelInThisBundle;

/**
Initialize CoremlEncoder instance from an existing MLModel object.
Usually the application does not use this initializer unless it makes a subclass of CoremlEncoder.
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;

/**
Initialize CoremlEncoder instance with the model in this bundle.
*/
- (nullable instancetype)init;

/**
Initialize CoremlEncoder instance with the model in this bundle.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Initialize CoremlEncoder instance from the model URL.
@param modelURL URL to the .mlmodelc directory for CoremlEncoder.
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Initialize CoremlEncoder instance from the model URL.
@param modelURL URL to the .mlmodelc directory for CoremlEncoder.
@param configuration The model configuration object
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
*/
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Construct CoremlEncoder instance asynchronously with configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Construct CoremlEncoder instance asynchronously with URL of .mlmodelc directory and optional configuration.
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
@param modelURL The model URL.
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Make a prediction using the standard interface
@param input an instance of CoremlEncoderInput to predict from
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as CoremlEncoderOutput
*/
- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Make a prediction using the standard interface
@param input an instance of CoremlEncoderInput to predict from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as CoremlEncoderOutput
*/
- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Make a prediction using the convenience interface
@param melSegment as 1 × 80 × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as CoremlEncoderOutput
*/
- (nullable CoremlEncoderOutput *)predictionFromMelSegment:(MLMultiArray *)melSegment error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Batch prediction
@param inputArray array of CoremlEncoderInput instances to obtain predictions from
@param options prediction options
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the predictions as NSArray<CoremlEncoderOutput *>
*/
- (nullable NSArray<CoremlEncoderOutput *> *)predictionsFromInputs:(NSArray<CoremlEncoderInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
@end

NS_ASSUME_NONNULL_END
Loading

0 comments on commit b0ac915

Please sign in to comment.