Skip to content

Commit

Permalink
Kernel API simplifications and SVMs improvements (#80)
Browse files Browse the repository at this point in the history
* moved kernel transform inside fit for svms

* removed dataset from kernel, moved w_sum in svm

* svm stores only support vectors

* fixed alpha calculation for linear kernel regression

* keep only kernel method in svm

* adapted other crates to new kernel definition

* addresses pr comments
  • Loading branch information
Sauro98 authored Feb 7, 2021
1 parent 7198962 commit 1a7982b
Show file tree
Hide file tree
Showing 10 changed files with 498 additions and 401 deletions.
13 changes: 5 additions & 8 deletions linfa-hierarchical/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ impl<F: Float> HierarchicalCluster<F> {
}
}

impl<'b: 'a, 'a, F: Float> Transformer<Kernel<'a, F>, DatasetBase<Kernel<'a, F>, Vec<usize>>>
impl<F: Float> Transformer<Kernel<F>, DatasetBase<Kernel<F>, Vec<usize>>>
for HierarchicalCluster<F>
{
/// Perform hierarchical clustering of a similarity matrix
///
/// Returns the class id for each data point
fn transform(&self, kernel: Kernel<'a, F>) -> DatasetBase<Kernel<'a, F>, Vec<usize>> {
fn transform(&self, kernel: Kernel<F>) -> DatasetBase<Kernel<F>, Vec<usize>> {
// ignore all similarities below this value
let threshold = F::from(1e-6).unwrap();

Expand Down Expand Up @@ -129,17 +129,14 @@ impl<'b: 'a, 'a, F: Float> Transformer<Kernel<'a, F>, DatasetBase<Kernel<'a, F>,
}
}

impl<'a, F: Float, T: Targets>
Transformer<DatasetBase<Kernel<'a, F>, T>, DatasetBase<Kernel<'a, F>, Vec<usize>>>
impl<F: Float, T: Targets>
Transformer<DatasetBase<Kernel<F>, T>, DatasetBase<Kernel<F>, Vec<usize>>>
for HierarchicalCluster<F>
{
/// Perform hierarchical clustering of a similarity matrix
///
/// Returns the class id for each data point
fn transform(
&self,
dataset: DatasetBase<Kernel<'a, F>, T>,
) -> DatasetBase<Kernel<'a, F>, Vec<usize>> {
fn transform(&self, dataset: DatasetBase<Kernel<F>, T>) -> DatasetBase<Kernel<F>, Vec<usize>> {
//let Dataset { records, .. } = dataset;
self.transform(dataset.records)
}
Expand Down
4 changes: 4 additions & 0 deletions linfa-kernel/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use serde_crate::{Deserialize, Serialize};
use sprs::{CsMat, CsMatView};
use std::ops::Mul;

/// Specifies the methods an inner matrix of a kernel must
/// be able to provide
pub trait Inner {
type Elem: Float;

Expand All @@ -18,6 +20,8 @@ pub trait Inner {
fn diagonal(&self) -> Array1<Self::Elem>;
}

/// Allows a kernel to have either a dense or a sparse inner
/// matrix in a way that is transparent to the user
pub enum KernelInner<K1: Inner, K2: Inner> {
Dense(K1),
Sparse(K2),
Expand Down
Loading

0 comments on commit 1a7982b

Please sign in to comment.