Skip to content

Commit

Permalink
Switch metadata functionality to not be fixed to DataVec RecordMetaDa…
Browse files Browse the repository at this point in the history
…ta in Evaluation class
  • Loading branch information
AlexDBlack committed Oct 22, 2016
1 parent 38c7fd8 commit d6fb420
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -652,14 +652,16 @@ public void testEvaluationWithMetaData() throws Exception {

List<Prediction> errors = e.getPredictionErrors(); //*** New - get list of prediction errors from evaluation ***
List<RecordMetaData> metaForErrors = new ArrayList<>();
for(Prediction p : errors) metaForErrors.add(p.getRecordMetaData());
for(Prediction p : errors){
metaForErrors.add((RecordMetaData)p.getRecordMetaData());
}
DataSet ds = rrdsi.loadFromMetaData(metaForErrors); //*** New - dynamically load a subset of the data, just for prediction errors ***
INDArray output = net.output(ds.getFeatures());

int count = 0;
for(Prediction t : errors){
System.out.println(t
+ "\t\tRaw Data: " + csv.loadFromMetaData(t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) ***
+ "\t\tRaw Data: " + csv.loadFromMetaData((RecordMetaData)t.getRecordMetaData()).getRecord() //*** New - load subset of data from MetaData object (usually batched for efficiency) ***
+ "\tNormalized: " + ds.getFeatureMatrix().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count)
+ "\tNetwork predictions: " + output.getRow(count));
count++;
Expand Down
7 changes: 0 additions & 7 deletions deeplearning4j-nn/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@
</properties>

<dependencies>
<!-- TODO: Do we want this here? Used only in Evaluation metadata...-->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${datavec.version}</version>
</dependency>

<!-- ND4J API -->
<dependency>
<groupId>org.nd4j</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.deeplearning4j.eval;

import org.datavec.api.records.metadata.RecordMetaData;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.eval.meta.Prediction;
Expand Down Expand Up @@ -60,7 +59,7 @@ public class Evaluation implements Serializable {
//What to output from the precision/recall function when we encounter an edge case
protected static final double DEFAULT_EDGE_VALUE = 0.0;

protected Map<Pair<Integer,Integer>,List<RecordMetaData>> confusionMatrixMetaData; //Pair: (Actual,Predicted)
protected Map<Pair<Integer,Integer>,List<Object>> confusionMatrixMetaData; //Pair: (Actual,Predicted)

// Empty constructor
public Evaluation() {
Expand Down Expand Up @@ -189,7 +188,7 @@ public void eval(INDArray trueLabels, INDArray input, MultiLayerNetwork network)
* @param guesses the guesses/prediction (usually a probability vector)
*/
public void eval(INDArray realOutcomes, INDArray guesses) {
eval(realOutcomes, guesses, (List<RecordMetaData>)null);
eval(realOutcomes, guesses, (List<Object>)null);
}

/**
Expand All @@ -200,7 +199,7 @@ public void eval(INDArray realOutcomes, INDArray guesses) {
* @param recordMetaData Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
*
*/
public void eval(INDArray realOutcomes, INDArray guesses, List<RecordMetaData> recordMetaData ) {
public void eval(INDArray realOutcomes, INDArray guesses, List<?> recordMetaData ) {
// Add the number of rows to numRowCounter
numRowCounter += realOutcomes.shape()[0];

Expand Down Expand Up @@ -265,7 +264,7 @@ public void eval(INDArray realOutcomes, INDArray guesses, List<RecordMetaData> r
confusion.add(actual,predicted);

if(recordMetaData != null && recordMetaData.size() > i){
RecordMetaData m = recordMetaData.get(i);
Object m = recordMetaData.get(i);
addToMetaConfusionMatrix(actual,predicted,m);
}
}
Expand Down Expand Up @@ -993,13 +992,13 @@ public String confusionToString() {
}


private void addToMetaConfusionMatrix(int actual, int predicted, RecordMetaData metaData){
private void addToMetaConfusionMatrix(int actual, int predicted, Object metaData){
if(confusionMatrixMetaData == null){
confusionMatrixMetaData = new HashMap<>();
}

Pair<Integer,Integer> p = new Pair<>(actual,predicted);
List<RecordMetaData> list = confusionMatrixMetaData.get(p);
List<Object> list = confusionMatrixMetaData.get(p);
if(list == null){
list = new ArrayList<>();
confusionMatrixMetaData.put(p,list);
Expand All @@ -1023,10 +1022,10 @@ public List<Prediction> getPredictionErrors() {

List<Prediction> list = new ArrayList<>();

List<Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>>> sorted = new ArrayList<>(confusionMatrixMetaData.entrySet());
Collections.sort(sorted, new Comparator<Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>>>() {
List<Map.Entry<Pair<Integer, Integer>, List<Object>>> sorted = new ArrayList<>(confusionMatrixMetaData.entrySet());
Collections.sort(sorted, new Comparator<Map.Entry<Pair<Integer, Integer>, List<Object>>>() {
@Override
public int compare(Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> o1, Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> o2) {
public int compare(Map.Entry<Pair<Integer, Integer>, List<Object>> o1, Map.Entry<Pair<Integer, Integer>, List<Object>> o2) {
Pair<Integer, Integer> p1 = o1.getKey();
Pair<Integer, Integer> p2 = o2.getKey();
int order = Integer.compare(p1.getFirst(), p2.getFirst());
Expand All @@ -1036,13 +1035,13 @@ public int compare(Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> o1, M
}
});

for (Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> entry : sorted) {
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : sorted) {
Pair<Integer, Integer> p = entry.getKey();
if (p.getFirst().equals(p.getSecond())) {
//predicted = actual -> not an error -> skip
continue;
}
for (RecordMetaData m : entry.getValue()) {
for (Object m : entry.getValue()) {
list.add(new Prediction(p.getFirst(), p.getSecond(), m));
}
}
Expand All @@ -1066,11 +1065,11 @@ public List<Prediction> getPredictionsByActualClass(int actualClass) {
if (confusionMatrixMetaData == null) return null;

List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
if (entry.getKey().getFirst() == actualClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (RecordMetaData m : entry.getValue()) {
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
Expand All @@ -1094,11 +1093,11 @@ public List<Prediction> getPredictionByPredictedClass(int predictedClass) {
if (confusionMatrixMetaData == null) return null;

List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<RecordMetaData>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) { //Entry Pair: (Actual,Predicted)
if (entry.getKey().getSecond() == predictedClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (RecordMetaData m : entry.getValue()) {
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
Expand All @@ -1117,10 +1116,10 @@ public List<Prediction> getPredictions(int actualClass, int predictedClass) {
if (confusionMatrixMetaData == null) return null;

List<Prediction> out = new ArrayList<>();
List<RecordMetaData> list = confusionMatrixMetaData.get(new Pair<>(actualClass, predictedClass));
List<Object> list = confusionMatrixMetaData.get(new Pair<>(actualClass, predictedClass));
if (list == null) return out;

for (RecordMetaData meta : list) {
for (Object meta : list) {
out.add(new Prediction(actualClass, predictedClass, meta));
}
return out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,34 @@

import lombok.AllArgsConstructor;
import lombok.Data;
import org.datavec.api.records.metadata.RecordMetaData;

/**
* Created by Alex on 22/09/2016.
* Prediction: a prediction for classification, used with the {@link org.deeplearning4j.eval.Evaluation} class.
* Holds predicted and actual classes, along with an object for the example/record that produced this evaluation.
*
* @author Alex Black
*/
@AllArgsConstructor @Data
@AllArgsConstructor
@Data
public class Prediction {

private int actualClass;
private int predictedClass;
private RecordMetaData recordMetaData;
private Object recordMetaData;

@Override
public String toString(){
return "Prediction(actualClass=" + actualClass + ",predictedClass=" + predictedClass + ",RecordMetaData=" + recordMetaData.getLocation() + ")";
public String toString() {
return "Prediction(actualClass=" + actualClass + ",predictedClass=" + predictedClass + ",RecordMetaData=" + recordMetaData + ")";
}

/**
* Convenience method for getting the record meta data as a particular class (as an alternative to casting it manually).
* NOTE: This uses an unchecked cast inernally.
*
* @param recordMetaDataClass Class of the record metadata
* @param <T> Type to return
*/
public <T> T getRecordMetaData(Class<T> recordMetaDataClass) {
return (T) recordMetaData;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


import lombok.Setter;
import org.datavec.api.records.metadata.RecordMetaData;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
Expand Down Expand Up @@ -2416,7 +2415,11 @@ public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, in
} else {
out = this.output(features,false);
if(labels.rank() == 3 ) e.evalTimeSeries(labels,out);
else e.eval(labels,out,next.getExampleMetaData(RecordMetaData.class));
else{
List<Serializable> meta = next.getExampleMetaData();
List<Object> meta2 = (meta == null ? null : new ArrayList<Object>(meta));
e.eval(labels,out,meta2);
}
}
}

Expand Down

0 comments on commit d6fb420

Please sign in to comment.