Skip to content

预训练数据

概述

预训练最重要的部分是数据的获取和处理,预训练数据量级在10T tokens,继续预训练则至少在100B tokens 的量级

预训练数据组成​预训练数据集包含各种类型、各种领域的未标记文本数据。

一般分为:​通用预训练数据集​和特定领域预训练数据集

数据预处理四个步骤:数据收集,数据过滤,数据去重,分词。

  • 数据收集:数据收集步骤即从对应数据源采样需求的数据

  • 数据过滤:数据过滤的目的是提高数据质量,主要的实现方式是对数据进行筛选和清洗。具体方法包括基于模型的方法和基于启发式的方法。

    • 基于模型的方法:训练筛选模型来过滤低质量的数据。高质量的预训练语料库可以用作正面样本,而要过滤的受污染文本用作负面样本,以训练用于过滤的分类器

    • 过滤低质量的数据:启发式规则通常手动定义,并设置为相关的质量指标

  • 数据去重:数据去重涉及删除语料库中的重复或高度相似的文本

    • TF-IDF(词频-逆文档频率)软去重:计算文本中每个词的TF-IDF权重来比较文本之间的相似性。相似度超过阈值的文本将被删除。TF-IDF权重是词在文本中的频率(TF)乘以在整个语料库中的逆文档频率(IDF)。较高的权重表明这个词在特定文本中频繁出现,但在整个语料库中不常见,使其成为文本的关键特征。

    • MinHash:该方法估计两个集合之间的相似性。文本通过随机哈希处理,获得一组最小哈希值。然后通过比较这些最小哈希值来估计相似性

    • SimHash:这个算法用于计算文本相似性。文本特征向量被哈希以生成固定长度的哈希码。通过比较文本哈希码之间的汉明距离来估计相似性,距离越小表示相似性越大

数据来源

预训练涉及学习大量数据,以全面了解世界及其各种复杂性。数据集就不再介绍了。

数据爬取

有些高质量的数据,比如论文书籍,往往是 pdf 格式,这时候还需要去调用效果较好的 pdf 服务。

​爬虫主要可分为:

  • 定向网站爬取​

  • 根据关键词从指定网站爬取​

  • 基于搜索引擎爬取

数据清洗:预训练处理数据最重要的部分

得到的数据质量、多样性决定了模型性能,可以多看大厂的report,数据处理的Pipeline好好看看,以后会详细讲解这部分

URL过滤:llama 3.1制定URL黑名单和计算URL分数决定内容是否保留。

#!/usr/bin/env python3

#@> IMPORTING ALL THE DEPENDENCIES
import re
import argparse
from urllib.parse import urlparse
from sys import stdin, stdout, exit

#@> PROCESSING COMMAND LINE ARGUMENTS
parser = argparse.ArgumentParser(epilog="\tExample:\nwaybackurls sub.domain.tld | parshu")
parser._optionals.title = "OPTIONS"

#@> OPTIONS
parser.add_argument('-x', dest='xss', action='store_true', help="Filter all URLS where you check for xss.")
parser.add_argument('-r', dest='redirect', action='store_true', help="Filter all URLS where you check for open-redirect.")
parser.add_argument('-l', dest='lfi', action='store_true', help="Filter all URLS where you check for lfi.")
parser.add_argument('-s', dest='sql', action='store_true', help="Filter all URLS where you check for sqli.")
parser.add_argument('-t', dest='ssti', action='store_true', help="Filter all URLS where you check for ssti.")
parser.add_argument('-f', dest='ssrf', action='store_true', help="Filter all URLS where you check for ssrf.")
parser.add_argument('-c', dest='rce', action='store_true', help="Filter all URLS where you check for rce.")
parser.add_argument('-q', dest='query', action='store_true', help="Filter all URLS which contains query parameters.")

cmd = parser.parse_args()

