1
22
23 package com.liferay.portal.dao.shard;
24
25 import com.liferay.counter.service.persistence.CounterPersistence;
26 import com.liferay.portal.NoSuchCompanyException;
27 import com.liferay.portal.PortalException;
28 import com.liferay.portal.SystemException;
29 import com.liferay.portal.kernel.log.Log;
30 import com.liferay.portal.kernel.log.LogFactoryUtil;
31 import com.liferay.portal.kernel.util.InitialThreadLocal;
32 import com.liferay.portal.kernel.util.StringPool;
33 import com.liferay.portal.kernel.util.StringUtil;
34 import com.liferay.portal.model.Company;
35 import com.liferay.portal.model.Shard;
36 import com.liferay.portal.security.auth.CompanyThreadLocal;
37 import com.liferay.portal.service.CompanyLocalServiceUtil;
38 import com.liferay.portal.service.ShardLocalServiceUtil;
39 import com.liferay.portal.service.persistence.ClassNamePersistence;
40 import com.liferay.portal.service.persistence.CompanyPersistence;
41 import com.liferay.portal.service.persistence.ReleasePersistence;
42 import com.liferay.portal.service.persistence.ShardPersistence;
43 import com.liferay.portal.util.PropsValues;
44
45 import java.util.HashMap;
46 import java.util.Map;
47 import java.util.Stack;
48
49 import javax.sql.DataSource;
50
51 import org.aspectj.lang.ProceedingJoinPoint;
52
53
60 public class ShardAdvice {
61
62 public Object invokeAccountService(ProceedingJoinPoint proceedingJoinPoint)
63 throws Throwable {
64
65 String methodName = proceedingJoinPoint.getSignature().getName();
66 Object[] arguments = proceedingJoinPoint.getArgs();
67
68 String shardName = PropsValues.SHARD_DEFAULT_NAME;
69
70 if (methodName.equals("getAccount") && (arguments.length == 2)) {
71 long companyId = (Long)arguments[0];
72
73 Shard shard = ShardLocalServiceUtil.getShard(
74 Company.class.getName(), companyId);
75
76 shardName = shard.getName();
77 }
78 else {
79 return proceedingJoinPoint.proceed();
80 }
81
82 if (_log.isInfoEnabled()) {
83 _log.info(
84 "Company service being set to shard " + shardName + " for " +
85 _getSignature(proceedingJoinPoint));
86 }
87
88 Object returnValue = null;
89
90 pushCompanyService(shardName);
91
92 try {
93 returnValue = proceedingJoinPoint.proceed();
94 }
95 finally {
96 popCompanyService();
97 }
98
99 return returnValue;
100 }
101
102 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
103 throws Throwable {
104
105 String methodName = proceedingJoinPoint.getSignature().getName();
106 Object[] arguments = proceedingJoinPoint.getArgs();
107
108 String shardName = PropsValues.SHARD_DEFAULT_NAME;
109
110 if (methodName.equals("addCompany")) {
111 String webId = (String)arguments[0];
112 String virtualHost = (String)arguments[1];
113 String mx = (String)arguments[2];
114 shardName = (String)arguments[3];
115
116 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
117
118 arguments[3] = shardName;
119 }
120 else if (methodName.equals("checkCompany")) {
121 String webId = (String)arguments[0];
122
123 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
124 if (arguments.length == 3) {
125 String mx = (String)arguments[1];
126 shardName = (String)arguments[2];
127
128 shardName = _getCompanyShardName(
129 webId, null, mx, shardName);
130
131 arguments[2] = shardName;
132 }
133
134 try {
135 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
136 webId);
137
138 shardName = company.getShardName();
139 }
140 catch (NoSuchCompanyException nsce) {
141 }
142 }
143 }
144 else if (methodName.startsWith("update")) {
145 long companyId = (Long)arguments[0];
146
147 Shard shard = ShardLocalServiceUtil.getShard(
148 Company.class.getName(), companyId);
149
150 shardName = shard.getName();
151 }
152 else {
153 return proceedingJoinPoint.proceed();
154 }
155
156 if (_log.isInfoEnabled()) {
157 _log.info(
158 "Company service being set to shard " + shardName + " for " +
159 _getSignature(proceedingJoinPoint));
160 }
161
162 Object returnValue = null;
163
164 pushCompanyService(shardName);
165
166 try {
167 returnValue = proceedingJoinPoint.proceed(arguments);
168 }
169 finally {
170 popCompanyService();
171 }
172
173 return returnValue;
174 }
175
176 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
177 throws Throwable {
178
179 _globalCallThreadLocal.set(new Object());
180
181 try {
182 if (_log.isInfoEnabled()) {
183 _log.info(
184 "All shards invoked for " +
185 _getSignature(proceedingJoinPoint));
186 }
187
188 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
189 _shardDataSourceTargetSource.setDataSource(shardName);
190 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
191
192 proceedingJoinPoint.proceed();
193 }
194 }
195 finally {
196 _globalCallThreadLocal.set(null);
197 }
198
199 return null;
200 }
201
202 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
203 throws Throwable {
204
205 Object target = proceedingJoinPoint.getTarget();
206
207 if (target instanceof ClassNamePersistence ||
208 target instanceof CompanyPersistence ||
209 target instanceof CounterPersistence ||
210 target instanceof ReleasePersistence ||
211 target instanceof ShardPersistence) {
212
213 _shardDataSourceTargetSource.setDataSource(
214 PropsValues.SHARD_DEFAULT_NAME);
215 _shardSessionFactoryTargetSource.setSessionFactory(
216 PropsValues.SHARD_DEFAULT_NAME);
217
218 if (_log.isDebugEnabled()) {
219 _log.debug(
220 "Using default shard for " +
221 _getSignature(proceedingJoinPoint));
222 }
223
224 return proceedingJoinPoint.proceed();
225 }
226
227 if (_globalCallThreadLocal.get() == null) {
228 _setShardNameByCompany();
229
230 String shardName = _getShardName();
231
232 _shardDataSourceTargetSource.setDataSource(shardName);
233 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
234
235 if (_log.isInfoEnabled()) {
236 _log.info(
237 "Using shard name " + shardName + " for " +
238 _getSignature(proceedingJoinPoint));
239 }
240
241 return proceedingJoinPoint.proceed();
242 }
243 else {
244 return proceedingJoinPoint.proceed();
245 }
246 }
247
248 public Object invokeUserService(ProceedingJoinPoint proceedingJoinPoint)
249 throws Throwable {
250
251 String methodName = proceedingJoinPoint.getSignature().getName();
252 Object[] arguments = proceedingJoinPoint.getArgs();
253
254 String shardName = PropsValues.SHARD_DEFAULT_NAME;
255
256 if (methodName.equals("searchCount")) {
257 long companyId = (Long)arguments[0];
258
259 Shard shard = ShardLocalServiceUtil.getShard(
260 Company.class.getName(), companyId);
261
262 shardName = shard.getName();
263 }
264 else {
265 return proceedingJoinPoint.proceed();
266 }
267
268 if (_log.isInfoEnabled()) {
269 _log.info(
270 "Company service being set to shard " + shardName + " for " +
271 _getSignature(proceedingJoinPoint));
272 }
273
274 Object returnValue = null;
275
276 pushCompanyService(shardName);
277
278 try {
279 returnValue = proceedingJoinPoint.proceed();
280 }
281 finally {
282 popCompanyService();
283 }
284
285 return returnValue;
286 }
287
288 public void setShardDataSourceTargetSource(
289 ShardDataSourceTargetSource shardDataSourceTargetSource) {
290
291 _shardDataSourceTargetSource = shardDataSourceTargetSource;
292 }
293
294 public void setShardSessionFactoryTargetSource(
295 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
296
297 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
298 }
299
300 protected DataSource getDataSource() {
301 return _shardDataSourceTargetSource.getDataSource();
302 }
303
304 protected String popCompanyService() {
305 return _getCompanyServiceStack().pop();
306 }
307
308 protected void pushCompanyService(long companyId) {
309 try {
310 Shard shard = ShardLocalServiceUtil.getShard(
311 Company.class.getName(), companyId);
312
313 String shardName = shard.getName();
314
315 pushCompanyService(shardName);
316 }
317 catch (Exception e) {
318 _log.error(e, e);
319 }
320 }
321
322 protected void pushCompanyService(String shardName) {
323 _getCompanyServiceStack().push(shardName);
324 }
325
326 private Stack<String> _getCompanyServiceStack() {
327 Stack<String> companyServiceStack = _companyServiceStack.get();
328
329 if (companyServiceStack == null) {
330 companyServiceStack = new Stack<String>();
331
332 _companyServiceStack.set(companyServiceStack);
333 }
334
335 return companyServiceStack;
336 }
337
338 private String _getCompanyShardName(
339 String webId, String virtualHost, String mx, String shardName) {
340
341 Map<String, String> shardParams = new HashMap<String, String>();
342
343 shardParams.put("webId", webId);
344 shardParams.put("mx", mx);
345
346 if (virtualHost != null) {
347 shardParams.put("virtualHost", virtualHost);
348 }
349
350 shardName = ShardUtil.getShardSelector().getShardName(
351 ShardUtil.COMPANY_SCOPE, shardName, shardParams);
352
353 return shardName;
354 }
355
356 private String _getShardName() {
357 return _shardNameThreadLocal.get();
358 }
359
360 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
361 String methodName = StringUtil.extractLast(
362 proceedingJoinPoint.getTarget().getClass().getName(),
363 StringPool.PERIOD);
364
365 methodName +=
366 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
367 "()";
368
369 return methodName;
370 }
371
372 private void _setShardName(String shardName) {
373 _shardNameThreadLocal.set(shardName);
374 }
375
376 private void _setShardNameByCompany() throws Throwable {
377 Stack<String> companyServiceStack = _getCompanyServiceStack();
378
379 if (companyServiceStack.isEmpty()) {
380 long companyId = CompanyThreadLocal.getCompanyId();
381
382 _setShardNameByCompanyId(companyId);
383 }
384 else {
385 String shardName = companyServiceStack.peek();
386
387 _setShardName(shardName);
388 }
389 }
390
391 private void _setShardNameByCompanyId(long companyId)
392 throws PortalException, SystemException {
393
394 if (companyId == 0) {
395 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
396 }
397 else {
398 Shard shard = ShardLocalServiceUtil.getShard(
399 Company.class.getName(), companyId);
400
401 String shardName = shard.getName();
402
403 _setShardName(shardName);
404 }
405 }
406
407 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
408
409 private static ThreadLocal<Stack<String>> _companyServiceStack =
410 new ThreadLocal<Stack<String>>();
411 private static ThreadLocal<Object> _globalCallThreadLocal =
412 new ThreadLocal<Object>();
413 private static ThreadLocal<String> _shardNameThreadLocal =
414 new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
415
416 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
417 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
418
419 }