经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 大数据/云/AI » 人工智能基础 » 查看文章
机器学习常见的sampling策略 附PyTorch实现 - zh-jp
来源:cnblogs  作者:zh-jp  时间:2024/4/10 15:16:15  对本文有异议

简单的采样策略

首先介绍三种简单采样策略:

  1. Instance-balanced sampling, 实例平衡采样。
  2. Class-balanced sampling, 类平衡采样。
  3. Square-root sampling, 平方根采样。

它们可抽象为:

\[p_j=\frac{n_j^q}{\sum_{i=1}^Cn_i^q}, \]

\(p_j\)表示从j类采样数据的概率;\(C\)表示类别数量;\(n_j\)表示j类样本数;\(q\in\{1,0,\frac{1}{2}\}\)
Instance-balanced sampling
最常见的数据采样方式,其中每个训练样本被选择的概率相等(\(q=1\))。j类被采样的概率\(p^{\mathbf{IB}}_j\)与j类样本数\(n_j\)成正比,即\(p^{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}^Cn_i}\)

Class-balanced sampling
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。
Square-root sampling
平方根采样最常见的变体,\(q=\frac{1}{2}\)

由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:

  1. import numpy as np
  2. from torch.utils.data.sampler import WeightedRandomSampler
  3. def get_sampler(sampling_type, targets):
  4. cls_counts = np.bincount(targets)
  5. if sampling_type == 'instance-balanced':
  6. cls_weights = cls_counts / np.sum(cls_counts)
  7. elif sampling_type == 'class-balanced':
  8. cls_num = len(cls_counts)
  9. cls_weights = [1. / cls_num] * cls_num
  10. elif sampling_type == 'square-root':
  11. sqrt_and_sum = np.sum([num**0.5 for num in cls_counts])
  12. cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts]
  13. else:
  14. raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root')
  15. cls_weights = np.array(cls_weights)
  16. return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)

WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

在模拟的长尾数据集测试下:

  1. import torch
  2. from torch.utils.data import Dataset, DataLoader, Sampler
  3. torch.manual_seed(0)
  4. np.random.seed(0)
  5. class LongTailDataset(Dataset):
  6. def __init__(self, num_classes, max_samples_per_class):
  7. self.num_classes = num_classes
  8. self.max_samples_per_class = max_samples_per_class
  9. # Generate number of samples for each class inversely proportional to class index
  10. self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
  11. self.total_samples = sum(self.samples_per_class)
  12. # Generate targets for the dataset
  13. self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])
  14. def __len__(self):
  15. return self.total_samples
  16. def __getitem__(self, idx):
  17. # For simplicity, just return the index as the data
  18. return idx, self.targets[idx]
  19. # Parameters
  20. num_classes = 25
  21. max_samples_per_class = 1000
  22. # Create dataset
  23. dataset = LongTailDataset(num_classes, max_samples_per_class)
  24. # Create sampler
  25. batch_size = 128
  26. sampler1 = get_sampler('instance-balanced', dataset.targets.numpy())
  27. sampler2 = get_sampler('class-balanced', dataset.targets.numpy())
  28. sampler3 = get_sampler('square-root', dataset.targets.numpy())
  29. def test_sampler_in_one_batch(sampler:Sampler, inf:str):
  30. print(inf)
  31. for (_, target) in DataLoader(dataset, batch_size=64, sampler=sampler):
  32. cls_idx, cls_counts = np.unique(target.numpy(), return_counts=True)
  33. print(f'Class indices: {cls_idx}')
  34. print(f'Class counts: {cls_counts}')
  35. break # just show one batch
  36. print('-'*20)
  37. samplers = [sampler1, sampler2, sampler3]
  38. infs = ['Instance-balanced:', 'Class-balanced:', 'Square-root:']
  39. for sampler, inf in zip(samplers, infs):
  40. test_sampler_in_one_batch(sampler, inf)

Output:

  1. Instance-balanced:
  2. Class indices: [ 0 1 2 3 5 16 22 23]
  3. Class counts: [42 10 5 2 2 1 1 1]
  4. --------------------
  5. Class-balanced:
  6. Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 20 21 23]
  7. Class counts: [22 7 6 4 2 1 2 2 3 3 1 2 1 1 1 1 2 1 1 1]
  8. --------------------
  9. Square-root:
  10. Class indices: [ 0 1 2 3 4 5 6 9 10 21 22 23]
  11. Class counts: [37 8 3 6 3 1 1 1 1 1 1 1]
  12. --------------------

混合采样策略

最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:

\[p_j^{\mathbf{PB}}(t)=(1-\frac tT)p_j^{\mathbf{IB}}+\frac tTp_j^{\mathbf{CB}} \]

t表示当前epoch,T表示总epoch数。

不平衡数据集下的采样策略

不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即\(p_j=\frac{1}{n_j}\)

  1. ...
  2. elif sampling_type == 'inverse':
  3. cls_weights = 1. / cls_counts
  4. ...

在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:

\[E_n=(1-\beta^n)/(1-\beta),\ \mathrm{where~}\beta=(N-1)/N. \]

这里N表示数据集样本总数。

相关代码:

  1. ...
  2. elif sampling_type == 'effective':
  3. beta = (len(targets) - 1) / len(targets)
  4. cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
  5. ...

Output

  1. Effective:
  2. Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 20 21 22 23 24]
  3. Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1]
  4. --------------------

在和上面一样的模拟长尾数据集上,采样的结果更加均衡。

参考文献

  1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
  2. torch.utils.data.WeightedRandomSampler
  3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.

原文链接:https://www.cnblogs.com/zh-jp/p/18124824

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号