Skip to content

Commit

Permalink
feat: add partial_fit method to SGDRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Aug 25, 2023
1 parent 946c0a3 commit d3f448e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
19 changes: 19 additions & 0 deletions rumale-linear_model/lib/rumale/linear_model/sgd_regressor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,25 @@ def fit(x, y)
self
end

# Perform 1-epoch of stochastic gradient descent optimization with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::DFloat] (shape: [n_samples]) The single target variables to be used for fitting the model.
# @return [SGDRegressor] The learned regressor itself.
def partial_fit(x, y)
x = Rumale::Validation.check_convert_sample_array(x)
y = Rumale::Validation.check_convert_target_value_array(y)
Rumale::Validation.check_sample_size(x, y)

n_features = x.shape[1]
n_features += 1 if fit_bias?
need_init = @weight.nil? || @weight.shape[0] != n_features

@weight_vec, @bias_term = partial_fit_(x, y, max_iter: 1, init: need_init)

self
end

# Predict values for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the values.
Expand Down
59 changes: 59 additions & 0 deletions rumale-linear_model/spec/rumale/linear_model/sgd_regressor_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,68 @@
end
end

shared_examples 'partially fitted regression problems' do
let(:y) { single_target }
let(:estimator) do
described_class.new(loss: loss, reg_param: 1, epsilon: 0.1, fit_bias: fit_bias, n_jobs: n_jobs, random_seed: 1)
end

it 'learns the linear model', :aggregate_failures do
estimator.partial_fit(x, y)

expect(estimator.weight_vec).to be_a(Numo::DFloat)
expect(estimator.weight_vec).to be_contiguous
expect(estimator.weight_vec.ndim).to eq(1)
expect(estimator.weight_vec.shape[0]).to eq(n_features)
expect(estimator.bias_term).to be_zero

prev_weight = estimator.weight_vec
999.times { estimator.partial_fit(x, y) }
curr_weight = estimator.weight_vec
expect((prev_weight - curr_weight).abs.max).to be > 1.0

prev_weight = curr_weight
estimator.partial_fit(x, y)
curr_weight = estimator.weight_vec
expect((prev_weight - curr_weight).abs.max).to be < 0.5

expect(predicted).to be_a(Numo::DFloat)
expect(predicted).to be_contiguous
expect(predicted.ndim).to eq(1)
expect(predicted.shape[0]).to eq(n_samples)
expect(score).to be_within(0.01).of(1.0)
end

context 'when fit_bias parameter is true' do
let(:fit_bias) { true }

it 'learns the linear model with bias term', :aggregate_failures do
estimator.partial_fit(x, y)

expect(estimator.weight_vec.ndim).to eq(1)
expect(estimator.weight_vec.shape[0]).to eq(n_features)
expect(estimator.bias_term).not_to be_zero

prev_bias = estimator.bias_term
999.times { estimator.partial_fit(x, y) }
curr_bias = estimator.bias_term
expect((prev_bias - curr_bias).abs).to be > 0.01

prev_bias = curr_bias
estimator.partial_fit(x, y)
curr_bias = estimator.bias_term
expect((prev_bias - curr_bias).abs).to be < 0.01

expect(score).to be_within(0.01).of(1.0)
end
end
end

context 'when loss is "squared_error"' do
let(:loss) { 'squared_error' }

it_behaves_like 'regression problems'
it_behaves_like 'partially fitted regression problems'

it 'dumps and restores itself using Marshal module', :aggregate_failures do
expect(copied.class).to eq(estimator.class)
Expand All @@ -110,6 +168,7 @@
let(:loss) { 'epsilon_insensitive' }

it_behaves_like 'regression problems'
it_behaves_like 'partially fitted regression problems'

it 'dumps and restores itself using Marshal module', :aggregate_failures do
expect(copied.class).to eq(estimator.class)
Expand Down

0 comments on commit d3f448e

Please sign in to comment.