Update geneformer/perturber_utils.py to be compatible with different versions of the datasets package.
Browse files'Column' object has no attribute 'to' error from evaluation_utils.py in Geneformer_tutorial.ipynb #551
geneformer/perturber_utils.py
CHANGED
@@ -9,6 +9,7 @@ from typing import List
|
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
11 |
import torch
|
|
|
12 |
from datasets import Dataset, load_from_disk
|
13 |
from peft import LoraConfig, get_peft_model
|
14 |
from transformers import (
|
@@ -430,6 +431,11 @@ def remove_perturbed_indices_set(
|
|
430 |
def make_perturbation_batch(
|
431 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
432 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
433 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
434 |
if perturb_type in ["overexpress", "activate"]:
|
435 |
range_start = 1
|
|
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
11 |
import torch
|
12 |
+
import datasets
|
13 |
from datasets import Dataset, load_from_disk
|
14 |
from peft import LoraConfig, get_peft_model
|
15 |
from transformers import (
|
|
|
431 |
def make_perturbation_batch(
|
432 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
433 |
) -> tuple[Dataset, List[int]]:
|
434 |
+
|
435 |
+
# For datasets>=4.0.0, convert to dict to avoid format issues
|
436 |
+
if int(datasets.__version__.split(".")[0]) >= 4:
|
437 |
+
example_cell = example_cell[:]
|
438 |
+
|
439 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
440 |
if perturb_type in ["overexpress", "activate"]:
|
441 |
range_start = 1
|