Skip to content

Commit

Permalink
simbert update
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmetaydin123 committed Jan 24, 2022
1 parent 1b65d56 commit a296f2a
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/main/java/edu/anadolu/datasets/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import org.clueweb09.InfoNeed;
import org.clueweb09.tracks.Track;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
Expand Down Expand Up @@ -48,7 +50,7 @@ public Path home() {
}

public Path indexesPath() {
return Paths.get(tfd_home, collection.toString(), "indexes");
return Paths.get("/indexes/TFD-HOME", collection.toString(), "indexes");
}


Expand Down
24 changes: 19 additions & 5 deletions src/main/java/edu/anadolu/ltr/DocFeatureBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public class DocFeatureBase {
List<String> keyword;
List<String> description;
List<String> hTags;
float[] vectorlistContent;
float[] vectortitle;
float[] vectorkeyword;
float[] vectordescription;
float[] vectorhTags;
IndexSearcher searcher;
IndexReader reader;
Map<String,Integer> mapTf;
Expand Down Expand Up @@ -85,6 +90,13 @@ public class DocFeatureBase {
.map(e -> e.text())
.map(String::trim)
.filter(notEmpty).collect(Collectors.toList());
if(bert!=null){
vectorlistContent = bertVector(this.listContent);
vectortitle = bertVector(this.title);
vectorkeyword = bertVector(this.keyword);
vectordescription = bertVector(this.description);
vectorhTags = bertVector(this.hTags);
}
} catch (Exception exception) {
System.err.println("jdoc exception " + warcRecord.id());
exception.printStackTrace();
Expand Down Expand Up @@ -273,14 +285,16 @@ protected double cosSim(String str1, String str2){
return score;
}

protected double bertSim(String str1, String str2){
if(str1.length()==0 || str2.length()==0) return 0;
float[][] embeddings = bert.embedSequences(str1,str2);
protected float[] bertVector(List<String> str){
if(str.size()==0) return null;
return bert.embedSequence(String.join(" ",str));
}

protected double bertSim(float[] vectorA, float[] vectorB){
if(vectorA==null || vectorB==null) return 0;
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
float[] vectorA = embeddings[0];
float[] vectorB = embeddings[1];
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
Expand Down
60 changes: 33 additions & 27 deletions src/main/java/edu/anadolu/ltr/SEOTool.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.anadolu.ltr;

import com.robrua.nlp.bert.Bert;
import edu.anadolu.Indexer;
import edu.anadolu.analysis.Analyzers;
import edu.anadolu.analysis.Tag;
Expand Down Expand Up @@ -149,31 +150,35 @@ public void run(Properties props) throws Exception {

///////////////////////////// Index Reading for stats ///////////////////////////////////////////
Path indexPath=null;
if(this.tag == null)
indexPath = Files.newDirectoryStream(dataset.indexesPath(), Files::isDirectory).iterator().next();
else {
try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataset.indexesPath(), Files::isDirectory)) {
for (Path path : stream) {
if(!tag.equals(path.getFileName().toString())) continue;
indexPath = path;
}
} catch (IOException e) {
e.printStackTrace();
}
}

if(indexPath == null)
throw new RuntimeException(tag + " index not found");


this.indexTag = indexPath.getFileName().toString();
this.analyzerTag = Tag.tag(indexTag);

this.reader = DirectoryReader.open(FSDirectory.open(indexPath));

IndexSearcher searcher = new IndexSearcher(reader);
CollectionStatistics collectionStatistics = searcher.collectionStatistics(Indexer.FIELD_CONTENTS);

// if(this.tag == null)
// indexPath = Files.newDirectoryStream(dataset.indexesPath(), Files::isDirectory).iterator().next();
// else {
// try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataset.indexesPath(), Files::isDirectory)) {
// for (Path path : stream) {
// if(!tag.equals(path.getFileName().toString())) continue;
// indexPath = path;
// }
// } catch (IOException e) {
// e.printStackTrace();
// }
// }
//
// if(indexPath == null)
// throw new RuntimeException(tag + " index not found");
//
//
// this.indexTag = indexPath.getFileName().toString();
// this.analyzerTag = Tag.tag(indexTag);
this.analyzerTag = Tag.tag(tag);

// this.reader = DirectoryReader.open(FSDirectory.open(indexPath));

IndexSearcher searcher = null;
// IndexSearcher searcher = new IndexSearcher(reader);
CollectionStatistics collectionStatistics = null;
// CollectionStatistics collectionStatistics = searcher.collectionStatistics(Indexer.FIELD_CONTENTS);

Bert bert = null;


List<IDocFeature> features = new ArrayList<>();
Expand Down Expand Up @@ -272,6 +277,7 @@ public void run(Properties props) throws Exception {
features.add(new SimContentTitle(type));
features.add(new SimTitleH(type));
features.add(new SimTitleKeyword(type));
bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12");
}
}
}
Expand All @@ -297,8 +303,8 @@ public void run(Properties props) throws Exception {
features.add(new CDD());
}

Traverser traverser = new Traverser(dataset, docsPath, docIdSet, features, collectionStatistics, analyzerTag, searcher, reader, resultsettype);
System.out.println("Average Doc Len = "+(double)collectionStatistics.sumTotalTermFreq()/collectionStatistics.docCount());
Traverser traverser = new Traverser(dataset, docsPath, docIdSet, features, collectionStatistics, analyzerTag, searcher, reader, resultsettype, bert);
// System.out.println("Average Doc Len = "+(double)collectionStatistics.sumTotalTermFreq()/collectionStatistics.docCount());

final int numThreads = props.containsKey("numThreads") ? Integer.parseInt(props.getProperty("numThreads")) : Runtime.getRuntime().availableProcessors();
System.out.println(numThreads + " threads are running.");
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimContentDescription.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.listContent, base.description);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.listContent),String.join(" ",base.description));
return base.bertSim(base.vectorlistContent,base.vectordescription);
return base.cosSim(String.join(" ",base.listContent),String.join(" ",base.description));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimContentH.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.listContent, base.hTags);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.listContent),String.join(" ",base.hTags));
return base.bertSim(base.vectorlistContent,base.vectorhTags);
return base.cosSim(String.join(" ",base.listContent),String.join(" ",base.hTags));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimContentKeyword.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.listContent, base.keyword);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.listContent),String.join(" ",base.keyword));
return base.bertSim(base.vectorlistContent,base.vectorkeyword);
return base.cosSim(String.join(" ",base.listContent),String.join(" ",base.keyword));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimContentTitle.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.listContent, base.title);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.listContent),String.join(" ",base.title));
return base.bertSim(base.vectorlistContent,base.vectortitle);
return base.cosSim(String.join(" ",base.listContent),String.join(" ",base.title));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimDescriptionH.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.description, base.hTags);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.description),String.join(" ",base.hTags));
return base.bertSim(base.vectordescription,base.vectorhTags);
return base.cosSim(String.join(" ",base.description),String.join(" ",base.hTags));
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimKeywordDescription.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.keyword, base.description);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.keyword),String.join(" ",base.description));
return base.bertSim(base.vectorkeyword,base.vectordescription);
return base.cosSim(String.join(" ",base.keyword),String.join(" ",base.description));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimKeywordH.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.keyword, base.hTags);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.keyword),String.join(" ",base.hTags));
return base.bertSim(base.vectorkeyword,base.vectorhTags);
return base.cosSim(String.join(" ",base.keyword),String.join(" ",base.hTags));
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimTitleDescription.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.title, base.description);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.title),String.join(" ",base.description));
return base.bertSim(base.vectortitle,base.vectordescription);
return base.cosSim(String.join(" ",base.title),String.join(" ",base.description));
}
}
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimTitleH.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.title, base.hTags);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.title),String.join(" ",base.hTags));
return base.bertSim(base.vectortitle, base.vectorhTags);
return base.cosSim(String.join(" ",base.title),String.join(" ",base.hTags));
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/edu/anadolu/ltr/SimTitleKeyword.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public String toString() {
public double calculate(DocFeatureBase base) throws IOException, NullPointerException {
// return base.textSimilarity(base.title, base.keyword);
if("bert".equals(this.type))
return base.bertSim(String.join(" ",base.title),String.join(" ",base.keyword));
return base.bertSim(base.vectortitle,base.vectorkeyword);
return base.cosSim(String.join(" ",base.title),String.join(" ",base.keyword));
}
}
6 changes: 4 additions & 2 deletions src/main/java/edu/anadolu/ltr/Traverser.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ protected boolean skip(String docId) {
private IndexSearcher searcher;
private IndexReader reader;
private String resultsettype;
private Bert bert;

Traverser(DataSet dataset, String docsDir, Set<String> docIdSet, List<IDocFeature> featureList, CollectionStatistics collectionStatistics, Tag analyzerTag, IndexSearcher searcher, IndexReader reader, String resultsettype) {
Traverser(DataSet dataset, String docsDir, Set<String> docIdSet, List<IDocFeature> featureList, CollectionStatistics collectionStatistics, Tag analyzerTag, IndexSearcher searcher, IndexReader reader, String resultsettype, Bert bert) {
this.collection = dataset.collection();
this.docIdSet = docIdSet;
this.featureList = featureList;
Expand All @@ -205,6 +206,7 @@ protected boolean skip(String docId) {
this.searcher = searcher;
this.reader = reader;
this.resultsettype=resultsettype;
this.bert=bert;

docsPath = Paths.get(docsDir);
if (!Files.exists(docsPath) || !Files.isReadable(docsPath) || !Files.isDirectory(docsPath)) {
Expand All @@ -218,7 +220,7 @@ protected boolean skip(String docId) {
* Traverse based on Java8's parallel streams
*/
void traverseParallel(Path resultPath, int numThreads) throws IOException {
Bert bert = Bert.load("com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12");


// RelatednessCalculator rc1 = new WuPalmer(new NictWordNet());

Expand Down

0 comments on commit a296f2a

Please sign in to comment.