Update geneformer/perturber_utils.py to be compatible with different versions of the datasets package.
#553
by
IchigoJiken
- opened
geneformer/perturber_utils.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import List
|
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
import torch
|
|
|
|
| 12 |
from datasets import Dataset, load_from_disk
|
| 13 |
from peft import LoraConfig, get_peft_model
|
| 14 |
from transformers import (
|
|
@@ -430,6 +431,11 @@ def remove_perturbed_indices_set(
|
|
| 430 |
def make_perturbation_batch(
|
| 431 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 432 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 434 |
if perturb_type in ["overexpress", "activate"]:
|
| 435 |
range_start = 1
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
| 11 |
import torch
|
| 12 |
+
import datasets
|
| 13 |
from datasets import Dataset, load_from_disk
|
| 14 |
from peft import LoraConfig, get_peft_model
|
| 15 |
from transformers import (
|
|
|
|
| 431 |
def make_perturbation_batch(
|
| 432 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 433 |
) -> tuple[Dataset, List[int]]:
|
| 434 |
+
|
| 435 |
+
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 436 |
+
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 437 |
+
example_cell = example_cell[:]
|
| 438 |
+
|
| 439 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 440 |
if perturb_type in ["overexpress", "activate"]:
|
| 441 |
range_start = 1
|