pytorch Dataset类

官网:Dataset类

简介

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader这个类来更加快捷的对数据进行操作。在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。


官网参数

参数解释:

dataset:需要load的数据
batch_size:每个batch的大小,一次返回几条数据处理
shuffle:是否进行shuffle操作
num_workers:加载数据的时候使用几个子进程

使用方式很固定,直接记住代码格式:
import torch.utils.data as Data

input_batch, target_batch = make_data(sentence)
dataset = Data.TensorDataset(input_batch, target_batch)
loader = Data.DataLoader(dataset, batch_size=16, shuffle=True)
#训练
for epoch in range(10000):
  for x, y in loader:
      pred = model(x)
      loss = criterion(pred, y)
      if (epoch + 1) % 1000 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
      optimizer.zero_grad()#torch三件套
      loss.backward()
      optimizer.step()
####分割线就是我
#将数据处理成需要的格式tensor类型
train_inputs, train_token, train_mask, train_labels=make_data(sentence)
train_data = Data.TensorDataset(train_inputs, train_token, train_mask, train_labels)
train_dataloader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
for _ in range(2):
    for i, batch in enumerate(train_dataloader):#因为数据是四种,所以用迭代器的方式处理,bach=[tensor1,tensor2,tensor3,tensor4]
            batch = tuple(t.to(device) for t in batch)#转换成tuple并且指定cuda运行
            loss = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2], labels=batch[3])[0]
            print(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % 10 == 0:
              eval(model, validation_dataloader)

以上图第二段例子代码来看:
首先是把原句子输入预处理函数 make_data;这个函数是自己定义的,按照任务要求,一般调用 tokenizer 来处理;举个例子:
下面是一个预处理的函数示例:

def make_data(sentences,labels):
    input_ids,token_type_ids,attention_mask=[],[],[]
    #input_ids是每个词对应的索引idx ;token_type_ids是对应的0和1,标识是第几个句子;attention_mask是对句子长度做pad
    #input_ids=[22,21,...499] token_type_ids=[0,0,0,0,1,1,1,1] ;attention_mask=[1,1,1,1,1,0,0,0]补零
    for i in range(len(sentences)):
        encoded_dict = tokenizer.encode_plus(
        sentences[i],
        add_special_tokens = True,      # 添加 '[CLS]' 和 '[SEP]'
        max_length = 96,           # 填充 & 截断长度
        pad_to_max_length = True,
        return_tensors = 'pt',         # 返回 pytorch tensors 格式的数据
        )
        input_ids.append(encoded_dict['input_ids'])
        token_type_ids.append(encoded_dict['token_type_ids'])
        attention_mask.append(encoded_dict['attention_mask'])
  
    input_ids = torch.LongTensor(input_ids)#每个词对应的索引
    token_type_ids = torch.LongTensor(token_type_ids)#0&1标识是哪个句子
    attention_mask = torch.LongTensor(attention_mask)#[11100]padding之后的句子
    labels = torch.LongTensor(labels)#所有实例的label对应的索引idx

  return input_ids, token_type_ids, attention_mask, labels

经过这个处理之后的返回值 然后调用 Data.TensorDataset 和 Data.DataLoader;这个其实就是把所有需要的数据都装载到了一个dict 里面,然后被封装好,设定 batch 的大小,然后进行训练:

train_inputs, train_token, train_mask, train_labels=make_data(sentence)
train_data = Data.TensorDataset(train_inputs, train_token, train_mask, train_labels)
train_dataloader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
for i, batch in enumerate(train_dataloader):
   batch = tuple(t.to(device) for t in batch)#转换成tuple并且指定cuda运行
   loss = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2], labels=batch[3])[0]
 #这里通过batch下标就可以达到访问各个类别训练数据的目的

改写dataset 类

在我们自己的模型训练中,常常需要使用非官方自制的数据集。
我们可以通过改写torch.utils.data.Dataset中的getitemlen来载入我们自己的数据集。
getitem:获取数据集中的数据
len:获取整个数据集的长度(即个数)
看下面的一个例子:

class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
#item={"input_ids":[] ,"token_type_ids":[],"attention_mask":[],"labels":[]}

    def __len__(self):
        return len(self.labels)

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

train_dataset = IMDbDataset(train_encodings, train_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
for epoch in range(3):
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optim.step()

这个例子跟上文例子其实相差不多,只是这个地方把 作为分类任务的 label 数据传入 重写的 dataset 类,加入训练数据的 dict 中
item={"input_ids":[] ,"token_type_ids":[],"attention_mask":[],"labels":[]}
上文的例子是直接把 label 和其他类别的训练数据如:input_ids ,attention_mask 等一起传入 Data.TensorDataset() 其实也是一样的目的,这个地方只是给一个简单的例子,表示在实际的任务中,如果需要更加复杂的封装训练数据,需要重写 dataset 类中的这些函数来达到加载数据的目的;

参数 collate_fn 用法:

将一个list的sample组成一个mini-batch的函数,可以自定义这个参数,将自己的数据处理成一个 batch,padding mask 等操作:

def custom_collate(batch):
    transposed = list(zip(*batch))
    lst = []
    # transposed[0]: list of token ids of text
    padded_seq = []
    max_seq_len = len(max(transposed[0], key=len))
    for seq in transposed[0]:
        padded_seq.append(seq + [0] * (max_seq_len - len(seq)))
    lst.append(torch.LongTensor(padded_seq))

    # tansposed[1]: list of tag ids of SAME LENGTH!
    padded_tag = []
    att_mask = []
    for seq in transposed[1]:
        padded_tag.append(seq + [0] * (max_seq_len - len(seq)))
        att_mask.append([1] * len(seq) + [0] * (max_seq_len - len(seq)))
    lst.append(torch.LongTensor(padded_tag))
    lst.append(torch.FloatTensor(att_mask))

    return lst

dataset = MyDataset(train_tokens, train_tags)

train_loader = DataLoader(dataset=dataset, batch_size=train_batch, collate_fn=custom_collate, shuffle=True)

  • simple example:
# a simple custom collate function, just to show the idea
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]

1.

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
"""定义dataloader,在训练阶段shuffle数据,预测阶段不需要shuffle"""

train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=batch_size)


参考:
huggingface fine-tuning with custom dataset
BERT Fine-Tuning Tutorial with PyTorch · Chris McCormick (mccormickml.com)
https://ptorch.com/docs/1/utils-data

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 162,825评论 4 377
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 68,887评论 2 308
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 112,425评论 0 255
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,801评论 0 224
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 53,252评论 3 299
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 41,089评论 1 226
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 32,216评论 2 322
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 31,005评论 0 215
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,747评论 1 250
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,883评论 2 255
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,354评论 1 265
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,694评论 3 265
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,406评论 3 246
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,222评论 0 9
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,996评论 0 201
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 36,242评论 2 287
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 36,017评论 2 281

推荐阅读更多精彩内容