Skip to content

Commit

Permalink
Polish Job API and introduce function.
Browse files Browse the repository at this point in the history
- Function is lighter version of job and shares
the same API.

- Job#exec was always returning Job#DONE, so result
  value was removed. Btw: it was implicitly expected
  by the code.

- Corrected the rest of code to be compatible with API changes.
  • Loading branch information
mmalohlava committed Apr 1, 2014
1 parent 60762b9 commit 41361e0
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 159 deletions.
3 changes: 1 addition & 2 deletions src/main/java/hex/GridSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class GridSearch extends Job {
public GridSearch(){

}
@Override protected JobState execImpl() {
@Override protected void execImpl() {
UKV.put(destination_key, this);
int max = jobs[0].gridParallelism();
int head = 0, tail = 0;
Expand All @@ -35,7 +35,6 @@ public GridSearch(){
}
}
}
return JobState.DONE;
}

@Override protected void onCancelled() {
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/hex/KMeans2.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public KMeans2() {
description = "K-means";
}

@Override protected JobState execImpl() {
@Override protected void execImpl() {
logStart();
source.read_lock(self());
String sourceArg = input("source");
Expand Down Expand Up @@ -111,7 +111,7 @@ public KMeans2() {
clusters = Utils.append(clusters, sampler._sampled);

if( !isRunning(self()) )
return JobState.DONE;
return;
model.centers = normalize ? denormalize(clusters, vecs) : clusters;
model.total_within_SS = sqr._sqr;
model.iterations++;
Expand Down Expand Up @@ -158,7 +158,7 @@ public KMeans2() {
}
model.unlock(self());
source.unlock(self());
return JobState.DONE;
return;
}

@Override protected Response redirect() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/deeplearning/DeepLearning.java
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ public static String link(Key k, String content, Key cp, String response, Key va
* Train a Deep Learning model, assumes that all members are populated
* @return JobState
*/
@Override public JobState execImpl() {
@Override protected final void execImpl() {
DeepLearningModel cp;
if (checkpoint == null) cp = initModel();
else {
Expand Down Expand Up @@ -444,7 +444,6 @@ public static String link(Key k, String content, Key cp, String response, Key va
}
trainModel(cp);
delete();
return JobState.DONE;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ public double calcError(Frame ftest, Frame fpreds, Frame hitratio_fpreds, String
auc.predict = fpreds;
auc.vpredict = fpreds.vecs()[2]; //binary classifier (label, prob0, prob1 (THIS ONE), adaptedlabel)
auc.threshold_criterion = AUC.ThresholdCriterion.maximum_F1;
auc.serve();
auc.invoke();
auc.toASCII(sb);
error = auc.err(); //using optimal threshold for F1
}
Expand All @@ -816,7 +816,7 @@ public double calcError(Frame ftest, Frame fpreds, Frame hitratio_fpreds, String
cm.vactual = ftest.lastVec(); //original vector or adapted response (label) if CM adaptation was done
cm.predict = fpreds;
cm.vpredict = fpreds.vecs()[0]; //ditto
cm.serve();
cm.invoke();
cm.toASCII(sb);
error = isClassifier() ? new hex.ConfusionMatrix(cm.cm).err() : cm.mse;
}
Expand Down
18 changes: 8 additions & 10 deletions src/main/java/hex/drf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,16 @@ public static String link(Key k, String content) {

// ==========================================================================

// Compute a DRF tree.

// Start by splitting all the data according to some criteria (minimize
// variance at the leaves). Record on each row which split it goes to, and
// assign a split number to it (for next pass). On *this* pass, use the
// split-number to build a per-split histogram, with a per-histogram-bucket
// variance.

@Override protected JobState execImpl() {
/** Compute a DRF tree.
*
* Start by splitting all the data according to some criteria (minimize
* variance at the leaves). Record on each row which split it goes to, and
* assign a split number to it (for next pass). On *this* pass, use the
* split-number to build a per-split histogram, with a per-histogram-bucket
* variance. */
@Override protected void execImpl() {
logStart();
buildModel();
return JobState.DONE;
}

@Override protected Response redirect() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/gbm/GBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ public static String link(Key k, String content) {
return rs.toString();
}

@Override protected JobState execImpl() {
@Override protected void execImpl() {
logStart();
buildModel();
return JobState.DONE;
}

@Override public int gridParallelism() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/nb/NaiveBayes.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class NaiveBayes extends ModelJob {
@API(help = "Laplace smoothing parameter", filter = Default.class, lmin = 0, lmax = 100000, json = true)
public int laplace = 0;

@Override protected JobState execImpl() {
@Override protected void execImpl() {
Frame fr = DataInfo.prepareFrame(source, response, ignored_cols, false, false);

// TODO: Temporarily reject data with missing entries until NA handling implemented
Expand All @@ -40,7 +40,6 @@ public class NaiveBayes extends ModelJob {
NBModel myModel = buildModel(dinfo, tsk, laplace);
myModel.delete_and_lock(self());
myModel.unlock(self());
return JobState.DONE;
}

@Override protected void init() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/pca/PCA.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public PCA(String desc, Key dest, Frame src, int max_pc, double tolerance, boole
this.standardize = standardize;
}

@Override protected JobState execImpl() {
@Override protected void execImpl() {
Frame fr = selectFrame(source);
Vec[] vecs = fr.vecs();

Expand All @@ -80,7 +80,6 @@ public PCA(String desc, Key dest, Frame src, int max_pc, double tolerance, boole
PCAModel myModel = buildModel(dinfo, tsk);
myModel.delete_and_lock(self());
myModel.unlock(self());
return JobState.DONE;
}

@Override protected void init() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/pca/PCAImpute.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ public class PCAImpute extends FrameJob {
@API(help = "Scale columns by their standard deviations", filter = Default.class)
boolean scale = true;

@Override protected JobState execImpl() {
@Override protected void execImpl() {
Frame fr = source;
new Frame(destination_key,fr._names.clone(),fr.vecs().clone()).delete_and_lock(null).unlock(null);
return JobState.DONE;
}

@Override protected void init() {
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/hex/pca/PCAScore.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class PCAScore extends FrameJob {
@API(help = "Number of principal components to return", filter = Default.class, lmin = 1, lmax = 10000)
int num_pc = 1;

@Override protected JobState execImpl() {
@Override protected void execImpl() {
// Note: Source data MUST contain all features (matched by name) used to build PCA model!
// If additional columns exist in source, they are automatically ignored in scoring
new Frame(destination_key, new String[0], new Vec[0]).delete_and_lock(self());
Expand All @@ -47,7 +47,6 @@ public class PCAScore extends FrameJob {
domains[i] = null;
}
tsk.outputFrame(destination_key, names, domains).unlock(self());
return JobState.DONE;
}

@Override protected void init() {
Expand Down
90 changes: 15 additions & 75 deletions src/main/java/water/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.HashMap;
import java.util.HashSet;

public abstract class Job extends Request2 {
public abstract class Job extends Func {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.

Expand Down Expand Up @@ -88,49 +88,6 @@ public int gridParallelism() {
return 1;
}

/** A set containing a temporary vectors which are <strong>automatically</strong> deleted when job is done.
* Deletion is by {@link #cleanup()} call. */
private transient HashSet<Key> _gVecTrash = new HashSet<Key>();
/** Local trash which can be deleted by user call */
private transient HashSet<Key> _lVecTrash = new HashSet<Key>();
/** Clean-up code which is executed after each {@link Job#exec()} call in any case (normal/exceptional). */
protected void cleanup() {
// Clean-up global list of temporary vectors
Futures fs = new Futures();
cleanupTrash(_gVecTrash, fs);
if (!_lVecTrash.isEmpty()) cleanupTrash(_lVecTrash, fs);
fs.blockForPending();
}
/** User call which empty local trash of vectors. */
protected final void emptyLTrash() {
if (_lVecTrash.isEmpty()) return;
Futures fs = new Futures();
cleanupTrash(_lVecTrash, fs);
fs.blockForPending();
}
/** Append all vectors from given frame to a global clean up list.
* @see #cleanup()
* @see #_gVecTrash */
protected final void gtrash(Frame fr) { gtrash(fr.vecs()); }
/** Append given vector to clean up list.
* @see #cleanup()*/
protected final void gtrash(Vec ...vec) { appendToTrash(_gVecTrash, vec); }
/** Put given frame vectors into local trash which can be emptied by a user calling the {@link #emptyLTrash()} method.
* @see #emptyLTrash() */
protected final void ltrash(Frame fr) { ltrash(fr.vecs()); }
/** Put given vectors into local trash.
* * @see #emptyLTrash() */
protected final void ltrash(Vec ...vec) { appendToTrash(_lVecTrash, vec); }

/** Put given vectors into a given trash. */
private void appendToTrash(HashSet<Key> t, Vec[] vec) {
for (Vec v : vec) t.add(v._key);
}
/** Delete all vectors in given trash. */
private void cleanupTrash(HashSet<Key> trash, Futures fs) {
for (Key k : trash) UKV.remove(k, fs);
}

protected Key defaultJobKey() {
// Pinned to this node (i.e., the node invoked computation), because it should be almost always updated locally
return Key.make((byte) 0, Key.JOB, H2O.SELF);
Expand Down Expand Up @@ -367,33 +324,31 @@ public Job fork() {
init();
H2OCountedCompleter task = new H2OCountedCompleter() {
@Override public void compute2() {
Throwable t = null;
try {
JobState status = Job.this.exec();
if(status == JobState.DONE)
try {
// Exec always waits till the end of computation
exec();
Job.this.remove();
} catch (Throwable t_) {
t = t_;
if(!(t instanceof ExpectedExceptionForDebug))
Log.err(t);
} catch (Throwable t) {
if(!(t instanceof ExpectedExceptionForDebug))
Log.err(t);
Job.this.cancel(t);
}
} finally {
tryComplete();
tryComplete();
}
if(t != null)
Job.this.cancel(t);
}
};
start(task);
H2O.submitTask(task);
return this;
}

public void invoke() {
@Override public void invoke() {
init();
start(new H2OEmptyCompleter());
JobState status = exec();
if(status == JobState.DONE)
remove();
start(new H2OEmptyCompleter()); // mark job started
exec(); // execute the implementation
remove(); // remove the job
}

/**
Expand All @@ -403,25 +358,10 @@ public void invoke() {
* @throws IllegalArgumentException throws the exception if initialization fails to ensure
* correct job runtime environment.
*/
protected void init() throws IllegalArgumentException {
@Override protected void init() throws IllegalArgumentException {
if (destination_key == null) destination_key = defaultDestKey();
}

/**
* Actual job code.
*
* @return true if job is done, false if it will still be running after the method returns.
*/
private final JobState exec() {
try {
return execImpl(); // Execute job
} finally {
cleanup(); // Perform job cleanup
}
}

protected JobState execImpl() { throw new RuntimeException("Job does not support exec call! Please implement execImpl method!"); };

/**
* Block synchronously waiting for a job to end, success or not.
* @param jobkey Job to wait for.
Expand Down
16 changes: 7 additions & 9 deletions src/main/java/water/api/AUC.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import static java.util.Arrays.sort;
import hex.ConfusionMatrix;
import org.apache.commons.lang.StringEscapeUtils;
import water.MRTask2;
import water.Request2;
import water.UKV;

import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Utils;

import java.util.HashSet;

public class AUC extends Request2 {
public class AUC extends Func {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
public static final String DOC_GET = "AUC";
Expand Down Expand Up @@ -167,16 +166,18 @@ public AUC(hex.ConfusionMatrix[] cms, float[] thresh) {
computeMetrics();
}

@Override public Response serve() {
Vec va = null, vp;
@Override protected void init() throws IllegalArgumentException {
// Input handling
if( vactual==null || vpredict==null )
throw new IllegalArgumentException("Missing vactual or vpredict!");
if (vactual.length() != vpredict.length())
throw new IllegalArgumentException("Both arguments must have the same length!");
if (!vactual.isInt())
throw new IllegalArgumentException("Actual column must be integer class labels!");
}

@Override protected void execImpl() {
Vec va = null, vp;
try {
va = vactual.toEnum(); // always returns TransfVec
actual_domain = va._domain;
Expand Down Expand Up @@ -215,9 +216,6 @@ public AUC(hex.ConfusionMatrix[] cms, float[] thresh) {
computeAUC();
findBestThresholds();
computeMetrics();
return Response.done(this);
} catch( Throwable t ) {
return Response.error(t);
} finally { // Delete adaptation vectors
if (va!=null) UKV.remove(va._key);
}
Expand Down
Loading

0 comments on commit 41361e0

Please sign in to comment.