2017年9月27日水曜日

PyTorch練習 02日目 2

単純な線形回帰をPyTorchで実装する.

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

torch.manual_seed(1)
class LinearRegression(object):
    """
    二乗誤差をloss functionとした線形回帰

    methods:
        fit(X, y, lr, n_iter) fit linear model
            X: data matrix(numpy.array)
            y: data matrix(numpy.array)
            lr: learning rate, defalt 1e-3
            n_iter: number of iteration, default 5000
        get_params() get parameters for this estimator
            return: 
        predict(X) predict using the linear model
            X: data matrix(numpy.array) 
    """

    def __init__(self):
        pass

    def fit(self, X, y, lr=1e-3, n_iter=5000):
        X = Variable(torch.from_numpy(X).float())      # 何かエラーが起きたら手当り次第
        y = Variable(torch.from_numpy(y).float())      # flaot()にすると解決するかも
        self.model = torch.nn.Sequential(              # 線形レイヤー1枚を線形回帰の
            torch.nn.Linear(X.size()[1], y.size()[1])) # パラメータ学習器にする
        loss_fn = torch.nn.MSELoss(size_average=False)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        for t in range(n_iter):
            y_pred = self.model(X)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    def get_params(self):
        try:
            return list(model.parameters())
        except ValueError:
            print("We need to fit any data first")

    def predict(self,X):
        X = Variable(torch.from_numpy(X).float())
        try:
            return self.model(X)
        except ValueError:
            print("We need to fit any data first")
x1 = np.linspace(-1, 1, 100)
x2 = np.linspace(-1, 1, 100)
X = np.vstack((x1, x2)).T
y = 2 * x1 + 1 * x2 + 0.5 * np.random.randn(100) + 1
y = y.reshape([100, 1])

lr = LinearRegression()
lr.fit(X, y)

fig = plt.figure(figsize=(9, 9))
ax = fig.add_subplot(111, projection='3d')


ax.scatter(x1, x2, y, label='data')
ax.scatter(x1, x2, lr.predict(X).data.numpy(), label='prediction')
plt.legend()
plt.show()

enter image description here

X = np.linspace(-1, 1, 100).reshape([100, 1])
y = 2 * X + 0.5 * np.random.randn(100, 1)
lr = LinearRegression()
lr.fit(X, y)
plt.scatter(X, y, label='data')
plt.scatter(X, lr.predict(X).data.numpy(), label='predict')
plt.legend()
plt.show()

enter image description here

0 件のコメント:

コメントを投稿