Skip to content

Commit

Permalink
Add code for interaction policies
Browse files Browse the repository at this point in the history
  • Loading branch information
suvaansh committed Oct 10, 2023
1 parent 480ee54 commit ff14c31
Show file tree
Hide file tree
Showing 63 changed files with 8,688 additions and 0 deletions.
Binary file added Interactions/.DS_Store
Binary file not shown.
Binary file added Interactions/models/.DS_Store
Binary file not shown.
Empty file added Interactions/models/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions Interactions/models/config/rewards.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
{
"Generic":
{
"success": 2,
"failure": -0.05,
"step_penalty": -0.01,
"goal_reward": 5
},
"BaseAction":
{
"positive": 1,
"negative": 0,
"neutral": 0,
"invalid_action": -0.1
},
"GotoLocationAction":
{
"positive": 4,
"negative": 0,
"neutral": 0,
"invalid_action": -0.5,
"min_reach_distance": 5
},
"PickupObjectAction":
{
"positive": 2,
"negative": -1,
"neutral": 0,
"invalid_action": -0.1
},
"PutObjectAction":
{
"positive": 2,
"negative": -1,
"neutral": 0,
"invalid_action": -0.1
},
"OpenObjectAction":
{
"positive": 2,
"negative": -0.05,
"neutral": 0,
"invalid_action": -0.1
},
"CloseObjectAction":
{
"positive": 1,
"negative": 0,
"neutral": 0,
"invalid_action": -0.1
},
"ToggleObjectAction":
{
"positive": 1,
"negative": -1,
"neutral": 0,
"invalid_action": -0.1
},
"SliceObjectAction":
{
"positive": 1,
"negative": -4,
"neutral": 0,
"invalid_action": -0.1
},
"CleanObjectAction":
{
"positive": 2,
"negative": 0,
"neutral": 0,
"invalid_action": -0.1
},
"HeatObjectAction":
{
"positive": 2,
"negative": 0,
"neutral": 0,
"invalid_action": -0.1
},
"CoolObjectAction":
{
"positive": 2,
"negative": 0,
"neutral": 0,
"invalid_action": -0.1
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
136 changes: 136 additions & 0 deletions Interactions/models/eval/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
import pprint
import random
import time
import torch
import torch.multiprocessing as mp
from models.nn.resnet import Resnet
from data.preprocess import Dataset
from importlib import import_module

class Eval(object):

# tokens
STOP_TOKEN = "<<stop>>"
SEQ_TOKEN = "<<seg>>"
TERMINAL_TOKENS = [STOP_TOKEN, SEQ_TOKEN]
MANIPULATE_TOKEN = "Manipulate"

def __init__(self, args, manager):
# args and manager
self.args = args
self.manager = manager

# load splits
with open(self.args.splits) as f:
self.splits = json.load(f)
pprint.pprint({k: len(v) for k, v in self.splits.items()})

# load model
print("Loading: ", self.args.model_path)
M = import_module(self.args.model)
self.model, optimizer = M.Module.load(self.args.model_path)
self.model.share_memory()
self.model.eval()
self.model.test_mode = True #Change here

# updated args
self.model.args.dout = self.args.model_path.replace(self.args.model_path.split('/')[-1], '')
self.model.args.data = self.args.data if self.args.data else self.model.args.data

# preprocess and save
if args.preprocess:
print("\nPreprocessing dataset and saving to %s folders ... This is will take a while. Do this once as required:" % self.model.args.pp_folder)
self.model.args.fast_epoch = self.args.fast_epoch
dataset = Dataset(self.model.args, self.model.vocab)
dataset.preprocess_splits(self.splits)

# load resnet
args.visual_model = 'resnet18'
self.resnet = Resnet(args, eval=True, share_memory=True, use_conv_feat=True)

# gpu
if self.args.gpu:
self.model = self.model.to(torch.device('cuda'))

# success and failure lists
self.create_stats()

# set random seed for shuffling
random.seed(int(time.time()))

def queue_tasks(self):
'''
create queue of trajectories to be evaluated
'''
task_queue = self.manager.Queue()
files = self.splits[self.args.eval_split]

# debugging: fast epoch
if self.args.fast_epoch:
files = files[:16]

if self.args.shuffle:
random.shuffle(files)
for traj in files:
task_queue.put(traj)
return task_queue

def spawn_threads(self):
'''
spawn multiple threads to run eval in parallel
'''
task_queue = self.queue_tasks()

# start threads
threads = []
lock = self.manager.Lock()
for n in range(self.args.num_threads):
thread = mp.Process(target=self.run, args=(self.model, self.resnet, task_queue, self.args, lock,
self.successes, self.failures, self.results))
thread.start()
threads.append(thread)

for t in threads:
t.join()

# save
self.save_results()

@classmethod
def setup_scene(cls, env, traj_data, r_idx, args, reward_type='dense'):
'''
intialize the scene and agent from the task info
'''
# scene setup
scene_num = traj_data['scene']['scene_num']
object_poses = traj_data['scene']['object_poses']
dirty_and_empty = traj_data['scene']['dirty_and_empty']
object_toggles = traj_data['scene']['object_toggles']

scene_name = 'FloorPlan%d' % scene_num
env.reset(scene_name)
env.restore_scene(object_poses, object_toggles, dirty_and_empty)

# initialize to start position
env.step(dict(traj_data['scene']['init_action']))

# print goal instr
print("Task: %s" % (traj_data['turk_annotations']['anns'][r_idx]['task_desc']))

# setup task for reward
env.set_task(traj_data, args, reward_type=reward_type)

@classmethod
def run(cls, model, resnet, task_queue, args, lock, successes, failures):
raise NotImplementedError()

@classmethod
def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, successes, failures):
raise NotImplementedError()

def save_results(self):
raise NotImplementedError()

def create_stats(self):
raise NotImplementedError()
57 changes: 57 additions & 0 deletions Interactions/models/eval/eval_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import sys
sys.path.append(os.path.join(os.environ['ALFRED_ROOT']))
sys.path.append(os.path.join(os.environ['ALFRED_ROOT'], 'gen'))
sys.path.append(os.path.join(os.environ['ALFRED_ROOT'], 'models'))

import argparse
import torch.multiprocessing as mp
from eval_task import EvalTask
from eval_subgoals import EvalSubgoals


if __name__ == '__main__':
# multiprocessing settings
mp.set_start_method('spawn')
manager = mp.Manager()

# parser
parser = argparse.ArgumentParser()

# settings
parser.add_argument('--splits', type=str, default="data/splits/oct21.json")
parser.add_argument('--data', type=str, default="data/json_feat_2.1.0")
parser.add_argument('--reward_config', default='models/config/rewards.json')
parser.add_argument('--eval_split', type=str, choices=['train', 'valid_seen', 'valid_unseen'])
parser.add_argument('--model_path', type=str, default="exp/pretrained/pretrained.pth")
parser.add_argument('--model', type=str, default='models.model.seq2seq_im_mask')
parser.add_argument('--preprocess', dest='preprocess', action='store_true')
parser.add_argument('--shuffle', dest='shuffle', action='store_true')
parser.add_argument('--gpu', dest='gpu', action='store_true')
parser.add_argument('--num_threads', type=int, default=1)

# eval params
parser.add_argument('--max_steps', type=int, default=1000, help='max steps before episode termination')
parser.add_argument('--max_fails', type=int, default=10, help='max API execution failures before episode termination')

# eval settings
parser.add_argument('--subgoals', type=str, help="subgoals to evaluate independently, eg:all or GotoLocation,PickupObject...", default="")
parser.add_argument('--smooth_nav', dest='smooth_nav', action='store_true', help='smooth nav actions (might be required based on training data)')
parser.add_argument('--skip_model_unroll_with_expert', action='store_true', help='forward model with expert actions')
parser.add_argument('--no_teacher_force_unroll_with_expert', action='store_true', help='no teacher forcing with expert')

# debug
parser.add_argument('--debug', dest='debug', action='store_true')
parser.add_argument('--fast_epoch', dest='fast_epoch', action='store_true')

# parse arguments
args = parser.parse_args()

# eval mode
if args.subgoals:
eval = EvalSubgoals(args, manager)
else:
eval = EvalTask(args, manager)

# start threads
eval.spawn_threads()
Loading

0 comments on commit ff14c31

Please sign in to comment.