BAIR Robot Pushing Datasetについて

BAIR Robot Pushing Datasetは、ロボットを押す動画で構成されたデータセットである。

約44,000の例を含む。

 

Tensorflowの公式サイトにも説明がある。

bair_robot_pushing_small  |  TensorFlow Datasets

 

公式サイトにも示されている様に、1つのデータの集まりにおいて2つのシーケンスが含まれている(image_aux1 と image_main)。

'image_aux1': Image(shape=(64, 64, 3), dtype=tf.uint8),
'image_main': Image(shape=(64, 64, 3), dtype=tf.uint8),

 

例えばimage_aux1は

f:id:jskangaroo:20201018180019j:plain ...

image_mainは

f:id:jskangaroo:20201018180045j:plain ...

のように、別の角度から撮影された動画がセットとなっている。

 

BAIR Robot Pushing Datasetの容量は30GBほどで、以下のサイトからダウンロードできる。

https://sites.google.com/view/sna-visual-mpc/

 

(私は以下のgithubを参考にしたため、別方法でダウンロードしている)

https://github.com/alexlee-gk/video_prediction

 

ダウンロードはmp4やpngではなく、tensorflow独自の.tfrecordsという形式となっているはずだ。

 

tfrecordsを画像に戻すのには苦戦したが、以下のサイトでできることが確認できた。

https://datascience.stackexchange.com/questions/52061/how-to-download-bair-action-free-robot-pushing-dataset

 

サイトが削除されたときのことを考えて、以下にもコードを添付する。 

赤文字で示した部分のみ変更すれば上手くいくはずだ。

上手く行けばtfrecordsに記録されている画像と動画を得ることができる。

(skvideoがインストールされていない場合は pip install sk-video が必要)

 

import datetime
import os
import time

import cv2
import numpy as np
import skvideo.io
import tensorflow as tf
from PIL import Image
from tensorflow.python.platform import gfile

def get_next_video_data(data_dir):
    filenames = gfile.Glob(os.path.join(data_dir, '*'))
    if not filenames:
        raise RuntimeError('No data files found.')

    for f in filenames:
        k = 0
        for serialized_example in tf.python_io.tf_record_iterator(f):
            example = tf.train.Example()
            example.ParseFromString(serialized_example)
            # print(example)        # To know what all features are present

            actions = np.empty((0, 4), dtype='float')
            endeffector_positions = np.empty((0, 3), dtype='float')
            frames_aux1 = []
            frames_main = []
            i = 0
            while True:
                action_name = str(i) + '/action'
                action_value = np.array(example.features.feature[action_name].float_list.value)
                if action_value.shape == (0,):      # End of frames/data
                    break
                actions = np.vstack((actions, action_value))

                endeffector_pos_name = str(i) + '/endeffector_pos'
                endeffector_pos_value = list(example.features.feature[endeffector_pos_name].float_list.value)
                endeffector_positions = np.vstack((endeffector_positions, endeffector_pos_value))

                aux1_image_name = str(i) + '/image_aux1/encoded'
                aux1_byte_str = example.features.feature[aux1_image_name].bytes_list.value[0]
                aux1_img = Image.frombytes('RGB', (64, 64), aux1_byte_str)
                aux1_arr = np.array(aux1_img.getdata()).reshape((aux1_img.size[1], aux1_img.size[0], 3))
                frames_aux1.append(aux1_arr.reshape(1, 64, 64, 3))

                main_image_name = str(i) + '/image_main/encoded'
                main_byte_str = example.features.feature[main_image_name].bytes_list.value[0]
                main_img = Image.frombytes('RGB', (64, 64), main_byte_str)
                main_arr = np.array(main_img.getdata()).reshape((main_img.size[1], main_img.size[0], 3))
                frames_main.append(main_arr.reshape(1, 64, 64, 3))
                i += 1

            np_frames_aux1 = np.concatenate(frames_aux1, axis=0)
            np_frames_main = np.concatenate(frames_main, axis=0)
            yield f, k, actions, endeffector_positions, np_frames_aux1, np_frames_main
            k = k + 1


def extract_data(data_dir, output_dir, frame_rate):
    """
    Extracts data in tfrecord format to gifs, frames and text files
    :param data_dir:
    :param output_dir:
    :param frame_rate:
    :return:
    """
    if os.path.exists(output_dir):
        if os.listdir(output_dir):
            raise RuntimeError('Directory not empty: {0}'.format(output_dir))
    else:
        os.makedirs(output_dir)

    seq_generator = get_next_video_data(data_dir)
    while True:
        try:
            _, k, actions, endeff_pos, aux1_frames, main_frames = next(seq_generator)
        except StopIteration:
            break
        video_out_dir = os.path.join(output_dir, '{0:03}'.format(k))
        os.makedirs(video_out_dir)

        # noinspection PyTypeChecker
        np.savetxt(os.path.join(video_out_dir, 'actions.csv'), actions, delimiter=',')
        # noinspection PyTypeChecker
        np.savetxt(os.path.join(video_out_dir, 'endeffector_positions.csv'), endeff_pos, delimiter=',')
        skvideo.io.vwrite(os.path.join(video_out_dir, 'aux1.gif'), aux1_frames, inputdict={'-r': str(frame_rate)})
        skvideo.io.vwrite(os.path.join(video_out_dir, 'main.gif'), main_frames, inputdict={'-r': str(frame_rate)})
        skvideo.io.vwrite(os.path.join(video_out_dir, 'aux1.mp4'), aux1_frames, inputdict={'-r': str(frame_rate)})
        skvideo.io.vwrite(os.path.join(video_out_dir, 'main.mp4'), main_frames, inputdict={'-r': str(frame_rate)})

        # Save frames
        aux1_folder_path = os.path.join(video_out_dir, 'aux1_frames')
        os.makedirs(aux1_folder_path)
        for i, frame in enumerate(aux1_frames):
            filepath = os.path.join(aux1_folder_path, 'frame_{0:03}.bmp'.format(i))
            cv2.imwrite(filepath, cv2.cvtColor(frame.astype('uint8'), cv2.COLOR_RGB2BGR))
        main_folder_path = os.path.join(video_out_dir, 'main_frames')
        os.makedirs(main_folder_path)
        for i, frame in enumerate(main_frames):
            filepath = os.path.join(main_folder_path, 'frame_{0:03}.bmp'.format(i))
            cv2.imwrite(filepath, cv2.cvtColor(frame.astype('uint8'), cv2.COLOR_RGB2BGR))
        print('Saved video: {0:03}'.format(k))


def main():
    data_dir = './data/bair/test'
    output_dir = './ExtractedData/test'
    frame_rate = 4
    extract_data(data_dir, output_dir, frame_rate)
    return


if __name__ == '__main__':
    print('Program started at ' + datetime.datetime.now().strftime('%d/%m/%Y %I:%M:%S %p'))
    start_time = time.time()
    main()
    end_time = time.time()
    print('Program ended at ' + datetime.datetime.now().strftime('%d/%m/%Y %I:%M:%S %p'))
    print('Execution time: ' + str(datetime.timedelta(seconds=end_time - start_time)))

 

 (code copied from https://datascience.stackexchange.com/questions/52061/how-to-download-bair-action-free-robot-pushing-dataset)

 

この様にBAIRデータセットから画像を得られる。