Skip to content

Commit 616ff8b

Browse files
committed
Updated ranks in tensorkrylov.py
1 parent cb62516 commit 616ff8b

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

examples/tensorkrylov.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.random import seed, rand
77
from scikit_tt.tensor_train import TT, build_core, residual_error, uniform, residual_error
88
from scikit_tt.solvers.sle import als
9+
from time import time
910

1011
class MatrixCollection(object):
1112

@@ -167,9 +168,13 @@ def _update_approximation(x_TT: "TT", V: "MatrixCollection", y_TT: "TT"):
167168
for s in range(x_TT.order):
168169

169170
x_TT.cores[s] = np.sum(V[s][None, :, :, None, None] @ y_TT.cores[s][:, None, :, :, :], axis = 2)
171+
x_TT.ranks[s] = x_TT.cores[s].shape[0]
172+
x_TT.ranks[s + 1] = x_TT.cores[s].shape[3]
173+
170174

171175
return
172176

177+
#def _residual_norm()
173178

174179
def symmetric_tensorkrylov(A: "MatrixCollection", b: List[np.ndarray], rank: int, nmax: int, tol = 1e-9):
175180

@@ -213,11 +218,9 @@ def symmetric_tensorkrylov(A: "MatrixCollection", b: List[np.ndarray], rank: int
213218

214219
y_TT = als(TT_operator, TT_guess, TT_rhs)
215220
_update_approximation(x_TT, V_minors, y_TT)
216-
print(A_TT)
217-
print(x_TT)
218-
print(b_TT)
219-
r_norm = residual_error(A_TT, x_TT, b_TT)
221+
#r_norm = residual_error(A_TT, x_TT, b_TT)
220222

223+
#print(r_norm)
221224
if r_norm <= tol:
222225

223226
return x_TT
@@ -239,8 +242,23 @@ def random_rhs(n: int):
239242

240243
A = MatrixCollection([ As for _ in range(d) ])
241244
b = [ bs for _ in range(d) ]
242-
rank = 5
245+
rank = 8
243246
ranks = [1] + ([rank] * (d - 1)) + [1]
244247

248+
row_dims = [n for _ in range(d)]
249+
col_dims = [1 for _ in range(d)]
250+
251+
x_TT = scikit_tt.tensor_train.rand(row_dims, col_dims, ranks)
252+
A_TT = _TT_operator(A, n - 1)
253+
b_TT = _TT_rhs(b)
254+
255+
256+
start = time()
257+
x_TT = als(A_TT, x_TT, b_TT)
258+
end = time()
259+
print("Done", end - start)
260+
261+
print(residual_error(A_TT, x_TT, b_TT))
262+
245263

246264
print(symmetric_tensorkrylov(A, b, rank, n, tol = 1e-9))

0 commit comments

Comments
 (0)