#@> REGEX PATTERNS
PATTERN = "(\.asp|\.aspx|\.bat|\.cfm|\.cgi|\.css|\.dll|\.exe|\.htm|\.html|\.inc|\.jhtml|\.js|\.jsa|\.jsp|\.log|\.mdb|\.nsf|\.pcap|\.php|\.php2|\.php3|\.php4|\.php5|\.php6|\.php7|\.phps|\.pht|\.phtml|\.pl|\.reg|\.sh|\.shtml|\.sql|\.swf|\.txt|\.xml|\.ini|\,xml|\.bat|\.LOG|\.tn|\.bak|\.sql)"
XSS_REGEX = "(api=|api_key=|begindate=|callback=|=|categoryid=|csrf_token=|email=|emailto=|enddate=|id=|imagine=|immagine=|item=|jsonp=|key=|keyword=|keywords=|l=|lang=|list_type=|month=|name=|p=|page=|page_id=|password=|pid=|pp=|q=|query=|s=|search=|terms=|token=|type=|unsubscribe_token=|url=|username=|view=|year=)"
RED_REGEX = "(Lmage_url=|Open=|callback=|cgi-bin/redirect.cgi|cgi-bin/redirect.cgi?|checkout=|checkout_url=|continue=|data=|dest=|destination=|dir=|domain=|feed=|file=|file_name=|file_url=|folder=|folder_url=|forward=|from_url=|go=|goto=|host=|html=|image_url=|img_url=|load_file=|load_url=|login?to=|login_url=|logout=|navigation=|next=|next_page=|out=|page=|page_url=|path=|port=|redir=|redirect=|redirect_to=|redirect_uri=|redirect_url=|reference=|return=|returnTo=|return_path=|return_to=|return_url=|rt=|rurl=|show=|site=|target=|to=|uri=|url=|val=|validate=|view=|window=)"
SQL_REGEX = "(id=|select=|report=|role=|update=|query=|user=|name=|sort=|where=|search=|params=|process=|row=|view=|table=|from=|sel=|results=|sleep=|fetch=|order=|keyword=|column=|field=|delete=|string=|number=|filter=)"
STI_REGEX = "(template=|preview=|id=|view=|activity=|name=|content=|redirect=)"
SRF_REGEX = "(access=|admin=|dbg=|debug=|edit=|grant=|test=|alter=|clone=|create=|delete=|disable=|enable=|exec=|execute=|load=|make=|modify=|rename=|reset=|shell=|toggle=|adm=|root=|cfg=|dest=|redirect=|uri=|path=|continue=|url=|window=|next=|data=|reference=|site=|html=|val=|validate=|domain=|callback=|return=|page=|feed=|host=|port=|to=|out=|view=|dir=|show=|navigation=|open=|file=|document=|folder=|pg=|php_path=|style=|doc=|img=|filename=)"
LFI_REGEX = "(file=|document=|folder=|root=|path=|pg=|style=|pdf=|template=|php_path=|doc=|page=|name=|cat=|dir=|action=|board=|date=|detail=|download=|prefix=|include=|inc=|locate=|show=|site=|type=|view=|content=|layout=|mod=|conf=|url=)"
RCE_REGEX = "(daemon=|upload=|dir=|download=|log=|ip=|cli=|cmd=|exec=|command=|execute=|ping=|query=|jump=|code=|reg=|do=|func=|arg=|option=|load=|process=|step=|read=|function|req=|feature=|exe=|module=|payload=|run=|print=)"

