Skip to content

Commit

Permalink
Add model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 293786720
  • Loading branch information
VilmosProkaj authored and vbardiovskyg committed Feb 11, 2020
1 parent d3cc4e6 commit 5a09f24
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions tfhub_dev/assets/prvi/tf2-nq/1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Module prvi/tf2nq/1
Question Answering model for Kaggle TensorFlow 2.0 Question Answering challenge [1].

<!-- asset-path: https://hpz400.cs.elte.hu/model/model.tar.gz -->
<!-- module-type: text-question-answering -->
<!-- network-architecture: Transformer -->
<!-- dataset: Natural Questions -->
<!-- language: en -->
<!-- fine-tunable: true -->
<!-- format: saved_model_2 -->

## Overview

This a finetuned BERT large model with a special head, where the core model was pretrained with whole word masking.

The input to this model is a tokenized text that start with a question and
continues with part of a wikipedia page.
The aim is to find the answer to the question within the wikipedia page. The input an the
output format are somewhat different from the BERT joint baseline. More details can be found
in the short description of this solution [2] and in the pre and post process code [3].

#### Example use

The model can be loaded with
```
import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/prvi/tf2nq/1")
```

Since this model not only outputs the start and end logits of short spans, but a logit for each possible short spans postprocessing is required. E.g. to keep the 100 most promising candidates one can define
```
def output(unique_id,model_output,n_keep=100):
pos_logits,ans_logits,long_mask,short_mask,cross = model_output
long_span_logits = pos_logits
mask = tf.cast(tf.expand_dims(long_mask,-1),long_span_logits.dtype)
long_span_logits = long_span_logits-10000*mask
long_p = tf.nn.softmax(long_span_logits,axis=1)
short_span_logits = pos_logits
short_span_logits -= 10000*tf.cast(tf.expand_dims(short_mask,-1),short_span_logits.dtype)
start_logits,end_logits = short_span_logits[:,:,0],short_span_logits[:,:,1]
batch_size,seq_length = short_span_logits.shape[0],short_span_logits.shape[1]
seq = tf.range(seq_length)
i_leq_j_mask = tf.cast(tf.expand_dims(seq,1)>tf.expand_dims(seq,0),short_span_logits.dtype)
i_leq_j_mask = tf.expand_dims(i_leq_j_mask,0)
logits = tf.expand_dims(start_logits,2)+tf.expand_dims(end_logits,1)+cross
logits -= 10000*i_leq_j_mask
logits = tf.reshape(logits, [batch_size,seq_length*seq_length])
short_p = tf.nn.softmax(logits)
indices = tf.argsort(short_p,axis=1,direction='DESCENDING')[:,:n_keep]
short_p = tf.gather(short_p,indices,batch_dims=1)
return dict(unique_id = unique_id,
ans_logits= ans_logits,
long_p = long_p,
short_p = short_p,
short_p_indices = indices)
```
and apply it to the model output.

E.g. for a minibatch `b` that has the fields `unique_id`,`token_ids`,`data_len`, and `question_len` one can use

```
## b is a minibatch
unique_id = b.pop('unique_id').numpy()
b = [b['data_len'],b['input_ids'],b['question_len']]
out_dict = output(unique_id,model(b,training=False))
for k,v in out_dict.items():
if isinstance(v,tf.Tensor):
out_dict[k] = v.numpy()
out_dict
```
Then `out_dict` contains the `unique_id`, the answer type logits, the long span probabilities, the short spans with the highest probabilities and their indices.


It can also be used within Keras:
```
hub_layer = hub.KerasLayer("https://tfhub.dev/kaggle/tf2nq/1")
def post_process(model_output,n_keep=100):
pos_logits,ans_logits,long_mask,short_mask,cross = model_output
long_span_logits = pos_logits
mask = tf.cast(tf.expand_dims(long_mask,-1),long_span_logits.dtype)
long_span_logits = long_span_logits-10000*mask
long_p = tf.nn.softmax(long_span_logits,axis=1)
short_span_logits = pos_logits
short_span_logits -= 10000*tf.cast(tf.expand_dims(short_mask,-1),short_span_logits.dtype)
start_logits,end_logits = short_span_logits[:,:,0],short_span_logits[:,:,1]
batch_size,seq_length = short_span_logits.shape[0],short_span_logits.shape[1]
seq = tf.range(seq_length)
i_leq_j_mask = tf.cast(tf.expand_dims(seq,1)>tf.expand_dims(seq,0),short_span_logits.dtype)
i_leq_j_mask = tf.expand_dims(i_leq_j_mask,0)
logits = tf.expand_dims(start_logits,2)+tf.expand_dims(end_logits,1)+cross
logits -= 10000*i_leq_j_mask
logits = tf.reshape(logits, [-1,seq_length*seq_length])
short_p = tf.nn.softmax(logits)
indices = tf.argsort(short_p,axis=1,direction='DESCENDING')[:,:n_keep]
short_p = tf.gather(short_p,indices,batch_dims=1)
return dict(ans_logits= ans_logits,
long_p = long_p,
short_p = short_p,
short_p_indices = indices)
token_ids = tf.keras.Input(shape=[512],dtype=tf.int32)
data_len = tf.keras.Input(shape=[],dtype=tf.int32)
question_len = tf.keras.Input(shape=[],dtype=tf.int32)
pos_logits,ans_logits,long_mask,short_mask,cross = layer([data_len,token_ids,question_len])
output = post_process([pos_logits,ans_logits,long_mask,short_mask,cross])
keras_model = tf.keras.Model(inputs = dict(token_ids=token_ids,
question_len=question_len,
data_len=data_len),
outputs = output)
keras_model.summary()
```

For further example of usage see the inference kernel on Kaggle [4]

#### References

[1] https://www.kaggle.com/c/tensorflow2-question-answering

[2] https://www.kaggle.com/c/tensorflow2-question-answering/discussion/127521

[3] https://www.kaggle.com/prokaj/bert-baseline-pre-and-post-process

[4] https://www.kaggle.com/prokaj/fork-of-baseline-html-tokens-v5

0 comments on commit 5a09f24

Please sign in to comment.