File size: 2,781 Bytes
f4d7da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import torch
import torch.nn as nn

class GaussianProcessRegressor(nn.Module):
    def __init__(
        self,
        length_scale=1.0,
        noise_scale=1.0,
        amplitude_scale=1.0,
    ):
        super().__init__()
        if isinstance(length_scale, float):
            length_scale = np.array([length_scale])
        elif isinstance(length_scale, np.ndarray):
            assert length_scale.ndim == 1
        else:
            raise TypeError()

        self.register_parameter(
            "length_scale_",
            param=nn.Parameter(torch.Tensor(np.log(length_scale)), requires_grad=True),
        )
        self.register_parameter(
            "noise_scale_",
            param=nn.Parameter(torch.tensor(np.log(noise_scale)), requires_grad=True),
        )
        self.register_parameter(
            "amplitude_scale_",
            param=nn.Parameter(
                torch.tensor(np.log(amplitude_scale)), requires_grad=True
            ),
        )

        self.nll = None

    def forward(self, x):
        alpha = self.alpha
        k = self.Kxy(self.X, x)
        mu = k.T.mm(alpha)
        return mu

    def log_marginal_likelihood(self, X, y):
        D = X.shape[1]
        K = self.Kxx(X)
        L = torch.linalg.cholesky(K)
        alpha = torch.linalg.solve(L.T, torch.linalg.solve(L, y))
        marginal_likelihood = (
            -0.5 * y.T.mm(alpha)
            - torch.log(torch.diag(L)).sum()
            - D * 0.5 * np.log(2 * np.pi)
        )
        self.L = L
        self.alpha = alpha
        self.K = K
        return marginal_likelihood

    def Kxx(self, X):
        param = self.length_scale_.exp().sqrt()
        sqdist = torch.cdist(X / param[None], X / param[None]) ** 2

        res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist) + self.noise_scale_.exp() * torch.eye(len(X)).type_as(X)

        return res

    def Kxy(self, X, Z):
        param = self.length_scale_.exp().sqrt()
        sqdist = torch.cdist(X / param[None], Z / param[None]) ** 2
        res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist)

        return res

    def fit(self, X, y, opt, num_steps):
        assert X.shape[1] == len(self.length_scale_)
        self.y = y
        self.X = X

        scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9)

        self.train()
        nll_hist = []
        for it in range(num_steps):
            opt.zero_grad()
            try:
                nll = -self.log_marginal_likelihood(self.X, self.y).sum()
            except torch.linalg.LinAlgError:
                break
            nll.backward()
            opt.step()
            if it%10==0 and it<1000:
                scheduler.step()
            nll_hist.append(nll.item())
        return nll_hist