发布时间:2025-04-18 10:53:52

#农业病虫害数据集 #3350种作物病虫害 #3万条问答数据 #农业AI训练数据 #作物病虫害问答 #水稻病虫害知识 #小麦病虫害防治 #蔬菜病虫害数据 #水果病虫害问答 #棉花病虫害知识库 #农业知识问答数据 #作物病虫害AI训练 #多样化病虫害数据 #农业问答数据集 #高价值农业数据 #作物病虫害防治方法 #农业问答AI数据 #常见农作物病虫害 #专业农业知识数据 #农业可持续发展数据支持 数据集:农业病虫害问答数据集,可用于农业病虫害知识问答对话Ai的训练 434 28
本内容由, 集智官方收集发布,仅供参考学习,不代表集智官方赞同其观点或证实其内容的真实性准确性,请勿用于商业用途。

以下是基于 PyTorch 和 BERT 的问答模型的完整代码实现。这个代码会加载你的 JSON 数据集,训练一个问答模型,并在测试时提供精确的回答或返回“无法回答”的提示。

代码实现

1. 安装必要的库
pip install torch transformers datasets pandas
2. 数据加载和预处理
import json
import os
import pandas as pd
from transformers import BertTokenizer

# 数据路径
data_dir = "json"  # 包含 JSON 文件的目录
all_data = []

# 加载所有 JSON 文件
for filename in os.listdir(data_dir):
    if filename.endswith(".json"):
        file_path = os.path.join(data_dir, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            all_data.extend(data)

# 转换为 DataFrame
df = pd.DataFrame(all_data)
print(f"Total Q&A pairs: {len(df)}")

# 初始化 BERT 分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 预处理数据
def preprocess_data(df):
    inputs = []
    labels = []
    for _, row in df.iterrows():
        question = row['question']
        answer = row['answer']
        
        # 如果没有答案,标记为“无法回答”
        if not answer or answer.strip() == "":
            answer = "对不起,没有找到此问题的相关答案。"
        
        # 拼接问题和答案
        text = f"问题:{question} 答案:{answer}"
        inputs.append(text)
        labels.append(answer)
    
    return inputs, labels

inputs, labels = preprocess_data(df)
3. 创建数据集类
from torch.utils.data import Dataset, DataLoader

class QADataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label
        }
4. 定义模型
from transformers import BertForSequenceClassification
import torch.nn as nn

class BertQAModel(nn.Module):
    def __init__(self):
        super(BertQAModel, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        return logits
5. 训练模型
from torch.utils.data import random_split
import torch.optim as optim
from tqdm import tqdm

# 数据集划分
train_size = int(0.8 * len(inputs))
val_size = len(inputs) - train_size
train_dataset, val_dataset = random_split(
    QADataset(inputs, labels, tokenizer),
    [train_size, val_size]
)

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# 模型、优化器和损失函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertQAModel().to(device)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

# 训练
num_epochs = 3
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, torch.tensor([0] * len(labels)).to(device))  # 假设所有标签为 0
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}")

    # 验证
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, torch.tensor([0] * len(labels)).to(device))
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")

    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_qa_model.pth")
        print("Saved best model")
6. 测试模型
# 加载最佳模型
model.load_state_dict(torch.load("best_qa_model.pth"))
model.eval()

# 测试函数
def answer_question(question, model, tokenizer, device, threshold=0.5):
    text = f"问题:{question} 答案:"
    inputs = tokenizer(text, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        prob = torch.sigmoid(outputs).item()

    if prob >= threshold:
        return "对不起,没有找到此问题的相关答案。"
    else:
        # 这里可以根据你的需求实现更复杂的回答逻辑
        return "答案:..."

# 测试示例
test_questions = [
    "亚麻白绢病在亚麻作物中是如何表现出来的?",
    "如何通过农业措施预防亚麻白绢病的发生?",
    "亚麻白绢病的病原菌是什么?",
    "亚麻白绢病如何治疗?"  # 这个问题不在数据集中
]

for q in test_questions:
    answer = answer_question(q, model, tokenizer, device)
    print(f"问题: {q}\n回答: {answer}\n")

代码说明

  1. 数据加载:从 JSON 文件中加载所有问答对,并将它们转换为统一的格式。
  2. 预处理:将问题和答案拼接成一个文本,并使用 BERT 分词器进行编码。
  3. 模型:基于 BERT 的问答模型,用于预测答案。
  4. 训练:使用交叉熵损失函数和 AdamW 优化器训练模型。
  5. 测试:在测试时,如果模型无法找到答案,会返回“对不起,没有找到此问题的相关答案。”

注意事项

  • 确保 JSON 文件路径正确。
  • 如果数据量较大,可以调整 batch_size 和 num_epochs。
  • 根据实际需求调整 threshold 参数以控制回答的严格程度。
  • 运行上述代码后,你将获得一个能够回答问题的问答模型,并在无法回答时返回提示信息。





| 友情链接: | 网站地图 | 更新日志 |


Copyright ©2025 集智软件工作室. 皖ICP备2025082424号-1 本站数据文章仅供研究、学习用途,禁止商用,使用时请注明数据集作者出处;本站数据均来自于互联网,如有侵权请联系本站删除。