scikit-learn 学习(三)可视化股市结构

本例使用一些无监督学习技术从历史报价的变化来提取股市结构。
我们使用的数量是报价的每日变化:相关的报价往往会在一天中出现波动。

学习一个图结构

我们使用稀疏逆协方差估计来寻找哪些报价与其他条件相关。特别地,稀疏逆协方差给出了一个图,它是一个关联的表。对给一个符号,它所连接的符号也能解释它的波动。

聚类

我们使用聚类来聚集相似的报价。这里,scikit-learn 中有各种可用的聚类技术,我们使用Affinity Propagation(吸引子传播),因为它不强制要求相同大小的类,并且可以自动从数据中选择类的数量。

请注意,这给了我们一个不同于图表的指示,因为图表反映了变量之间的条件关系,而聚类则反映了边际性质:聚集在一起的变量可以被认为在整个股票市场上具有类似的影响。

嵌入到2D空间

为了可视化,我们需要在2D画布上放置不同的符号。 为此,我们使用流形学习技术来进行二维嵌入。

可视化

3个模型的输出结合在一个2D图中,其中节点表示股票和边:

  • 簇标签用于定义节点的颜色
  • 稀疏协方差模型用于显示边的强度
  • 二维嵌入用于定位平面中的节点

这个例子有相当多的与可视化相关的代码,因为可视化对于显示图形至关重要。 其中一个挑战是定位标签,尽量减少重叠。 为此,我们使用基于每个轴最近邻的方向的启发式方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
from __future__ import print_function

# Author: Gael Varoquaux gael.varoquaux@normalesup.org
# License: BSD 3 clause

import sys
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from six.moves.urllib.request import urlopen
from six.moves.urllib.parse import urlencode
from sklearn import cluster, covariance, manifold

import requests
import os

print(__doc__)


def download(symbol, start_date, end_date):
if os.path.exists(symbol + '.csv'):
return
params = {
'q': symbol,
'startdate': start_date.strftime('%Y-%m-%d'),
'enddate': end_date.strftime('%Y-%m-%d'),
'output': 'csv',
}

# 代理
proxies = {
'https': 'https://127.0.0.1:8118',
'http': 'http://127.0.0.1:8118'
}

url = 'https://finance.google.com/finance/historical?' + urlencode(params)
response = requests.get(url, proxies=proxies)
with open(symbol + '.csv', "wb") as code:
code.write(response.content)


def retry(f, n_attempts=3):
"Wrapper function to retry function calls in case of exceptions"
def wrapper(*args, **kwargs):
for i in range(n_attempts):
try:
return f(*args, **kwargs)
except Exception:
if i == n_attempts - 1:
raise
return wrapper


def quotes_historical_google(symbol, start_date, end_date):
"""Get the historical data from Google finance.

Parameters
----------
symbol : str
Ticker symbol to query for, for example ``"DELL"``.
start_date : datetime.datetime
Start date.
end_date : datetime.datetime
End date.

Returns
-------
X : array
The columns are ``date`` -- date, ``open``, ``high``,
``low``, ``close`` and ``volume`` of type float.
"""
params = {
'q': symbol,
'startdate': start_date.strftime('%Y-%m-%d'),
'enddate': end_date.strftime('%Y-%m-%d'),
'output': 'csv',
}

url = 'file:///home/chiz/scikit-learn/' + symbol + '.csv'
response = urlopen(url)
dtype = {
'names': ['date', 'open', 'high', 'low', 'close', 'volume'],
'formats': ['object', 'f4', 'f4', 'f4', 'f4', 'f4']
}
converters = {
0: lambda s: datetime.strptime(s.decode(), '%d-%b-%y').date()}
data = np.genfromtxt(response, delimiter=',', skip_header=1,
dtype=dtype, converters=converters,
missing_values='-', filling_values=-1)
min_date = min(data['date'], default=datetime.min.date())
max_date = max(data['date'], default=datetime.max.date())
start_end_diff = (end_date - start_date).days
min_max_diff = (max_date - min_date).days
data_is_fine = (
start_date <= min_date <= end_date and
start_date <= max_date <= end_date and
start_end_diff - 7 <= min_max_diff <= start_end_diff)

