Skip to content

Commit 55327cb

Browse files
committed
code-update-to-handle-queries-feature
1 parent 8d44b9d commit 55327cb

File tree

4 files changed

+129
-22
lines changed

4 files changed

+129
-22
lines changed

textract/async-form-table/lambda-create-job/lambda_function.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import boto3
1414
import logging
1515
import traceback
16+
from ast import literal_eval
1617
from urllib.parse import unquote_plus
1718

1819
logger = logging.getLogger()
@@ -37,17 +38,18 @@ def process_error():
3738
OUTPUT_S3_PREFIX = os.environ["OUTPUT_S3_PREFIX"]
3839
SNS_TOPIC_ARN = os.environ["SNS_TOPIC_ARN"]
3940
SNS_ROLE_ARN = os.environ["SNS_ROLE_ARN"]
41+
FEATURES = literal_eval(os.environ["FEATURES"])
4042
logger.info(
41-
f"OUTPUT_BUCKET_NAME: {OUTPUT_BUCKET_NAME}, OUTPUT_S3_PREFIX: {OUTPUT_S3_PREFIX}, SNS_TOPIC_ARN: {SNS_TOPIC_ARN}, SNS_ROLE_ARN: {SNS_ROLE_ARN}"
43+
f"OUTPUT_BUCKET_NAME: {OUTPUT_BUCKET_NAME}, OUTPUT_S3_PREFIX: {OUTPUT_S3_PREFIX}, SNS_TOPIC_ARN: {SNS_TOPIC_ARN}, SNS_ROLE_ARN: {SNS_ROLE_ARN}, FEATURES: {FEATURES}"
4244
)
4345
except Exception as e:
4446
error_msg = process_error()
4547
logger.error(error_msg)
4648

4749

4850
def lambda_handler(event, context):
49-
5051
textract = boto3.client("textract")
52+
s3 = boto3.client("s3")
5153
try:
5254
if "Records" in event:
5355
logger.info(f"Event: {event}")
@@ -56,18 +58,44 @@ def lambda_handler(event, context):
5658
filename = unquote_plus(str(file_obj["s3"]["object"]["key"]))
5759
logger.info(f"Bucket: {bucketname} ::: Key: {filename}")
5860

