alphagenome_test / streamlit_app.py
Eitan177's picture
Create streamlit_app.py
8cf26e7 verified
import streamlit as st
import pandas as pd
import altair as alt
from alphagenome.data import gene_annotation, genome, transcript
from alphagenome.models import dna_client
from alphagenome.visualization import plot_components
# --- Page Configuration ---
st.set_page_config(
page_title="AlphaGenome Variant Visualizer",
page_icon="🧬",
layout="wide",
initial_sidebar_state="expanded",
)
# --- Global Variables & Constants ---
ORGANISM_MAP = {
'human': dna_client.Organism.HOMO_SAPIENS,
'mouse': dna_client.Organism.MUS_MUSCULUS,
}
HG38_GTF_FEATHER = (
'https://storage.googleapis.com/alphagenome/reference/gencode/'
'hg38/gencode.v46.annotation.gtf.gz.feather'
)
MM10_GTF_FEATHER = (
'https://storage.googleapis.com/alphagenome/reference/gencode/'
'mm10/gencode.vM23.annotation.gtf.gz.feather'
)
# --- Caching Functions ---
@st.cache_resource
def get_dna_model():
"""
Creates and caches the DNA model client using the API key from Streamlit secrets.
"""
api_key = st.secrets.get("ALPHAGENOME_API_KEY")
if not api_key:
st.error("AlphaGenome API key not found. Please add it to your Streamlit secrets.")
st.info("For local testing, create a file at .streamlit/secrets.toml with the content: \n\nALPHAGENOME_API_KEY = \"YOUR_API_KEY_HERE\"")
st.stop()
try:
return dna_client.create(api_key)
except Exception as e:
st.error(f"Failed to create DNA model client. Please check your API key. Error: {e}")
st.stop()
@st.cache_data
def get_transcript_extractors(organism_str):
"""
Loads and caches gene annotation data and creates transcript extractors.
"""
organism = ORGANISM_MAP[organism_str]
with st.spinner('Loading gene annotation...'):
if organism == dna_client.Organism.HOMO_SAPIENS:
gtf_path = HG38_GTF_FEATHER
elif organism == dna_client.Organism.MUS_MUSCULUS:
gtf_path = MM10_GTF_FEATHER
else:
st.error(f'Unsupported organism: {organism}')
st.stop()
gtf = pd.read_feather(gtf_path)
gtf_transcript = gene_annotation.filter_transcript_support_level(
gene_annotation.filter_protein_coding(gtf), ['1']
)
transcript_extractor = transcript.TranscriptExtractor(gtf_transcript)
gtf_longest_transcript = gene_annotation.filter_to_longest_transcript(
gtf_transcript
)
longest_transcript_extractor = transcript.TranscriptExtractor(
gtf_longest_transcript
)
return transcript_extractor, longest_transcript_extractor
@st.cache_data
def predict_variant_cached(_dna_model, interval_str, variant_str, organism_str):
"""
Cache wrapper for dna_model.predict_variant to avoid re-running predictions.
The _dna_model argument is used to invalidate the cache if the model changes,
but it's not directly used in the function body.
"""
interval = genome.Interval.from_str(interval_str)
variant = genome.Variant.from_str(variant_str)
organism = dna_client.Organism[organism_str]
# This is the correct prediction function based on the working Colab notebook
return _dna_model.predict_variant(
interval=interval,
variant=variant,
organism=organism,
requested_outputs=[*dna_client.OutputType],
ontology_terms=['EFO:0001187', 'EFO:0002067', 'EFO:0002784'],
)
# --- Sidebar UI ---
st.sidebar.title('Variant Visualizer')
st.sidebar.header('Variant and Plotting Options')
organism_choice = st.sidebar.selectbox('Organism', ['human', 'mouse'])
variant_chromosome = st.sidebar.text_input('Chromosome', 'chr22')
variant_position = st.sidebar.number_input('Position', value=36201698, format="%d")
variant_reference_bases = st.sidebar.text_input('Reference Bases', 'A')
variant_alternate_bases = st.sidebar.text_input('Alternate Bases', 'C')
sequence_length_choice = st.sidebar.selectbox('Sequence Length', ["2KB", "16KB", "100KB", "500KB", "1MB"], index=4)
st.sidebar.header('Plotting Options')
st.sidebar.subheader('Output Types')
plot_rna_seq = st.sidebar.checkbox('RNA-SEQ', value=True)
plot_splice_sites = st.sidebar.checkbox('SPLICE_SITES', value=True)
plot_splice_junctions = st.sidebar.checkbox('SPLICE_JUNCTIONS', value=True)
plot_cage = st.sidebar.checkbox('CAGE', value=True)
plot_atac = st.sidebar.checkbox('ATAC', value=False)
plot_dnase = st.sidebar.checkbox('DNASE', value=False)
plot_chip_histone = st.sidebar.checkbox('CHIP_HISTONE', value=False)
plot_chip_tf = st.sidebar.checkbox('CHIP_TF', value=False)
plot_contact_maps = st.sidebar.checkbox('CONTACT_MAPS', value=False)
st.sidebar.subheader('Gene Annotation')
plot_gene_annotation = st.sidebar.checkbox('Plot Gene Annotation', value=True)
plot_longest_transcript_only = st.sidebar.checkbox('Plot Longest Transcript Only', value=True)
st.sidebar.subheader('DNA Strand Filter')
strand_filter = st.sidebar.radio('Filter to Strand', ('None', 'Positive', 'Negative'), index=0)
st.sidebar.subheader('Visualization Settings')
plot_interval_width = st.sidebar.slider('Plot Interval Width (bp)', min_value=1024, max_value=200000, step=1024, value=43008)
plot_interval_shift = st.sidebar.slider('Plot Interval Shift (bp)', min_value=-524288, max_value=524288, step=2048, value=0)
ref_color = st.sidebar.color_picker('Reference Color', value='#808080')
alt_color = st.sidebar.color_picker('Alternate Color', value='#FF0000')
ref_alt_colors = {'REF': ref_color, 'ALT': alt_color}
# --- Main App Logic ---
dna_model = get_dna_model()
if st.sidebar.button('Visualize Variant', use_container_width=True):
st.header(f"Visualizing variant: {variant_chromosome}:{variant_position}:{variant_reference_bases}>{variant_alternate_bases}")
variant = genome.Variant(
chromosome=variant_chromosome,
position=int(variant_position),
reference_bases=variant_reference_bases,
alternate_bases=variant_alternate_bases,
)
sequence_length = dna_client.SUPPORTED_SEQUENCE_LENGTHS[f'SEQUENCE_LENGTH_{sequence_length_choice}']
interval = variant.reference_interval.resize(sequence_length)
transcript_extractor, longest_transcript_extractor = get_transcript_extractors(organism_choice)
output = predict_variant_cached(
dna_model,
interval_str=str(interval),
variant_str=str(variant),
organism_str=ORGANISM_MAP[organism_choice].name,
)
ref, alt = output.reference, output.alternate
if strand_filter == 'Positive':
ref = ref.filter_to_strand(strand='+')
alt = alt.filter_to_strand(strand='+')
elif strand_filter == 'Negative':
ref = ref.filter_to_strand(strand='-')
alt = alt.filter_to_strand(strand='-')
components = []
if plot_gene_annotation:
extractor = longest_transcript_extractor if plot_longest_transcript_only else transcript_extractor
transcripts = extractor.extract(interval)
components.append(plot_components.TranscriptAnnotation(transcripts))
plot_map = {
plot_atac: (ref.atac, alt.atac, 'ATAC'),
plot_cage: (ref.cage, alt.cage, 'CAGE'),
plot_chip_histone: (ref.chip_histone, alt.chip_histone, 'CHIP_HISTONE'),
plot_chip_tf: (ref.chip_tf, alt.chip_tf, 'CHIP_TF'),
plot_contact_maps: (ref.contact_maps, alt.contact_maps, 'CONTACT_MAPS'),
plot_dnase: (ref.dnase, alt.dnase, 'DNASE'),
plot_rna_seq: (ref.rna_seq, alt.rna_seq, 'RNA_SEQ'),
plot_splice_junctions: (ref.splice_junctions, alt.splice_junctions, 'SPLICE_JUNCTIONS'),
plot_splice_sites: (ref.splice_sites, alt.splice_sites, 'SPLICE_SITES'),
}
for should_plot, (ref_data, alt_data, output_type) in plot_map.items():
if should_plot:
if ref_data is None or ref_data.values.shape[-1] == 0:
st.warning(f'No tracks exist for {output_type} with the current filters.')
continue
if output_type == 'CONTACT_MAPS':
components.append(plot_components.ContactMapsDiff(tdata=alt_data - ref_data))
elif output_type == 'SPLICE_JUNCTIONS':
components.append(plot_components.Sashimi(ref_data, ylabel_template='REF: {track}'))
components.append(plot_components.Sashimi(alt_data, ylabel_template='ALT: {track}'))
else:
components.append(plot_components.OverlaidTracks(
tdata={'REF': ref_data, 'ALT': alt_data},
colors=ref_alt_colors,
ylabel_template='{track}'
))
if not components:
st.warning("No data to plot. Please select at least one output type to visualize.")
elif plot_interval_width > interval.width:
st.error(f'Plot Interval Width ({plot_interval_width}) must be less than Sequence Length ({interval.width}).')
else:
with st.spinner('Generating plot...'):
plot = plot_components.plot(
components=components,
interval=interval.shift(plot_interval_shift).resize(plot_interval_width),
annotations=[plot_components.VariantAnnotation([variant])],
)
st.altair_chart(plot, use_container_width=True)
else:
st.info("Configure your variant in the sidebar and click 'Visualize Variant' to begin.")