Learn practical skills, build real-world projects, and advance your career
!conda install matplotlib -y
!conda install -c pytorch pytorch -y
from mpl_toolkits import mplot3d
import torch
import numpy
import re
import matplotlib.pyplot as plt

#p = re.compile('(\[ *-?\d+ *, *-?\d+ *, *-?\d+ *])')
p = re.compile('\[ *(-?\d+) *, *(-?\d+) *, *(-?\d+) *]')

with open('dataset.txt', 'r') as f:
    input = f.readlines()

class Point:
    
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    
    @classmethod
    def from_string(cls, str):
        split_str = str.split(',')
        #print('str: ', str)
        #print('split_str: ', split_str)
        x = int(split_str[0].split('=')[1])
        y = int(split_str[1].split('=')[1])
        z = int(split_str[2].split('=')[1])
        return cls(x, y, z)
        

    def __eq__(self, other):
        return self.x == other.x and\
                self.y == other.y and\
                self.z == other.z
    
    def __hash__(self):
        return hash(str(self.x) + str(self.y) + str(self.z))
    
    def __str__(self):
        return 'X={x}, Y={y}, Z={z}'.format(x=self.x, y=self.y, z=self.z)

def get_data_array(start_idx, input):
    elements = []
    for i in range(start_idx, start_idx + 9):
        elements += p.findall(input[i])

    x_arr = torch.from_numpy(numpy.array([x for x, _, _ in elements], dtype=numpy.float))
    y_arr = torch.from_numpy(numpy.array([y for _, y, _ in elements], dtype=numpy.float))
    z_arr = torch.from_numpy(numpy.array([z for _, _, z in elements], dtype=numpy.float))
    
    def process(arr):
        arr = arr + 8192
        arr = arr.pow(1./3)
        return arr
    
    return (x_arr, y_arr, z_arr)

i = 0
data = {}
dataset = []
while i < len(input):
    if input[i][0] == 'X':
        point = Point.from_string(input[i])
        data_arr = get_data_array(i, input)
        data[point] = data_arr
        dataset.append((point, data_arr))
        i += 9
    elif input[i][0] == '\n':
        i += 1


#print('data: ', data)
#print('data set size: ', len(data))

def display(point):
    x, y, z = data[point]
    print('Point: ', point)
    print('x.shape: ', x.shape)
    plt.matshow(x.reshape(8, 32))
    plt.colorbar()
    plt.show()
    plt.matshow(y.reshape(8, 32))
    plt.colorbar()
    plt.show()
    plt.matshow(z.reshape(8, 32))
    plt.colorbar()
    plt.show()

def plot3d(point):
    x, y, z = data[point]
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot3D(x, y, z)
    

def quiver(point):
    x, y, z = data[point]
    origin = [0], [0]

    #list = [[xe, ye] for xe, ye in zip(x, y)]
    
    u = torch.cos(x)
    v = torch.sin(y)
    plt.quiver(x, y, u, v, scale=13)
    plt.show()
"""
p = Point(70, 0, 100)
quiver(p)
#display(p)
#plot3d(p)

p = Point(-11, -55, 15)
quiver(p)
#display(p)
#plot3d(p)

p = Point(0, 0, 60)
quiver(p)
#display(p)
#plot3d(p)

p = Point(-15, -25, 45)
quiver(p)
#display(p)

p = Point(-5, -5, 45)
quiver(p)

p = Point(0, -35, 60)
quiver(p)
"""
'\np = Point(70, 0, 100)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(-11, -55, 15)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(0, 0, 60)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(-15, -25, 45)\nquiver(p)\n#display(p)\n\np = Point(-5, -5, 45)\nquiver(p)\n\np = Point(0, -35, 60)\nquiver(p)\n'
import torch.nn as nn
from torch.utils.data import random_split