BAITS.st.tl.Cluster#
- class BAITS.st.tl.Cluster(n_clusters=1, covariance_type='full', init_strategy='kmeans', init_means=None, convergence_tolerance=0.001, covariance_regularization=1e-06, batch_size=None, trainer_params=None, random_state=0)#
Cluster cells or spots based on features.
- Parameters:
n_clusters (
int(default:1)) – The number of components in the GMM. The dimensionality of each component is automatically inferred from the data.covariance_type (
str(default:'full')) – The type of covariance to assume for all Gaussian components.init_strategy (
str(default:'kmeans')) – The strategy for initializing component means and covariances.init_means (
Optional[Tensor] (default:None)) – An optional initial guess for the means of the components. If provided, must be a tensor of shape[num_components, num_features]. If this is given, theinit_strategyis ignored and the means are handled as if K-means initialization has been run.convergence_tolerance (
float(default:0.001)) – The change in the per-datapoint negative log-likelihood which implies that training has converged.covariance_regularization (
float(default:1e-06)) – A small value which is added to the diagonal of the covariance matrix to ensure that it is positive semi-definite.batch_size (The batch size to use when fitting the model. If not provided, the full) – data will be used as a single batch. Set this if the full data does not fit into memory.
trainer_params (
Optional[dict] (default:None)) – Initialization parameters to use when initializing a PyTorch Lightning trainer. By default, it disables various stdout logs unless TorchGMM is configured to do verbose logging. Checkpointing and logging are disabled regardless of the log level. This estimator further sets the following overridable defaults: -max_epochs=100.random_state (
Union[int,RandomState,None] (default:0)) – Initialization seed.
Examples
>>> adata = anndata.read_h5ad(path_to_anndata) >>> sq.gr.spatial_neighbors(adata, coord_type='generic', delaunay=True) >>> cc.gr.remove_long_links(adata) >>> cc.gr.aggregate_neighbors(adata, n_layers=3) >>> model = cc.tl.Cluster(n_clusters=11) >>> model.fit(adata, use_rep='X_cellcharter')
- __init__(n_clusters=1, covariance_type='full', init_strategy='kmeans', init_means=None, convergence_tolerance=0.001, covariance_regularization=1e-06, batch_size=None, trainer_params=None, random_state=0)#
- Args:
- num_components: The number of components in the GMM. The dimensionality of each
component is automatically inferred from the data.
covariance_type: The type of covariance to assume for all Gaussian components. init_strategy: The strategy for initializing component means and covariances. init_means: An optional initial guess for the means of the components. If provided,
must be a tensor of shape
[num_components, num_features]. If this is given, theinit_strategyis ignored and the means are handled as if K-means initialization has been run.- convergence_tolerance: The change in the per-datapoint negative log-likelihood which
implies that training has converged.
- covariance_regularization: A small value which is added to the diagonal of the
covariance matrix to ensure that it is positive semi-definite.
- batch_size: The batch size to use when fitting the model. If not provided, the full
data will be used as a single batch. Set this if the full data does not fit into memory.
- num_workers: The number of workers to use for loading the data. Only used if a PyTorch
dataset is passed to
fit()or related methods.- trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
trainer. By default, it disables various stdout logs unless TorchGMM is configured to do verbose logging. Checkpointing and logging are disabled regardless of the log level. This estimator further sets the following overridable defaults:
max_epochs=100
- Note:
The number of epochs passed to the initializer only define the number of optimization epochs. Prior to that, initialization is run which may perform additional iterations through the data.
- Note:
For batch training, the number of epochs run (i.e. the number of passes through the data), does not align with the number of epochs passed to the initializer. This is because the EM algorithm needs to be split up across two epochs. The actual number of minimum/maximum epochs is, thus, doubled. Nonetheless,
num_iter_indicates how many EM iterations have been run.
Methods
__init__([n_clusters, covariance_type, ...])Args:
clone()Clones the estimator without copying any fitted attributes.
fit(adata[, use_rep])Fit data into
n_clustersclusters.fit_predict(data)Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator.
get_params([deep])Returns the estimator's parameters as passed to the initializer.
load(path)Loads the estimator and (if available) the fitted model.
load_attributes(path)Loads the fitted attributes that are stored at the fitted path.
load_parameters(path)Initializes this estimator by loading its parameters.
predict(adata[, use_rep, store_labels, ...])Predict the labels for the data in
use_repusing the fitted model.predict_proba(data)Computes a distribution over the components for each of the provided datapoints.
sample(num_datapoints)Samples datapoints from the fitted Gaussian mixture.
save(path)Saves the estimator to the provided directory.
save_attributes(path)Saves the fitted attributes of this estimator.
save_parameters(path)Saves the parameters of this estimator.
score(adata[, use_rep])Fit data into
n_clustersclusters.score_samples(data)Computes the negative log-likelihood (NLL) of each of the provided datapoints.
set_params(values)Sets the provided values on the estimator.
trainer(**kwargs)Returns the trainer as configured by the estimator.
Attributes
persistent_attributesReturns the list of fitted attributes that ought to be saved and loaded.
model_The fitted PyTorch module with all estimated parameters.
converged_A boolean indicating whether the model converged during training.
num_iter_The number of iterations the model was fitted for, excluding initialization.
nll_The average per-datapoint negative log-likelihood at the last training step.