Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased]

### Added
- Convert JSON to CSV for search index tool result ([#140](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/140))

### Fixed
- Fix AWS auth issues for cat based tools, pin OpenSearchPy to 2.18.0 ([#135](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/135))
Expand Down
178 changes: 178 additions & 0 deletions src/opensearch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import json
import logging
import csv
import io
from semver import Version
from tools.tool_params import *

Expand Down Expand Up @@ -277,6 +279,182 @@ async def get_nodes_info(args: GetNodesArgs) -> json:
return response


def convert_search_results_to_csv(search_results: dict) -> str:
"""Convert OpenSearch search results to CSV format.

Args:
search_results: The JSON response from search_index function

Returns:
str: CSV formatted string of the search results
"""
if not search_results:
return "No search results to convert"

# Handle aggregations-only queries
if 'aggregations' in search_results and ('hits' not in search_results or not search_results['hits']['hits']):
return _convert_aggregations_to_csv(search_results['aggregations'])

# Handle regular search results
if 'hits' not in search_results:
return "No search results to convert"

hits = search_results['hits']['hits']
if not hits:
return "No documents found in search results"

Comment on lines +302 to +305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about aggregations queries?

# Extract all unique field names from all documents (flattened)
all_fields = set()
for hit in hits:
if '_source' in hit:
_flatten_fields(hit['_source'], all_fields)
# Also include metadata fields
all_fields.update(['_index', '_id', '_score'])

# Convert to sorted list for consistent column order
fieldnames = sorted(list(all_fields))

# Create CSV in memory
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()

# Write each document as a row
for hit in hits:
row = {}
# Add metadata fields
row['_index'] = hit.get('_index', '')
row['_id'] = hit.get('_id', '')
row['_score'] = hit.get('_score', '')

# Add source fields (flattened)
if '_source' in hit:
_flatten_object(hit['_source'], row)

writer.writerow(row)

return output.getvalue()


def _convert_aggregations_to_csv(aggregations: dict) -> str:
"""Convert OpenSearch aggregations to CSV format.

Args:
aggregations: The aggregations section from search results

Returns:
str: CSV formatted string of the aggregations
"""
rows = []
_flatten_aggregations(aggregations, {}, rows)

if not rows:
return "No aggregation data to convert"

# Get all unique field names
all_fields = set()
for row in rows:
all_fields.update(row.keys())

fieldnames = sorted(list(all_fields))

# Create CSV in memory
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=fieldnames)
writer.writeheader()

for row in rows:
writer.writerow(row)

return output.getvalue()


def _flatten_fields(obj: dict, fields: set, prefix: str = '') -> None:
"""Extract all field names from nested objects.

Args:
obj: Object to extract field names from
fields: Set to add field names to
prefix: Current field prefix
"""
for key, value in obj.items():
field_name = f'{prefix}{key}' if prefix else key
if isinstance(value, dict):
_flatten_fields(value, fields, f'{field_name}.')
elif isinstance(value, list) and value and isinstance(value[0], dict):
# For arrays of objects, flatten the first object to get field structure
_flatten_fields(value[0], fields, f'{field_name}.')
fields.add(field_name) # Also keep the array field itself
else:
fields.add(field_name)


def _flatten_object(obj: dict, row: dict, prefix: str = '') -> None:
"""Flatten nested objects into separate columns.

Args:
obj: Object to flatten
row: Row dictionary to add flattened fields to
prefix: Current field prefix
"""
for key, value in obj.items():
field_name = f'{prefix}{key}' if prefix else key
if isinstance(value, dict):
_flatten_object(value, row, f'{field_name}.')
elif isinstance(value, list):
if value and isinstance(value[0], dict):
# For arrays of objects, flatten first object and keep array as JSON
_flatten_object(value[0], row, f'{field_name}.')
row[field_name] = json.dumps(value)
else:
# For simple arrays, convert to JSON
row[field_name] = json.dumps(value)
else:
row[field_name] = str(value) if value is not None else ''


def _flatten_aggregations(aggs: dict, current_row: dict, rows: list, prefix: str = '') -> None:
"""Recursively flatten aggregations into CSV rows.

Args:
aggs: Current aggregation level
current_row: Current row being built
rows: List to append completed rows
prefix: Current field prefix
"""
for agg_name, agg_data in aggs.items():
if isinstance(agg_data, dict):
# Handle bucket aggregations
if 'buckets' in agg_data:
for bucket in agg_data['buckets']:
new_row = current_row.copy()
bucket_key = f'{prefix}{agg_name}_key' if prefix else f'{agg_name}_key'
new_row[bucket_key] = str(bucket.get('key', ''))

if 'doc_count' in bucket:
count_key = f'{prefix}{agg_name}_doc_count' if prefix else f'{agg_name}_doc_count'
new_row[count_key] = bucket['doc_count']

# Handle nested aggregations
nested_aggs = {k: v for k, v in bucket.items() if k not in ['key', 'doc_count']}
if nested_aggs:
_flatten_aggregations(nested_aggs, new_row, rows, f'{prefix}{agg_name}_')
else:
rows.append(new_row)

# Handle metric aggregations
elif 'value' in agg_data:
value_key = f'{prefix}{agg_name}' if prefix else agg_name
current_row[value_key] = agg_data['value']

# Handle stats aggregations
elif any(k in agg_data for k in ['count', 'min', 'max', 'avg', 'sum']):
for stat_name, stat_value in agg_data.items():
if stat_name in ['count', 'min', 'max', 'avg', 'sum']:
stat_key = f'{prefix}{agg_name}_{stat_name}' if prefix else f'{agg_name}_{stat_name}'
current_row[stat_key] = stat_value


async def get_opensearch_version(args: baseToolArgs) -> Version:
"""Get the version of OpenSearch cluster.

Expand Down
1 change: 1 addition & 0 deletions src/tools/tool_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class GetIndexMappingArgs(baseToolArgs):
class SearchIndexArgs(baseToolArgs):
index: str = Field(description='The name of the index to search in')
query: Any = Field(description='The search query in OpenSearch query DSL format')
format: str = Field(default='json', description='Output format: "json" or "csv"')


class GetShardsArgs(baseToolArgs):
Expand Down
26 changes: 18 additions & 8 deletions src/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .utils import is_tool_compatible
from opensearch.helper import (
convert_search_results_to_csv,
get_allocation,
get_cluster_state,
get_index,
Expand Down Expand Up @@ -109,14 +110,23 @@ async def search_index_tool(args: SearchIndexArgs) -> list[dict]:
try:
await check_tool_compatibility('SearchIndexTool', args)
result = await search_index(args)
formatted_result = json.dumps(result, indent=2)

return [
{
'type': 'text',
'text': f'Search results from {args.index}:\n{formatted_result}',
}
]

if args.format.lower() == 'csv':
csv_result = convert_search_results_to_csv(result)
return [
{
'type': 'text',
'text': f'Search results from {args.index} (CSV format):\n{csv_result}',
}
]
else:
formatted_result = json.dumps(result, indent=2)
return [
{
'type': 'text',
'text': f'Search results from {args.index} (JSON format):\n{formatted_result}',
}
]
except Exception as e:
return [{'type': 'text', 'text': f'Error searching index: {str(e)}'}]

Expand Down
Loading