Train config
import os
from mu.core.base_config import BaseConfig
from pathlib import Path
current_dir = Path(__file__).parent
class SelectiveAmnesiaConfig(BaseConfig):
def __init__(self, **kwargs):
# Training parameters
self.seed = 23 # Random seed
self.scale_lr = True # Flag for scaling learning rate
# Model configuration
self.config_path = current_dir / "train_config.yaml"
self.model_config_path = (
current_dir / "model_config.yaml"
) # Config path for model
self.ckpt_path = "models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion
self.full_fisher_dict_pkl_path = "mu/algorithms/selective_amnesia/data/full_fisher_dict.pkl" # Path for Fisher dict
# Dataset directories
self.raw_dataset_dir = "data/quick-canvas-dataset/sample"
self.processed_dataset_dir = "mu/algorithms/selective_amnesia/data"
self.dataset_type = (
"unlearncanvas" # Dataset type (choices: unlearncanvas, i2p)
)
self.template = "style" # Template to use
self.template_name = "Abstractionism" # Template name
self.replay_prompt_path = "mu/algorithms/selective_amnesia/data/fim_prompts_sample.txt" # Path for replay prompts
# Output configurations
self.output_dir = "outputs/selective_amnesia/finetuned_models" # Output directory to save results
# Device configuration
self.devices = "0," # CUDA devices (comma-separated)
# Additional flags
self.use_sample = True # Use sample dataset for training
# Data configuration
self.data = {
"target": "mu.algorithms.selective_amnesia.data_handler.SelectiveAmnesiaDataHandler",
"params": {
"train_batch_size": 4,
"val_batch_size": 6,
"num_workers": 1,
"num_val_workers": 0, # Avoid val dataloader issue
"train": {
"target": "stable_diffusion.ldm.data.ForgettingDataset",
"params": {
"forget_prompt": "An image in Artist_Sketch style",
"forget_dataset_path": "./q_dist/photo_style",
},
},
"validation": {
"target": "stable_diffusion.ldm.data.VisualizationDataset",
"params": {
"output_size": 512,
"n_gpus": 1, # Number of GPUs for validation
},
},
},
}
# Lightning configuration
self.lightning = {
"find_unused_parameters": False,
"modelcheckpoint": {
"params": {"every_n_epochs": 0, "save_top_k": 0, "monitor": None}
},
"callbacks": {
"image_logger": {
"target": "mu.algorithms.selective_amnesia.callbacks.ImageLogger",
"params": {
"batch_frequency": 1,
"max_images": 999,
"increase_log_steps": False,
"log_first_step": False,
"log_all_val": True,
"clamp": True,
"log_images_kwargs": {
"ddim_eta": 0,
"ddim_steps": 50,
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 6, # Number of validation prompts
"unconditional_guidance_scale": 7.5,
"unconditional_guidance_label": [""],
},
},
}
},
"trainer": {
"benchmark": True,
"num_sanity_val_steps": 0,
"max_epochs": 50, # Modify epochs here!
"check_val_every_n_epoch": 10,
},
}
# Update properties based on provided kwargs
for key, value in kwargs.items():
setattr(self, key, value)
def validate_config(self):
"""
Perform basic validation on the config parameters.
"""
# Check if necessary directories exist
if not os.path.exists(self.raw_dataset_dir):
raise FileNotFoundError(f"Directory {self.raw_dataset_dir} does not exist.")
if not os.path.exists(self.processed_dataset_dir):
raise FileNotFoundError(
f"Directory {self.processed_dataset_dir} does not exist."
)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Check if model and checkpoint files exist
if not os.path.exists(self.model_config_path):
raise FileNotFoundError(
f"Model config file {self.model_config_path} does not exist."
)
if not os.path.exists(self.ckpt_path):
raise FileNotFoundError(f"Checkpoint file {self.ckpt_path} does not exist.")
if not os.path.exists(self.full_fisher_dict_pkl_path):
raise FileNotFoundError(
f"Fisher dictionary file {self.full_fisher_dict_pkl_path} does not exist."
)
# Check if replay prompts file exists
if not os.path.exists(self.replay_prompt_path):
raise FileNotFoundError(
f"Replay prompt file {self.replay_prompt_path} does not exist."
)
# Validate dataset type
if self.dataset_type not in ["unlearncanvas", "i2p"]:
raise ValueError(
f"Invalid dataset type {self.dataset_type}. Choose from ['unlearncanvas', 'i2p']"
)
# Validate batch sizes
if self.data["params"]["train_batch_size"] <= 0:
raise ValueError(f"train_batch_size should be a positive integer.")
if self.data["params"]["val_batch_size"] <= 0:
raise ValueError(f"val_batch_size should be a positive integer.")
# Validate lightning trainer max_epochs
if self.lightning["trainer"]["max_epochs"] <= 0:
raise ValueError(f"max_epochs should be a positive integer.")
selective_amnesia_config_unlearn_canvas = SelectiveAmnesiaConfig()
selective_amnesia_config_unlearn_canvas.dataset_type = "unlearncanvas"
selective_amnesia_config_unlearn_canvas.raw_dataset_dir = (
"data/quick-canvas-dataset/sample"
)
selective_amnesia_config_i2p = SelectiveAmnesiaConfig()
selective_amnesia_config_i2p.dataset_type = "i2p"
selective_amnesia_config_i2p.raw_dataset_dir = "data/i2p-dataset/sample"
Train config yaml file
# Training parameters
seed : 23
scale_lr : True
# Model configuration
model_config_path: "mu/algorithms/selective_amnesia/configs/model_config.yaml"
ckpt_path: "/home/ubuntu/Projects/UnlearnCanvas/UnlearnCanvas/machine_unlearning/models/compvis/style50/compvis.ckpt" # Checkpoint path for Stable Diffusion
full_fisher_dict_pkl_path : "mu/algorithms/selective_amnesia/data/full_fisher_dict.pkl"
# Dataset directories
raw_dataset_dir: "data/quick-canvas-dataset/sample"
processed_dataset_dir: "mu/algorithms/selective_amnesia/data"
dataset_type : "unlearncanvas"
template : "style"
template_name : "Abstractionism"
replay_prompt_path: "mu/algorithms/selective_amnesia/data/fim_prompts_sample.txt"
# Output configurations
output_dir: "outputs/selective_amnesia/finetuned_models" # Output directory to save results
# Sampling and image configurations
# Device configuration
devices: "0," # CUDA devices to train on (comma-separated)
# Additional flags
use_sample: True # Use the sample dataset for training
data:
target: mu.algorithms.selective_amnesia.data_handler.SelectiveAmnesiaDataHandler
params:
train_batch_size: 4
val_batch_size: 6
num_workers: 4
num_val_workers: 0 # Avoid a weird val dataloader issue (keep unchanged)
train:
target: stable_diffusion.ldm.data.ForgettingDataset
params:
forget_prompt: An image in Artist_Sketch style
forget_dataset_path: ./q_dist/photo_style
validation:
target: stable_diffusion.ldm.data.VisualizationDataset
params:
output_size: 512
n_gpus: 1 # CHANGE THIS TO NUMBER OF GPUS! small hack to sure we see all our logging samples
lightning:
find_unused_parameters: False
modelcheckpoint:
params:
every_n_epochs: 0
save_top_k: 0
monitor: null
callbacks:
image_logger:
target: mu.algorithms.selective_amnesia.callbacks.ImageLogger
params:
batch_frequency: 1
max_images: 999
increase_log_steps: False
log_first_step: False
log_all_val: True
clamp: True
log_images_kwargs:
ddim_eta: 0
ddim_steps: 50
use_ema_scope: True
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 6 # keep this the same as number of validation prompts!
unconditional_guidance_scale: 7.5
unconditional_guidance_label: [""]
trainer:
benchmark: True
num_sanity_val_steps: 0
max_epochs: 50 # modify epochs here!
check_val_every_n_epoch: 10
Configuration File description
Training Parameters
-
seed: Random seed for reproducibility.
- Type: int
- Example: 23
-
scale_lr: Whether to scale the base learning rate.
- Type: bool
- Example: True
Model Configuration
-
model_config_path: Path to the Stable Diffusion model configuration YAML file.
- Type: str
- Example: "/path/to/model_config.yaml"
-
ckpt_path: Path to the Stable Diffusion model checkpoint.
- Type: str
- Example: "/path/to/compvis.ckpt"
-
full_fisher_dict_pkl_path: Path to the full fisher dict pkl file
- Type: str
- Example: "full_fisher_dict.pkl"
Dataset Directories
-
raw_dataset_dir: Directory containing the raw dataset categorized by themes or classes.
- Type: str
- Example: "/path/to/raw_dataset"
-
processed_dataset_dir: Directory to save the processed dataset.
- Type: str
- Example: "/path/to/processed_dataset"
-
dataset_type: Specifies the dataset type for training. Use
generic
as type if you want to use your own dataset.- Choices: ["unlearncanvas", "i2p", "generic"]
- Example: "unlearncanvas"
-
template: Type of template to use during training.
- Choices: ["object", "style", "i2p"]
- Example: "style"
-
template_name: Name of the concept or style to erase.
- Choices: ["self-harm", "Abstractionism"]
- Example: "Abstractionism"
Output Configurations
- output_dir: Directory to save fine-tuned models and results.
- Type: str
- Example: "outputs/selective_amnesia/finetuned_models"
Device Configuration
- devices: CUDA devices for training (comma-separated).
- Type: str
- Example: "0"
Data Parameters
-
train_batch_size: Batch size for training.
- Type: int
- Example: 4
-
val_batch_size: Batch size for validation.
- Type: int
- Example: 6
-
num_workers: Number of worker threads for data loading.
- Type: int
- Example: 4
-
forget_prompt: Prompt to specify the style or concept to forget.
- Type: str
- Example: "An image in Artist_Sketch style"
Lightning Configuration
-
max_epochs: Maximum number of epochs for training.
- Type: int
- Example: 50
-
callbacks:
-
batch_frequency: Frequency for logging image batches.
- Type: int
- Example: 1
-
max_images: Maximum number of images to log.
- Type: int
- Example: 999
-