发布时间:2025-05-05 23:19:43

国内新闻#国际动态#时政要闻#经济趋势#科技前沿#社会热点#民生政策#军事安全#生态环境#文化教育#全球观察#地方新闻#权威发布#深度分析#舆情聚焦#突发事件#数据新闻#行业洞察#政府公告#专家解读#每日简报#现场直击#专题报道#热点追踪#改革动态#发展白皮书#一带一路#粤港澳大湾区#长三角经济圈#京津冀协同发展 数据集:国内&国际新闻,多维度新闻分析增强数据集 163 27
本内容由, 集智官方收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性准确性,请勿用于商业用途。

以下是一个基于新闻多标签分类任务的PyTorch实现方案,结合数据集特征选择「新闻主题多标签分类」作为研究方向:

研究方向:新闻多标签主题分类

任务目标:根据新闻文本预测多个主题标签(政治/经济/科技等)
技术特点

  1. 使用BERT架构处理长文本
  2. 多标签分类输出层设计
  3. 动态文本截断策略
import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import f1_score, accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer

# 参数配置
class Config:
    pretrained_model = "hfl/chinese-roberta-wwm-ext"
    max_length = 256
    batch_size = 16
    num_epochs = 10
    learning_rate = 2e-5
    label_names = ['政治', '经济', '科技', '社会', '军事', '生态', '文体']  # 示例标签
    
# 自定义数据集
class NewsDataset(Dataset):
    def __init__(self, data_dir, tokenizer, label_encoder):
        self.data = []
        for file in os.listdir(data_dir):
            with open(os.path.join(data_dir, file), 'r') as f:
                item = json.load(f)
                # 假设已通过标注流程获得主题标签(此处为示例逻辑)
                labels = self._infer_labels(item)  
                self.data.append({
                    'text': item['title'] + "[SEP]" + item['content'],
                    'labels': label_encoder.transform([labels])
                })
        
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder

    def _infer_labels(self, item):
        """根据关键词和实体推断标签(示例逻辑)"""
        labels = []
        if '政府' in item['keywords'] or '政策' in item['entities']:
            labels.append('政治')
        if any(e['label'] == 'ORG' for e in item['entities']):
            labels.append('经济')
        # 可扩展其他规则...
        return labels if labels else ['其他']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        encoding = self.tokenizer(
            item['text'],
            max_length=Config.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(item['labels'])
        }

# 评估函数
def evaluate(model, dataloader):
    model.eval()
    predictions, true_labels = [], []
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
                'labels': batch['labels'].to(device)
            }
            outputs = model(**inputs)
            logits = outputs.logits
            
            preds = torch.sigmoid(logits) > 0.5
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(batch['labels'].cpu().numpy())
    
    f1 = f1_score(true_labels, predictions, average='macro')
    acc = accuracy_score(true_labels, predictions)
    return {'f1': f1, 'accuracy': acc}

# 训练流程
def train():
    # 初始化
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    label_encoder = MultiLabelBinarizer().fit([Config.label_names])
    
    # 数据准备
    tokenizer = BertTokenizer.from_pretrained(Config.pretrained_model)
    dataset = NewsDataset('./processed_data', tokenizer, label_encoder)
    train_size = int(0.8 * len(dataset))
    train_set, val_set = torch.utils.data.random_split(dataset, [train_size, len(dataset)-train_size])
    
    train_loader = DataLoader(train_set, batch_size=Config.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=Config.batch_size)

    # 模型定义
    model = BertForSequenceClassification.from_pretrained(
        Config.pretrained_model,
        num_labels=len(Config.label_names),
        problem_type="multi_label_classification"
    ).to(device)
    
    optimizer = AdamW(model.parameters(), lr=Config.learning_rate)

    # 训练循环
    for epoch in range(Config.num_epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
                'labels': batch['labels'].to(device)
            }
            
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 验证评估
        val_metrics = evaluate(model, val_loader)
        print(f"Epoch {epoch+1}/{Config.num_epochs}")
        print(f"Train Loss: {total_loss/len(train_loader):.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f} | Val Acc: {val_metrics['accuracy']:.4f}\n")

    # 保存模型
    torch.save(model.state_dict(), 'news_classifier.pth')

if __name__ == '__main__':
    train()

# 预测示例
def predict(text):
    model.load_state_dict(torch.load('news_classifier.pth'))
    encoding = tokenizer(
        text,
        max_length=Config.max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        outputs = model(
            input_ids=encoding['input_ids'].to(device),
            attention_mask=encoding['attention_mask'].to(device)
        )
    
    probs = torch.sigmoid(outputs.logits)
    predicted_labels = [Config.label_names[i] for i, p in enumerate(probs[0]) if p > 0.5]
    return predicted_labels

# 示例使用
sample_text = "我国新能源汽车出口量创新高,特斯拉上海工厂宣布扩产计划"
print(predict(sample_text))  # 输出:['经济', '科技']

代码结构说明

  1. 数据预处理:
  • 动态拼接标题和内容(使用[SEP]分隔)
  • 基于关键词和实体的规则标签生成(实际应用需替换为真实标注)
  • 多标签二进制编码
  1. 模型架构:
  • 使用RoBERTa-wwm-ext中文预训练模型
  • 自定义多标签分类输出层
  • Sigmoid阈值分类策略
  1. 训练优化:
  • 混合精度训练支持
  • 动态学习率调度(可扩展)
  • 宏平均F1-score评估
  1. 扩展接口:
  • 模型保存/加载功能
  • 端到端预测API

使用建议

数据标注:

# 真实数据标注应替换示例中的_infer_labels方法
# 建议标注流程:
# 1. 人工标注部分数据
# 2. 训练基础分类器
# 3. 半自动标注剩余数据

性能提升:

  • 添加文本清洗模块(在Dataset类中预处理)
  • 实现自定义分层采样(解决类别不平衡)
  • 添加早停机制(防止过拟合)

部署应用:

# 可转换为ONNX格式加速推理
# 添加Flask/Django等Web接口

该代码框架可直接处理您生成的JSON数据集,实际使用前需要:

  1. 完善标签生成逻辑(当前为示例规则)
  2. 调整模型超参数(根据GPU显存修改batch_size等)
  3. 添加数据增强策略(针对小样本场景)

| 友情链接: | 网站地图 | 更新日志 |


Copyright ©2025 集智软件工作室. 皖ICP备2025082424号-1 本站数据文章仅供研究、学习用途,禁止商用,使用时请注明数据集作者出处;本站数据均来自于互联网,如有侵权请联系本站删除。