1   /**
2    * Copyright (c) 2000-2009 Liferay, Inc. All rights reserved.
3    *
4    * Permission is hereby granted, free of charge, to any person obtaining a copy
5    * of this software and associated documentation files (the "Software"), to deal
6    * in the Software without restriction, including without limitation the rights
7    * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8    * copies of the Software, and to permit persons to whom the Software is
9    * furnished to do so, subject to the following conditions:
10   *
11   * The above copyright notice and this permission notice shall be included in
12   * all copies or substantial portions of the Software.
13   *
14   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20   * SOFTWARE.
21   */
22  
23  package com.liferay.portal.dao.shard;
24  
25  import com.liferay.counter.service.persistence.CounterPersistence;
26  import com.liferay.portal.NoSuchCompanyException;
27  import com.liferay.portal.PortalException;
28  import com.liferay.portal.SystemException;
29  import com.liferay.portal.kernel.log.Log;
30  import com.liferay.portal.kernel.log.LogFactoryUtil;
31  import com.liferay.portal.kernel.util.InitialThreadLocal;
32  import com.liferay.portal.kernel.util.StringPool;
33  import com.liferay.portal.kernel.util.StringUtil;
34  import com.liferay.portal.model.Company;
35  import com.liferay.portal.model.Shard;
36  import com.liferay.portal.security.auth.CompanyThreadLocal;
37  import com.liferay.portal.service.CompanyLocalServiceUtil;
38  import com.liferay.portal.service.ShardLocalServiceUtil;
39  import com.liferay.portal.service.persistence.ClassNamePersistence;
40  import com.liferay.portal.service.persistence.CompanyPersistence;
41  import com.liferay.portal.service.persistence.ReleasePersistence;
42  import com.liferay.portal.service.persistence.ShardPersistence;
43  import com.liferay.portal.util.PropsValues;
44  
45  import java.util.HashMap;
46  import java.util.Map;
47  import java.util.Stack;
48  
49  import javax.sql.DataSource;
50  
51  import org.aspectj.lang.ProceedingJoinPoint;
52  
53  /**
54   * <a href="ShardAdvice.java.html"><b><i>View Source</i></b></a>
55   *
56   * @author Michael Young
57   * @author Alexander Chow
58   *
59   */
60  public class ShardAdvice {
61  
62      public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
63          throws Throwable {
64  
65          Object[] arguments = proceedingJoinPoint.getArgs();
66  
67          long companyId = (Long)arguments[0];
68  
69          Shard shard = ShardLocalServiceUtil.getShard(
70              Company.class.getName(), companyId);
71  
72          String shardName = shard.getName();
73  
74          if (_log.isInfoEnabled()) {
75              _log.info(
76                  "Service being set to shard " + shardName + " for " +
77                      _getSignature(proceedingJoinPoint));
78          }
79  
80          Object returnValue = null;
81  
82          pushCompanyService(shardName);
83  
84          try {
85              returnValue = proceedingJoinPoint.proceed();
86          }
87          finally {
88              popCompanyService();
89          }
90  
91          return returnValue;
92      }
93  
94      public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
95          throws Throwable {
96  
97          String methodName = proceedingJoinPoint.getSignature().getName();
98          Object[] arguments = proceedingJoinPoint.getArgs();
99  
100         String shardName = PropsValues.SHARD_DEFAULT_NAME;
101 
102         if (methodName.equals("addCompany") && (arguments.length > 3)) {
103             String webId = (String)arguments[0];
104             String virtualHost = (String)arguments[1];
105             String mx = (String)arguments[2];
106             shardName = (String)arguments[3];
107 
108             shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
109 
110             arguments[3] = shardName;
111         }
112         else if (methodName.equals("checkCompany")) {
113             String webId = (String)arguments[0];
114 
115             if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
116                 if (arguments.length == 3) {
117                     String mx = (String)arguments[1];
118                     shardName = (String)arguments[2];
119 
120                     shardName = _getCompanyShardName(
121                         webId, null, mx, shardName);
122 
123                     arguments[2] = shardName;
124                 }
125 
126                 try {
127                     Company company = CompanyLocalServiceUtil.getCompanyByWebId(
128                         webId);
129 
130                     shardName = company.getShardName();
131                 }
132                 catch (NoSuchCompanyException nsce) {
133                 }
134             }
135         }
136         else if (methodName.startsWith("update")) {
137             long companyId = (Long)arguments[0];
138 
139             Shard shard = ShardLocalServiceUtil.getShard(
140                 Company.class.getName(), companyId);
141 
142             shardName = shard.getName();
143         }
144         else {
145             return proceedingJoinPoint.proceed();
146         }
147 
148         if (_log.isInfoEnabled()) {
149             _log.info(
150                 "Company service being set to shard " + shardName + " for " +
151                     _getSignature(proceedingJoinPoint));
152         }
153 
154         Object returnValue = null;
155 
156         pushCompanyService(shardName);
157 
158         try {
159             returnValue = proceedingJoinPoint.proceed(arguments);
160         }
161         finally {
162             popCompanyService();
163         }
164 
165         return returnValue;
166     }
167 
168     public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
169         throws Throwable {
170 
171         _globalCallThreadLocal.set(new Object());
172 
173         try {
174             if (_log.isInfoEnabled()) {
175                 _log.info(
176                     "All shards invoked for " +
177                         _getSignature(proceedingJoinPoint));
178             }
179 
180             for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
181                 _shardDataSourceTargetSource.setDataSource(shardName);
182                 _shardSessionFactoryTargetSource.setSessionFactory(shardName);
183 
184                 proceedingJoinPoint.proceed();
185             }
186         }
187         finally {
188             _globalCallThreadLocal.set(null);
189         }
190 
191         return null;
192     }
193 
194     public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
195         throws Throwable {
196 
197         Object target = proceedingJoinPoint.getTarget();
198 
199         if (target instanceof ClassNamePersistence ||
200             target instanceof CompanyPersistence ||
201             target instanceof CounterPersistence ||
202             target instanceof ReleasePersistence ||
203             target instanceof ShardPersistence) {
204 
205             _shardDataSourceTargetSource.setDataSource(
206                 PropsValues.SHARD_DEFAULT_NAME);
207             _shardSessionFactoryTargetSource.setSessionFactory(
208                 PropsValues.SHARD_DEFAULT_NAME);
209 
210             if (_log.isDebugEnabled()) {
211                 _log.debug(
212                     "Using default shard for " +
213                         _getSignature(proceedingJoinPoint));
214             }
215 
216             return proceedingJoinPoint.proceed();
217         }
218 
219         if (_globalCallThreadLocal.get() == null) {
220             _setShardNameByCompany();
221 
222             String shardName = _getShardName();
223 
224             _shardDataSourceTargetSource.setDataSource(shardName);
225             _shardSessionFactoryTargetSource.setSessionFactory(shardName);
226 
227             if (_log.isInfoEnabled()) {
228                 _log.info(
229                     "Using shard name " + shardName + " for " +
230                         _getSignature(proceedingJoinPoint));
231             }
232 
233             return proceedingJoinPoint.proceed();
234         }
235         else {
236             return proceedingJoinPoint.proceed();
237         }
238     }
239 
240     public void setShardDataSourceTargetSource(
241         ShardDataSourceTargetSource shardDataSourceTargetSource) {
242 
243         _shardDataSourceTargetSource = shardDataSourceTargetSource;
244     }
245 
246     public void setShardSessionFactoryTargetSource(
247         ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
248 
249         _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
250     }
251 
252     protected DataSource getDataSource() {
253         return _shardDataSourceTargetSource.getDataSource();
254     }
255 
256     protected String popCompanyService() {
257         return _getCompanyServiceStack().pop();
258     }
259 
260     protected void pushCompanyService(long companyId) {
261         try {
262             Shard shard = ShardLocalServiceUtil.getShard(
263                 Company.class.getName(), companyId);
264 
265             String shardName = shard.getName();
266 
267             pushCompanyService(shardName);
268         }
269         catch (Exception e) {
270             _log.error(e, e);
271         }
272     }
273 
274     protected void pushCompanyService(String shardName) {
275         _getCompanyServiceStack().push(shardName);
276     }
277 
278     private Stack<String> _getCompanyServiceStack() {
279         Stack<String> companyServiceStack = _companyServiceStack.get();
280 
281         if (companyServiceStack == null) {
282             companyServiceStack = new Stack<String>();
283 
284             _companyServiceStack.set(companyServiceStack);
285         }
286 
287         return companyServiceStack;
288     }
289 
290     private String _getCompanyShardName(
291         String webId, String virtualHost, String mx, String shardName) {
292 
293         Map<String, String> shardParams = new HashMap<String, String>();
294 
295         shardParams.put("webId", webId);
296         shardParams.put("mx", mx);
297 
298         if (virtualHost != null) {
299             shardParams.put("virtualHost", virtualHost);
300         }
301 
302         shardName = ShardUtil.getShardSelector().getShardName(
303             ShardUtil.COMPANY_SCOPE, shardName, shardParams);
304 
305         return shardName;
306     }
307 
308     private String _getShardName() {
309         return _shardNameThreadLocal.get();
310     }
311 
312     private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
313         String methodName = StringUtil.extractLast(
314             proceedingJoinPoint.getTarget().getClass().getName(),
315             StringPool.PERIOD);
316 
317         methodName +=
318             StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
319                 "()";
320 
321         return methodName;
322     }
323 
324     private void _setShardName(String shardName) {
325         _shardNameThreadLocal.set(shardName);
326     }
327 
328     private void _setShardNameByCompany() throws Throwable {
329         Stack<String> companyServiceStack = _getCompanyServiceStack();
330 
331         if (companyServiceStack.isEmpty()) {
332             long companyId = CompanyThreadLocal.getCompanyId();
333 
334             _setShardNameByCompanyId(companyId);
335         }
336         else {
337             String shardName = companyServiceStack.peek();
338 
339             _setShardName(shardName);
340         }
341     }
342 
343     private void _setShardNameByCompanyId(long companyId)
344         throws PortalException, SystemException {
345 
346         if (companyId == 0) {
347             _setShardName(PropsValues.SHARD_DEFAULT_NAME);
348         }
349         else {
350             Shard shard = ShardLocalServiceUtil.getShard(
351                 Company.class.getName(), companyId);
352 
353             String shardName = shard.getName();
354 
355             _setShardName(shardName);
356         }
357     }
358 
359     private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
360 
361     private static ThreadLocal<Stack<String>> _companyServiceStack =
362         new ThreadLocal<Stack<String>>();
363     private static ThreadLocal<Object> _globalCallThreadLocal =
364         new ThreadLocal<Object>();
365     private static ThreadLocal<String> _shardNameThreadLocal =
366         new InitialThreadLocal<String>(PropsValues.SHARD_DEFAULT_NAME);
367 
368     private ShardDataSourceTargetSource _shardDataSourceTargetSource;
369     private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
370 
371 }