Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
!conda install matplotlib -y
!conda install -c pytorch pytorch -y
from mpl_toolkits import mplot3d
import torch
import numpy
import re
import matplotlib.pyplot as plt
#p = re.compile('(\[ *-?\d+ *, *-?\d+ *, *-?\d+ *])')
p = re.compile('\[ *(-?\d+) *, *(-?\d+) *, *(-?\d+) *]')
with open('dataset.txt', 'r') as f:
input = f.readlines()
class Point:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
@classmethod
def from_string(cls, str):
split_str = str.split(',')
#print('str: ', str)
#print('split_str: ', split_str)
x = int(split_str[0].split('=')[1])
y = int(split_str[1].split('=')[1])
z = int(split_str[2].split('=')[1])
return cls(x, y, z)
def __eq__(self, other):
return self.x == other.x and\
self.y == other.y and\
self.z == other.z
def __hash__(self):
return hash(str(self.x) + str(self.y) + str(self.z))
def __str__(self):
return 'X={x}, Y={y}, Z={z}'.format(x=self.x, y=self.y, z=self.z)
def get_data_array(start_idx, input):
elements = []
for i in range(start_idx, start_idx + 9):
elements += p.findall(input[i])
x_arr = torch.from_numpy(numpy.array([x for x, _, _ in elements], dtype=numpy.float))
y_arr = torch.from_numpy(numpy.array([y for _, y, _ in elements], dtype=numpy.float))
z_arr = torch.from_numpy(numpy.array([z for _, _, z in elements], dtype=numpy.float))
def process(arr):
arr = arr + 8192
arr = arr.pow(1./3)
return arr
return (x_arr, y_arr, z_arr)
i = 0
data = {}
dataset = []
while i < len(input):
if input[i][0] == 'X':
point = Point.from_string(input[i])
data_arr = get_data_array(i, input)
data[point] = data_arr
dataset.append((point, data_arr))
i += 9
elif input[i][0] == '\n':
i += 1
#print('data: ', data)
#print('data set size: ', len(data))
def display(point):
x, y, z = data[point]
print('Point: ', point)
print('x.shape: ', x.shape)
plt.matshow(x.reshape(8, 32))
plt.colorbar()
plt.show()
plt.matshow(y.reshape(8, 32))
plt.colorbar()
plt.show()
plt.matshow(z.reshape(8, 32))
plt.colorbar()
plt.show()
def plot3d(point):
x, y, z = data[point]
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(x, y, z)
def quiver(point):
x, y, z = data[point]
origin = [0], [0]
#list = [[xe, ye] for xe, ye in zip(x, y)]
u = torch.cos(x)
v = torch.sin(y)
plt.quiver(x, y, u, v, scale=13)
plt.show()
"""
p = Point(70, 0, 100)
quiver(p)
#display(p)
#plot3d(p)
p = Point(-11, -55, 15)
quiver(p)
#display(p)
#plot3d(p)
p = Point(0, 0, 60)
quiver(p)
#display(p)
#plot3d(p)
p = Point(-15, -25, 45)
quiver(p)
#display(p)
p = Point(-5, -5, 45)
quiver(p)
p = Point(0, -35, 60)
quiver(p)
"""
'\np = Point(70, 0, 100)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(-11, -55, 15)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(0, 0, 60)\nquiver(p)\n#display(p)\n#plot3d(p)\n\np = Point(-15, -25, 45)\nquiver(p)\n#display(p)\n\np = Point(-5, -5, 45)\nquiver(p)\n\np = Point(0, -35, 60)\nquiver(p)\n'
import torch.nn as nn
from torch.utils.data import random_split