发布时间:2025-04-18 10:53:52
本内容由, 集智官方收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性准确性,请勿用于商业用途。
以下是基于 PyTorch 和 BERT 的问答模型的完整代码实现。这个代码会加载你的 JSON 数据集,训练一个问答模型,并在测试时提供精确的回答或返回“无法回答”的提示。
pip install torch transformers datasets pandas
import json
import os
import pandas as pd
from transformers import BertTokenizer
# 数据路径
data_dir = "json" # 包含 JSON 文件的目录
all_data = []
# 加载所有 JSON 文件
for filename in os.listdir(data_dir):
if filename.endswith(".json"):
file_path = os.path.join(data_dir, filename)
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
all_data.extend(data)
# 转换为 DataFrame
df = pd.DataFrame(all_data)
print(f"Total Q&A pairs: {len(df)}")
# 初始化 BERT 分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# 预处理数据
def preprocess_data(df):
inputs = []
labels = []
for _, row in df.iterrows():
question = row['question']
answer = row['answer']
# 如果没有答案,标记为“无法回答”
if not answer or answer.strip() == "":
answer = "对不起,没有找到此问题的相关答案。"
# 拼接问题和答案
text = f"问题:{question} 答案:{answer}"
inputs.append(text)
labels.append(answer)
return inputs, labels
inputs, labels = preprocess_data(df)
from torch.utils.data import Dataset, DataLoader
class QADataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length=128):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': label
}
from transformers import BertForSequenceClassification
import torch.nn as nn
class BertQAModel(nn.Module):
def __init__(self):
super(BertQAModel, self).__init__()
self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)
self.classifier = nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
return logits
from torch.utils.data import random_split
import torch.optim as optim
from tqdm import tqdm
# 数据集划分
train_size = int(0.8 * len(inputs))
val_size = len(inputs) - train_size
train_dataset, val_dataset = random_split(
QADataset(inputs, labels, tokenizer),
[train_size, val_size]
)
# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
# 模型、优化器和损失函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertQAModel().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()
# 训练
num_epochs = 3
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels']
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, torch.tensor([0] * len(labels)).to(device)) # 假设所有标签为 0
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels']
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, torch.tensor([0] * len(labels)).to(device))
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "best_qa_model.pth")
print("Saved best model")
# 加载最佳模型
model.load_state_dict(torch.load("best_qa_model.pth"))
model.eval()
# 测试函数
def answer_question(question, model, tokenizer, device, threshold=0.5):
text = f"问题:{question} 答案:"
inputs = tokenizer(text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
prob = torch.sigmoid(outputs).item()
if prob >= threshold:
return "对不起,没有找到此问题的相关答案。"
else:
# 这里可以根据你的需求实现更复杂的回答逻辑
return "答案:..."
# 测试示例
test_questions = [
"亚麻白绢病在亚麻作物中是如何表现出来的?",
"如何通过农业措施预防亚麻白绢病的发生?",
"亚麻白绢病的病原菌是什么?",
"亚麻白绢病如何治疗?" # 这个问题不在数据集中
]
for q in test_questions:
answer = answer_question(q, model, tokenizer, device)
print(f"问题: {q}\n回答: {answer}\n")
对话交流数据集是一种专门用于训练对话系统或聊天机器人的数据集合,它包含了大量的对话实例。这些实例通常是由真实的对话记录或模拟的对话场景构成,旨在让机器学习模型能够理解和生成自然流畅的对话。