1   /**
2    * Copyright (c) 2000-2009 Liferay, Inc. All rights reserved.
3    *
4    * Permission is hereby granted, free of charge, to any person obtaining a copy
5    * of this software and associated documentation files (the "Software"), to deal
6    * in the Software without restriction, including without limitation the rights
7    * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8    * copies of the Software, and to permit persons to whom the Software is
9    * furnished to do so, subject to the following conditions:
10   *
11   * The above copyright notice and this permission notice shall be included in
12   * all copies or substantial portions of the Software.
13   *
14   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20   * SOFTWARE.
21   */
22  
23  package com.liferay.portal.dao.shard;
24  
25  import com.liferay.counter.service.persistence.CounterPersistence;
26  import com.liferay.portal.NoSuchCompanyException;
27  import com.liferay.portal.PortalException;
28  import com.liferay.portal.SystemException;
29  import com.liferay.portal.kernel.log.Log;
30  import com.liferay.portal.kernel.log.LogFactoryUtil;
31  import com.liferay.portal.kernel.util.InitialThreadLocal;
32  import com.liferay.portal.kernel.util.StringPool;
33  import com.liferay.portal.kernel.util.StringUtil;
34  import com.liferay.portal.model.Company;
35  import com.liferay.portal.model.Shard;
36  import com.liferay.portal.security.auth.CompanyThreadLocal;
37  import com.liferay.portal.service.CompanyLocalServiceUtil;
38  import com.liferay.portal.service.ShardLocalServiceUtil;
39  import com.liferay.portal.service.persistence.ClassNamePersistence;
40  import com.liferay.portal.service.persistence.CompanyPersistence;
41  import com.liferay.portal.service.persistence.ReleasePersistence;
42  import com.liferay.portal.service.persistence.ShardPersistence;
43  import com.liferay.portal.util.PropsValues;
44  
45  import java.util.HashMap;
46  import java.util.Map;
47  import java.util.Stack;
48  
49  import javax.sql.DataSource;
50  
51  import org.aspectj.lang.ProceedingJoinPoint;
52  
53  /**
54   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
55   *
56   * @author Michael Young
57   * @author Alexander Chow
58   *
59   */
60  public class ShardAdvice {
61  
62      public Object invokeAccountService(ProceedingJoinPoint proceedingJoinPoint)
63          throws Throwable {
64  
65          String methodName = proceedingJoinPoint.getSignature().getName();
66          Object[] arguments = proceedingJoinPoint.getArgs();
67  
68          String shardName = PropsValues.SHARD_DEFAULT_NAME;
69  
70          if (methodName.equals("getAccount") && (arguments.length == 2)) {
71              long companyId = (Long)arguments[0];
72  
73              Shard shard = ShardLocalServiceUtil.getShard(
74                  Company.class.getName(), companyId);
75  
76              shardName = shard.getName();
77          }
78          else {
79              return proceedingJoinPoint.proceed();
80          }
81  
82          if (_log.isInfoEnabled()) {
83              _log.info(
84                  "Company service being set to shard " + shardName + " for " +
85                      _getSignature(proceedingJoinPoint));
86          }
87  
88          Object returnValue = null;
89  
90          pushCompanyService(shardName);
91  
92          try {
93              returnValue = proceedingJoinPoint.proceed();
94          }
95          finally {
96              popCompanyService();
97          }
98  
99          return returnValue;
100     }
101 
102     public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
103         throws Throwable {
104 
105         String methodName = proceedingJoinPoint.getSignature().getName();
106         Object[] arguments = proceedingJoinPoint.getArgs();
107 
108         String shardName = PropsValues.SHARD_DEFAULT_NAME;
109 
110         if (methodName.equals("addCompany")) {
111             String webId = (String)arguments[0];
112             String virtualHost = (String)arguments[1];
113             String mx = (String)arguments[2];
114             shardName = (String)arguments[3];
115 
116             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
117 
118             arguments[3] = shardName;
119         }
120         else if (methodName.equals("checkCompany")) {
121             String webId = (String)arguments[0];
122 
123             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
124                 if (arguments.length == 3) {
125                     String mx = (String)arguments[1];
126                     shardName = (String)arguments[2];
127 
128                     shardName = _getCompanyShardName(
129                         webId, null, mx, shardName);
130 
131                     arguments[2] = shardName;
132                 }
133 
134                 try {
135                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
136                         webId);
137 
138                     shardName = company.getShardName();
139                 }
140                 catch (NoSuchCompanyException nsce) {
141                 }
142             }
143         }
144         else if (methodName.startsWith("update")) {
145             long companyId = (Long)arguments[0];
146 
147             Shard shard = ShardLocalServiceUtil.getShard(
148                 Company.class.getName(), companyId);
149 
150             shardName = shard.getName();
151         }
152         else {
153             return proceedingJoinPoint.proceed();
154         }
155 
156         if (_log.isInfoEnabled()) {
157             _log.info(
158                 "Company service being set to shard " + shardName + " for " +
159                     _getSignature(proceedingJoinPoint));
160         }
161 
162         Object returnValue = null;
163 
164         pushCompanyService(shardName);
165 
166         try {
167             returnValue = proceedingJoinPoint.proceed(arguments);
168         }
169         finally {
170             popCompanyService();
171         }
172 
173         return returnValue;
174     }
175 
176     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
177         throws Throwable {
178 
179         _globalCallThreadLocal.set(new Object());
180 
181         try {
182             if (_log.isInfoEnabled()) {
183                 _log.info(
184                     "All shards invoked for " +
185                         _getSignature(proceedingJoinPoint));
186             }
187 
188             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
189                 _shardDataSourceTargetSource.setDataSource(shardName);
190                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
191 
192                 proceedingJoinPoint.proceed();
193             }
194         }
195         finally {
196             _globalCallThreadLocal.set(null);
197         }
198 
199         return null;
200     }
201 
202     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
203         throws Throwable {
204 
205         Object target = proceedingJoinPoint.getTarget();
206 
207         if (target instanceof ClassNamePersistence ||
208             target instanceof CompanyPersistence ||
209             target instanceof CounterPersistence ||
210             target instanceof ReleasePersistence ||
211             target instanceof ShardPersistence) {
212 
213             _shardDataSourceTargetSource.setDataSource(
214                 PropsValues.SHARD_DEFAULT_NAME);
215             _shardSessionFactoryTargetSource.setSessionFactory(
216                 PropsValues.SHARD_DEFAULT_NAME);
217 
218             if (_log.isDebugEnabled()) {
219                 _log.debug(
220                     "Using default shard for " +
221                         _getSignature(proceedingJoinPoint));
222             }
223 
224             return proceedingJoinPoint.proceed();
225         }
226 
227         if (_globalCallThreadLocal.get() == null) {
228             _setShardNameByCompany();
229 
230             String shardName = _getShardName();
231 
232             _shardDataSourceTargetSource.setDataSource(shardName);
233             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
234 
235             if (_log.isInfoEnabled()) {
236                 _log.info(
237                     "Using shard name " + shardName + " for " +
238                         _getSignature(proceedingJoinPoint));
239             }
240 
241             return proceedingJoinPoint.proceed();
242         }
243         else {
244             return proceedingJoinPoint.proceed();
245         }
246     }
247 
248     public Object invokeUserService(ProceedingJoinPoint proceedingJoinPoint)
249         throws Throwable {
250 
251         String methodName = proceedingJoinPoint.getSignature().getName();
252         Object[] arguments = proceedingJoinPoint.getArgs();
253 
254         String shardName = PropsValues.SHARD_DEFAULT_NAME;
255 
256         if (methodName.equals("searchCount")) {
257             long companyId = (Long)arguments[0];
258 
259             Shard shard = ShardLocalServiceUtil.getShard(
260                 Company.class.getName(), companyId);
261 
262             shardName = shard.getName();
263         }
264         else {
265             return proceedingJoinPoint.proceed();
266         }
267 
268         if (_log.isInfoEnabled()) {
269             _log.info(
270                 "Company service being set to shard " + shardName + " for " +
271                     _getSignature(proceedingJoinPoint));
272         }
273 
274         Object returnValue = null;
275 
276         pushCompanyService(shardName);
277 
278         try {
279             returnValue = proceedingJoinPoint.proceed();
280         }
281         finally {
282             popCompanyService();
283         }
284 
285         return returnValue;
286     }
287 
288     public void setShardDataSourceTargetSource(
289         ShardDataSourceTargetSource shardDataSourceTargetSource) {
290 
291         _shardDataSourceTargetSource = shardDataSourceTargetSource;
292     }
293 
294     public void setShardSessionFactoryTargetSource(
295         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
296 
297         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
298     }
299 
300     protected DataSource getDataSource() {
301         return _shardDataSourceTargetSource.getDataSource();
302     }
303 
304     protected String popCompanyService() {
305         return _getCompanyServiceStack().pop();
306     }
307 
308     protected void pushCompanyService(long companyId) {
309         try {
310             Shard shard = ShardLocalServiceUtil.getShard(
311                 Company.class.getName(), companyId);
312 
313             String shardName = shard.getName();
314 
315             pushCompanyService(shardName);
316         }
317         catch (Exception e) {
318             _log.error(e, e);
319         }
320     }
321 
322     protected void pushCompanyService(String shardName) {
323         _getCompanyServiceStack().push(shardName);
324     }
325 
326     private Stack<String> _getCompanyServiceStack() {
327         Stack<String> companyServiceStack = _companyServiceStack.get();
328 
329         if (companyServiceStack == null) {
330             companyServiceStack = new Stack<String>();
331 
332             _companyServiceStack.set(companyServiceStack);
333         }
334 
335         return companyServiceStack;
336     }
337 
338     private String _getCompanyShardName(
339         String webId, String virtualHost, String mx, String shardName) {
340 
341         Map<String, String> shardParams = new HashMap<String, String>();
342 
343         shardParams.put("webId", webId);
344         shardParams.put("mx", mx);
345 
346         if (virtualHost != null) {
347             shardParams.put("virtualHost", virtualHost);
348         }
349 
350         shardName = ShardUtil.getShardSelector().getShardName(
351             ShardUtil.COMPANY_SCOPE, shardName, shardParams);
352 
353         return shardName;
354     }
355 
356     private String _getShardName() {
357         return _shardNameThreadLocal.get();
358     }
359 
360     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
361         String methodName = StringUtil.extractLast(
362             proceedingJoinPoint.getTarget().getClass().getName(),
363             StringPool.PERIOD);
364 
365         methodName +=
366             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
367                 "()";
368 
369         return methodName;
370     }
371 
372     private void _setShardName(String shardName) {
373         _shardNameThreadLocal.set(shardName);
374     }
375 
376     private void _setShardNameByCompany() throws Throwable {
377         Stack<String> companyServiceStack = _getCompanyServiceStack();
378 
379         if (companyServiceStack.isEmpty()) {
380             long companyId = CompanyThreadLocal.getCompanyId();
381 
382             _setShardNameByCompanyId(companyId);
383         }
384         else {
385             String shardName = companyServiceStack.peek();
386 
387             _setShardName(shardName);
388         }
389     }
390 
391     private void _setShardNameByCompanyId(long companyId)
392         throws PortalException, SystemException {
393 
394         if (companyId == 0) {
395             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
396         }
397         else {
398             Shard shard = ShardLocalServiceUtil.getShard(
399                 Company.class.getName(), companyId);
400 
401             String shardName = shard.getName();
402 
403             _setShardName(shardName);
404         }
405     }
406 
407     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
408 
409     private static ThreadLocal<Stack<String>> _companyServiceStack =
410         new ThreadLocal<Stack<String>>();
411     private static ThreadLocal<Object> _globalCallThreadLocal =
412         new ThreadLocal<Object>();
413     private static ThreadLocal<String> _shardNameThreadLocal =
414         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
415 
416     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
417     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
418 
419 }