Skip to content

Commit

Permalink
Updating MLFlowClient.jl version and multiple code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Aug 12, 2023
1 parent 70e961e commit 951fb8f
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"

[compat]
MLFlowClient = "0.4.3"
MLFlowClient = "0.4.4"
MLJ = "0.19"
MLJBase = "0.21.11"
OrderedCollections = "1.1"
Expand Down
18 changes: 3 additions & 15 deletions src/MLJFlow.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
module MLJFlow

using MLJBase: info, name, Model,
Machine

Machine, deep_params, flat_params
using OrderedCollections: LittleDict
using MLFlowClient: MLFlow, logparam, logmetric,
createrun, MLFlowRun, updaterun,
logartifact, getorcreateexperiment

using OrderedCollections: LittleDict
healthcheck, logartifact, getorcreateexperiment

import MLJBase: save, log_evaluation

include("types.jl")
include("base.jl")
include("client.jl")
include("utilities.jl")

# base.jl
export log_evaluation, save

# client.jl
export runs

# types.jl
export MLFlowLogger

# utilities.jl
export flat_params

end
12 changes: 1 addition & 11 deletions src/client.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function logmodelparams(client::MLFlow, run::MLFlowRun, model::Model)
model_params = params(model) |> flat_params |> collect
model_params = deep_params(model) |> flat_params |> collect
for (name, value) in model_params
logparam(client, run, name, value)
end
Expand All @@ -12,13 +12,3 @@ function logmachinemeasures(client::MLFlow, run::MLFlowRun, measures,
logmetric(client, run, name, value)
end
end

"""
runs(logger::MLFlowLogger)
Return a list of runs for the experiment specified by
`logger.experiment_name`. The list is returned as a
`Vector{MLFlowRun}`.
"""
runs(logger::MLFlowLogger) = searchruns(logger.client,
getexperiment(logger.client, logger.experiment_name))
29 changes: 19 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@ A wrapper around a MLFlow client, with an experiment name and an artifact
location. This is the type passed to the `logger` keyword argument of
multiple methods in MLJBase.
# Fields
- `client`: an MLFlow client(see [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow))
- `experiment_name`: the name of the experiment. If not provided, a default
experiment with the name "MLJ experiment" will be created.
- `artifact_location`: the location of the artifact store. If not provided,
a default artifact location will be defined by MLFlow. For more information,
see [MLFlow documentation](https://www.mlflow.org/docs/latest/tracking.html).
To use this logger, you need to have a MLFlow server running. For more
information, see [MLFlow documentation](https://www.mlflow.org/docs/latest/quickstart.html).
If it is not running, an informative error will be thrown.
# Return value
A `MLFlowLogger` object, containing the client, the experiment name and the
artifact location.
Depending on the MLFlow server configuration, the `baseuri` can be a local
server or a remote server. The `experiment_name` is used to identify the
experiment in the MLFlow server. If the experiment does not exist, it will be
created with the default name "MLJ experiment". The `artifact_location` is
used to store the artifacts of the experiment. If not provided, a default
artifact location will be defined by MLFlow. For more information, see
[MLFlow documentation](https://www.mlflow.org/docs/latest/tracking.html).
This constructor returns a `MLFlowLogger` object, containing the experiment
name and the artifact location specified previously. Also it contains a
`MLFlow` client, which is used to communicate with the MLFlow server. For
more information, see [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow).
"""
struct MLFlowLogger
Expand All @@ -27,5 +32,9 @@ end
function MLFlowLogger(baseuri; experiment_name="MLJ experiment",
artifact_location=nothing)
client = MLFlow(baseuri)

if ~healthcheck(client)
error("It seems that the MLFlow server is not running. For more information, see https://mlflow.org/docs/latest/quickstart.html")
end
MLFlowLogger(client, experiment_name, artifact_location)
end
62 changes: 0 additions & 62 deletions src/utilities.jl

This file was deleted.

9 changes: 5 additions & 4 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
clf = ConstantClassifier()
clf_machine = machine(clf, X, y)
e1 = evaluate!(clf_machine, resampling=CV(),
measures=[LogLoss(), Accuracy()], verbosity=1)
measures=[LogLoss(), Accuracy()], verbosity=1, logger=logger)

@testset "log_evaluation" begin
run = log_evaluation(logger, e1)
@test typeof(run) == MLFlowRun
runs = searchruns(logger.client,
getexperiment(logger.client, logger.experiment_name))
@test typeof(runs[1]) == MLFlowRun
end

@testset "save" begin
run = save(logger, clf_machine)
run = MLJ.save(logger, clf_machine)
@test typeof(run) == MLFlowRun
@test listartifacts(logger.client, run) |> length == 1
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using MLJ
using Test
using MLFlowClient

using MLJ
using MLJFlow
using MLFlowClient

include("base.jl")
include("types.jl")
include("utilities.jl")
10 changes: 0 additions & 10 deletions test/utilities.jl

This file was deleted.

0 comments on commit 951fb8f

Please sign in to comment.