001    /**
002     * Copyright (c) 2000-2011 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.dao.shard;
016    
017    import com.liferay.counter.service.persistence.CounterFinder;
018    import com.liferay.counter.service.persistence.CounterPersistence;
019    import com.liferay.portal.NoSuchCompanyException;
020    import com.liferay.portal.kernel.exception.PortalException;
021    import com.liferay.portal.kernel.exception.SystemException;
022    import com.liferay.portal.kernel.log.Log;
023    import com.liferay.portal.kernel.log.LogFactoryUtil;
024    import com.liferay.portal.kernel.util.InfrastructureUtil;
025    import com.liferay.portal.kernel.util.InitialThreadLocal;
026    import com.liferay.portal.kernel.util.StringPool;
027    import com.liferay.portal.kernel.util.StringUtil;
028    import com.liferay.portal.model.Company;
029    import com.liferay.portal.model.Shard;
030    import com.liferay.portal.security.auth.CompanyThreadLocal;
031    import com.liferay.portal.service.CompanyLocalServiceUtil;
032    import com.liferay.portal.service.ShardLocalServiceUtil;
033    import com.liferay.portal.service.persistence.ClassNamePersistence;
034    import com.liferay.portal.service.persistence.CompanyPersistence;
035    import com.liferay.portal.service.persistence.ReleasePersistence;
036    import com.liferay.portal.service.persistence.ShardPersistence;
037    import com.liferay.portal.util.PropsValues;
038    
039    import java.util.EmptyStackException;
040    import java.util.HashMap;
041    import java.util.Map;
042    import java.util.Stack;
043    
044    import javax.sql.DataSource;
045    
046    import org.aspectj.lang.ProceedingJoinPoint;
047    
048    /**
049     * @author Michael Young
050     * @author Alexander Chow
051     */
052    public class ShardAdvice {
053    
054            public void afterPropertiesSet() {
055                    if (_shardDataSourceTargetSource == null) {
056                            _shardDataSourceTargetSource =
057                                    (ShardDataSourceTargetSource)InfrastructureUtil.
058                                            getShardDataSourceTargetSource();
059                    }
060    
061                    if (_shardSessionFactoryTargetSource == null) {
062                            _shardSessionFactoryTargetSource =
063                                    (ShardSessionFactoryTargetSource)InfrastructureUtil.
064                                            getShardSessionFactoryTargetSource();
065                    }
066            }
067    
068            public Object invokeByParameter(ProceedingJoinPoint proceedingJoinPoint)
069                    throws Throwable {
070    
071                    Object[] arguments = proceedingJoinPoint.getArgs();
072    
073                    long companyId = (Long)arguments[0];
074    
075                    Shard shard = ShardLocalServiceUtil.getShard(
076                            Company.class.getName(), companyId);
077    
078                    String shardName = shard.getName();
079    
080                    if (_log.isInfoEnabled()) {
081                            _log.info(
082                                    "Service being set to shard " + shardName + " for " +
083                                            _getSignature(proceedingJoinPoint));
084                    }
085    
086                    Object returnValue = null;
087    
088                    pushCompanyService(shardName);
089    
090                    try {
091                            returnValue = proceedingJoinPoint.proceed();
092                    }
093                    finally {
094                            popCompanyService();
095                    }
096    
097                    return returnValue;
098            }
099    
100            public Object invokeCompanyService(ProceedingJoinPoint proceedingJoinPoint)
101                    throws Throwable {
102    
103                    String methodName = proceedingJoinPoint.getSignature().getName();
104                    Object[] arguments = proceedingJoinPoint.getArgs();
105    
106                    String shardName = PropsValues.SHARD_DEFAULT_NAME;
107    
108                    if (methodName.equals("addCompany")) {
109                            String webId = (String)arguments[0];
110                            String virtualHost = (String)arguments[1];
111                            String mx = (String)arguments[2];
112                            shardName = (String)arguments[3];
113    
114                            shardName = _getCompanyShardName(webId, virtualHost, mx, shardName);
115    
116                            arguments[3] = shardName;
117                    }
118                    else if (methodName.equals("checkCompany")) {
119                            String webId = (String)arguments[0];
120    
121                            if (!webId.equals(PropsValues.COMPANY_DEFAULT_WEB_ID)) {
122                                    if (arguments.length == 3) {
123                                            String mx = (String)arguments[1];
124                                            shardName = (String)arguments[2];
125    
126                                            shardName = _getCompanyShardName(
127                                                    webId, null, mx, shardName);
128    
129                                            arguments[2] = shardName;
130                                    }
131    
132                                    try {
133                                            Company company = CompanyLocalServiceUtil.getCompanyByWebId(
134                                                    webId);
135    
136                                            shardName = company.getShardName();
137                                    }
138                                    catch (NoSuchCompanyException nsce) {
139                                    }
140                            }
141                    }
142                    else if (methodName.startsWith("update")) {
143                            long companyId = (Long)arguments[0];
144    
145                            Shard shard = ShardLocalServiceUtil.getShard(
146                                    Company.class.getName(), companyId);
147    
148                            shardName = shard.getName();
149                    }
150                    else {
151                            return proceedingJoinPoint.proceed();
152                    }
153    
154                    if (_log.isInfoEnabled()) {
155                            _log.info(
156                                    "Company service being set to shard " + shardName + " for " +
157                                            _getSignature(proceedingJoinPoint));
158                    }
159    
160                    Object returnValue = null;
161    
162                    pushCompanyService(shardName);
163    
164                    try {
165                            returnValue = proceedingJoinPoint.proceed(arguments);
166                    }
167                    finally {
168                            popCompanyService();
169                    }
170    
171                    return returnValue;
172            }
173    
174            /**
175             * Invoke a join point across all shards while ignoring the company service
176             * stack.
177             *
178             * @see #invokeIteratively
179             */
180            public Object invokeGlobally(ProceedingJoinPoint proceedingJoinPoint)
181                    throws Throwable {
182    
183                    _globalCall.set(new Object());
184    
185                    try {
186                            if (_log.isInfoEnabled()) {
187                                    _log.info(
188                                            "All shards invoked for " +
189                                                    _getSignature(proceedingJoinPoint));
190                            }
191    
192                            for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
193                                    _shardDataSourceTargetSource.setDataSource(shardName);
194                                    _shardSessionFactoryTargetSource.setSessionFactory(shardName);
195    
196                                    proceedingJoinPoint.proceed();
197                            }
198                    }
199                    finally {
200                            _globalCall.set(null);
201                    }
202    
203                    return null;
204            }
205    
206            /**
207             * Invoke a join point across all shards while using the company service
208             * stack.
209             *
210             * @see #invokeGlobally
211             */
212            public Object invokeIteratively(ProceedingJoinPoint proceedingJoinPoint)
213                    throws Throwable {
214    
215                    if (_log.isInfoEnabled()) {
216                            _log.info(
217                                    "Iterating through all shards for " +
218                                            _getSignature(proceedingJoinPoint));
219                    }
220    
221                    for (String shardName : PropsValues.SHARD_AVAILABLE_NAMES) {
222                            pushCompanyService(shardName);
223    
224                            try {
225                                    proceedingJoinPoint.proceed();
226                            }
227                            finally {
228                                    popCompanyService();
229                            }
230                    }
231    
232                    return null;
233            }
234    
235            public Object invokePersistence(ProceedingJoinPoint proceedingJoinPoint)
236                    throws Throwable {
237    
238                    if ((_shardDataSourceTargetSource == null) ||
239                            (_shardSessionFactoryTargetSource == null)) {
240    
241                            return proceedingJoinPoint.proceed();
242                    }
243    
244                    Object target = proceedingJoinPoint.getTarget();
245    
246                    if (target instanceof ClassNamePersistence ||
247                            target instanceof CompanyPersistence ||
248                            target instanceof CounterFinder ||
249                            target instanceof CounterPersistence ||
250                            target instanceof ReleasePersistence ||
251                            target instanceof ShardPersistence) {
252    
253                            _shardDataSourceTargetSource.setDataSource(
254                                    PropsValues.SHARD_DEFAULT_NAME);
255                            _shardSessionFactoryTargetSource.setSessionFactory(
256                                    PropsValues.SHARD_DEFAULT_NAME);
257    
258                            if (_log.isDebugEnabled()) {
259                                    _log.debug(
260                                            "Using default shard for " +
261                                                    _getSignature(proceedingJoinPoint));
262                            }
263    
264                            return proceedingJoinPoint.proceed();
265                    }
266    
267                    if (_globalCall.get() == null) {
268                            _setShardNameByCompany();
269    
270                            String shardName = _getShardName();
271    
272                            _shardDataSourceTargetSource.setDataSource(shardName);
273                            _shardSessionFactoryTargetSource.setSessionFactory(shardName);
274    
275                            if (_log.isInfoEnabled()) {
276                                    _log.info(
277                                            "Using shard name " + shardName + " for " +
278                                                    _getSignature(proceedingJoinPoint));
279                            }
280    
281                            return proceedingJoinPoint.proceed();
282                    }
283                    else {
284                            return proceedingJoinPoint.proceed();
285                    }
286            }
287    
288            public void setShardDataSourceTargetSource(
289                    ShardDataSourceTargetSource shardDataSourceTargetSource) {
290    
291                    _shardDataSourceTargetSource = shardDataSourceTargetSource;
292            }
293    
294            public void setShardSessionFactoryTargetSource(
295                    ShardSessionFactoryTargetSource shardSessionFactoryTargetSource) {
296    
297                    _shardSessionFactoryTargetSource = shardSessionFactoryTargetSource;
298            }
299    
300            protected String getCurrentShardName() {
301                    String shardName = null;
302    
303                    try {
304                            shardName = _getCompanyServiceStack().peek();
305                    }
306                    catch (EmptyStackException ese) {
307                    }
308    
309                    if (shardName == null) {
310                            shardName = PropsValues.SHARD_DEFAULT_NAME;
311                    }
312    
313                    return shardName;
314            }
315    
316            protected DataSource getDataSource() {
317                    return _shardDataSourceTargetSource.getDataSource();
318            }
319    
320            protected String popCompanyService() {
321                    return _getCompanyServiceStack().pop();
322            }
323    
324            protected void pushCompanyService(long companyId) {
325                    try {
326                            Shard shard = ShardLocalServiceUtil.getShard(
327                                    Company.class.getName(), companyId);
328    
329                            String shardName = shard.getName();
330    
331                            pushCompanyService(shardName);
332                    }
333                    catch (Exception e) {
334                            _log.error(e, e);
335                    }
336            }
337    
338            protected void pushCompanyService(String shardName) {
339                    _getCompanyServiceStack().push(shardName);
340            }
341    
342            private Stack<String> _getCompanyServiceStack() {
343                    Stack<String> companyServiceStack = _companyServiceStack.get();
344    
345                    if (companyServiceStack == null) {
346                            companyServiceStack = new Stack<String>();
347    
348                            _companyServiceStack.set(companyServiceStack);
349                    }
350    
351                    return companyServiceStack;
352            }
353    
354            private String _getCompanyShardName(
355                    String webId, String virtualHost, String mx, String shardName) {
356    
357                    Map<String, String> shardParams = new HashMap<String, String>();
358    
359                    shardParams.put("webId", webId);
360                    shardParams.put("mx", mx);
361    
362                    if (virtualHost != null) {
363                            shardParams.put("virtualHost", virtualHost);
364                    }
365    
366                    shardName = _shardSelector.getShardName(
367                            ShardSelector.COMPANY_SCOPE, shardName, shardParams);
368    
369                    return shardName;
370            }
371    
372            private String _getShardName() {
373                    return _shardName.get();
374            }
375    
376            private String _getSignature(ProceedingJoinPoint proceedingJoinPoint) {
377                    String methodName = StringUtil.extractLast(
378                            proceedingJoinPoint.getTarget().getClass().getName(),
379                            StringPool.PERIOD);
380    
381                    methodName +=
382                            StringPool.PERIOD + proceedingJoinPoint.getSignature().getName() +
383                                    "()";
384    
385                    return methodName;
386            }
387    
388            private void _setShardName(String shardName) {
389                    _shardName.set(shardName);
390            }
391    
392            private void _setShardNameByCompany() throws Throwable {
393                    Stack<String> companyServiceStack = _getCompanyServiceStack();
394    
395                    if (companyServiceStack.isEmpty()) {
396                            long companyId = CompanyThreadLocal.getCompanyId();
397    
398                            _setShardNameByCompanyId(companyId);
399                    }
400                    else {
401                            String shardName = companyServiceStack.peek();
402    
403                            _setShardName(shardName);
404                    }
405            }
406    
407            private void _setShardNameByCompanyId(long companyId)
408                    throws PortalException, SystemException {
409    
410                    if (companyId == 0) {
411                            _setShardName(PropsValues.SHARD_DEFAULT_NAME);
412                    }
413                    else {
414                            Shard shard = ShardLocalServiceUtil.getShard(
415                                    Company.class.getName(), companyId);
416    
417                            String shardName = shard.getName();
418    
419                            _setShardName(shardName);
420                    }
421            }
422    
423            private static Log _log = LogFactoryUtil.getLog(ShardAdvice.class);
424    
425            private static ThreadLocal<Stack<String>> _companyServiceStack =
426                    new ThreadLocal<Stack<String>>();
427            private static ThreadLocal<Object> _globalCall = new ThreadLocal<Object>();
428            private static ThreadLocal<String> _shardName =
429                    new InitialThreadLocal<String>(
430                            ShardAdvice.class + "._shardName", PropsValues.SHARD_DEFAULT_NAME);
431            private static ShardSelector _shardSelector;
432    
433            private ShardDataSourceTargetSource _shardDataSourceTargetSource;
434            private ShardSessionFactoryTargetSource _shardSessionFactoryTargetSource;
435    
436            static {
437                    try {
438                            _shardSelector = (ShardSelector)Class.forName(
439                                    PropsValues.SHARD_SELECTOR).newInstance();
440                    }
441                    catch (Exception e) {
442                            _log.error(e, e);
443                    }
444            }
445    
446    }