Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object ConstEmbeddingsGlove {
val logger:Logger = LoggerFactory.getLogger(classOf[ConstEmbeddingsGlove])

// This is not marked private for debugging purposes
private var SINGLETON_WORD_EMBEDDING_MAP: Option[WordEmbeddingMap] = None
var SINGLETON_WORD_EMBEDDING_MAP: Option[WordEmbeddingMap] = None

// make sure the singleton is loaded
load()
Expand Down
74 changes: 48 additions & 26 deletions main/src/main/scala/org/clulab/dynet/TestOnnx.scala
Original file line number Diff line number Diff line change
@@ -1,50 +1,70 @@
package org.clulab.dynet

import org.clulab.embeddings.{CompactWordEmbeddingMap, WordEmbeddingMapPool}

import java.io.{FileWriter, PrintWriter}

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
import com.typesafe.config.ConfigFactory
import org.clulab.dynet.Utils._
import org.clulab.utils.StringUtils
import org.clulab.utils.{StringUtils, Timer}

import scala.io.Source
import scala.util.parsing.json._

import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession}
import org.slf4j.{Logger, LoggerFactory}

import java.time.LocalDateTime
import java.time.Duration

import scala.io.Source


object TestOnnx extends App {

class TextEmbedder(filename: String) {

def get_embeddings(embed_file_path: String): Map[String,Array[Float]]={
val emb = Source.fromFile(embed_file_path)
var emb_map:Map[String,Array[Float]] = Map()
for (s<-emb.getLines){
if (s.split(" ")(0) == ""){
emb_map += ("<UNK>"-> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat))
}else{
emb_map += (s.split(" ")(0) -> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat))
}
val emb = Source.fromFile(embed_file_path)
var emb_map:Map[String,Array[Float]] = Map()
for (s<-emb.getLines){
if (s.split(" ")(0) == ""){
// TODO: These probably need to be normalized in both cases.
emb_map += ("<UNK>"-> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat))
}else{
emb_map += (s.split(" ")(0) -> s.split(" ").slice(1, s.split(" ").size).map(_.toFloat))
}
emb_map
}
emb_map
}

protected val map = {
val timer = new Timer("get_embeddings")
val result = timer.time(get_embeddings(filename))

println(timer.toString)
result
}
protected val unknown = map("<UNK>")

def apply(key: String): Array[Float] = map.getOrElse(key, unknown)
}

class ProcessorsEmbedder() {
val map = {
val timer = new Timer("SINGLETON_WORD_EMBEDDING_MAP")
val result = timer.time(ConstEmbeddingsGlove.SINGLETON_WORD_EMBEDDING_MAP.get)

println(timer.toString)
result
}

def apply(key: String): Array[Float] = map.getOrElseUnknown(key).toArray
}

val start_time = LocalDateTime.now()
val props = StringUtils.argsToProperties(args)

val configName = props.getProperty("conf")
val config = ConfigFactory.load(configName)
val taskManager = new TaskManager(config)

val embed_file_path: String = "/data1/home/zheng/processors/main/src/main/python/glove.840B.300d.10f.txt"
val wordEmbeddingMap = get_embeddings(embed_file_path)

// Pick one of these.
val embedder = new TextEmbedder("/data1/home/zheng/processors/main/src/main/python/glove.840B.300d.10f.txt")
// val embedder = new TextEmbedder("../glove/glove.840B.300d.10f.txt")
// val embedder = new ProcessorsEmbedder()

val jsonString = Source.fromFile("/data1/home/zheng/processors/ner.json").getLines.mkString
// val jsonString = Source.fromFile("../onnx/ner.json").getLines.mkString
val parsed = JSON.parseFull(jsonString)
val w2i = parsed.get.asInstanceOf[List[Any]](0).asInstanceOf[Map[String, Any]]("x2i").asInstanceOf[Map[String, Any]]("initialLayer").asInstanceOf[Map[String, Any]]("w2i").asInstanceOf[Map[String, Double]]
val c2i = parsed.get.asInstanceOf[List[Any]](0).asInstanceOf[Map[String, Any]]("x2i").asInstanceOf[Map[String, Any]]("initialLayer").asInstanceOf[Map[String, Any]]("c2i").asInstanceOf[Map[String, Double]]
Expand All @@ -53,8 +73,10 @@ object TestOnnx extends App {

val ortEnvironment = OrtEnvironment.getEnvironment
val modelpath1 = "/data1/home/zheng/processors/char.onnx"
/// val modelpath1 = "../onnx/char.onnx"
val session1 = ortEnvironment.createSession(modelpath1, new OrtSession.SessionOptions)
val modelpath2 = "/data1/home/zheng/processors/model.onnx"
/// val modelpath2 = "../onnx/model.onnx"
val session2 = ortEnvironment.createSession(modelpath2, new OrtSession.SessionOptions)

println(session1.getOutputInfo)
Expand All @@ -81,7 +103,7 @@ object TestOnnx extends App {
var char_embs:Array[Array[Float]] = new Array[Array[Float]](words.length)
for(i <- words.indices){
val word = words(i)
embeddings(i) = wordEmbeddingMap.getOrElse(word,wordEmbeddingMap.get( "<UNK>").get)
embeddings(i) = embedder(word)
wordIds(i) = w2i.getOrElse(word, 0).asInstanceOf[Number].longValue
val char_input = new java.util.HashMap[String, OnnxTensor]()
char_input.put("char_ids", OnnxTensor.createTensor(ortEnvironment, word.map(c => c2i.getOrElse(c.toString, 0).asInstanceOf[Number].longValue).toArray))
Expand Down