if not data_is_fine:
message = (
'Data looks wrong for symbol {}, url {}\n'
' - start_date: {}, end_date: {}\n'
' - min_date: {}, max_date: {}\n'
' - start_end_diff: {}, min_max_diff: {}'.format(
symbol, url,
start_date, end_date,
min_date, max_date,
start_end_diff, min_max_diff))
raise RuntimeError(message)
return data

# #############################################################################
# Retrieve the data from Internet

# Choose a time period reasonably calm (not too long ago so that we get
# high-tech firms, and before the 2008 crash)
start_date = datetime(2003, 1, 1).date()
end_date = datetime(2008, 1, 1).date()

symbol_dict = {
'NYSE:TOT': 'Total',
'NYSE:XOM': 'Exxon',
'NYSE:CVX': 'Chevron',
'NYSE:COP': 'ConocoPhillips',
'NYSE:VLO': 'Valero Energy',
'NASDAQ:MSFT': 'Microsoft',
'NYSE:IBM': 'IBM',
'NYSE:TWX': 'Time Warner',
'NASDAQ:CMCSA': 'Comcast',
'NYSE:CVC': 'Cablevision',
'NASDAQ:YHOO': 'Yahoo',
'NASDAQ:DELL': 'Dell',
'NYSE:HPQ': 'HP',
'NASDAQ:AMZN': 'Amazon',
'NYSE:TM': 'Toyota',
'NYSE:CAJ': 'Canon',
'NYSE:SNE': 'Sony',
'NYSE:F': 'Ford',
'NYSE:HMC': 'Honda',
'NYSE:NAV': 'Navistar',
'NYSE:NOC': 'Northrop Grumman',
'NYSE:BA': 'Boeing',
'NYSE:KO': 'Coca Cola',
'NYSE:MMM': '3M',
'NYSE:MCD': 'McDonald\'s',
'NYSE:PEP': 'Pepsi',
'NYSE:K': 'Kellogg',
'NYSE:UN': 'Unilever',
'NASDAQ:MAR': 'Marriott',
'NYSE:PG': 'Procter Gamble',
'NYSE:CL': 'Colgate-Palmolive',
'NYSE:GE': 'General Electrics',
'NYSE:WFC': 'Wells Fargo',
'NYSE:JPM': 'JPMorgan Chase',
'NYSE:AIG': 'AIG',
'NYSE:AXP': 'American express',
'NYSE:BAC': 'Bank of America',
'NYSE:GS': 'Goldman Sachs',
'NASDAQ:AAPL': 'Apple',
'NYSE:SAP': 'SAP',
'NASDAQ:CSCO': 'Cisco',
'NASDAQ:TXN': 'Texas Instruments',
'NYSE:XRX': 'Xerox',
'NYSE:WMT': 'Wal-Mart',
'NYSE:HD': 'Home Depot',
'NYSE:GSK': 'GlaxoSmithKline',
'NYSE:PFE': 'Pfizer',
'NYSE:SNY': 'Sanofi-Aventis',
'NYSE:NVS': 'Novartis',
'NYSE:KMB': 'Kimberly-Clark',
'NYSE:R': 'Ryder',
'NYSE:GD': 'General Dynamics',
'NYSE:RTN': 'Raytheon',
'NYSE:CVS': 'CVS',
'NYSE:CAT': 'Caterpillar',
'NYSE:DD': 'DuPont de Nemours'}


symbols, names = np.array(sorted(symbol_dict.items())).T

# retry is used because quotes_historical_google can temporarily fail
# for various reasons (e.g. empty result from Google API).
quotes = []

