Fine tuning questions

#462
by ZYSK-huggingface - opened

Hi !

We wanted to fine-tune 95M models with our own data for cell state perturbation but we could not obtain a wonderful loss. When checking your examples, we found some functions or parameters, but not sure their effect or usage:

①in your cell_classification.ipynb:
IS 'n_hyperopt_trials=n in cc.validate()' essential for fine tuning?

AND is function like 'train_all_data' 'train_classifier' 'hyperopt_classifier' useful?(these functions are listed in https://geneformer.readthedocs.io/en/latest/geneformer.classifier.html, but do not appear in your example codes)

AND is a quantized model better or not?

Finaly, we also see 'evaluate_model' function, is it different from 'cc.validate'

In conclusion, we are confused at these extra functions or parameters and not sure which could be used to optimize fine tuning

②Create multi-task_cell_classification.ipynb:
We found this example and wonder if it could be used for optimizing finetuning model for cell state perturbation

AND what is the difference between it and the above code? What is the best choice to fine tune a 95M model for cell state perturbation?

Thank you so much !

Thanks for your questions.

  • as stated in the documentation, as for all deep learning applications, hyperparameter tuning can be highly beneficial for fine tuning

  • the multitask classification is useful if you would like to train with multiple tasks, for example to understand how disease may affect different cell types by learning from both of these aspects simultaneously

  • the quantized models are more efficient

  • the other functions you mention are useful depending on your goal. Please read the documentation to understand them and decide which is relevant to your goal.

ctheodoris changed discussion status to closed

Thanks for your questions.

  • as stated in the documentation, as for all deep learning applications, hyperparameter tuning can be highly beneficial for fine tuning

  • the multitask classification is useful if you would like to train with multiple tasks, for example to understand how disease may affect different cell types by learning from both of these aspects simultaneously

  • the quantized models are more efficient

  • the other functions you mention are useful depending on your goal. Please read the documentation to understand them and decide which is relevant to your goal.

Thank you for your answer, for 'train_all_data' 'train_classifier' 'hyperopt_classifier',I do not see difference in documentation, so they are just for the same aim but have different capacity? And, in your cell_classification.ipynb, could I just insert these functions after or before the cc.prepare?

Our goal is to fine tune a 95M model for cell classification with different state, and to apply it for cell state perturbation. Could you please give me some more specific guidance?

Best wishes

Besides, for the multi-task_cell_classification.ipynb, seems that it include cell_classification and perturbation. So it is similar with cell_classification.ipynb first and cell state perturbation second? I mean, like, just for a cell classifier finetuned model, could multi-task_cell_classification.ipynb replace cell_classification.ipynb and perform better?

Thank you again!

Thank you for your questions! "hyperopt_classifier" is for hyperparameter optimization. "train_all_data" trains on all the provided data and does not hold any out for validation, and there is therefore no hyperparameter optimization. "train_classifier" fine-tunes a model on the provided training data and then evaluates on the provided validation data, also without hyperparameter optimization.

"validate" is the most common way the model should be used, as provided in the examples. If you want to perform hyperparameter optimization, you can just set n_hyperopt_trials to be >0. The other functions are mostly used internally by "validate" etc, but we provide documentation in case others would find them helpful for their use case. To accomplish your goal of fine-tuning a model for cell state classification, I would suggest following the example and using "validate" with changing n_hyperopt_trials if you want to perform hyperparameter tuning.

For the multi-task cell state classification, yes this would replace the single-task cell state classification, while the next step of in silico perturbation would be the same for both. The reason to use the multi-task version is if you have multiple tasks you are interested in fine-tuning the model towards.

Once finding the optimal hyperparameters with validate, do we retrain the model with train_all_data?
Let's say my use case was disease classification of the cell. I assume you 1) initialize a cell classifier model using V2 Geneformer path, 2) run classifier.validate(n_hyperopt_trials =x), 3) what occurs here?

Should the model just be evaluated from its saved path (aka using evaluate_saved_model)?
Should the classifier be reinitialized with new training arguments and then trained and evaluated?
Should the classifier stay as it as before and simply train from the saved path?
Lastly, when and why should one use "train_all_data" or "train classifier"? While I understand the differences in how they split data and evaluate, the documentation doesn't explain why either of these functions should be used and for what reason.

Once finding the optimal hyperparameters with validate, do we retrain the model with train_all_data?

Generally for machine learning, there are three sets of held out data: training, validation, and test. The validate function is set up to fine-tune the model with optimizing hyperparameters and evaluating on held-out validation data. After finalizing the training, the best model can be evaluated on the held-out test data to confirm generalizability. Once the model is confirmed to be generalizable, often in machine learning one strategy is going back and training the model the same way on all the data to maximize its potential for novel discoveries when using it for inference in now new domains. Another utility of train_all_data is to provide a function to train on all the data with predefined hyperparameters without performing hyperparameter tuning or validation, which can also be used internally to another scheme for validation that the user may design. You do not have to go back to retrain the model from validate if you don't want to though; the models are saved for each trial.

Let's say my use case was disease classification of the cell. I assume you 1) initialize a cell classifier model using V2 Geneformer path, 2) run classifier.validate(n_hyperopt_trials =x), 3) what occurs here?

The hyperparameters are tuned and the best hyperparameters are chosen based on the validation data. If test data is provided, the best model is evaluated on the test data.

Should the model just be evaluated from its saved path (aka using evaluate_saved_model)?

This is another way you can evaluate the model if that better fits your workflow.

Should the classifier be reinitialized with new training arguments and then trained and evaluated?

It is unclear at what step "reinitialized" is happening in this question. You can use a model saved from "validate" or train a new one if needed.

Should the classifier stay as it as before and simply train from the saved path?

It is unclear what this question means. If the model is saved and has a saved path, and the model is therefore already trained, it does not need to be trained further, unless you want to expose it to new data.

Lastly, when and why should one use "train_all_data" or "train classifier"? While I understand the differences in how they split data and evaluate, the documentation doesn't explain why either of these functions should be used and for what reason.

There are various functions provided to provide flexibility to users dependent on how they are structuring their analysis. The reasons they would be used really depends on the project and is a general machine learning question of how one is setting up the analysis and is not specific to Geneformer. Many of these functions are used internally in validate, but they are also exposed here to users in case they would like to use them in other workflows. We provide an example protocol for fine-tuning a model towards disease state classification and predicting candidate therapeutic targets in the Colab tutorial linked in the model card.

Sign up or log in to comment