溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

TensorFlow和Keras大數(shù)據(jù)量?jī)?nèi)存溢出怎么辦

發(fā)布時(shí)間:2020-07-06 11:32:07 來源:億速云 閱讀:334 作者:清晨 欄目:開發(fā)技術(shù)

這篇文章主要介紹TensorFlow和Keras大數(shù)據(jù)量?jī)?nèi)存溢出怎么辦,文中介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們一定要看完!

內(nèi)存溢出問題是參加kaggle比賽或者做大數(shù)據(jù)量實(shí)驗(yàn)的第一個(gè)攔路虎。

以前做的練手小項(xiàng)目導(dǎo)致新手產(chǎn)生一個(gè)慣性思維——讀取訓(xùn)練集圖片的時(shí)候把所有圖讀到內(nèi)存中,然后分批訓(xùn)練。

其實(shí)這是有問題的,很容易導(dǎo)致OOM?,F(xiàn)在內(nèi)存一般16G,而訓(xùn)練集圖片通常是上萬張,而且RGB圖,還很大,VGG16的圖片一般是224x224x3,上萬張圖片,16G內(nèi)存根本不夠用。這時(shí)候又會(huì)想起——設(shè)置batch,但是那個(gè)batch的輸入?yún)?shù)卻又是圖片,它只是把傳進(jìn)去的圖片分批送到顯卡,而我OOM的地方恰是那個(gè)“傳進(jìn)去”的圖片,怎么辦?

解決思路其實(shí)說來也簡(jiǎn)單,打破思維定式就好了,不是把所有圖片讀到內(nèi)存中,而是只把所有圖片的路徑一次性讀到內(nèi)存中。

大致的解決思路為:

將上萬張圖片的路徑一次性讀到內(nèi)存中,自己實(shí)現(xiàn)一個(gè)分批讀取函數(shù),在該函數(shù)中根據(jù)自己的內(nèi)存情況設(shè)置讀取圖片,只把這一批圖片讀入內(nèi)存中,然后交給模型,模型再對(duì)這一批圖片進(jìn)行分批訓(xùn)練,因?yàn)閮?nèi)存一般大于等于顯存,所以內(nèi)存的批次大小和顯存的批次大小通常不相同。

下面代碼分別介紹Tensorflow和Keras分批將數(shù)據(jù)讀到內(nèi)存中的關(guān)鍵函數(shù)。Tensorflow對(duì)初學(xué)者不太友好,所以我個(gè)人現(xiàn)階段更習(xí)慣用它的高層API Keras來做相關(guān)項(xiàng)目,下面的TF實(shí)現(xiàn)是之前不會(huì)用Keras分批讀時(shí)候參考的一些列資料,在模型訓(xùn)練上仍使用Keras,只有分批讀取用了TF的API。

Tensorlow

在input.py里寫get_batch函數(shù)。

def get_batch(X_train, y_train, img_w, img_h, color_type, batch_size, capacity):
  '''
  Args:
    X_train: train img path list
    y_train: train labels list
    img_w: image width
    img_h: image height
    batch_size: batch size
    capacity: the maximum elements in queue
  Returns:
    X_train_batch: 4D tensor [batch_size, width, height, chanel],\
            dtype=tf.float32
    y_train_batch: 1D tensor [batch_size], dtype=int32
  '''
  X_train = tf.cast(X_train, tf.string)

  y_train = tf.cast(y_train, tf.int32)
  
  # make an input queue
  input_queue = tf.train.slice_input_producer([X_train, y_train])

  y_train = input_queue[1]
  X_train_contents = tf.read_file(input_queue[0])
  X_train = tf.image.decode_jpeg(X_train_contents, channels=color_type)

  X_train = tf.image.resize_images(X_train, [img_h, img_w], 
                   tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  X_train_batch, y_train_batch = tf.train.batch([X_train, y_train],
                         batch_size=batch_size,
                         num_threads=64,
                         capacity=capacity)
  y_train_batch = tf.one_hot(y_train_batch, 10)

  return X_train_batch, y_train_batch

在train.py文件中訓(xùn)練(下面不是純TF代碼,model.fit是Keras的擬合,用純TF的替換就好了)。

X_train_batch, y_train_batch = inp.get_batch(X_train, y_train, 
                       img_w, img_h, color_type, 
                       train_batch_size, capacity)
X_valid_batch, y_valid_batch = inp.get_batch(X_valid, y_valid, 
                       img_w, img_h, color_type, 
                       valid_batch_size, capacity)
with tf.Session() as sess:

  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  try:
    for step in np.arange(max_step):
      if coord.should_stop() :
        break
      X_train, y_train = sess.run([X_train_batch, 
                       y_train_batch])
      X_valid, y_valid = sess.run([X_valid_batch,
                       y_valid_batch])
       
      ckpt_path = 'log/weights-{val_loss:.4f}.hdf5'
      ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_path, 
                           monitor='val_loss', 
                           verbose=1, 
                           save_best_only=True, 
                           mode='min')
      model.fit(X_train, y_train, batch_size=64, 
             epochs=50, verbose=1,
             validation_data=(X_valid, y_valid),
             callbacks=[ckpt])
      
      del X_train, y_train, X_valid, y_valid

  except tf.errors.OutOfRangeError:
    print('done!')
  finally:
    coord.request_stop()
  coord.join(threads)
  sess.close()