for symbol in symbols:
print('Fetching quote history for %r' % symbol, file=sys.stderr)
download(symbol, start_date, end_date)

for symbol in symbols:
quotes.append(retry(quotes_historical_google)(
symbol, start_date, end_date))


close_prices = np.vstack([q['close'] for q in quotes])
open_prices = np.vstack([q['open'] for q in quotes])

# The daily variations of the quotes are what carry most information
variation = close_prices - open_prices


# #############################################################################
# Learn a graphical structure from the correlations
edge_model = covariance.GraphLassoCV()

# standardize the time series: using correlations rather than covariance
# is more efficient for structure recovery
X = variation.copy().T
X /= X.std(axis=0)
edge_model.fit(X)

# #############################################################################
# Cluster using affinity propagation

_, labels = cluster.affinity_propagation(edge_model.covariance_)
n_labels = labels.max()

for i in range(n_labels + 1):
print('Cluster %i: %s' % ((i + 1), ', '.join(names[labels == i])))

# #############################################################################
# Find a low-dimension embedding for visualization: find the best position of
# the nodes (the stocks) on a 2D plane

# We use a dense eigen_solver to achieve reproducibility (arpack is
# initiated with random vectors that we don't control). In addition, we
# use a large number of neighbors to capture the large-scale structure.
node_position_model = manifold.LocallyLinearEmbedding(
n_components=2, eigen_solver='dense', n_neighbors=6)

embedding = node_position_model.fit_transform(X.T).T

# #############################################################################
# Visualization
plt.figure(1, facecolor='w', figsize=(10, 8))
plt.clf()
ax = plt.axes([0., 0., 1., 1.])
plt.axis('off')

# Display a graph of the partial correlations
partial_correlations = edge_model.precision_.copy()
d = 1 / np.sqrt(np.diag(partial_correlations))
partial_correlations *= d
partial_correlations *= d[:, np.newaxis]
non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02)

# Plot the nodes using the coordinates of our embedding
plt.scatter(embedding[0], embedding[1], s=100 * d ** 2, c=labels,
cmap=plt.cm.spectral)

# Plot the edges
start_idx, end_idx = np.where(non_zero)
# a sequence of (*line0*, *line1*, *line2*), where::
# linen = (x0, y0), (x1, y1), ... (xm, ym)
segments = [[embedding[:, start], embedding[:, stop]]
for start, stop in zip(start_idx, end_idx)]
values = np.abs(partial_correlations[non_zero])
lc = LineCollection(segments,
zorder=0, cmap=plt.cm.hot_r,
norm=plt.Normalize(0, .7 * values.max()))
lc.set_array(values)
lc.set_linewidths(15 * values)
ax.add_collection(lc)

# Add a label to each node. The challenge here is that we want to
# position the labels to avoid overlap with other labels
for index, (name, label, (x, y)) in enumerate(
zip(names, labels, embedding.T)):

dx = x - embedding[0]
dx[index] = 1
dy = y - embedding[1]
dy[index] = 1
this_dx = dx[np.argmin(np.abs(dy))]
this_dy = dy[np.argmin(np.abs(dx))]
if this_dx > 0:
horizontalalignment = 'left'
x = x + .002
else:
horizontalalignment = 'right'
x = x - .002
if this_dy > 0:
verticalalignment = 'bottom'
y = y + .002
else:
verticalalignment = 'top'
y = y - .002
plt.text(x, y, name, size=10,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
bbox=dict(facecolor='w',
edgecolor=plt.cm.spectral(label / float(n_labels)),
alpha=.6))

plt.xlim(embedding[0].min() - .15 * embedding[0].ptp(),
embedding[0].max() + .10 * embedding[0].ptp(),)
plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(),
embedding[1].max() + .03 * embedding[1].ptp())

plt.show()
Automatically created module for IPython interactive environment


