Spaces:
Runtime error
Runtime error
| import re | |
| from itertools import count, islice | |
| from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload | |
| from datasets import Features, Value, get_dataset_config_info | |
| from datasets.features.features import FeatureType, _visit | |
| from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult | |
| Row = dict[str, Any] | |
| T = TypeVar("T") | |
| BATCH_SIZE = 1 | |
| MAX_TEXT_LENGTH = 500 | |
| analyzer = AnalyzerEngine() | |
| batch_analyzer = BatchAnalyzerEngine(analyzer) | |
| class PresidioEntity(TypedDict): | |
| text: str | |
| type: str | |
| row_idx: int | |
| column_name: str | |
| def batched(it: Iterable[T], n: int) -> Iterable[list[T]]: | |
| ... | |
| def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]: | |
| ... | |
| def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]: | |
| ... | |
| def batched( | |
| it: Iterable[T], n: int, with_indices: bool = False | |
| ) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]: | |
| it, indices = iter(it), count() | |
| while batch := list(islice(it, n)): | |
| yield (list(islice(indices, len(batch))), batch) if with_indices else batch | |
| def mask(text: str) -> str: | |
| return " ".join( | |
| word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :]) | |
| for word in text.split(" ") | |
| ) | |
| def get_strings(row_content: Any) -> str: | |
| if isinstance(row_content, str): | |
| return row_content | |
| if isinstance(row_content, dict): | |
| if "src" in row_content: | |
| return "" # could be image or audio | |
| row_content = list(row_content.values()) | |
| if isinstance(row_content, list): | |
| str_items = (get_strings(row_content_item) for row_content_item in row_content) | |
| return "\n".join(str_item for str_item in str_items if str_item) | |
| return "" | |
| def _simple_analyze_iterator_cache( | |
| batch_analyzer: BatchAnalyzerEngine, | |
| texts: Iterable[str], | |
| language: str, | |
| score_threshold: float, | |
| cache: dict[str, list[RecognizerResult]], | |
| ) -> list[list[RecognizerResult]]: | |
| not_cached_results = iter( | |
| batch_analyzer.analyze_iterator( | |
| (text for text in texts if text not in cache), language=language, score_threshold=score_threshold | |
| ) | |
| ) | |
| results = [cache[text] if text in cache else next(not_cached_results) for text in texts] | |
| # cache the last results | |
| cache.clear() | |
| cache.update(dict(zip(texts, results))) | |
| return results | |
| def analyze( | |
| batch_analyzer: BatchAnalyzerEngine, | |
| batch: list[dict[str, str]], | |
| indices: Iterable[int], | |
| scanned_columns: list[str], | |
| columns_descriptions: list[str], | |
| cache: Optional[dict[str, list[RecognizerResult]]] = None, | |
| ) -> list[PresidioEntity]: | |
| cache = {} if cache is None else cache | |
| texts = [ | |
| f"The following is {columns_description} data:\n\n{example[column_name] or ''}" | |
| for example in batch | |
| for column_name, columns_description in zip(scanned_columns, columns_descriptions) | |
| ] | |
| return [ | |
| PresidioEntity( | |
| text=texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end], | |
| type=recognizer_result.entity_type, | |
| row_idx=row_idx, | |
| column_name=column_name, | |
| ) | |
| for i, row_idx, recognizer_row_results in zip( | |
| count(), | |
| indices, | |
| batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)), | |
| ) | |
| for j, column_name, columns_description, recognizer_results in zip( | |
| count(), scanned_columns, columns_descriptions, recognizer_row_results | |
| ) | |
| for recognizer_result in recognizer_results | |
| if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n") | |
| ] | |
| def presidio_scan_entities( | |
| rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str] | |
| ) -> Iterable[PresidioEntity]: | |
| cache: dict[str, list[RecognizerResult]] = {} | |
| rows_with_scanned_columns_only = ( | |
| {column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows | |
| ) | |
| for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True): | |
| yield from analyze( | |
| batch_analyzer=batch_analyzer, | |
| batch=batch, | |
| indices=indices, | |
| scanned_columns=scanned_columns, | |
| columns_descriptions=columns_descriptions, | |
| cache=cache, | |
| ) | |
| def get_columns_with_strings(features: Features) -> list[str]: | |
| columns_with_strings: list[str] = [] | |
| for column, feature in features.items(): | |
| str_column = str(column) | |
| with_string = False | |
| def classify(feature: FeatureType) -> None: | |
| nonlocal with_string | |
| if isinstance(feature, Value) and feature.dtype == "string": | |
| with_string = True | |
| _visit(feature, classify) | |
| if with_string: | |
| columns_with_strings.append(str_column) | |
| return columns_with_strings | |
| def get_column_description(column_name: str, feature: FeatureType) -> str: | |
| nested_fields: list[str] = [] | |
| def get_nested_field_names(feature: FeatureType) -> None: | |
| nonlocal nested_fields | |
| if isinstance(feature, dict): | |
| nested_fields += list(feature) | |
| _visit(feature, get_nested_field_names) | |
| return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name | |