Spaces:
Runtime error
Runtime error
| import base64 | |
| import json | |
| import mimetypes | |
| # import mimetypes | |
| import os | |
| import sys | |
| from io import BytesIO | |
| from typing import Dict, Tuple, Union | |
| import banana_dev as banana | |
| import geopy.distance | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly | |
| import plotly.express as px | |
| # import requests | |
| from dotenv import load_dotenv | |
| from smart_open import open as smartopen | |
| sys.path.append("..") | |
| from gantry_callback.gantry_util import GantryImageToTextLogger # noqa: E402 | |
| from gantry_callback.s3_util import ( # noqa: E402 | |
| add_access_policy, | |
| enable_bucket_versioning, | |
| get_or_create_bucket, | |
| get_uri_of, | |
| make_key, | |
| make_unique_bucket_name, | |
| ) | |
| from gantry_callback.string_img_util import read_b64_string # noqa: E402 | |
| load_dotenv() | |
| URL = os.getenv("ENDPOINT") | |
| GANTRY_APP_NAME = os.getenv("GANTRY_APP_NAME") | |
| GANTRY_KEY = os.getenv("GANTRY_API_KEY") | |
| MAPBOX_TOKEN = os.getenv("MAPBOX_TOKEN") | |
| BANANA_API_KEY = os.getenv("BANANA_API_KEY") | |
| BANANA_MODEL_KEY = os.getenv("BANANA_MODEL_KEY") | |
| examples = json.load(open("examples.json")) | |
| def compute_distance(map_data: Dict[str, Dict[str, Union[str, float, None]]]): | |
| hierarchy_lat, hierarchy_long = ( | |
| map_data["hierarchy"]["latitude"], | |
| map_data["hierarchy"]["longitude"], | |
| ) | |
| coarse_lat, coarse_long = ( | |
| map_data["coarse"]["latitude"], | |
| map_data["coarse"]["longitude"], | |
| ) | |
| fine_lat, fine_long = ( | |
| map_data["fine"]["latitude"], | |
| map_data["fine"]["longitude"], | |
| ) | |
| hierarchy_to_coarse = geopy.distance.geodesic( | |
| (hierarchy_lat, hierarchy_long), (coarse_lat, coarse_long) | |
| ).miles | |
| hierarchy_to_fine = geopy.distance.geodesic( | |
| (hierarchy_lat, hierarchy_long), (fine_lat, fine_long) | |
| ).miles | |
| return hierarchy_to_coarse, hierarchy_to_fine | |
| def get_plotly_graph( | |
| map_data: Dict[str, Dict[str, Union[str, float, None]]] | |
| ) -> plotly.graph_objects.Figure: | |
| hierarchy_to_coarse, hierarchy_to_fine = compute_distance(map_data) | |
| what_to_consider = {"hierarchy"} | |
| if hierarchy_to_coarse > 5000: | |
| what_to_consider.add("coarse") | |
| if hierarchy_to_fine > 30: | |
| what_to_consider.add("fine") | |
| size_map = {"hierarchy": 3, "fine": 1, "coarse": 1} | |
| lat_long_data = [] | |
| for subdivision, location_data in map_data.items(): | |
| if subdivision in what_to_consider: | |
| lat_long_data.append( | |
| [ | |
| subdivision, | |
| float(location_data["latitude"]), | |
| float(location_data["longitude"]), | |
| location_data["location"], | |
| size_map[subdivision], | |
| ] | |
| ) | |
| map_df = pd.DataFrame( | |
| lat_long_data, | |
| columns=["subdivision", "latitude", "longitude", "location", "size"], | |
| ) | |
| px.set_mapbox_access_token(MAPBOX_TOKEN) | |
| fig = px.scatter_mapbox( | |
| map_df, | |
| lat="latitude", | |
| lon="longitude", | |
| hover_name="location", | |
| hover_data=["latitude", "longitude", "subdivision"], | |
| color="subdivision", | |
| color_discrete_map={ | |
| "hierarchy": "fuchsia", | |
| "coarse": "blue", | |
| "fine": "yellow", | |
| }, | |
| zoom=2, | |
| height=500, | |
| size="size", | |
| ) | |
| fig.update_layout(mapbox_style="dark") | |
| fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0}) | |
| return fig | |
| def gradio_error(): | |
| raise gr.Error("Unable to detect the location!") | |
| def get_outputs( | |
| data: Dict[str, Dict[str, Union[str, float, None]]] | |
| ) -> Tuple[str, str, plotly.graph_objects.Figure]: | |
| if data is None: | |
| gradio_error() | |
| location, latitude, longitude = ( | |
| data["hierarchy"]["location"], | |
| data["hierarchy"]["latitude"], | |
| data["hierarchy"]["longitude"], | |
| ) | |
| if location is None: | |
| gradio_error() | |
| return ( | |
| location, | |
| f"{latitude},{longitude}", | |
| get_plotly_graph(map_data=data), | |
| ) | |
| def image_gradio(img_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]: | |
| # data = json.loads( | |
| # requests.post( | |
| # f"{URL}predict-image", | |
| # files={ | |
| # "image": ( | |
| # img_file, | |
| # open(img_file, "rb"), | |
| # mimetypes.guess_type(img_file)[0], | |
| # ) | |
| # }, | |
| # ).text | |
| # ) | |
| with open(img_file, "rb") as image_file: | |
| image_bytes = BytesIO(image_file.read()) | |
| data = banana.run( | |
| BANANA_API_KEY, | |
| BANANA_MODEL_KEY, | |
| { | |
| "image": base64.b64encode(image_bytes.getvalue()).decode("utf-8"), | |
| "filename": os.path.basename(img_file), | |
| }, | |
| )["modelOutputs"][0] | |
| return get_outputs(data=data) | |
| def _upload_video_to_s3(video_b64_string): | |
| bucket = get_or_create_bucket( | |
| make_unique_bucket_name(prefix="geolocator-app", seed="420") | |
| ) | |
| enable_bucket_versioning(bucket) | |
| add_access_policy(bucket) | |
| data_type, video_buffer = read_b64_string(video_b64_string, return_data_type=True) | |
| video_bytes = video_buffer.read() | |
| key = make_key(video_bytes, filetype=data_type) | |
| s3_uri = get_uri_of(bucket, key) | |
| with smartopen(s3_uri, "wb") as s3_object: | |
| s3_object.write(video_bytes) | |
| return s3_uri | |
| def video_gradio(video_file: str) -> Tuple[str, str, plotly.graph_objects.Figure]: | |
| # data = json.loads( | |
| # requests.post( | |
| # f"{URL}predict-video", | |
| # files={ | |
| # "video": ( | |
| # video_file, | |
| # open(video_file, "rb"), | |
| # "application/octet-stream", | |
| # ) | |
| # }, | |
| # ).text | |
| # ) | |
| with open(video_file, "rb") as video_file: | |
| video_b64_string = base64.b64encode( | |
| BytesIO(video_file.read()).getvalue() | |
| ).decode("utf8") | |
| video_mime = mimetypes.guess_type(video_file)[0] | |
| s3_uri = _upload_video_to_s3(f"data:{video_mime};base64," + video_b64_string) | |
| data = banana.run( | |
| BANANA_API_KEY, | |
| BANANA_MODEL_KEY, | |
| { | |
| "video": s3_uri, | |
| "filename": os.path.basename(video_file), | |
| }, | |
| )["modelOutputs"][0] | |
| return get_outputs(data=data) | |
| def url_gradio(url: str) -> Tuple[str, str, plotly.graph_objects.Figure]: | |
| # data = json.loads( | |
| # requests.post( | |
| # f"{URL}predict-url", | |
| # headers={"content-type": "text/plain"}, | |
| # data=url, | |
| # ).text | |
| # ) | |
| data = banana.run(BANANA_API_KEY, BANANA_MODEL_KEY, {"url": url},)[ | |
| "modelOutputs" | |
| ][0] | |
| return get_outputs(data=data) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# GeoLocator") | |
| gr.Markdown( | |
| "### An app that guesses the location of an image π or a YouTube video link π." | |
| ) | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| img_input = gr.Image(type="filepath", label="Image") | |
| with gr.Column(): | |
| img_text_output = gr.Textbox(label="Location") | |
| img_coordinates = gr.Textbox(label="Coordinates") | |
| img_plot = gr.Plot() | |
| img_text_button = gr.Button("Go locate!") | |
| with gr.Row(): | |
| # Flag button | |
| img_flag_button = gr.Button("Flag this output") | |
| gr.Examples(examples["images"], inputs=[img_input]) | |
| # with gr.Tab("Video"): | |
| # with gr.Row(): | |
| # video_input = gr.Video(type="filepath", label="Video") | |
| # with gr.Column(): | |
| # video_text_output = gr.Textbox(label="Location") | |
| # video_coordinates = gr.Textbox(label="Coordinates") | |
| # video_plot = gr.Plot() | |
| # video_text_button = gr.Button("Go locate!") | |
| # gr.Examples(examples["videos"], inputs=[video_input]) | |
| with gr.Tab("YouTube Link"): | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="Link") | |
| with gr.Column(): | |
| url_text_output = gr.Textbox(label="Location") | |
| url_coordinates = gr.Textbox(label="Coordinates") | |
| url_plot = gr.Plot() | |
| url_text_button = gr.Button("Go locate!") | |
| gr.Examples(examples["video_urls"], inputs=[url_input]) | |
| # Gantry flagging for image # | |
| callback = GantryImageToTextLogger(application=GANTRY_APP_NAME, api_key=GANTRY_KEY) | |
| callback.setup( | |
| components=[img_input, img_text_output], | |
| flagging_dir=make_unique_bucket_name(prefix=GANTRY_APP_NAME, seed="420"), | |
| ) | |
| img_flag_button.click( | |
| fn=lambda *args: callback.flag(args), | |
| inputs=[img_input, img_text_output, img_coordinates], | |
| outputs=None, | |
| preprocess=False, | |
| ) | |
| ################### | |
| img_text_button.click( | |
| image_gradio, | |
| inputs=img_input, | |
| outputs=[img_text_output, img_coordinates, img_plot], | |
| ) | |
| # video_text_button.click( | |
| # video_gradio, | |
| # inputs=video_input, | |
| # outputs=[video_text_output, video_coordinates, video_plot], | |
| # ) | |
| url_text_button.click( | |
| url_gradio, | |
| inputs=url_input, | |
| outputs=[url_text_output, url_coordinates, url_plot], | |
| ) | |
| gr.Markdown( | |
| "Check out the [GitHub repository](https://github.com/samhita-alla/geolocator) that this demo is based off of." | |
| ) | |
| gr.Markdown( | |
| "#### To understand what subdivision means, refer to the [Geolocation paper](https://openaccess.thecvf.com/content_ECCV_2018/papers/Eric_Muller-Budack_Geolocation_Estimation_of_ECCV_2018_paper.pdf)." | |
| ) | |
| gr.Markdown( | |
| "#### TL;DR Fine and Coarse are spatial resolutions and Hierarchy generates predictions at fine scale but incorporates knowledge from coarse and middle partitionings." | |
| ) | |
| demo.launch() | |