from typing import Type, Optional import logging from pydantic import BaseModel, Field from elasticsearch import Elasticsearch from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.tools.base import BaseTool from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA logging.basicConfig(level="INFO") logger = logging.getLogger("elasticsearch_playground") es = Elasticsearch( cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, api_key=SEMANTIC_ELASTIC_QA.api_key, verify_certs=True, request_timeout=60 * 3 ) class IndexShowDataInput(BaseModel): """Input for the index show data tool.""" index_name: str = Field( ..., description="The name of the index for which the data is to be retrieved" ) class IndexShowDataTool(BaseTool): """Tool for getting a list of entries from an ElasticSearch index, helpful to figure out what data is available.""" name: str = "elastic_index_show_data" # Added type annotation description: str = ( "Input is an index name, output is a JSON based string with an extract of the data of the index" ) args_schema: Optional[Type[BaseModel]] = ( IndexShowDataInput # This should be placed before methods ) def _run( self, index_name: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Get all indices in the Elasticsearch server, usually separated by a line break.""" try: # Ensure `es` is properly initialized before this method is called res = es.search( index=index_name, from_=0, size=20, query={"match_all": {}}, ) return str(res["hits"]) except Exception as e: print(e) logger.exception("Could not fetch index data for %s", index_name) return ""