1
22
23 package com.liferay.portal.dao.shard;
24
25 import com.liferay.counter.service.persistence.CounterPersistence;
26 import com.liferay.portal.NoSuchCompanyException;
27 import com.liferay.portal.PortalException;
28 import com.liferay.portal.SystemException;
29 import com.liferay.portal.kernel.log.Log;
30 import com.liferay.portal.kernel.log.LogFactoryUtil;
31 import com.liferay.portal.kernel.util.InitialThreadLocal;
32 import com.liferay.portal.kernel.util.StringPool;
33 import com.liferay.portal.kernel.util.StringUtil;
34 import com.liferay.portal.model.Company;
35 import com.liferay.portal.model.Shard;
36 import com.liferay.portal.security.auth.CompanyThreadLocal;
37 import com.liferay.portal.service.CompanyLocalServiceUtil;
38 import com.liferay.portal.service.ShardLocalServiceUtil;
39 import com.liferay.portal.service.persistence.ClassNamePersistence;
40 import com.liferay.portal.service.persistence.CompanyPersistence;
41 import com.liferay.portal.service.persistence.ReleasePersistence;
42 import com.liferay.portal.service.persistence.ShardPersistence;
43 import com.liferay.portal.util.PropsValues;
44
45 import java.util.HashMap;
46 import java.util.Map;
47 import java.util.Stack;
48
49 import javax.sql.DataSource;
50
51 import org.aspectj.lang.ProceedingJoinPoint;
52
53
60 public class ShardAdvice {
61
62 public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
63 throws Throwable {
64
65 Object[] arguments = proceedingJoinPoint.getArgs();
66
67 long companyId = (Long)arguments[0];
68
69 Shard shard = ShardLocalServiceUtil.getShard(
70 Company.class.getName(), companyId);
71
72 String shardName = shard.getName();
73
74 if (_log.isInfoEnabled()) {
75 _log.info(
76 "Service being set to shard " + shardName + " for " +
77 _getSignature(proceedingJoinPoint));
78 }
79
80 Object returnValue = null;
81
82 pushCompanyService(shardName);
83
84 try {
85 returnValue = proceedingJoinPoint.proceed();
86 }
87 finally {
88 popCompanyService();
89 }
90
91 return returnValue;
92 }
93
94 public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
95 throws Throwable {
96
97 String methodName = proceedingJoinPoint.getSignature().getName();
98 Object[] arguments = proceedingJoinPoint.getArgs();
99
100 String shardName = PropsValues.SHARD_DEFAULT_NAME;
101
102 if (methodName.equals("addCompany") && (arguments.length > 3)) {
103 String webId = (String)arguments[0];
104 String virtualHost = (String)arguments[1];
105 String mx = (String)arguments[2];
106 shardName = (String)arguments[3];
107
108 shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
109
110 arguments[3] = shardName;
111 }
112 else if (methodName.equals("checkCompany")) {
113 String webId = (String)arguments[0];
114
115 if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
116 if (arguments.length == 3) {
117 String mx = (String)arguments[1];
118 shardName = (String)arguments[2];
119
120 shardName = _getCompanyShardName(
121 webId, null, mx, shardName);
122
123 arguments[2] = shardName;
124 }
125
126 try {
127 Company company = CompanyLocalServiceUtil.getCompanyByWebId(
128 webId);
129
130 shardName = company.getShardName();
131 }
132 catch (NoSuchCompanyException nsce) {
133 }
134 }
135 }
136 else if (methodName.startsWith("update")) {
137 long companyId = (Long)arguments[0];
138
139 Shard shard = ShardLocalServiceUtil.getShard(
140 Company.class.getName(), companyId);
141
142 shardName = shard.getName();
143 }
144 else {
145 return proceedingJoinPoint.proceed();
146 }
147
148 if (_log.isInfoEnabled()) {
149 _log.info(
150 "Company service being set to shard " + shardName + " for " +
151 _getSignature(proceedingJoinPoint));
152 }
153
154 Object returnValue = null;
155
156 pushCompanyService(shardName);
157
158 try {
159 returnValue = proceedingJoinPoint.proceed(arguments);
160 }
161 finally {
162 popCompanyService();
163 }
164
165 return returnValue;
166 }
167
168 public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
169 throws Throwable {
170
171 _globalCallThreadLocal.set(new Object());
172
173 try {
174 if (_log.isInfoEnabled()) {
175 _log.info(
176 "All shards invoked for " +
177 _getSignature(proceedingJoinPoint));
178 }
179
180 for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
181 _shardDataSourceTargetSource.setDataSource(shardName);
182 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
183
184 proceedingJoinPoint.proceed();
185 }
186 }
187 finally {
188 _globalCallThreadLocal.set(null);
189 }
190
191 return null;
192 }
193
194 public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
195 throws Throwable {
196
197 Object target = proceedingJoinPoint.getTarget();
198
199 if (target instanceof ClassNamePersistence ||
200 target instanceof CompanyPersistence ||
201 target instanceof CounterPersistence ||
202 target instanceof ReleasePersistence ||
203 target instanceof ShardPersistence) {
204
205 _shardDataSourceTargetSource.setDataSource(
206 PropsValues.SHARD_DEFAULT_NAME);
207 _shardSessionFactoryTargetSource.setSessionFactory(
208 PropsValues.SHARD_DEFAULT_NAME);
209
210 if (_log.isDebugEnabled()) {
211 _log.debug(
212 "Using default shard for " +
213 _getSignature(proceedingJoinPoint));
214 }
215
216 return proceedingJoinPoint.proceed();
217 }
218
219 if (_globalCallThreadLocal.get() == null) {
220 _setShardNameByCompany();
221
222 String shardName = _getShardName();
223
224 _shardDataSourceTargetSource.setDataSource(shardName);
225 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
226
227 if (_log.isInfoEnabled()) {
228 _log.info(
229 "Using shard name " + shardName + " for " +
230 _getSignature(proceedingJoinPoint));
231 }
232
233 return proceedingJoinPoint.proceed();
234 }
235 else {
236 return proceedingJoinPoint.proceed();
237 }
238 }
239
240 public void setShardDataSourceTargetSource(
241 ShardDataSourceTargetSource shardDataSourceTargetSource) {
242
243 _shardDataSourceTargetSource = shardDataSourceTargetSource;
244 }
245
246 public void setShardSessionFactoryTargetSource(
247 ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
248
249 _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
250 }
251
252 protected DataSource getDataSource() {
253 return _shardDataSourceTargetSource.getDataSource();
254 }
255
256 protected String popCompanyService() {
257 return _getCompanyServiceStack().pop();
258 }
259
260 protected void pushCompanyService(long companyId) {
261 try {
262 Shard shard = ShardLocalServiceUtil.getShard(
263 Company.class.getName(), companyId);
264
265 String shardName = shard.getName();
266
267 pushCompanyService(shardName);
268 }
269 catch (Exception e) {
270 _log.error(e, e);
271 }
272 }
273
274 protected void pushCompanyService(String shardName) {
275 _getCompanyServiceStack().push(shardName);
276 }
277
278 private Stack<String> _getCompanyServiceStack() {
279 Stack<String> companyServiceStack = _companyServiceStack.get();
280
281 if (companyServiceStack == null) {
282 companyServiceStack = new Stack<String>();
283
284 _companyServiceStack.set(companyServiceStack);
285 }
286
287 return companyServiceStack;
288 }
289
290 private String _getCompanyShardName(
291 String webId, String virtualHost, String mx, String shardName) {
292
293 Map<String, String> shardParams = new HashMap<String, String>();
294
295 shardParams.put("webId", webId);
296 shardParams.put("mx", mx);
297
298 if (virtualHost != null) {
299 shardParams.put("virtualHost", virtualHost);
300 }
301
302 shardName = ShardUtil.getShardSelector().getShardName(
303 ShardUtil.COMPANY_SCOPE, shardName, shardParams);
304
305 return shardName;
306 }
307
308 private String _getShardName() {
309 return _shardNameThreadLocal.get();
310 }
311
312 private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
313 String methodName = StringUtil.extractLast(
314 proceedingJoinPoint.getTarget().getClass().getName(),
315 StringPool.PERIOD);
316
317 methodName +=
318 StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
319 "()";
320
321 return methodName;
322 }
323
324 private void _setShardName(String shardName) {
325 _shardNameThreadLocal.set(shardName);
326 }
327
328 private void _setShardNameByCompany() throws Throwable {
329 Stack<String> companyServiceStack = _getCompanyServiceStack();
330
331 if (companyServiceStack.isEmpty()) {
332 long companyId = CompanyThreadLocal.getCompanyId();
333
334 _setShardNameByCompanyId(companyId);
335 }
336 else {
337 String shardName = companyServiceStack.peek();
338
339 _setShardName(shardName);
340 }
341 }
342
343 private void _setShardNameByCompanyId(long companyId)
344 throws PortalException, SystemException {
345
346 if (companyId == 0) {
347 _setShardName(PropsValues.SHARD_DEFAULT_NAME);
348 }
349 else {
350 Shard shard = ShardLocalServiceUtil.getShard(
351 Company.class.getName(), companyId);
352
353 String shardName = shard.getName();
354
355 _setShardName(shardName);
356 }
357 }
358
359 private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
360
361 private static ThreadLocal<Stack<String>> _companyServiceStack =
362 new ThreadLocal<Stack<String>>();
363 private static ThreadLocal<Object> _globalCallThreadLocal =
364 new ThreadLocal<Object>();
365 private static ThreadLocal<String> _shardNameThreadLocal =
366 new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
367
368 private ShardDataSourceTargetSource _shardDataSourceTargetSource;
369 private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
370
371 }