59-
response = textract.start_document_analysis(
60-
DocumentLocation={"S3Object": {"Bucket": bucketname, "Name": filename}},
61-
FeatureTypes=["TABLES", "FORMS", "SIGNATURES"],
62-
OutputConfig={
63-
"S3Bucket": OUTPUT_BUCKET_NAME,
64-
"S3Prefix": OUTPUT_S3_PREFIX,
65-
},
66-
NotificationChannel={
67-
"SNSTopicArn": SNS_TOPIC_ARN,
68-
"RoleArn": SNS_ROLE_ARN,
69-
},
70-
)
61+
if "QUERIES" in FEATURES:
62+
file_obj = s3.get_object(
63+
Bucket=bucketname, Key="async-input/queries.json"
64+
)
65+
queries_data = json.loads(file_obj["Body"].read().decode("utf-8"))
66+
67+
response = textract.start_document_analysis(
68+
DocumentLocation={
69+
"S3Object": {"Bucket": bucketname, "Name": filename}
70+
},
71+
FeatureTypes=FEATURES,
72+
QueriesConfig=queries_data,
73+
OutputConfig={
74+
"S3Bucket": OUTPUT_BUCKET_NAME,
75+
"S3Prefix": OUTPUT_S3_PREFIX,
76+
},
77+
NotificationChannel={
78+
"SNSTopicArn": SNS_TOPIC_ARN,
79+
"RoleArn": SNS_ROLE_ARN,
80+
},
81+
)
82+
83+
else:
84+
response = textract.start_document_analysis(
85+
DocumentLocation={
86+
"S3Object": {"Bucket": bucketname, "Name": filename}
87+
},
88+
FeatureTypes=FEATURES,
89+
OutputConfig={
90+
"S3Bucket": OUTPUT_BUCKET_NAME,
91+
"S3Prefix": OUTPUT_S3_PREFIX,
92+
},
93+
NotificationChannel={
94+
"SNSTopicArn": SNS_TOPIC_ARN,
95+
"RoleArn": SNS_ROLE_ARN,
96+
},
97+
)
98+
7199
if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
72100
logger.info(f"Job created successfully")
73101
return {

textract/async-form-table/lambda-process-response/helper/helper.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def save_text_csv(keys, values, job_id, BUCKET_NAME):
101101
upload_to_s3(csv_buffer, BUCKET_NAME, key)
102102

103103

104+
def save_queries_csv(queries, job_id, BUCKET_NAME):
105+
key = f"queries/{job_id}/queryAnswer.csv"
106+
df = pd.DataFrame.from_dict(queries, orient="index").reset_index()
107+
df.drop(["index", "answer_ids"], axis=1, inplace=True)
108+
csv_buffer = io.StringIO()
109+
df.to_csv(csv_buffer)
110+
upload_to_s3(csv_buffer, BUCKET_NAME, key)
111+
112+
104113
def map_word_id(response):
105114
word_map = {}
106115
for block in response["Blocks"]:
@@ -112,7 +121,13 @@ def map_word_id(response):
112121

113122

114123
def process_response(
115-
BUCKET_NAME, job_id, get_table=True, get_kv=True, get_text=True, get_signatures=True
124+
BUCKET_NAME,
125+
job_id,
126+
get_table=True,
127+
get_kv=True,
128+
get_text=True,
129+
get_signatures=True,
130+
get_queries=True,
116131
):
117132
textract = boto3.client("textract")
118133

@@ -149,8 +164,9 @@ def process_response(
149164
get_kv=get_kv,
150165
get_text=get_text,
151166
get_signatures=get_signatures,
167+
get_queries=get_queries,
152168
)
153-
table, final_map, text, sign = parse.process_response()
169+
table, final_map, text, sign, queries = parse.process_response()
154170

155171
if get_kv:
156172
keys = list(map(itemgetter(0), final_map))
@@ -165,4 +181,6 @@ def process_response(
165181
save_text_csv(text_key, text_value, job_id, BUCKET_NAME)
166182
if get_signatures:
167183
save_sign_csv(sign, job_id, BUCKET_NAME)
184+
if get_queries:
185+
save_queries_csv(queries, job_id, BUCKET_NAME)
168186
logger.info("Parsing completed")

textract/async-form-table/lambda-process-response/helper/parser.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class Parse:
19-
def __init__(self, page, get_table, get_kv, get_text, get_signatures):
19+
def __init__(self, page, get_table, get_kv, get_text, get_signatures, get_queries):
2020
self.response = page
2121
self.word_map = {}
2222
self.table_page_map = {}
@@ -28,6 +28,7 @@ def __init__(self, page, get_table, get_kv, get_text, get_signatures):
2828
self.get_kv = get_kv
2929
self.get_text = get_text
3030
self.get_signatures = get_signatures
31+
self.get_queries = get_queries
3132

3233
def extract_text(self, extract_by="LINE"):
3334
for block in self.response:
@@ -57,7 +58,6 @@ def extract_table_info(self):
5758
response_block_len = len(self.response) - 1
5859

5960
for n, block in enumerate(self.response):
60-
6161
if block["BlockType"] == "TABLE":
6262
key = f"table_{uuid.uuid4().hex}_page_{block['Page']}"
6363
temp_table = []
@@ -94,7 +94,6 @@ def extract_table_info(self):
9494

9595
def get_key_map(self):
9696
for block in self.response:
97-
9897
if block["BlockType"] == "KEY_VALUE_SET" and "KEY" in block["EntityTypes"]:
9998
for relation in block["Relationships"]:
10099
if relation["Type"] == "VALUE":
@@ -136,8 +135,66 @@ def get_signature_info(self):
136135
temp_counter += 1
137136
return (page, signature, confidence)
138137

138+
def get_queries_info(self):
139+
temp_id = []
140+
f_response = {}
141+
142+
for e, block in enumerate(self.response):
143+
if block["BlockType"] == "QUERY":
144+
if "Relationships" not in block:
145+
rp = {
146+
"query": block.get("Query").get("Text"),
147+
"alias": block.get("Query").get("Alias"),
148+
"answer_ids": None,
149+
"answer": None,
150+
"confidence": None,
151+
"page": None,
152+
}
153+
else:
154+
child_ids = [
155+
ids
156+
for rel in block.get("Relationships")
157+
for ids in rel["Ids"]
158+
if rel.get("Type") == "ANSWER"
159+
]
160+
rp = {
161+
"query": block.get("Query").get("Text"),
162+
"alias": block.get("Query").get("Alias"),
163+
"answer_ids": child_ids,
164+
"answer": None,
165+
"confidence": None,
166+
"page": None,
167+
}
168+
169+
f_response[block.get("Id")] = rp
170+
temp_id.append({block.get("Id"): rp})
171+
172+
if block["BlockType"] == "QUERY_RESULT":
173+
q_id = list(temp_id[-1].keys())[0]
174+
q_val = temp_id[-1].get(q_id)
175+
176+
if q_val.get("answer_ids"):
177+
if block.get("Id") in q_val.get("answer_ids"):
178+
q_ans = block.get("Text")
179+
confidence_s = block.get("Confidence")
180+
q_val["confidence"] = confidence_s
181+
if q_val.get("answer"):
182+
q_val["answer"] = f"{q_val.get('answer')} {q_ans}"
183+
else:
184+
q_val["answer"] = q_ans
185+
q_val["page"] = block.get("Page")
186+
f_response[q_id] = q_val
187+
temp_id = []
188+
return f_response
189+
139190
def process_response(self):
140-
final_map, table_info, text = None, None, None
191+
final_map, table_info, text, sign_info, queries_info = (
192+
None,
193+
None,
194+
None,
195+
None,
196+
None,
197+
)
141198

142199
logging.info("Mapping Id with word")
143200
self.map_word_id()
@@ -160,4 +217,8 @@ def process_response(self):
160217
logging.info("Extracting signature information")
161218
sign_info = self.get_signature_info()
162219

163-
return table_info, final_map, text, sign_info
220+
if self.get_queries:
221+
logging.info("Extracting queries information")
222+
queries_info = self.get_queries_info()
223+
224+
return table_info, final_map, text, sign_info, queries_info

textract/async-form-table/lambda-process-response/lambda_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
def lambda_handler(event, context):
20-
2120
try:
2221
BUCKET_NAME = os.environ["BUCKET_NAME"]
2322
logger.info(f"Destination bucket: {BUCKET_NAME}")
@@ -39,6 +38,7 @@ def lambda_handler(event, context):
3938
get_kv=True,
4039
get_text=True,
4140
get_signatures=True,
41+
get_queries=True,
4242
)
4343

4444
return {

0 commit comments

Comments
 (0)