Keras

keras文檔中對(duì)fit、predict、evaluate這些函數(shù)都有一個(gè)generator,這個(gè)generator就是解決分批問題的。

關(guān)鍵函數(shù):fit_generator

# 讀取圖片函數(shù)
def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True):
  '''
  參數(shù):
    paths:要讀取的圖片路徑列表
    img_rows:圖片行
    img_cols:圖片列
    color_type:圖片顏色通道
  返回: 
    imgs: 圖片數(shù)組
  '''
  # Load as grayscale
  imgs = []
  for path in paths:
    if color_type == 1:
      img = cv2.imread(path, 0)
    elif color_type == 3:
      img = cv2.imread(path)
    # Reduce size
    resized = cv2.resize(img, (img_cols, img_rows))
    if normalize:
      resized = resized.astype('float32')
      resized /= 127.5
      resized -= 1. 
    
    imgs.append(resized)
    
  return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type)

獲取批次函數(shù),其實(shí)就是一個(gè)generator

def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation):
  '''
  參數(shù):
    X_train:所有圖片路徑列表
    y_train: 所有圖片對(duì)應(yīng)的標(biāo)簽列表
    batch_size:批次
    img_w:圖片寬
    img_h:圖片高
    color_type:圖片類型
    is_argumentation:是否需要數(shù)據(jù)增強(qiáng)
  返回: 
    一個(gè)generator,x: 獲取的批次圖片 y: 獲取的圖片對(duì)應(yīng)的標(biāo)簽
  '''
  while 1:
    for i in range(0, len(X_train), batch_size):
      x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type)
      y = y_train[i:i+batch_size]
      if is_argumentation:
        # 數(shù)據(jù)增強(qiáng)
        x, y = img_augmentation(x, y)
      # 最重要的就是這個(gè)yield,它代表返回,返回以后循環(huán)還是會(huì)繼續(xù),然后再返回。就比如有一個(gè)機(jī)器一直在作累加運(yùn)算,但是會(huì)把每次累加中間結(jié)果告訴你一樣,直到把所有數(shù)加完
      yield({'input': x}, {'output': y})

訓(xùn)練函數(shù)

result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True), 
     steps_per_epoch=1351, 
     epochs=50, verbose=1,
     validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False),
     validation_steps=52,
     callbacks=[ckpt, early_stop],
     max_queue_size=capacity,
     workers=1)

就是這么簡(jiǎn)單。但是當(dāng)初從0到1的過程很難熬,每天都沒有進(jìn)展,沒有頭緒,急躁占據(jù)了思維的大部,熬過了這個(gè)階段,就會(huì)一切順利,不是運(yùn)氣,而是踩過的從0到1的每個(gè)腳印累積的靈感的爆發(fā),從0到1的腳印越多,后面的路越順利。

以上是TensorFlow和Keras大數(shù)據(jù)量?jī)?nèi)存溢出怎么辦的所有內(nèi)容,感謝各位的閱讀!希望分享的內(nèi)容對(duì)大家有幫助,更多相關(guān)知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI