Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions backend/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,53 @@ async def chat_bot(uri=Form(None),model=Form(None),userName=Form(None), password
finally:
gc.collect()

@app.post("/chat_bot_grounding")
async def chat_bot_grounding(
uri=Form(None),
model=Form(None),
userName=Form(None),
password=Form(None),
database=Form(None),
question=Form(None),
document_names=Form(None),
session_id=Form(None),
mode=Form(None),
email=Form(None),
requireGrounding: bool = Form(True)
):
logging.info(f"QA_RAG (grounding) called at {datetime.now()}")
logging.info(f"document_names = {document_names}")
qa_rag_start_time = time.time()
try:
if mode == "graph":
graph = Neo4jGraph( url=uri,username=userName,password=password,database=database,sanitize = True, refresh_schema=True)
else:
graph = create_graph_database_connection(uri, userName, password, database)

graph_DB_dataAccess = graphDBdataAccess(graph)
write_access = graph_DB_dataAccess.check_account_access(database=database)
# Select the system template based on requireGrounding (to be used inside QA_RAG or before calling it):
# system_template = CHAT_SYSTEM_TEMPLATE if requireGrounding else CHAT_SYSTEM_TEMPLATE_UNGROUNDED
result = await asyncio.to_thread(QA_RAG_GROUNDING,graph=graph,model=model,question=question,document_names=document_names,session_id=session_id,mode=mode,write_access=write_access,requireGrounding=requireGrounding)

total_call_time = time.time() - qa_rag_start_time
logging.info(f"Total Response time is {total_call_time:.2f} seconds")
result["info"]["response_time"] = round(total_call_time, 2)

json_obj = {'api_name':'chat_bot_grounding','db_url':uri, 'userName':userName, 'database':database, 'question':question,'document_names':document_names,
'session_id':session_id, 'mode':mode, 'requireGrounding': requireGrounding, 'logging_time': formatted_time(datetime.now(timezone.utc)), 'elapsed_api_time':f'{total_call_time:.2f}','email':email}
logger.log_struct(json_obj, "INFO")

return create_api_response('Success',data=result)
except Exception as e:
job_status = "Failed"
message="Unable to get chat response"
error_message = str(e)
logging.exception(f'Exception in chat bot grounding:{error_message}')
return create_api_response(job_status, message=message, error=error_message,data=mode)
finally:
gc.collect()

@app.post("/chunk_entities")
async def chunk_entities(uri=Form(None),userName=Form(None), password=Form(None), database=Form(None), nodedetails=Form(None),entities=Form(),mode=Form(),email=Form(None)):
try:
Expand Down
45 changes: 45 additions & 0 deletions backend/src/QA_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,4 +911,49 @@ def QA_RAG(graph,model, question, document_names, session_id, mode, write_access

result["session_id"] = session_id

return result

def QA_RAG_GROUNDING(graph,model, question, document_names, session_id, mode, write_access=True, requireGrounding=True):
logging.info(f"Chat Mode: {mode}")

history = create_neo4j_chat_message_history(graph, session_id, write_access)
print(history)
messages = history.messages

print("message history")
print(messages)

user_question = HumanMessage(content=question)
messages.append(user_question)

tool_calls = extract_tool_calls(model, user_question)
logging.info(tool_calls)

if mode == CHAT_GRAPH_MODE:
result = process_graph_response(model, graph, question, messages, history)
else:
chat_mode_settings = get_chat_mode_settings(mode=mode)
document_names= list(map(str.strip, json.loads(document_names)))
logging.info(f"chat_mode_settings['document_filter'] = {chat_mode_settings['document_filter']}")
if document_names and not chat_mode_settings["document_filter"]:
result = {
"session_id": "",
"message": "Please deselect all documents in the table before using this chat mode",
"info": {
"sources": [],
"model": "",
"nodedetails": [],
"total_tokens": 0,
"response_time": 0,
"mode": chat_mode_settings["mode"],
"entities": [],
"metric_details": [],
},
"user": "chatbot"
}
else:
result = process_chat_response(messages,history, question, model, graph, document_names,chat_mode_settings, extract_tools=False, filter_properties={}, requireGrounding=requireGrounding)

result["session_id"] = session_id

return result
25 changes: 25 additions & 0 deletions terraform/environments/metrix-demo-dev/.terraform.lock.hcl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

85 changes: 85 additions & 0 deletions terraform/environments/metrix-demo-dev/main.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
terraform {
required_version = "1.10.5"

backend "s3" {
# eschewed this for supplying via --backend-config=metrix-demo-dev.backend.conf
}

required_providers {
aws = {
source = "hashicorp/aws"
version = "~> 5.84.0"
}
}
}

provider "aws" {
region = "us-east-2"
}

locals {
project_name = "graphbuilder"
}

module "my_ecr_repo" {
source = "../../modules/ecr" # Path to the module
project_name = local.project_name
environment = var.environment
}

module "vpc" {
source = "../../modules/vpc"
project_name = local.project_name
vpc_cidr_block = "10.0.0.0/16"
environment = var.environment
region = var.aws_region
}


module "sg" {
source = "../../modules/sg"
project_name = local.project_name
vpc_id = module.vpc.vpc_id
env = var.environment
}

module "ecs" {
source = "../../modules/ecs"
project_name = local.project_name
environment = var.environment
aws_region = var.aws_region
instance_type = "t4g.large"
vpc_security_group_ids = [module.sg.sg.id]
vpc_id = module.vpc.vpc_id
subnet_id = module.vpc.public_subnet.id
ami_id = "ami-016032b20e02dbcad"
ec2_keypair_name = "metrix-demo-dev-ec2-keypair"
user_data_path = "./ec2-user-data.ps1"
backend_ecr_url = module.my_ecr_repo.backend_ecr_repository_url
backend_ecr_arn = module.my_ecr_repo.backend_repository_arn
backend_repository_name = module.my_ecr_repo.backend_repository_name
frontend_ecr_url = module.my_ecr_repo.frontend_ecr_repository_url
frontend_ecr_arn = module.my_ecr_repo.frontend_repository_arn
frontend_repository_name = module.my_ecr_repo.frontend_repository_name
}

module "alb" {
source = "../../modules/alb"
acm_certificate_domain = "*.futuretalk.ca"
environment = var.environment
project_name = local.project_name
vpc_id = module.vpc.vpc_id
subnet_ids = [module.vpc.public_subnet.id, module.vpc.public_subnet_2.id]
target_group_port = 8000
ec2_instance_id = module.ecs.ec2_instance_id
health_check_path = "/health"
health_check_port = 8000
}

output "vpc_id" {
value = module.vpc.vpc_id
}

output "certificate_arn" {
value = module.alb.certificate_arn
}
9 changes: 9 additions & 0 deletions terraform/environments/metrix-demo-dev/variables.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
variable "environment" {
type = string
description = "environment name to deploy to (dev, prod, developer_name etc)"
}

variable "aws_region" {
type = string
description = "aws region to deploy to"
}
137 changes: 137 additions & 0 deletions terraform/modules/alb/alb.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# # Application Load Balancer
resource "aws_lb" "main" {
name = "${var.project_name}-${var.environment}-alb"
internal = false
load_balancer_type = "application"
security_groups = [aws_security_group.alb.id]
subnets = var.subnet_ids

enable_deletion_protection = false

tags = {
Name = "${var.project_name}-${var.environment}-alb"
Environment = var.environment
Project = var.project_name
}
}

# Security Group for ALB
resource "aws_security_group" "alb" {
name = "${var.project_name}-${var.environment}-alb-sg"
description = "Security group for Application Load Balancer"
vpc_id = var.vpc_id

ingress {
description = "HTTPS from Internet"
from_port = 443
to_port = 443
protocol = "tcp"
cidr_blocks = ["0.0.0.0/0"]
}

ingress {
description = "HTTP from Internet"
from_port = 80
to_port = 80
protocol = "tcp"
cidr_blocks = ["0.0.0.0/0"]
}

egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}

tags = {
Name = "${var.project_name}-${var.environment}-alb-sg"
Environment = var.environment
Project = var.project_name
}
}

# Target Group
resource "aws_lb_target_group" "main" {
name = "${var.project_name}-${var.environment}-tg"
port = var.target_group_port
protocol = "HTTP"
vpc_id = var.vpc_id

health_check {
enabled = true
healthy_threshold = 2
interval = 30
matcher = "200"
path = var.health_check_path
port = var.health_check_port
protocol = "HTTP"
timeout = 5
unhealthy_threshold = 2
}

tags = {
Name = "${var.project_name}-${var.environment}-tg"
Environment = var.environment
Project = var.project_name
}
}

# Target Group Attachment
resource "aws_lb_target_group_attachment" "main" {
target_group_arn = aws_lb_target_group.main.arn
target_id = var.ec2_instance_id
port = var.target_group_port
}

# Data source for ACM certificate
data "aws_acm_certificate" "main" {
count = var.acm_certificate_domain != null ? 1 : 0
domain = var.acm_certificate_domain
statuses = ["ISSUED"]
most_recent = true
}

# HTTPS Listener (only created if certificate is available)
resource "aws_lb_listener" "https" {
count = (var.certificate_arn != null || var.acm_certificate_domain != null) ? 1 : 0
load_balancer_arn = aws_lb.main.arn
port = "443"
protocol = "HTTPS"
ssl_policy = "ELBSecurityPolicy-2016-08"
certificate_arn = var.certificate_arn != null ? var.certificate_arn : data.aws_acm_certificate.main[0].arn

default_action {
type = "forward"
target_group_arn = aws_lb_target_group.main.arn
}
}

# HTTP Listener (redirects to HTTPS if HTTPS listener exists, otherwise forwards to target group)
resource "aws_lb_listener" "http" {
load_balancer_arn = aws_lb.main.arn
port = "80"
protocol = "HTTP"

default_action {
type = (var.certificate_arn != null || var.acm_certificate_domain != null) ? "redirect" : "forward"

dynamic "redirect" {
for_each = (var.certificate_arn != null || var.acm_certificate_domain != null) ? [1] : []
content {
port = "443"
protocol = "HTTPS"
status_code = "HTTP_301"
}
}

dynamic "forward" {
for_each = (var.certificate_arn == null && var.acm_certificate_domain == null) ? [1] : []
content {
target_group {
arn = aws_lb_target_group.main.arn
}
}
}
}
}
Loading