Spaces:
Runtime error
Runtime error
| """ | |
| Class to handle flagging in Gradio to Gantry. | |
| Originally written by the FSDL educators at https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/app_gradio/flagging.py | |
| that has been adjusted for the geolocator project. | |
| """ | |
| import os | |
| from typing import List, Optional, Union | |
| import gantry | |
| import gradio as gr | |
| from gradio.components import Component | |
| from smart_open import open | |
| from .s3_util import ( | |
| add_access_policy, | |
| enable_bucket_versioning, | |
| get_or_create_bucket, | |
| get_uri_of, | |
| make_key, | |
| ) | |
| from .string_img_util import read_b64_string | |
| class GantryImageToTextLogger(gr.FlaggingCallback): | |
| """ | |
| A FlaggingCallback that logs flagged image-to-text data to Gantry via S3. | |
| """ | |
| def __init__( | |
| self, | |
| application: str, | |
| version: Union[int, str, None] = None, | |
| api_key: Optional[str] = None, | |
| ): | |
| """Logs image-to-text data that was flagged in Gradio to Gantry. | |
| Images are logged to Amazon Web Services' Simple Storage Service (S3). | |
| The flagging_dir provided to the Gradio interface is used to set the | |
| name of the bucket on S3 into which images are logged. | |
| See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK | |
| for Python, boto3: https://realpython.com/python-boto3-aws-s3/ | |
| See https://gradio.app/docs/#flagging for details on how | |
| flagging data is handled by Gradio. | |
| See https://docs.gantry.io for information about logging data to Gantry. | |
| Parameters | |
| ---------- | |
| application | |
| The name of the application on Gantry to which flagged data should be uploaded. | |
| Gantry validates and monitors data per application. | |
| version | |
| The schema version to use during validation by Gantry. If not provided, Gantry | |
| will use the latest version. A new version will be created if the provided version | |
| does not exist yet. | |
| api_key | |
| Optionally, provide your Gantry API key here. Provided for convenience | |
| when testing and developing locally or in notebooks. The API key can | |
| alternatively be provided via the GANTRY_API_KEY environment variable. | |
| """ | |
| self.application = application | |
| self.version = version | |
| gantry.init(api_key=api_key) | |
| def setup(self, components: List[Component], flagging_dir: str): | |
| """Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket.""" | |
| self._counter = 0 | |
| self.bucket = get_or_create_bucket(flagging_dir) | |
| enable_bucket_versioning(self.bucket) | |
| add_access_policy(self.bucket) | |
| ( | |
| self.image_component_idx, | |
| self.text_component_idx, | |
| self.text_component2_idx, | |
| ) = self._find_image_video_and_text_components(components) | |
| def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: | |
| """Sends flagged outputs and feedback to Gantry and image inputs to S3.""" | |
| image = flag_data[self.image_component_idx] | |
| text = flag_data[self.text_component_idx] | |
| text2 = flag_data[self.text_component2_idx] | |
| feedback = {"flag": flag_option} | |
| if username is not None: | |
| feedback["user"] = username | |
| data_type, image_buffer = read_b64_string(image, return_data_type=True) | |
| image_url = self._to_s3(image_buffer.read(), filetype=data_type) | |
| self._to_gantry( | |
| input_image_url=image_url, | |
| pred_location=text, | |
| pred_coordinates=text2, | |
| feedback=feedback, | |
| ) | |
| self._counter += 1 | |
| return self._counter | |
| def _to_gantry(self, input_image_url, pred_location, pred_coordinates, feedback): | |
| inputs = {"image": input_image_url} | |
| outputs = {"location": pred_location, "coordinates": pred_coordinates} | |
| gantry.log_record( | |
| self.application, | |
| self.version, | |
| inputs=inputs, | |
| outputs=outputs, | |
| feedback=feedback, | |
| ) | |
| def _to_s3(self, image_bytes, key=None, filetype=None): | |
| if key is None: | |
| key = make_key(image_bytes, filetype=filetype) | |
| s3_uri = get_uri_of(self.bucket, key) | |
| with open(s3_uri, "wb") as s3_object: | |
| s3_object.write(image_bytes) | |
| return s3_uri | |
| def _find_image_video_and_text_components(self, components: List[Component]): | |
| """ | |
| Manual indexing of images and text components | |
| """ | |
| image_component_idx = 0 | |
| text_component_idx = 1 | |
| text_component2_idx = 2 | |
| return ( | |
| image_component_idx, | |
| text_component_idx, | |
| text_component2_idx, | |
| ) | |
| def get_api_key() -> Optional[str]: | |
| """Convenience method for fetching the Gantry API key.""" | |
| api_key = os.environ.get("GANTRY_API_KEY") | |
| return api_key | |