1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package org.dbunit.ant;
22
23 import java.io.File;
24 import java.io.IOException;
25 import java.net.MalformedURLException;
26 import java.sql.SQLException;
27 import java.util.ArrayList;
28 import java.util.Iterator;
29 import java.util.List;
30
31 import org.apache.tools.ant.ProjectComponent;
32 import org.dbunit.DatabaseUnitException;
33 import org.dbunit.database.CachedResultSetTableFactory;
34 import org.dbunit.database.DatabaseConfig;
35 import org.dbunit.database.ForwardOnlyResultSetTableFactory;
36 import org.dbunit.database.IDatabaseConnection;
37 import org.dbunit.database.IResultSetTableFactory;
38 import org.dbunit.database.QueryDataSet;
39 import org.dbunit.dataset.CachedDataSet;
40 import org.dbunit.dataset.CompositeDataSet;
41 import org.dbunit.dataset.DataSetException;
42 import org.dbunit.dataset.IDataSet;
43 import org.dbunit.dataset.csv.CsvProducer;
44 import org.dbunit.dataset.excel.XlsDataSet;
45 import org.dbunit.dataset.stream.IDataSetProducer;
46 import org.dbunit.dataset.stream.StreamingDataSet;
47 import org.dbunit.dataset.xml.FlatDtdProducer;
48 import org.dbunit.dataset.xml.FlatXmlProducer;
49 import org.dbunit.dataset.xml.XmlProducer;
50 import org.dbunit.util.FileHelper;
51 import org.slf4j.Logger;
52 import org.slf4j.LoggerFactory;
53 import org.xml.sax.InputSource;
54
55
56
57
58
59
60
61 public abstract class AbstractStep extends ProjectComponent implements DbUnitTaskStep
62 {
63
64
65
66
67 private static final Logger logger = LoggerFactory.getLogger(AbstractStep.class);
68
69 public static final String FORMAT_FLAT = "flat";
70 public static final String FORMAT_XML = "xml";
71 public static final String FORMAT_DTD = "dtd";
72 public static final String FORMAT_CSV = "csv";
73 public static final String FORMAT_XLS = "xls";
74
75 private boolean ordered = false;
76
77
78 protected IDataSet getDatabaseDataSet(IDatabaseConnection connection,
79 List tables, boolean forwardonly) throws DatabaseUnitException
80 {
81 if (logger.isDebugEnabled())
82 {
83 logger.debug("getDatabaseDataSet(connection={}, tables={}, forwardonly={}) - start",
84 new Object[] { connection, tables, String.valueOf(forwardonly) });
85 }
86
87 try
88 {
89
90 IResultSetTableFactory factory = null;
91 if (forwardonly)
92 {
93 factory = new ForwardOnlyResultSetTableFactory();
94 }
95 else
96 {
97 factory = new CachedResultSetTableFactory();
98 }
99 DatabaseConfig config = connection.getConfig();
100 config.setProperty(DatabaseConfig.PROPERTY_RESULTSET_TABLE_FACTORY, factory);
101
102
103 if (tables.size() == 0)
104 {
105 logger.debug("Retrieving the whole database because tables/queries have not been specified");
106 return connection.createDataSet();
107 }
108
109 List queryDataSets = createQueryDataSet(tables, connection);
110 IDataSet[] dataSetsArray = (IDataSet[])queryDataSets.toArray( new IDataSet[queryDataSets.size()] );
111 return new CompositeDataSet(dataSetsArray);
112 }
113 catch (SQLException e)
114 {
115 throw new DatabaseUnitException(e);
116 }
117 }
118
119
120 private List createQueryDataSet(List tables, IDatabaseConnection connection)
121 throws DataSetException, SQLException
122 {
123 logger.debug("createQueryDataSet(tables={}, connection={})", tables, connection);
124
125 List queryDataSets = new ArrayList();
126
127 QueryDataSet queryDataSet = new QueryDataSet(connection);
128
129 for (Iterator it = tables.iterator(); it.hasNext();)
130 {
131 Object item = it.next();
132
133 if(item instanceof QuerySet) {
134 if(queryDataSet.getTableNames().length > 0)
135 queryDataSets.add(queryDataSet);
136
137 QueryDataSet newQueryDataSet = (((QuerySet)item).getQueryDataSet(connection));
138 queryDataSets.add(newQueryDataSet);
139 queryDataSet = new QueryDataSet(connection);
140 }
141 else if (item instanceof Query)
142 {
143 Query queryItem = (Query)item;
144 queryDataSet.addTable(queryItem.getName(), queryItem.getSql());
145 }
146 else if (item instanceof Table)
147 {
148 Table tableItem = (Table)item;
149 queryDataSet.addTable(tableItem.getName());
150 }
151 else
152 {
153 throw new IllegalArgumentException("Unsupported element type " + item.getClass().getName() + ".");
154 }
155 }
156
157 if(queryDataSet.getTableNames().length > 0)
158 queryDataSets.add(queryDataSet);
159
160 return queryDataSets;
161 }
162
163
164 protected IDataSet getSrcDataSet(File src, String format,
165 boolean forwardonly) throws DatabaseUnitException
166 {
167 if (logger.isDebugEnabled())
168 {
169 logger.debug("getSrcDataSet(src={}, format={}, forwardonly={}) - start",
170 new Object[]{ src, format, String.valueOf(forwardonly) });
171 }
172
173 try
174 {
175 IDataSetProducer producer = null;
176 if (format.equalsIgnoreCase(FORMAT_XML))
177 {
178 producer = new XmlProducer(getInputSource(src));
179 }
180 else if (format.equalsIgnoreCase(FORMAT_CSV))
181 {
182 producer = new CsvProducer(src);
183 }
184 else if (format.equalsIgnoreCase(FORMAT_FLAT))
185 {
186 producer = new FlatXmlProducer(getInputSource(src), true, true);
187 }
188 else if (format.equalsIgnoreCase(FORMAT_DTD))
189 {
190 producer = new FlatDtdProducer(getInputSource(src));
191 }
192 else if (format.equalsIgnoreCase(FORMAT_XLS))
193 {
194 return new CachedDataSet(new XlsDataSet(src));
195 }
196 else
197 {
198 throw new IllegalArgumentException("Type must be either 'flat'(default), 'xml', 'csv', 'xls' or 'dtd' but was: " + format);
199 }
200
201 if (forwardonly)
202 {
203 return new StreamingDataSet(producer);
204 }
205 return new CachedDataSet(producer);
206 }
207 catch (IOException e)
208 {
209 throw new DatabaseUnitException(e);
210 }
211 }
212
213
214
215
216
217
218
219
220
221
222
223 public boolean isDataFormat(String format)
224 {
225 logger.debug("isDataFormat(format={}) - start", format);
226
227 if (format.equalsIgnoreCase(FORMAT_FLAT)
228 || format.equalsIgnoreCase(FORMAT_XML)
229 || format.equalsIgnoreCase(FORMAT_CSV)
230 || format.equalsIgnoreCase(FORMAT_XLS)
231 )
232 {
233 return true;
234 }
235 else
236 {
237 return false;
238 }
239 }
240
241
242
243
244
245
246
247
248
249
250 protected void checkDataFormat(String format)
251 {
252 logger.debug("checkDataFormat(format={}) - start", format);
253
254 if (!isDataFormat(format))
255 {
256 throw new IllegalArgumentException("format must be either 'flat'(default), 'xml', 'csv' or 'xls' but was: " + format);
257 }
258 }
259
260
261
262
263
264
265
266
267 public static InputSource getInputSource(File file) throws MalformedURLException
268 {
269 InputSource source = FileHelper.createInputSource(file);
270 return source;
271 }
272
273 public boolean isOrdered()
274 {
275 return ordered;
276 }
277
278 public void setOrdered(boolean ordered)
279 {
280 this.ordered = ordered;
281 }
282
283 public String toString()
284 {
285 StringBuffer result = new StringBuffer();
286 result.append("AbstractStep: ");
287 result.append("ordered=").append(this.ordered);
288 return result.toString();
289 }
290
291 }