Fetching quote history for 'NASDAQ:AAPL'
Fetching quote history for 'NASDAQ:AMZN'
Fetching quote history for 'NASDAQ:CMCSA'
Fetching quote history for 'NASDAQ:CSCO'
Fetching quote history for 'NASDAQ:DELL'
Fetching quote history for 'NASDAQ:MAR'
Fetching quote history for 'NASDAQ:MSFT'
Fetching quote history for 'NASDAQ:TXN'
Fetching quote history for 'NASDAQ:YHOO'
Fetching quote history for 'NYSE:AIG'
Fetching quote history for 'NYSE:AXP'
Fetching quote history for 'NYSE:BA'
Fetching quote history for 'NYSE:BAC'
Fetching quote history for 'NYSE:CAJ'
Fetching quote history for 'NYSE:CAT'
Fetching quote history for 'NYSE:CL'
Fetching quote history for 'NYSE:COP'
Fetching quote history for 'NYSE:CVC'
Fetching quote history for 'NYSE:CVS'
Fetching quote history for 'NYSE:CVX'
Fetching quote history for 'NYSE:DD'
Fetching quote history for 'NYSE:F'
Fetching quote history for 'NYSE:GD'
Fetching quote history for 'NYSE:GE'
Fetching quote history for 'NYSE:GS'
Fetching quote history for 'NYSE:GSK'
Fetching quote history for 'NYSE:HD'
Fetching quote history for 'NYSE:HMC'
Fetching quote history for 'NYSE:HPQ'
Fetching quote history for 'NYSE:IBM'
Fetching quote history for 'NYSE:JPM'
Fetching quote history for 'NYSE:K'
Fetching quote history for 'NYSE:KMB'
Fetching quote history for 'NYSE:KO'
Fetching quote history for 'NYSE:MCD'
Fetching quote history for 'NYSE:MMM'
Fetching quote history for 'NYSE:NAV'
Fetching quote history for 'NYSE:NOC'
Fetching quote history for 'NYSE:NVS'
Fetching quote history for 'NYSE:PEP'
Fetching quote history for 'NYSE:PFE'
Fetching quote history for 'NYSE:PG'
Fetching quote history for 'NYSE:R'
Fetching quote history for 'NYSE:RTN'
Fetching quote history for 'NYSE:SAP'
Fetching quote history for 'NYSE:SNE'
Fetching quote history for 'NYSE:SNY'
Fetching quote history for 'NYSE:TM'
Fetching quote history for 'NYSE:TOT'
Fetching quote history for 'NYSE:TWX'
Fetching quote history for 'NYSE:UN'
Fetching quote history for 'NYSE:VLO'
Fetching quote history for 'NYSE:WFC'
Fetching quote history for 'NYSE:WMT'
Fetching quote history for 'NYSE:XOM'
Fetching quote history for 'NYSE:XRX'



Cluster 1: Apple, Amazon, Yahoo
Cluster 2: Cisco, Dell, Microsoft, Texas Instruments, HP, IBM, SAP
Cluster 3: American express
Cluster 4: Boeing
Cluster 5: Cablevision
Cluster 6: ConocoPhillips, Chevron, Total, Valero Energy, Exxon
Cluster 7: Comcast, Marriott, AIG, Bank of America, CVS, DuPont de Nemours, Ford, General Electrics, Goldman Sachs, Home Depot, JPMorgan Chase, McDonald's, 3M, Pfizer, Ryder, Wells Fargo, Wal-Mart
Cluster 8: Navistar
Cluster 9: General Dynamics, Northrop Grumman, Raytheon
Cluster 10: GlaxoSmithKline, Novartis, Sanofi-Aventis
Cluster 11: Kellogg, Coca Cola, Pepsi
Cluster 12: Colgate-Palmolive, Kimberly-Clark, Procter Gamble
Cluster 13: Canon, Caterpillar, Honda, Sony, Toyota, Unilever, Xerox
Cluster 14: Time Warner

png

分享到