try:
    if stdin.isatty():
        # This part of the code is likely intended to provide a default
        # behavior or instructions when no input is piped to the script.
        # The original snippet provided doesn't include the full content
        # of this block, but typically it would print usage information
        # or an error message.
        pass # Placeholder for content when stdin is a tty

    elif cmd.lfi:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains lfi parameters.
            """
            link = str(URL.strip())
            if re.search(LFI_REGEX, link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')

    elif cmd.sql:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains sql parameters.
            """
            link = str(URL.strip())
            if re.search(SQL_REGEX, link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')

    elif cmd.ssti:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains ssti parameters.
            """
            link = str(URL.strip())
            if re.search(STI_REGEX, link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')

    elif cmd.ssrf:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains ssrf parameters.
            """
            link = str(URL.strip())
            if re.search(SRF_REGEX, link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')

    elif cmd.rce:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains rce parameters.
            """
            link = str(URL.strip())
            if re.search(RCE_REGEX, link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')

    elif cmd.query:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS which contains query parameters.
            """
            link = str(URL.strip())
            if re.search("(=|&)", link):
                stdout.write(re.sub(r"'|(|)", "", link) + '\n')
    else:
        for URL in stdin.readlines():
            """
            STDOUT all the URLS except which contains query parameters and files.
            """
            links = re.split('/', urlparse(URL.strip()).path)[-1]

            if re.search(PATTERN, links):
                output = re.sub(r"([^/\\]+)(\.php|\.p|\%|\.asp|\.aspx|\.bat|\.cfm|\.cgi|\.css|\.dll|\.exe|\.htm|\.html|\.inc|\.jhtml|\.js|\.jsa|\.jsp|\.log|\.mdb|\.nsf|\.pcap|\.php2|\.php3|\.php4|\.php5|\.php6|\.php7|\.phps|\.pht|\.phtml|\.pl|\.reg|\.sh|\.shtml|\.sql|\.swf|\.txt|\.xml|\.ini|\,xml|\.bat|\.LOG|\.tn|\.bak|\.sql).*", "", str(URL.strip()))
                stdout.write(re.sub(r"'|(|)", "", output) + '\n')

            elif re.search(r"(&|=|\?|\%)", URL.strip()):
                pass

            else:
                link = urlparse(URL.strip()).scheme + "://" + urlparse(URL.strip()).netloc + urlparse(URL.strip()).path
                output = re.sub(r"([^/\\]+)(\.php|\.p|\%).*", "", link)
                stdout.write(str(re.sub(r"'|(|)", "", output)) + '\n')
except KeyboardInterrupt:
    exit(0)
except:
    pass

内容抽取

在获得过滤后的 URL 集合后,我们需要获取到这些 URL 中的「文本信息」,需要过滤并丢弃「目录」、「标题」、「广告」等无关的内容,使用 trafilatura 库来进行内容抓取,trafilatura 的使用方法也比较简单:

1
2
3
4
5
6
7
8
9
# load necessary components
from trafilatura import fetch_url, extract

url = 'https://github.blog/2019-03-29-leader-spotlight-erin-spiceland/'

downloaded = fetch_url(url)
downloaded is None
result = extract(downloaded)
print(result)

语音识别

1
2
3
4
5
1. 使用FastText训练一个语言识别模型,去掉那些语言阈值得分低于 0.65 的文章

2. 使用langid识别语言种类​3.

3. 使用data_juicer/ops/filter/language_id_score_filter.py进行识别
import langid
from typing import Dict, Any, Optional

# 定义常量和字段名
class StatsKeys:
    lang = "lang"
    lang_score = "lang_score"

class Fields:
    stats = "stats"

# 语言识别器类
class LanguageDetector:
    def __init__(self, text_key: str = "text", model_key: str = "lang_model"):
        """
        初始化语言识别器

        Args:
            text_key: 输入样本中包含文本的字段名
            model_key: 语言模型的键名
        """
        self.text_key = text_key
        self.model_key = model_key
        self.models = {}  # 模型缓存

    def get_model(self, model_key: str) -> Optional[Any]:
        """获取语言识别模型,支持模型缓存"""
        if model_key not in self.models:
            try:
                # 实际项目中这里可能是加载fastText等模型
                # 此处简化为直接使用langid
                self.models[model_key] = langid
                print(f"模型 {model_key} 加载成功")
            except Exception as e:
                print(f"模型加载失败: {e}")
                return None
        return self.models[model_key]

    def compute_stats_single(self, sample: Dict) -> Dict:
        """
        计算单个样本的语言识别结果

        Args:
            sample: 包含文本的样本字典

        Returns:
            处理后的样本字典,添加了语言识别结果
        """
        # 检查是否已经计算过
        if (StatsKeys.lang in sample.get(Fields.stats, {}) and 
            StatsKeys.lang_score in sample.get(Fields.stats, {})):
            return sample

        # 预处理文本
        text = sample[self.text_key].lower().replace('\n', ' ')

        # 获取语言识别模型
        ft_model = self.get_model(self.model_key)
        if ft_model is None:
            err_msg = 'Model not loaded. Please retry later.'
            print(err_msg)
            raise ValueError(err_msg)

        # 执行语言识别
        # 注意:langid的API与fastText不同,这里做了适配
        lang_id, lang_score = ft_model.classify(text)

        # 存储识别结果
        if Fields.stats not in sample:
            sample[Fields.stats] = {}

        sample[Fields.stats][StatsKeys.lang] = lang_id
        sample[Fields.stats][StatsKeys.lang_score] = lang_score

        return sample

# 示例用法
if __name__ == "__main__":
    # 初始化语言识别器
    detector = LanguageDetector()

    # 示例样本
    sample = {
        "text": "Tudo bem?",
        "other_field": "其他信息"
    }

    # 执行语言识别
    result = detector.compute_stats_single(sample)

    # 输出结果
    print(f"识别语言: {result[Fields.stats][StatsKeys.lang]}")
    print(f"置信度: {result[Fields.stats][StatsKeys.lang_score]:.4f}")

低质过滤

  • 篇章级别过滤:首先,去除一些文章内一直重复同一段内容的文章,以及一些包含错误信息的文章(例如:抓取超时等);其次,对每一篇文章,通过判断文章整体长度、标点符号占文章长度的比例等,来过滤掉那些不正规的文章

  • 句子级别过滤:过滤掉文章中的一些无用的句子,这种句子通常没有具体含义,但话术繁多,又难以穷举

  • 利用启发式规则进行 pretrain 数据质量筛选:通过精心设计的规则有针对性地识别和剔除低质量数据,如果训练特定语言,可以过滤其它语言文本

模型打分

  • 利用模型对 pretrain 数据的质量进行打分,同 size 下,BERT 结构的模型的表征能力是强于 transformer-decoder 模型, 因此打分模型可以使用BERT类型的模型进行训练,或者利用强闭源模型来对训练数据进行打分

  • 训练文本分类器进行数据清洗,对比学习,将高质量数据视为正样本,低质量数据视为负样本

数据去重

  • 训练数据集中的重复:用于训练的多个类型、来源的数据集内部和之间的重复。

    • 数据集内部的重复,包括单个文档内部重复的lines、paragraphs、n-grams等(说明文档本身质量较低)、多个文档之间的重复(基于完全匹配或者模糊匹配)

    • 数据集之间的重复,主要指多个数据集基于同一个来源但是进行了不同的预处理之后产出的数据,一个典型的例子是LLaMa等模型会同时使用基于CommonCrawl预处理过的数据以及之前T5模型基于当时CommonCrawl处理的C4数据,两者定然是有一定重复的​

    • 训练迭代设置的重复:在给定数据集上人为设定的重复轮次(Epochs),一般来说针对不同类型、来源、质量的数据采样训练的Epochs并不一致​

    • 训练与测试集的重复:例如,预训练过程中预留的观测Language Modeling Loss的测试集、用于评测效果的Benchmarks等,都应该从训练集合中去除相似的数据

数据多样性

任务多样性、语义多样性、语种多样性、数据来源多样性等,提高数据的质量和多样性一直是大家比较关注的问题。

分为三步:向量化\(\to\)聚类\(\to\)核心样本采样