Skip to content

Commit 9211c1a

Browse files
committed
Allow specific problems to be selected
1 parent 6e5cf18 commit 9211c1a

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

backend/main.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,24 @@ def get(self, problemID):
187187

188188
class EntryListAPI(Resource):
189189
def __init__(self):
190-
self.parser = reqparse.RequestParser()
191-
self.parser.add_argument("problemID", type=str, required=True, location="json")
192-
self.parser.add_argument("userID", type=str, required=True, location="json")
193-
self.parser.add_argument("file", type=FileStorage, required=True, location="files")
194190
super(EntryListAPI, self).__init__()
195191

196192
def get(self):
197-
return jsonify([a for a in db.entry.find({})])
193+
parser = reqparse.RequestParser()
194+
parser.add_argument("problemID", type=str, required=False, location="args")
195+
196+
args = parser.parse_args()
197+
if args["problemID"] is not None:
198+
return jsonify([a for a in db.entry.find({"problemID": args["problemID"]})])
199+
else:
200+
return jsonify([a for a in db.entry.find({})])
198201

199202
def post(self):
200-
entry = self.parser.parse_args()
203+
parser = reqparse.RequestParser()
204+
parser.add_argument("problemID", type=str, required=True, location="json")
205+
parser.add_argument("userID", type=str, required=True, location="json")
206+
parser.add_argument("file", type=FileStorage, required=True, location="files")
207+
entry = parser.parse_args()
201208

202209
try:
203210
if db.problem.find_one({"_id": ObjectId(entry['problemID'])}) == None:
@@ -223,7 +230,6 @@ def post(self):
223230
status_code = 400
224231

225232
return jsonify(structuredGradingOutput, status=status_code)
226-
227233
class EntryAPI(Resource):
228234
def get(self, entryID):
229235
try:

backend/tests.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def testGet(self):
181181

182182
INVALID_EXAMPLE_ENTRY = {"problemID": "incorrectproblemid", "userID": "incorrectuserid", "score": 12}
183183

184-
def generateExampleEntry(db):
185-
exampleUser = copy.deepcopy(EXAMPLE_USER)
186-
exampleProblem = copy.deepcopy(EXAMPLE_PROBLEM)
184+
def generateExampleEntry(db, exampleProblem=EXAMPLE_PROBLEM, exampleUser=EXAMPLE_USER):
185+
exampleUser = copy.deepcopy(exampleUser)
186+
exampleProblem = copy.deepcopy(exampleProblem)
187187

188188
db.user.insert_one(exampleUser)
189189
db.problem.insert_one(exampleProblem)
@@ -199,6 +199,22 @@ def testGetAll(self):
199199
newEntry = json.loads(self.app.get("/entries").data.decode("utf-8"))[0]
200200
assert areDicsEqual(exampleEntry, newEntry)
201201

202+
def testGetProblem(self):
203+
assert b'[]' in self.app.get("/entries").data
204+
205+
exampleEntry1 = generateExampleEntry(self.db)
206+
207+
exampleProblem2 = copy.deepcopy(EXAMPLE_PROBLEM)
208+
exampleProblem2["name"] = "Other Problem"
209+
exampleEntry2 = generateExampleEntry(self.db, exampleProblem=exampleProblem2)
210+
211+
self.db.entry.insert_one(exampleEntry1)
212+
self.db.entry.insert_one(exampleEntry2)
213+
214+
returnedEntries = json.loads(self.app.get("/entries", query_string={"problemID": exampleEntry1["problemID"]}).data.decode("utf-8"))
215+
assert areDicsEqual(exampleEntry1, returnedEntries[0])
216+
assert len(returnedEntries) == 1
217+
202218
def testGet(self):
203219
assert self.app.get("/entries/1").status_code == 404
204220

0 commit comments

Comments
 (0)