日期:2025/04/02 16:07来源:未知 人气:52
Datasets库是Hugging Face的一个重要的数据集库。 当需要微调一个模型的时候,需要进行下面操作:
使用的示例数据集:
from datasets import load_dataset# 加载数据dataset = load_dataset(path='seamew/ChnSentiCorp', split='train')print(dataset)
打印结果:
Dataset({ features: ['text', 'label'], num_rows: 9600}){'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1}
sortData = dataset.sort('label')
shuffleData = sortData.shuffle(seed=20);
从数据集中取出某些指定的部分。
dataset.select([0,1,2,3])
def filter(data): return data['text'].startswith('1')b = dataset.filter(filter)
dataset.train_test_split(test_size=0.1)
把数据集切分,10%为测试集。
把数据集均数若干份,取其中的第几份。
dataset.shard(num_shards=5, index=0)
c = a.rename_column('text', 'newColumn')
d = c.remove_columns(['newColumn'])
set_format函数用来实现与其它库数据格式的转换;
遍历数据,对每个数据进行处理
def handler(data): data['text'] = 'Prefix' + data['text'] return datadatasetMap = dataset.map(handler)
dataset.save_to_disk('./')
from datasets import load_from_diskdataset = load_from_disk('./')
安装Evaluate库:
pip install evaluate
import evaluateaccuracy = evaluate.load("accuracy")
element_count = evaluate.load("lvwerra/element_count", module_type="measurement")
evaluate.list_evaluation_modules( module_type="comparison", include_community=False, with_details=True)
for ref, pred in zip([0,1,0,1], [1,0,0,1]): accuracy.add(references=ref, predictions=pred)accuracy.compute()
输出:
{'accuracy': 0.5}
批添加:
for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]): accuracy.add_batch(references=refs, predictions=preds)accuracy.compute()
import evaluatefrom evaluate.visualization import radar_plotdata = [ {"accuracy": 0.99, "precision": 0.8, "f1": 0.95, "latency_in_seconds": 33.6}, {"accuracy": 0.98, "precision": 0.87, "f1": 0.91, "latency_in_seconds": 11.2}, {"accuracy": 0.98, "precision": 0.78, "f1": 0.88, "latency_in_seconds": 87.6}, {"accuracy": 0.88, "precision": 0.78, "f1": 0.81, "latency_in_seconds": 101.6} ]model_names = ["Model 1", "Model 2", "Model 3", "Model 4"]plot = radar_plot(data=data, model_names=model_names)plot.show()