Skip to content

Commit

Permalink
Merge pull request #8 from SulRash/master
Browse files Browse the repository at this point in the history
Changed some code to make it compatible with tf2
  • Loading branch information
AnshulSood11 authored Dec 19, 2023
2 parents a1722d8 + f29bdb4 commit 19ec568
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import md_config as cfg
from feature_collection import FeatureCollection


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import CuDNNLSTM, Dense, TimeDistributed, GlobalAveragePooling1D, Activation, Concatenate, \
from tensorflow.keras.layers import LSTM, Dense, TimeDistributed, GlobalAveragePooling1D, Activation, Concatenate, \
InputLayer, PReLU

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
tf.keras.backend.set_session(session)
config = tf.compat.v1.ConfigProto()
#config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)

interval_duration = 10.0

Expand All @@ -39,12 +38,12 @@ def define_model(hparams, model_name):
input_shape=(current_time_step, hparams['FC1'][0])))

model.add(
CuDNNLSTM(current_lstm_units[0], return_sequences=True, input_shape=(current_time_step, current_input_units),
LSTM(current_lstm_units[0], return_sequences=True, input_shape=(current_time_step, current_input_units),
stateful=False))

if current_n_lstms > 1:
for idx in range(1, current_n_lstms):
model.add(CuDNNLSTM(current_lstm_units[idx], return_sequences=True))
model.add(LSTM(current_lstm_units[idx], return_sequences=True))

for idx in range(current_n_denses):
model.add(TimeDistributed(Dense(current_dense_units[idx], activation='relu')))
Expand Down Expand Up @@ -119,12 +118,12 @@ def startTimer():

graph1 = tf.Graph()
with graph1.as_default():
session1 = tf.Session()
session1 = tf.compat.v1.Session()
with session1.as_default():
eye_gaze_v1 = get_model(model_index=0)
graph2 = tf.Graph()
with graph2.as_default():
session2 = tf.Session()
session2 = tf.compat.v1.Session()
with session2.as_default():
eye_gaze_v2 = get_model(model_index=1)

Expand Down

0 comments on commit 19ec568

Please sign in to comment.