import os
import scipy
import torch
import numpy as np
from .utils.data_utils import get_dataloader, VID_EXTENSIONS
from .utils.metric_utils import seed_everything, FeatureStats
import numpy as np
import torch
import requests
from tqdm import tqdm
from einops import rearrange
from .third_party.VideoMAEv2.utils import load_videomae_model, preprocess_videomae
from .third_party.i3d.utils import load_i3d_model, preprocess_i3d
from typing import List, Optional, Union
import numpy.typing as npt
def get_videomae_features(stats, model, videos, batchsize=16, device='cuda', model_dtype=torch.float32):
vid_length = videos.shape[0]
for i in range(0, videos.shape[0], batchsize):
batch = videos[i:min(vid_length, i + batchsize)]
input_data = preprocess_videomae(batch) # torch.Size([B, 3, T, H, W])
input_data = input_data.to(device=device, dtype=model_dtype)
with torch.no_grad():
features = model.forward_features(input_data)
stats.append_torch(features, num_gpus=1, rank=0)
return stats
def get_i3d_logits(stats, i3d, videos, batchsize=16, device='cuda', model_dtype=torch.float32):
vid_length = videos.shape[0]
for i in range(0, vid_length, batchsize):
batch = videos[i:min(vid_length, i + batchsize)]
input_data = preprocess_i3d(batch)
input_data = input_data.to(device=device, dtype=model_dtype)
with torch.no_grad():
features = i3d(input_data)
stats.append_torch(features, num_gpus=1, rank=0)
return stats
[docs]
class cdfvd(object):
'''This class loads a pretrained model (I3D or VideoMAE) and contains functions to compute the FVD score between real and fake videos.
Args:
model: Name of the model to use, either `videomae` or `i3d`.
n_real: Number of real videos to use for computing the FVD score, if `'full'`, all the videos in the dataset will be used.
n_fake: Number of fake videos to use for computing the FVD score.
ckpt_path: Path to save the model checkpoint.
seed: Random seed.
compute_feats: Whether to compute all features or just mean and covariance.
device: Device to use for computing the features.
half_precision: Whether to use half precision for the model.
'''
def __init__(self, model: str = 'i3d', n_real: str = 'full', n_fake: int = 2048, ckpt_path: Optional[str] = None,
seed: int = 42, compute_feats: bool = False, device: str = 'cuda', half_precision: bool = False,
*args, **kwargs):
self.model_name = model
self.ckpt_path = ckpt_path
self.seed = seed
self.device = device
self.n_real = n_real
self.n_fake = n_fake
self.real_stats = FeatureStats(max_items=None if n_real == 'full' else n_real, capture_mean_cov=True, capture_all=compute_feats)
self.fake_stats = FeatureStats(max_items=n_fake, capture_mean_cov=True, capture_all=compute_feats)
self.model_dtype = (
torch.float16 if half_precision else torch.float32
)
assert self.model_name in ['videomae', 'i3d']
print('Loading %s model ...' % self.model_name)
if self.model_name == 'videomae':
self.model = load_videomae_model(torch.device(device), ckpt_path).eval().to(dtype=self.model_dtype)
self.feature_fn = get_videomae_features
else:
self.model = load_i3d_model(torch.device(device), ckpt_path).eval().to(dtype=self.model_dtype)
self.feature_fn = get_i3d_logits
[docs]
def compute_fvd_from_stats(self, fake_stats: Optional[FeatureStats] = None, real_stats: Optional[FeatureStats] = None) -> float:
'''This function computes the FVD score between real and fake videos using precomputed features.
If the stats are not provided, it uses the stats stored in the object.
Args:
fake_stats: `FeatureStats` object containing the features of the fake videos.
real_stats: `FeatureStats` object containing the features of the real videos.
Returns:
FVD score between the real and fake videos.
'''
fake_stats = self.fake_stats if fake_stats is None else fake_stats
real_stats = self.real_stats if real_stats is None else real_stats
mu_fake, sigma_fake = fake_stats.get_mean_cov()
mu_real, sigma_real = real_stats.get_mean_cov()
m = np.square(mu_real - mu_fake).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_real, sigma_fake), disp=False)
return np.real(m + np.trace(sigma_fake + sigma_real - s * 2))
[docs]
def compute_fvd(self, real_videos: npt.NDArray[np.uint8], fake_videos: npt.NDArray[np.uint8]) -> float:
'''
This function computes the FVD score between real and fake videos in the form of numpy arrays.
Args:
real_videos: A numpy array of videos with shape `(B, T, H, W, C)`, values in the range `[0, 255]`
fake_videos: A numpy array of videos with shape `(B, T, H, W, C)`, values in the range `[0, 255]`
Returns:
FVD score between the real and fake videos.
'''
self.real_stats = self.feature_fn(self.real_stats, self.model, real_videos, device=self.device, model_dtype=self.model_dtype)
self.fake_stats = self.feature_fn(self.fake_stats, self.model, fake_videos, device=self.device, model_dtype=self.model_dtype)
return self.compute_fvd_from_stats(
self.fake_stats, self.real_stats)
[docs]
def compute_real_stats(self, loader: Union[torch.utils.data.DataLoader, List, None] = None) -> FeatureStats:
'''
This function computes the real features from a dataset.
Args:
loader: real videos, either in the type of dataloader or list of numpy arrays.
Returns:
FeatureStats object containing the features of the real videos.
'''
seed_everything(self.seed)
if loader is None:
assert self.real_stats.max_items is not None
return
while self.real_stats.max_items is None or self.real_stats.num_items < self.real_stats.max_items:
for batch in tqdm(loader):
real_videos = rearrange(batch['video']*255, 'b c t h w -> b t h w c').byte().data.numpy()
self.real_stats = self.feature_fn(self.real_stats, self.model, real_videos, device=self.device, model_dtype=self.model_dtype)
if self.real_stats.max_items is not None and self.real_stats.num_items >= self.real_stats.max_items:
break
if self.real_stats.max_items is None:
break
return self.real_stats
[docs]
def compute_fake_stats(self, loader: Union[torch.utils.data.DataLoader, List, None] = None) -> FeatureStats:
'''
This function computes the fake features from a dataset.
Args:
loader: fake videos, either in the type of dataloader or list of numpy arrays.
Returns:
FeatureStats object containing the features of the fake videos.
'''
seed_everything(self.seed)
while self.fake_stats.max_items is None or self.fake_stats.num_items < self.fake_stats.max_items:
for batch in tqdm(loader):
fake_videos = rearrange(batch['video']*255, 'b c t h w -> b t h w c').byte().data.numpy()
self.fake_stats = self.feature_fn(self.fake_stats, self.model, fake_videos, device=self.device, model_dtype=self.model_dtype)
if self.fake_stats.max_items is not None and self.fake_stats.num_items >= self.fake_stats.max_items:
break
if self.fake_stats.max_items is None:
break
return self.fake_stats
[docs]
def add_real_stats(self, real_videos: npt.NDArray[np.uint8]):
'''
This function adds features of real videos to the real_stats object.
Args:
real_videos: A numpy array of videos with shape `(B, T, H, W, C)`, values in the range `[0, 255]`.
'''
self.real_stats = self.feature_fn(self.real_stats, self.model, real_videos, device=self.device, model_dtype=self.model_dtype)
[docs]
def add_fake_stats(self, fake_videos: npt.NDArray[np.uint8]):
'''
This function adds features of fake videos to the fake_stats object.
Args:
fake_videos: A numpy array of videos with shape `(B, T, H, W, C)`, values in the range `[0, 255]`.
'''
self.fake_stats = self.feature_fn(self.fake_stats, self.model, fake_videos, device=self.device, model_dtype=self.model_dtype)
[docs]
def empty_real_stats(self):
'''
This function empties the real_stats object.
'''
self.real_stats = FeatureStats(max_items=self.real_stats.max_items, capture_mean_cov=True)
[docs]
def empty_fake_stats(self):
'''
This function empties the real_stats object.
'''
self.fake_stats = FeatureStats(max_items=self.fake_stats.max_items, capture_mean_cov=True)
[docs]
def save_real_stats(self, path: str):
'''
This function saves the real_stats object to a file.
Args:
path: Path to save the real_stats object.
'''
self.real_stats.save(path)
print('Real stats saved to %s' % path)
[docs]
def load_real_stats(self, path: str):
'''
This function loads the real_stats object from a file.
Args:
path: Path to load the real_stats object.
'''
self.real_stats = self.real_stats.load(path)
print('Real stats loaded from %s' % path)
[docs]
def load_videos(self, video_info: str, resolution: int = 256, sequence_length: int = 16, sample_every_n_frames: int = 1,
data_type: str = 'video_numpy', num_workers: int = 4, batch_size: int = 16) -> Union[torch.utils.data.DataLoader, List, None]:
'''
This function loads videos from a way specified by `data_type`.
`video_numpy` loads videos from a file containing a numpy array with the shape `(B, T, H, W, C)`.
`video_folder` loads videos from a folder containing video files.
`image_folder` loads videos from a folder containing image files.
`stats_pkl` indicates that `video_info` of a dataset name for pre-computed features. Currently supports `ucf101`, `kinetics`, `sky`, `ffs`, and `taichi`.
Args:
video_info: Path to the video file or folder.
resolution: Resolution of the video.
sequence_length: Length of the video sequence.
sample_every_n_frames: Number of frames to skip.
data_type: Type of the video data, either `video_numpy`, `video_folder`, `image_folder`, or `stats_pkl`.
num_workers: Number of workers for the dataloader.
batch_size: Batch size for the dataloader.
Returns:
Dataloader or list of numpy arrays containing the videos.
'''
if data_type=='video_numpy' or video_info.endswith('.npy'):
video_array = np.load(video_info)
video_loader = [{'video': rearrange(torch.from_numpy(video_array[i:i+batch_size])/255., 'b t h w c -> b c t h w')} for i in range(0, video_array.shape[0], batch_size)]
elif data_type=='video_folder':
print('Loading from video files ...')
video_loader = get_dataloader(video_info, image_folder=False,
resolution=resolution, sequence_length=sequence_length,
sample_every_n_frames=sample_every_n_frames,
batch_size=batch_size, num_workers=num_workers)
elif data_type=='image_folder':
print('Loading from frame files ...')
video_loader = get_dataloader(video_info, image_folder=True,
resolution=resolution, sequence_length=sequence_length,
sample_every_n_frames=sample_every_n_frames,
batch_size=batch_size, num_workers=num_workers)
elif data_type=='stats_pkl':
video_loader = None
cache_name = '%s_%s_%s_res%d_len%d_skip%d_seed%d.pkl' % (self.model_name.lower(), video_info, self.n_real, resolution, sequence_length, sample_every_n_frames, 0)
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = os.path.join(current_dir, 'fvd_stats_cache', cache_name)
os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
if not os.path.exists(ckpt_path):
# download the ckpt to the path
ckpt_url = 'https://content-debiased-fvd.github.io/files/%s' % cache_name
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
with open(ckpt_path, "wb") as fw:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
fw.write(data)
self.real_stats = self.real_stats.load(ckpt_path)
else:
raise ValueError('Invalid real_video path')
return video_loader
[docs]
def offload_model_to_cpu(self):
'''
This function offloads the model to the CPU to release the memory.
'''
self.model = self.model.cpu()
torch.cuda.empty_cache()